outputs.py 11.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from collections.abc import MutableSequence
from collections.abc import Sequence as GenericSequence
6
from dataclasses import dataclass
Robert Shaw's avatar
Robert Shaw committed
7
from typing import Any, Generic, Optional, Union
8

9
import torch
10
from typing_extensions import TypeVar
11

12
from vllm.logger import init_logger
13
from vllm.logprobs import PromptLogprobs, SampleLogprobs
14
from vllm.lora.request import LoRARequest
15
from vllm.multimodal.inputs import MultiModalPlaceholderDict
16
from vllm.sequence import RequestMetrics
17

18
19
logger = init_logger(__name__)

20

21
@dataclass
22
class CompletionOutput:
Zhuohan Li's avatar
Zhuohan Li committed
23
24
25
26
27
28
29
30
31
32
33
    """The output data of one completion output of a request.

    Args:
        index: The index of the output in the request.
        text: The generated output text.
        token_ids: The token IDs of the generated output text.
        cumulative_logprob: The cumulative log probability of the generated
            output text.
        logprobs: The log probabilities of the top probability words at each
            position if the logprobs are requested.
        finish_reason: The reason why the sequence is finished.
34
35
36
        stop_reason: The stop string or token id that caused the completion
            to stop, None if the completion finished for some other reason
            including encountering the EOS token.
37
        lora_request: The LoRA request that was used to generate the output.
Zhuohan Li's avatar
Zhuohan Li committed
38
    """
39

40
41
    index: int
    text: str
42
    token_ids: GenericSequence[int]
43
    cumulative_logprob: Optional[float]
44
45
46
47
    logprobs: Optional[SampleLogprobs]
    finish_reason: Optional[str] = None
    stop_reason: Union[int, str, None] = None
    lora_request: Optional[LoRARequest] = None
Zhuohan Li's avatar
Zhuohan Li committed
48
49
50

    def finished(self) -> bool:
        return self.finish_reason is not None
51
52

    def __repr__(self) -> str:
53
54
        return (f"CompletionOutput(index={self.index}, "
                f"text={self.text!r}, "
55
                f"token_ids={self.token_ids}, "
56
                f"cumulative_logprob={self.cumulative_logprob}, "
57
                f"logprobs={self.logprobs}, "
58
59
                f"finish_reason={self.finish_reason}, "
                f"stop_reason={self.stop_reason})")
60
61


62
@dataclass
63
64
class PoolingOutput:
    """The output data of one pooling output of a request.
65
66

    Args:
67
        data: The extracted hidden states.
68
    """
69
    data: torch.Tensor
70
71

    def __repr__(self) -> str:
72
73
74
75
76
77
        return (f"PoolingOutput(data={self.data})")

    def __eq__(self, other: object) -> bool:
        return (isinstance(other, self.__class__) and bool(
            (self.data == other.data).all()))

78

79
class RequestOutput:
80
    """The output data of a completion request to the LLM.
Zhuohan Li's avatar
Zhuohan Li committed
81
82
83
84

    Args:
        request_id: The unique ID of the request.
        prompt: The prompt string of the request.
85
86
                For encoder/decoder models, this is the
                decoder input prompt.
Zhuohan Li's avatar
Zhuohan Li committed
87
        prompt_token_ids: The token IDs of the prompt.
88
89
                          For encoder/decoder models, this is the
                          decoder input prompt token ids.
lots-o's avatar
lots-o committed
90
        prompt_logprobs: The log probabilities to return per prompt token.
Zhuohan Li's avatar
Zhuohan Li committed
91
        outputs: The output sequences of the request.
92
        finished: Whether the whole request is finished.
93
        metrics: Metrics associated with the request.
94
        lora_request: The LoRA request that was used to generate the output.
95
96
97
98
99
        encoder_prompt: The encoder prompt string of the request.
                        None if decoder-only.
        encoder_prompt_token_ids: The token IDs of the encoder prompt.
                                  None if decoder-only.
        num_cached_tokens: The number of tokens with prefix cache hit.
Robert Shaw's avatar
Robert Shaw committed
100
        kv_transfer_params: The params for remote K/V transfer.
Zhuohan Li's avatar
Zhuohan Li committed
101
    """
102

103
104
    def __init__(
        self,
105
        request_id: str,
106
        prompt: Optional[str],
107
        prompt_token_ids: Optional[list[int]],
108
        prompt_logprobs: Optional[PromptLogprobs],
109
        outputs: list[CompletionOutput],
110
        finished: bool,
111
        metrics: Optional[RequestMetrics] = None,
112
        lora_request: Optional[LoRARequest] = None,
113
        encoder_prompt: Optional[str] = None,
114
        encoder_prompt_token_ids: Optional[list[int]] = None,
115
        num_cached_tokens: Optional[int] = None,
116
117
        *,
        multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None,
Robert Shaw's avatar
Robert Shaw committed
118
        kv_transfer_params: Optional[dict[str, Any]] = None,
119
120
121
        # Forward compatibility, code that uses args added in new release can
        # still run with older versions of vLLM without breaking.
        **kwargs: Any,
122
    ) -> None:
123
124
125
        if kwargs:
            logger.warning_once("RequestOutput: Ignoring extra arguments: %s",
                                str(kwargs))
126
127
128
        self.request_id = request_id
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
129
        self.multi_modal_placeholders = multi_modal_placeholders or {}
130
        self.prompt_logprobs = prompt_logprobs
131
        self.outputs = outputs
132
        self.finished = finished
133
        self.metrics = metrics
134
        self.lora_request = lora_request
135
136
        self.encoder_prompt = encoder_prompt
        self.encoder_prompt_token_ids = encoder_prompt_token_ids
137
        self.num_cached_tokens = num_cached_tokens
Robert Shaw's avatar
Robert Shaw committed
138
        self.kv_transfer_params = kv_transfer_params
139

140
    def add(self, next_output: "RequestOutput", aggregate: bool) -> None:
141
142
143
        """Merge subsequent RequestOutput into this one"""

        self.finished |= next_output.finished
Robert Shaw's avatar
Robert Shaw committed
144
        self.kv_transfer_params = next_output.kv_transfer_params
145

146
        for next_completion in next_output.outputs:
147
            for i, completion in enumerate(self.outputs):
148
                if completion.index == next_completion.index:
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
                    if aggregate:
                        # Merge outputs with same index
                        completion.text += next_completion.text
                        if not isinstance(completion.token_ids,
                                          MutableSequence):
                            completion.token_ids = list(completion.token_ids)
                        completion.token_ids.extend(next_completion.token_ids)
                        if next_completion.logprobs:
                            assert completion.logprobs is not None
                            completion.logprobs.extend(
                                next_completion.logprobs)
                        completion.cumulative_logprob = (
                            next_completion.cumulative_logprob)
                        completion.finish_reason = next_completion.finish_reason
                        completion.stop_reason = next_completion.stop_reason
                    else:
                        # Replace the output with the new one
                        self.outputs[i] = next_completion
167
168
169
                    break
            else:
                self.outputs.append(next_completion)
170

171
172
173
174
    def __repr__(self) -> str:
        return (f"RequestOutput(request_id={self.request_id}, "
                f"prompt={self.prompt!r}, "
                f"prompt_token_ids={self.prompt_token_ids}, "
175
176
                f"encoder_prompt={self.encoder_prompt!r}, "
                f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, "
177
                f"prompt_logprobs={self.prompt_logprobs}, "
178
                f"outputs={self.outputs}, "
179
                f"finished={self.finished}, "
180
                f"metrics={self.metrics}, "
181
                f"lora_request={self.lora_request}, "
182
183
                f"num_cached_tokens={self.num_cached_tokens}, "
                f"multi_modal_placeholders={self.multi_modal_placeholders})")
184
185


186
187
188
189
_O = TypeVar("_O", default=PoolingOutput)


class PoolingRequestOutput(Generic[_O]):
190
    """
191
    The output data of a pooling request to the LLM.
192
193

    Args:
194
195
        request_id (str): A unique identifier for the pooling request.
        outputs (PoolingOutput): The pooling results for the given input.
196
        prompt_token_ids (list[int]): A list of token IDs used in the prompt.
197
        finished (bool): A flag indicating whether the pooling is completed.
198
199
    """

200
    def __init__(self, request_id: str, outputs: _O,
201
                 prompt_token_ids: list[int], finished: bool):
202
203
204
205
206
207
        self.request_id = request_id
        self.prompt_token_ids = prompt_token_ids
        self.finished = finished
        self.outputs = outputs

    def __repr__(self):
208
209
        return (f"{type(self).__name__}(request_id={self.request_id!r}, "
                f"outputs={self.outputs!r}, "
210
211
212
213
                f"prompt_token_ids={self.prompt_token_ids}, "
                f"finished={self.finished})")


214
@dataclass
215
216
class EmbeddingOutput:
    """The output data of one embedding output of a request.
217
218

    Args:
219
        embedding: The embedding vector, which is a list of floats.
220
            Its length depends on the hidden dimension of the model.
221
    """
222
223
224
225
226
227
228
229
230
231
232
233
234
    embedding: list[float]

    @staticmethod
    def from_base(pooling_output: PoolingOutput):
        pooled_data = pooling_output.data
        if pooled_data.ndim != 1:
            raise ValueError("pooled_data should be a 1-D embedding vector")

        return EmbeddingOutput(pooled_data.tolist())

    @property
    def hidden_size(self) -> int:
        return len(self.embedding)
235
236

    def __repr__(self) -> str:
237
        return f"EmbeddingOutput(hidden_size={self.hidden_size})"
238
239


240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
class EmbeddingRequestOutput(PoolingRequestOutput[EmbeddingOutput]):

    @staticmethod
    def from_base(request_output: PoolingRequestOutput):
        return EmbeddingRequestOutput(
            request_id=request_output.request_id,
            outputs=EmbeddingOutput.from_base(request_output.outputs),
            prompt_token_ids=request_output.prompt_token_ids,
            finished=request_output.finished,
        )


@dataclass
class ClassificationOutput:
    """The output data of one classification output of a request.
255
256

    Args:
257
        probs: The probability vector, which is a list of floats.
258
            Its length depends on the number of classes.
259
    """
260
    probs: list[float]
261

262
263
    @staticmethod
    def from_base(pooling_output: PoolingOutput):
264
        # pooling_output shape: (num_classes)
265
266
267
        pooled_data = pooling_output.data
        if pooled_data.ndim != 1:
            raise ValueError("pooled_data should be a 1-D probability vector")
268

269
        return ClassificationOutput(pooled_data.tolist())
270

271
272
273
    @property
    def num_classes(self) -> int:
        return len(self.probs)
274

275
276
    def __repr__(self) -> str:
        return f"ClassificationOutput(num_classes={self.num_classes})"
277
278


279
class ClassificationRequestOutput(PoolingRequestOutput[ClassificationOutput]):
280
281

    @staticmethod
282
283
284
285
286
287
288
    def from_base(request_output: PoolingRequestOutput):
        return ClassificationRequestOutput(
            request_id=request_output.request_id,
            outputs=ClassificationOutput.from_base(request_output.outputs),
            prompt_token_ids=request_output.prompt_token_ids,
            finished=request_output.finished,
        )
289
290


291
292
293
@dataclass
class ScoringOutput:
    """The output data of one scoring output of a request.
294

295
296
297
298
299
300
301
    Args:
        score: The similarity score, which is a scalar value.
    """
    score: float

    @staticmethod
    def from_base(pooling_output: PoolingOutput):
302
303
304
305
        # pooling_output shape:
        #   classify task: (num_classes) num_classes == 1
        #   embed task: a scalar value
        pooled_data = pooling_output.data.squeeze()
306
307
        if pooled_data.ndim != 0:
            raise ValueError("pooled_data should be a scalar score")
308

309
        return ScoringOutput(pooled_data.item())
310

311
312
    def __repr__(self) -> str:
        return f"ScoringOutput(score={self.score})"
313
314


315
316
317
318
319
320
321
322
323
324
class ScoringRequestOutput(PoolingRequestOutput[ScoringOutput]):

    @staticmethod
    def from_base(request_output: PoolingRequestOutput):
        return ScoringRequestOutput(
            request_id=request_output.request_id,
            outputs=ScoringOutput.from_base(request_output.outputs),
            prompt_token_ids=request_output.prompt_token_ids,
            finished=request_output.finished,
        )