outputs.py 8.31 KB
Newer Older
1
import time
2
from dataclasses import dataclass
3
from typing import List, Optional, Tuple, Union
4

5
from vllm.lora.request import LoRARequest
6
7
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
                           SequenceGroup, SequenceStatus)
8
9


10
@dataclass
11
class CompletionOutput:
Zhuohan Li's avatar
Zhuohan Li committed
12
13
14
15
16
17
18
19
20
21
22
    """The output data of one completion output of a request.

    Args:
        index: The index of the output in the request.
        text: The generated output text.
        token_ids: The token IDs of the generated output text.
        cumulative_logprob: The cumulative log probability of the generated
            output text.
        logprobs: The log probabilities of the top probability words at each
            position if the logprobs are requested.
        finish_reason: The reason why the sequence is finished.
23
24
25
        stop_reason: The stop string or token id that caused the completion
            to stop, None if the completion finished for some other reason
            including encountering the EOS token.
26
        lora_request: The LoRA request that was used to generate the output.
Zhuohan Li's avatar
Zhuohan Li committed
27
    """
28

29
30
    index: int
    text: str
31
    token_ids: Tuple[int, ...]
32
33
34
35
36
    cumulative_logprob: float
    logprobs: Optional[SampleLogprobs]
    finish_reason: Optional[str] = None
    stop_reason: Union[int, str, None] = None
    lora_request: Optional[LoRARequest] = None
Zhuohan Li's avatar
Zhuohan Li committed
37
38
39

    def finished(self) -> bool:
        return self.finish_reason is not None
40
41

    def __repr__(self) -> str:
42
43
        return (f"CompletionOutput(index={self.index}, "
                f"text={self.text!r}, "
44
                f"token_ids={self.token_ids}, "
45
                f"cumulative_logprob={self.cumulative_logprob}, "
46
                f"logprobs={self.logprobs}, "
47
48
                f"finish_reason={self.finish_reason}, "
                f"stop_reason={self.stop_reason})")
49
50


51
@dataclass
52
53
54
55
56
57
58
59
class EmbeddingOutput:
    """The output data of one completion output of a request.

    Args:
        embedding: The embedding vector, which is a list of floats. The
        length of vector depends on the model as listed in the embedding guide.
    """

60
    embedding: List[float]
61
62
63

    def __repr__(self) -> str:
        return (f"EmbeddingOutput("
64
                f"embedding={len(self.embedding)})")
65
66


67
class RequestOutput:
68
    """The output data of a completion request to the LLM.
Zhuohan Li's avatar
Zhuohan Li committed
69
70
71
72
73

    Args:
        request_id: The unique ID of the request.
        prompt: The prompt string of the request.
        prompt_token_ids: The token IDs of the prompt.
lots-o's avatar
lots-o committed
74
        prompt_logprobs: The log probabilities to return per prompt token.
Zhuohan Li's avatar
Zhuohan Li committed
75
        outputs: The output sequences of the request.
76
        finished: Whether the whole request is finished.
77
        metrics: Metrics associated with the request.
78
        lora_request: The LoRA request that was used to generate the output.
Zhuohan Li's avatar
Zhuohan Li committed
79
    """
80

81
82
    def __init__(
        self,
83
        request_id: str,
84
        prompt: Optional[str],
85
        prompt_token_ids: List[int],
86
        prompt_logprobs: Optional[PromptLogprobs],
87
        outputs: List[CompletionOutput],
88
        finished: bool,
89
        metrics: Optional[RequestMetrics] = None,
90
        lora_request: Optional[LoRARequest] = None,
91
92
93
94
    ) -> None:
        self.request_id = request_id
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
95
        self.prompt_logprobs = prompt_logprobs
96
        self.outputs = outputs
97
        self.finished = finished
98
        self.metrics = metrics
99
        self.lora_request = lora_request
100

101
102
    @classmethod
    def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
103
104
105
        if seq_group.sampling_params is None:
            raise ValueError(
                "Sampling parameters are missing for a CompletionRequest.")
106
        seqs = seq_group.get_seqs()
107
        if len(seqs) == 1:
108
            top_n_seqs = seqs
109
        else:
110
111
            # Get the top-n sequences.
            n = seq_group.sampling_params.n
112
113
114
115
116
117
118
            if seq_group.sampling_params.use_beam_search:
                sorting_key = lambda seq: seq.get_beam_search_score(
                    seq_group.sampling_params.length_penalty)
            else:
                sorting_key = lambda seq: seq.get_cumulative_logprob()
            sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
            top_n_seqs = sorted_seqs[:n]
119

120
        # Create the outputs.
121
122
123
        # NOTE: We need omit logprobs here explicitly because the sequence
        # always has the logprobs of the sampled tokens even if the
        # logprobs are not requested.
124
        include_logprobs = seq_group.sampling_params.logprobs is not None
125
        text_buffer_length = seq_group.sampling_params.output_text_buffer_length
126
        outputs = [
127
128
            CompletionOutput(seqs.index(seq),
                             seq.get_output_text_to_return(text_buffer_length),
129
130
131
                             seq.get_output_token_ids(),
                             seq.get_cumulative_logprob(),
                             seq.output_logprobs if include_logprobs else None,
132
133
                             SequenceStatus.get_finished_reason(seq.status),
                             seq.stop_reason) for seq in top_n_seqs
134
        ]
135
136

        # Every sequence in the sequence group should have the same prompt.
137
138
139
        prompt = seq_group.prompt
        prompt_token_ids = seq_group.prompt_token_ids
        prompt_logprobs = seq_group.prompt_logprobs
140
        finished = seq_group.is_finished()
141
142
        finished_time = time.time() if finished else None
        seq_group.set_finished_time(finished_time)
143
144
145
146
147
148
        return cls(seq_group.request_id,
                   prompt,
                   prompt_token_ids,
                   prompt_logprobs,
                   outputs,
                   finished,
149
                   seq_group.metrics,
150
                   lora_request=seq_group.lora_request)
151
152
153
154
155

    def __repr__(self) -> str:
        return (f"RequestOutput(request_id={self.request_id}, "
                f"prompt={self.prompt!r}, "
                f"prompt_token_ids={self.prompt_token_ids}, "
156
                f"prompt_logprobs={self.prompt_logprobs}, "
157
                f"outputs={self.outputs}, "
158
                f"finished={self.finished}, "
159
                f"metrics={self.metrics}, "
160
                f"lora_request={self.lora_request})")
161
162
163
164
165
166
167
168
169
170
171
172
173


class EmbeddingRequestOutput:
    """
    The output data of an embedding request to the LLM.

    Args:
        request_id (str): A unique identifier for the embedding request.
        outputs (EmbeddingOutput): The embedding results for the given input.
        prompt_token_ids (List[int]): A list of token IDs used in the prompt.
        finished (bool): A flag indicating whether the embedding is completed.
    """

174
    def __init__(self, request_id: str, outputs: "EmbeddingOutput",
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
                 prompt_token_ids: List[int], finished: bool):
        self.request_id = request_id
        self.prompt_token_ids = prompt_token_ids
        self.finished = finished
        self.outputs = outputs

    @classmethod
    def from_seq_group(cls,
                       seq_group: 'SequenceGroup') -> "EmbeddingRequestOutput":
        if seq_group.embeddings is None:
            raise ValueError(
                "Embeddings are missing in seq_group for EmbeddingRequest.")
        output = EmbeddingOutput(seq_group.embeddings)
        prompt_token_ids = seq_group.prompt_token_ids
        finished = seq_group.is_finished()

        return cls(seq_group.request_id, output, prompt_token_ids, finished)

    def __repr__(self):
        """
        Returns a string representation of an EmbeddingRequestOutput instance.

        The representation includes the request_id and the number of outputs,
        providing a quick overview of the embedding request's results.

        Returns:
            str: A string representation of the EmbeddingRequestOutput instance.
        """
        return (f"EmbeddingRequestOutput(request_id='{self.request_id}', "
                f"outputs={repr(self.outputs)}, "
                f"prompt_token_ids={self.prompt_token_ids}, "
                f"finished={self.finished})")


class RequestOutputFactory:

    @staticmethod
    def create(seq_group):
        # Determine the type based on a condition, for example:
        if hasattr(seq_group,
                   'embeddings') and seq_group.embeddings is not None:
            return EmbeddingRequestOutput.from_seq_group(seq_group)
        else:
            return RequestOutput.from_seq_group(seq_group)