outputs.py 19.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
6
from collections.abc import MutableSequence
from collections.abc import Sequence as GenericSequence
7
from dataclasses import dataclass
Robert Shaw's avatar
Robert Shaw committed
8
from typing import Any, Generic, Optional, Union
9

10
import torch
11
from typing_extensions import TypeVar
12

13
from vllm.logger import init_logger
14
from vllm.lora.request import LoRARequest
15
from vllm.multimodal.inputs import MultiModalPlaceholderDict
16
from vllm.sampling_params import RequestOutputKind
17
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
18
                           SequenceGroup, SequenceGroupBase, SequenceStatus)
19

20
21
logger = init_logger(__name__)

22

23
@dataclass
24
class CompletionOutput:
Zhuohan Li's avatar
Zhuohan Li committed
25
26
27
28
29
30
31
32
33
34
35
    """The output data of one completion output of a request.

    Args:
        index: The index of the output in the request.
        text: The generated output text.
        token_ids: The token IDs of the generated output text.
        cumulative_logprob: The cumulative log probability of the generated
            output text.
        logprobs: The log probabilities of the top probability words at each
            position if the logprobs are requested.
        finish_reason: The reason why the sequence is finished.
36
37
38
        stop_reason: The stop string or token id that caused the completion
            to stop, None if the completion finished for some other reason
            including encountering the EOS token.
39
        lora_request: The LoRA request that was used to generate the output.
Zhuohan Li's avatar
Zhuohan Li committed
40
    """
41

42
43
    index: int
    text: str
44
    token_ids: GenericSequence[int]
45
    cumulative_logprob: Optional[float]
46
47
48
49
    logprobs: Optional[SampleLogprobs]
    finish_reason: Optional[str] = None
    stop_reason: Union[int, str, None] = None
    lora_request: Optional[LoRARequest] = None
Zhuohan Li's avatar
Zhuohan Li committed
50
51
52

    def finished(self) -> bool:
        return self.finish_reason is not None
53
54

    def __repr__(self) -> str:
55
56
        return (f"CompletionOutput(index={self.index}, "
                f"text={self.text!r}, "
57
                f"token_ids={self.token_ids}, "
58
                f"cumulative_logprob={self.cumulative_logprob}, "
59
                f"logprobs={self.logprobs}, "
60
61
                f"finish_reason={self.finish_reason}, "
                f"stop_reason={self.stop_reason})")
62
63


64
@dataclass
65
66
class PoolingOutput:
    """The output data of one pooling output of a request.
67
68

    Args:
69
        data: The extracted hidden states.
70
    """
71
    data: torch.Tensor
72
73

    def __repr__(self) -> str:
74
75
76
77
78
79
        return (f"PoolingOutput(data={self.data})")

    def __eq__(self, other: object) -> bool:
        return (isinstance(other, self.__class__) and bool(
            (self.data == other.data).all()))

80

81
class RequestOutput:
82
    """The output data of a completion request to the LLM.
Zhuohan Li's avatar
Zhuohan Li committed
83
84
85
86

    Args:
        request_id: The unique ID of the request.
        prompt: The prompt string of the request.
87
88
                For encoder/decoder models, this is the
                decoder input prompt.
Zhuohan Li's avatar
Zhuohan Li committed
89
        prompt_token_ids: The token IDs of the prompt.
90
91
                          For encoder/decoder models, this is the
                          decoder input prompt token ids.
lots-o's avatar
lots-o committed
92
        prompt_logprobs: The log probabilities to return per prompt token.
Zhuohan Li's avatar
Zhuohan Li committed
93
        outputs: The output sequences of the request.
94
        finished: Whether the whole request is finished.
95
        metrics: Metrics associated with the request.
96
        lora_request: The LoRA request that was used to generate the output.
97
98
99
100
101
        encoder_prompt: The encoder prompt string of the request.
                        None if decoder-only.
        encoder_prompt_token_ids: The token IDs of the encoder prompt.
                                  None if decoder-only.
        num_cached_tokens: The number of tokens with prefix cache hit.
Robert Shaw's avatar
Robert Shaw committed
102
        kv_transfer_params: The params for remote K/V transfer.
Zhuohan Li's avatar
Zhuohan Li committed
103
    """
104

105
106
    def __init__(
        self,
107
        request_id: str,
108
        prompt: Optional[str],
109
        prompt_token_ids: Optional[list[int]],
110
        prompt_logprobs: Optional[PromptLogprobs],
111
        outputs: list[CompletionOutput],
112
        finished: bool,
113
        metrics: Optional[RequestMetrics] = None,
114
        lora_request: Optional[LoRARequest] = None,
115
        encoder_prompt: Optional[str] = None,
116
        encoder_prompt_token_ids: Optional[list[int]] = None,
117
        num_cached_tokens: Optional[int] = None,
118
119
        *,
        multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None,
Robert Shaw's avatar
Robert Shaw committed
120
        kv_transfer_params: Optional[dict[str, Any]] = None,
121
122
123
        # Forward compatibility, code that uses args added in new release can
        # still run with older versions of vLLM without breaking.
        **kwargs: Any,
124
    ) -> None:
125
126
127
        if kwargs:
            logger.warning_once("RequestOutput: Ignoring extra arguments: %s",
                                str(kwargs))
128
129
130
        self.request_id = request_id
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
131
        self.multi_modal_placeholders = multi_modal_placeholders or {}
132
        self.prompt_logprobs = prompt_logprobs
133
        self.outputs = outputs
134
        self.finished = finished
135
        self.metrics = metrics
136
        self.lora_request = lora_request
137
138
        self.encoder_prompt = encoder_prompt
        self.encoder_prompt_token_ids = encoder_prompt_token_ids
139
        self.num_cached_tokens = num_cached_tokens
Robert Shaw's avatar
Robert Shaw committed
140
        self.kv_transfer_params = kv_transfer_params
141

142
    def add(self, next_output: "RequestOutput", aggregate: bool) -> None:
143
144
145
        """Merge subsequent RequestOutput into this one"""

        self.finished |= next_output.finished
Robert Shaw's avatar
Robert Shaw committed
146
        self.kv_transfer_params = next_output.kv_transfer_params
147

148
        for next_completion in next_output.outputs:
149
            for i, completion in enumerate(self.outputs):
150
                if completion.index == next_completion.index:
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
                    if aggregate:
                        # Merge outputs with same index
                        completion.text += next_completion.text
                        if not isinstance(completion.token_ids,
                                          MutableSequence):
                            completion.token_ids = list(completion.token_ids)
                        completion.token_ids.extend(next_completion.token_ids)
                        if next_completion.logprobs:
                            assert completion.logprobs is not None
                            completion.logprobs.extend(
                                next_completion.logprobs)
                        completion.cumulative_logprob = (
                            next_completion.cumulative_logprob)
                        completion.finish_reason = next_completion.finish_reason
                        completion.stop_reason = next_completion.stop_reason
                    else:
                        # Replace the output with the new one
                        self.outputs[i] = next_completion
169
170
171
                    break
            else:
                self.outputs.append(next_completion)
172

173
    @classmethod
174
175
    def from_seq_group(
        cls, seq_group: SequenceGroup, use_cache: bool,
176
        seq_id_to_seq_group: dict[str, SequenceGroupBase]
177
178
179
180
181
182
    ) -> Optional["RequestOutput"]:
        finished = seq_group.is_finished()

        if seq_group.request_id in seq_id_to_seq_group:
            group: SequenceGroupBase = seq_id_to_seq_group[
                seq_group.request_id]
183
            assembled_seq_group = group.maybe_assemble_group(seq_group)
184
185
186
187
            if finished:
                group.finish_seq(seq_group)
            if assembled_seq_group is None:
                return None
188
189
190
191
192
193
194

            # clear finished seq in seq_id_to_seq_group
            if len(group.to_be_finished) == 0:
                for sub_request_id in list(group.seq_id_to_index.keys()):
                    if sub_request_id in seq_id_to_seq_group:
                        del seq_id_to_seq_group[sub_request_id]

195
196
197
            return cls.from_seq_group(assembled_seq_group, use_cache,
                                      seq_id_to_seq_group)

198
199
        sampling_params = seq_group.sampling_params
        if sampling_params is None:
200
201
            raise ValueError(
                "Sampling parameters are missing for a CompletionRequest.")
202

203
204
205
206
        if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and (
                not finished):
            return None

207
208
209
210
211
212
213
214
215
216
        # Init cache (if needed)
        if use_cache and seq_group.cached_request_output is None:
            seq_group.cached_request_output = RequestOutput(  # type: ignore
                request_id="",
                prompt=None,
                prompt_token_ids=[],
                prompt_logprobs=None,
                outputs=[],
                finished=False)

217
        top_n_seqs = seq_group.get_seqs()
218

219
        # Create the outputs.
220
221
222
        # NOTE: We need omit logprobs here explicitly because the sequence
        # always has the logprobs of the sampled tokens even if the
        # logprobs are not requested.
223
224
225
226
227
228
        include_logprobs = sampling_params.logprobs is not None
        text_buffer_length = sampling_params.output_text_buffer_length
        delta = sampling_params.output_kind == RequestOutputKind.DELTA

        outputs = []
        include_prompt = True
229
230
        # num_cached_tokens should be the same for all the sequences
        num_cached_tokens = None
231
        for i, seq in enumerate(top_n_seqs):
232
233
            output_text = seq.get_output_text_to_return(
                text_buffer_length, delta)
234

235
            output_token_ids = seq.get_output_token_ids_to_return(delta)
236
237
            num_output_tokens = 1 if isinstance(output_token_ids,
                                                int) else len(output_token_ids)
238
            num_cached_tokens = seq.data.get_num_cached_tokens()
239

240
241
242
243
244
            output_logprobs = seq.output_logprobs if include_logprobs else None

            if delta:
                # Slice logprobs delta if applicable
                if output_logprobs:
245
246
247
248
249
250
                    # num_output_tokens can be 0 when n > 1 and request finishes
                    # before the others
                    if num_output_tokens > 0:
                        output_logprobs = output_logprobs[-num_output_tokens:]
                    else:
                        output_logprobs = None
251
252
                # Don't include prompt if this is after the first output
                # containing decode token ids
253
                if include_prompt and seq.get_output_len() > num_output_tokens:
254
255
                    include_prompt = False

256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
            if use_cache:
                # Get cached output object
                cached_outputs = seq_group.cached_request_output.outputs  # type: ignore
                if i >= len(cached_outputs):
                    cached_outputs.append(
                        CompletionOutput(index=i,
                                         text="",
                                         token_ids=[],
                                         cumulative_logprob=None,
                                         logprobs=None,
                                         finish_reason=None,
                                         stop_reason=None))
                output = cached_outputs[i]

                # Init cached output object
                assert output.index == i
                output.text = output_text

                if isinstance(output_token_ids, int):
                    output.token_ids.clear()
                    output.token_ids.append(output_token_ids)
                else:
                    output.token_ids = output_token_ids

                output.cumulative_logprob = seq.get_cumulative_logprob() \
                    if include_logprobs else None
                output.logprobs = output_logprobs
                output.finish_reason = SequenceStatus.get_finished_reason(
                    seq.status)
                output.stop_reason = seq.stop_reason

            else:
                output = CompletionOutput(
289
                    top_n_seqs.index(seq), output_text, [output_token_ids]
290
                    if isinstance(output_token_ids, int) else output_token_ids,
291
292
293
                    seq.get_cumulative_logprob() if include_logprobs else None,
                    output_logprobs,
                    SequenceStatus.get_finished_reason(seq.status),
294
295
296
                    seq.stop_reason)

            outputs.append(output)
297
298

        # Every sequence in the sequence group should have the same prompt.
299
300
301
302
303
304
305
306
307
308
309
310
        if include_prompt:
            prompt = seq_group.prompt
            prompt_token_ids = seq_group.prompt_token_ids
            encoder_prompt = seq_group.encoder_prompt
            encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids
            prompt_logprobs = seq_group.prompt_logprobs
        else:
            prompt = None
            prompt_token_ids = None
            encoder_prompt = None
            encoder_prompt_token_ids = None
            prompt_logprobs = None
311
312
        finished_time = time.time() if finished else None
        seq_group.set_finished_time(finished_time)
313

314
315
316
317
318
319
320
321
322
323
324
325
326
327
        init_kwargs = {
            "request_id": seq_group.request_id,
            "prompt": prompt,
            "prompt_token_ids": prompt_token_ids,
            "prompt_logprobs": prompt_logprobs,
            "outputs": outputs,
            "finished": finished,
            "metrics": seq_group.metrics,
            "lora_request": seq_group.lora_request,
            "encoder_prompt": encoder_prompt,
            "encoder_prompt_token_ids": encoder_prompt_token_ids,
            "num_cached_tokens": num_cached_tokens,
            "multi_modal_placeholders": seq_group.multi_modal_placeholders
        }
328
329
330

        if use_cache:
            request_output = seq_group.cached_request_output
331
            request_output.__init__(**init_kwargs)  # type: ignore
332
        else:
333
            request_output = cls(**init_kwargs)  # type: ignore
334
335

        return request_output
336
337
338
339
340

    def __repr__(self) -> str:
        return (f"RequestOutput(request_id={self.request_id}, "
                f"prompt={self.prompt!r}, "
                f"prompt_token_ids={self.prompt_token_ids}, "
341
342
                f"encoder_prompt={self.encoder_prompt!r}, "
                f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, "
343
                f"prompt_logprobs={self.prompt_logprobs}, "
344
                f"outputs={self.outputs}, "
345
                f"finished={self.finished}, "
346
                f"metrics={self.metrics}, "
347
                f"lora_request={self.lora_request}, "
348
349
                f"num_cached_tokens={self.num_cached_tokens}, "
                f"multi_modal_placeholders={self.multi_modal_placeholders})")
350
351


352
353
354
355
_O = TypeVar("_O", default=PoolingOutput)


class PoolingRequestOutput(Generic[_O]):
356
    """
357
    The output data of a pooling request to the LLM.
358
359

    Args:
360
361
        request_id (str): A unique identifier for the pooling request.
        outputs (PoolingOutput): The pooling results for the given input.
362
        prompt_token_ids (list[int]): A list of token IDs used in the prompt.
363
        finished (bool): A flag indicating whether the pooling is completed.
364
365
    """

366
    def __init__(self, request_id: str, outputs: _O,
367
                 prompt_token_ids: list[int], finished: bool):
368
369
370
371
372
        self.request_id = request_id
        self.prompt_token_ids = prompt_token_ids
        self.finished = finished
        self.outputs = outputs

373
374
375
376
377
    @staticmethod
    def from_seq_group(seq_group: SequenceGroup) -> "PoolingRequestOutput":
        pooled_data = seq_group.pooled_data
        assert pooled_data is not None

378
379
        data = pooled_data.to(dtype=torch.float32, device="cpu")
        output = PoolingOutput(data)
380
381
382
        prompt_token_ids = seq_group.prompt_token_ids
        finished = seq_group.is_finished()

383
384
        return PoolingRequestOutput(seq_group.request_id, output,
                                    prompt_token_ids, finished)
385
386

    def __repr__(self):
387
388
        return (f"{type(self).__name__}(request_id={self.request_id!r}, "
                f"outputs={self.outputs!r}, "
389
390
391
392
                f"prompt_token_ids={self.prompt_token_ids}, "
                f"finished={self.finished})")


393
394
395
396
class RequestOutputFactory:

    @staticmethod
    def create(seq_group: SequenceGroup,
397
               seq_id_to_seq_group: dict[str, SequenceGroupBase],
398
399
400
401
402
403
404
405
               use_cache: bool = False):
        if seq_group.pooled_data is not None:
            return PoolingRequestOutput.from_seq_group(seq_group)
        else:
            return RequestOutput.from_seq_group(seq_group, use_cache,
                                                seq_id_to_seq_group)


406
@dataclass
407
408
class EmbeddingOutput:
    """The output data of one embedding output of a request.
409
410

    Args:
411
        embedding: The embedding vector, which is a list of floats.
412
            Its length depends on the hidden dimension of the model.
413
    """
414
415
416
417
418
419
420
421
422
423
424
425
426
    embedding: list[float]

    @staticmethod
    def from_base(pooling_output: PoolingOutput):
        pooled_data = pooling_output.data
        if pooled_data.ndim != 1:
            raise ValueError("pooled_data should be a 1-D embedding vector")

        return EmbeddingOutput(pooled_data.tolist())

    @property
    def hidden_size(self) -> int:
        return len(self.embedding)
427
428

    def __repr__(self) -> str:
429
        return f"EmbeddingOutput(hidden_size={self.hidden_size})"
430
431


432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
class EmbeddingRequestOutput(PoolingRequestOutput[EmbeddingOutput]):

    @staticmethod
    def from_base(request_output: PoolingRequestOutput):
        return EmbeddingRequestOutput(
            request_id=request_output.request_id,
            outputs=EmbeddingOutput.from_base(request_output.outputs),
            prompt_token_ids=request_output.prompt_token_ids,
            finished=request_output.finished,
        )


@dataclass
class ClassificationOutput:
    """The output data of one classification output of a request.
447
448

    Args:
449
        probs: The probability vector, which is a list of floats.
450
            Its length depends on the number of classes.
451
    """
452
    probs: list[float]
453

454
455
    @staticmethod
    def from_base(pooling_output: PoolingOutput):
456
        # pooling_output shape: (num_classes)
457
458
459
        pooled_data = pooling_output.data
        if pooled_data.ndim != 1:
            raise ValueError("pooled_data should be a 1-D probability vector")
460

461
        return ClassificationOutput(pooled_data.tolist())
462

463
464
465
    @property
    def num_classes(self) -> int:
        return len(self.probs)
466

467
468
    def __repr__(self) -> str:
        return f"ClassificationOutput(num_classes={self.num_classes})"
469
470


471
class ClassificationRequestOutput(PoolingRequestOutput[ClassificationOutput]):
472
473

    @staticmethod
474
475
476
477
478
479
480
    def from_base(request_output: PoolingRequestOutput):
        return ClassificationRequestOutput(
            request_id=request_output.request_id,
            outputs=ClassificationOutput.from_base(request_output.outputs),
            prompt_token_ids=request_output.prompt_token_ids,
            finished=request_output.finished,
        )
481
482


483
484
485
@dataclass
class ScoringOutput:
    """The output data of one scoring output of a request.
486

487
488
489
490
491
492
493
    Args:
        score: The similarity score, which is a scalar value.
    """
    score: float

    @staticmethod
    def from_base(pooling_output: PoolingOutput):
494
495
496
497
        # pooling_output shape:
        #   classify task: (num_classes) num_classes == 1
        #   embed task: a scalar value
        pooled_data = pooling_output.data.squeeze()
498
499
        if pooled_data.ndim != 0:
            raise ValueError("pooled_data should be a scalar score")
500

501
        return ScoringOutput(pooled_data.item())
502

503
504
    def __repr__(self) -> str:
        return f"ScoringOutput(score={self.score})"
505
506


507
508
509
510
511
512
513
514
515
516
class ScoringRequestOutput(PoolingRequestOutput[ScoringOutput]):

    @staticmethod
    def from_base(request_output: PoolingRequestOutput):
        return ScoringRequestOutput(
            request_id=request_output.request_id,
            outputs=ScoringOutput.from_base(request_output.outputs),
            prompt_token_ids=request_output.prompt_token_ids,
            finished=request_output.finished,
        )