outputs.py 19.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import time
4
5
from collections.abc import MutableSequence
from collections.abc import Sequence as GenericSequence
6
from dataclasses import dataclass
Robert Shaw's avatar
Robert Shaw committed
7
from typing import Any, Generic, Optional, Union
8

9
import torch
10
from typing_extensions import TypeVar
11

12
from vllm.logger import init_logger
13
from vllm.lora.request import LoRARequest
14
from vllm.multimodal.inputs import MultiModalPlaceholderDict
15
from vllm.sampling_params import RequestOutputKind
16
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
17
                           SequenceGroup, SequenceGroupBase, SequenceStatus)
18

19
20
logger = init_logger(__name__)

21

22
@dataclass
23
class CompletionOutput:
Zhuohan Li's avatar
Zhuohan Li committed
24
25
26
27
28
29
30
31
32
33
34
    """The output data of one completion output of a request.

    Args:
        index: The index of the output in the request.
        text: The generated output text.
        token_ids: The token IDs of the generated output text.
        cumulative_logprob: The cumulative log probability of the generated
            output text.
        logprobs: The log probabilities of the top probability words at each
            position if the logprobs are requested.
        finish_reason: The reason why the sequence is finished.
35
36
37
        stop_reason: The stop string or token id that caused the completion
            to stop, None if the completion finished for some other reason
            including encountering the EOS token.
38
        lora_request: The LoRA request that was used to generate the output.
Zhuohan Li's avatar
Zhuohan Li committed
39
    """
40

41
42
    index: int
    text: str
43
    token_ids: GenericSequence[int]
44
    cumulative_logprob: Optional[float]
45
46
47
48
    logprobs: Optional[SampleLogprobs]
    finish_reason: Optional[str] = None
    stop_reason: Union[int, str, None] = None
    lora_request: Optional[LoRARequest] = None
Zhuohan Li's avatar
Zhuohan Li committed
49
50
51

    def finished(self) -> bool:
        return self.finish_reason is not None
52
53

    def __repr__(self) -> str:
54
55
        return (f"CompletionOutput(index={self.index}, "
                f"text={self.text!r}, "
56
                f"token_ids={self.token_ids}, "
57
                f"cumulative_logprob={self.cumulative_logprob}, "
58
                f"logprobs={self.logprobs}, "
59
60
                f"finish_reason={self.finish_reason}, "
                f"stop_reason={self.stop_reason})")
61
62


63
@dataclass
64
65
class PoolingOutput:
    """The output data of one pooling output of a request.
66
67

    Args:
68
        data: The extracted hidden states.
69
    """
70
    data: torch.Tensor
71
72

    def __repr__(self) -> str:
73
74
75
76
77
78
        return (f"PoolingOutput(data={self.data})")

    def __eq__(self, other: object) -> bool:
        return (isinstance(other, self.__class__) and bool(
            (self.data == other.data).all()))

79

80
class RequestOutput:
81
    """The output data of a completion request to the LLM.
Zhuohan Li's avatar
Zhuohan Li committed
82
83
84
85

    Args:
        request_id: The unique ID of the request.
        prompt: The prompt string of the request.
86
87
                For encoder/decoder models, this is the
                decoder input prompt.
Zhuohan Li's avatar
Zhuohan Li committed
88
        prompt_token_ids: The token IDs of the prompt.
89
90
                          For encoder/decoder models, this is the
                          decoder input prompt token ids.
lots-o's avatar
lots-o committed
91
        prompt_logprobs: The log probabilities to return per prompt token.
Zhuohan Li's avatar
Zhuohan Li committed
92
        outputs: The output sequences of the request.
93
        finished: Whether the whole request is finished.
94
        metrics: Metrics associated with the request.
95
        lora_request: The LoRA request that was used to generate the output.
96
97
98
99
100
        encoder_prompt: The encoder prompt string of the request.
                        None if decoder-only.
        encoder_prompt_token_ids: The token IDs of the encoder prompt.
                                  None if decoder-only.
        num_cached_tokens: The number of tokens with prefix cache hit.
Robert Shaw's avatar
Robert Shaw committed
101
        kv_transfer_params: The params for remote K/V transfer.
Zhuohan Li's avatar
Zhuohan Li committed
102
    """
103

104
105
    def __init__(
        self,
106
        request_id: str,
107
        prompt: Optional[str],
108
        prompt_token_ids: Optional[list[int]],
109
        prompt_logprobs: Optional[PromptLogprobs],
110
        outputs: list[CompletionOutput],
111
        finished: bool,
112
        metrics: Optional[RequestMetrics] = None,
113
        lora_request: Optional[LoRARequest] = None,
114
        encoder_prompt: Optional[str] = None,
115
        encoder_prompt_token_ids: Optional[list[int]] = None,
116
        num_cached_tokens: Optional[int] = None,
117
118
        *,
        multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None,
Robert Shaw's avatar
Robert Shaw committed
119
        kv_transfer_params: Optional[dict[str, Any]] = None,
120
121
122
        # Forward compatibility, code that uses args added in new release can
        # still run with older versions of vLLM without breaking.
        **kwargs: Any,
123
    ) -> None:
124
125
126
        if kwargs:
            logger.warning_once("RequestOutput: Ignoring extra arguments: %s",
                                str(kwargs))
127
128
129
        self.request_id = request_id
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
130
        self.multi_modal_placeholders = multi_modal_placeholders or {}
131
        self.prompt_logprobs = prompt_logprobs
132
        self.outputs = outputs
133
        self.finished = finished
134
        self.metrics = metrics
135
        self.lora_request = lora_request
136
137
        self.encoder_prompt = encoder_prompt
        self.encoder_prompt_token_ids = encoder_prompt_token_ids
138
        self.num_cached_tokens = num_cached_tokens
Robert Shaw's avatar
Robert Shaw committed
139
        self.kv_transfer_params = kv_transfer_params
140

141
    def add(self, next_output: "RequestOutput", aggregate: bool) -> None:
142
143
144
        """Merge subsequent RequestOutput into this one"""

        self.finished |= next_output.finished
Robert Shaw's avatar
Robert Shaw committed
145
        self.kv_transfer_params = next_output.kv_transfer_params
146

147
        for next_completion in next_output.outputs:
148
            for i, completion in enumerate(self.outputs):
149
                if completion.index == next_completion.index:
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
                    if aggregate:
                        # Merge outputs with same index
                        completion.text += next_completion.text
                        if not isinstance(completion.token_ids,
                                          MutableSequence):
                            completion.token_ids = list(completion.token_ids)
                        completion.token_ids.extend(next_completion.token_ids)
                        if next_completion.logprobs:
                            assert completion.logprobs is not None
                            completion.logprobs.extend(
                                next_completion.logprobs)
                        completion.cumulative_logprob = (
                            next_completion.cumulative_logprob)
                        completion.finish_reason = next_completion.finish_reason
                        completion.stop_reason = next_completion.stop_reason
                    else:
                        # Replace the output with the new one
                        self.outputs[i] = next_completion
168
169
170
                    break
            else:
                self.outputs.append(next_completion)
171

172
    @classmethod
173
174
    def from_seq_group(
        cls, seq_group: SequenceGroup, use_cache: bool,
175
        seq_id_to_seq_group: dict[str, SequenceGroupBase]
176
177
178
179
180
181
    ) -> Optional["RequestOutput"]:
        finished = seq_group.is_finished()

        if seq_group.request_id in seq_id_to_seq_group:
            group: SequenceGroupBase = seq_id_to_seq_group[
                seq_group.request_id]
182
            assembled_seq_group = group.maybe_assemble_group(seq_group)
183
184
185
186
            if finished:
                group.finish_seq(seq_group)
            if assembled_seq_group is None:
                return None
187
188
189
190
191
192
193

            # clear finished seq in seq_id_to_seq_group
            if len(group.to_be_finished) == 0:
                for sub_request_id in list(group.seq_id_to_index.keys()):
                    if sub_request_id in seq_id_to_seq_group:
                        del seq_id_to_seq_group[sub_request_id]

194
195
196
            return cls.from_seq_group(assembled_seq_group, use_cache,
                                      seq_id_to_seq_group)

197
198
        sampling_params = seq_group.sampling_params
        if sampling_params is None:
199
200
            raise ValueError(
                "Sampling parameters are missing for a CompletionRequest.")
201

202
203
204
205
        if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and (
                not finished):
            return None

206
207
208
209
210
211
212
213
214
215
        # Init cache (if needed)
        if use_cache and seq_group.cached_request_output is None:
            seq_group.cached_request_output = RequestOutput(  # type: ignore
                request_id="",
                prompt=None,
                prompt_token_ids=[],
                prompt_logprobs=None,
                outputs=[],
                finished=False)

216
        top_n_seqs = seq_group.get_seqs()
217

218
        # Create the outputs.
219
220
221
        # NOTE: We need omit logprobs here explicitly because the sequence
        # always has the logprobs of the sampled tokens even if the
        # logprobs are not requested.
222
223
224
225
226
227
        include_logprobs = sampling_params.logprobs is not None
        text_buffer_length = sampling_params.output_text_buffer_length
        delta = sampling_params.output_kind == RequestOutputKind.DELTA

        outputs = []
        include_prompt = True
228
229
        # num_cached_tokens should be the same for all the sequences
        num_cached_tokens = None
230
        for i, seq in enumerate(top_n_seqs):
231
232
            output_text = seq.get_output_text_to_return(
                text_buffer_length, delta)
233

234
            output_token_ids = seq.get_output_token_ids_to_return(delta)
235
236
            num_output_tokens = 1 if isinstance(output_token_ids,
                                                int) else len(output_token_ids)
237
            num_cached_tokens = seq.data.get_num_cached_tokens()
238

239
240
241
242
243
            output_logprobs = seq.output_logprobs if include_logprobs else None

            if delta:
                # Slice logprobs delta if applicable
                if output_logprobs:
244
245
246
247
248
249
                    # num_output_tokens can be 0 when n > 1 and request finishes
                    # before the others
                    if num_output_tokens > 0:
                        output_logprobs = output_logprobs[-num_output_tokens:]
                    else:
                        output_logprobs = None
250
251
                # Don't include prompt if this is after the first output
                # containing decode token ids
252
                if include_prompt and seq.get_output_len() > num_output_tokens:
253
254
                    include_prompt = False

255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
            if use_cache:
                # Get cached output object
                cached_outputs = seq_group.cached_request_output.outputs  # type: ignore
                if i >= len(cached_outputs):
                    cached_outputs.append(
                        CompletionOutput(index=i,
                                         text="",
                                         token_ids=[],
                                         cumulative_logprob=None,
                                         logprobs=None,
                                         finish_reason=None,
                                         stop_reason=None))
                output = cached_outputs[i]

                # Init cached output object
                assert output.index == i
                output.text = output_text

                if isinstance(output_token_ids, int):
                    output.token_ids.clear()
                    output.token_ids.append(output_token_ids)
                else:
                    output.token_ids = output_token_ids

                output.cumulative_logprob = seq.get_cumulative_logprob() \
                    if include_logprobs else None
                output.logprobs = output_logprobs
                output.finish_reason = SequenceStatus.get_finished_reason(
                    seq.status)
                output.stop_reason = seq.stop_reason

            else:
                output = CompletionOutput(
288
                    top_n_seqs.index(seq), output_text, [output_token_ids]
289
                    if isinstance(output_token_ids, int) else output_token_ids,
290
291
292
                    seq.get_cumulative_logprob() if include_logprobs else None,
                    output_logprobs,
                    SequenceStatus.get_finished_reason(seq.status),
293
294
295
                    seq.stop_reason)

            outputs.append(output)
296
297

        # Every sequence in the sequence group should have the same prompt.
298
299
300
301
302
303
304
305
306
307
308
309
        if include_prompt:
            prompt = seq_group.prompt
            prompt_token_ids = seq_group.prompt_token_ids
            encoder_prompt = seq_group.encoder_prompt
            encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids
            prompt_logprobs = seq_group.prompt_logprobs
        else:
            prompt = None
            prompt_token_ids = None
            encoder_prompt = None
            encoder_prompt_token_ids = None
            prompt_logprobs = None
310
311
        finished_time = time.time() if finished else None
        seq_group.set_finished_time(finished_time)
312

313
314
315
316
317
318
319
320
321
322
323
324
325
326
        init_kwargs = {
            "request_id": seq_group.request_id,
            "prompt": prompt,
            "prompt_token_ids": prompt_token_ids,
            "prompt_logprobs": prompt_logprobs,
            "outputs": outputs,
            "finished": finished,
            "metrics": seq_group.metrics,
            "lora_request": seq_group.lora_request,
            "encoder_prompt": encoder_prompt,
            "encoder_prompt_token_ids": encoder_prompt_token_ids,
            "num_cached_tokens": num_cached_tokens,
            "multi_modal_placeholders": seq_group.multi_modal_placeholders
        }
327
328
329

        if use_cache:
            request_output = seq_group.cached_request_output
330
            request_output.__init__(**init_kwargs)  # type: ignore
331
        else:
332
            request_output = cls(**init_kwargs)  # type: ignore
333
334

        return request_output
335
336
337
338
339

    def __repr__(self) -> str:
        return (f"RequestOutput(request_id={self.request_id}, "
                f"prompt={self.prompt!r}, "
                f"prompt_token_ids={self.prompt_token_ids}, "
340
341
                f"encoder_prompt={self.encoder_prompt!r}, "
                f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, "
342
                f"prompt_logprobs={self.prompt_logprobs}, "
343
                f"outputs={self.outputs}, "
344
                f"finished={self.finished}, "
345
                f"metrics={self.metrics}, "
346
                f"lora_request={self.lora_request}, "
347
348
                f"num_cached_tokens={self.num_cached_tokens}, "
                f"multi_modal_placeholders={self.multi_modal_placeholders})")
349
350


351
352
353
354
_O = TypeVar("_O", default=PoolingOutput)


class PoolingRequestOutput(Generic[_O]):
355
    """
356
    The output data of a pooling request to the LLM.
357
358

    Args:
359
360
        request_id (str): A unique identifier for the pooling request.
        outputs (PoolingOutput): The pooling results for the given input.
361
        prompt_token_ids (list[int]): A list of token IDs used in the prompt.
362
        finished (bool): A flag indicating whether the pooling is completed.
363
364
    """

365
    def __init__(self, request_id: str, outputs: _O,
366
                 prompt_token_ids: list[int], finished: bool):
367
368
369
370
371
        self.request_id = request_id
        self.prompt_token_ids = prompt_token_ids
        self.finished = finished
        self.outputs = outputs

372
373
374
375
376
    @staticmethod
    def from_seq_group(seq_group: SequenceGroup) -> "PoolingRequestOutput":
        pooled_data = seq_group.pooled_data
        assert pooled_data is not None

377
378
        data = pooled_data.to(dtype=torch.float32, device="cpu")
        output = PoolingOutput(data)
379
380
381
        prompt_token_ids = seq_group.prompt_token_ids
        finished = seq_group.is_finished()

382
383
        return PoolingRequestOutput(seq_group.request_id, output,
                                    prompt_token_ids, finished)
384
385

    def __repr__(self):
386
387
        return (f"{type(self).__name__}(request_id={self.request_id!r}, "
                f"outputs={self.outputs!r}, "
388
389
390
391
                f"prompt_token_ids={self.prompt_token_ids}, "
                f"finished={self.finished})")


392
393
394
395
class RequestOutputFactory:

    @staticmethod
    def create(seq_group: SequenceGroup,
396
               seq_id_to_seq_group: dict[str, SequenceGroupBase],
397
398
399
400
401
402
403
404
               use_cache: bool = False):
        if seq_group.pooled_data is not None:
            return PoolingRequestOutput.from_seq_group(seq_group)
        else:
            return RequestOutput.from_seq_group(seq_group, use_cache,
                                                seq_id_to_seq_group)


405
@dataclass
406
407
class EmbeddingOutput:
    """The output data of one embedding output of a request.
408
409

    Args:
410
411
        embedding: The embedding vector, which is a list of floats.
        Its length depends on the hidden dimension of the model.
412
    """
413
414
415
416
417
418
419
420
421
422
423
424
425
    embedding: list[float]

    @staticmethod
    def from_base(pooling_output: PoolingOutput):
        pooled_data = pooling_output.data
        if pooled_data.ndim != 1:
            raise ValueError("pooled_data should be a 1-D embedding vector")

        return EmbeddingOutput(pooled_data.tolist())

    @property
    def hidden_size(self) -> int:
        return len(self.embedding)
426
427

    def __repr__(self) -> str:
428
        return f"EmbeddingOutput(hidden_size={self.hidden_size})"
429
430


431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
class EmbeddingRequestOutput(PoolingRequestOutput[EmbeddingOutput]):

    @staticmethod
    def from_base(request_output: PoolingRequestOutput):
        return EmbeddingRequestOutput(
            request_id=request_output.request_id,
            outputs=EmbeddingOutput.from_base(request_output.outputs),
            prompt_token_ids=request_output.prompt_token_ids,
            finished=request_output.finished,
        )


@dataclass
class ClassificationOutput:
    """The output data of one classification output of a request.
446
447

    Args:
448
449
        probs: The probability vector, which is a list of floats.
        Its length depends on the number of classes.
450
    """
451
    probs: list[float]
452

453
454
455
456
457
    @staticmethod
    def from_base(pooling_output: PoolingOutput):
        pooled_data = pooling_output.data
        if pooled_data.ndim != 1:
            raise ValueError("pooled_data should be a 1-D probability vector")
458

459
        return ClassificationOutput(pooled_data.tolist())
460

461
462
463
    @property
    def num_classes(self) -> int:
        return len(self.probs)
464

465
466
    def __repr__(self) -> str:
        return f"ClassificationOutput(num_classes={self.num_classes})"
467
468


469
class ClassificationRequestOutput(PoolingRequestOutput[ClassificationOutput]):
470
471

    @staticmethod
472
473
474
475
476
477
478
    def from_base(request_output: PoolingRequestOutput):
        return ClassificationRequestOutput(
            request_id=request_output.request_id,
            outputs=ClassificationOutput.from_base(request_output.outputs),
            prompt_token_ids=request_output.prompt_token_ids,
            finished=request_output.finished,
        )
479
480


481
482
483
@dataclass
class ScoringOutput:
    """The output data of one scoring output of a request.
484

485
486
487
488
489
490
491
492
493
494
    Args:
        score: The similarity score, which is a scalar value.
    """
    score: float

    @staticmethod
    def from_base(pooling_output: PoolingOutput):
        pooled_data = pooling_output.data
        if pooled_data.ndim != 0:
            raise ValueError("pooled_data should be a scalar score")
495

496
        return ScoringOutput(pooled_data.item())
497

498
499
    def __repr__(self) -> str:
        return f"ScoringOutput(score={self.score})"
500
501


502
503
504
505
506
507
508
509
510
511
class ScoringRequestOutput(PoolingRequestOutput[ScoringOutput]):

    @staticmethod
    def from_base(request_output: PoolingRequestOutput):
        return ScoringRequestOutput(
            request_id=request_output.request_id,
            outputs=ScoringOutput.from_base(request_output.outputs),
            prompt_token_ids=request_output.prompt_token_ids,
            finished=request_output.finished,
        )