outputs.py 3.87 KB
Newer Older
Zhuohan Li's avatar
Zhuohan Li committed
1
from typing import Dict, List, Optional
2

Woosuk Kwon's avatar
Woosuk Kwon committed
3
from vllm.sequence import SequenceGroup, SequenceStatus
4
5
6


class CompletionOutput:
Zhuohan Li's avatar
Zhuohan Li committed
7
8
9
10
11
12
13
14
15
16
17
18
    """The output data of one completion output of a request.

    Args:
        index: The index of the output in the request.
        text: The generated output text.
        token_ids: The token IDs of the generated output text.
        cumulative_logprob: The cumulative log probability of the generated
            output text.
        logprobs: The log probabilities of the top probability words at each
            position if the logprobs are requested.
        finish_reason: The reason why the sequence is finished.
    """
19
20
21

    def __init__(
        self,
22
        index: int,
23
24
        text: str,
        token_ids: List[int],
25
        cumulative_logprob: float,
Zhuohan Li's avatar
Zhuohan Li committed
26
        logprobs: Optional[List[Dict[int, float]]],
Zhuohan Li's avatar
Zhuohan Li committed
27
        finish_reason: Optional[str] = None,
28
    ) -> None:
29
        self.index = index
30
31
        self.text = text
        self.token_ids = token_ids
32
        self.cumulative_logprob = cumulative_logprob
33
        self.logprobs = logprobs
Zhuohan Li's avatar
Zhuohan Li committed
34
35
36
37
        self.finish_reason = finish_reason

    def finished(self) -> bool:
        return self.finish_reason is not None
38
39

    def __repr__(self) -> str:
40
41
        return (f"CompletionOutput(index={self.index}, "
                f"text={self.text!r}, "
42
                f"token_ids={self.token_ids}, "
43
                f"cumulative_logprob={self.cumulative_logprob}, "
44
                f"logprobs={self.logprobs}, "
Zhuohan Li's avatar
Zhuohan Li committed
45
                f"finish_reason={self.finish_reason})")
46
47
48


class RequestOutput:
Zhuohan Li's avatar
Zhuohan Li committed
49
50
51
52
53
54
55
56
    """The output data of a request to the LLM.

    Args:
        request_id: The unique ID of the request.
        prompt: The prompt string of the request.
        prompt_token_ids: The token IDs of the prompt.
        outputs: The output sequences of the request.
    """
57
58
    def __init__(
        self,
59
        request_id: str,
60
61
62
63
64
65
66
67
68
        prompt: str,
        prompt_token_ids: List[int],
        outputs: List[CompletionOutput],
    ) -> None:
        self.request_id = request_id
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
        self.outputs = outputs

69
70
    @classmethod
    def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
71
72
        # Get the top-n sequences.
        n = seq_group.sampling_params.n
73
        seqs = seq_group.get_seqs()
74
75
76
77
        assert n <= len(seqs)
        sorted_seqs = sorted(
            seqs, key=lambda seq: seq.get_cumulative_logprob(), reverse=True)
        top_n_seqs = sorted_seqs[:n]
78

79
80
81
        # Create the outputs.
        outputs: List[CompletionOutput] = []
        for seq in top_n_seqs:
82
            logprobs = seq.output_logprobs
Zhuohan Li's avatar
Zhuohan Li committed
83
            if seq_group.sampling_params.logprobs is None:
84
85
86
87
                # NOTE: We need to take care of this case because the sequence
                # always has the logprobs of the sampled tokens even if the
                # logprobs are not requested.
                logprobs = {}
Zhuohan Li's avatar
Zhuohan Li committed
88
            finshed_reason = SequenceStatus.get_finished_reason(seq.status)
89
90
            output = CompletionOutput(seqs.index(seq), seq.output_text,
                                      seq.get_output_token_ids(),
Zhuohan Li's avatar
Zhuohan Li committed
91
92
                                      seq.get_cumulative_logprob(), logprobs,
                                      finshed_reason)
93
94
95
            outputs.append(output)

        # Every sequence in the sequence group should have the same prompt.
96
97
        prompt = top_n_seqs[0].prompt
        prompt_token_ids = top_n_seqs[0].data.prompt_token_ids
Zhuohan Li's avatar
Zhuohan Li committed
98
        return cls(seq_group.request_id, prompt, prompt_token_ids, outputs)
99
100
101
102
103

    def __repr__(self) -> str:
        return (f"RequestOutput(request_id={self.request_id}, "
                f"prompt={self.prompt!r}, "
                f"prompt_token_ids={self.prompt_token_ids}, "
Zhuohan Li's avatar
Zhuohan Li committed
104
105
106
107
                f"outputs={self.outputs})")

    def finished(self) -> bool:
        return all(output.finished() for output in self.outputs)