detokenizer.py 7.48 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from typing import Dict, List, Optional
4

5
6
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams,
                           Sequence, SequenceGroup)
7

8
9
from .detokenizer_utils import (convert_prompt_ids_to_tokens,
                                detokenize_incrementally)
10
from .tokenizer import AnyTokenizer
11
from .tokenizer_group import TokenizerGroup
12
13
14
15
16


class Detokenizer:
    """Provides methods to decode the output of a model into text."""

17
18
19
20
21
22
    def __init__(self, tokenizer_group: TokenizerGroup, mode="auto"):
        self.mode = mode
        if self.mode != "cpm":
            self.tokenizer_group = tokenizer_group
        else:
            self.tokenizer = tokenizer_group
23

24
    def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
25
26
27
        """Returns the HF tokenizer to use for a given sequence."""
        return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request)

28
29
30
31
    def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup,
                                       prompt_logprobs: List[Optional[Dict[
                                           int, Logprob]]],
                                       position_offset: int) -> None:
32
33
34
35
36
        """Decodes the logprobs for the prompt of a sequence group.

        Args:
            seq_group: The sequence group to decode.
            prompt_logprobs: The logprobs to decode.
37
38
            position_offset: Offset of the first index of the logprobs 
                relative to the start of the sequence (for chunked prefill).
39
40
41
42
43
        
        Returns:
            The prompt logprobs with the decoded tokens.
        """
        prms = seq_group.sampling_params
44
45
        assert prms is not None

46
        # We can pick any sequence for the prompt.
47
        seq = seq_group.get_seqs()[0]
48
49
50
        # Only prompt, without the generated token.
        all_token_ids = seq.get_token_ids()
        prompt_token_ids = all_token_ids[:-1]
51
52
53
54
        if self.mode != "cpm":
            tokenizer = self.get_tokenizer_for_seq(seq)
        else:
            tokenizer = self.tokenizer
55
56
57
58
        prefix_offset = 0
        read_offset = 0
        next_iter_prefix_offset = 0
        next_iter_read_offset = 0
59
        next_iter_tokens: List[str] = []
60
61
        prev_tokens = None

62
        for token_position_in_logprob, prompt_logprobs_for_token in enumerate(
63
                prompt_logprobs):
64
65
66
67
68

            # Absolute token position equals the index in the logprobs
            # list plus the offset of the entire logprobs list relative
            # to the start of the sequence.
            token_position = token_position_in_logprob + position_offset
69
70
71
72
            if not prompt_logprobs_for_token:
                continue
            for token_id, sample_logprob in prompt_logprobs_for_token.items():
                if (sample_logprob.decoded_token is None
73
                        and token_id != VLLM_INVALID_TOKEN_ID):
74
75
76
77
78
79
80
81
82
83
84
85
                    prompt_token_ids_with_token = (
                        prompt_token_ids[:token_position] + [token_id])
                    (new_tokens, new_text, new_prefix_offset,
                     new_read_offset) = detokenize_incrementally(
                         tokenizer=tokenizer,
                         all_input_ids=prompt_token_ids_with_token,
                         prev_tokens=prev_tokens,
                         prefix_offset=prefix_offset,
                         read_offset=read_offset,
                         skip_special_tokens=prms.skip_special_tokens,
                         spaces_between_special_tokens=prms.
                         spaces_between_special_tokens,
86
                         mode=self.mode,
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
                     )

                    sample_logprob.decoded_token = new_text

                    # Use the offsets & prev tokens corresponding to
                    # real tokens to ensure detokenization is consistent
                    # actual with prompt.
                    if token_id == all_token_ids[token_position]:
                        next_iter_prefix_offset = new_prefix_offset
                        next_iter_read_offset = new_read_offset
                        next_iter_tokens = new_tokens

            # Advance to the next token position.
            prefix_offset = next_iter_prefix_offset
            read_offset = next_iter_read_offset
            if prev_tokens is None:
103
                prev_tokens = next_iter_tokens.copy()
104
105
106
107
            else:
                prev_tokens.extend(next_iter_tokens)

    def decode_sequence_inplace(self, seq: Sequence,
108
                                prms: SamplingParams) -> int:
109
110
111
112
113
        """Decodes the new token for a sequence. In-place operation.

        Args:
            seq: The sequence to decode.
            prms: The sampling parameters used to generate the sequence.
114
115
116

        Returns:
            The number of characters added to the output text.
117
118
119
        """
        all_input_ids = seq.get_token_ids()
        token_id_generated_this_iteration = all_input_ids[-1]
120
121
122
123
        if self.mode != "cpm":
            tokenizer = self.get_tokenizer_for_seq(seq)
        else:
            tokenizer = self.tokenizer
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144

        # Convert prompt token IDs to tokens if necessary.
        # Do it here so that we don't have to repeat this
        # computation for each logprob.
        if seq.tokens is None:
            (seq.tokens, seq.prefix_offset,
             seq.read_offset) = convert_prompt_ids_to_tokens(
                 tokenizer=tokenizer,
                 prompt_ids=all_input_ids[:-1],
                 skip_special_tokens=prms.skip_special_tokens,
             )

        (new_tokens, new_decoded_token_text, prefix_offset,
         read_offset) = detokenize_incrementally(
             tokenizer=tokenizer,
             all_input_ids=all_input_ids,
             prev_tokens=seq.tokens,
             prefix_offset=seq.prefix_offset,
             read_offset=seq.read_offset,
             skip_special_tokens=prms.skip_special_tokens,
             spaces_between_special_tokens=prms.spaces_between_special_tokens,
145
             mode=self.mode,
146
147
148
149
150
151
152
153
154
155
156
157
158
159
         )

        # Decode logprobs
        logprobs = seq.output_logprobs[-1]
        if logprobs:
            previous_tokens = all_input_ids[:-1]
            for token_id, sample_logprob in logprobs.items():
                # If the token was generated this iteration,
                # use the provided text.
                if token_id == token_id_generated_this_iteration:
                    sample_logprob.decoded_token = new_decoded_token_text
                    continue

                if (sample_logprob.decoded_token is None
160
                        and token_id != VLLM_INVALID_TOKEN_ID):
161
162
163
164
165
166
167
168
169
170
                    all_input_ids_with_logprob = previous_tokens + [token_id]
                    (_, new_text, _, _) = detokenize_incrementally(
                        tokenizer=tokenizer,
                        all_input_ids=all_input_ids_with_logprob,
                        prev_tokens=seq.tokens,
                        prefix_offset=seq.prefix_offset,
                        read_offset=seq.read_offset,
                        skip_special_tokens=prms.skip_special_tokens,
                        spaces_between_special_tokens=prms.
                        spaces_between_special_tokens,
171
                        mode=self.mode,
172
173
174
                    )
                    sample_logprob.decoded_token = new_text

175
        seq.tokens.extend(new_tokens)
176
177
178
        seq.prefix_offset = prefix_offset
        seq.read_offset = read_offset
        seq.output_text += new_decoded_token_text
179

180
        return len(new_decoded_token_text)