Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
05fcd1b4
Unverified
Commit
05fcd1b4
authored
Apr 17, 2025
by
Nick Hill
Committed by
GitHub
Apr 17, 2025
Browse files
[V1][Perf] Faster incremental detokenization (#15137)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
7c02d6a1
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
316 additions
and
144 deletions
+316
-144
requirements/common.txt
requirements/common.txt
+1
-1
requirements/test.in
requirements/test.in
+1
-0
requirements/test.txt
requirements/test.txt
+4
-2
tests/lora/test_llama_tp.py
tests/lora/test_llama_tp.py
+1
-0
tests/tokenization/test_detokenize.py
tests/tokenization/test_detokenize.py
+137
-55
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
+9
-0
vllm/v1/engine/detokenizer.py
vllm/v1/engine/detokenizer.py
+163
-86
No files found.
requirements/common.txt
View file @
05fcd1b4
...
...
@@ -8,7 +8,7 @@ blake3
py-cpuinfo
transformers >= 4.51.1
huggingface-hub[hf_xet] >= 0.30.0 # Required for Xet downloads.
tokenizers >= 0.1
9
.1 # Required for
Llama 3
.
tokenizers >= 0.
2
1.1 # Required for
fast incremental detokenization
.
protobuf # Required by LlamaTokenizer.
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.
aiohttp
...
...
requirements/test.in
View file @
05fcd1b4
...
...
@@ -35,6 +35,7 @@ opencv-python-headless >= 4.11.0 # required for video test
datamodel_code_generator # required for minicpm3 test
lm-eval[api]==0.4.8 # required for model evaluation test
transformers==4.51.1
tokenizers==0.21.1
huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads.
# quantization
bitsandbytes>=0.45.3
...
...
requirements/test.txt
View file @
05fcd1b4
...
...
@@ -624,8 +624,10 @@ tiktoken==0.7.0
# mistral-common
timm==1.0.11
# via -r requirements/test.in
tokenizers==0.21.0
# via transformers
tokenizers==0.21.1
# via
# -r requirements/test.in
# transformers
torch==2.6.0
# via
# -r requirements/test.in
...
...
tests/lora/test_llama_tp.py
View file @
05fcd1b4
...
...
@@ -47,6 +47,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
]
sampling_params
=
vllm
.
SamplingParams
(
temperature
=
0
,
max_tokens
=
256
,
skip_special_tokens
=
False
,
stop
=
[
"[/assistant]"
])
outputs
=
llm
.
generate
(
prompts
,
...
...
tests/tokenization/test_detokenize.py
View file @
05fcd1b4
...
...
@@ -4,14 +4,22 @@ from collections.abc import Generator
from
typing
import
Any
,
Optional
import
pytest
from
transformers
import
AutoTokenizer
from
transformers
import
(
AutoTokenizer
,
PreTrainedTokenizer
,
PreTrainedTokenizerFast
)
from
vllm.inputs
import
token_inputs
from
vllm.sequence
import
Logprob
,
SamplingParams
,
Sequence
,
SequenceGroup
from
vllm.transformers_utils.detokenizer
import
(
Detokenizer
,
detokenize_incrementally
)
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.tokenizer_group
import
get_tokenizer_group
from
vllm.transformers_utils.tokenizers.mistral
import
MistralTokenizer
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine.detokenizer
import
(
FastIncrementalDetokenizer
,
IncrementalDetokenizer
,
SlowIncrementalDetokenizer
)
SPECIAL_TOKS_TRUTH
=
[
"Some text with adjacent special tokens <|padding|><|padding|><fim_prefix><fim_middle><fim_suffix>other text<fim_pad>"
,
# noqa
]
TRUTH
=
[
"Hello here, this is a simple test"
,
...
...
@@ -22,7 +30,8 @@ TRUTH = [
# incomplete UTF-8 characters
# see https://github.com/vllm-project/vllm/pull/9625
"ပုံပြင်လေးပြောပြပါ်"
,
]
]
+
SPECIAL_TOKS_TRUTH
TOKENIZERS
=
[
"facebook/opt-125m"
,
"gpt2"
,
...
...
@@ -38,26 +47,37 @@ TOKENIZERS = [
]
def
_run_incremental_decode
(
tokenizer
,
all_input_ids
,
skip_special_tokens
:
bool
,
starting_index
:
int
):
decoded_text
=
""
offset
=
0
token_offset
=
0
prev_tokens
=
None
for
i
in
range
(
starting_index
,
len
(
all_input_ids
)):
new_tokens
,
text
,
offset
,
token_offset
=
detokenize_incrementally
(
tokenizer
,
all_input_ids
[:
i
+
1
],
prev_tokens
,
offset
,
token_offset
,
skip_special_tokens
=
skip_special_tokens
)
decoded_text
+=
text
if
prev_tokens
is
None
:
prev_tokens
=
new_tokens
else
:
prev_tokens
+=
new_tokens
return
decoded_text
def
_run_incremental_decode
(
tokenizer
,
all_input_ids
,
skip_special_tokens
:
bool
,
starting_index
:
int
,
spaces_between_special_tokens
:
bool
=
True
,
fast
:
Optional
[
bool
]
=
None
):
prompt_token_ids
=
all_input_ids
[:
starting_index
]
params
=
SamplingParams
(
skip_special_tokens
=
skip_special_tokens
,
spaces_between_special_tokens
=
spaces_between_special_tokens
,
)
request
=
EngineCoreRequest
(
""
,
""
,
prompt_token_ids
,
None
,
None
,
None
,
params
,
None
,
0.0
,
None
)
if
fast
is
None
:
detokenizer
=
IncrementalDetokenizer
.
from_new_request
(
tokenizer
,
request
)
elif
fast
:
detokenizer
=
FastIncrementalDetokenizer
(
tokenizer
,
request
)
else
:
detokenizer
=
SlowIncrementalDetokenizer
(
tokenizer
,
request
)
output_text
=
""
for
i
,
token_id
in
enumerate
(
all_input_ids
[
starting_index
:]):
detokenizer
.
update
([
token_id
],
False
)
finished
=
i
==
len
(
all_input_ids
)
-
1
output_text
+=
detokenizer
.
get_next_output_text
(
finished
,
delta
=
True
)
return
output_text
,
detokenizer
.
output_token_ids
@
pytest
.
fixture
...
...
@@ -85,11 +105,13 @@ def test_mistral_edge_case(tokenizer, truth):
starting_index
=
0
all_input_ids
=
tokenizer
(
truth
,
add_special_tokens
=
False
).
input_ids
decoded_text
=
_run_incremental_decode
(
tokenizer
,
all_input_ids
,
skip_special_tokens
=
True
,
starting_index
=
starting_index
)
decoded_text
,
out_ids
=
_run_incremental_decode
(
tokenizer
,
all_input_ids
,
skip_special_tokens
=
True
,
starting_index
=
starting_index
)
assert
decoded_text
==
truth
assert
out_ids
==
all_input_ids
[
starting_index
:]
@
pytest
.
fixture
...
...
@@ -106,40 +128,86 @@ def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
@
pytest
.
mark
.
parametrize
(
"with_prompt"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"tokenizer_name"
,
TOKENIZERS
)
@
pytest
.
mark
.
parametrize
(
"skip_special_tokens"
,
(
True
,
False
),
indirect
=
True
)
def
test_decode_streaming
(
tokenizer
,
truth
,
with_prompt
,
skip_special_tokens
):
@
pytest
.
mark
.
parametrize
(
"spaces_between_special_tokens"
,
(
True
,
False
))
@
pytest
.
mark
.
parametrize
(
"fast"
,
(
True
,
False
))
def
test_decode_streaming
(
tokenizer
,
truth
,
with_prompt
,
skip_special_tokens
,
spaces_between_special_tokens
,
fast
):
if
fast
and
not
isinstance
(
tokenizer
,
PreTrainedTokenizerFast
):
pytest
.
skip
()
if
skip_special_tokens
and
not
spaces_between_special_tokens
:
pytest
.
skip
()
if
not
fast
and
isinstance
(
tokenizer
,
PreTrainedTokenizerFast
):
# Fix up inconsistency in fast/slow tokenizer behaviour.
tokenizer
.
add_special_tokens
({
"additional_special_tokens"
:
[
at
for
at
in
tokenizer
.
_tokenizer
.
get_added_tokens_decoder
().
values
()
if
at
.
special
]
})
extra_decode_args
=
{}
if
not
isinstance
(
tokenizer
,
PreTrainedTokenizer
)
\
else
{
"spaces_between_special_tokens"
:
spaces_between_special_tokens
}
truth_tokens
=
tokenizer
(
truth
,
add_special_tokens
=
False
).
input_ids
if
tokenizer
.
bos_token_id
is
not
None
:
truth_tokens
.
insert
(
0
,
tokenizer
.
bos_token_id
)
truth_tokens
.
append
(
tokenizer
.
eos_token_id
)
new_truth
=
tokenizer
.
decode
(
truth_tokens
,
skip_special_tokens
=
skip_special_tokens
,
**
extra_decode_args
)
if
with_prompt
:
truth_tokens
=
tokenizer
(
truth
,
add_special_tokens
=
False
).
input_ids
prompt_input_ids
=
truth_tokens
[:
len
(
truth
)
//
2
]
generated_input_ids
=
truth_tokens
[
len
(
truth
)
//
2
:]
num_prompt_tokens
=
len
(
tokenizer
(
truth
[:
len
(
truth
)
//
2
],
add_special_tokens
=
False
).
input_ids
)
if
tokenizer
.
bos_token_id
is
not
None
:
num_prompt_tokens
+=
1
prompt_input_ids
=
truth_tokens
[:
num_prompt_tokens
]
generated_input_ids
=
truth_tokens
[
num_prompt_tokens
:]
all_input_ids
=
prompt_input_ids
+
generated_input_ids
starting_index
=
len
(
prompt_input_ids
)
prompt
=
tokenizer
.
decode
(
prompt_input_ids
,
skip_special_tokens
=
skip_special_tokens
)
generated
=
truth
[
len
(
prompt
):]
skip_special_tokens
=
skip_special_tokens
,
**
extra_decode_args
)
generated
=
new_truth
[
len
(
prompt
):]
else
:
generated
=
truth
generated
=
new_
truth
starting_index
=
0
all_input_ids
=
tokenizer
(
truth
,
add_special_tokens
=
False
).
input_ids
if
skip_special_tokens
:
if
tokenizer
.
bos_token_id
is
not
None
:
all_input_ids
=
[
tokenizer
.
bos_token_id
]
+
all_input_ids
starting_index
+=
1
all_input_ids
=
all_input_ids
+
[
tokenizer
.
eos_token_id
]
all_input_ids
=
truth_tokens
decoded_text
=
_run_incremental_decode
(
decoded_text
,
out_ids
=
_run_incremental_decode
(
tokenizer
,
all_input_ids
,
skip_special_tokens
=
skip_special_tokens
,
starting_index
=
starting_index
)
starting_index
=
starting_index
,
spaces_between_special_tokens
=
spaces_between_special_tokens
,
fast
=
fast
)
assert
decoded_text
==
generated
assert
out_ids
==
all_input_ids
[
starting_index
:]
decoded_text
=
_run_incremental_decode
(
@
pytest
.
mark
.
parametrize
(
"tokenizer_name"
,
TOKENIZERS
)
@
pytest
.
mark
.
parametrize
(
"fast"
,
(
True
,
False
))
def
test_oov_decode
(
tokenizer
,
fast
):
if
fast
and
not
isinstance
(
tokenizer
,
PreTrainedTokenizerFast
):
pytest
.
skip
()
decoded_text
,
out_ids
=
_run_incremental_decode
(
tokenizer
,
[
len
(
tokenizer
)],
skip_special_tokens
=
skip_special_tokens
,
starting_index
=
starting_index
)
skip_special_tokens
=
True
,
starting_index
=
0
,
spaces_between_special_tokens
=
True
,
fast
=
fast
)
assert
decoded_text
==
''
assert
out_ids
==
[
len
(
tokenizer
)]
@
pytest
.
fixture
...
...
@@ -165,15 +233,14 @@ def detokenizer(tokenizer_name: str) -> Detokenizer:
@
pytest
.
fixture
(
name
=
"complete_sequence_token_ids"
)
def
create_complete_sequence_token_ids
(
complete_sequence
:
str
,
tokenizer
)
->
list
[
int
]:
complete_sequence_token_ids
=
tokenizer
(
complete_sequence
).
input_ids
return
complete_sequence_token_ids
return
tokenizer
(
complete_sequence
,
add_special_tokens
=
False
).
input_ids
def
create_sequence
(
prompt_token_ids
=
None
):
prompt_token_ids
=
prompt_token_ids
or
[
1
]
prompt_token_ids
=
prompt_token_ids
or
[]
return
Sequence
(
seq_id
=
0
,
inputs
=
token_inputs
(
prompt_token_ids
,
prompt
=
"<s>"
),
inputs
=
token_inputs
(
prompt_token_ids
),
block_size
=
16
,
)
...
...
@@ -224,7 +291,7 @@ def test_decode_sequence_logprobs(complete_sequence: str,
assert
sequential_result
==
""
.
join
(
sequential_logprobs_text_chosen_token
)
assert
sequential_result
!=
""
.
join
(
sequential_logprobs_text_other_token
)
if
skip_special_tokens
:
if
not
skip_special_tokens
:
# Text for logprobs for the chosen token should be the same as the
# generated text. Note that this will only be true if we skip
# special tokens.
...
...
@@ -233,10 +300,23 @@ def test_decode_sequence_logprobs(complete_sequence: str,
@
pytest
.
mark
.
parametrize
(
"complete_sequence"
,
TRUTH
)
@
pytest
.
mark
.
parametrize
(
"tokenizer_name"
,
TOKENIZERS
)
def
test_decode_prompt_logprobs
(
complete_sequence_token_ids
:
list
[
int
],
def
test_decode_prompt_logprobs
(
complete_sequence
:
str
,
complete_sequence_token_ids
:
list
[
int
],
detokenizer
:
Detokenizer
):
# We want to use skip_special_tokens=False here but Mistral tokenizers
# don't support that.
if
complete_sequence
not
in
SPECIAL_TOKS_TRUTH
:
skip_special_tokens
=
True
elif
not
isinstance
(
detokenizer
.
tokenizer_group
.
get_lora_tokenizer
(
None
),
MistralTokenizer
):
skip_special_tokens
=
False
else
:
pytest
.
skip
(
"MistralTokenizers don't support "
"skip_special_tokens=False"
)
return
"""Verify Detokenizer decodes prompt logprobs correctly."""
sampling_params
=
SamplingParams
(
skip_special_tokens
=
True
,
sampling_params
=
SamplingParams
(
skip_special_tokens
=
skip_special_tokens
,
prompt_logprobs
=
1
)
# Run sequentially.
...
...
@@ -256,8 +336,10 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int],
# decoded_prompt_logprobs doesn't contain the first token.
token_ids
=
complete_sequence_token_ids
tokenizer
=
detokenizer
.
get_tokenizer_for_seq
(
seq
)
text_full
=
tokenizer
.
decode
(
token_ids
,
skip_special_tokens
=
True
)
text_first
=
tokenizer
.
decode
(
token_ids
[
0
],
skip_special_tokens
=
True
)
text_full
=
tokenizer
.
decode
(
token_ids
,
skip_special_tokens
=
skip_special_tokens
)
text_first
=
tokenizer
.
decode
(
token_ids
[
0
],
skip_special_tokens
=
skip_special_tokens
)
text
=
text_full
[
len
(
text_first
):]
# Text for logprobs for the chosen token should be the same as the
...
...
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
View file @
05fcd1b4
...
...
@@ -70,6 +70,15 @@ class MistralToolParser(ToolParser):
"Mistral Tool Parser could not locate the tool call token in "
"the tokenizer!"
)
def
adjust_request
(
self
,
request
:
ChatCompletionRequest
)
->
ChatCompletionRequest
:
if
request
.
tools
and
request
.
tool_choice
!=
'none'
:
# do not skip special tokens because mistral uses the special
# tokens to indicate the start and end of the tool calls
# information.
request
.
skip_special_tokens
=
False
return
request
def
extract_tool_calls
(
self
,
model_output
:
str
,
...
...
vllm/v1/engine/detokenizer.py
View file @
05fcd1b4
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
,
field
from
abc
import
ABC
,
abstractmethod
from
typing
import
Optional
from
tokenizers
import
Tokenizer
from
tokenizers.decoders
import
DecodeStream
from
transformers
import
PreTrainedTokenizerFast
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.detokenizer_utils
import
(
...
...
@@ -12,39 +15,22 @@ from vllm.v1.engine import EngineCoreRequest
logger
=
init_logger
(
__name__
)
@
dataclass
class
IncrementalDetokenizer
:
# Generation data
token_ids
:
list
[
int
]
output_text
:
str
=
""
tokens
:
list
[
str
]
=
field
(
default_factory
=
list
)
prompt_len
:
int
=
0
# Stop strings
stop
:
list
[
str
]
=
field
(
default_factory
=
list
)
include_stop_str_in_output
:
bool
=
False
# Metadata for incremental detokenization
prefix_offset
:
int
=
0
read_offset
:
int
=
0
# Parameters for detokenization
skip_special_tokens
:
bool
=
True
spaces_between_special_tokens
:
bool
=
True
# Tokenizer for this request,
# None if detokenization is disabled.
tokenizer
:
Optional
[
AnyTokenizer
]
=
None
# Accounting for stop string buffering
stop_buffer_length
:
int
=
0
_last_output_text_offset
:
int
=
0
def
__init__
(
self
):
self
.
token_ids
:
list
[
int
]
=
[]
@
property
def
output_token_ids
(
self
)
->
list
[
int
]:
return
self
.
token_ids
if
not
self
.
prompt_len
else
(
self
.
token_ids
[
self
.
prompt_len
:])
return
self
.
token_ids
def
update
(
self
,
new_token_ids
:
list
[
int
],
stop_terminated
:
bool
)
->
Optional
[
str
]:
self
.
token_ids
.
extend
(
new_token_ids
)
return
None
def
get_next_output_text
(
self
,
finished
:
bool
,
delta
:
bool
)
->
str
:
return
""
@
classmethod
def
from_new_request
(
...
...
@@ -54,39 +40,37 @@ class IncrementalDetokenizer:
)
->
"IncrementalDetokenizer"
:
if
tokenizer
is
None
:
return
cls
(
token_ids
=
[])
# No tokenizer => skipping detokenization.
return
IncrementalDetokenizer
()
if
isinstance
(
tokenizer
,
PreTrainedTokenizerFast
):
# Fast tokenizer => use tokenizers library DecodeStream.
return
FastIncrementalDetokenizer
(
tokenizer
,
request
)
# Fall back to slow python-based incremental detokenization.
return
SlowIncrementalDetokenizer
(
tokenizer
,
request
)
class
BaseIncrementalDetokenizer
(
IncrementalDetokenizer
,
ABC
):
def
__init__
(
self
,
request
:
EngineCoreRequest
):
super
().
__init__
()
tokens
,
prefix_offset
,
read_offset
=
convert_prompt_ids_to_tokens
(
tokenizer
=
tokenizer
,
prompt_ids
=
request
.
prompt_token_ids
,
skip_special_tokens
=
request
.
sampling_params
.
skip_special_tokens
,
)
# Stop strings
params
=
request
.
sampling_params
self
.
stop
=
stop
=
params
.
stop
self
.
include_stop_str_in_output
=
params
.
include_stop_str_in_output
stops
=
request
.
sampling_params
.
stop
# Number of chars to hold back when stop strings are to be excluded
# from streamed output.
if
stop
s
and
not
request
.
sampling_params
.
include_stop_str_in_output
:
stop_buffer_length
=
max
(
len
(
s
)
for
s
in
stop
s
)
-
1
if
stop
and
not
self
.
include_stop_str_in_output
:
self
.
stop_buffer_length
=
max
(
len
(
s
)
for
s
in
stop
)
-
1
else
:
stop_buffer_length
=
0
return
cls
(
tokens
=
tokens
,
# Detokenizer mutates this list, so need a unique copy.
# NOTE(Nick): could we take ownership of it though?
token_ids
=
request
.
prompt_token_ids
.
copy
(),
stop
=
stops
,
include_stop_str_in_output
=
request
.
sampling_params
.
include_stop_str_in_output
,
prefix_offset
=
prefix_offset
,
read_offset
=
read_offset
,
skip_special_tokens
=
request
.
sampling_params
.
skip_special_tokens
,
spaces_between_special_tokens
=
request
.
sampling_params
.
spaces_between_special_tokens
,
prompt_len
=
len
(
request
.
prompt_token_ids
),
tokenizer
=
tokenizer
,
stop_buffer_length
=
stop_buffer_length
,
)
self
.
stop_buffer_length
=
0
self
.
_last_output_text_offset
:
int
=
0
# Generation data
self
.
output_text
=
""
def
update
(
self
,
new_token_ids
:
list
[
int
],
stop_terminated
:
bool
)
->
Optional
[
str
]:
...
...
@@ -98,11 +82,7 @@ class IncrementalDetokenizer:
Return matched stop string or None.
"""
if
not
new_token_ids
:
# Skip detokenization if no new token ids
return
None
if
self
.
tokenizer
is
None
:
# Skip detokenization if no tokenizer
self
.
token_ids
.
extend
(
new_token_ids
)
# Skip detokenization if no new token ids.
return
None
if
stop_terminated
and
not
self
.
include_stop_str_in_output
:
...
...
@@ -116,34 +96,16 @@ class IncrementalDetokenizer:
# 1) Detokenize the new token ids incrementally.
# TODO(woosuk): This method becomes very inefficient when the number of
# new_token_ids is more than 1. We need to optimize this.
decoded_text
=
""
offset_before
=
len
(
self
.
output_text
)
for
new_token_id
in
new_token_ids
:
self
.
token_ids
.
append
(
new_token_id
)
(
new_tokens
,
new_decoded_token_text
,
prefix_offset
,
read_offset
)
=
detokenize_incrementally
(
tokenizer
=
self
.
tokenizer
,
all_input_ids
=
self
.
token_ids
,
prev_tokens
=
self
.
tokens
,
prefix_offset
=
self
.
prefix_offset
,
read_offset
=
self
.
read_offset
,
skip_special_tokens
=
self
.
skip_special_tokens
,
spaces_between_special_tokens
=
self
.
spaces_between_special_tokens
,
)
self
.
tokens
.
extend
(
new_tokens
)
self
.
prefix_offset
=
prefix_offset
self
.
read_offset
=
read_offset
decoded_text
+=
new_decoded_token_text
self
.
output_text
+=
decoded_text
self
.
output_text
+=
self
.
decode_next
(
new_token_id
)
if
stop_terminated
:
if
skipped_stop_token_id
is
not
None
:
# Cleanup after skipping detokenization
# Cleanup after skipping detokenization
.
self
.
token_ids
.
append
(
skipped_stop_token_id
)
# Stop token triggered; skip stop string check
# Stop token triggered; skip stop string check
.
return
None
# 2) Evaluate stop strings.
...
...
@@ -151,7 +113,7 @@ class IncrementalDetokenizer:
if
self
.
stop
:
stop
=
StopChecker
.
check_stop_strings
(
output_text
=
self
.
output_text
,
new_char_count
=
len
(
decoded_text
)
,
new_char_count
=
len
(
self
.
output_text
)
-
offset_before
,
stop
=
self
.
stop
,
include_in_output
=
self
.
include_stop_str_in_output
,
)
...
...
@@ -162,6 +124,10 @@ class IncrementalDetokenizer:
return
stop_string
@
abstractmethod
def
decode_next
(
self
,
next_token_id
:
int
)
->
str
:
raise
NotImplementedError
def
get_next_output_text
(
self
,
finished
:
bool
,
delta
:
bool
)
->
str
:
"""If delta is True, only new text since the last call to
this method is returned"""
...
...
@@ -177,3 +143,114 @@ class IncrementalDetokenizer:
self
.
_last_output_text_offset
=
length
return
self
.
output_text
[
last_offset
:
length
]
return
""
class
FastIncrementalDetokenizer
(
BaseIncrementalDetokenizer
):
def
__init__
(
self
,
tokenizer
:
PreTrainedTokenizerFast
,
request
:
EngineCoreRequest
):
super
().
__init__
(
request
)
sampling_params
=
request
.
sampling_params
self
.
stream
=
DecodeStream
(
skip_special_tokens
=
sampling_params
.
skip_special_tokens
)
self
.
tokenizer
:
Tokenizer
=
tokenizer
.
_tokenizer
# Find a safe place to start.
prompt_suffix
=
request
.
prompt_token_ids
prompt_len
=
len
(
prompt_suffix
)
if
prompt_len
>
4
:
for
i
in
range
(
4
,
max
(
prompt_len
+
1
,
32
)):
suffix
=
request
.
prompt_token_ids
[
-
i
:]
if
'�'
not
in
self
.
tokenizer
.
decode
(
suffix
):
prompt_suffix
=
suffix
break
# Prime the stream.
for
tid
in
prompt_suffix
:
self
.
stream
.
step
(
self
.
tokenizer
,
tid
)
self
.
spaces_between_special_tokens
=
(
sampling_params
.
skip_special_tokens
or
sampling_params
.
spaces_between_special_tokens
)
if
not
self
.
spaces_between_special_tokens
:
# Store dict of added token ids so that we can suppress
# the spaces between them.
if
(
added_token_ids
:
=
getattr
(
self
.
tokenizer
,
"added_token_ids"
,
None
))
is
None
:
self
.
tokenizer
.
added_token_ids
=
added_token_ids
=
{
tid
:
tok
.
content
for
tid
,
tok
in
self
.
tokenizer
.
get_added_tokens_decoder
().
items
()
}
if
added_token_ids
:
self
.
last_special
=
False
self
.
added_token_ids
=
added_token_ids
else
:
# No added tokens.
self
.
spaces_between_special_tokens
=
True
def
decode_next
(
self
,
next_token_id
:
int
)
->
str
:
token
=
self
.
stream
.
step
(
self
.
tokenizer
,
next_token_id
)
if
not
self
.
spaces_between_special_tokens
:
special_token
=
self
.
added_token_ids
.
get
(
next_token_id
)
is_special
=
special_token
is
not
None
if
is_special
and
self
.
last_special
:
# Return raw token string without any prefixed spaces.
token
=
special_token
self
.
last_special
=
is_special
return
token
or
""
class
SlowIncrementalDetokenizer
(
BaseIncrementalDetokenizer
):
def
__init__
(
self
,
tokenizer
:
AnyTokenizer
,
request
:
EngineCoreRequest
):
super
().
__init__
(
request
)
self
.
tokenizer
=
tokenizer
# Metadata for incremental detokenization.
self
.
tokens
,
self
.
prefix_offset
,
self
.
read_offset
=
(
convert_prompt_ids_to_tokens
(
tokenizer
=
tokenizer
,
prompt_ids
=
request
.
prompt_token_ids
,
skip_special_tokens
=
request
.
sampling_params
.
skip_special_tokens
,
))
self
.
token_ids
.
extend
(
request
.
prompt_token_ids
)
self
.
prompt_len
=
len
(
request
.
prompt_token_ids
)
params
=
request
.
sampling_params
self
.
skip_special_tokens
=
params
.
skip_special_tokens
self
.
spaces_between_special_tokens
=
(
params
.
spaces_between_special_tokens
)
@
property
def
output_token_ids
(
self
)
->
list
[
int
]:
return
self
.
token_ids
if
not
self
.
prompt_len
else
(
self
.
token_ids
[
self
.
prompt_len
:])
def
decode_next
(
self
,
next_token_id
:
int
)
->
str
:
new_tokens
,
decoded_text
,
prefix_offset
,
read_offset
=
(
detokenize_incrementally
(
tokenizer
=
self
.
tokenizer
,
all_input_ids
=
self
.
token_ids
,
prev_tokens
=
self
.
tokens
,
prefix_offset
=
self
.
prefix_offset
,
read_offset
=
self
.
read_offset
,
skip_special_tokens
=
self
.
skip_special_tokens
,
spaces_between_special_tokens
=
self
.
spaces_between_special_tokens
,
))
self
.
tokens
.
extend
(
new_tokens
)
self
.
prefix_offset
=
prefix_offset
self
.
read_offset
=
read_offset
return
decoded_text
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment