outputs.py 9.37 KB
Newer Older
1
import time
2
from dataclasses import dataclass
3
4
5
from typing import List, Optional
from typing import Sequence as GenericSequence
from typing import Union
6

7
from vllm.lora.request import LoRARequest
8
9
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
                           SequenceGroup, SequenceStatus)
10
11


12
@dataclass
13
class CompletionOutput:
Zhuohan Li's avatar
Zhuohan Li committed
14
15
16
17
18
19
20
21
22
23
24
    """The output data of one completion output of a request.

    Args:
        index: The index of the output in the request.
        text: The generated output text.
        token_ids: The token IDs of the generated output text.
        cumulative_logprob: The cumulative log probability of the generated
            output text.
        logprobs: The log probabilities of the top probability words at each
            position if the logprobs are requested.
        finish_reason: The reason why the sequence is finished.
25
26
27
        stop_reason: The stop string or token id that caused the completion
            to stop, None if the completion finished for some other reason
            including encountering the EOS token.
28
        lora_request: The LoRA request that was used to generate the output.
Zhuohan Li's avatar
Zhuohan Li committed
29
    """
30

31
32
    index: int
    text: str
33
    token_ids: GenericSequence[int]
34
    cumulative_logprob: Optional[float]
35
36
37
38
    logprobs: Optional[SampleLogprobs]
    finish_reason: Optional[str] = None
    stop_reason: Union[int, str, None] = None
    lora_request: Optional[LoRARequest] = None
Zhuohan Li's avatar
Zhuohan Li committed
39
40
41

    def finished(self) -> bool:
        return self.finish_reason is not None
42
43

    def __repr__(self) -> str:
44
45
        return (f"CompletionOutput(index={self.index}, "
                f"text={self.text!r}, "
46
                f"token_ids={self.token_ids}, "
47
                f"cumulative_logprob={self.cumulative_logprob}, "
48
                f"logprobs={self.logprobs}, "
49
50
                f"finish_reason={self.finish_reason}, "
                f"stop_reason={self.stop_reason})")
51
52


53
@dataclass
54
55
56
57
58
59
60
61
class EmbeddingOutput:
    """The output data of one completion output of a request.

    Args:
        embedding: The embedding vector, which is a list of floats. The
        length of vector depends on the model as listed in the embedding guide.
    """

62
    embedding: List[float]
63
64
65

    def __repr__(self) -> str:
        return (f"EmbeddingOutput("
66
                f"embedding={len(self.embedding)})")
67
68


69
class RequestOutput:
70
    """The output data of a completion request to the LLM.
Zhuohan Li's avatar
Zhuohan Li committed
71
72
73
74

    Args:
        request_id: The unique ID of the request.
        prompt: The prompt string of the request.
75
76
                For encoder/decoder models, this is the
                decoder input prompt.
Zhuohan Li's avatar
Zhuohan Li committed
77
        prompt_token_ids: The token IDs of the prompt.
78
79
                          For encoder/decoder models, this is the
                          decoder input prompt token ids.
lots-o's avatar
lots-o committed
80
        prompt_logprobs: The log probabilities to return per prompt token.
Zhuohan Li's avatar
Zhuohan Li committed
81
        outputs: The output sequences of the request.
82
        finished: Whether the whole request is finished.
83
        metrics: Metrics associated with the request.
84
        lora_request: The LoRA request that was used to generate the output.
85
86
87
88
        encoder_prompt: The encoder prompt string of the request; 
                        None if decoder-only
        encoder_prompt_token_ids: The token IDs of the encoder prompt;
                                  None if decoder-only
Zhuohan Li's avatar
Zhuohan Li committed
89
    """
90

91
92
    def __init__(
        self,
93
        request_id: str,
94
        prompt: Optional[str],
95
        prompt_token_ids: List[int],
96
        prompt_logprobs: Optional[PromptLogprobs],
97
        outputs: List[CompletionOutput],
98
        finished: bool,
99
        metrics: Optional[RequestMetrics] = None,
100
        lora_request: Optional[LoRARequest] = None,
101
102
        encoder_prompt: Optional[str] = None,
        encoder_prompt_token_ids: Optional[List[int]] = None,
103
104
105
106
    ) -> None:
        self.request_id = request_id
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
107
        self.prompt_logprobs = prompt_logprobs
108
        self.outputs = outputs
109
        self.finished = finished
110
        self.metrics = metrics
111
        self.lora_request = lora_request
112
113
        self.encoder_prompt = encoder_prompt
        self.encoder_prompt_token_ids = encoder_prompt_token_ids
114

115
116
    @classmethod
    def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
117
118
119
        if seq_group.sampling_params is None:
            raise ValueError(
                "Sampling parameters are missing for a CompletionRequest.")
120
        seqs = seq_group.get_seqs()
121
        if len(seqs) == 1:
122
            top_n_seqs = seqs
123
        else:
124
125
            # Get the top-n sequences.
            n = seq_group.sampling_params.n
126
127
128
129
130
131
132
            if seq_group.sampling_params.use_beam_search:
                sorting_key = lambda seq: seq.get_beam_search_score(
                    seq_group.sampling_params.length_penalty)
            else:
                sorting_key = lambda seq: seq.get_cumulative_logprob()
            sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
            top_n_seqs = sorted_seqs[:n]
133

134
        # Create the outputs.
135
136
137
        # NOTE: We need omit logprobs here explicitly because the sequence
        # always has the logprobs of the sampled tokens even if the
        # logprobs are not requested.
138
        include_logprobs = seq_group.sampling_params.logprobs is not None
139
        text_buffer_length = seq_group.sampling_params.output_text_buffer_length
140
        outputs = [
141
142
143
            CompletionOutput(
                seqs.index(seq),
                seq.get_output_text_to_return(text_buffer_length),
144
                seq.data._output_token_ids,
145
146
147
148
                seq.get_cumulative_logprob() if include_logprobs else None,
                seq.output_logprobs if include_logprobs else None,
                SequenceStatus.get_finished_reason(seq.status),
                seq.stop_reason) for seq in top_n_seqs
149
        ]
150
151

        # Every sequence in the sequence group should have the same prompt.
152
153
        prompt = seq_group.prompt
        prompt_token_ids = seq_group.prompt_token_ids
154
155
        encoder_prompt = seq_group.encoder_prompt
        encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids
156
        prompt_logprobs = seq_group.prompt_logprobs
157
        finished = seq_group.is_finished()
158
159
        finished_time = time.time() if finished else None
        seq_group.set_finished_time(finished_time)
160
161
162
163
164
165
        return cls(seq_group.request_id,
                   prompt,
                   prompt_token_ids,
                   prompt_logprobs,
                   outputs,
                   finished,
166
                   seq_group.metrics,
167
168
169
                   lora_request=seq_group.lora_request,
                   encoder_prompt=encoder_prompt,
                   encoder_prompt_token_ids=encoder_prompt_token_ids)
170
171
172
173
174

    def __repr__(self) -> str:
        return (f"RequestOutput(request_id={self.request_id}, "
                f"prompt={self.prompt!r}, "
                f"prompt_token_ids={self.prompt_token_ids}, "
175
176
                f"encoder_prompt={self.encoder_prompt!r}, "
                f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, "
177
                f"prompt_logprobs={self.prompt_logprobs}, "
178
                f"outputs={self.outputs}, "
179
                f"finished={self.finished}, "
180
                f"metrics={self.metrics}, "
181
                f"lora_request={self.lora_request})")
182
183
184
185
186
187
188
189
190
191
192
193
194


class EmbeddingRequestOutput:
    """
    The output data of an embedding request to the LLM.

    Args:
        request_id (str): A unique identifier for the embedding request.
        outputs (EmbeddingOutput): The embedding results for the given input.
        prompt_token_ids (List[int]): A list of token IDs used in the prompt.
        finished (bool): A flag indicating whether the embedding is completed.
    """

195
    def __init__(self, request_id: str, outputs: "EmbeddingOutput",
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
                 prompt_token_ids: List[int], finished: bool):
        self.request_id = request_id
        self.prompt_token_ids = prompt_token_ids
        self.finished = finished
        self.outputs = outputs

    @classmethod
    def from_seq_group(cls,
                       seq_group: 'SequenceGroup') -> "EmbeddingRequestOutput":
        if seq_group.embeddings is None:
            raise ValueError(
                "Embeddings are missing in seq_group for EmbeddingRequest.")
        output = EmbeddingOutput(seq_group.embeddings)
        prompt_token_ids = seq_group.prompt_token_ids
        finished = seq_group.is_finished()

        return cls(seq_group.request_id, output, prompt_token_ids, finished)

    def __repr__(self):
        """
        Returns a string representation of an EmbeddingRequestOutput instance.

        The representation includes the request_id and the number of outputs,
        providing a quick overview of the embedding request's results.

        Returns:
            str: A string representation of the EmbeddingRequestOutput instance.
        """
        return (f"EmbeddingRequestOutput(request_id='{self.request_id}', "
                f"outputs={repr(self.outputs)}, "
                f"prompt_token_ids={self.prompt_token_ids}, "
                f"finished={self.finished})")


class RequestOutputFactory:

    @staticmethod
    def create(seq_group):
        # Determine the type based on a condition, for example:
        if hasattr(seq_group,
                   'embeddings') and seq_group.embeddings is not None:
            return EmbeddingRequestOutput.from_seq_group(seq_group)
        else:
            return RequestOutput.from_seq_group(seq_group)