detokenizer.py 7.04 KB
Newer Older
1
from typing import Dict, List, Optional
2

3
4
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams,
                           Sequence, SequenceGroup)
5

6
7
from .detokenizer_utils import (convert_prompt_ids_to_tokens,
                                detokenize_incrementally)
8
9
from .tokenizer import AnyTokenizer
from .tokenizer_group import BaseTokenizerGroup
10
11
12
13
14
15
16
17


class Detokenizer:
    """Provides methods to decode the output of a model into text."""

    def __init__(self, tokenizer_group: BaseTokenizerGroup):
        self.tokenizer_group = tokenizer_group

18
    def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
19
20
21
        """Returns the HF tokenizer to use for a given sequence."""
        return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request)

22
23
24
25
    def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup,
                                       prompt_logprobs: List[Optional[Dict[
                                           int, Logprob]]],
                                       position_offset: int) -> None:
26
27
28
29
30
        """Decodes the logprobs for the prompt of a sequence group.

        Args:
            seq_group: The sequence group to decode.
            prompt_logprobs: The logprobs to decode.
31
32
            position_offset: Offset of the first index of the logprobs 
                relative to the start of the sequence (for chunked prefill).
33
34
35
36
37
        
        Returns:
            The prompt logprobs with the decoded tokens.
        """
        prms = seq_group.sampling_params
38
39
        assert prms is not None

40
        # We can pick any sequence for the prompt.
41
        seq = seq_group.get_seqs()[0]
42
43
44
45
46
47
48
49
        # Only prompt, without the generated token.
        all_token_ids = seq.get_token_ids()
        prompt_token_ids = all_token_ids[:-1]
        tokenizer = self.get_tokenizer_for_seq(seq)
        prefix_offset = 0
        read_offset = 0
        next_iter_prefix_offset = 0
        next_iter_read_offset = 0
50
        next_iter_tokens: List[str] = []
51
52
        prev_tokens = None

53
        for token_position_in_logprob, prompt_logprobs_for_token in enumerate(
54
                prompt_logprobs):
55
56
57
58
59

            # Absolute token position equals the index in the logprobs
            # list plus the offset of the entire logprobs list relative
            # to the start of the sequence.
            token_position = token_position_in_logprob + position_offset
60
61
62
63
            if not prompt_logprobs_for_token:
                continue
            for token_id, sample_logprob in prompt_logprobs_for_token.items():
                if (sample_logprob.decoded_token is None
64
                        and token_id != VLLM_INVALID_TOKEN_ID):
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
                    prompt_token_ids_with_token = (
                        prompt_token_ids[:token_position] + [token_id])
                    (new_tokens, new_text, new_prefix_offset,
                     new_read_offset) = detokenize_incrementally(
                         tokenizer=tokenizer,
                         all_input_ids=prompt_token_ids_with_token,
                         prev_tokens=prev_tokens,
                         prefix_offset=prefix_offset,
                         read_offset=read_offset,
                         skip_special_tokens=prms.skip_special_tokens,
                         spaces_between_special_tokens=prms.
                         spaces_between_special_tokens,
                     )

                    sample_logprob.decoded_token = new_text

                    # Use the offsets & prev tokens corresponding to
                    # real tokens to ensure detokenization is consistent
                    # actual with prompt.
                    if token_id == all_token_ids[token_position]:
                        next_iter_prefix_offset = new_prefix_offset
                        next_iter_read_offset = new_read_offset
                        next_iter_tokens = new_tokens

            # Advance to the next token position.
            prefix_offset = next_iter_prefix_offset
            read_offset = next_iter_read_offset
            if prev_tokens is None:
93
                prev_tokens = next_iter_tokens.copy()
94
95
96
97
            else:
                prev_tokens.extend(next_iter_tokens)

    def decode_sequence_inplace(self, seq: Sequence,
98
                                prms: SamplingParams) -> int:
99
100
101
102
103
        """Decodes the new token for a sequence. In-place operation.

        Args:
            seq: The sequence to decode.
            prms: The sampling parameters used to generate the sequence.
104
105
106

        Returns:
            The number of characters added to the output text.
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
        """
        all_input_ids = seq.get_token_ids()
        token_id_generated_this_iteration = all_input_ids[-1]
        tokenizer = self.get_tokenizer_for_seq(seq)

        # Convert prompt token IDs to tokens if necessary.
        # Do it here so that we don't have to repeat this
        # computation for each logprob.
        if seq.tokens is None:
            (seq.tokens, seq.prefix_offset,
             seq.read_offset) = convert_prompt_ids_to_tokens(
                 tokenizer=tokenizer,
                 prompt_ids=all_input_ids[:-1],
                 skip_special_tokens=prms.skip_special_tokens,
             )

        (new_tokens, new_decoded_token_text, prefix_offset,
         read_offset) = detokenize_incrementally(
             tokenizer=tokenizer,
             all_input_ids=all_input_ids,
             prev_tokens=seq.tokens,
             prefix_offset=seq.prefix_offset,
             read_offset=seq.read_offset,
             skip_special_tokens=prms.skip_special_tokens,
             spaces_between_special_tokens=prms.spaces_between_special_tokens,
         )

        # Decode logprobs
        logprobs = seq.output_logprobs[-1]
        if logprobs:
            previous_tokens = all_input_ids[:-1]
            for token_id, sample_logprob in logprobs.items():
                # If the token was generated this iteration,
                # use the provided text.
                if token_id == token_id_generated_this_iteration:
                    sample_logprob.decoded_token = new_decoded_token_text
                    continue

                if (sample_logprob.decoded_token is None
146
                        and token_id != VLLM_INVALID_TOKEN_ID):
147
148
149
150
151
152
153
154
155
156
157
158
159
                    all_input_ids_with_logprob = previous_tokens + [token_id]
                    (_, new_text, _, _) = detokenize_incrementally(
                        tokenizer=tokenizer,
                        all_input_ids=all_input_ids_with_logprob,
                        prev_tokens=seq.tokens,
                        prefix_offset=seq.prefix_offset,
                        read_offset=seq.read_offset,
                        skip_special_tokens=prms.skip_special_tokens,
                        spaces_between_special_tokens=prms.
                        spaces_between_special_tokens,
                    )
                    sample_logprob.decoded_token = new_text

160
        seq.tokens.extend(new_tokens)
161
162
163
        seq.prefix_offset = prefix_offset
        seq.read_offset = read_offset
        seq.output_text += new_decoded_token_text
164

165
        return len(new_decoded_token_text)