outputs.py 4.82 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from dataclasses import dataclass, field
6
from typing import TYPE_CHECKING, NamedTuple
7
8
9

import torch

10
if TYPE_CHECKING:
11
    from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
12
13
else:
    KVConnectorStats = object
14

15

16
17
class LogprobsLists(NamedTuple):
    # [num_reqs, max_num_logprobs + 1]
18
    logprob_token_ids: list[list[int]]
19
    # [num_reqs, max_num_logprobs + 1]
20
    logprobs: list[list[float]]
21
    # [num_reqs]
22
    sampled_token_ranks: list[int]
23
24
25
26
27
28
29
30
31
32

    def slice(self, start: int, end: int):
        return LogprobsLists(
            self.logprob_token_ids[start:end],
            self.logprobs[start:end],
            self.sampled_token_ranks[start:end],
        )


class LogprobsTensors(NamedTuple):
33
    # [num_reqs, max_num_logprobs + 1]
34
    logprob_token_ids: torch.Tensor
35
    # [num_reqs, max_num_logprobs + 1]
36
37
38
    logprobs: torch.Tensor
    # [num_reqs]
    selected_token_ranks: torch.Tensor
39

40
41
42
43
44
45
46
    def tolists(self):
        return LogprobsLists(
            self.logprob_token_ids.tolist(),
            self.logprobs.tolist(),
            self.selected_token_ranks.tolist(),
        )

47
    @staticmethod
48
49
50
    def empty_cpu(
        num_positions: int, num_tokens_per_position: int
    ) -> "LogprobsTensors":
51
52
53
        """Create empty LogprobsTensors on CPU."""

        logprob_token_ids = torch.empty(
54
55
            (num_positions, num_tokens_per_position), dtype=torch.int32, device="cpu"
        )
56
        logprobs = torch.empty_like(logprob_token_ids, dtype=torch.float32)
57
58
59
        selected_token_ranks = torch.empty(
            num_positions, dtype=torch.int32, device="cpu"
        )
60
61
62
63
64
65
        return LogprobsTensors(
            logprob_token_ids=logprob_token_ids,
            logprobs=logprobs,
            selected_token_ranks=selected_token_ranks,
        )

66

67
68
# [num_reqs, <dynamic>]
# The shape of each element depends on the pooler used
69
PoolerOutput = torch.Tensor | list[torch.Tensor]
70
71


72
73
@dataclass
class SamplerOutput:
74
75
76
    # [num_reqs, max_num_generated_tokens]
    # Different requests can have different number of generated tokens.
    # All requests are padded to max_num_generated_tokens.
77
    # PLACEHOLDER_TOKEN_ID (-1 by default) is used for padding.
78
    sampled_token_ids: torch.Tensor
79
    logprobs_tensors: LogprobsTensors | None
80
81


82
83
84
@dataclass
class KVConnectorOutput:
    # [req_ids]
85
86
87
    finished_sending: set[str] | None = None
    finished_recving: set[str] | None = None
    kv_connector_stats: KVConnectorStats | None = None
88
89
90
    # IDs of externally computed KV blocks that failed to load.
    # Requests referencing these blocks should be rescheduled to recompute them.
    invalid_block_ids: set[int] = field(default_factory=set)
91
92

    def is_empty(self):
93
94
95
96
97
98
        return (
            not self.finished_sending
            and not self.finished_recving
            and not self.kv_connector_stats
            and not self.invalid_block_ids
        )
99
100


101
# ModelRunnerOutput is serialized and sent to the scheduler process.
102
# This is expensive for torch.Tensor so prefer to use list instead.
103
104
105
@dataclass
class ModelRunnerOutput:
    # [num_reqs]
106
    req_ids: list[str]
107
    # req_id -> index
108
    req_id_to_index: dict[str, int]
109

110
111
112
113
    # num_reqs x num_generated_tokens
    # num_generated_tokens is the number of tokens
    # generated in the current step. It can be different for
    # each request due to speculative/jump decoding.
114
    sampled_token_ids: list[list[int]]
115
116
117

    # [num_reqs, max_num_logprobs + 1]
    # [num_reqs, max_num_logprobs + 1]
118
    # [num_reqs]
119
    logprobs: LogprobsLists | None
120
121
122
123
124

    # req_id -> (token_ids, logprobs, ranks)
    # [prompt_len, num_prompt_logprobs]
    # [prompt_len, num_prompt_logprobs]
    # [prompt_len]
125
    prompt_logprobs_dict: dict[str, LogprobsTensors | None]
126

127
    # [num_reqs, hidden_size]
128
    pooler_output: list[torch.Tensor | None]
129

130
    kv_connector_output: KVConnectorOutput | None = None
Robert Shaw's avatar
Robert Shaw committed
131

132
    # req_id -> num_nans_in_logits
133
    num_nans_in_logits: dict[str, int] | None = None
134

Robert Shaw's avatar
Robert Shaw committed
135

136
137
138
139
140
# ModelRunnerOutput wrapper for async scheduling.
class AsyncModelRunnerOutput(ABC):
    @abstractmethod
    def get_output(self) -> ModelRunnerOutput:
        """Get the ModelRunnerOutput for this async output.
141

142
143
144
145
146
147
148
        This is a blocking call that waits until the results are ready, which
        might involve copying device tensors to the host.
        This method should only be called once per AsyncModelRunnerOutput.
        """
        pass


149
150
151
152
153
154
155
156
@dataclass
class DraftTokenIds:
    # [num_reqs]
    req_ids: list[str]
    # num_reqs x num_draft_tokens
    draft_token_ids: list[list[int]]


157
158
159
160
161
162
163
164
165
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
    req_ids=[],
    req_id_to_index={},
    sampled_token_ids=[],
    logprobs=None,
    prompt_logprobs_dict={},
    pooler_output=[],
    num_nans_in_logits=None,
)