outputs.py 6.71 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from dataclasses import dataclass, field
6
from typing import TYPE_CHECKING, NamedTuple
7
8
9

import torch

10
if TYPE_CHECKING:
11
    from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
12
13
else:
    KVConnectorStats = object
14

15

16
class LogprobsLists(NamedTuple):
17
    # [num_reqs x num_generated_tokens, max_num_logprobs + 1]
18
    logprob_token_ids: list[list[int]]
19
    # [num_reqs x num_generated_tokens, max_num_logprobs + 1]
20
    logprobs: list[list[float]]
21
    # [num_reqs x num_generated_tokens]
22
    sampled_token_ranks: list[int]
23
24
25
26
27
    # [num_reqs]
    # Used for slicing the logprobs in cases like speculative
    # decoding where the number of generated tokens may be
    # different for each request.
    cu_num_generated_tokens: list[int] | None = None
28

29
30
31
32
    def slice(self, start_req_idx: int, end_req_idx: int):
        if self.cu_num_generated_tokens:
            start = self.cu_num_generated_tokens[start_req_idx]
            end = self.cu_num_generated_tokens[end_req_idx]
33
34
35
36
37
38
39
40
            # Recompute cumulative array starting from 0
            cu_num_offset = self.cu_num_generated_tokens[start_req_idx]
            sliced_cu_num_generated_tokens = [
                cu_num - cu_num_offset
                for cu_num in self.cu_num_generated_tokens[
                    start_req_idx : end_req_idx + 1
                ]
            ]
41
42
43
        else:
            start = start_req_idx
            end = end_req_idx
44
            sliced_cu_num_generated_tokens = None
45
46
47
48
        return LogprobsLists(
            self.logprob_token_ids[start:end],
            self.logprobs[start:end],
            self.sampled_token_ranks[start:end],
49
            sliced_cu_num_generated_tokens,
50
51
52
53
        )


class LogprobsTensors(NamedTuple):
54
    # [num_reqs x num_generated_tokens, max_num_logprobs + 1]
55
    logprob_token_ids: torch.Tensor
56
    # [num_reqs x num_generated_tokens, max_num_logprobs + 1]
57
    logprobs: torch.Tensor
58
    # [num_reqs x num_generated_tokens]
59
    selected_token_ranks: torch.Tensor
60

61
    def tolists(self, cu_num_generated_tokens: list[int] | None = None):
62
63
64
65
        return LogprobsLists(
            self.logprob_token_ids.tolist(),
            self.logprobs.tolist(),
            self.selected_token_ranks.tolist(),
66
            cu_num_generated_tokens,
67
68
        )

69
70
71
72
73
74
75
76
77
    def to_cpu_nonblocking(self) -> "LogprobsTensors":
        if self.logprob_token_ids.device.type == "cpu":
            return self
        return LogprobsTensors(
            self.logprob_token_ids.to("cpu", non_blocking=True),
            self.logprobs.to("cpu", non_blocking=True),
            self.selected_token_ranks.to("cpu", non_blocking=True),
        )

78
    @staticmethod
79
80
81
    def empty_cpu(
        num_positions: int, num_tokens_per_position: int
    ) -> "LogprobsTensors":
82
83
84
        """Create empty LogprobsTensors on CPU."""

        logprob_token_ids = torch.empty(
85
86
            (num_positions, num_tokens_per_position), dtype=torch.int32, device="cpu"
        )
87
        logprobs = torch.empty_like(logprob_token_ids, dtype=torch.float32)
88
89
90
        selected_token_ranks = torch.empty(
            num_positions, dtype=torch.int32, device="cpu"
        )
91
92
93
94
95
96
        return LogprobsTensors(
            logprob_token_ids=logprob_token_ids,
            logprobs=logprobs,
            selected_token_ranks=selected_token_ranks,
        )

97

98
99
# [num_reqs, <dynamic>]
# The shape of each element depends on the pooler used
100
PoolerOutput = torch.Tensor | list[torch.Tensor]
101
102


103
104
@dataclass
class SamplerOutput:
105
106
107
    # [num_reqs, max_num_generated_tokens]
    # Different requests can have different number of generated tokens.
    # All requests are padded to max_num_generated_tokens.
108
    # PLACEHOLDER_TOKEN_ID (-1 by default) is used for padding.
109
    sampled_token_ids: torch.Tensor
110
    logprobs_tensors: LogprobsTensors | None
111
112


113
114
115
@dataclass
class KVConnectorOutput:
    # [req_ids]
116
117
118
    finished_sending: set[str] | None = None
    finished_recving: set[str] | None = None
    kv_connector_stats: KVConnectorStats | None = None
119
    # IDs of externally computed KV blocks that failed to load.
120
    # Requests referencing these blocks should be rescheduled to recompute them
121
    invalid_block_ids: set[int] = field(default_factory=set)
122
123
124
125
126
127
    # Configuration describing how many finished sending/receiving
    # notifications should be expected for each request. This allows
    # handshake-based connectors like Nixl to update the KVOutputAggregator.
    # It captures a static setup info and should almost always remain constant
    # for a given connector after discovery. Default value entails no change.
    expected_finished_count: int = 0
128
129

    def is_empty(self):
130
131
132
133
134
135
        return (
            not self.finished_sending
            and not self.finished_recving
            and not self.kv_connector_stats
            and not self.invalid_block_ids
        )
136
137


138
# ModelRunnerOutput is serialized and sent to the scheduler process.
139
# This is expensive for torch.Tensor so prefer to use list instead.
140
141
142
@dataclass
class ModelRunnerOutput:
    # [num_reqs]
143
    req_ids: list[str]
144
    # req_id -> index
145
    req_id_to_index: dict[str, int]
146

147
148
149
150
    # num_reqs x num_generated_tokens
    # num_generated_tokens is the number of tokens
    # generated in the current step. It can be different for
    # each request due to speculative/jump decoding.
151
    sampled_token_ids: list[list[int]]
152
153
154

    # [num_reqs, max_num_logprobs + 1]
    # [num_reqs, max_num_logprobs + 1]
155
    # [num_reqs]
156
    logprobs: LogprobsLists | None
157
158
159
160
161

    # req_id -> (token_ids, logprobs, ranks)
    # [prompt_len, num_prompt_logprobs]
    # [prompt_len, num_prompt_logprobs]
    # [prompt_len]
162
    prompt_logprobs_dict: dict[str, LogprobsTensors | None]
163

164
    # [num_reqs, hidden_size]
165
    pooler_output: list[torch.Tensor | None]
166

167
    kv_connector_output: KVConnectorOutput | None = None
Robert Shaw's avatar
Robert Shaw committed
168

169
    # req_id -> num_nans_in_logits
170
    num_nans_in_logits: dict[str, int] | None = None
171

Robert Shaw's avatar
Robert Shaw committed
172

173
174
175
176
177
# ModelRunnerOutput wrapper for async scheduling.
class AsyncModelRunnerOutput(ABC):
    @abstractmethod
    def get_output(self) -> ModelRunnerOutput:
        """Get the ModelRunnerOutput for this async output.
178

179
180
181
182
183
184
185
        This is a blocking call that waits until the results are ready, which
        might involve copying device tensors to the host.
        This method should only be called once per AsyncModelRunnerOutput.
        """
        pass


186
187
188
189
190
191
192
193
@dataclass
class DraftTokenIds:
    # [num_reqs]
    req_ids: list[str]
    # num_reqs x num_draft_tokens
    draft_token_ids: list[list[int]]


194
195
196
197
198
199
200
201
202
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
    req_ids=[],
    req_id_to_index={},
    sampled_token_ids=[],
    logprobs=None,
    prompt_logprobs_dict={},
    pooler_output=[],
    num_nans_in_logits=None,
)