outputs.py 8.01 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from dataclasses import dataclass, field
6
from typing import TYPE_CHECKING, NamedTuple, TypeAlias
7

8
import numpy as np
9
10
import torch

11
from vllm.compilation.cuda_graph import CUDAGraphStat
12
13
from vllm.v1.core.sched.output import SchedulerOutput

14
if TYPE_CHECKING:
15
    from vllm.distributed.kv_events import KVConnectorKVEvents
16
    from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
17
18
else:
    KVConnectorStats = object
19
    KVConnectorKVEvents = object
20

21

22
class LogprobsLists(NamedTuple):
23
    # [num_reqs x num_generated_tokens, max_num_logprobs + 1]
24
    logprob_token_ids: np.ndarray
25
    # [num_reqs x num_generated_tokens, max_num_logprobs + 1]
26
    logprobs: np.ndarray
27
    # [num_reqs x num_generated_tokens]
28
    sampled_token_ranks: np.ndarray
29
30
31
32
33
    # [num_reqs]
    # Used for slicing the logprobs in cases like speculative
    # decoding where the number of generated tokens may be
    # different for each request.
    cu_num_generated_tokens: list[int] | None = None
34

35
36
37
38
    def slice_request(self, req_idx: int, num_positions: int):
        if self.cu_num_generated_tokens is not None:
            req_idx = self.cu_num_generated_tokens[req_idx]
        end_idx = req_idx + num_positions
39
        return LogprobsLists(
40
41
42
43
            self.logprob_token_ids[req_idx:end_idx],
            self.logprobs[req_idx:end_idx],
            self.sampled_token_ranks[req_idx:end_idx],
            None,
44
45
46
47
        )


class LogprobsTensors(NamedTuple):
48
    # [num_reqs x num_generated_tokens, max_num_logprobs + 1]
49
    logprob_token_ids: torch.Tensor
50
    # [num_reqs x num_generated_tokens, max_num_logprobs + 1]
51
    logprobs: torch.Tensor
52
    # [num_reqs x num_generated_tokens]
53
    selected_token_ranks: torch.Tensor
54

55
    def tolists(self, cu_num_generated_tokens: list[int] | None = None):
56
        return LogprobsLists(
57
58
59
            self.logprob_token_ids.cpu().numpy(),
            self.logprobs.cpu().numpy(),
            self.selected_token_ranks.cpu().numpy(),
60
            cu_num_generated_tokens,
61
62
        )

63
64
65
66
67
68
69
70
71
    def to_cpu_nonblocking(self) -> "LogprobsTensors":
        if self.logprob_token_ids.device.type == "cpu":
            return self
        return LogprobsTensors(
            self.logprob_token_ids.to("cpu", non_blocking=True),
            self.logprobs.to("cpu", non_blocking=True),
            self.selected_token_ranks.to("cpu", non_blocking=True),
        )

72
    @staticmethod
73
74
75
    def empty_cpu(
        num_positions: int, num_tokens_per_position: int
    ) -> "LogprobsTensors":
76
77
78
        """Create empty LogprobsTensors on CPU."""

        logprob_token_ids = torch.empty(
79
80
            (num_positions, num_tokens_per_position), dtype=torch.int32, device="cpu"
        )
81
        logprobs = torch.empty_like(logprob_token_ids, dtype=torch.float32)
82
83
84
        selected_token_ranks = torch.empty(
            num_positions, dtype=torch.int32, device="cpu"
        )
85
86
87
88
89
90
        return LogprobsTensors(
            logprob_token_ids=logprob_token_ids,
            logprobs=logprobs,
            selected_token_ranks=selected_token_ranks,
        )

91

92
93
# [num_reqs, <dynamic>]
# The shape of each element depends on the pooler used
94
TokenPoolerOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
95
96
TokenwisePoolerOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None]
PoolerOutput: TypeAlias = TokenPoolerOutput | TokenwisePoolerOutput
97
98


99
100
@dataclass
class SamplerOutput:
101
102
103
    # [num_reqs, max_num_generated_tokens]
    # Different requests can have different number of generated tokens.
    # All requests are padded to max_num_generated_tokens.
104
    # PLACEHOLDER_TOKEN_ID (-1 by default) is used for padding.
105
    sampled_token_ids: torch.Tensor
106
    logprobs_tensors: LogprobsTensors | None
107
108


109
110
111
@dataclass
class KVConnectorOutput:
    # [req_ids]
112
113
114
    finished_sending: set[str] | None = None
    finished_recving: set[str] | None = None
    kv_connector_stats: KVConnectorStats | None = None
115
    kv_cache_events: KVConnectorKVEvents | None = None
116
    # IDs of externally computed KV blocks that failed to load.
117
    # Requests referencing these blocks should be rescheduled to recompute them
118
    invalid_block_ids: set[int] = field(default_factory=set)
119
120
121
122
123
124
    # Configuration describing how many finished sending/receiving
    # notifications should be expected for each request. This allows
    # handshake-based connectors like Nixl to update the KVOutputAggregator.
    # It captures a static setup info and should almost always remain constant
    # for a given connector after discovery. Default value entails no change.
    expected_finished_count: int = 0
125
126

    def is_empty(self):
127
128
129
130
        return (
            not self.finished_sending
            and not self.finished_recving
            and not self.kv_connector_stats
131
            and not self.kv_cache_events
132
133
            and not self.invalid_block_ids
        )
134
135


136
137
138
139
140
141
142
@dataclass
class ECConnectorOutput:
    # [mm_hash]
    finished_sending: set[str] | None = None
    finished_recving: set[str] | None = None


143
# ModelRunnerOutput is serialized and sent to the scheduler process.
144
# This is expensive for torch.Tensor so prefer to use list instead.
145
146
147
@dataclass
class ModelRunnerOutput:
    # [num_reqs]
148
    req_ids: list[str]
149
    # req_id -> index
150
    req_id_to_index: dict[str, int]
151

152
153
154
155
    # num_reqs x num_generated_tokens
    # num_generated_tokens is the number of tokens
    # generated in the current step. It can be different for
    # each request due to speculative/jump decoding.
156
    sampled_token_ids: list[list[int]] = field(default_factory=list)
157
158
159

    # [num_reqs, max_num_logprobs + 1]
    # [num_reqs, max_num_logprobs + 1]
160
    # [num_reqs]
161
    logprobs: LogprobsLists | None = None
162
163
164
165
166

    # req_id -> (token_ids, logprobs, ranks)
    # [prompt_len, num_prompt_logprobs]
    # [prompt_len, num_prompt_logprobs]
    # [prompt_len]
167
168
169
    prompt_logprobs_dict: dict[str, LogprobsTensors | None] = field(
        default_factory=dict
    )
170

171
    # [num_reqs, hidden_size]
172
    pooler_output: list[torch.Tensor | None] | None = None
173

174
    kv_connector_output: KVConnectorOutput | None = None
Robert Shaw's avatar
Robert Shaw committed
175

176
177
    ec_connector_output: ECConnectorOutput | None = None

178
    # req_id -> num_nans_in_logits
179
    num_nans_in_logits: dict[str, int] | None = None
180
181
182

    # information related to cudagraph execution
    cudagraph_stats: CUDAGraphStat | None = None
183

Robert Shaw's avatar
Robert Shaw committed
184

185
186
187
188
189
# ModelRunnerOutput wrapper for async scheduling.
class AsyncModelRunnerOutput(ABC):
    @abstractmethod
    def get_output(self) -> ModelRunnerOutput:
        """Get the ModelRunnerOutput for this async output.
190

191
192
193
194
195
196
197
        This is a blocking call that waits until the results are ready, which
        might involve copying device tensors to the host.
        This method should only be called once per AsyncModelRunnerOutput.
        """
        pass


198
199
200
201
202
203
204
205
@dataclass
class DraftTokenIds:
    # [num_reqs]
    req_ids: list[str]
    # num_reqs x num_draft_tokens
    draft_token_ids: list[list[int]]


206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def make_empty_encoder_model_runner_output(
    scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput:
    """
    Create a ModelRunnerOutput stub that contains the correct
    per-request bookkeeping but no generated data yet.
    """
    if not scheduler_output.num_scheduled_tokens:
        return EMPTY_MODEL_RUNNER_OUTPUT

    # Convert to list so we get a deterministic, indexable sequence
    req_ids: list[str] = list(scheduler_output.num_scheduled_tokens.keys())

    # Give every request its own contiguous index
    req_id_to_index: dict[str, int] = {rid: idx for idx, rid in enumerate(req_ids)}

    # No tokens generated yet ⇒ one empty list per request
223
    sampled_token_ids: list[list[int]] = [[0] for _ in req_ids]
224
225
226
227
228
229
230
231
232
233
234
235

    # Pooler outputs are not available yet ⇒ use None placeholders
    pooler_output: list[torch.Tensor | None] = [None for _ in req_ids]

    return ModelRunnerOutput(
        req_ids=req_ids,
        req_id_to_index=req_id_to_index,
        sampled_token_ids=sampled_token_ids,
        pooler_output=pooler_output,
    )


236
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], req_id_to_index={})