outputs.py 2.69 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from typing import Dict, List, Union

from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

from cacheflow.sequence import SequenceGroup


class CompletionOutput:

    def __init__(
        self,
        text: str,
        token_ids: List[int],
        cumulative_logprobs: float,
        logprobs: List[Dict[int, float]],
    ) -> None:
        self.text = text
        self.token_ids = token_ids
        self.cumulative_logprobs = cumulative_logprobs
        self.logprobs = logprobs

    def __repr__(self) -> str:
        return (f"CompletionOutput(output={self.text!r}, "
                f"token_ids={self.token_ids}, "
                f"cumulative_logprobs={self.cumulative_logprobs}, "
                f"logprobs={self.logprobs})")


class RequestOutput:

    def __init__(
        self,
        request_id: int,
        prompt: str,
        prompt_token_ids: List[int],
        outputs: List[CompletionOutput],
        done: bool = False,
    ) -> None:
        self.request_id = request_id
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
        self.outputs = outputs
        self.done = done

    @staticmethod
    def from_seq_group(
        seq_group: SequenceGroup,
        tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    ) -> "RequestOutput":
        outputs: List[CompletionOutput] = []
        seqs = seq_group.get_seqs()
        for seq in seqs:
            output_token_ids = seq.data.output_token_ids
            output_str = tokenizer.decode(output_token_ids,
                                          skip_special_tokens=True)
            seq_logprobs = seq.data.cumulative_logprobs

            logprobs = seq.output_logprobs
            if seq_group.sampling_params.logprobs == 0:
                # NOTE: We need to take care of this case because the sequence
                # always has the logprobs of the sampled tokens even if the
                # logprobs are not requested.
                logprobs = {}
            output = CompletionOutput(output_str, output_token_ids,
                                      seq_logprobs, logprobs)
            outputs.append(output)

        # Every sequence in the sequence group should have the same prompt.
        prompt = seqs[0].prompt
        prompt_token_ids = seqs[0].data.prompt_token_ids
        return RequestOutput(seq_group.request_id, prompt, prompt_token_ids,
                             outputs, seq_group.is_finished())

    def __repr__(self) -> str:
        return (f"RequestOutput(request_id={self.request_id}, "
                f"prompt={self.prompt!r}, "
                f"prompt_token_ids={self.prompt_token_ids}, "
                f"outputs={self.outputs}, "
                f"done={self.done})")