outputs.py 2.73 KB
Newer Older
1
from typing import Dict, List
2
3
4
5
6
7
8
9

from cacheflow.sequence import SequenceGroup


class CompletionOutput:

    def __init__(
        self,
10
        index: int,
11
12
        text: str,
        token_ids: List[int],
13
        cumulative_logprob: float,
14
15
        logprobs: List[Dict[int, float]],
    ) -> None:
16
        self.index = index
17
18
        self.text = text
        self.token_ids = token_ids
19
        self.cumulative_logprob = cumulative_logprob
20
21
22
        self.logprobs = logprobs

    def __repr__(self) -> str:
23
24
        return (f"CompletionOutput(index={self.index}, "
                f"text={self.text!r}, "
25
                f"token_ids={self.token_ids}, "
26
                f"cumulative_logprob={self.cumulative_logprob}, "
27
28
29
30
31
32
33
34
35
36
37
                f"logprobs={self.logprobs})")


class RequestOutput:

    def __init__(
        self,
        request_id: int,
        prompt: str,
        prompt_token_ids: List[int],
        outputs: List[CompletionOutput],
38
        done: bool,
39
40
41
42
43
44
45
    ) -> None:
        self.request_id = request_id
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
        self.outputs = outputs
        self.done = done

46
47
    @classmethod
    def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
48
49
        # Get the top-n sequences.
        n = seq_group.sampling_params.n
50
        seqs = seq_group.get_seqs()
51
52
53
54
        assert n <= len(seqs)
        sorted_seqs = sorted(
            seqs, key=lambda seq: seq.get_cumulative_logprob(), reverse=True)
        top_n_seqs = sorted_seqs[:n]
55

56
57
58
        # Create the outputs.
        outputs: List[CompletionOutput] = []
        for seq in top_n_seqs:
59
60
61
62
63
64
            logprobs = seq.output_logprobs
            if seq_group.sampling_params.logprobs == 0:
                # NOTE: We need to take care of this case because the sequence
                # always has the logprobs of the sampled tokens even if the
                # logprobs are not requested.
                logprobs = {}
65
66
67
            output = CompletionOutput(seqs.index(seq), seq.output_text,
                                      seq.get_output_token_ids(),
                                      seq.get_cumulative_logprob(), logprobs)
68
69
70
            outputs.append(output)

        # Every sequence in the sequence group should have the same prompt.
71
72
        prompt = top_n_seqs[0].prompt
        prompt_token_ids = top_n_seqs[0].data.prompt_token_ids
73
74
        return cls(seq_group.request_id, prompt, prompt_token_ids, outputs,
                   seq_group.is_finished())
75
76
77
78
79
80
81

    def __repr__(self) -> str:
        return (f"RequestOutput(request_id={self.request_id}, "
                f"prompt={self.prompt!r}, "
                f"prompt_token_ids={self.prompt_token_ids}, "
                f"outputs={self.outputs}, "
                f"done={self.done})")