outputs.py 6.73 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from dataclasses import dataclass, field
6
from typing import TYPE_CHECKING, NamedTuple
7

8
import numpy as np
9
10
import torch

11
if TYPE_CHECKING:
12
    from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
13
14
else:
    KVConnectorStats = object
15

16

17
class LogprobsLists(NamedTuple):
18
    # [num_reqs x num_generated_tokens, max_num_logprobs + 1]
19
    logprob_token_ids: np.ndarray
20
    # [num_reqs x num_generated_tokens, max_num_logprobs + 1]
21
    logprobs: np.ndarray
22
    # [num_reqs x num_generated_tokens]
23
    sampled_token_ranks: np.ndarray
24
25
26
27
28
    # [num_reqs]
    # Used for slicing the logprobs in cases like speculative
    # decoding where the number of generated tokens may be
    # different for each request.
    cu_num_generated_tokens: list[int] | None = None
29

30
31
32
33
    def slice(self, start_req_idx: int, end_req_idx: int):
        if self.cu_num_generated_tokens:
            start = self.cu_num_generated_tokens[start_req_idx]
            end = self.cu_num_generated_tokens[end_req_idx]
34
35
36
37
38
39
40
41
            # Recompute cumulative array starting from 0
            cu_num_offset = self.cu_num_generated_tokens[start_req_idx]
            sliced_cu_num_generated_tokens = [
                cu_num - cu_num_offset
                for cu_num in self.cu_num_generated_tokens[
                    start_req_idx : end_req_idx + 1
                ]
            ]
42
43
44
        else:
            start = start_req_idx
            end = end_req_idx
45
            sliced_cu_num_generated_tokens = None
46
47
48
49
        return LogprobsLists(
            self.logprob_token_ids[start:end],
            self.logprobs[start:end],
            self.sampled_token_ranks[start:end],
50
            sliced_cu_num_generated_tokens,
51
52
53
54
        )


class LogprobsTensors(NamedTuple):
55
    # [num_reqs x num_generated_tokens, max_num_logprobs + 1]
56
    logprob_token_ids: torch.Tensor
57
    # [num_reqs x num_generated_tokens, max_num_logprobs + 1]
58
    logprobs: torch.Tensor
59
    # [num_reqs x num_generated_tokens]
60
    selected_token_ranks: torch.Tensor
61

62
    def tolists(self, cu_num_generated_tokens: list[int] | None = None):
63
        return LogprobsLists(
64
65
66
            self.logprob_token_ids.cpu().numpy(),
            self.logprobs.cpu().numpy(),
            self.selected_token_ranks.cpu().numpy(),
67
            cu_num_generated_tokens,
68
69
        )

70
71
72
73
74
75
76
77
78
    def to_cpu_nonblocking(self) -> "LogprobsTensors":
        if self.logprob_token_ids.device.type == "cpu":
            return self
        return LogprobsTensors(
            self.logprob_token_ids.to("cpu", non_blocking=True),
            self.logprobs.to("cpu", non_blocking=True),
            self.selected_token_ranks.to("cpu", non_blocking=True),
        )

79
    @staticmethod
80
81
82
    def empty_cpu(
        num_positions: int, num_tokens_per_position: int
    ) -> "LogprobsTensors":
83
84
85
        """Create empty LogprobsTensors on CPU."""

        logprob_token_ids = torch.empty(
86
87
            (num_positions, num_tokens_per_position), dtype=torch.int32, device="cpu"
        )
88
        logprobs = torch.empty_like(logprob_token_ids, dtype=torch.float32)
89
90
91
        selected_token_ranks = torch.empty(
            num_positions, dtype=torch.int32, device="cpu"
        )
92
93
94
95
96
97
        return LogprobsTensors(
            logprob_token_ids=logprob_token_ids,
            logprobs=logprobs,
            selected_token_ranks=selected_token_ranks,
        )

98

99
100
# [num_reqs, <dynamic>]
# The shape of each element depends on the pooler used
101
PoolerOutput = torch.Tensor | list[torch.Tensor]
102
103


104
105
@dataclass
class SamplerOutput:
106
107
108
    # [num_reqs, max_num_generated_tokens]
    # Different requests can have different number of generated tokens.
    # All requests are padded to max_num_generated_tokens.
109
    # PLACEHOLDER_TOKEN_ID (-1 by default) is used for padding.
110
    sampled_token_ids: torch.Tensor
111
    logprobs_tensors: LogprobsTensors | None
112
113


114
115
116
@dataclass
class KVConnectorOutput:
    # [req_ids]
117
118
119
    finished_sending: set[str] | None = None
    finished_recving: set[str] | None = None
    kv_connector_stats: KVConnectorStats | None = None
120
    # IDs of externally computed KV blocks that failed to load.
121
    # Requests referencing these blocks should be rescheduled to recompute them
122
    invalid_block_ids: set[int] = field(default_factory=set)
123
124
125
126
127
128
    # Configuration describing how many finished sending/receiving
    # notifications should be expected for each request. This allows
    # handshake-based connectors like Nixl to update the KVOutputAggregator.
    # It captures a static setup info and should almost always remain constant
    # for a given connector after discovery. Default value entails no change.
    expected_finished_count: int = 0
129
130

    def is_empty(self):
131
132
133
134
135
136
        return (
            not self.finished_sending
            and not self.finished_recving
            and not self.kv_connector_stats
            and not self.invalid_block_ids
        )
137
138


139
# ModelRunnerOutput is serialized and sent to the scheduler process.
140
# This is expensive for torch.Tensor so prefer to use list instead.
141
142
143
@dataclass
class ModelRunnerOutput:
    # [num_reqs]
144
    req_ids: list[str]
145
    # req_id -> index
146
    req_id_to_index: dict[str, int]
147

148
149
150
151
    # num_reqs x num_generated_tokens
    # num_generated_tokens is the number of tokens
    # generated in the current step. It can be different for
    # each request due to speculative/jump decoding.
152
    sampled_token_ids: list[list[int]]
153
154
155

    # [num_reqs, max_num_logprobs + 1]
    # [num_reqs, max_num_logprobs + 1]
156
    # [num_reqs]
157
    logprobs: LogprobsLists | None
158
159
160
161
162

    # req_id -> (token_ids, logprobs, ranks)
    # [prompt_len, num_prompt_logprobs]
    # [prompt_len, num_prompt_logprobs]
    # [prompt_len]
163
    prompt_logprobs_dict: dict[str, LogprobsTensors | None]
164

165
    # [num_reqs, hidden_size]
166
    pooler_output: list[torch.Tensor | None]
167

168
    kv_connector_output: KVConnectorOutput | None = None
Robert Shaw's avatar
Robert Shaw committed
169

170
    # req_id -> num_nans_in_logits
171
    num_nans_in_logits: dict[str, int] | None = None
172

Robert Shaw's avatar
Robert Shaw committed
173

174
175
176
177
178
# ModelRunnerOutput wrapper for async scheduling.
class AsyncModelRunnerOutput(ABC):
    @abstractmethod
    def get_output(self) -> ModelRunnerOutput:
        """Get the ModelRunnerOutput for this async output.
179

180
181
182
183
184
185
186
        This is a blocking call that waits until the results are ready, which
        might involve copying device tensors to the host.
        This method should only be called once per AsyncModelRunnerOutput.
        """
        pass


187
188
189
190
191
192
193
194
@dataclass
class DraftTokenIds:
    # [num_reqs]
    req_ids: list[str]
    # num_reqs x num_draft_tokens
    draft_token_ids: list[list[int]]


195
196
197
198
199
200
201
202
203
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
    req_ids=[],
    req_id_to_index={},
    sampled_token_ids=[],
    logprobs=None,
    prompt_logprobs_dict={},
    pooler_output=[],
    num_nans_in_logits=None,
)