outputs.py 9.31 KB
Newer Older
1
import time
2
from dataclasses import dataclass
3
from typing import List, Optional, Tuple, Union
4

5
from vllm.lora.request import LoRARequest
6
7
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
                           SequenceGroup, SequenceStatus)
8
9


10
@dataclass
11
class CompletionOutput:
Zhuohan Li's avatar
Zhuohan Li committed
12
13
14
15
16
17
18
19
20
21
22
    """The output data of one completion output of a request.

    Args:
        index: The index of the output in the request.
        text: The generated output text.
        token_ids: The token IDs of the generated output text.
        cumulative_logprob: The cumulative log probability of the generated
            output text.
        logprobs: The log probabilities of the top probability words at each
            position if the logprobs are requested.
        finish_reason: The reason why the sequence is finished.
23
24
25
        stop_reason: The stop string or token id that caused the completion
            to stop, None if the completion finished for some other reason
            including encountering the EOS token.
26
        lora_request: The LoRA request that was used to generate the output.
Zhuohan Li's avatar
Zhuohan Li committed
27
    """
28

29
30
    index: int
    text: str
31
    token_ids: Tuple[int, ...]
32
    cumulative_logprob: Optional[float]
33
34
35
36
    logprobs: Optional[SampleLogprobs]
    finish_reason: Optional[str] = None
    stop_reason: Union[int, str, None] = None
    lora_request: Optional[LoRARequest] = None
Zhuohan Li's avatar
Zhuohan Li committed
37
38
39

    def finished(self) -> bool:
        return self.finish_reason is not None
40
41

    def __repr__(self) -> str:
42
43
        return (f"CompletionOutput(index={self.index}, "
                f"text={self.text!r}, "
44
                f"token_ids={self.token_ids}, "
45
                f"cumulative_logprob={self.cumulative_logprob}, "
46
                f"logprobs={self.logprobs}, "
47
48
                f"finish_reason={self.finish_reason}, "
                f"stop_reason={self.stop_reason})")
49
50


51
@dataclass
52
53
54
55
56
57
58
59
class EmbeddingOutput:
    """The output data of one completion output of a request.

    Args:
        embedding: The embedding vector, which is a list of floats. The
        length of vector depends on the model as listed in the embedding guide.
    """

60
    embedding: List[float]
61
62
63

    def __repr__(self) -> str:
        return (f"EmbeddingOutput("
64
                f"embedding={len(self.embedding)})")
65
66


67
class RequestOutput:
68
    """The output data of a completion request to the LLM.
Zhuohan Li's avatar
Zhuohan Li committed
69
70
71
72

    Args:
        request_id: The unique ID of the request.
        prompt: The prompt string of the request.
73
74
                For encoder/decoder models, this is the
                decoder input prompt.
Zhuohan Li's avatar
Zhuohan Li committed
75
        prompt_token_ids: The token IDs of the prompt.
76
77
                          For encoder/decoder models, this is the
                          decoder input prompt token ids.
lots-o's avatar
lots-o committed
78
        prompt_logprobs: The log probabilities to return per prompt token.
Zhuohan Li's avatar
Zhuohan Li committed
79
        outputs: The output sequences of the request.
80
        finished: Whether the whole request is finished.
81
        metrics: Metrics associated with the request.
82
        lora_request: The LoRA request that was used to generate the output.
83
84
85
86
        encoder_prompt: The encoder prompt string of the request; 
                        None if decoder-only
        encoder_prompt_token_ids: The token IDs of the encoder prompt;
                                  None if decoder-only
Zhuohan Li's avatar
Zhuohan Li committed
87
    """
88

89
90
    def __init__(
        self,
91
        request_id: str,
92
        prompt: Optional[str],
93
        prompt_token_ids: List[int],
94
        prompt_logprobs: Optional[PromptLogprobs],
95
        outputs: List[CompletionOutput],
96
        finished: bool,
97
        metrics: Optional[RequestMetrics] = None,
98
        lora_request: Optional[LoRARequest] = None,
99
100
        encoder_prompt: Optional[str] = None,
        encoder_prompt_token_ids: Optional[List[int]] = None,
101
102
103
104
    ) -> None:
        self.request_id = request_id
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
105
        self.prompt_logprobs = prompt_logprobs
106
        self.outputs = outputs
107
        self.finished = finished
108
        self.metrics = metrics
109
        self.lora_request = lora_request
110
111
        self.encoder_prompt = encoder_prompt
        self.encoder_prompt_token_ids = encoder_prompt_token_ids
112

113
114
    @classmethod
    def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
115
116
117
        if seq_group.sampling_params is None:
            raise ValueError(
                "Sampling parameters are missing for a CompletionRequest.")
118
        seqs = seq_group.get_seqs()
119
        if len(seqs) == 1:
120
            top_n_seqs = seqs
121
        else:
122
123
            # Get the top-n sequences.
            n = seq_group.sampling_params.n
124
125
126
127
128
129
130
            if seq_group.sampling_params.use_beam_search:
                sorting_key = lambda seq: seq.get_beam_search_score(
                    seq_group.sampling_params.length_penalty)
            else:
                sorting_key = lambda seq: seq.get_cumulative_logprob()
            sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
            top_n_seqs = sorted_seqs[:n]
131

132
        # Create the outputs.
133
134
135
        # NOTE: We need omit logprobs here explicitly because the sequence
        # always has the logprobs of the sampled tokens even if the
        # logprobs are not requested.
136
        include_logprobs = seq_group.sampling_params.logprobs is not None
137
        text_buffer_length = seq_group.sampling_params.output_text_buffer_length
138
        outputs = [
139
140
141
142
143
144
145
146
            CompletionOutput(
                seqs.index(seq),
                seq.get_output_text_to_return(text_buffer_length),
                seq.get_output_token_ids(),
                seq.get_cumulative_logprob() if include_logprobs else None,
                seq.output_logprobs if include_logprobs else None,
                SequenceStatus.get_finished_reason(seq.status),
                seq.stop_reason) for seq in top_n_seqs
147
        ]
148
149

        # Every sequence in the sequence group should have the same prompt.
150
151
        prompt = seq_group.prompt
        prompt_token_ids = seq_group.prompt_token_ids
152
153
        encoder_prompt = seq_group.encoder_prompt
        encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids
154
        prompt_logprobs = seq_group.prompt_logprobs
155
        finished = seq_group.is_finished()
156
157
        finished_time = time.time() if finished else None
        seq_group.set_finished_time(finished_time)
158
159
160
161
162
163
        return cls(seq_group.request_id,
                   prompt,
                   prompt_token_ids,
                   prompt_logprobs,
                   outputs,
                   finished,
164
                   seq_group.metrics,
165
166
167
                   lora_request=seq_group.lora_request,
                   encoder_prompt=encoder_prompt,
                   encoder_prompt_token_ids=encoder_prompt_token_ids)
168
169
170
171
172

    def __repr__(self) -> str:
        return (f"RequestOutput(request_id={self.request_id}, "
                f"prompt={self.prompt!r}, "
                f"prompt_token_ids={self.prompt_token_ids}, "
173
174
                f"encoder_prompt={self.encoder_prompt!r}, "
                f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, "
175
                f"prompt_logprobs={self.prompt_logprobs}, "
176
                f"outputs={self.outputs}, "
177
                f"finished={self.finished}, "
178
                f"metrics={self.metrics}, "
179
                f"lora_request={self.lora_request})")
180
181
182
183
184
185
186
187
188
189
190
191
192


class EmbeddingRequestOutput:
    """
    The output data of an embedding request to the LLM.

    Args:
        request_id (str): A unique identifier for the embedding request.
        outputs (EmbeddingOutput): The embedding results for the given input.
        prompt_token_ids (List[int]): A list of token IDs used in the prompt.
        finished (bool): A flag indicating whether the embedding is completed.
    """

193
    def __init__(self, request_id: str, outputs: "EmbeddingOutput",
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
                 prompt_token_ids: List[int], finished: bool):
        self.request_id = request_id
        self.prompt_token_ids = prompt_token_ids
        self.finished = finished
        self.outputs = outputs

    @classmethod
    def from_seq_group(cls,
                       seq_group: 'SequenceGroup') -> "EmbeddingRequestOutput":
        if seq_group.embeddings is None:
            raise ValueError(
                "Embeddings are missing in seq_group for EmbeddingRequest.")
        output = EmbeddingOutput(seq_group.embeddings)
        prompt_token_ids = seq_group.prompt_token_ids
        finished = seq_group.is_finished()

        return cls(seq_group.request_id, output, prompt_token_ids, finished)

    def __repr__(self):
        """
        Returns a string representation of an EmbeddingRequestOutput instance.

        The representation includes the request_id and the number of outputs,
        providing a quick overview of the embedding request's results.

        Returns:
            str: A string representation of the EmbeddingRequestOutput instance.
        """
        return (f"EmbeddingRequestOutput(request_id='{self.request_id}', "
                f"outputs={repr(self.outputs)}, "
                f"prompt_token_ids={self.prompt_token_ids}, "
                f"finished={self.finished})")


class RequestOutputFactory:

    @staticmethod
    def create(seq_group):
        # Determine the type based on a condition, for example:
        if hasattr(seq_group,
                   'embeddings') and seq_group.embeddings is not None:
            return EmbeddingRequestOutput.from_seq_group(seq_group)
        else:
            return RequestOutput.from_seq_group(seq_group)