outputs.py 7.69 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from dataclasses import dataclass, field
6
from typing import TYPE_CHECKING, NamedTuple
7

8
import numpy as np
9
10
import torch

11
12
from vllm.v1.core.sched.output import SchedulerOutput

13
if TYPE_CHECKING:
14
    from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
15
16
else:
    KVConnectorStats = object
17

18

19
class LogprobsLists(NamedTuple):
20
    # [num_reqs x num_generated_tokens, max_num_logprobs + 1]
21
    logprob_token_ids: np.ndarray
22
    # [num_reqs x num_generated_tokens, max_num_logprobs + 1]
23
    logprobs: np.ndarray
24
    # [num_reqs x num_generated_tokens]
25
    sampled_token_ranks: np.ndarray
26
27
28
29
30
    # [num_reqs]
    # Used for slicing the logprobs in cases like speculative
    # decoding where the number of generated tokens may be
    # different for each request.
    cu_num_generated_tokens: list[int] | None = None
31

32
33
34
35
    def slice_request(self, req_idx: int, num_positions: int):
        if self.cu_num_generated_tokens is not None:
            req_idx = self.cu_num_generated_tokens[req_idx]
        end_idx = req_idx + num_positions
36
        return LogprobsLists(
37
38
39
40
            self.logprob_token_ids[req_idx:end_idx],
            self.logprobs[req_idx:end_idx],
            self.sampled_token_ranks[req_idx:end_idx],
            None,
41
42
43
44
        )


class LogprobsTensors(NamedTuple):
45
    # [num_reqs x num_generated_tokens, max_num_logprobs + 1]
46
    logprob_token_ids: torch.Tensor
47
    # [num_reqs x num_generated_tokens, max_num_logprobs + 1]
48
    logprobs: torch.Tensor
49
    # [num_reqs x num_generated_tokens]
50
    selected_token_ranks: torch.Tensor
51

52
    def tolists(self, cu_num_generated_tokens: list[int] | None = None):
53
        return LogprobsLists(
54
55
56
            self.logprob_token_ids.cpu().numpy(),
            self.logprobs.cpu().numpy(),
            self.selected_token_ranks.cpu().numpy(),
57
            cu_num_generated_tokens,
58
59
        )

60
61
62
63
64
65
66
67
68
    def to_cpu_nonblocking(self) -> "LogprobsTensors":
        if self.logprob_token_ids.device.type == "cpu":
            return self
        return LogprobsTensors(
            self.logprob_token_ids.to("cpu", non_blocking=True),
            self.logprobs.to("cpu", non_blocking=True),
            self.selected_token_ranks.to("cpu", non_blocking=True),
        )

69
    @staticmethod
70
71
72
    def empty_cpu(
        num_positions: int, num_tokens_per_position: int
    ) -> "LogprobsTensors":
73
74
75
        """Create empty LogprobsTensors on CPU."""

        logprob_token_ids = torch.empty(
76
77
            (num_positions, num_tokens_per_position), dtype=torch.int32, device="cpu"
        )
78
        logprobs = torch.empty_like(logprob_token_ids, dtype=torch.float32)
79
80
81
        selected_token_ranks = torch.empty(
            num_positions, dtype=torch.int32, device="cpu"
        )
82
83
84
85
86
87
        return LogprobsTensors(
            logprob_token_ids=logprob_token_ids,
            logprobs=logprobs,
            selected_token_ranks=selected_token_ranks,
        )

88

89
90
# [num_reqs, <dynamic>]
# The shape of each element depends on the pooler used
91
PoolerOutput = torch.Tensor | list[torch.Tensor]
92
93


94
95
@dataclass
class SamplerOutput:
96
97
98
    # [num_reqs, max_num_generated_tokens]
    # Different requests can have different number of generated tokens.
    # All requests are padded to max_num_generated_tokens.
99
    # PLACEHOLDER_TOKEN_ID (-1 by default) is used for padding.
100
    sampled_token_ids: torch.Tensor
101
    logprobs_tensors: LogprobsTensors | None
102
103


104
105
106
@dataclass
class KVConnectorOutput:
    # [req_ids]
107
108
109
    finished_sending: set[str] | None = None
    finished_recving: set[str] | None = None
    kv_connector_stats: KVConnectorStats | None = None
110
    # IDs of externally computed KV blocks that failed to load.
111
    # Requests referencing these blocks should be rescheduled to recompute them
112
    invalid_block_ids: set[int] = field(default_factory=set)
113
114
115
116
117
118
    # Configuration describing how many finished sending/receiving
    # notifications should be expected for each request. This allows
    # handshake-based connectors like Nixl to update the KVOutputAggregator.
    # It captures a static setup info and should almost always remain constant
    # for a given connector after discovery. Default value entails no change.
    expected_finished_count: int = 0
119
120

    def is_empty(self):
121
122
123
124
125
126
        return (
            not self.finished_sending
            and not self.finished_recving
            and not self.kv_connector_stats
            and not self.invalid_block_ids
        )
127
128


129
130
131
132
133
134
135
@dataclass
class ECConnectorOutput:
    # [mm_hash]
    finished_sending: set[str] | None = None
    finished_recving: set[str] | None = None


136
# ModelRunnerOutput is serialized and sent to the scheduler process.
137
# This is expensive for torch.Tensor so prefer to use list instead.
138
139
140
@dataclass
class ModelRunnerOutput:
    # [num_reqs]
141
    req_ids: list[str]
142
    # req_id -> index
143
    req_id_to_index: dict[str, int]
144

145
146
147
148
    # num_reqs x num_generated_tokens
    # num_generated_tokens is the number of tokens
    # generated in the current step. It can be different for
    # each request due to speculative/jump decoding.
149
    sampled_token_ids: list[list[int]]
150
151
152

    # [num_reqs, max_num_logprobs + 1]
    # [num_reqs, max_num_logprobs + 1]
153
    # [num_reqs]
154
    logprobs: LogprobsLists | None
155
156
157
158
159

    # req_id -> (token_ids, logprobs, ranks)
    # [prompt_len, num_prompt_logprobs]
    # [prompt_len, num_prompt_logprobs]
    # [prompt_len]
160
    prompt_logprobs_dict: dict[str, LogprobsTensors | None]
161

162
    # [num_reqs, hidden_size]
163
    pooler_output: list[torch.Tensor | None]
164

165
    kv_connector_output: KVConnectorOutput | None = None
Robert Shaw's avatar
Robert Shaw committed
166

167
168
    ec_connector_output: ECConnectorOutput | None = None

169
    # req_id -> num_nans_in_logits
170
    num_nans_in_logits: dict[str, int] | None = None
171

Robert Shaw's avatar
Robert Shaw committed
172

173
174
175
176
177
# ModelRunnerOutput wrapper for async scheduling.
class AsyncModelRunnerOutput(ABC):
    @abstractmethod
    def get_output(self) -> ModelRunnerOutput:
        """Get the ModelRunnerOutput for this async output.
178

179
180
181
182
183
184
185
        This is a blocking call that waits until the results are ready, which
        might involve copying device tensors to the host.
        This method should only be called once per AsyncModelRunnerOutput.
        """
        pass


186
187
188
189
190
191
192
193
@dataclass
class DraftTokenIds:
    # [num_reqs]
    req_ids: list[str]
    # num_reqs x num_draft_tokens
    draft_token_ids: list[list[int]]


194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
def make_empty_encoder_model_runner_output(
    scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput:
    """
    Create a ModelRunnerOutput stub that contains the correct
    per-request bookkeeping but no generated data yet.
    """
    if not scheduler_output.num_scheduled_tokens:
        return EMPTY_MODEL_RUNNER_OUTPUT

    # Convert to list so we get a deterministic, indexable sequence
    req_ids: list[str] = list(scheduler_output.num_scheduled_tokens.keys())

    # Give every request its own contiguous index
    req_id_to_index: dict[str, int] = {rid: idx for idx, rid in enumerate(req_ids)}

    # No tokens generated yet ⇒ one empty list per request
211
    sampled_token_ids: list[list[int]] = [[0] for _ in req_ids]
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228

    # Pooler outputs are not available yet ⇒ use None placeholders
    pooler_output: list[torch.Tensor | None] = [None for _ in req_ids]

    return ModelRunnerOutput(
        req_ids=req_ids,
        req_id_to_index=req_id_to_index,
        sampled_token_ids=sampled_token_ids,
        logprobs=None,
        prompt_logprobs_dict={},
        pooler_output=pooler_output,
        kv_connector_output=None,
        ec_connector_output=None,
        num_nans_in_logits=None,
    )


229
230
231
232
233
234
235
236
237
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
    req_ids=[],
    req_id_to_index={},
    sampled_token_ids=[],
    logprobs=None,
    prompt_logprobs_dict={},
    pooler_output=[],
    num_nans_in_logits=None,
)