outputs.py 8.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from dataclasses import dataclass, field
6
from typing import TYPE_CHECKING, NamedTuple, TypeAlias
7

8
import numpy as np
9
10
import torch

11
from vllm.compilation.cuda_graph import CUDAGraphStat
12
13
from vllm.v1.core.sched.output import SchedulerOutput

14
if TYPE_CHECKING:
15
    from vllm.distributed.kv_events import KVConnectorKVEvents
16
    from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
17
18
else:
    KVConnectorStats = object
19
    KVConnectorKVEvents = object
20

21

22
class LogprobsLists(NamedTuple):
23
    # [num_reqs x num_generated_tokens, max_num_logprobs + 1]
24
    logprob_token_ids: np.ndarray
25
    # [num_reqs x num_generated_tokens, max_num_logprobs + 1]
26
    logprobs: np.ndarray
27
    # [num_reqs x num_generated_tokens]
28
    sampled_token_ranks: np.ndarray
29
30
31
32
33
    # [num_reqs]
    # Used for slicing the logprobs in cases like speculative
    # decoding where the number of generated tokens may be
    # different for each request.
    cu_num_generated_tokens: list[int] | None = None
34

35
36
37
38
    def slice_request(self, req_idx: int, num_positions: int):
        if self.cu_num_generated_tokens is not None:
            req_idx = self.cu_num_generated_tokens[req_idx]
        end_idx = req_idx + num_positions
39
        return LogprobsLists(
40
41
42
43
            self.logprob_token_ids[req_idx:end_idx],
            self.logprobs[req_idx:end_idx],
            self.sampled_token_ranks[req_idx:end_idx],
            None,
44
45
46
47
        )


class LogprobsTensors(NamedTuple):
48
    # [num_reqs x num_generated_tokens, max_num_logprobs + 1]
49
    logprob_token_ids: torch.Tensor
50
    # [num_reqs x num_generated_tokens, max_num_logprobs + 1]
51
    logprobs: torch.Tensor
52
    # [num_reqs x num_generated_tokens]
53
    selected_token_ranks: torch.Tensor
54
55
    # [num_reqs]
    cu_num_generated_tokens: list[int] | None = None
56

57
    def tolists(self, cu_num_generated_tokens: list[int] | None = None):
58
        return LogprobsLists(
59
60
61
            self.logprob_token_ids.cpu().numpy(),
            self.logprobs.cpu().numpy(),
            self.selected_token_ranks.cpu().numpy(),
62
63
64
            cu_num_generated_tokens
            if cu_num_generated_tokens is not None
            else self.cu_num_generated_tokens,
65
66
        )

67
68
69
70
71
72
73
    def to_cpu_nonblocking(self) -> "LogprobsTensors":
        if self.logprob_token_ids.device.type == "cpu":
            return self
        return LogprobsTensors(
            self.logprob_token_ids.to("cpu", non_blocking=True),
            self.logprobs.to("cpu", non_blocking=True),
            self.selected_token_ranks.to("cpu", non_blocking=True),
74
            self.cu_num_generated_tokens,
75
76
        )

77
78
    def filter(self, mask: torch.Tensor) -> "LogprobsTensors":
        """Filter the logprobs tensors with the given bool mask."""
79
80
81
        assert self.cu_num_generated_tokens is None, (
            "filter can't be used with cu_num_generated_tokens"
        )
82
83
84
85
86
87
        return LogprobsTensors(
            self.logprob_token_ids[mask],
            self.logprobs[mask],
            self.selected_token_ranks[mask],
        )

88
    @staticmethod
89
90
91
    def empty_cpu(
        num_positions: int, num_tokens_per_position: int
    ) -> "LogprobsTensors":
92
93
94
        """Create empty LogprobsTensors on CPU."""

        logprob_token_ids = torch.empty(
95
96
            (num_positions, num_tokens_per_position), dtype=torch.int32, device="cpu"
        )
97
        logprobs = torch.empty_like(logprob_token_ids, dtype=torch.float32)
98
99
100
        selected_token_ranks = torch.empty(
            num_positions, dtype=torch.int32, device="cpu"
        )
101
102
103
104
105
106
        return LogprobsTensors(
            logprob_token_ids=logprob_token_ids,
            logprobs=logprobs,
            selected_token_ranks=selected_token_ranks,
        )

107

108
109
# [num_reqs, <dynamic>]
# The shape of each element depends on the pooler used
110
PoolerOutput: TypeAlias = torch.Tensor | list[torch.Tensor] | list[torch.Tensor | None]
111
112


113
114
@dataclass
class SamplerOutput:
115
116
117
    # [num_reqs, max_num_generated_tokens]
    # Different requests can have different number of generated tokens.
    # All requests are padded to max_num_generated_tokens.
118
    # PLACEHOLDER_TOKEN_ID (-1 by default) is used for padding.
119
    sampled_token_ids: torch.Tensor
120
    logprobs_tensors: LogprobsTensors | None
121
122


123
124
125
@dataclass
class KVConnectorOutput:
    # [req_ids]
126
127
128
    finished_sending: set[str] | None = None
    finished_recving: set[str] | None = None
    kv_connector_stats: KVConnectorStats | None = None
129
    kv_cache_events: KVConnectorKVEvents | None = None
130
    # IDs of externally computed KV blocks that failed to load.
131
    # Requests referencing these blocks should be rescheduled to recompute them
132
    invalid_block_ids: set[int] = field(default_factory=set)
133
134
135
136
137
138
    # Configuration describing how many finished sending/receiving
    # notifications should be expected for each request. This allows
    # handshake-based connectors like Nixl to update the KVOutputAggregator.
    # It captures a static setup info and should almost always remain constant
    # for a given connector after discovery. Default value entails no change.
    expected_finished_count: int = 0
139
140

    def is_empty(self):
141
142
143
144
        return (
            not self.finished_sending
            and not self.finished_recving
            and not self.kv_connector_stats
145
            and not self.kv_cache_events
146
147
            and not self.invalid_block_ids
        )
148
149


150
151
152
153
154
155
156
@dataclass
class ECConnectorOutput:
    # [mm_hash]
    finished_sending: set[str] | None = None
    finished_recving: set[str] | None = None


157
# ModelRunnerOutput is serialized and sent to the scheduler process.
158
# This is expensive for torch.Tensor so prefer to use list instead.
159
160
161
@dataclass
class ModelRunnerOutput:
    # [num_reqs]
162
    req_ids: list[str]
163
    # req_id -> index
164
    req_id_to_index: dict[str, int]
165

166
167
168
169
    # num_reqs x num_generated_tokens
    # num_generated_tokens is the number of tokens
    # generated in the current step. It can be different for
    # each request due to speculative/jump decoding.
170
    sampled_token_ids: list[list[int]] = field(default_factory=list)
171
172
173

    # [num_reqs, max_num_logprobs + 1]
    # [num_reqs, max_num_logprobs + 1]
174
    # [num_reqs]
175
    logprobs: LogprobsLists | None = None
176
177
178
179
180

    # req_id -> (token_ids, logprobs, ranks)
    # [prompt_len, num_prompt_logprobs]
    # [prompt_len, num_prompt_logprobs]
    # [prompt_len]
181
182
183
    prompt_logprobs_dict: dict[str, LogprobsTensors | None] = field(
        default_factory=dict
    )
184

185
    # [num_reqs, hidden_size]
186
    pooler_output: list[torch.Tensor | None] | None = None
187

188
    kv_connector_output: KVConnectorOutput | None = None
Robert Shaw's avatar
Robert Shaw committed
189

190
191
    ec_connector_output: ECConnectorOutput | None = None

192
    # req_id -> num_nans_in_logits
193
    num_nans_in_logits: dict[str, int] | None = None
194
195
196

    # information related to cudagraph execution
    cudagraph_stats: CUDAGraphStat | None = None
197

Robert Shaw's avatar
Robert Shaw committed
198

199
200
201
202
203
# ModelRunnerOutput wrapper for async scheduling.
class AsyncModelRunnerOutput(ABC):
    @abstractmethod
    def get_output(self) -> ModelRunnerOutput:
        """Get the ModelRunnerOutput for this async output.
204

205
206
207
208
209
210
211
        This is a blocking call that waits until the results are ready, which
        might involve copying device tensors to the host.
        This method should only be called once per AsyncModelRunnerOutput.
        """
        pass


212
213
214
215
216
217
218
219
@dataclass
class DraftTokenIds:
    # [num_reqs]
    req_ids: list[str]
    # num_reqs x num_draft_tokens
    draft_token_ids: list[list[int]]


220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def make_empty_encoder_model_runner_output(
    scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput:
    """
    Create a ModelRunnerOutput stub that contains the correct
    per-request bookkeeping but no generated data yet.
    """
    if not scheduler_output.num_scheduled_tokens:
        return EMPTY_MODEL_RUNNER_OUTPUT

    # Convert to list so we get a deterministic, indexable sequence
    req_ids: list[str] = list(scheduler_output.num_scheduled_tokens.keys())

    # Give every request its own contiguous index
    req_id_to_index: dict[str, int] = {rid: idx for idx, rid in enumerate(req_ids)}

    # No tokens generated yet ⇒ one empty list per request
237
    sampled_token_ids: list[list[int]] = [[0] for _ in req_ids]
238
239
240
241
242
243
244
245
246
247
248
249

    # Pooler outputs are not available yet ⇒ use None placeholders
    pooler_output: list[torch.Tensor | None] = [None for _ in req_ids]

    return ModelRunnerOutput(
        req_ids=req_ids,
        req_id_to_index=req_id_to_index,
        sampled_token_ids=sampled_token_ids,
        pooler_output=pooler_output,
    )


250
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], req_id_to_index={})