detokenizer.py 7.33 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import os
4
from typing import Dict, List, Optional
5

6
7
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams,
                           Sequence, SequenceGroup)
8

9
10
from .detokenizer_utils import (convert_prompt_ids_to_tokens,
                                detokenize_incrementally)
11
12
from .tokenizer import AnyTokenizer
from .tokenizer_group import BaseTokenizerGroup
13
14
15
16
17
18
19


class Detokenizer:
    """Provides methods to decode the output of a model into text."""

    def __init__(self, tokenizer_group: BaseTokenizerGroup):
        self.tokenizer_group = tokenizer_group
guanyu1's avatar
guanyu1 committed
20
        self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
21

22
    def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
23
24
25
        """Returns the HF tokenizer to use for a given sequence."""
        return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request)

26
27
28
29
    def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup,
                                       prompt_logprobs: List[Optional[Dict[
                                           int, Logprob]]],
                                       position_offset: int) -> None:
30
31
32
33
34
        """Decodes the logprobs for the prompt of a sequence group.

        Args:
            seq_group: The sequence group to decode.
            prompt_logprobs: The logprobs to decode.
35
36
            position_offset: Offset of the first index of the logprobs 
                relative to the start of the sequence (for chunked prefill).
37
38
39
40
41
        
        Returns:
            The prompt logprobs with the decoded tokens.
        """
        prms = seq_group.sampling_params
42
43
        assert prms is not None

44
        # We can pick any sequence for the prompt.
45
        seq = seq_group.get_seqs()[0]
46
47
48
49
50
51
52
53
        # Only prompt, without the generated token.
        all_token_ids = seq.get_token_ids()
        prompt_token_ids = all_token_ids[:-1]
        tokenizer = self.get_tokenizer_for_seq(seq)
        prefix_offset = 0
        read_offset = 0
        next_iter_prefix_offset = 0
        next_iter_read_offset = 0
54
        next_iter_tokens: List[str] = []
55
56
        prev_tokens = None

57
        for token_position_in_logprob, prompt_logprobs_for_token in enumerate(
58
                prompt_logprobs):
59
60
61
62
63

            # Absolute token position equals the index in the logprobs
            # list plus the offset of the entire logprobs list relative
            # to the start of the sequence.
            token_position = token_position_in_logprob + position_offset
64
65
66
67
            if not prompt_logprobs_for_token:
                continue
            for token_id, sample_logprob in prompt_logprobs_for_token.items():
                if (sample_logprob.decoded_token is None
68
                        and token_id != VLLM_INVALID_TOKEN_ID):
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
                    prompt_token_ids_with_token = (
                        prompt_token_ids[:token_position] + [token_id])
                    (new_tokens, new_text, new_prefix_offset,
                     new_read_offset) = detokenize_incrementally(
                         tokenizer=tokenizer,
                         all_input_ids=prompt_token_ids_with_token,
                         prev_tokens=prev_tokens,
                         prefix_offset=prefix_offset,
                         read_offset=read_offset,
                         skip_special_tokens=prms.skip_special_tokens,
                         spaces_between_special_tokens=prms.
                         spaces_between_special_tokens,
                     )

                    sample_logprob.decoded_token = new_text

                    # Use the offsets & prev tokens corresponding to
                    # real tokens to ensure detokenization is consistent
                    # actual with prompt.
                    if token_id == all_token_ids[token_position]:
                        next_iter_prefix_offset = new_prefix_offset
                        next_iter_read_offset = new_read_offset
                        next_iter_tokens = new_tokens

            # Advance to the next token position.
            prefix_offset = next_iter_prefix_offset
            read_offset = next_iter_read_offset
            if prev_tokens is None:
97
                prev_tokens = next_iter_tokens.copy()
98
99
100
101
            else:
                prev_tokens.extend(next_iter_tokens)

    def decode_sequence_inplace(self, seq: Sequence,
102
                                prms: SamplingParams) -> int:
103
104
105
106
107
        """Decodes the new token for a sequence. In-place operation.

        Args:
            seq: The sequence to decode.
            prms: The sampling parameters used to generate the sequence.
108
109
110

        Returns:
            The number of characters added to the output text.
111
        """
112
        all_input_ids = seq.get_token_ids()        
guanyu1's avatar
guanyu1 committed
113
        if self.zero_overhead:
114
115
116
            eff_length = seq.get_prompt_len() + seq.data._effective_length
            all_input_ids = seq.get_token_ids()[ : eff_length]

117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        token_id_generated_this_iteration = all_input_ids[-1]
        tokenizer = self.get_tokenizer_for_seq(seq)

        # Convert prompt token IDs to tokens if necessary.
        # Do it here so that we don't have to repeat this
        # computation for each logprob.
        if seq.tokens is None:
            (seq.tokens, seq.prefix_offset,
             seq.read_offset) = convert_prompt_ids_to_tokens(
                 tokenizer=tokenizer,
                 prompt_ids=all_input_ids[:-1],
                 skip_special_tokens=prms.skip_special_tokens,
             )

        (new_tokens, new_decoded_token_text, prefix_offset,
         read_offset) = detokenize_incrementally(
             tokenizer=tokenizer,
             all_input_ids=all_input_ids,
             prev_tokens=seq.tokens,
             prefix_offset=seq.prefix_offset,
             read_offset=seq.read_offset,
             skip_special_tokens=prms.skip_special_tokens,
             spaces_between_special_tokens=prms.spaces_between_special_tokens,
         )

        # Decode logprobs
        logprobs = seq.output_logprobs[-1]
        if logprobs:
            previous_tokens = all_input_ids[:-1]
            for token_id, sample_logprob in logprobs.items():
                # If the token was generated this iteration,
                # use the provided text.
                if token_id == token_id_generated_this_iteration:
                    sample_logprob.decoded_token = new_decoded_token_text
                    continue

                if (sample_logprob.decoded_token is None
154
                        and token_id != VLLM_INVALID_TOKEN_ID):
155
156
157
158
159
160
161
162
163
164
165
166
167
                    all_input_ids_with_logprob = previous_tokens + [token_id]
                    (_, new_text, _, _) = detokenize_incrementally(
                        tokenizer=tokenizer,
                        all_input_ids=all_input_ids_with_logprob,
                        prev_tokens=seq.tokens,
                        prefix_offset=seq.prefix_offset,
                        read_offset=seq.read_offset,
                        skip_special_tokens=prms.skip_special_tokens,
                        spaces_between_special_tokens=prms.
                        spaces_between_special_tokens,
                    )
                    sample_logprob.decoded_token = new_text

168
        seq.tokens.extend(new_tokens)
169
170
171
        seq.prefix_offset = prefix_offset
        seq.read_offset = read_offset
        seq.output_text += new_decoded_token_text
172

173
        return len(new_decoded_token_text)