outputs.py 10.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections.abc import Callable
6
from dataclasses import dataclass, field
7
from typing import TYPE_CHECKING, NamedTuple, TypeAlias, TypeVar
8

9
import numpy as np
10
11
import torch

12
from vllm.compilation.cuda_graph import CUDAGraphStat
13
14
from vllm.v1.core.sched.output import SchedulerOutput

15
if TYPE_CHECKING:
16
    from vllm.distributed.kv_events import KVConnectorKVEvents
17
    from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
18
19
else:
    KVConnectorStats = object
20
    KVConnectorKVEvents = object
21

22

23
class LogprobsLists(NamedTuple):
24
    # [num_reqs x num_generated_tokens, max_num_logprobs + 1]
25
    logprob_token_ids: np.ndarray
26
    # [num_reqs x num_generated_tokens, max_num_logprobs + 1]
27
    logprobs: np.ndarray
28
    # [num_reqs x num_generated_tokens]
29
    sampled_token_ranks: np.ndarray
30
31
32
33
34
    # [num_reqs]
    # Used for slicing the logprobs in cases like speculative
    # decoding where the number of generated tokens may be
    # different for each request.
    cu_num_generated_tokens: list[int] | None = None
35

36
37
38
39
    def slice_request(self, req_idx: int, num_positions: int):
        if self.cu_num_generated_tokens is not None:
            req_idx = self.cu_num_generated_tokens[req_idx]
        end_idx = req_idx + num_positions
40
        return LogprobsLists(
41
42
43
44
            self.logprob_token_ids[req_idx:end_idx],
            self.logprobs[req_idx:end_idx],
            self.sampled_token_ranks[req_idx:end_idx],
            None,
45
46
47
48
        )


class LogprobsTensors(NamedTuple):
49
    # [num_reqs x num_generated_tokens, max_num_logprobs + 1]
50
    logprob_token_ids: torch.Tensor
51
    # [num_reqs x num_generated_tokens, max_num_logprobs + 1]
52
    logprobs: torch.Tensor
53
    # [num_reqs x num_generated_tokens]
54
    selected_token_ranks: torch.Tensor
55
56
    # [num_reqs]
    cu_num_generated_tokens: list[int] | None = None
57

58
    def tolists(self, cu_num_generated_tokens: list[int] | None = None):
59
        return LogprobsLists(
60
61
62
            self.logprob_token_ids.cpu().numpy(),
            self.logprobs.cpu().numpy(),
            self.selected_token_ranks.cpu().numpy(),
63
64
65
            cu_num_generated_tokens
            if cu_num_generated_tokens is not None
            else self.cu_num_generated_tokens,
66
67
        )

68
69
70
71
72
73
74
    def to_cpu_nonblocking(self) -> "LogprobsTensors":
        if self.logprob_token_ids.device.type == "cpu":
            return self
        return LogprobsTensors(
            self.logprob_token_ids.to("cpu", non_blocking=True),
            self.logprobs.to("cpu", non_blocking=True),
            self.selected_token_ranks.to("cpu", non_blocking=True),
75
            self.cu_num_generated_tokens,
76
77
        )

78
79
    def filter(self, mask: torch.Tensor) -> "LogprobsTensors":
        """Filter the logprobs tensors with the given bool mask."""
80
81
82
        assert self.cu_num_generated_tokens is None, (
            "filter can't be used with cu_num_generated_tokens"
        )
83
84
85
86
87
88
        return LogprobsTensors(
            self.logprob_token_ids[mask],
            self.logprobs[mask],
            self.selected_token_ranks[mask],
        )

89
    @staticmethod
90
91
92
    def empty_cpu(
        num_positions: int, num_tokens_per_position: int
    ) -> "LogprobsTensors":
93
94
95
        """Create empty LogprobsTensors on CPU."""

        logprob_token_ids = torch.empty(
96
97
            (num_positions, num_tokens_per_position), dtype=torch.int32, device="cpu"
        )
98
        logprobs = torch.empty_like(logprob_token_ids, dtype=torch.float32)
99
100
101
        selected_token_ranks = torch.empty(
            num_positions, dtype=torch.int32, device="cpu"
        )
102
103
104
105
106
107
        return LogprobsTensors(
            logprob_token_ids=logprob_token_ids,
            logprobs=logprobs,
            selected_token_ranks=selected_token_ranks,
        )

108

109
110
# [num_reqs, <dynamic>]
# The shape of each element depends on the pooler used
111
PoolerOutput: TypeAlias = torch.Tensor | list[torch.Tensor] | list[torch.Tensor | None]
112
113


114
115
@dataclass
class SamplerOutput:
116
117
118
    # [num_reqs, max_num_generated_tokens]
    # Different requests can have different number of generated tokens.
    # All requests are padded to max_num_generated_tokens.
119
    # PLACEHOLDER_TOKEN_ID (-1 by default) is used for padding.
120
    sampled_token_ids: torch.Tensor
121
    logprobs_tensors: LogprobsTensors | None
122
123


124
125
126
127
128
129
130
131
132
133
134
135
136
137
T = TypeVar("T")


def _combine_non_none(f: Callable[[T, T], T], items: list[T | None]) -> T | None:
    non_none = [item for item in items if item is not None]
    if len(non_none) == 0:
        return None

    combined = non_none[0]
    for item in non_none[1:]:
        combined = f(combined, item)
    return combined


138
139
140
@dataclass
class KVConnectorOutput:
    # [req_ids]
141
142
143
    finished_sending: set[str] | None = None
    finished_recving: set[str] | None = None
    kv_connector_stats: KVConnectorStats | None = None
144
    kv_cache_events: KVConnectorKVEvents | None = None
145
    # IDs of externally computed KV blocks that failed to load.
146
    # Requests referencing these blocks should be rescheduled to recompute them
147
    invalid_block_ids: set[int] = field(default_factory=set)
148
149
150
151
152
153
    # Configuration describing how many finished sending/receiving
    # notifications should be expected for each request. This allows
    # handshake-based connectors like Nixl to update the KVOutputAggregator.
    # It captures a static setup info and should almost always remain constant
    # for a given connector after discovery. Default value entails no change.
    expected_finished_count: int = 0
154
155

    def is_empty(self):
156
157
158
159
        return (
            not self.finished_sending
            and not self.finished_recving
            and not self.kv_connector_stats
160
            and not self.kv_cache_events
161
162
            and not self.invalid_block_ids
        )
163

164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    @classmethod
    def merge(cls, *outputs: "KVConnectorOutput"):
        assert len(outputs) > 0, "Cannot merge empty outputs"
        finished_sending = _combine_non_none(
            set.union, [output.finished_sending for output in outputs]
        )
        finished_recving = _combine_non_none(
            set.union, [output.finished_recving for output in outputs]
        )
        kv_connector_stats = _combine_non_none(
            lambda x, y: x.aggregate(y),
            [output.kv_connector_stats for output in outputs],
        )
        kv_cache_events = _combine_non_none(
            lambda x, y: x.merge(y),
            [output.kv_cache_events for output in outputs],
        )
        invalid_block_ids = _combine_non_none(
            set.union, [output.invalid_block_ids for output in outputs]
        )
        assert invalid_block_ids is not None

        assert all(
            output.expected_finished_count == outputs[0].expected_finished_count
            for output in outputs
        )
        expected_finished_count = outputs[0].expected_finished_count

        return cls(
            finished_sending=finished_sending,
            finished_recving=finished_recving,
            kv_connector_stats=kv_connector_stats,
            kv_cache_events=kv_cache_events,
            invalid_block_ids=invalid_block_ids,
            expected_finished_count=expected_finished_count,
        )

201

202
203
204
205
206
207
208
@dataclass
class ECConnectorOutput:
    # [mm_hash]
    finished_sending: set[str] | None = None
    finished_recving: set[str] | None = None


209
# ModelRunnerOutput is serialized and sent to the scheduler process.
210
# This is expensive for torch.Tensor so prefer to use list instead.
211
212
213
@dataclass
class ModelRunnerOutput:
    # [num_reqs]
214
    req_ids: list[str]
215
    # req_id -> index
216
    req_id_to_index: dict[str, int]
217

218
219
220
221
    # num_reqs x num_generated_tokens
    # num_generated_tokens is the number of tokens
    # generated in the current step. It can be different for
    # each request due to speculative/jump decoding.
222
    sampled_token_ids: list[list[int]] = field(default_factory=list)
223
224
225

    # [num_reqs, max_num_logprobs + 1]
    # [num_reqs, max_num_logprobs + 1]
226
    # [num_reqs]
227
    logprobs: LogprobsLists | None = None
228
229
230
231
232

    # req_id -> (token_ids, logprobs, ranks)
    # [prompt_len, num_prompt_logprobs]
    # [prompt_len, num_prompt_logprobs]
    # [prompt_len]
233
234
235
    prompt_logprobs_dict: dict[str, LogprobsTensors | None] = field(
        default_factory=dict
    )
236

237
    # [num_reqs, hidden_size]
238
    pooler_output: list[torch.Tensor | None] | None = None
239

240
    kv_connector_output: KVConnectorOutput | None = None
Robert Shaw's avatar
Robert Shaw committed
241

242
243
    ec_connector_output: ECConnectorOutput | None = None

244
    # req_id -> num_nans_in_logits
245
    num_nans_in_logits: dict[str, int] | None = None
246
247
248

    # information related to cudagraph execution
    cudagraph_stats: CUDAGraphStat | None = None
249

Robert Shaw's avatar
Robert Shaw committed
250

251
252
253
254
255
# ModelRunnerOutput wrapper for async scheduling.
class AsyncModelRunnerOutput(ABC):
    @abstractmethod
    def get_output(self) -> ModelRunnerOutput:
        """Get the ModelRunnerOutput for this async output.
256

257
258
259
260
261
262
263
        This is a blocking call that waits until the results are ready, which
        might involve copying device tensors to the host.
        This method should only be called once per AsyncModelRunnerOutput.
        """
        pass


264
265
266
267
268
269
270
271
@dataclass
class DraftTokenIds:
    # [num_reqs]
    req_ids: list[str]
    # num_reqs x num_draft_tokens
    draft_token_ids: list[list[int]]


272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
def make_empty_encoder_model_runner_output(
    scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput:
    """
    Create a ModelRunnerOutput stub that contains the correct
    per-request bookkeeping but no generated data yet.
    """
    if not scheduler_output.num_scheduled_tokens:
        return EMPTY_MODEL_RUNNER_OUTPUT

    # Convert to list so we get a deterministic, indexable sequence
    req_ids: list[str] = list(scheduler_output.num_scheduled_tokens.keys())

    # Give every request its own contiguous index
    req_id_to_index: dict[str, int] = {rid: idx for idx, rid in enumerate(req_ids)}

    # No tokens generated yet ⇒ one empty list per request
289
    sampled_token_ids: list[list[int]] = [[0] for _ in req_ids]
290
291
292
293
294
295
296
297
298
299
300
301

    # Pooler outputs are not available yet ⇒ use None placeholders
    pooler_output: list[torch.Tensor | None] = [None for _ in req_ids]

    return ModelRunnerOutput(
        req_ids=req_ids,
        req_id_to_index=req_id_to_index,
        sampled_token_ids=sampled_token_ids,
        pooler_output=pooler_output,
    )


302
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], req_id_to_index={})