outputs.py 4.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from dataclasses import dataclass
6
from typing import TYPE_CHECKING, NamedTuple, Optional, Union
7
8
9

import torch

10
11
12
13
if TYPE_CHECKING:
    from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
        KVConnectorStats)

14

15
class LogprobsLists(NamedTuple):
16

17
    # [num_reqs, max_num_logprobs + 1]
18
    logprob_token_ids: list[list[int]]
19
    # [num_reqs, max_num_logprobs + 1]
20
    logprobs: list[list[float]]
21
    # [num_reqs]
22
    sampled_token_ranks: list[int]
23
24
25
26
27
28
29
30
31
32

    def slice(self, start: int, end: int):
        return LogprobsLists(
            self.logprob_token_ids[start:end],
            self.logprobs[start:end],
            self.sampled_token_ranks[start:end],
        )


class LogprobsTensors(NamedTuple):
33
34

    # [num_reqs, max_num_logprobs + 1]
35
    logprob_token_ids: torch.Tensor
36
    # [num_reqs, max_num_logprobs + 1]
37
38
39
    logprobs: torch.Tensor
    # [num_reqs]
    selected_token_ranks: torch.Tensor
40

41
42
43
44
45
46
47
    def tolists(self):
        return LogprobsLists(
            self.logprob_token_ids.tolist(),
            self.logprobs.tolist(),
            self.selected_token_ranks.tolist(),
        )

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    @staticmethod
    def empty_cpu(num_positions: int,
                  num_tokens_per_position: int) -> "LogprobsTensors":
        """Create empty LogprobsTensors on CPU."""

        logprob_token_ids = torch.empty(
            (num_positions, num_tokens_per_position),
            dtype=torch.int32,
            device="cpu")
        logprobs = torch.empty_like(logprob_token_ids, dtype=torch.float32)
        selected_token_ranks = torch.empty(num_positions,
                                           dtype=torch.int32,
                                           device="cpu")
        return LogprobsTensors(
            logprob_token_ids=logprob_token_ids,
            logprobs=logprobs,
            selected_token_ranks=selected_token_ranks,
        )

67

68
69
70
71
72
# [num_reqs, <dynamic>]
# The shape of each element depends on the pooler used
PoolerOutput = Union[torch.Tensor, list[torch.Tensor]]


73
74
75
@dataclass
class SamplerOutput:

76
77
78
    # [num_reqs, max_num_generated_tokens]
    # Different requests can have different number of generated tokens.
    # All requests are padded to max_num_generated_tokens.
79
    # PLACEHOLDER_TOKEN_ID (-1 by default) is used for padding.
80
81
    sampled_token_ids: torch.Tensor
    logprobs_tensors: Optional[LogprobsTensors]
82
83


84
85
86
87
88
@dataclass
class KVConnectorOutput:
    # [req_ids]
    finished_sending: Optional[set[str]] = None
    finished_recving: Optional[set[str]] = None
89
90
91
92
93
    kv_connector_stats: Optional["KVConnectorStats"] = None

    def is_empty(self):
        return (not self.finished_sending and not self.finished_recving
                and not self.kv_connector_stats)
94
95


96
# ModelRunnerOutput is serialized and sent to the scheduler process.
97
# This is expensive for torch.Tensor so prefer to use list instead.
98
99
100
101
@dataclass
class ModelRunnerOutput:

    # [num_reqs]
102
    req_ids: list[str]
103
    # req_id -> index
104
    req_id_to_index: dict[str, int]
105

106
107
108
109
    # num_reqs x num_generated_tokens
    # num_generated_tokens is the number of tokens
    # generated in the current step. It can be different for
    # each request due to speculative/jump decoding.
110
    sampled_token_ids: list[list[int]]
111
112
113

    # [num_reqs, max_num_logprobs + 1]
    # [num_reqs, max_num_logprobs + 1]
114
115
116
117
118
119
120
    # [num_reqs]
    logprobs: Optional[LogprobsLists]

    # req_id -> (token_ids, logprobs, ranks)
    # [prompt_len, num_prompt_logprobs]
    # [prompt_len, num_prompt_logprobs]
    # [prompt_len]
121
    prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]]
122

123
124
125
    # [num_reqs, hidden_size]
    pooler_output: list[Optional[torch.Tensor]]

126
    kv_connector_output: Optional[KVConnectorOutput] = None
Robert Shaw's avatar
Robert Shaw committed
127

128
129
130
    # req_id -> num_nans_in_logits
    num_nans_in_logits: Optional[dict[str, int]] = None

Robert Shaw's avatar
Robert Shaw committed
131

132
133
134
135
136
137
138
139
140
141
142
143
144
145
# ModelRunnerOutput wrapper for async scheduling.
class AsyncModelRunnerOutput(ABC):

    @abstractmethod
    def get_output(self) -> ModelRunnerOutput:
        """Get the ModelRunnerOutput for this async output.
        
        This is a blocking call that waits until the results are ready, which
        might involve copying device tensors to the host.
        This method should only be called once per AsyncModelRunnerOutput.
        """
        pass


146
147
148
149
150
151
152
153
154
@dataclass
class DraftTokenIds:

    # [num_reqs]
    req_ids: list[str]
    # num_reqs x num_draft_tokens
    draft_token_ids: list[list[int]]


Robert Shaw's avatar
Robert Shaw committed
155
156
157
158
159
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
                                              req_id_to_index={},
                                              sampled_token_ids=[],
                                              logprobs=None,
                                              prompt_logprobs_dict={},
160
                                              pooler_output=[],
161
                                              num_nans_in_logits=None)