outputs.py 19.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
6
from collections.abc import MutableSequence
from collections.abc import Sequence as GenericSequence
7
from dataclasses import dataclass
Robert Shaw's avatar
Robert Shaw committed
8
from typing import Any, Generic, Optional, Union
9

10
import torch
11
from typing_extensions import TypeVar
12

13
from vllm.logger import init_logger
14
from vllm.logprobs import PromptLogprobs, SampleLogprobs
15
from vllm.lora.request import LoRARequest
16
from vllm.multimodal.inputs import MultiModalPlaceholderDict
17
from vllm.sampling_params import RequestOutputKind
18
19
from vllm.sequence import (RequestMetrics, SequenceGroup, SequenceGroupBase,
                           SequenceStatus)
20

21
22
logger = init_logger(__name__)

23

24
@dataclass
25
class CompletionOutput:
Zhuohan Li's avatar
Zhuohan Li committed
26
27
28
29
30
31
32
33
34
35
36
    """The output data of one completion output of a request.

    Args:
        index: The index of the output in the request.
        text: The generated output text.
        token_ids: The token IDs of the generated output text.
        cumulative_logprob: The cumulative log probability of the generated
            output text.
        logprobs: The log probabilities of the top probability words at each
            position if the logprobs are requested.
        finish_reason: The reason why the sequence is finished.
37
38
39
        stop_reason: The stop string or token id that caused the completion
            to stop, None if the completion finished for some other reason
            including encountering the EOS token.
40
        lora_request: The LoRA request that was used to generate the output.
Zhuohan Li's avatar
Zhuohan Li committed
41
    """
42

43
44
    index: int
    text: str
45
    token_ids: GenericSequence[int]
46
    cumulative_logprob: Optional[float]
47
48
49
50
    logprobs: Optional[SampleLogprobs]
    finish_reason: Optional[str] = None
    stop_reason: Union[int, str, None] = None
    lora_request: Optional[LoRARequest] = None
Zhuohan Li's avatar
Zhuohan Li committed
51
52
53

    def finished(self) -> bool:
        return self.finish_reason is not None
54
55

    def __repr__(self) -> str:
56
57
        return (f"CompletionOutput(index={self.index}, "
                f"text={self.text!r}, "
58
                f"token_ids={self.token_ids}, "
59
                f"cumulative_logprob={self.cumulative_logprob}, "
60
                f"logprobs={self.logprobs}, "
61
62
                f"finish_reason={self.finish_reason}, "
                f"stop_reason={self.stop_reason})")
63
64


65
@dataclass
66
67
class PoolingOutput:
    """The output data of one pooling output of a request.
68
69

    Args:
70
        data: The extracted hidden states.
71
    """
72
    data: torch.Tensor
73
74

    def __repr__(self) -> str:
75
76
77
78
79
80
        return (f"PoolingOutput(data={self.data})")

    def __eq__(self, other: object) -> bool:
        return (isinstance(other, self.__class__) and bool(
            (self.data == other.data).all()))

81

82
class RequestOutput:
83
    """The output data of a completion request to the LLM.
Zhuohan Li's avatar
Zhuohan Li committed
84
85
86
87

    Args:
        request_id: The unique ID of the request.
        prompt: The prompt string of the request.
88
89
                For encoder/decoder models, this is the
                decoder input prompt.
Zhuohan Li's avatar
Zhuohan Li committed
90
        prompt_token_ids: The token IDs of the prompt.
91
92
                          For encoder/decoder models, this is the
                          decoder input prompt token ids.
lots-o's avatar
lots-o committed
93
        prompt_logprobs: The log probabilities to return per prompt token.
Zhuohan Li's avatar
Zhuohan Li committed
94
        outputs: The output sequences of the request.
95
        finished: Whether the whole request is finished.
96
        metrics: Metrics associated with the request.
97
        lora_request: The LoRA request that was used to generate the output.
98
99
100
101
102
        encoder_prompt: The encoder prompt string of the request.
                        None if decoder-only.
        encoder_prompt_token_ids: The token IDs of the encoder prompt.
                                  None if decoder-only.
        num_cached_tokens: The number of tokens with prefix cache hit.
Robert Shaw's avatar
Robert Shaw committed
103
        kv_transfer_params: The params for remote K/V transfer.
Zhuohan Li's avatar
Zhuohan Li committed
104
    """
105

106
107
    def __init__(
        self,
108
        request_id: str,
109
        prompt: Optional[str],
110
        prompt_token_ids: Optional[list[int]],
111
        prompt_logprobs: Optional[PromptLogprobs],
112
        outputs: list[CompletionOutput],
113
        finished: bool,
114
        metrics: Optional[RequestMetrics] = None,
115
        lora_request: Optional[LoRARequest] = None,
116
        encoder_prompt: Optional[str] = None,
117
        encoder_prompt_token_ids: Optional[list[int]] = None,
118
        num_cached_tokens: Optional[int] = None,
119
120
        *,
        multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None,
Robert Shaw's avatar
Robert Shaw committed
121
        kv_transfer_params: Optional[dict[str, Any]] = None,
122
123
124
        # Forward compatibility, code that uses args added in new release can
        # still run with older versions of vLLM without breaking.
        **kwargs: Any,
125
    ) -> None:
126
127
128
        if kwargs:
            logger.warning_once("RequestOutput: Ignoring extra arguments: %s",
                                str(kwargs))
129
130
131
        self.request_id = request_id
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
132
        self.multi_modal_placeholders = multi_modal_placeholders or {}
133
        self.prompt_logprobs = prompt_logprobs
134
        self.outputs = outputs
135
        self.finished = finished
136
        self.metrics = metrics
137
        self.lora_request = lora_request
138
139
        self.encoder_prompt = encoder_prompt
        self.encoder_prompt_token_ids = encoder_prompt_token_ids
140
        self.num_cached_tokens = num_cached_tokens
Robert Shaw's avatar
Robert Shaw committed
141
        self.kv_transfer_params = kv_transfer_params
142

143
    def add(self, next_output: "RequestOutput", aggregate: bool) -> None:
144
145
146
        """Merge subsequent RequestOutput into this one"""

        self.finished |= next_output.finished
Robert Shaw's avatar
Robert Shaw committed
147
        self.kv_transfer_params = next_output.kv_transfer_params
148

149
        for next_completion in next_output.outputs:
150
            for i, completion in enumerate(self.outputs):
151
                if completion.index == next_completion.index:
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
                    if aggregate:
                        # Merge outputs with same index
                        completion.text += next_completion.text
                        if not isinstance(completion.token_ids,
                                          MutableSequence):
                            completion.token_ids = list(completion.token_ids)
                        completion.token_ids.extend(next_completion.token_ids)
                        if next_completion.logprobs:
                            assert completion.logprobs is not None
                            completion.logprobs.extend(
                                next_completion.logprobs)
                        completion.cumulative_logprob = (
                            next_completion.cumulative_logprob)
                        completion.finish_reason = next_completion.finish_reason
                        completion.stop_reason = next_completion.stop_reason
                    else:
                        # Replace the output with the new one
                        self.outputs[i] = next_completion
170
171
172
                    break
            else:
                self.outputs.append(next_completion)
173

174
    @classmethod
175
176
    def from_seq_group(
        cls, seq_group: SequenceGroup, use_cache: bool,
177
        seq_id_to_seq_group: dict[str, SequenceGroupBase]
178
179
180
181
182
183
    ) -> Optional["RequestOutput"]:
        finished = seq_group.is_finished()

        if seq_group.request_id in seq_id_to_seq_group:
            group: SequenceGroupBase = seq_id_to_seq_group[
                seq_group.request_id]
184
            assembled_seq_group = group.maybe_assemble_group(seq_group)
185
186
187
188
            if finished:
                group.finish_seq(seq_group)
            if assembled_seq_group is None:
                return None
189
190
191
192
193
194
195

            # clear finished seq in seq_id_to_seq_group
            if len(group.to_be_finished) == 0:
                for sub_request_id in list(group.seq_id_to_index.keys()):
                    if sub_request_id in seq_id_to_seq_group:
                        del seq_id_to_seq_group[sub_request_id]

196
197
198
            return cls.from_seq_group(assembled_seq_group, use_cache,
                                      seq_id_to_seq_group)

199
200
        sampling_params = seq_group.sampling_params
        if sampling_params is None:
201
202
            raise ValueError(
                "Sampling parameters are missing for a CompletionRequest.")
203

204
205
206
207
        if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and (
                not finished):
            return None

208
209
210
211
212
213
214
215
216
217
        # Init cache (if needed)
        if use_cache and seq_group.cached_request_output is None:
            seq_group.cached_request_output = RequestOutput(  # type: ignore
                request_id="",
                prompt=None,
                prompt_token_ids=[],
                prompt_logprobs=None,
                outputs=[],
                finished=False)

218
        top_n_seqs = seq_group.get_seqs()
219

220
        # Create the outputs.
221
222
223
        # NOTE: We need omit logprobs here explicitly because the sequence
        # always has the logprobs of the sampled tokens even if the
        # logprobs are not requested.
224
225
226
227
228
229
        include_logprobs = sampling_params.logprobs is not None
        text_buffer_length = sampling_params.output_text_buffer_length
        delta = sampling_params.output_kind == RequestOutputKind.DELTA

        outputs = []
        include_prompt = True
230
231
        # num_cached_tokens should be the same for all the sequences
        num_cached_tokens = None
232
        for i, seq in enumerate(top_n_seqs):
233
234
            output_text = seq.get_output_text_to_return(
                text_buffer_length, delta)
235

236
            output_token_ids = seq.get_output_token_ids_to_return(delta)
237
238
            num_output_tokens = 1 if isinstance(output_token_ids,
                                                int) else len(output_token_ids)
239
            num_cached_tokens = seq.data.get_num_cached_tokens()
240

241
242
243
244
245
            output_logprobs = seq.output_logprobs if include_logprobs else None

            if delta:
                # Slice logprobs delta if applicable
                if output_logprobs:
246
247
248
249
250
251
                    # num_output_tokens can be 0 when n > 1 and request finishes
                    # before the others
                    if num_output_tokens > 0:
                        output_logprobs = output_logprobs[-num_output_tokens:]
                    else:
                        output_logprobs = None
252
253
                # Don't include prompt if this is after the first output
                # containing decode token ids
254
                if include_prompt and seq.get_output_len() > num_output_tokens:
255
256
                    include_prompt = False

257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
            if use_cache:
                # Get cached output object
                cached_outputs = seq_group.cached_request_output.outputs  # type: ignore
                if i >= len(cached_outputs):
                    cached_outputs.append(
                        CompletionOutput(index=i,
                                         text="",
                                         token_ids=[],
                                         cumulative_logprob=None,
                                         logprobs=None,
                                         finish_reason=None,
                                         stop_reason=None))
                output = cached_outputs[i]

                # Init cached output object
                assert output.index == i
                output.text = output_text

                if isinstance(output_token_ids, int):
                    output.token_ids.clear()
                    output.token_ids.append(output_token_ids)
                else:
                    output.token_ids = output_token_ids

                output.cumulative_logprob = seq.get_cumulative_logprob() \
                    if include_logprobs else None
                output.logprobs = output_logprobs
                output.finish_reason = SequenceStatus.get_finished_reason(
                    seq.status)
                output.stop_reason = seq.stop_reason

            else:
                output = CompletionOutput(
290
                    top_n_seqs.index(seq), output_text, [output_token_ids]
291
                    if isinstance(output_token_ids, int) else output_token_ids,
292
293
294
                    seq.get_cumulative_logprob() if include_logprobs else None,
                    output_logprobs,
                    SequenceStatus.get_finished_reason(seq.status),
295
296
297
                    seq.stop_reason)

            outputs.append(output)
298
299

        # Every sequence in the sequence group should have the same prompt.
300
301
302
303
304
305
306
307
308
309
310
311
        if include_prompt:
            prompt = seq_group.prompt
            prompt_token_ids = seq_group.prompt_token_ids
            encoder_prompt = seq_group.encoder_prompt
            encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids
            prompt_logprobs = seq_group.prompt_logprobs
        else:
            prompt = None
            prompt_token_ids = None
            encoder_prompt = None
            encoder_prompt_token_ids = None
            prompt_logprobs = None
312
313
        finished_time = time.time() if finished else None
        seq_group.set_finished_time(finished_time)
314

315
316
317
318
319
320
321
322
323
324
325
326
327
328
        init_kwargs = {
            "request_id": seq_group.request_id,
            "prompt": prompt,
            "prompt_token_ids": prompt_token_ids,
            "prompt_logprobs": prompt_logprobs,
            "outputs": outputs,
            "finished": finished,
            "metrics": seq_group.metrics,
            "lora_request": seq_group.lora_request,
            "encoder_prompt": encoder_prompt,
            "encoder_prompt_token_ids": encoder_prompt_token_ids,
            "num_cached_tokens": num_cached_tokens,
            "multi_modal_placeholders": seq_group.multi_modal_placeholders
        }
329
330
331

        if use_cache:
            request_output = seq_group.cached_request_output
332
            request_output.__init__(**init_kwargs)  # type: ignore
333
        else:
334
            request_output = cls(**init_kwargs)  # type: ignore
335
336

        return request_output
337
338
339
340
341

    def __repr__(self) -> str:
        return (f"RequestOutput(request_id={self.request_id}, "
                f"prompt={self.prompt!r}, "
                f"prompt_token_ids={self.prompt_token_ids}, "
342
343
                f"encoder_prompt={self.encoder_prompt!r}, "
                f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, "
344
                f"prompt_logprobs={self.prompt_logprobs}, "
345
                f"outputs={self.outputs}, "
346
                f"finished={self.finished}, "
347
                f"metrics={self.metrics}, "
348
                f"lora_request={self.lora_request}, "
349
350
                f"num_cached_tokens={self.num_cached_tokens}, "
                f"multi_modal_placeholders={self.multi_modal_placeholders})")
351
352


353
354
355
356
_O = TypeVar("_O", default=PoolingOutput)


class PoolingRequestOutput(Generic[_O]):
357
    """
358
    The output data of a pooling request to the LLM.
359
360

    Args:
361
362
        request_id (str): A unique identifier for the pooling request.
        outputs (PoolingOutput): The pooling results for the given input.
363
        prompt_token_ids (list[int]): A list of token IDs used in the prompt.
364
        finished (bool): A flag indicating whether the pooling is completed.
365
366
    """

367
    def __init__(self, request_id: str, outputs: _O,
368
                 prompt_token_ids: list[int], finished: bool):
369
370
371
372
373
        self.request_id = request_id
        self.prompt_token_ids = prompt_token_ids
        self.finished = finished
        self.outputs = outputs

374
375
376
377
378
    @staticmethod
    def from_seq_group(seq_group: SequenceGroup) -> "PoolingRequestOutput":
        pooled_data = seq_group.pooled_data
        assert pooled_data is not None

379
380
        data = pooled_data.to(dtype=torch.float32, device="cpu")
        output = PoolingOutput(data)
381
382
383
        prompt_token_ids = seq_group.prompt_token_ids
        finished = seq_group.is_finished()

384
385
        return PoolingRequestOutput(seq_group.request_id, output,
                                    prompt_token_ids, finished)
386
387

    def __repr__(self):
388
389
        return (f"{type(self).__name__}(request_id={self.request_id!r}, "
                f"outputs={self.outputs!r}, "
390
391
392
393
                f"prompt_token_ids={self.prompt_token_ids}, "
                f"finished={self.finished})")


394
395
396
397
class RequestOutputFactory:

    @staticmethod
    def create(seq_group: SequenceGroup,
398
               seq_id_to_seq_group: dict[str, SequenceGroupBase],
399
400
401
402
403
404
405
406
               use_cache: bool = False):
        if seq_group.pooled_data is not None:
            return PoolingRequestOutput.from_seq_group(seq_group)
        else:
            return RequestOutput.from_seq_group(seq_group, use_cache,
                                                seq_id_to_seq_group)


407
@dataclass
408
409
class EmbeddingOutput:
    """The output data of one embedding output of a request.
410
411

    Args:
412
        embedding: The embedding vector, which is a list of floats.
413
            Its length depends on the hidden dimension of the model.
414
    """
415
416
417
418
419
420
421
422
423
424
425
426
427
    embedding: list[float]

    @staticmethod
    def from_base(pooling_output: PoolingOutput):
        pooled_data = pooling_output.data
        if pooled_data.ndim != 1:
            raise ValueError("pooled_data should be a 1-D embedding vector")

        return EmbeddingOutput(pooled_data.tolist())

    @property
    def hidden_size(self) -> int:
        return len(self.embedding)
428
429

    def __repr__(self) -> str:
430
        return f"EmbeddingOutput(hidden_size={self.hidden_size})"
431
432


433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
class EmbeddingRequestOutput(PoolingRequestOutput[EmbeddingOutput]):

    @staticmethod
    def from_base(request_output: PoolingRequestOutput):
        return EmbeddingRequestOutput(
            request_id=request_output.request_id,
            outputs=EmbeddingOutput.from_base(request_output.outputs),
            prompt_token_ids=request_output.prompt_token_ids,
            finished=request_output.finished,
        )


@dataclass
class ClassificationOutput:
    """The output data of one classification output of a request.
448
449

    Args:
450
        probs: The probability vector, which is a list of floats.
451
            Its length depends on the number of classes.
452
    """
453
    probs: list[float]
454

455
456
    @staticmethod
    def from_base(pooling_output: PoolingOutput):
457
        # pooling_output shape: (num_classes)
458
459
460
        pooled_data = pooling_output.data
        if pooled_data.ndim != 1:
            raise ValueError("pooled_data should be a 1-D probability vector")
461

462
        return ClassificationOutput(pooled_data.tolist())
463

464
465
466
    @property
    def num_classes(self) -> int:
        return len(self.probs)
467

468
469
    def __repr__(self) -> str:
        return f"ClassificationOutput(num_classes={self.num_classes})"
470
471


472
class ClassificationRequestOutput(PoolingRequestOutput[ClassificationOutput]):
473
474

    @staticmethod
475
476
477
478
479
480
481
    def from_base(request_output: PoolingRequestOutput):
        return ClassificationRequestOutput(
            request_id=request_output.request_id,
            outputs=ClassificationOutput.from_base(request_output.outputs),
            prompt_token_ids=request_output.prompt_token_ids,
            finished=request_output.finished,
        )
482
483


484
485
486
@dataclass
class ScoringOutput:
    """The output data of one scoring output of a request.
487

488
489
490
491
492
493
494
    Args:
        score: The similarity score, which is a scalar value.
    """
    score: float

    @staticmethod
    def from_base(pooling_output: PoolingOutput):
495
496
497
498
        # pooling_output shape:
        #   classify task: (num_classes) num_classes == 1
        #   embed task: a scalar value
        pooled_data = pooling_output.data.squeeze()
499
500
        if pooled_data.ndim != 0:
            raise ValueError("pooled_data should be a scalar score")
501

502
        return ScoringOutput(pooled_data.item())
503

504
505
    def __repr__(self) -> str:
        return f"ScoringOutput(score={self.score})"
506
507


508
509
510
511
512
513
514
515
516
517
class ScoringRequestOutput(PoolingRequestOutput[ScoringOutput]):

    @staticmethod
    def from_base(request_output: PoolingRequestOutput):
        return ScoringRequestOutput(
            request_id=request_output.request_id,
            outputs=ScoringOutput.from_base(request_output.outputs),
            prompt_token_ids=request_output.prompt_token_ids,
            finished=request_output.finished,
        )