Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
40a046cd
Unverified
Commit
40a046cd
authored
Dec 05, 2025
by
Rohan Potdar
Committed by
GitHub
Dec 05, 2025
Browse files
[Bugfix]: Fix `TokenizerLike` interface (#30009)
Signed-off-by:
Rohan138
<
rohanpotdar138@gmail.com
>
parent
e858bc4d
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
78 additions
and
52 deletions
+78
-52
vllm/benchmarks/datasets.py
vllm/benchmarks/datasets.py
+32
-27
vllm/benchmarks/serve.py
vllm/benchmarks/serve.py
+16
-16
vllm/benchmarks/throughput.py
vllm/benchmarks/throughput.py
+16
-6
vllm/config/model.py
vllm/config/model.py
+2
-1
vllm/tokenizers/deepseekv32.py
vllm/tokenizers/deepseekv32.py
+3
-0
vllm/tokenizers/mistral.py
vllm/tokenizers/mistral.py
+5
-1
vllm/tokenizers/protocol.py
vllm/tokenizers/protocol.py
+3
-0
vllm/tokenizers/registry.py
vllm/tokenizers/registry.py
+1
-1
No files found.
vllm/benchmarks/datasets.py
View file @
40a046cd
...
@@ -32,7 +32,6 @@ from typing import Any, cast
...
@@ -32,7 +32,6 @@ from typing import Any, cast
import
numpy
as
np
import
numpy
as
np
from
PIL
import
Image
from
PIL
import
Image
from
transformers
import
PreTrainedTokenizerBase
from
typing_extensions
import
deprecated
from
typing_extensions
import
deprecated
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
...
@@ -189,7 +188,7 @@ class BenchmarkDataset(ABC):
...
@@ -189,7 +188,7 @@ class BenchmarkDataset(ABC):
@
abstractmethod
@
abstractmethod
def
sample
(
def
sample
(
self
,
self
,
tokenizer
:
PreTrained
Tokenizer
Bas
e
,
tokenizer
:
Tokenizer
Lik
e
,
num_requests
:
int
,
num_requests
:
int
,
request_id_prefix
:
str
=
""
,
request_id_prefix
:
str
=
""
,
no_oversample
:
bool
=
False
,
no_oversample
:
bool
=
False
,
...
@@ -201,7 +200,7 @@ class BenchmarkDataset(ABC):
...
@@ -201,7 +200,7 @@ class BenchmarkDataset(ABC):
for generating a list of SampleRequest objects.
for generating a list of SampleRequest objects.
Args:
Args:
tokenizer (
PreTrained
Tokenizer
Bas
e): The tokenizer to be used
tokenizer (Tokenizer
Lik
e): The tokenizer to be used
for processing the dataset's text.
for processing the dataset's text.
num_requests (int): The number of sample requests to generate.
num_requests (int): The number of sample requests to generate.
request_id_prefix (str): The prefix of request_id.
request_id_prefix (str): The prefix of request_id.
...
@@ -380,7 +379,7 @@ def process_video(video: Any) -> Mapping[str, Any]:
...
@@ -380,7 +379,7 @@ def process_video(video: Any) -> Mapping[str, Any]:
def
gen_prompt_decode_to_target_len
(
def
gen_prompt_decode_to_target_len
(
tokenizer
:
PreTrained
Tokenizer
Bas
e
,
tokenizer
:
Tokenizer
Lik
e
,
token_sequence
:
list
[
int
],
token_sequence
:
list
[
int
],
target_token_len
:
int
,
target_token_len
:
int
,
max_retry
:
int
=
10
,
max_retry
:
int
=
10
,
...
@@ -468,7 +467,7 @@ class RandomDataset(BenchmarkDataset):
...
@@ -468,7 +467,7 @@ class RandomDataset(BenchmarkDataset):
def
sample
(
def
sample
(
self
,
self
,
tokenizer
:
PreTrained
Tokenizer
Bas
e
,
tokenizer
:
Tokenizer
Lik
e
,
num_requests
:
int
,
num_requests
:
int
,
request_id_prefix
:
str
=
""
,
request_id_prefix
:
str
=
""
,
no_oversample
:
bool
=
False
,
no_oversample
:
bool
=
False
,
...
@@ -580,7 +579,7 @@ class RandomDataset(BenchmarkDataset):
...
@@ -580,7 +579,7 @@ class RandomDataset(BenchmarkDataset):
range_ratio
:
float
,
range_ratio
:
float
,
input_len
:
int
,
input_len
:
int
,
output_len
:
int
,
output_len
:
int
,
tokenizer
:
PreTrained
Tokenizer
Bas
e
,
tokenizer
:
Tokenizer
Lik
e
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
"""
"""
Get the sampling parameters for the dataset.
Get the sampling parameters for the dataset.
...
@@ -626,7 +625,7 @@ class RandomDataset(BenchmarkDataset):
...
@@ -626,7 +625,7 @@ class RandomDataset(BenchmarkDataset):
def
generate_token_sequence
(
def
generate_token_sequence
(
self
,
self
,
*
,
*
,
tokenizer
:
PreTrained
Tokenizer
Bas
e
,
tokenizer
:
Tokenizer
Lik
e
,
prefix_token_ids
:
list
[
int
],
prefix_token_ids
:
list
[
int
],
prefix_len
:
int
,
prefix_len
:
int
,
vocab_size
:
int
,
vocab_size
:
int
,
...
@@ -686,7 +685,7 @@ class RandomDatasetForReranking(RandomDataset):
...
@@ -686,7 +685,7 @@ class RandomDatasetForReranking(RandomDataset):
def
sample
(
def
sample
(
self
,
self
,
tokenizer
:
PreTrained
Tokenizer
Bas
e
,
tokenizer
:
Tokenizer
Lik
e
,
num_requests
:
int
,
num_requests
:
int
,
request_id_prefix
:
str
=
""
,
request_id_prefix
:
str
=
""
,
range_ratio
:
float
=
RandomDataset
.
DEFAULT_RANGE_RATIO
,
range_ratio
:
float
=
RandomDataset
.
DEFAULT_RANGE_RATIO
,
...
@@ -716,7 +715,11 @@ class RandomDatasetForReranking(RandomDataset):
...
@@ -716,7 +715,11 @@ class RandomDatasetForReranking(RandomDataset):
doc_lens
,
_
,
doc_offsets
=
self
.
get_sampling_params
(
doc_lens
,
_
,
doc_offsets
=
self
.
get_sampling_params
(
num_requests
,
range_ratio
,
doc_len_param
,
0
,
tokenizer
num_requests
,
range_ratio
,
doc_len_param
,
0
,
tokenizer
)
)
vocab_size
=
tokenizer
.
vocab_size
vocab_size
=
tokenizer
.
vocab_size
prohibited_tokens
=
tokenizer
.
all_special_ids
all_tokens
=
np
.
arange
(
vocab_size
)
allowed_tokens
=
np
.
array
(
list
(
set
(
all_tokens
)
-
set
(
prohibited_tokens
)))
query_prompt
,
query_input_len
,
token_mismatch_total
=
(
query_prompt
,
query_input_len
,
token_mismatch_total
=
(
self
.
generate_token_sequence
(
self
.
generate_token_sequence
(
...
@@ -727,6 +730,7 @@ class RandomDatasetForReranking(RandomDataset):
...
@@ -727,6 +730,7 @@ class RandomDatasetForReranking(RandomDataset):
input_len
=
query_len
,
input_len
=
query_len
,
offset
=
int
(
query_offsets
[
0
]),
offset
=
int
(
query_offsets
[
0
]),
index
=
0
,
index
=
0
,
allowed_tokens
=
allowed_tokens
,
)
)
)
)
...
@@ -740,6 +744,7 @@ class RandomDatasetForReranking(RandomDataset):
...
@@ -740,6 +744,7 @@ class RandomDatasetForReranking(RandomDataset):
input_len
=
int
(
doc_lens
[
i
]),
input_len
=
int
(
doc_lens
[
i
]),
offset
=
int
(
doc_offsets
[
i
]),
offset
=
int
(
doc_offsets
[
i
]),
index
=
i
+
1
,
index
=
i
+
1
,
allowed_tokens
=
allowed_tokens
,
)
)
token_mismatch_total
+=
token_mismatch
token_mismatch_total
+=
token_mismatch
requests
.
append
((
prompt
,
total_input_len
))
requests
.
append
((
prompt
,
total_input_len
))
...
@@ -1077,7 +1082,7 @@ class RandomMultiModalDataset(RandomDataset):
...
@@ -1077,7 +1082,7 @@ class RandomMultiModalDataset(RandomDataset):
def
sample
(
def
sample
(
self
,
self
,
tokenizer
:
PreTrained
Tokenizer
Bas
e
,
tokenizer
:
Tokenizer
Lik
e
,
num_requests
:
int
,
num_requests
:
int
,
request_id_prefix
:
str
=
""
,
request_id_prefix
:
str
=
""
,
no_oversample
:
bool
=
False
,
no_oversample
:
bool
=
False
,
...
@@ -1231,7 +1236,7 @@ class ShareGPTDataset(BenchmarkDataset):
...
@@ -1231,7 +1236,7 @@ class ShareGPTDataset(BenchmarkDataset):
def
sample
(
def
sample
(
self
,
self
,
tokenizer
:
PreTrained
Tokenizer
Bas
e
,
tokenizer
:
Tokenizer
Lik
e
,
num_requests
:
int
,
num_requests
:
int
,
lora_path
:
str
|
None
=
None
,
lora_path
:
str
|
None
=
None
,
max_loras
:
int
|
None
=
None
,
max_loras
:
int
|
None
=
None
,
...
@@ -1633,7 +1638,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
...
@@ -1633,7 +1638,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
)
)
def
get_samples
(
args
,
tokenizer
)
->
list
[
SampleRequest
]:
def
get_samples
(
args
,
tokenizer
:
TokenizerLike
)
->
list
[
SampleRequest
]:
if
not
hasattr
(
args
,
"request_id_prefix"
):
if
not
hasattr
(
args
,
"request_id_prefix"
):
args
.
request_id_prefix
=
""
args
.
request_id_prefix
=
""
...
@@ -1971,7 +1976,7 @@ class CustomDataset(BenchmarkDataset):
...
@@ -1971,7 +1976,7 @@ class CustomDataset(BenchmarkDataset):
def
sample
(
def
sample
(
self
,
self
,
tokenizer
:
PreTrained
Tokenizer
Bas
e
,
tokenizer
:
Tokenizer
Lik
e
,
num_requests
:
int
,
num_requests
:
int
,
lora_path
:
str
|
None
=
None
,
lora_path
:
str
|
None
=
None
,
max_loras
:
int
|
None
=
None
,
max_loras
:
int
|
None
=
None
,
...
@@ -2101,7 +2106,7 @@ class SonnetDataset(BenchmarkDataset):
...
@@ -2101,7 +2106,7 @@ class SonnetDataset(BenchmarkDataset):
def
sample
(
def
sample
(
self
,
self
,
tokenizer
,
tokenizer
:
TokenizerLike
,
num_requests
:
int
,
num_requests
:
int
,
prefix_len
:
int
=
DEFAULT_PREFIX_LEN
,
prefix_len
:
int
=
DEFAULT_PREFIX_LEN
,
input_len
:
int
=
DEFAULT_INPUT_LEN
,
input_len
:
int
=
DEFAULT_INPUT_LEN
,
...
@@ -2202,7 +2207,7 @@ class BurstGPTDataset(BenchmarkDataset):
...
@@ -2202,7 +2207,7 @@ class BurstGPTDataset(BenchmarkDataset):
def
sample
(
def
sample
(
self
,
self
,
tokenizer
:
PreTrained
Tokenizer
Bas
e
,
tokenizer
:
Tokenizer
Lik
e
,
num_requests
:
int
,
num_requests
:
int
,
max_loras
:
int
|
None
=
None
,
max_loras
:
int
|
None
=
None
,
lora_path
:
str
|
None
=
None
,
lora_path
:
str
|
None
=
None
,
...
@@ -2287,7 +2292,7 @@ class ConversationDataset(HuggingFaceDataset):
...
@@ -2287,7 +2292,7 @@ class ConversationDataset(HuggingFaceDataset):
def
sample
(
def
sample
(
self
,
self
,
tokenizer
:
PreTrained
Tokenizer
Bas
e
,
tokenizer
:
Tokenizer
Lik
e
,
num_requests
:
int
,
num_requests
:
int
,
output_len
:
int
|
None
=
None
,
output_len
:
int
|
None
=
None
,
enable_multimodal_chat
:
bool
=
False
,
enable_multimodal_chat
:
bool
=
False
,
...
@@ -2347,7 +2352,7 @@ class MultiModalConversationDataset(HuggingFaceDataset):
...
@@ -2347,7 +2352,7 @@ class MultiModalConversationDataset(HuggingFaceDataset):
def
sample
(
def
sample
(
self
,
self
,
tokenizer
:
PreTrained
Tokenizer
Bas
e
,
tokenizer
:
Tokenizer
Lik
e
,
num_requests
:
int
,
num_requests
:
int
,
output_len
:
int
|
None
=
None
,
output_len
:
int
|
None
=
None
,
enable_multimodal_chat
:
bool
=
False
,
enable_multimodal_chat
:
bool
=
False
,
...
@@ -2416,7 +2421,7 @@ class VisionArenaDataset(HuggingFaceDataset):
...
@@ -2416,7 +2421,7 @@ class VisionArenaDataset(HuggingFaceDataset):
def
sample
(
def
sample
(
self
,
self
,
tokenizer
:
PreTrained
Tokenizer
Bas
e
,
tokenizer
:
Tokenizer
Lik
e
,
num_requests
:
int
,
num_requests
:
int
,
output_len
:
int
|
None
=
None
,
output_len
:
int
|
None
=
None
,
enable_multimodal_chat
:
bool
=
False
,
enable_multimodal_chat
:
bool
=
False
,
...
@@ -2470,7 +2475,7 @@ class MMVUDataset(HuggingFaceDataset):
...
@@ -2470,7 +2475,7 @@ class MMVUDataset(HuggingFaceDataset):
def
sample
(
def
sample
(
self
,
self
,
tokenizer
:
PreTrained
Tokenizer
Bas
e
,
tokenizer
:
Tokenizer
Lik
e
,
num_requests
:
int
,
num_requests
:
int
,
output_len
:
int
|
None
=
None
,
output_len
:
int
|
None
=
None
,
enable_multimodal_chat
:
bool
=
False
,
enable_multimodal_chat
:
bool
=
False
,
...
@@ -2531,7 +2536,7 @@ class InstructCoderDataset(HuggingFaceDataset):
...
@@ -2531,7 +2536,7 @@ class InstructCoderDataset(HuggingFaceDataset):
def
sample
(
def
sample
(
self
,
self
,
tokenizer
:
PreTrained
Tokenizer
Bas
e
,
tokenizer
:
Tokenizer
Lik
e
,
num_requests
:
int
,
num_requests
:
int
,
output_len
:
int
|
None
=
None
,
output_len
:
int
|
None
=
None
,
enable_multimodal_chat
:
bool
=
False
,
enable_multimodal_chat
:
bool
=
False
,
...
@@ -2595,7 +2600,7 @@ class MTBenchDataset(HuggingFaceDataset):
...
@@ -2595,7 +2600,7 @@ class MTBenchDataset(HuggingFaceDataset):
def
sample
(
def
sample
(
self
,
self
,
tokenizer
:
PreTrained
Tokenizer
Bas
e
,
tokenizer
:
Tokenizer
Lik
e
,
num_requests
:
int
,
num_requests
:
int
,
output_len
:
int
|
None
=
None
,
output_len
:
int
|
None
=
None
,
enable_multimodal_chat
:
bool
=
False
,
enable_multimodal_chat
:
bool
=
False
,
...
@@ -2661,7 +2666,7 @@ class BlazeditDataset(HuggingFaceDataset):
...
@@ -2661,7 +2666,7 @@ class BlazeditDataset(HuggingFaceDataset):
def
sample
(
def
sample
(
self
,
self
,
tokenizer
:
PreTrained
Tokenizer
Bas
e
,
tokenizer
:
Tokenizer
Lik
e
,
num_requests
:
int
,
num_requests
:
int
,
output_len
:
int
|
None
=
None
,
output_len
:
int
|
None
=
None
,
skip_chat_template
:
bool
=
False
,
skip_chat_template
:
bool
=
False
,
...
@@ -2742,7 +2747,7 @@ class AIMODataset(HuggingFaceDataset):
...
@@ -2742,7 +2747,7 @@ class AIMODataset(HuggingFaceDataset):
def
sample
(
def
sample
(
self
,
self
,
tokenizer
:
PreTrained
Tokenizer
Bas
e
,
tokenizer
:
Tokenizer
Lik
e
,
num_requests
:
int
,
num_requests
:
int
,
output_len
:
int
|
None
=
None
,
output_len
:
int
|
None
=
None
,
request_id_prefix
:
str
=
""
,
request_id_prefix
:
str
=
""
,
...
@@ -2852,7 +2857,7 @@ class NextEditPredictionDataset(HuggingFaceDataset):
...
@@ -2852,7 +2857,7 @@ class NextEditPredictionDataset(HuggingFaceDataset):
def
sample
(
def
sample
(
self
,
self
,
tokenizer
:
PreTrained
Tokenizer
Bas
e
,
tokenizer
:
Tokenizer
Lik
e
,
num_requests
:
int
,
num_requests
:
int
,
request_id_prefix
:
str
=
""
,
request_id_prefix
:
str
=
""
,
no_oversample
:
bool
=
False
,
no_oversample
:
bool
=
False
,
...
@@ -2924,7 +2929,7 @@ class ASRDataset(HuggingFaceDataset):
...
@@ -2924,7 +2929,7 @@ class ASRDataset(HuggingFaceDataset):
def
sample
(
def
sample
(
self
,
self
,
tokenizer
:
PreTrained
Tokenizer
Bas
e
,
tokenizer
:
Tokenizer
Lik
e
,
num_requests
:
int
,
num_requests
:
int
,
output_len
:
int
|
None
=
None
,
output_len
:
int
|
None
=
None
,
request_id_prefix
:
str
=
""
,
request_id_prefix
:
str
=
""
,
...
@@ -3002,7 +3007,7 @@ class MLPerfDataset(HuggingFaceDataset):
...
@@ -3002,7 +3007,7 @@ class MLPerfDataset(HuggingFaceDataset):
def
sample
(
def
sample
(
self
,
self
,
tokenizer
:
PreTrained
Tokenizer
Bas
e
,
tokenizer
:
Tokenizer
Lik
e
,
num_requests
:
int
,
num_requests
:
int
,
output_len
:
int
|
None
=
None
,
output_len
:
int
|
None
=
None
,
request_id_prefix
:
str
=
""
,
request_id_prefix
:
str
=
""
,
...
@@ -3081,7 +3086,7 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset):
...
@@ -3081,7 +3086,7 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset):
def
sample
(
def
sample
(
self
,
self
,
tokenizer
:
PreTrained
Tokenizer
Bas
e
,
tokenizer
:
Tokenizer
Lik
e
,
num_requests
:
int
,
num_requests
:
int
,
prefix_len
:
int
=
DEFAULT_PREFIX_LEN
,
prefix_len
:
int
=
DEFAULT_PREFIX_LEN
,
suffix_len
:
int
=
DEFAULT_SUFFIX_LEN
,
suffix_len
:
int
=
DEFAULT_SUFFIX_LEN
,
...
@@ -3167,7 +3172,7 @@ class MMStarDataset(HuggingFaceDataset):
...
@@ -3167,7 +3172,7 @@ class MMStarDataset(HuggingFaceDataset):
def
sample
(
def
sample
(
self
,
self
,
tokenizer
:
PreTrained
Tokenizer
Bas
e
,
tokenizer
:
Tokenizer
Lik
e
,
num_requests
:
int
,
num_requests
:
int
,
output_len
:
int
|
None
=
None
,
output_len
:
int
|
None
=
None
,
enable_multimodal_chat
:
bool
=
False
,
enable_multimodal_chat
:
bool
=
False
,
...
...
vllm/benchmarks/serve.py
View file @
40a046cd
...
@@ -36,7 +36,6 @@ from typing import Any, Literal
...
@@ -36,7 +36,6 @@ from typing import Any, Literal
import
aiohttp
import
aiohttp
import
numpy
as
np
import
numpy
as
np
from
tqdm.asyncio
import
tqdm
from
tqdm.asyncio
import
tqdm
from
transformers
import
PreTrainedTokenizerBase
from
vllm.benchmarks.datasets
import
SampleRequest
,
add_dataset_parser
,
get_samples
from
vllm.benchmarks.datasets
import
SampleRequest
,
add_dataset_parser
,
get_samples
from
vllm.benchmarks.lib.endpoint_request_func
import
(
from
vllm.benchmarks.lib.endpoint_request_func
import
(
...
@@ -47,7 +46,7 @@ from vllm.benchmarks.lib.endpoint_request_func import (
...
@@ -47,7 +46,7 @@ from vllm.benchmarks.lib.endpoint_request_func import (
)
)
from
vllm.benchmarks.lib.ready_checker
import
wait_for_endpoint
from
vllm.benchmarks.lib.ready_checker
import
wait_for_endpoint
from
vllm.benchmarks.lib.utils
import
convert_to_pytorch_benchmark_format
,
write_to_json
from
vllm.benchmarks.lib.utils
import
convert_to_pytorch_benchmark_format
,
write_to_json
from
vllm.tokenizers
import
get_tokenizer
from
vllm.tokenizers
import
TokenizerLike
,
get_tokenizer
from
vllm.utils.gc_utils
import
freeze_gc_heap
from
vllm.utils.gc_utils
import
freeze_gc_heap
from
vllm.utils.network_utils
import
join_host_port
from
vllm.utils.network_utils
import
join_host_port
...
@@ -286,7 +285,7 @@ def calculate_metrics(
...
@@ -286,7 +285,7 @@ def calculate_metrics(
input_requests
:
list
[
SampleRequest
],
input_requests
:
list
[
SampleRequest
],
outputs
:
list
[
RequestFuncOutput
],
outputs
:
list
[
RequestFuncOutput
],
dur_s
:
float
,
dur_s
:
float
,
tokenizer
:
PreTrained
Tokenizer
Bas
e
,
tokenizer
:
Tokenizer
Lik
e
,
selected_percentiles
:
list
[
float
],
selected_percentiles
:
list
[
float
],
goodput_config_dict
:
dict
[
str
,
float
],
goodput_config_dict
:
dict
[
str
,
float
],
)
->
tuple
[
BenchmarkMetrics
,
list
[
int
]]:
)
->
tuple
[
BenchmarkMetrics
,
list
[
int
]]:
...
@@ -489,7 +488,7 @@ async def benchmark(
...
@@ -489,7 +488,7 @@ async def benchmark(
base_url
:
str
,
base_url
:
str
,
model_id
:
str
,
model_id
:
str
,
model_name
:
str
,
model_name
:
str
,
tokenizer
:
PreTrained
Tokenizer
Bas
e
,
tokenizer
:
Tokenizer
Lik
e
,
input_requests
:
list
[
SampleRequest
],
input_requests
:
list
[
SampleRequest
],
logprobs
:
int
|
None
,
logprobs
:
int
|
None
,
request_rate
:
float
,
request_rate
:
float
,
...
@@ -1032,6 +1031,19 @@ def add_cli_args(parser: argparse.ArgumentParser):
...
@@ -1032,6 +1031,19 @@ def add_cli_args(parser: argparse.ArgumentParser):
type
=
str
,
type
=
str
,
help
=
"Name or path of the tokenizer, if not using the default tokenizer."
,
# noqa: E501
help
=
"Name or path of the tokenizer, if not using the default tokenizer."
,
# noqa: E501
)
)
parser
.
add_argument
(
"--tokenizer-mode"
,
type
=
str
,
default
=
"auto"
,
help
=
"""Tokenizer mode:
\n
- "auto" will use the tokenizer from `mistral_common` for Mistral models
if available, otherwise it will use the "hf" tokenizer.
\n
- "hf" will use the fast tokenizer if available.
\n
- "slow" will always use the slow tokenizer.
\n
- "mistral" will always use the tokenizer from `mistral_common`.
\n
- "deepseek_v32" will always use the tokenizer from `deepseek_v32`.
\n
- Other custom values can be supported via plugins."""
,
)
parser
.
add_argument
(
"--use-beam-search"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--use-beam-search"
,
action
=
"store_true"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--logprobs"
,
"--logprobs"
,
...
@@ -1228,18 +1240,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
...
@@ -1228,18 +1240,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
help
=
"Common prefix length shared by all prompts (used by random dataset)"
,
help
=
"Common prefix length shared by all prompts (used by random dataset)"
,
)
)
parser
.
add_argument
(
"--tokenizer-mode"
,
type
=
str
,
default
=
"auto"
,
choices
=
[
"auto"
,
"slow"
,
"mistral"
,
"custom"
],
help
=
'The tokenizer mode.
\n\n
* "auto" will use the '
'fast tokenizer if available.
\n
* "slow" will '
"always use the slow tokenizer.
\n
* "
'"mistral" will always use the `mistral_common` tokenizer.
\n
*'
'"custom" will use --tokenizer to select the preregistered tokenizer.'
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--served-model-name"
,
"--served-model-name"
,
type
=
str
,
type
=
str
,
...
...
vllm/benchmarks/throughput.py
View file @
40a046cd
...
@@ -14,7 +14,7 @@ from typing import Any
...
@@ -14,7 +14,7 @@ from typing import Any
import
torch
import
torch
import
uvloop
import
uvloop
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
,
PreTrainedTokenizerBase
from
transformers
import
AutoModelForCausalLM
,
PreTrainedTokenizerBase
from
vllm.benchmarks.datasets
import
(
from
vllm.benchmarks.datasets
import
(
AIMODataset
,
AIMODataset
,
...
@@ -35,6 +35,7 @@ from vllm.inputs import TextPrompt, TokensPrompt
...
@@ -35,6 +35,7 @@ from vllm.inputs import TextPrompt, TokensPrompt
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
BeamSearchParams
from
vllm.sampling_params
import
BeamSearchParams
from
vllm.tokenizers
import
TokenizerLike
,
get_tokenizer
from
vllm.utils.async_utils
import
merge_async_iterators
from
vllm.utils.async_utils
import
merge_async_iterators
...
@@ -246,12 +247,15 @@ async def run_vllm_async(
...
@@ -246,12 +247,15 @@ async def run_vllm_async(
def
run_hf
(
def
run_hf
(
requests
:
list
[
SampleRequest
],
requests
:
list
[
SampleRequest
],
model
:
str
,
model
:
str
,
tokenizer
:
PreTrained
Tokenizer
Bas
e
,
tokenizer
:
Tokenizer
Lik
e
,
n
:
int
,
n
:
int
,
max_batch_size
:
int
,
max_batch_size
:
int
,
trust_remote_code
:
bool
,
trust_remote_code
:
bool
,
disable_detokenize
:
bool
=
False
,
disable_detokenize
:
bool
=
False
,
)
->
float
:
)
->
float
:
assert
isinstance
(
tokenizer
,
PreTrainedTokenizerBase
),
(
"the hf backend only supports HF tokenizers"
)
llm
=
AutoModelForCausalLM
.
from_pretrained
(
llm
=
AutoModelForCausalLM
.
from_pretrained
(
model
,
dtype
=
torch
.
float16
,
trust_remote_code
=
trust_remote_code
model
,
dtype
=
torch
.
float16
,
trust_remote_code
=
trust_remote_code
)
)
...
@@ -692,15 +696,21 @@ def add_cli_args(parser: argparse.ArgumentParser):
...
@@ -692,15 +696,21 @@ def add_cli_args(parser: argparse.ArgumentParser):
def
main
(
args
:
argparse
.
Namespace
):
def
main
(
args
:
argparse
.
Namespace
):
if
args
.
tokenizer
is
None
:
args
.
tokenizer
=
args
.
model
validate_args
(
args
)
validate_args
(
args
)
if
args
.
seed
is
None
:
if
args
.
seed
is
None
:
args
.
seed
=
0
args
.
seed
=
0
random
.
seed
(
args
.
seed
)
random
.
seed
(
args
.
seed
)
# Sample the requests.
# Sample the requests.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
if
(
args
.
tokenizer
,
trust_remote_code
=
args
.
trust_remote_code
args
.
backend
==
"hf"
or
args
.
backend
==
"mii"
)
and
args
.
tokenizer_mode
==
"auto"
:
# mistral_common tokenizer is only supported on vllm and vllm-chat backends;
# for hf and mii backends, we use hf tokenizer
args
.
tokenizer_mode
=
"hf"
tokenizer
=
get_tokenizer
(
args
.
tokenizer
,
tokenizer_mode
=
args
.
tokenizer_mode
,
trust_remote_code
=
args
.
trust_remote_code
,
)
)
requests
=
get_requests
(
args
,
tokenizer
)
requests
=
get_requests
(
args
,
tokenizer
)
is_multi_modal
=
any
(
request
.
multi_modal_data
is
not
None
for
request
in
requests
)
is_multi_modal
=
any
(
request
.
multi_modal_data
is
not
None
for
request
in
requests
)
...
...
vllm/config/model.py
View file @
40a046cd
...
@@ -136,7 +136,8 @@ class ModelConfig:
...
@@ -136,7 +136,8 @@ class ModelConfig:
name or path will be used."""
name or path will be used."""
tokenizer_mode
:
TokenizerMode
|
str
=
"auto"
tokenizer_mode
:
TokenizerMode
|
str
=
"auto"
"""Tokenizer mode:
\n
"""Tokenizer mode:
\n
- "auto" will use "hf" tokenizer if Mistral's tokenizer is not available.
\n
- "auto" will use the tokenizer from `mistral_common` for Mistral models
if available, otherwise it will use the "hf" tokenizer.
\n
- "hf" will use the fast tokenizer if available.
\n
- "hf" will use the fast tokenizer if available.
\n
- "slow" will always use the slow tokenizer.
\n
- "slow" will always use the slow tokenizer.
\n
- "mistral" will always use the tokenizer from `mistral_common`.
\n
- "mistral" will always use the tokenizer from `mistral_common`.
\n
...
...
vllm/tokenizers/deepseekv32.py
View file @
40a046cd
...
@@ -54,6 +54,9 @@ class DeepseekV32Tokenizer(HfTokenizer):
...
@@ -54,6 +54,9 @@ class DeepseekV32Tokenizer(HfTokenizer):
prompt_str
=
encode_messages
(
messages
,
**
encode_config
)
# type: ignore
prompt_str
=
encode_messages
(
messages
,
**
encode_config
)
# type: ignore
return
prompt_str
return
prompt_str
def
num_special_tokens_to_add
(
self
)
->
int
:
return
len
(
self
.
encode
(
""
))
@
property
@
property
def
all_special_tokens
(
self
)
->
list
[
str
]:
def
all_special_tokens
(
self
)
->
list
[
str
]:
return
self
.
tokenizer
.
all_special_tokens
return
self
.
tokenizer
.
all_special_tokens
...
...
vllm/tokenizers/mistral.py
View file @
40a046cd
...
@@ -309,6 +309,9 @@ class MistralTokenizer(TokenizerLike):
...
@@ -309,6 +309,9 @@ class MistralTokenizer(TokenizerLike):
for
i
in
all_special_ids
for
i
in
all_special_ids
]
]
def
num_special_tokens_to_add
(
self
)
->
int
:
return
len
(
self
.
encode
(
""
))
# the following attributes are set to fit vLLM's design and are used
# the following attributes are set to fit vLLM's design and are used
# by the structured output backends.
# by the structured output backends.
@
property
@
property
...
@@ -421,6 +424,7 @@ class MistralTokenizer(TokenizerLike):
...
@@ -421,6 +424,7 @@ class MistralTokenizer(TokenizerLike):
)
->
list
[
int
]:
)
->
list
[
int
]:
add_generation_prompt
=
kwargs
.
pop
(
"add_generation_prompt"
,
False
)
add_generation_prompt
=
kwargs
.
pop
(
"add_generation_prompt"
,
False
)
continue_final_message
=
kwargs
.
get
(
"continue_final_message"
,
False
)
continue_final_message
=
kwargs
.
get
(
"continue_final_message"
,
False
)
tokenize
=
kwargs
.
get
(
"tokenize"
,
True
)
padding
=
kwargs
.
get
(
"padding"
,
False
)
padding
=
kwargs
.
get
(
"padding"
,
False
)
truncation
=
kwargs
.
get
(
"truncation"
,
False
)
truncation
=
kwargs
.
get
(
"truncation"
,
False
)
max_length
=
kwargs
.
get
(
"max_length"
)
max_length
=
kwargs
.
get
(
"max_length"
)
...
@@ -433,7 +437,7 @@ class MistralTokenizer(TokenizerLike):
...
@@ -433,7 +437,7 @@ class MistralTokenizer(TokenizerLike):
conversation
=
messages
,
conversation
=
messages
,
tools
=
tools
,
tools
=
tools
,
continue_final_message
=
continue_final_message
,
continue_final_message
=
continue_final_message
,
tokenize
=
Tru
e
,
tokenize
=
tokeniz
e
,
padding
=
padding
,
padding
=
padding
,
truncation
=
truncation
,
truncation
=
truncation
,
max_length
=
max_length
,
max_length
=
max_length
,
...
...
vllm/tokenizers/protocol.py
View file @
40a046cd
...
@@ -22,6 +22,9 @@ class TokenizerLike(Protocol):
...
@@ -22,6 +22,9 @@ class TokenizerLike(Protocol):
)
->
"TokenizerLike"
:
)
->
"TokenizerLike"
:
raise
NotImplementedError
raise
NotImplementedError
def
num_special_tokens_to_add
(
self
)
->
int
:
raise
NotImplementedError
@
property
@
property
def
all_special_tokens
(
self
)
->
list
[
str
]:
def
all_special_tokens
(
self
)
->
list
[
str
]:
raise
NotImplementedError
raise
NotImplementedError
...
...
vllm/tokenizers/registry.py
View file @
40a046cd
...
@@ -183,7 +183,7 @@ def get_tokenizer(
...
@@ -183,7 +183,7 @@ def get_tokenizer(
"`tokenizer_mode='custom'` when initializing vLLM."
,
"`tokenizer_mode='custom'` when initializing vLLM."
,
tokenizer_args
,
tokenizer_args
,
str
(
tokenizer_kwargs
),
str
(
tokenizer_kwargs
),
tokenizer_
mod
e
,
tokenizer_
nam
e
,
)
)
tokenizer_mode
=
str
(
tokenizer_name
)
tokenizer_mode
=
str
(
tokenizer_name
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment