outputs.py 5.76 KB
Newer Older
1
import time
2
from typing import List, Optional, Union
3

4
from vllm.lora.request import LoRARequest
5
6
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
                           SequenceGroup, SequenceStatus)
7
8
9


class CompletionOutput:
Zhuohan Li's avatar
Zhuohan Li committed
10
11
12
13
14
15
16
17
18
19
20
    """The output data of one completion output of a request.

    Args:
        index: The index of the output in the request.
        text: The generated output text.
        token_ids: The token IDs of the generated output text.
        cumulative_logprob: The cumulative log probability of the generated
            output text.
        logprobs: The log probabilities of the top probability words at each
            position if the logprobs are requested.
        finish_reason: The reason why the sequence is finished.
21
22
23
        stop_reason: The stop string or token id that caused the completion
            to stop, None if the completion finished for some other reason
            including encountering the EOS token.
24
        lora_request: The LoRA request that was used to generate the output.
Zhuohan Li's avatar
Zhuohan Li committed
25
    """
26
27
28

    def __init__(
        self,
29
        index: int,
30
31
        text: str,
        token_ids: List[int],
32
        cumulative_logprob: float,
33
        logprobs: Optional[SampleLogprobs],
Zhuohan Li's avatar
Zhuohan Li committed
34
        finish_reason: Optional[str] = None,
35
        stop_reason: Union[int, str, None] = None,
36
        lora_request: Optional[LoRARequest] = None,
37
    ) -> None:
38
        self.index = index
39
40
        self.text = text
        self.token_ids = token_ids
41
        self.cumulative_logprob = cumulative_logprob
42
        self.logprobs = logprobs
Zhuohan Li's avatar
Zhuohan Li committed
43
        self.finish_reason = finish_reason
44
        self.stop_reason = stop_reason
45
        self.lora_request = lora_request
Zhuohan Li's avatar
Zhuohan Li committed
46
47
48

    def finished(self) -> bool:
        return self.finish_reason is not None
49
50

    def __repr__(self) -> str:
51
52
        return (f"CompletionOutput(index={self.index}, "
                f"text={self.text!r}, "
53
                f"token_ids={self.token_ids}, "
54
                f"cumulative_logprob={self.cumulative_logprob}, "
55
                f"logprobs={self.logprobs}, "
56
57
                f"finish_reason={self.finish_reason}, "
                f"stop_reason={self.stop_reason})")
58
59
60


class RequestOutput:
Zhuohan Li's avatar
Zhuohan Li committed
61
62
63
64
65
66
    """The output data of a request to the LLM.

    Args:
        request_id: The unique ID of the request.
        prompt: The prompt string of the request.
        prompt_token_ids: The token IDs of the prompt.
lots-o's avatar
lots-o committed
67
        prompt_logprobs: The log probabilities to return per prompt token.
Zhuohan Li's avatar
Zhuohan Li committed
68
        outputs: The output sequences of the request.
69
        finished: Whether the whole request is finished.
70
        metrics: Metrics associated with the request.
71
        lora_request: The LoRA request that was used to generate the output.
Zhuohan Li's avatar
Zhuohan Li committed
72
    """
73

74
75
    def __init__(
        self,
76
        request_id: str,
77
78
        prompt: str,
        prompt_token_ids: List[int],
79
        prompt_logprobs: Optional[PromptLogprobs],
80
        outputs: List[CompletionOutput],
81
        finished: bool,
82
        metrics: Optional[RequestMetrics] = None,
83
        lora_request: Optional[LoRARequest] = None,
84
85
86
87
    ) -> None:
        self.request_id = request_id
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
88
        self.prompt_logprobs = prompt_logprobs
89
        self.outputs = outputs
90
        self.finished = finished
91
        self.metrics = metrics
92
        self.lora_request = lora_request
93

94
95
    @classmethod
    def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
96
        seqs = seq_group.get_seqs()
97
        if len(seqs) == 1:
98
            top_n_seqs = seqs
99
        else:
100
101
            # Get the top-n sequences.
            n = seq_group.sampling_params.n
102
103
104
105
106
107
108
            if seq_group.sampling_params.use_beam_search:
                sorting_key = lambda seq: seq.get_beam_search_score(
                    seq_group.sampling_params.length_penalty)
            else:
                sorting_key = lambda seq: seq.get_cumulative_logprob()
            sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
            top_n_seqs = sorted_seqs[:n]
109

110
        # Create the outputs.
111
112
113
        # NOTE: We need omit logprobs here explicitly because the sequence
        # always has the logprobs of the sampled tokens even if the
        # logprobs are not requested.
114
        include_logprobs = seq_group.sampling_params.logprobs is not None
115
116
117
118
119
        outputs = [
            CompletionOutput(seqs.index(seq), seq.output_text,
                             seq.get_output_token_ids(),
                             seq.get_cumulative_logprob(),
                             seq.output_logprobs if include_logprobs else None,
120
121
                             SequenceStatus.get_finished_reason(seq.status),
                             seq.stop_reason) for seq in top_n_seqs
122
        ]
123
124

        # Every sequence in the sequence group should have the same prompt.
125
126
127
        prompt = seq_group.prompt
        prompt_token_ids = seq_group.prompt_token_ids
        prompt_logprobs = seq_group.prompt_logprobs
128
        finished = seq_group.is_finished()
129
130
        finished_time = time.time() if finished else None
        seq_group.set_finished_time(finished_time)
131
132
133
134
135
136
        return cls(seq_group.request_id,
                   prompt,
                   prompt_token_ids,
                   prompt_logprobs,
                   outputs,
                   finished,
137
                   seq_group.metrics,
138
                   lora_request=seq_group.lora_request)
139
140
141
142
143

    def __repr__(self) -> str:
        return (f"RequestOutput(request_id={self.request_id}, "
                f"prompt={self.prompt!r}, "
                f"prompt_token_ids={self.prompt_token_ids}, "
144
                f"prompt_logprobs={self.prompt_logprobs}, "
145
                f"outputs={self.outputs}, "
146
                f"finished={self.finished}, "
147
                f"metrics={self.metrics}, "
148
                f"lora_request={self.lora_request})")