Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
05fcd1b4
Unverified
Commit
05fcd1b4
authored
Apr 17, 2025
by
Nick Hill
Committed by
GitHub
Apr 17, 2025
Browse files
[V1][Perf] Faster incremental detokenization (#15137)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
7c02d6a1
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
316 additions
and
144 deletions
+316
-144
requirements/common.txt
requirements/common.txt
+1
-1
requirements/test.in
requirements/test.in
+1
-0
requirements/test.txt
requirements/test.txt
+4
-2
tests/lora/test_llama_tp.py
tests/lora/test_llama_tp.py
+1
-0
tests/tokenization/test_detokenize.py
tests/tokenization/test_detokenize.py
+137
-55
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
+9
-0
vllm/v1/engine/detokenizer.py
vllm/v1/engine/detokenizer.py
+163
-86
No files found.
requirements/common.txt
View file @
05fcd1b4
...
@@ -8,7 +8,7 @@ blake3
...
@@ -8,7 +8,7 @@ blake3
py-cpuinfo
py-cpuinfo
transformers >= 4.51.1
transformers >= 4.51.1
huggingface-hub[hf_xet] >= 0.30.0 # Required for Xet downloads.
huggingface-hub[hf_xet] >= 0.30.0 # Required for Xet downloads.
tokenizers >= 0.1
9
.1 # Required for
Llama 3
.
tokenizers >= 0.
2
1.1 # Required for
fast incremental detokenization
.
protobuf # Required by LlamaTokenizer.
protobuf # Required by LlamaTokenizer.
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.
aiohttp
aiohttp
...
...
requirements/test.in
View file @
05fcd1b4
...
@@ -35,6 +35,7 @@ opencv-python-headless >= 4.11.0 # required for video test
...
@@ -35,6 +35,7 @@ opencv-python-headless >= 4.11.0 # required for video test
datamodel_code_generator # required for minicpm3 test
datamodel_code_generator # required for minicpm3 test
lm-eval[api]==0.4.8 # required for model evaluation test
lm-eval[api]==0.4.8 # required for model evaluation test
transformers==4.51.1
transformers==4.51.1
tokenizers==0.21.1
huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads.
huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads.
# quantization
# quantization
bitsandbytes>=0.45.3
bitsandbytes>=0.45.3
...
...
requirements/test.txt
View file @
05fcd1b4
...
@@ -624,8 +624,10 @@ tiktoken==0.7.0
...
@@ -624,8 +624,10 @@ tiktoken==0.7.0
# mistral-common
# mistral-common
timm==1.0.11
timm==1.0.11
# via -r requirements/test.in
# via -r requirements/test.in
tokenizers==0.21.0
tokenizers==0.21.1
# via transformers
# via
# -r requirements/test.in
# transformers
torch==2.6.0
torch==2.6.0
# via
# via
# -r requirements/test.in
# -r requirements/test.in
...
...
tests/lora/test_llama_tp.py
View file @
05fcd1b4
...
@@ -47,6 +47,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
...
@@ -47,6 +47,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
]
]
sampling_params
=
vllm
.
SamplingParams
(
temperature
=
0
,
sampling_params
=
vllm
.
SamplingParams
(
temperature
=
0
,
max_tokens
=
256
,
max_tokens
=
256
,
skip_special_tokens
=
False
,
stop
=
[
"[/assistant]"
])
stop
=
[
"[/assistant]"
])
outputs
=
llm
.
generate
(
outputs
=
llm
.
generate
(
prompts
,
prompts
,
...
...
tests/tokenization/test_detokenize.py
View file @
05fcd1b4
...
@@ -4,14 +4,22 @@ from collections.abc import Generator
...
@@ -4,14 +4,22 @@ from collections.abc import Generator
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
import
pytest
import
pytest
from
transformers
import
AutoTokenizer
from
transformers
import
(
AutoTokenizer
,
PreTrainedTokenizer
,
PreTrainedTokenizerFast
)
from
vllm.inputs
import
token_inputs
from
vllm.inputs
import
token_inputs
from
vllm.sequence
import
Logprob
,
SamplingParams
,
Sequence
,
SequenceGroup
from
vllm.sequence
import
Logprob
,
SamplingParams
,
Sequence
,
SequenceGroup
from
vllm.transformers_utils.detokenizer
import
(
Detokenizer
,
from
vllm.transformers_utils.detokenizer
import
Detokenizer
detokenize_incrementally
)
from
vllm.transformers_utils.tokenizer_group
import
get_tokenizer_group
from
vllm.transformers_utils.tokenizer_group
import
get_tokenizer_group
from
vllm.transformers_utils.tokenizers.mistral
import
MistralTokenizer
from
vllm.transformers_utils.tokenizers.mistral
import
MistralTokenizer
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine.detokenizer
import
(
FastIncrementalDetokenizer
,
IncrementalDetokenizer
,
SlowIncrementalDetokenizer
)
SPECIAL_TOKS_TRUTH
=
[
"Some text with adjacent special tokens <|padding|><|padding|><fim_prefix><fim_middle><fim_suffix>other text<fim_pad>"
,
# noqa
]
TRUTH
=
[
TRUTH
=
[
"Hello here, this is a simple test"
,
"Hello here, this is a simple test"
,
...
@@ -22,7 +30,8 @@ TRUTH = [
...
@@ -22,7 +30,8 @@ TRUTH = [
# incomplete UTF-8 characters
# incomplete UTF-8 characters
# see https://github.com/vllm-project/vllm/pull/9625
# see https://github.com/vllm-project/vllm/pull/9625
"ပုံပြင်လေးပြောပြပါ်"
,
"ပုံပြင်လေးပြောပြပါ်"
,
]
]
+
SPECIAL_TOKS_TRUTH
TOKENIZERS
=
[
TOKENIZERS
=
[
"facebook/opt-125m"
,
"facebook/opt-125m"
,
"gpt2"
,
"gpt2"
,
...
@@ -38,26 +47,37 @@ TOKENIZERS = [
...
@@ -38,26 +47,37 @@ TOKENIZERS = [
]
]
def
_run_incremental_decode
(
tokenizer
,
all_input_ids
,
def
_run_incremental_decode
(
tokenizer
,
skip_special_tokens
:
bool
,
starting_index
:
int
):
all_input_ids
,
decoded_text
=
""
skip_special_tokens
:
bool
,
offset
=
0
starting_index
:
int
,
token_offset
=
0
spaces_between_special_tokens
:
bool
=
True
,
prev_tokens
=
None
fast
:
Optional
[
bool
]
=
None
):
for
i
in
range
(
starting_index
,
len
(
all_input_ids
)):
new_tokens
,
text
,
offset
,
token_offset
=
detokenize_incrementally
(
prompt_token_ids
=
all_input_ids
[:
starting_index
]
tokenizer
,
all_input_ids
[:
i
+
1
],
params
=
SamplingParams
(
prev_tokens
,
skip_special_tokens
=
skip_special_tokens
,
offset
,
spaces_between_special_tokens
=
spaces_between_special_tokens
,
token_offset
,
)
skip_special_tokens
=
skip_special_tokens
)
request
=
EngineCoreRequest
(
""
,
""
,
prompt_token_ids
,
None
,
None
,
None
,
decoded_text
+=
text
params
,
None
,
0.0
,
None
)
if
prev_tokens
is
None
:
prev_tokens
=
new_tokens
if
fast
is
None
:
else
:
detokenizer
=
IncrementalDetokenizer
.
from_new_request
(
prev_tokens
+=
new_tokens
tokenizer
,
request
)
return
decoded_text
elif
fast
:
detokenizer
=
FastIncrementalDetokenizer
(
tokenizer
,
request
)
else
:
detokenizer
=
SlowIncrementalDetokenizer
(
tokenizer
,
request
)
output_text
=
""
for
i
,
token_id
in
enumerate
(
all_input_ids
[
starting_index
:]):
detokenizer
.
update
([
token_id
],
False
)
finished
=
i
==
len
(
all_input_ids
)
-
1
output_text
+=
detokenizer
.
get_next_output_text
(
finished
,
delta
=
True
)
return
output_text
,
detokenizer
.
output_token_ids
@
pytest
.
fixture
@
pytest
.
fixture
...
@@ -85,11 +105,13 @@ def test_mistral_edge_case(tokenizer, truth):
...
@@ -85,11 +105,13 @@ def test_mistral_edge_case(tokenizer, truth):
starting_index
=
0
starting_index
=
0
all_input_ids
=
tokenizer
(
truth
,
add_special_tokens
=
False
).
input_ids
all_input_ids
=
tokenizer
(
truth
,
add_special_tokens
=
False
).
input_ids
decoded_text
=
_run_incremental_decode
(
tokenizer
,
decoded_text
,
out_ids
=
_run_incremental_decode
(
all_input_ids
,
tokenizer
,
skip_special_tokens
=
True
,
all_input_ids
,
starting_index
=
starting_index
)
skip_special_tokens
=
True
,
starting_index
=
starting_index
)
assert
decoded_text
==
truth
assert
decoded_text
==
truth
assert
out_ids
==
all_input_ids
[
starting_index
:]
@
pytest
.
fixture
@
pytest
.
fixture
...
@@ -106,40 +128,86 @@ def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
...
@@ -106,40 +128,86 @@ def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
@
pytest
.
mark
.
parametrize
(
"with_prompt"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"with_prompt"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"tokenizer_name"
,
TOKENIZERS
)
@
pytest
.
mark
.
parametrize
(
"tokenizer_name"
,
TOKENIZERS
)
@
pytest
.
mark
.
parametrize
(
"skip_special_tokens"
,
(
True
,
False
),
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"skip_special_tokens"
,
(
True
,
False
),
indirect
=
True
)
def
test_decode_streaming
(
tokenizer
,
truth
,
with_prompt
,
skip_special_tokens
):
@
pytest
.
mark
.
parametrize
(
"spaces_between_special_tokens"
,
(
True
,
False
))
@
pytest
.
mark
.
parametrize
(
"fast"
,
(
True
,
False
))
def
test_decode_streaming
(
tokenizer
,
truth
,
with_prompt
,
skip_special_tokens
,
spaces_between_special_tokens
,
fast
):
if
fast
and
not
isinstance
(
tokenizer
,
PreTrainedTokenizerFast
):
pytest
.
skip
()
if
skip_special_tokens
and
not
spaces_between_special_tokens
:
pytest
.
skip
()
if
not
fast
and
isinstance
(
tokenizer
,
PreTrainedTokenizerFast
):
# Fix up inconsistency in fast/slow tokenizer behaviour.
tokenizer
.
add_special_tokens
({
"additional_special_tokens"
:
[
at
for
at
in
tokenizer
.
_tokenizer
.
get_added_tokens_decoder
().
values
()
if
at
.
special
]
})
extra_decode_args
=
{}
if
not
isinstance
(
tokenizer
,
PreTrainedTokenizer
)
\
else
{
"spaces_between_special_tokens"
:
spaces_between_special_tokens
}
truth_tokens
=
tokenizer
(
truth
,
add_special_tokens
=
False
).
input_ids
if
tokenizer
.
bos_token_id
is
not
None
:
truth_tokens
.
insert
(
0
,
tokenizer
.
bos_token_id
)
truth_tokens
.
append
(
tokenizer
.
eos_token_id
)
new_truth
=
tokenizer
.
decode
(
truth_tokens
,
skip_special_tokens
=
skip_special_tokens
,
**
extra_decode_args
)
if
with_prompt
:
if
with_prompt
:
truth_tokens
=
tokenizer
(
truth
,
add_special_tokens
=
False
).
input_ids
num_prompt_tokens
=
len
(
prompt_input_ids
=
truth_tokens
[:
len
(
truth
)
//
2
]
tokenizer
(
truth
[:
len
(
truth
)
//
2
],
generated_input_ids
=
truth_tokens
[
len
(
truth
)
//
2
:]
add_special_tokens
=
False
).
input_ids
)
if
tokenizer
.
bos_token_id
is
not
None
:
num_prompt_tokens
+=
1
prompt_input_ids
=
truth_tokens
[:
num_prompt_tokens
]
generated_input_ids
=
truth_tokens
[
num_prompt_tokens
:]
all_input_ids
=
prompt_input_ids
+
generated_input_ids
all_input_ids
=
prompt_input_ids
+
generated_input_ids
starting_index
=
len
(
prompt_input_ids
)
starting_index
=
len
(
prompt_input_ids
)
prompt
=
tokenizer
.
decode
(
prompt_input_ids
,
prompt
=
tokenizer
.
decode
(
prompt_input_ids
,
skip_special_tokens
=
skip_special_tokens
)
skip_special_tokens
=
skip_special_tokens
,
generated
=
truth
[
len
(
prompt
):]
**
extra_decode_args
)
generated
=
new_truth
[
len
(
prompt
):]
else
:
else
:
generated
=
truth
generated
=
new_
truth
starting_index
=
0
starting_index
=
0
all_input_ids
=
tokenizer
(
truth
,
add_special_tokens
=
False
).
input_ids
all_input_ids
=
truth_tokens
if
skip_special_tokens
:
if
tokenizer
.
bos_token_id
is
not
None
:
all_input_ids
=
[
tokenizer
.
bos_token_id
]
+
all_input_ids
starting_index
+=
1
all_input_ids
=
all_input_ids
+
[
tokenizer
.
eos_token_id
]
decoded_text
=
_run_incremental_decode
(
decoded_text
,
out_ids
=
_run_incremental_decode
(
tokenizer
,
tokenizer
,
all_input_ids
,
all_input_ids
,
skip_special_tokens
=
skip_special_tokens
,
skip_special_tokens
=
skip_special_tokens
,
starting_index
=
starting_index
)
starting_index
=
starting_index
,
spaces_between_special_tokens
=
spaces_between_special_tokens
,
fast
=
fast
)
assert
decoded_text
==
generated
assert
decoded_text
==
generated
assert
out_ids
==
all_input_ids
[
starting_index
:]
decoded_text
=
_run_incremental_decode
(
@
pytest
.
mark
.
parametrize
(
"tokenizer_name"
,
TOKENIZERS
)
@
pytest
.
mark
.
parametrize
(
"fast"
,
(
True
,
False
))
def
test_oov_decode
(
tokenizer
,
fast
):
if
fast
and
not
isinstance
(
tokenizer
,
PreTrainedTokenizerFast
):
pytest
.
skip
()
decoded_text
,
out_ids
=
_run_incremental_decode
(
tokenizer
,
[
len
(
tokenizer
)],
tokenizer
,
[
len
(
tokenizer
)],
skip_special_tokens
=
skip_special_tokens
,
skip_special_tokens
=
True
,
starting_index
=
starting_index
)
starting_index
=
0
,
spaces_between_special_tokens
=
True
,
fast
=
fast
)
assert
decoded_text
==
''
assert
decoded_text
==
''
assert
out_ids
==
[
len
(
tokenizer
)]
@
pytest
.
fixture
@
pytest
.
fixture
...
@@ -165,15 +233,14 @@ def detokenizer(tokenizer_name: str) -> Detokenizer:
...
@@ -165,15 +233,14 @@ def detokenizer(tokenizer_name: str) -> Detokenizer:
@
pytest
.
fixture
(
name
=
"complete_sequence_token_ids"
)
@
pytest
.
fixture
(
name
=
"complete_sequence_token_ids"
)
def
create_complete_sequence_token_ids
(
complete_sequence
:
str
,
def
create_complete_sequence_token_ids
(
complete_sequence
:
str
,
tokenizer
)
->
list
[
int
]:
tokenizer
)
->
list
[
int
]:
complete_sequence_token_ids
=
tokenizer
(
complete_sequence
).
input_ids
return
tokenizer
(
complete_sequence
,
add_special_tokens
=
False
).
input_ids
return
complete_sequence_token_ids
def
create_sequence
(
prompt_token_ids
=
None
):
def
create_sequence
(
prompt_token_ids
=
None
):
prompt_token_ids
=
prompt_token_ids
or
[
1
]
prompt_token_ids
=
prompt_token_ids
or
[]
return
Sequence
(
return
Sequence
(
seq_id
=
0
,
seq_id
=
0
,
inputs
=
token_inputs
(
prompt_token_ids
,
prompt
=
"<s>"
),
inputs
=
token_inputs
(
prompt_token_ids
),
block_size
=
16
,
block_size
=
16
,
)
)
...
@@ -224,7 +291,7 @@ def test_decode_sequence_logprobs(complete_sequence: str,
...
@@ -224,7 +291,7 @@ def test_decode_sequence_logprobs(complete_sequence: str,
assert
sequential_result
==
""
.
join
(
sequential_logprobs_text_chosen_token
)
assert
sequential_result
==
""
.
join
(
sequential_logprobs_text_chosen_token
)
assert
sequential_result
!=
""
.
join
(
sequential_logprobs_text_other_token
)
assert
sequential_result
!=
""
.
join
(
sequential_logprobs_text_other_token
)
if
skip_special_tokens
:
if
not
skip_special_tokens
:
# Text for logprobs for the chosen token should be the same as the
# Text for logprobs for the chosen token should be the same as the
# generated text. Note that this will only be true if we skip
# generated text. Note that this will only be true if we skip
# special tokens.
# special tokens.
...
@@ -233,10 +300,23 @@ def test_decode_sequence_logprobs(complete_sequence: str,
...
@@ -233,10 +300,23 @@ def test_decode_sequence_logprobs(complete_sequence: str,
@
pytest
.
mark
.
parametrize
(
"complete_sequence"
,
TRUTH
)
@
pytest
.
mark
.
parametrize
(
"complete_sequence"
,
TRUTH
)
@
pytest
.
mark
.
parametrize
(
"tokenizer_name"
,
TOKENIZERS
)
@
pytest
.
mark
.
parametrize
(
"tokenizer_name"
,
TOKENIZERS
)
def
test_decode_prompt_logprobs
(
complete_sequence_token_ids
:
list
[
int
],
def
test_decode_prompt_logprobs
(
complete_sequence
:
str
,
complete_sequence_token_ids
:
list
[
int
],
detokenizer
:
Detokenizer
):
detokenizer
:
Detokenizer
):
# We want to use skip_special_tokens=False here but Mistral tokenizers
# don't support that.
if
complete_sequence
not
in
SPECIAL_TOKS_TRUTH
:
skip_special_tokens
=
True
elif
not
isinstance
(
detokenizer
.
tokenizer_group
.
get_lora_tokenizer
(
None
),
MistralTokenizer
):
skip_special_tokens
=
False
else
:
pytest
.
skip
(
"MistralTokenizers don't support "
"skip_special_tokens=False"
)
return
"""Verify Detokenizer decodes prompt logprobs correctly."""
"""Verify Detokenizer decodes prompt logprobs correctly."""
sampling_params
=
SamplingParams
(
skip_special_tokens
=
True
,
sampling_params
=
SamplingParams
(
skip_special_tokens
=
skip_special_tokens
,
prompt_logprobs
=
1
)
prompt_logprobs
=
1
)
# Run sequentially.
# Run sequentially.
...
@@ -256,8 +336,10 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int],
...
@@ -256,8 +336,10 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int],
# decoded_prompt_logprobs doesn't contain the first token.
# decoded_prompt_logprobs doesn't contain the first token.
token_ids
=
complete_sequence_token_ids
token_ids
=
complete_sequence_token_ids
tokenizer
=
detokenizer
.
get_tokenizer_for_seq
(
seq
)
tokenizer
=
detokenizer
.
get_tokenizer_for_seq
(
seq
)
text_full
=
tokenizer
.
decode
(
token_ids
,
skip_special_tokens
=
True
)
text_full
=
tokenizer
.
decode
(
token_ids
,
text_first
=
tokenizer
.
decode
(
token_ids
[
0
],
skip_special_tokens
=
True
)
skip_special_tokens
=
skip_special_tokens
)
text_first
=
tokenizer
.
decode
(
token_ids
[
0
],
skip_special_tokens
=
skip_special_tokens
)
text
=
text_full
[
len
(
text_first
):]
text
=
text_full
[
len
(
text_first
):]
# Text for logprobs for the chosen token should be the same as the
# Text for logprobs for the chosen token should be the same as the
...
...
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
View file @
05fcd1b4
...
@@ -70,6 +70,15 @@ class MistralToolParser(ToolParser):
...
@@ -70,6 +70,15 @@ class MistralToolParser(ToolParser):
"Mistral Tool Parser could not locate the tool call token in "
"Mistral Tool Parser could not locate the tool call token in "
"the tokenizer!"
)
"the tokenizer!"
)
def
adjust_request
(
self
,
request
:
ChatCompletionRequest
)
->
ChatCompletionRequest
:
if
request
.
tools
and
request
.
tool_choice
!=
'none'
:
# do not skip special tokens because mistral uses the special
# tokens to indicate the start and end of the tool calls
# information.
request
.
skip_special_tokens
=
False
return
request
def
extract_tool_calls
(
def
extract_tool_calls
(
self
,
self
,
model_output
:
str
,
model_output
:
str
,
...
...
vllm/v1/engine/detokenizer.py
View file @
05fcd1b4
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
from
typing
import
Optional
from
tokenizers
import
Tokenizer
from
tokenizers.decoders
import
DecodeStream
from
transformers
import
PreTrainedTokenizerFast
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.detokenizer_utils
import
(
from
vllm.transformers_utils.detokenizer_utils
import
(
...
@@ -12,39 +15,22 @@ from vllm.v1.engine import EngineCoreRequest
...
@@ -12,39 +15,22 @@ from vllm.v1.engine import EngineCoreRequest
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
@
dataclass
class
IncrementalDetokenizer
:
class
IncrementalDetokenizer
:
# Generation data
def
__init__
(
self
):
token_ids
:
list
[
int
]
self
.
token_ids
:
list
[
int
]
=
[]
output_text
:
str
=
""
tokens
:
list
[
str
]
=
field
(
default_factory
=
list
)
prompt_len
:
int
=
0
# Stop strings
stop
:
list
[
str
]
=
field
(
default_factory
=
list
)
include_stop_str_in_output
:
bool
=
False
# Metadata for incremental detokenization
prefix_offset
:
int
=
0
read_offset
:
int
=
0
# Parameters for detokenization
skip_special_tokens
:
bool
=
True
spaces_between_special_tokens
:
bool
=
True
# Tokenizer for this request,
# None if detokenization is disabled.
tokenizer
:
Optional
[
AnyTokenizer
]
=
None
# Accounting for stop string buffering
stop_buffer_length
:
int
=
0
_last_output_text_offset
:
int
=
0
@
property
@
property
def
output_token_ids
(
self
)
->
list
[
int
]:
def
output_token_ids
(
self
)
->
list
[
int
]:
return
self
.
token_ids
if
not
self
.
prompt_len
else
(
return
self
.
token_ids
self
.
token_ids
[
self
.
prompt_len
:])
def
update
(
self
,
new_token_ids
:
list
[
int
],
stop_terminated
:
bool
)
->
Optional
[
str
]:
self
.
token_ids
.
extend
(
new_token_ids
)
return
None
def
get_next_output_text
(
self
,
finished
:
bool
,
delta
:
bool
)
->
str
:
return
""
@
classmethod
@
classmethod
def
from_new_request
(
def
from_new_request
(
...
@@ -54,39 +40,37 @@ class IncrementalDetokenizer:
...
@@ -54,39 +40,37 @@ class IncrementalDetokenizer:
)
->
"IncrementalDetokenizer"
:
)
->
"IncrementalDetokenizer"
:
if
tokenizer
is
None
:
if
tokenizer
is
None
:
return
cls
(
token_ids
=
[])
# No tokenizer => skipping detokenization.
return
IncrementalDetokenizer
()
if
isinstance
(
tokenizer
,
PreTrainedTokenizerFast
):
# Fast tokenizer => use tokenizers library DecodeStream.
return
FastIncrementalDetokenizer
(
tokenizer
,
request
)
# Fall back to slow python-based incremental detokenization.
return
SlowIncrementalDetokenizer
(
tokenizer
,
request
)
class
BaseIncrementalDetokenizer
(
IncrementalDetokenizer
,
ABC
):
def
__init__
(
self
,
request
:
EngineCoreRequest
):
super
().
__init__
()
tokens
,
prefix_offset
,
read_offset
=
convert_prompt_ids_to_tokens
(
# Stop strings
tokenizer
=
tokenizer
,
params
=
request
.
sampling_params
prompt_ids
=
request
.
prompt_token_ids
,
self
.
stop
=
stop
=
params
.
stop
skip_special_tokens
=
request
.
sampling_params
.
skip_special_tokens
,
self
.
include_stop_str_in_output
=
params
.
include_stop_str_in_output
)
stops
=
request
.
sampling_params
.
stop
# Number of chars to hold back when stop strings are to be excluded
# Number of chars to hold back when stop strings are to be excluded
# from streamed output.
# from streamed output.
if
stop
s
and
not
request
.
sampling_params
.
include_stop_str_in_output
:
if
stop
and
not
self
.
include_stop_str_in_output
:
stop_buffer_length
=
max
(
len
(
s
)
for
s
in
stop
s
)
-
1
self
.
stop_buffer_length
=
max
(
len
(
s
)
for
s
in
stop
)
-
1
else
:
else
:
stop_buffer_length
=
0
self
.
stop_buffer_length
=
0
self
.
_last_output_text_offset
:
int
=
0
return
cls
(
tokens
=
tokens
,
# Generation data
# Detokenizer mutates this list, so need a unique copy.
self
.
output_text
=
""
# NOTE(Nick): could we take ownership of it though?
token_ids
=
request
.
prompt_token_ids
.
copy
(),
stop
=
stops
,
include_stop_str_in_output
=
request
.
sampling_params
.
include_stop_str_in_output
,
prefix_offset
=
prefix_offset
,
read_offset
=
read_offset
,
skip_special_tokens
=
request
.
sampling_params
.
skip_special_tokens
,
spaces_between_special_tokens
=
request
.
sampling_params
.
spaces_between_special_tokens
,
prompt_len
=
len
(
request
.
prompt_token_ids
),
tokenizer
=
tokenizer
,
stop_buffer_length
=
stop_buffer_length
,
)
def
update
(
self
,
new_token_ids
:
list
[
int
],
def
update
(
self
,
new_token_ids
:
list
[
int
],
stop_terminated
:
bool
)
->
Optional
[
str
]:
stop_terminated
:
bool
)
->
Optional
[
str
]:
...
@@ -98,11 +82,7 @@ class IncrementalDetokenizer:
...
@@ -98,11 +82,7 @@ class IncrementalDetokenizer:
Return matched stop string or None.
Return matched stop string or None.
"""
"""
if
not
new_token_ids
:
if
not
new_token_ids
:
# Skip detokenization if no new token ids
# Skip detokenization if no new token ids.
return
None
if
self
.
tokenizer
is
None
:
# Skip detokenization if no tokenizer
self
.
token_ids
.
extend
(
new_token_ids
)
return
None
return
None
if
stop_terminated
and
not
self
.
include_stop_str_in_output
:
if
stop_terminated
and
not
self
.
include_stop_str_in_output
:
...
@@ -116,34 +96,16 @@ class IncrementalDetokenizer:
...
@@ -116,34 +96,16 @@ class IncrementalDetokenizer:
# 1) Detokenize the new token ids incrementally.
# 1) Detokenize the new token ids incrementally.
# TODO(woosuk): This method becomes very inefficient when the number of
# TODO(woosuk): This method becomes very inefficient when the number of
# new_token_ids is more than 1. We need to optimize this.
# new_token_ids is more than 1. We need to optimize this.
decoded_text
=
""
offset_before
=
len
(
self
.
output_text
)
for
new_token_id
in
new_token_ids
:
for
new_token_id
in
new_token_ids
:
self
.
token_ids
.
append
(
new_token_id
)
self
.
token_ids
.
append
(
new_token_id
)
(
new_tokens
,
new_decoded_token_text
,
prefix_offset
,
self
.
output_text
+=
self
.
decode_next
(
new_token_id
)
read_offset
)
=
detokenize_incrementally
(
tokenizer
=
self
.
tokenizer
,
all_input_ids
=
self
.
token_ids
,
prev_tokens
=
self
.
tokens
,
prefix_offset
=
self
.
prefix_offset
,
read_offset
=
self
.
read_offset
,
skip_special_tokens
=
self
.
skip_special_tokens
,
spaces_between_special_tokens
=
self
.
spaces_between_special_tokens
,
)
self
.
tokens
.
extend
(
new_tokens
)
self
.
prefix_offset
=
prefix_offset
self
.
read_offset
=
read_offset
decoded_text
+=
new_decoded_token_text
self
.
output_text
+=
decoded_text
if
stop_terminated
:
if
stop_terminated
:
if
skipped_stop_token_id
is
not
None
:
if
skipped_stop_token_id
is
not
None
:
# Cleanup after skipping detokenization
# Cleanup after skipping detokenization
.
self
.
token_ids
.
append
(
skipped_stop_token_id
)
self
.
token_ids
.
append
(
skipped_stop_token_id
)
# Stop token triggered; skip stop string check
# Stop token triggered; skip stop string check
.
return
None
return
None
# 2) Evaluate stop strings.
# 2) Evaluate stop strings.
...
@@ -151,7 +113,7 @@ class IncrementalDetokenizer:
...
@@ -151,7 +113,7 @@ class IncrementalDetokenizer:
if
self
.
stop
:
if
self
.
stop
:
stop
=
StopChecker
.
check_stop_strings
(
stop
=
StopChecker
.
check_stop_strings
(
output_text
=
self
.
output_text
,
output_text
=
self
.
output_text
,
new_char_count
=
len
(
decoded_text
)
,
new_char_count
=
len
(
self
.
output_text
)
-
offset_before
,
stop
=
self
.
stop
,
stop
=
self
.
stop
,
include_in_output
=
self
.
include_stop_str_in_output
,
include_in_output
=
self
.
include_stop_str_in_output
,
)
)
...
@@ -162,6 +124,10 @@ class IncrementalDetokenizer:
...
@@ -162,6 +124,10 @@ class IncrementalDetokenizer:
return
stop_string
return
stop_string
@
abstractmethod
def
decode_next
(
self
,
next_token_id
:
int
)
->
str
:
raise
NotImplementedError
def
get_next_output_text
(
self
,
finished
:
bool
,
delta
:
bool
)
->
str
:
def
get_next_output_text
(
self
,
finished
:
bool
,
delta
:
bool
)
->
str
:
"""If delta is True, only new text since the last call to
"""If delta is True, only new text since the last call to
this method is returned"""
this method is returned"""
...
@@ -177,3 +143,114 @@ class IncrementalDetokenizer:
...
@@ -177,3 +143,114 @@ class IncrementalDetokenizer:
self
.
_last_output_text_offset
=
length
self
.
_last_output_text_offset
=
length
return
self
.
output_text
[
last_offset
:
length
]
return
self
.
output_text
[
last_offset
:
length
]
return
""
return
""
class
FastIncrementalDetokenizer
(
BaseIncrementalDetokenizer
):
def
__init__
(
self
,
tokenizer
:
PreTrainedTokenizerFast
,
request
:
EngineCoreRequest
):
super
().
__init__
(
request
)
sampling_params
=
request
.
sampling_params
self
.
stream
=
DecodeStream
(
skip_special_tokens
=
sampling_params
.
skip_special_tokens
)
self
.
tokenizer
:
Tokenizer
=
tokenizer
.
_tokenizer
# Find a safe place to start.
prompt_suffix
=
request
.
prompt_token_ids
prompt_len
=
len
(
prompt_suffix
)
if
prompt_len
>
4
:
for
i
in
range
(
4
,
max
(
prompt_len
+
1
,
32
)):
suffix
=
request
.
prompt_token_ids
[
-
i
:]
if
'�'
not
in
self
.
tokenizer
.
decode
(
suffix
):
prompt_suffix
=
suffix
break
# Prime the stream.
for
tid
in
prompt_suffix
:
self
.
stream
.
step
(
self
.
tokenizer
,
tid
)
self
.
spaces_between_special_tokens
=
(
sampling_params
.
skip_special_tokens
or
sampling_params
.
spaces_between_special_tokens
)
if
not
self
.
spaces_between_special_tokens
:
# Store dict of added token ids so that we can suppress
# the spaces between them.
if
(
added_token_ids
:
=
getattr
(
self
.
tokenizer
,
"added_token_ids"
,
None
))
is
None
:
self
.
tokenizer
.
added_token_ids
=
added_token_ids
=
{
tid
:
tok
.
content
for
tid
,
tok
in
self
.
tokenizer
.
get_added_tokens_decoder
().
items
()
}
if
added_token_ids
:
self
.
last_special
=
False
self
.
added_token_ids
=
added_token_ids
else
:
# No added tokens.
self
.
spaces_between_special_tokens
=
True
def
decode_next
(
self
,
next_token_id
:
int
)
->
str
:
token
=
self
.
stream
.
step
(
self
.
tokenizer
,
next_token_id
)
if
not
self
.
spaces_between_special_tokens
:
special_token
=
self
.
added_token_ids
.
get
(
next_token_id
)
is_special
=
special_token
is
not
None
if
is_special
and
self
.
last_special
:
# Return raw token string without any prefixed spaces.
token
=
special_token
self
.
last_special
=
is_special
return
token
or
""
class
SlowIncrementalDetokenizer
(
BaseIncrementalDetokenizer
):
def
__init__
(
self
,
tokenizer
:
AnyTokenizer
,
request
:
EngineCoreRequest
):
super
().
__init__
(
request
)
self
.
tokenizer
=
tokenizer
# Metadata for incremental detokenization.
self
.
tokens
,
self
.
prefix_offset
,
self
.
read_offset
=
(
convert_prompt_ids_to_tokens
(
tokenizer
=
tokenizer
,
prompt_ids
=
request
.
prompt_token_ids
,
skip_special_tokens
=
request
.
sampling_params
.
skip_special_tokens
,
))
self
.
token_ids
.
extend
(
request
.
prompt_token_ids
)
self
.
prompt_len
=
len
(
request
.
prompt_token_ids
)
params
=
request
.
sampling_params
self
.
skip_special_tokens
=
params
.
skip_special_tokens
self
.
spaces_between_special_tokens
=
(
params
.
spaces_between_special_tokens
)
@
property
def
output_token_ids
(
self
)
->
list
[
int
]:
return
self
.
token_ids
if
not
self
.
prompt_len
else
(
self
.
token_ids
[
self
.
prompt_len
:])
def
decode_next
(
self
,
next_token_id
:
int
)
->
str
:
new_tokens
,
decoded_text
,
prefix_offset
,
read_offset
=
(
detokenize_incrementally
(
tokenizer
=
self
.
tokenizer
,
all_input_ids
=
self
.
token_ids
,
prev_tokens
=
self
.
tokens
,
prefix_offset
=
self
.
prefix_offset
,
read_offset
=
self
.
read_offset
,
skip_special_tokens
=
self
.
skip_special_tokens
,
spaces_between_special_tokens
=
self
.
spaces_between_special_tokens
,
))
self
.
tokens
.
extend
(
new_tokens
)
self
.
prefix_offset
=
prefix_offset
self
.
read_offset
=
read_offset
return
decoded_text
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment