outputs.py 5.06 KB
Newer Older
1
from typing import List, Optional
2

3
4
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SequenceGroup,
                           SequenceStatus)
5
from vllm.lora.request import LoRARequest
6
7
8


class CompletionOutput:
Zhuohan Li's avatar
Zhuohan Li committed
9
10
11
12
13
14
15
16
17
18
19
    """The output data of one completion output of a request.

    Args:
        index: The index of the output in the request.
        text: The generated output text.
        token_ids: The token IDs of the generated output text.
        cumulative_logprob: The cumulative log probability of the generated
            output text.
        logprobs: The log probabilities of the top probability words at each
            position if the logprobs are requested.
        finish_reason: The reason why the sequence is finished.
20
        lora_request: The LoRA request that was used to generate the output.
Zhuohan Li's avatar
Zhuohan Li committed
21
    """
22
23
24

    def __init__(
        self,
25
        index: int,
26
27
        text: str,
        token_ids: List[int],
28
        cumulative_logprob: float,
29
        logprobs: Optional[SampleLogprobs],
Zhuohan Li's avatar
Zhuohan Li committed
30
        finish_reason: Optional[str] = None,
31
        lora_request: Optional[LoRARequest] = None,
32
    ) -> None:
33
        self.index = index
34
35
        self.text = text
        self.token_ids = token_ids
36
        self.cumulative_logprob = cumulative_logprob
37
        self.logprobs = logprobs
Zhuohan Li's avatar
Zhuohan Li committed
38
        self.finish_reason = finish_reason
39
        self.lora_request = lora_request
Zhuohan Li's avatar
Zhuohan Li committed
40
41
42

    def finished(self) -> bool:
        return self.finish_reason is not None
43
44

    def __repr__(self) -> str:
45
46
        return (f"CompletionOutput(index={self.index}, "
                f"text={self.text!r}, "
47
                f"token_ids={self.token_ids}, "
48
                f"cumulative_logprob={self.cumulative_logprob}, "
49
                f"logprobs={self.logprobs}, "
Zhuohan Li's avatar
Zhuohan Li committed
50
                f"finish_reason={self.finish_reason})")
51
52
53


class RequestOutput:
Zhuohan Li's avatar
Zhuohan Li committed
54
55
56
57
58
59
    """The output data of a request to the LLM.

    Args:
        request_id: The unique ID of the request.
        prompt: The prompt string of the request.
        prompt_token_ids: The token IDs of the prompt.
lots-o's avatar
lots-o committed
60
        prompt_logprobs: The log probabilities to return per prompt token.
Zhuohan Li's avatar
Zhuohan Li committed
61
        outputs: The output sequences of the request.
62
        finished: Whether the whole request is finished.
63
        lora_request: The LoRA request that was used to generate the output.
Zhuohan Li's avatar
Zhuohan Li committed
64
    """
65

66
67
    def __init__(
        self,
68
        request_id: str,
69
70
        prompt: str,
        prompt_token_ids: List[int],
71
        prompt_logprobs: Optional[PromptLogprobs],
72
        outputs: List[CompletionOutput],
73
        finished: bool,
74
        lora_request: Optional[LoRARequest] = None,
75
76
77
78
    ) -> None:
        self.request_id = request_id
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
79
        self.prompt_logprobs = prompt_logprobs
80
        self.outputs = outputs
81
        self.finished = finished
82
        self.lora_request = lora_request
83

84
85
    @classmethod
    def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
86
87
        # Get the top-n sequences.
        n = seq_group.sampling_params.n
88
        seqs = seq_group.get_seqs()
89
90
91
92
93
94
        if seq_group.sampling_params.use_beam_search:
            sorting_key = lambda seq: seq.get_beam_search_score(
                seq_group.sampling_params.length_penalty)
        else:
            sorting_key = lambda seq: seq.get_cumulative_logprob()
        sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
95
        top_n_seqs = sorted_seqs[:n]
96

97
98
99
        # Create the outputs.
        outputs: List[CompletionOutput] = []
        for seq in top_n_seqs:
100
            logprobs = seq.output_logprobs
Zhuohan Li's avatar
Zhuohan Li committed
101
            if seq_group.sampling_params.logprobs is None:
102
103
104
                # NOTE: We need to take care of this case because the sequence
                # always has the logprobs of the sampled tokens even if the
                # logprobs are not requested.
105
                logprobs = None
Zhuohan Li's avatar
Zhuohan Li committed
106
            finshed_reason = SequenceStatus.get_finished_reason(seq.status)
107
108
            output = CompletionOutput(seqs.index(seq), seq.output_text,
                                      seq.get_output_token_ids(),
Zhuohan Li's avatar
Zhuohan Li committed
109
110
                                      seq.get_cumulative_logprob(), logprobs,
                                      finshed_reason)
111
112
113
            outputs.append(output)

        # Every sequence in the sequence group should have the same prompt.
114
115
116
        prompt = seq_group.prompt
        prompt_token_ids = seq_group.prompt_token_ids
        prompt_logprobs = seq_group.prompt_logprobs
117
        finished = seq_group.is_finished()
118
119
120
121
122
123
124
        return cls(seq_group.request_id,
                   prompt,
                   prompt_token_ids,
                   prompt_logprobs,
                   outputs,
                   finished,
                   lora_request=seq_group.lora_request)
125
126
127
128
129

    def __repr__(self) -> str:
        return (f"RequestOutput(request_id={self.request_id}, "
                f"prompt={self.prompt!r}, "
                f"prompt_token_ids={self.prompt_token_ids}, "
130
                f"prompt_logprobs={self.prompt_logprobs}, "
131
                f"outputs={self.outputs}, "
132
133
                f"finished={self.finished}, "
                f"lora_request={self.lora_request})")