outputs.py 8.69 KB
Newer Older
1
import time
2
from typing import List, Optional, Union
3

4
from vllm.lora.request import LoRARequest
5
6
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
                           SequenceGroup, SequenceStatus)
7
8
9


class CompletionOutput:
Zhuohan Li's avatar
Zhuohan Li committed
10
11
12
13
14
15
16
17
18
19
20
    """The output data of one completion output of a request.

    Args:
        index: The index of the output in the request.
        text: The generated output text.
        token_ids: The token IDs of the generated output text.
        cumulative_logprob: The cumulative log probability of the generated
            output text.
        logprobs: The log probabilities of the top probability words at each
            position if the logprobs are requested.
        finish_reason: The reason why the sequence is finished.
21
22
23
        stop_reason: The stop string or token id that caused the completion
            to stop, None if the completion finished for some other reason
            including encountering the EOS token.
24
        lora_request: The LoRA request that was used to generate the output.
Zhuohan Li's avatar
Zhuohan Li committed
25
    """
26
27
28

    def __init__(
        self,
29
        index: int,
30
31
        text: str,
        token_ids: List[int],
32
        cumulative_logprob: float,
33
        logprobs: Optional[SampleLogprobs],
Zhuohan Li's avatar
Zhuohan Li committed
34
        finish_reason: Optional[str] = None,
35
        stop_reason: Union[int, str, None] = None,
36
        lora_request: Optional[LoRARequest] = None,
37
    ) -> None:
38
        self.index = index
39
40
        self.text = text
        self.token_ids = token_ids
41
        self.cumulative_logprob = cumulative_logprob
42
        self.logprobs = logprobs
Zhuohan Li's avatar
Zhuohan Li committed
43
        self.finish_reason = finish_reason
44
        self.stop_reason = stop_reason
45
        self.lora_request = lora_request
Zhuohan Li's avatar
Zhuohan Li committed
46
47
48

    def finished(self) -> bool:
        return self.finish_reason is not None
49
50

    def __repr__(self) -> str:
51
52
        return (f"CompletionOutput(index={self.index}, "
                f"text={self.text!r}, "
53
                f"token_ids={self.token_ids}, "
54
                f"cumulative_logprob={self.cumulative_logprob}, "
55
                f"logprobs={self.logprobs}, "
56
57
                f"finish_reason={self.finish_reason}, "
                f"stop_reason={self.stop_reason})")
58
59


60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class EmbeddingOutput:
    """The output data of one completion output of a request.

    Args:
        embedding: The embedding vector, which is a list of floats. The
        length of vector depends on the model as listed in the embedding guide.
    """

    def __init__(
        self,
        embedding: List[float],
    ) -> None:
        self.embedding = embedding

    def __repr__(self) -> str:
        return (f"EmbeddingOutput("
                f"embedding={len(self.embedding)}")


79
class RequestOutput:
80
    """The output data of a completion request to the LLM.
Zhuohan Li's avatar
Zhuohan Li committed
81
82
83
84
85

    Args:
        request_id: The unique ID of the request.
        prompt: The prompt string of the request.
        prompt_token_ids: The token IDs of the prompt.
lots-o's avatar
lots-o committed
86
        prompt_logprobs: The log probabilities to return per prompt token.
Zhuohan Li's avatar
Zhuohan Li committed
87
        outputs: The output sequences of the request.
88
        finished: Whether the whole request is finished.
89
        metrics: Metrics associated with the request.
90
        lora_request: The LoRA request that was used to generate the output.
Zhuohan Li's avatar
Zhuohan Li committed
91
    """
92

93
94
    def __init__(
        self,
95
        request_id: str,
96
97
        prompt: str,
        prompt_token_ids: List[int],
98
        prompt_logprobs: Optional[PromptLogprobs],
99
        outputs: List[CompletionOutput],
100
        finished: bool,
101
        metrics: Optional[RequestMetrics] = None,
102
        lora_request: Optional[LoRARequest] = None,
103
104
105
106
    ) -> None:
        self.request_id = request_id
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
107
        self.prompt_logprobs = prompt_logprobs
108
        self.outputs = outputs
109
        self.finished = finished
110
        self.metrics = metrics
111
        self.lora_request = lora_request
112

113
114
    @classmethod
    def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
115
116
117
        if seq_group.sampling_params is None:
            raise ValueError(
                "Sampling parameters are missing for a CompletionRequest.")
118
        seqs = seq_group.get_seqs()
119
        if len(seqs) == 1:
120
            top_n_seqs = seqs
121
        else:
122
123
            # Get the top-n sequences.
            n = seq_group.sampling_params.n
124
125
126
127
128
129
130
            if seq_group.sampling_params.use_beam_search:
                sorting_key = lambda seq: seq.get_beam_search_score(
                    seq_group.sampling_params.length_penalty)
            else:
                sorting_key = lambda seq: seq.get_cumulative_logprob()
            sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
            top_n_seqs = sorted_seqs[:n]
131

132
        # Create the outputs.
133
134
135
        # NOTE: We need omit logprobs here explicitly because the sequence
        # always has the logprobs of the sampled tokens even if the
        # logprobs are not requested.
136
        include_logprobs = seq_group.sampling_params.logprobs is not None
137
        text_buffer_length = seq_group.sampling_params.output_text_buffer_length
138
        outputs = [
139
140
            CompletionOutput(seqs.index(seq),
                             seq.get_output_text_to_return(text_buffer_length),
141
142
143
                             seq.get_output_token_ids(),
                             seq.get_cumulative_logprob(),
                             seq.output_logprobs if include_logprobs else None,
144
145
                             SequenceStatus.get_finished_reason(seq.status),
                             seq.stop_reason) for seq in top_n_seqs
146
        ]
147
148

        # Every sequence in the sequence group should have the same prompt.
149
150
151
        prompt = seq_group.prompt
        prompt_token_ids = seq_group.prompt_token_ids
        prompt_logprobs = seq_group.prompt_logprobs
152
        finished = seq_group.is_finished()
153
154
        finished_time = time.time() if finished else None
        seq_group.set_finished_time(finished_time)
155
156
157
158
159
160
        return cls(seq_group.request_id,
                   prompt,
                   prompt_token_ids,
                   prompt_logprobs,
                   outputs,
                   finished,
161
                   seq_group.metrics,
162
                   lora_request=seq_group.lora_request)
163
164
165
166
167

    def __repr__(self) -> str:
        return (f"RequestOutput(request_id={self.request_id}, "
                f"prompt={self.prompt!r}, "
                f"prompt_token_ids={self.prompt_token_ids}, "
168
                f"prompt_logprobs={self.prompt_logprobs}, "
169
                f"outputs={self.outputs}, "
170
                f"finished={self.finished}, "
171
                f"metrics={self.metrics}, "
172
                f"lora_request={self.lora_request})")
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230


class EmbeddingRequestOutput:
    """
    The output data of an embedding request to the LLM.

    Args:
        request_id (str): A unique identifier for the embedding request.
        outputs (EmbeddingOutput): The embedding results for the given input.
        prompt_token_ids (List[int]): A list of token IDs used in the prompt.
        finished (bool): A flag indicating whether the embedding is completed.
    """

    def __init__(self, request_id: str, outputs: 'EmbeddingOutput',
                 prompt_token_ids: List[int], finished: bool):
        self.request_id = request_id
        self.prompt_token_ids = prompt_token_ids
        self.finished = finished
        self.outputs = outputs

    @classmethod
    def from_seq_group(cls,
                       seq_group: 'SequenceGroup') -> "EmbeddingRequestOutput":
        if seq_group.embeddings is None:
            raise ValueError(
                "Embeddings are missing in seq_group for EmbeddingRequest.")
        output = EmbeddingOutput(seq_group.embeddings)
        prompt_token_ids = seq_group.prompt_token_ids
        finished = seq_group.is_finished()

        return cls(seq_group.request_id, output, prompt_token_ids, finished)

    def __repr__(self):
        """
        Returns a string representation of an EmbeddingRequestOutput instance.

        The representation includes the request_id and the number of outputs,
        providing a quick overview of the embedding request's results.

        Returns:
            str: A string representation of the EmbeddingRequestOutput instance.
        """
        return (f"EmbeddingRequestOutput(request_id='{self.request_id}', "
                f"outputs={repr(self.outputs)}, "
                f"prompt_token_ids={self.prompt_token_ids}, "
                f"finished={self.finished})")


class RequestOutputFactory:

    @staticmethod
    def create(seq_group):
        # Determine the type based on a condition, for example:
        if hasattr(seq_group,
                   'embeddings') and seq_group.embeddings is not None:
            return EmbeddingRequestOutput.from_seq_group(seq_group)
        else:
            return RequestOutput.from_seq_group(seq_group)