detokenizer.py 6.78 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import Optional
5

6
7
8
from vllm.logprobs import Logprob
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, SamplingParams, Sequence,
                           SequenceGroup)
9

10
11
from .detokenizer_utils import (convert_prompt_ids_to_tokens,
                                detokenize_incrementally)
12
from .tokenizer import AnyTokenizer
13
14
15
16
17


class Detokenizer:
    """Provides methods to decode the output of a model into text."""

18
19
    def __init__(self, tokenizer: AnyTokenizer):
        self.tokenizer = tokenizer
20

21
    def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup,
22
                                       prompt_logprobs: list[Optional[dict[
23
24
                                           int, Logprob]]],
                                       position_offset: int) -> None:
25
26
27
28
29
        """Decodes the logprobs for the prompt of a sequence group.

        Args:
            seq_group: The sequence group to decode.
            prompt_logprobs: The logprobs to decode.
30
            position_offset: Offset of the first index of the logprobs
31
                relative to the start of the sequence (for chunked prefill).
32

33
34
35
36
        Returns:
            The prompt logprobs with the decoded tokens.
        """
        prms = seq_group.sampling_params
37
38
        assert prms is not None

39
        # We can pick any sequence for the prompt.
40
        seq = seq_group.get_seqs()[0]
41
42
43
44
45
46
47
        # Only prompt, without the generated token.
        all_token_ids = seq.get_token_ids()
        prompt_token_ids = all_token_ids[:-1]
        prefix_offset = 0
        read_offset = 0
        next_iter_prefix_offset = 0
        next_iter_read_offset = 0
48
        next_iter_tokens: list[str] = []
49
50
        prev_tokens = None

51
        for token_position_in_logprob, prompt_logprobs_for_token in enumerate(
52
                prompt_logprobs):
53
54
55
56
57

            # Absolute token position equals the index in the logprobs
            # list plus the offset of the entire logprobs list relative
            # to the start of the sequence.
            token_position = token_position_in_logprob + position_offset
58
59
60
61
            if not prompt_logprobs_for_token:
                continue
            for token_id, sample_logprob in prompt_logprobs_for_token.items():
                if (sample_logprob.decoded_token is None
62
                        and token_id != VLLM_INVALID_TOKEN_ID):
63
64
65
66
                    prompt_token_ids_with_token = (
                        prompt_token_ids[:token_position] + [token_id])
                    (new_tokens, new_text, new_prefix_offset,
                     new_read_offset) = detokenize_incrementally(
67
                         tokenizer=self.tokenizer,
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
                         all_input_ids=prompt_token_ids_with_token,
                         prev_tokens=prev_tokens,
                         prefix_offset=prefix_offset,
                         read_offset=read_offset,
                         skip_special_tokens=prms.skip_special_tokens,
                         spaces_between_special_tokens=prms.
                         spaces_between_special_tokens,
                     )

                    sample_logprob.decoded_token = new_text

                    # Use the offsets & prev tokens corresponding to
                    # real tokens to ensure detokenization is consistent
                    # actual with prompt.
                    if token_id == all_token_ids[token_position]:
                        next_iter_prefix_offset = new_prefix_offset
                        next_iter_read_offset = new_read_offset
                        next_iter_tokens = new_tokens

            # Advance to the next token position.
            prefix_offset = next_iter_prefix_offset
            read_offset = next_iter_read_offset
            if prev_tokens is None:
91
                prev_tokens = next_iter_tokens.copy()
92
93
94
95
            else:
                prev_tokens.extend(next_iter_tokens)

    def decode_sequence_inplace(self, seq: Sequence,
96
                                prms: SamplingParams) -> int:
97
98
99
100
101
        """Decodes the new token for a sequence. In-place operation.

        Args:
            seq: The sequence to decode.
            prms: The sampling parameters used to generate the sequence.
102
103
104

        Returns:
            The number of characters added to the output text.
105
106
107
108
109
110
111
112
113
114
        """
        all_input_ids = seq.get_token_ids()
        token_id_generated_this_iteration = all_input_ids[-1]

        # Convert prompt token IDs to tokens if necessary.
        # Do it here so that we don't have to repeat this
        # computation for each logprob.
        if seq.tokens is None:
            (seq.tokens, seq.prefix_offset,
             seq.read_offset) = convert_prompt_ids_to_tokens(
115
                 tokenizer=self.tokenizer,
116
117
118
119
120
121
                 prompt_ids=all_input_ids[:-1],
                 skip_special_tokens=prms.skip_special_tokens,
             )

        (new_tokens, new_decoded_token_text, prefix_offset,
         read_offset) = detokenize_incrementally(
122
             tokenizer=self.tokenizer,
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
             all_input_ids=all_input_ids,
             prev_tokens=seq.tokens,
             prefix_offset=seq.prefix_offset,
             read_offset=seq.read_offset,
             skip_special_tokens=prms.skip_special_tokens,
             spaces_between_special_tokens=prms.spaces_between_special_tokens,
         )

        # Decode logprobs
        logprobs = seq.output_logprobs[-1]
        if logprobs:
            previous_tokens = all_input_ids[:-1]
            for token_id, sample_logprob in logprobs.items():
                # If the token was generated this iteration,
                # use the provided text.
                if token_id == token_id_generated_this_iteration:
                    sample_logprob.decoded_token = new_decoded_token_text
                    continue

                if (sample_logprob.decoded_token is None
143
                        and token_id != VLLM_INVALID_TOKEN_ID):
144
145
                    all_input_ids_with_logprob = previous_tokens + [token_id]
                    (_, new_text, _, _) = detokenize_incrementally(
146
                        tokenizer=self.tokenizer,
147
148
149
150
151
152
153
154
155
156
                        all_input_ids=all_input_ids_with_logprob,
                        prev_tokens=seq.tokens,
                        prefix_offset=seq.prefix_offset,
                        read_offset=seq.read_offset,
                        skip_special_tokens=prms.skip_special_tokens,
                        spaces_between_special_tokens=prms.
                        spaces_between_special_tokens,
                    )
                    sample_logprob.decoded_token = new_text

157
        seq.tokens.extend(new_tokens)
158
159
160
        seq.prefix_offset = prefix_offset
        seq.read_offset = read_offset
        seq.output_text += new_decoded_token_text
161

162
        return len(new_decoded_token_text)