outputs.py 3.07 KB
Newer Older
Zhuohan Li's avatar
Zhuohan Li committed
1
from typing import Dict, List, Optional
2

Zhuohan Li's avatar
Zhuohan Li committed
3
from cacheflow.sequence import SequenceGroup, SequenceStatus
4
5
6
7
8
9


class CompletionOutput:

    def __init__(
        self,
10
        index: int,
11
12
        text: str,
        token_ids: List[int],
13
        cumulative_logprob: float,
14
        logprobs: List[Dict[int, float]],
Zhuohan Li's avatar
Zhuohan Li committed
15
        finish_reason: Optional[str] = None,
16
    ) -> None:
17
        self.index = index
18
19
        self.text = text
        self.token_ids = token_ids
20
        self.cumulative_logprob = cumulative_logprob
21
        self.logprobs = logprobs
Zhuohan Li's avatar
Zhuohan Li committed
22
23
24
25
        self.finish_reason = finish_reason

    def finished(self) -> bool:
        return self.finish_reason is not None
26
27

    def __repr__(self) -> str:
28
29
        return (f"CompletionOutput(index={self.index}, "
                f"text={self.text!r}, "
30
                f"token_ids={self.token_ids}, "
31
                f"cumulative_logprob={self.cumulative_logprob}, "
Zhuohan Li's avatar
Zhuohan Li committed
32
33
                f"logprobs={self.logprobs},"
                f"finish_reason={self.finish_reason})")
34
35
36
37
38
39


class RequestOutput:

    def __init__(
        self,
40
        request_id: str,
41
42
43
44
45
46
47
48
49
        prompt: str,
        prompt_token_ids: List[int],
        outputs: List[CompletionOutput],
    ) -> None:
        self.request_id = request_id
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
        self.outputs = outputs

50
51
    @classmethod
    def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
52
53
        # Get the top-n sequences.
        n = seq_group.sampling_params.n
54
        seqs = seq_group.get_seqs()
55
56
57
58
        assert n <= len(seqs)
        sorted_seqs = sorted(
            seqs, key=lambda seq: seq.get_cumulative_logprob(), reverse=True)
        top_n_seqs = sorted_seqs[:n]
59

60
61
62
        # Create the outputs.
        outputs: List[CompletionOutput] = []
        for seq in top_n_seqs:
63
            logprobs = seq.output_logprobs
Zhuohan Li's avatar
Zhuohan Li committed
64
            if seq_group.sampling_params.logprobs is None:
65
66
67
68
                # NOTE: We need to take care of this case because the sequence
                # always has the logprobs of the sampled tokens even if the
                # logprobs are not requested.
                logprobs = {}
Zhuohan Li's avatar
Zhuohan Li committed
69
            finshed_reason = SequenceStatus.get_finished_reason(seq.status)
70
71
            output = CompletionOutput(seqs.index(seq), seq.output_text,
                                      seq.get_output_token_ids(),
Zhuohan Li's avatar
Zhuohan Li committed
72
73
                                      seq.get_cumulative_logprob(), logprobs,
                                      finshed_reason)
74
75
76
            outputs.append(output)

        # Every sequence in the sequence group should have the same prompt.
77
78
        prompt = top_n_seqs[0].prompt
        prompt_token_ids = top_n_seqs[0].data.prompt_token_ids
Zhuohan Li's avatar
Zhuohan Li committed
79
        return cls(seq_group.request_id, prompt, prompt_token_ids, outputs)
80
81
82
83
84

    def __repr__(self) -> str:
        return (f"RequestOutput(request_id={self.request_id}, "
                f"prompt={self.prompt!r}, "
                f"prompt_token_ids={self.prompt_token_ids}, "
Zhuohan Li's avatar
Zhuohan Li committed
85
86
87
88
                f"outputs={self.outputs})")

    def finished(self) -> bool:
        return all(output.finished() for output in self.outputs)