Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
081057de
Commit
081057de
authored
Apr 29, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.5' into v0.8.5-ori
parents
7cf5d5c4
ba41cc90
Changes
554
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
832 additions
and
327 deletions
+832
-327
tests/models/embedding/language/test_jina.py
tests/models/embedding/language/test_jina.py
+21
-11
tests/models/embedding/language/test_snowflake_arctic_embed.py
.../models/embedding/language/test_snowflake_arctic_embed.py
+101
-0
tests/models/embedding/utils.py
tests/models/embedding/utils.py
+27
-0
tests/models/encoder_decoder/vision_language/test_florence2.py
.../models/encoder_decoder/vision_language/test_florence2.py
+10
-7
tests/models/multimodal/processing/test_common.py
tests/models/multimodal/processing/test_common.py
+4
-0
tests/models/multimodal/processing/test_phi4mm.py
tests/models/multimodal/processing/test_phi4mm.py
+59
-0
tests/models/registry.py
tests/models/registry.py
+45
-25
tests/models/test_bitblas.py
tests/models/test_bitblas.py
+63
-0
tests/models/test_gptq_bitblas.py
tests/models/test_gptq_bitblas.py
+61
-0
tests/models/test_initialization.py
tests/models/test_initialization.py
+1
-4
tests/models/test_oot_registration.py
tests/models/test_oot_registration.py
+2
-3
tests/quantization/test_compressed_tensors.py
tests/quantization/test_compressed_tensors.py
+15
-6
tests/samplers/test_beam_search.py
tests/samplers/test_beam_search.py
+82
-3
tests/spec_decode/test_scorer.py
tests/spec_decode/test_scorer.py
+2
-3
tests/test_config.py
tests/test_config.py
+24
-2
tests/test_utils.py
tests/test_utils.py
+128
-5
tests/tokenization/test_cached_tokenizer.py
tests/tokenization/test_cached_tokenizer.py
+31
-12
tests/tokenization/test_detokenize.py
tests/tokenization/test_detokenize.py
+139
-62
tests/tokenization/test_tokenizer_group.py
tests/tokenization/test_tokenizer_group.py
+3
-184
tests/tool_use/utils.py
tests/tool_use/utils.py
+14
-0
No files found.
Too many changes to show.
To preserve performance only
554 of 554+
files are displayed.
Plain diff
Email patch
tests/models/embedding/language/test_jina.py
View file @
081057de
...
@@ -153,14 +153,24 @@ def test_matryoshka(
...
@@ -153,14 +153,24 @@ def test_matryoshka(
with
vllm_runner
(
model
,
task
=
"embed"
,
dtype
=
dtype
,
with
vllm_runner
(
model
,
task
=
"embed"
,
dtype
=
dtype
,
max_model_len
=
None
)
as
vllm_model
:
max_model_len
=
None
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
encode
(
matryoshka_dimensions
=
(
example_prompts
,
vllm_model
.
model
.
llm_engine
.
model_config
.
matryoshka_dimensions
)
pooling_params
=
PoolingParams
(
dimensions
=
dimensions
))
assert
matryoshka_dimensions
is
not
None
check_embeddings_close
(
if
dimensions
not
in
matryoshka_dimensions
:
embeddings_0_lst
=
hf_outputs
,
with
pytest
.
raises
(
ValueError
):
embeddings_1_lst
=
vllm_outputs
,
vllm_model
.
encode
(
name_0
=
"hf"
,
example_prompts
,
name_1
=
"vllm"
,
pooling_params
=
PoolingParams
(
dimensions
=
dimensions
))
tol
=
1e-2
,
else
:
)
vllm_outputs
=
vllm_model
.
encode
(
example_prompts
,
pooling_params
=
PoolingParams
(
dimensions
=
dimensions
))
check_embeddings_close
(
embeddings_0_lst
=
hf_outputs
,
embeddings_1_lst
=
vllm_outputs
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
tol
=
1e-2
,
)
tests/models/embedding/language/test_snowflake_arctic_embed.py
0 → 100644
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
"""Compare the embedding outputs of HF and vLLM models.
Run `pytest tests/models/embedding/language/test_snowflake_arctic_embed.py`.
"""
import
pytest
from
tests.models.embedding.utils
import
EmbedModelInfo
from
..utils
import
check_embeddings_close
EMBEDDING_PROMPTS
=
[
'what is snowflake?'
,
'Where can I get the best tacos?'
,
'The Data Cloud!'
,
'Mexico City of Course!'
]
MODELS
=
[
EmbedModelInfo
(
"Snowflake/snowflake-arctic-embed-xs"
,
is_matryoshka
=
False
,
architecture
=
"BertModel"
,
enable_test
=
True
),
EmbedModelInfo
(
"Snowflake/snowflake-arctic-embed-s"
,
is_matryoshka
=
False
,
architecture
=
"BertModel"
,
enable_test
=
False
),
EmbedModelInfo
(
"Snowflake/snowflake-arctic-embed-m"
,
is_matryoshka
=
False
,
architecture
=
"BertModel"
,
enable_test
=
False
),
EmbedModelInfo
(
"Snowflake/snowflake-arctic-embed-m-long"
,
is_matryoshka
=
False
,
architecture
=
"NomicBertModel"
,
enable_test
=
True
),
EmbedModelInfo
(
"Snowflake/snowflake-arctic-embed-l"
,
is_matryoshka
=
False
,
architecture
=
"BertModel"
,
enable_test
=
False
),
EmbedModelInfo
(
"Snowflake/snowflake-arctic-embed-m-v1.5"
,
is_matryoshka
=
True
,
architecture
=
"BertModel"
,
enable_test
=
True
),
EmbedModelInfo
(
"Snowflake/snowflake-arctic-embed-l-v2.0"
,
is_matryoshka
=
True
,
architecture
=
"XLMRobertaModel"
,
enable_test
=
True
),
EmbedModelInfo
(
"Snowflake/snowflake-arctic-embed-m-v2.0"
,
is_matryoshka
=
True
,
architecture
=
"GteModel"
,
enable_test
=
True
),
]
@
pytest
.
mark
.
parametrize
(
"model_info"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
def
test_models
(
hf_runner
,
vllm_runner
,
example_prompts
,
model_info
:
EmbedModelInfo
,
dtype
:
str
,
monkeypatch
,
)
->
None
:
if
not
model_info
.
enable_test
:
# A model family has many models with the same architecture,
# and we don't need to test each one.
pytest
.
skip
(
"Skipping test."
)
example_prompts
=
example_prompts
+
EMBEDDING_PROMPTS
vllm_extra_kwargs
=
{
"hf_overrides"
:
{
"is_matryoshka"
:
model_info
.
is_matryoshka
}
}
with
hf_runner
(
model_info
.
name
,
dtype
=
dtype
,
is_sentence_transformer
=
True
)
as
hf_model
:
hf_outputs
=
hf_model
.
encode
(
example_prompts
)
with
vllm_runner
(
model_info
.
name
,
task
=
"embed"
,
dtype
=
dtype
,
max_model_len
=
None
,
**
vllm_extra_kwargs
)
as
vllm_model
:
assert
(
vllm_model
.
model
.
llm_engine
.
model_config
.
is_matryoshka
==
model_info
.
is_matryoshka
)
if
model_info
.
architecture
:
assert
(
model_info
.
architecture
in
vllm_model
.
model
.
llm_engine
.
model_config
.
architectures
)
vllm_outputs
=
vllm_model
.
encode
(
example_prompts
)
check_embeddings_close
(
embeddings_0_lst
=
hf_outputs
,
embeddings_1_lst
=
vllm_outputs
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
tol
=
1e-2
,
)
tests/models/embedding/utils.py
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
typing
import
NamedTuple
,
Optional
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -37,3 +38,29 @@ def matryoshka_fy(tensor, dimensions):
...
@@ -37,3 +38,29 @@ def matryoshka_fy(tensor, dimensions):
tensor
=
tensor
[...,
:
dimensions
]
tensor
=
tensor
[...,
:
dimensions
]
tensor
=
F
.
normalize
(
tensor
,
p
=
2
,
dim
=
1
)
tensor
=
F
.
normalize
(
tensor
,
p
=
2
,
dim
=
1
)
return
tensor
return
tensor
class
EmbedModelInfo
(
NamedTuple
):
name
:
str
is_matryoshka
:
bool
matryoshka_dimensions
:
Optional
[
list
[
int
]]
=
None
architecture
:
str
=
""
enable_test
:
bool
=
True
def
correctness_test
(
hf_model
,
inputs
,
vllm_outputs
:
Sequence
[
list
[
float
]],
dimensions
:
Optional
[
int
]
=
None
):
hf_outputs
=
hf_model
.
encode
(
inputs
)
if
dimensions
:
hf_outputs
=
matryoshka_fy
(
hf_outputs
,
dimensions
)
check_embeddings_close
(
embeddings_0_lst
=
hf_outputs
,
embeddings_1_lst
=
vllm_outputs
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
tol
=
1e-2
,
)
tests/models/encoder_decoder/vision_language/test_florence2.py
View file @
081057de
...
@@ -13,12 +13,12 @@ from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
...
@@ -13,12 +13,12 @@ from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from
...utils
import
check_logprobs_close
from
...utils
import
check_logprobs_close
MODELS
=
[
"microsoft/Florence-2-base"
]
MODELS
=
[
"microsoft/Florence-2-base"
]
# Florence-2
uses BartFastTokenizer which can't be loaded from AutoTokenizer
# Florence-2
model repo's tokenizer config is missing some special tokens.
# Therefore, we
borrow the BartTokenizer from the original Bart model
# Therefore, we
use a converted tokenizer from a forked repo
TOKENIZER
=
"
facebook/bart-base
"
TOKENIZER
=
"
Isotr0py/Florence-2-tokenizer
"
HF_IMAGE_PROMPTS
=
IMAGE_ASSETS
.
prompts
({
HF_IMAGE_PROMPTS
=
IMAGE_ASSETS
.
prompts
({
"stop_sign"
:
"stop_sign"
:
"<
CAPTION
>"
,
# special task token
"<
OD
>"
,
# special task token
which will output special tokens
"cherry_blossom"
:
"cherry_blossom"
:
"Describe in detail what is shown in the image."
,
"Describe in detail what is shown in the image."
,
})
})
...
@@ -45,7 +45,6 @@ def hf_to_vllm_output(hf_output: tuple[list[int], str,
...
@@ -45,7 +45,6 @@ def hf_to_vllm_output(hf_output: tuple[list[int], str,
output_ids
,
output_str
,
out_logprobs
=
hf_output
output_ids
,
output_str
,
out_logprobs
=
hf_output
output_str
=
output_str
.
replace
(
"</s>"
,
""
).
replace
(
"<s>"
,
""
)
output_str
=
output_str
.
replace
(
"</s>"
,
""
).
replace
(
"<s>"
,
""
)
output_ids
=
[
ids
for
ids
in
output_ids
if
ids
not
in
[
0
,
2
]]
return
output_ids
,
output_str
,
out_logprobs
return
output_ids
,
output_str
,
out_logprobs
...
@@ -71,8 +70,11 @@ def run_test(
...
@@ -71,8 +70,11 @@ def run_test(
enforce_eager
=
True
)
as
vllm_model
:
enforce_eager
=
True
)
as
vllm_model
:
vllm_outputs_per_case
=
[
vllm_outputs_per_case
=
[
vllm_model
.
generate_encoder_decoder_greedy_logprobs
(
vllm_model
.
generate_encoder_decoder_greedy_logprobs
(
prompts
,
max_tokens
,
num_logprobs
=
num_logprobs
)
prompts
,
for
prompts
in
inputs
max_tokens
,
num_logprobs
=
num_logprobs
,
skip_special_tokens
=
False
,
)
for
prompts
in
inputs
]
]
hf_inputs
=
[
get_hf_images_prompts
(
prompts
)
for
prompts
in
inputs
]
hf_inputs
=
[
get_hf_images_prompts
(
prompts
)
for
prompts
in
inputs
]
...
@@ -93,6 +95,7 @@ def run_test(
...
@@ -93,6 +95,7 @@ def run_test(
outputs_1_lst
=
vllm_outputs
,
outputs_1_lst
=
vllm_outputs
,
name_0
=
"hf"
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
name_1
=
"vllm"
,
num_outputs_0_skip_tokens
=
1
,
)
)
...
...
tests/models/multimodal/processing/test_common.py
View file @
081057de
...
@@ -254,10 +254,12 @@ def _test_processing_correctness_mistral(
...
@@ -254,10 +254,12 @@ def _test_processing_correctness_mistral(
"adept/fuyu-8b"
,
"adept/fuyu-8b"
,
"google/gemma-3-4b-it"
,
"google/gemma-3-4b-it"
,
"THUDM/glm-4v-9b"
,
"THUDM/glm-4v-9b"
,
"ibm-granite/granite-speech-3.3-8b"
,
"h2oai/h2ovl-mississippi-800m"
,
"h2oai/h2ovl-mississippi-800m"
,
"OpenGVLab/InternVL2-1B"
,
"OpenGVLab/InternVL2-1B"
,
"HuggingFaceM4/Idefics3-8B-Llama3"
,
"HuggingFaceM4/Idefics3-8B-Llama3"
,
"HuggingFaceTB/SmolVLM2-2.2B-Instruct"
,
"HuggingFaceTB/SmolVLM2-2.2B-Instruct"
,
"moonshotai/Kimi-VL-A3B-Instruct"
,
"meta-llama/Llama-4-Scout-17B-16E-Instruct"
,
"meta-llama/Llama-4-Scout-17B-16E-Instruct"
,
"llava-hf/llava-1.5-7b-hf"
,
"llava-hf/llava-1.5-7b-hf"
,
"llava-hf/llava-v1.6-mistral-7b-hf"
,
"llava-hf/llava-v1.6-mistral-7b-hf"
,
...
@@ -273,12 +275,14 @@ def _test_processing_correctness_mistral(
...
@@ -273,12 +275,14 @@ def _test_processing_correctness_mistral(
"nvidia/NVLM-D-72B"
,
"nvidia/NVLM-D-72B"
,
"google/paligemma-3b-mix-224"
,
"google/paligemma-3b-mix-224"
,
"google/paligemma2-3b-ft-docci-448"
,
"google/paligemma2-3b-ft-docci-448"
,
"microsoft/Phi-4-multimodal-instruct"
,
"mistralai/Pixtral-12B-2409"
,
"mistralai/Pixtral-12B-2409"
,
"mistral-community/pixtral-12b"
,
"mistral-community/pixtral-12b"
,
"Qwen/Qwen-VL-Chat"
,
"Qwen/Qwen-VL-Chat"
,
"Qwen/Qwen2-VL-2B-Instruct"
,
"Qwen/Qwen2-VL-2B-Instruct"
,
"Qwen/Qwen2.5-VL-3B-Instruct"
,
"Qwen/Qwen2.5-VL-3B-Instruct"
,
"Qwen/Qwen2-Audio-7B-Instruct"
,
"Qwen/Qwen2-Audio-7B-Instruct"
,
"Qwen/Qwen2.5-Omni-7B"
,
"Skywork/Skywork-R1V-38B"
,
"Skywork/Skywork-R1V-38B"
,
"fixie-ai/ultravox-v0_5-llama-3_2-1b"
,
"fixie-ai/ultravox-v0_5-llama-3_2-1b"
,
"openai/whisper-large-v3"
,
"openai/whisper-large-v3"
,
...
...
tests/models/multimodal/processing/test_phi4mm.py
0 → 100644
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
"""Tests for phi4mm's multimodal preprocessing kwargs."""
import
pytest
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
....conftest
import
_ImageAssets
from
...utils
import
build_model_context
@
pytest
.
mark
.
parametrize
(
"model_id"
,
[
"microsoft/Phi-4-multimodal-instruct"
])
# yapf: disable
@
pytest
.
mark
.
parametrize
(
(
"mm_processor_kwargs"
,
"expected_toks_per_img"
),
[
({
"dynamic_hd"
:
4
},
1329
),
({
"dynamic_hd"
:
16
},
4433
),
# the default num_crops of phi-4-multimodal is 36
({},
9585
),
])
# yapf: enable
@
pytest
.
mark
.
parametrize
(
"num_imgs"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"kwargs_on_init"
,
[
True
,
False
])
def
test_processor_override
(
image_assets
:
_ImageAssets
,
model_id
:
str
,
mm_processor_kwargs
:
dict
[
str
,
int
],
expected_toks_per_img
:
int
,
num_imgs
:
int
,
kwargs_on_init
:
bool
,
):
"""Ensure Phi4MMMultiModalProcessor handles dynamic_hd properly."""
# Avoid initializing CUDA early
from
vllm.model_executor.models.phi4mm
import
_IMAGE_PLACEHOLDER_TOKEN_ID
ctx
=
build_model_context
(
model_id
,
mm_processor_kwargs
=
mm_processor_kwargs
if
kwargs_on_init
else
None
,
limit_mm_per_prompt
=
{
"image"
:
num_imgs
},
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
ctx
.
model_config
)
hf_processor_mm_kwargs
=
{}
if
kwargs_on_init
else
mm_processor_kwargs
# Build the image str / prompt based on the number of images we pass
img_str
=
""
.
join
([
f
"<|image_
{
idx
}
|>
\n
"
for
idx
in
range
(
1
,
num_imgs
+
1
)])
prompt
=
f
"<|user|>
\n
{
img_str
}
<|end|>
\n
<|assistant|>
\n
"
image_size
=
ctx
.
get_hf_config
(
).
embd_layer
[
"image_embd_layer"
][
"crop_size"
]
dummy_image_size
=
(
image_size
*
7
,
image_size
*
7
)
dummy_image
=
image_assets
[
0
].
pil_image
.
resize
(
dummy_image_size
)
mm_data
=
{
"image"
:
[
dummy_image
]
*
num_imgs
}
processed_inputs
=
processor
.
apply
(
prompt
,
mm_data
,
hf_processor_mm_kwargs
)
# Ensure we have the right number of placeholders per num_crops size
img_tok_count
=
processed_inputs
[
"prompt_token_ids"
].
count
(
_IMAGE_PLACEHOLDER_TOKEN_ID
)
assert
img_tok_count
==
expected_toks_per_img
*
num_imgs
tests/models/registry.py
View file @
081057de
...
@@ -121,9 +121,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
...
@@ -121,9 +121,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"BaichuanForCausalLM"
:
_HfExamplesInfo
(
"baichuan-inc/Baichuan2-7B-chat"
,
"BaichuanForCausalLM"
:
_HfExamplesInfo
(
"baichuan-inc/Baichuan2-7B-chat"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"BambaForCausalLM"
:
_HfExamplesInfo
(
"ibm-ai-platform/Bamba-9B"
),
"BambaForCausalLM"
:
_HfExamplesInfo
(
"ibm-ai-platform/Bamba-9B"
),
"BloomForCausalLM"
:
_HfExamplesInfo
(
"bigscience/bloomz-1b1"
),
"BloomForCausalLM"
:
_HfExamplesInfo
(
"bigscience/bloom-560m"
,
{
"1b"
:
"bigscience/bloomz-1b1"
}),
"ChatGLMModel"
:
_HfExamplesInfo
(
"THUDM/chatglm3-6b"
,
"ChatGLMModel"
:
_HfExamplesInfo
(
"THUDM/chatglm3-6b"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
,
max_transformers_version
=
"4.48"
),
"ChatGLMForConditionalGeneration"
:
_HfExamplesInfo
(
"thu-coai/ShieldLM-6B-chatglm3"
,
# noqa: E501
"ChatGLMForConditionalGeneration"
:
_HfExamplesInfo
(
"thu-coai/ShieldLM-6B-chatglm3"
,
# noqa: E501
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"CohereForCausalLM"
:
_HfExamplesInfo
(
"CohereForAI/c4ai-command-r-v01"
,
"CohereForCausalLM"
:
_HfExamplesInfo
(
"CohereForAI/c4ai-command-r-v01"
,
...
@@ -141,24 +143,26 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
...
@@ -141,24 +143,26 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"ExaoneForCausalLM"
:
_HfExamplesInfo
(
"LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"
),
# noqa: E501
"ExaoneForCausalLM"
:
_HfExamplesInfo
(
"LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"
),
# noqa: E501
"Fairseq2LlamaForCausalLM"
:
_HfExamplesInfo
(
"mgleize/fairseq2-dummy-Llama-3.2-1B"
),
# noqa: E501
"Fairseq2LlamaForCausalLM"
:
_HfExamplesInfo
(
"mgleize/fairseq2-dummy-Llama-3.2-1B"
),
# noqa: E501
"FalconForCausalLM"
:
_HfExamplesInfo
(
"tiiuae/falcon-7b"
),
"FalconForCausalLM"
:
_HfExamplesInfo
(
"tiiuae/falcon-7b"
),
"GemmaForCausalLM"
:
_HfExamplesInfo
(
"google/gemma-
2b
"
),
"GemmaForCausalLM"
:
_HfExamplesInfo
(
"google/gemma-
1.1-2b-it
"
),
"Gemma2ForCausalLM"
:
_HfExamplesInfo
(
"google/gemma-2-9b"
),
"Gemma2ForCausalLM"
:
_HfExamplesInfo
(
"google/gemma-2-9b"
),
"Gemma3ForCausalLM"
:
_HfExamplesInfo
(
"google/gemma-3-1b-it"
,
"Gemma3ForCausalLM"
:
_HfExamplesInfo
(
"google/gemma-3-1b-it"
),
min_transformers_version
=
"4.50"
),
"GlmForCausalLM"
:
_HfExamplesInfo
(
"THUDM/glm-4-9b-chat-hf"
),
"GlmForCausalLM"
:
_HfExamplesInfo
(
"THUDM/glm-4-9b-chat-hf"
),
"Glm4ForCausalLM"
:
_HfExamplesInfo
(
"Glm4ForCausalLM"
:
_HfExamplesInfo
(
"THUDM/GLM-4-32B-
Chat-
0414"
,
"THUDM/GLM-4-32B-0414"
,
is_available_online
=
False
,
is_available_online
=
False
,
min_transformers_version
=
"4.52.dev0"
min_transformers_version
=
"4.52.dev0"
),
),
"GPT2LMHeadModel"
:
_HfExamplesInfo
(
"gpt2"
),
"GPT2LMHeadModel"
:
_HfExamplesInfo
(
"openai-community/gpt2"
,
"GPTBigCodeForCausalLM"
:
_HfExamplesInfo
(
"bigcode/starcoder"
),
{
"alias"
:
"gpt2"
}),
"GPTJForCausalLM"
:
_HfExamplesInfo
(
"EleutherAI/gpt-j-6b"
),
"GPTBigCodeForCausalLM"
:
_HfExamplesInfo
(
"bigcode/starcoder"
,
"GPTNeoXForCausalLM"
:
_HfExamplesInfo
(
"EleutherAI/pythia-160m"
),
{
"tiny"
:
"bigcode/tiny_starcoder_py"
}),
# noqa: E501
"GPTJForCausalLM"
:
_HfExamplesInfo
(
"Milos/slovak-gpt-j-405M"
,
{
"6b"
:
"EleutherAI/gpt-j-6b"
}),
"GPTNeoXForCausalLM"
:
_HfExamplesInfo
(
"EleutherAI/pythia-70m"
,
{
"1b"
:
"EleutherAI/pythia-1.4b"
}),
"GraniteForCausalLM"
:
_HfExamplesInfo
(
"ibm/PowerLM-3b"
),
"GraniteForCausalLM"
:
_HfExamplesInfo
(
"ibm/PowerLM-3b"
),
"GraniteMoeForCausalLM"
:
_HfExamplesInfo
(
"ibm/PowerMoE-3b"
),
"GraniteMoeForCausalLM"
:
_HfExamplesInfo
(
"ibm/PowerMoE-3b"
),
"GraniteMoeSharedForCausalLM"
:
_HfExamplesInfo
(
"ibm-research/moe-7b-1b-active-shared-experts"
,
# noqa: E501
"GraniteMoeSharedForCausalLM"
:
_HfExamplesInfo
(
"ibm-research/moe-7b-1b-active-shared-experts"
),
# noqa: E501
min_transformers_version
=
"4.49"
),
# noqa: E501
"Grok1ModelForCausalLM"
:
_HfExamplesInfo
(
"hpcai-tech/grok-1"
,
"Grok1ModelForCausalLM"
:
_HfExamplesInfo
(
"hpcai-tech/grok-1"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"InternLMForCausalLM"
:
_HfExamplesInfo
(
"internlm/internlm-chat-7b"
,
"InternLMForCausalLM"
:
_HfExamplesInfo
(
"internlm/internlm-chat-7b"
,
...
@@ -186,7 +190,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
...
@@ -186,7 +190,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"MiniMaxText01ForCausalLM"
:
_HfExamplesInfo
(
"MiniMaxAI/MiniMax-Text-01"
,
"MiniMaxText01ForCausalLM"
:
_HfExamplesInfo
(
"MiniMaxAI/MiniMax-Text-01"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"MistralForCausalLM"
:
_HfExamplesInfo
(
"mistralai/Mistral-7B-Instruct-v0.1"
),
"MistralForCausalLM"
:
_HfExamplesInfo
(
"mistralai/Mistral-7B-Instruct-v0.1"
),
"MixtralForCausalLM"
:
_HfExamplesInfo
(
"mistralai/Mixtral-8x7B-Instruct-v0.1"
),
# noqa: E501
"MixtralForCausalLM"
:
_HfExamplesInfo
(
"mistralai/Mixtral-8x7B-Instruct-v0.1"
,
# noqa: E501
{
"falcon3"
:
"ehristoforu/Falcon3-MoE-2x7B-Insruct"
}),
# noqa: E501
"QuantMixtralForCausalLM"
:
_HfExamplesInfo
(
"mistral-community/Mixtral-8x22B-v0.1-AWQ"
),
# noqa: E501
"QuantMixtralForCausalLM"
:
_HfExamplesInfo
(
"mistral-community/Mixtral-8x22B-v0.1-AWQ"
),
# noqa: E501
"MptForCausalLM"
:
_HfExamplesInfo
(
"mpt"
,
is_available_online
=
False
),
"MptForCausalLM"
:
_HfExamplesInfo
(
"mpt"
,
is_available_online
=
False
),
"MPTForCausalLM"
:
_HfExamplesInfo
(
"mosaicml/mpt-7b"
),
"MPTForCausalLM"
:
_HfExamplesInfo
(
"mosaicml/mpt-7b"
),
...
@@ -194,7 +199,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
...
@@ -194,7 +199,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"OlmoForCausalLM"
:
_HfExamplesInfo
(
"allenai/OLMo-1B-hf"
),
"OlmoForCausalLM"
:
_HfExamplesInfo
(
"allenai/OLMo-1B-hf"
),
"Olmo2ForCausalLM"
:
_HfExamplesInfo
(
"shanearora/OLMo-7B-1124-hf"
),
"Olmo2ForCausalLM"
:
_HfExamplesInfo
(
"shanearora/OLMo-7B-1124-hf"
),
"OlmoeForCausalLM"
:
_HfExamplesInfo
(
"allenai/OLMoE-1B-7B-0924-Instruct"
),
"OlmoeForCausalLM"
:
_HfExamplesInfo
(
"allenai/OLMoE-1B-7B-0924-Instruct"
),
"OPTForCausalLM"
:
_HfExamplesInfo
(
"facebook/opt-iml-max-1.3b"
),
"OPTForCausalLM"
:
_HfExamplesInfo
(
"facebook/opt-125m"
,
{
"1b"
:
"facebook/opt-iml-max-1.3b"
}),
"OrionForCausalLM"
:
_HfExamplesInfo
(
"OrionStarAI/Orion-14B-Chat"
,
"OrionForCausalLM"
:
_HfExamplesInfo
(
"OrionStarAI/Orion-14B-Chat"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"PersimmonForCausalLM"
:
_HfExamplesInfo
(
"adept/persimmon-8b-chat"
),
"PersimmonForCausalLM"
:
_HfExamplesInfo
(
"adept/persimmon-8b-chat"
),
...
@@ -204,10 +210,12 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
...
@@ -204,10 +210,12 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"PhiMoEForCausalLM"
:
_HfExamplesInfo
(
"microsoft/Phi-3.5-MoE-instruct"
,
"PhiMoEForCausalLM"
:
_HfExamplesInfo
(
"microsoft/Phi-3.5-MoE-instruct"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"Plamo2ForCausalLM"
:
_HfExamplesInfo
(
"pfnet/plamo-2-1b"
,
trust_remote_code
=
True
),
"QWenLMHeadModel"
:
_HfExamplesInfo
(
"Qwen/Qwen-7B-Chat"
,
"QWenLMHeadModel"
:
_HfExamplesInfo
(
"Qwen/Qwen-7B-Chat"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"Qwen2ForCausalLM"
:
_HfExamplesInfo
(
"Qwen/Qwen2-
7
B-Instruct"
,
"Qwen2ForCausalLM"
:
_HfExamplesInfo
(
"Qwen/Qwen2-
0.5
B-Instruct"
,
extras
=
{
"2.5"
:
"Qwen/Qwen2.5-
7
B-Instruct"
}),
# noqa: E501
extras
=
{
"2.5"
:
"Qwen/Qwen2.5-
0.5
B-Instruct"
}),
# noqa: E501
"Qwen2MoeForCausalLM"
:
_HfExamplesInfo
(
"Qwen/Qwen1.5-MoE-A2.7B-Chat"
),
"Qwen2MoeForCausalLM"
:
_HfExamplesInfo
(
"Qwen/Qwen1.5-MoE-A2.7B-Chat"
),
"Qwen3ForCausalLM"
:
_HfExamplesInfo
(
"Qwen3ForCausalLM"
:
_HfExamplesInfo
(
"Qwen/Qwen3-8B"
,
"Qwen/Qwen3-8B"
,
...
@@ -233,8 +241,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
...
@@ -233,8 +241,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"XverseForCausalLM"
:
_HfExamplesInfo
(
"xverse/XVERSE-7B-Chat"
,
"XverseForCausalLM"
:
_HfExamplesInfo
(
"xverse/XVERSE-7B-Chat"
,
is_available_online
=
False
,
is_available_online
=
False
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"Zamba2ForCausalLM"
:
_HfExamplesInfo
(
"Zyphra/Zamba2-7B-instruct"
,
"Zamba2ForCausalLM"
:
_HfExamplesInfo
(
"Zyphra/Zamba2-7B-instruct"
),
min_transformers_version
=
"4.49"
),
# [Encoder-decoder]
# [Encoder-decoder]
"BartModel"
:
_HfExamplesInfo
(
"facebook/bart-base"
),
"BartModel"
:
_HfExamplesInfo
(
"facebook/bart-base"
),
"BartForConditionalGeneration"
:
_HfExamplesInfo
(
"facebook/bart-large-cnn"
),
"BartForConditionalGeneration"
:
_HfExamplesInfo
(
"facebook/bart-large-cnn"
),
...
@@ -245,11 +252,15 @@ _EMBEDDING_EXAMPLE_MODELS = {
...
@@ -245,11 +252,15 @@ _EMBEDDING_EXAMPLE_MODELS = {
"BertModel"
:
_HfExamplesInfo
(
"BAAI/bge-base-en-v1.5"
),
"BertModel"
:
_HfExamplesInfo
(
"BAAI/bge-base-en-v1.5"
),
"Gemma2Model"
:
_HfExamplesInfo
(
"BAAI/bge-multilingual-gemma2"
),
"Gemma2Model"
:
_HfExamplesInfo
(
"BAAI/bge-multilingual-gemma2"
),
"GritLM"
:
_HfExamplesInfo
(
"parasail-ai/GritLM-7B-vllm"
),
"GritLM"
:
_HfExamplesInfo
(
"parasail-ai/GritLM-7B-vllm"
),
"GteModel"
:
_HfExamplesInfo
(
"Snowflake/snowflake-arctic-embed-m-v2.0"
,
trust_remote_code
=
True
),
"InternLM2ForRewardModel"
:
_HfExamplesInfo
(
"internlm/internlm2-1_8b-reward"
,
"InternLM2ForRewardModel"
:
_HfExamplesInfo
(
"internlm/internlm2-1_8b-reward"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"JambaForSequenceClassification"
:
_HfExamplesInfo
(
"ai21labs/Jamba-tiny-reward-dev"
),
# noqa: E501
"JambaForSequenceClassification"
:
_HfExamplesInfo
(
"ai21labs/Jamba-tiny-reward-dev"
),
# noqa: E501
"LlamaModel"
:
_HfExamplesInfo
(
"llama"
,
is_available_online
=
False
),
"LlamaModel"
:
_HfExamplesInfo
(
"llama"
,
is_available_online
=
False
),
"MistralModel"
:
_HfExamplesInfo
(
"intfloat/e5-mistral-7b-instruct"
),
"MistralModel"
:
_HfExamplesInfo
(
"intfloat/e5-mistral-7b-instruct"
),
"NomicBertModel"
:
_HfExamplesInfo
(
"Snowflake/snowflake-arctic-embed-m-long"
,
# noqa: E501
trust_remote_code
=
True
),
"Qwen2Model"
:
_HfExamplesInfo
(
"ssmits/Qwen2-7B-Instruct-embed-base"
),
"Qwen2Model"
:
_HfExamplesInfo
(
"ssmits/Qwen2-7B-Instruct-embed-base"
),
"Qwen2ForRewardModel"
:
_HfExamplesInfo
(
"Qwen/Qwen2.5-Math-RM-72B"
),
"Qwen2ForRewardModel"
:
_HfExamplesInfo
(
"Qwen/Qwen2.5-Math-RM-72B"
),
"Qwen2ForProcessRewardModel"
:
_HfExamplesInfo
(
"Qwen/Qwen2.5-Math-PRM-7B"
),
"Qwen2ForProcessRewardModel"
:
_HfExamplesInfo
(
"Qwen/Qwen2.5-Math-PRM-7B"
),
...
@@ -273,6 +284,7 @@ _CROSS_ENCODER_EXAMPLE_MODELS = {
...
@@ -273,6 +284,7 @@ _CROSS_ENCODER_EXAMPLE_MODELS = {
"BertForSequenceClassification"
:
_HfExamplesInfo
(
"cross-encoder/ms-marco-MiniLM-L-6-v2"
),
# noqa: E501
"BertForSequenceClassification"
:
_HfExamplesInfo
(
"cross-encoder/ms-marco-MiniLM-L-6-v2"
),
# noqa: E501
"RobertaForSequenceClassification"
:
_HfExamplesInfo
(
"cross-encoder/quora-roberta-base"
),
# noqa: E501
"RobertaForSequenceClassification"
:
_HfExamplesInfo
(
"cross-encoder/quora-roberta-base"
),
# noqa: E501
"XLMRobertaForSequenceClassification"
:
_HfExamplesInfo
(
"BAAI/bge-reranker-v2-m3"
),
# noqa: E501
"XLMRobertaForSequenceClassification"
:
_HfExamplesInfo
(
"BAAI/bge-reranker-v2-m3"
),
# noqa: E501
"ModernBertForSequenceClassification"
:
_HfExamplesInfo
(
"Alibaba-NLP/gte-reranker-modernbert-base"
),
# noqa: E501
}
}
_MULTIMODAL_EXAMPLE_MODELS
=
{
_MULTIMODAL_EXAMPLE_MODELS
=
{
...
@@ -286,10 +298,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
...
@@ -286,10 +298,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
extras
=
{
"fork"
:
"Isotr0py/deepseek-vl2-tiny"
},
# noqa: E501
extras
=
{
"fork"
:
"Isotr0py/deepseek-vl2-tiny"
},
# noqa: E501
max_transformers_version
=
"4.48"
,
# noqa: E501
max_transformers_version
=
"4.48"
,
# noqa: E501
transformers_version_reason
=
"HF model is not compatible."
,
# noqa: E501
transformers_version_reason
=
"HF model is not compatible."
,
# noqa: E501
hf_overrides
=
{
"architectures"
:
[
"DeepseekVLV2ForCausalLM"
]}),
# noqa: E501
hf_overrides
=
{
"architectures"
:
[
"DeepseekVLV2ForCausalLM"
]}),
# noqa: E501
"FuyuForCausalLM"
:
_HfExamplesInfo
(
"adept/fuyu-8b"
),
"FuyuForCausalLM"
:
_HfExamplesInfo
(
"adept/fuyu-8b"
),
"Gemma3ForConditionalGeneration"
:
_HfExamplesInfo
(
"google/gemma-3-4b-it"
,
"Gemma3ForConditionalGeneration"
:
_HfExamplesInfo
(
"google/gemma-3-4b-it"
),
min_transformers_version
=
"4.50"
),
"GraniteSpeechForConditionalGeneration"
:
_HfExamplesInfo
(
"ibm-granite/granite-speech-3.3-8b"
,
# noqa: E501
min_transformers_version
=
"4.52.0"
),
# noqa: E501
"GLM4VForCausalLM"
:
_HfExamplesInfo
(
"THUDM/glm-4v-9b"
,
"GLM4VForCausalLM"
:
_HfExamplesInfo
(
"THUDM/glm-4v-9b"
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
hf_overrides
=
{
"architectures"
:
[
"GLM4VForCausalLM"
]}),
# noqa: E501
hf_overrides
=
{
"architectures"
:
[
"GLM4VForCausalLM"
]}),
# noqa: E501
...
@@ -302,6 +315,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
...
@@ -302,6 +315,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"Idefics3ForConditionalGeneration"
:
_HfExamplesInfo
(
"HuggingFaceM4/Idefics3-8B-Llama3"
,
# noqa: E501
"Idefics3ForConditionalGeneration"
:
_HfExamplesInfo
(
"HuggingFaceM4/Idefics3-8B-Llama3"
,
# noqa: E501
{
"tiny"
:
"HuggingFaceTB/SmolVLM-256M-Instruct"
}),
# noqa: E501
{
"tiny"
:
"HuggingFaceTB/SmolVLM-256M-Instruct"
}),
# noqa: E501
"KimiVLForConditionalGeneration"
:
_HfExamplesInfo
(
"moonshotai/Kimi-VL-A3B-Instruct"
,
# noqa: E501
extras
=
{
"thinking"
:
"moonshotai/Kimi-VL-A3B-Thinking"
},
# noqa: E501
trust_remote_code
=
True
),
"Llama4ForConditionalGeneration"
:
_HfExamplesInfo
(
"meta-llama/Llama-4-Scout-17B-16E-Instruct"
,
# noqa: E501
"Llama4ForConditionalGeneration"
:
_HfExamplesInfo
(
"meta-llama/Llama-4-Scout-17B-16E-Instruct"
,
# noqa: E501
min_transformers_version
=
"4.51"
),
min_transformers_version
=
"4.51"
),
"LlavaForConditionalGeneration"
:
_HfExamplesInfo
(
"llava-hf/llava-1.5-7b-hf"
,
"LlavaForConditionalGeneration"
:
_HfExamplesInfo
(
"llava-hf/llava-1.5-7b-hf"
,
...
@@ -322,7 +338,6 @@ _MULTIMODAL_EXAMPLE_MODELS = {
...
@@ -322,7 +338,6 @@ _MULTIMODAL_EXAMPLE_MODELS = {
extras
=
{
"2.6"
:
"openbmb/MiniCPM-V-2_6"
},
# noqa: E501
extras
=
{
"2.6"
:
"openbmb/MiniCPM-V-2_6"
},
# noqa: E501
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"Mistral3ForConditionalGeneration"
:
_HfExamplesInfo
(
"mistralai/Mistral-Small-3.1-24B-Instruct-2503"
,
# noqa: E501
"Mistral3ForConditionalGeneration"
:
_HfExamplesInfo
(
"mistralai/Mistral-Small-3.1-24B-Instruct-2503"
,
# noqa: E501
min_transformers_version
=
"4.50"
,
# noqa: E501
extras
=
{
"fp8"
:
"nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic"
}),
# noqa: E501
extras
=
{
"fp8"
:
"nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic"
}),
# noqa: E501
"MolmoForCausalLM"
:
_HfExamplesInfo
(
"allenai/Molmo-7B-D-0924"
,
"MolmoForCausalLM"
:
_HfExamplesInfo
(
"allenai/Molmo-7B-D-0924"
,
max_transformers_version
=
"4.48"
,
max_transformers_version
=
"4.48"
,
...
@@ -348,8 +363,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
...
@@ -348,8 +363,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
hf_overrides
=
{
"architectures"
:
[
"QwenVLForConditionalGeneration"
]}),
# noqa: E501
hf_overrides
=
{
"architectures"
:
[
"QwenVLForConditionalGeneration"
]}),
# noqa: E501
"Qwen2AudioForConditionalGeneration"
:
_HfExamplesInfo
(
"Qwen/Qwen2-Audio-7B-Instruct"
),
# noqa: E501
"Qwen2AudioForConditionalGeneration"
:
_HfExamplesInfo
(
"Qwen/Qwen2-Audio-7B-Instruct"
),
# noqa: E501
"Qwen2VLForConditionalGeneration"
:
_HfExamplesInfo
(
"Qwen/Qwen2-VL-2B-Instruct"
),
# noqa: E501
"Qwen2VLForConditionalGeneration"
:
_HfExamplesInfo
(
"Qwen/Qwen2-VL-2B-Instruct"
),
# noqa: E501
"Qwen2_5_VLForConditionalGeneration"
:
_HfExamplesInfo
(
"Qwen/Qwen2.5-VL-3B-Instruct"
,
# noqa: E501
"Qwen2_5_VLForConditionalGeneration"
:
_HfExamplesInfo
(
"Qwen/Qwen2.5-VL-3B-Instruct"
),
# noqa: E501
min_transformers_version
=
"4.49"
),
# noqa: E501
"Qwen2_5OmniModel"
:
_HfExamplesInfo
(
"Qwen/Qwen2.5-Omni-7B"
,
# noqa: E501
min_transformers_version
=
"4.52"
),
# noqa: E501
"SkyworkR1VChatModel"
:
_HfExamplesInfo
(
"Skywork/Skywork-R1V-38B"
),
"SkyworkR1VChatModel"
:
_HfExamplesInfo
(
"Skywork/Skywork-R1V-38B"
),
"SmolVLMForConditionalGeneration"
:
_HfExamplesInfo
(
"HuggingFaceTB/SmolVLM2-2.2B-Instruct"
),
# noqa: E501
"SmolVLMForConditionalGeneration"
:
_HfExamplesInfo
(
"HuggingFaceTB/SmolVLM2-2.2B-Instruct"
),
# noqa: E501
"UltravoxModel"
:
_HfExamplesInfo
(
"fixie-ai/ultravox-v0_5-llama-3_2-1b"
,
# noqa: E501
"UltravoxModel"
:
_HfExamplesInfo
(
"fixie-ai/ultravox-v0_5-llama-3_2-1b"
,
# noqa: E501
...
@@ -358,7 +374,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
...
@@ -358,7 +374,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
# Therefore, we borrow the BartTokenizer from the original Bart model
# Therefore, we borrow the BartTokenizer from the original Bart model
"Florence2ForConditionalGeneration"
:
_HfExamplesInfo
(
"microsoft/Florence-2-base"
,
# noqa: E501
"Florence2ForConditionalGeneration"
:
_HfExamplesInfo
(
"microsoft/Florence-2-base"
,
# noqa: E501
tokenizer
=
"
facebook/bart-base
"
,
tokenizer
=
"
Isotr0py/Florence-2-tokenizer
"
,
trust_remote_code
=
True
),
# noqa: E501
trust_remote_code
=
True
),
# noqa: E501
"MllamaForConditionalGeneration"
:
_HfExamplesInfo
(
"meta-llama/Llama-3.2-11B-Vision-Instruct"
),
# noqa: E501
"MllamaForConditionalGeneration"
:
_HfExamplesInfo
(
"meta-llama/Llama-3.2-11B-Vision-Instruct"
),
# noqa: E501
"Llama4ForConditionalGeneration"
:
_HfExamplesInfo
(
"meta-llama/Llama-4-Scout-17B-16E-Instruct"
),
# noqa: E501
"Llama4ForConditionalGeneration"
:
_HfExamplesInfo
(
"meta-llama/Llama-4-Scout-17B-16E-Instruct"
),
# noqa: E501
...
@@ -379,6 +395,10 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
...
@@ -379,6 +395,10 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
trust_remote_code
=
True
,
trust_remote_code
=
True
,
speculative_model
=
"yuhuili/EAGLE-LLaMA3-Instruct-8B"
,
speculative_model
=
"yuhuili/EAGLE-LLaMA3-Instruct-8B"
,
tokenizer
=
"meta-llama/Meta-Llama-3-8B-Instruct"
),
# noqa: E501
tokenizer
=
"meta-llama/Meta-Llama-3-8B-Instruct"
),
# noqa: E501
"Eagle3LlamaForCausalLM"
:
_HfExamplesInfo
(
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
,
# noqa: E501
trust_remote_code
=
True
,
speculative_model
=
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
,
tokenizer
=
"meta-llama/Llama-3.1-8B-Instruct"
),
}
}
_TRANSFORMERS_MODELS
=
{
_TRANSFORMERS_MODELS
=
{
...
...
tests/models/test_bitblas.py
0 → 100644
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
"""Compare the outputs of a GPTQ model to a bitblas model.
Note: GPTQ and bitblas do not have bitwise correctness.
As a result, in this test, we just confirm that the top selected tokens of the
bitblas/GPTQ models are in the top 3 selections of each other.
Note: bitblas internally uses locks to synchronize the threads. This can
result in very slight nondeterminism for bitblas. As a result, we re-run the
test up to 3 times to see if we pass.
Run `pytest tests/models/test_bitblas.py`.
"""
from
dataclasses
import
dataclass
import
pytest
from
.utils
import
check_logprobs_close
@
dataclass
class
ModelPair
:
model_bitblas
:
str
model_gptq
:
str
model_pairs
=
[
ModelPair
(
model_bitblas
=
"hxbgsyxh/opt-125m-4bit-128g-bitblas"
,
model_gptq
=
"hxbgsyxh/opt-125m-4bit-128g"
),
]
@
pytest
.
mark
.
flaky
(
reruns
=
2
)
@
pytest
.
mark
.
skipif
(
True
,
reason
=
"BitBLAS takes too much time for tuning."
)
@
pytest
.
mark
.
parametrize
(
"model_pair"
,
model_pairs
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_models
(
vllm_runner
,
example_prompts
,
model_pair
:
ModelPair
,
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
,
)
->
None
:
with
vllm_runner
(
model_pair
.
model_bitblas
,
dtype
=
dtype
,
quantization
=
"bitblas"
)
as
bitblas_model
:
bitblas_outputs
=
bitblas_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
with
vllm_runner
(
model_pair
.
model_gptq
,
dtype
=
dtype
,
quantization
=
"gptq"
)
as
gptq_model
:
gptq_outputs
=
gptq_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
check_logprobs_close
(
outputs_0_lst
=
gptq_outputs
,
outputs_1_lst
=
bitblas_outputs
,
name_0
=
"gptq"
,
name_1
=
"bitblas"
,
)
tests/models/test_gptq_bitblas.py
0 → 100644
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
"""Compare the outputs of a GPTQ model to a bitblas model.
Note: GPTQ and bitblas do not have bitwise correctness.
As a result, in this test, we just confirm that the top selected tokens of the
bitblas/GPTQ models are in the top 3 selections of each other.
Note: bitblas internally uses locks to synchronize the threads. This can
result in very slight nondeterminism for bitblas. As a result, we re-run the
test up to 3 times to see if we pass.
Run `pytest tests/models/test_bitblas.py`.
"""
from
dataclasses
import
dataclass
import
pytest
from
.utils
import
check_logprobs_close
@
dataclass
class
ModelPair
:
model_gptq
:
str
model_pairs
=
[
ModelPair
(
model_gptq
=
"hxbgsyxh/opt-125m-4bit-128g"
),
]
@
pytest
.
mark
.
flaky
(
reruns
=
2
)
@
pytest
.
mark
.
skipif
(
True
,
reason
=
"BitBLAS takes too much time for tuning."
)
@
pytest
.
mark
.
parametrize
(
"model_pair"
,
model_pairs
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_models
(
vllm_runner
,
example_prompts
,
model_pair
:
ModelPair
,
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
,
)
->
None
:
with
vllm_runner
(
model_pair
.
model_gptq
,
dtype
=
dtype
,
quantization
=
"bitblas"
)
as
bitblas_model
:
bitblas_outputs
=
bitblas_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
with
vllm_runner
(
model_pair
.
model_gptq
,
dtype
=
dtype
,
quantization
=
"gptq"
)
as
gptq_model
:
gptq_outputs
=
gptq_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
check_logprobs_close
(
outputs_0_lst
=
gptq_outputs
,
outputs_1_lst
=
bitblas_outputs
,
name_0
=
"gptq"
,
name_1
=
"gptq_bitblas"
,
)
tests/models/test_initialization.py
View file @
081057de
...
@@ -24,10 +24,7 @@ def test_can_initialize(model_arch):
...
@@ -24,10 +24,7 @@ def test_can_initialize(model_arch):
def
hf_overrides
(
hf_config
:
PretrainedConfig
)
->
PretrainedConfig
:
def
hf_overrides
(
hf_config
:
PretrainedConfig
)
->
PretrainedConfig
:
hf_config
.
update
(
model_info
.
hf_overrides
)
hf_config
.
update
(
model_info
.
hf_overrides
)
if
hasattr
(
hf_config
,
"text_config"
):
text_config
=
hf_config
.
get_text_config
()
text_config
:
PretrainedConfig
=
hf_config
.
text_config
else
:
text_config
=
hf_config
text_config
.
update
({
text_config
.
update
({
"num_layers"
:
1
,
"num_layers"
:
1
,
...
...
tests/models/test_oot_registration.py
View file @
081057de
...
@@ -18,10 +18,9 @@ def test_plugin(
...
@@ -18,10 +18,9 @@ def test_plugin(
m
.
setenv
(
"VLLM_USE_V1"
,
"0"
)
m
.
setenv
(
"VLLM_USE_V1"
,
"0"
)
m
.
setenv
(
"VLLM_PLUGINS"
,
""
)
m
.
setenv
(
"VLLM_PLUGINS"
,
""
)
with
pytest
.
raises
(
Exception
)
as
excinfo
:
match
=
"Cannot find model module"
with
pytest
.
raises
(
ValueError
,
match
=
match
):
LLM
(
model
=
dummy_opt_path
,
load_format
=
"dummy"
)
LLM
(
model
=
dummy_opt_path
,
load_format
=
"dummy"
)
error_msg
=
"has no vLLM implementation and the Transformers implementation is not compatible with vLLM"
# noqa: E501
assert
(
error_msg
in
str
(
excinfo
.
value
))
@
create_new_process_for_each_test
()
@
create_new_process_for_each_test
()
...
...
tests/quantization/test_compressed_tensors.py
View file @
081057de
...
@@ -261,16 +261,23 @@ def test_compressed_tensors_w8a8_dynamic_per_token(
...
@@ -261,16 +261,23 @@ def test_compressed_tensors_w8a8_dynamic_per_token(
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"wNa16_args"
,
"wNa16_args"
,
[
[(
"nm-testing/tinyllama-oneshot-w4a16-channel-v2"
,
"channel"
,
None
,
8
,
(
"nm-testing/tinyllama-oneshot-w4a16-channel-v2"
,
"channel"
,
None
,
8
),
True
,
False
),
(
"nm-testing/tinyllama-oneshot-w4a16-group128-v2"
,
"group"
,
128
,
8
),
(
"nm-testing/tinyllama-oneshot-w4a16-group128-v2"
,
"group"
,
128
,
8
,
True
,
(
"nm-testing/tinyllama-oneshot-w8a16-per-channel"
,
"channel"
,
None
,
4
),
False
),
],
(
"nm-testing/tinyllama-oneshot-w8a16-per-channel"
,
"channel"
,
None
,
4
,
True
,
False
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-awq-group128-asym256"
,
"group"
,
128
,
8
,
False
,
False
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-Channel"
,
"channel"
,
None
,
8
,
False
,
False
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-ActOrder"
,
"group"
,
128
,
8
,
False
,
True
)],
)
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"The tests are skipped on non-CUDA platform."
)
reason
=
"The tests are skipped on non-CUDA platform."
)
def
test_compressed_tensors_wNa16
(
vllm_runner
,
wNa16_args
):
def
test_compressed_tensors_wNa16
(
vllm_runner
,
wNa16_args
):
model
,
strategy
,
group
,
pack_factor
=
wNa16_args
model
,
strategy
,
group
,
pack_factor
,
symmetric
,
has_g_idx
=
wNa16_args
with
vllm_runner
(
model
)
as
llm
:
with
vllm_runner
(
model
)
as
llm
:
def
check_model
(
model
):
def
check_model
(
model
):
...
@@ -286,6 +293,8 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
...
@@ -286,6 +293,8 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
if
group
is
None
else
group
)
if
group
is
None
else
group
)
assert
qkv_proj
.
scheme
.
pack_factor
==
pack_factor
assert
qkv_proj
.
scheme
.
pack_factor
==
pack_factor
assert
qkv_proj
.
scheme
.
symmetric
==
symmetric
assert
qkv_proj
.
scheme
.
has_g_idx
==
has_g_idx
llm
.
apply_model
(
check_model
)
llm
.
apply_model
(
check_model
)
...
...
tests/samplers/test_beam_search.py
View file @
081057de
...
@@ -5,6 +5,9 @@ Run `pytest tests/samplers/test_beam_search.py`.
...
@@ -5,6 +5,9 @@ Run `pytest tests/samplers/test_beam_search.py`.
"""
"""
import
pytest
import
pytest
from
transformers
import
AutoModelForSeq2SeqLM
from
vllm.assets.audio
import
AudioAsset
@
pytest
.
fixture
(
autouse
=
True
)
@
pytest
.
fixture
(
autouse
=
True
)
...
@@ -19,6 +22,7 @@ def v1(run_with_both_engines):
...
@@ -19,6 +22,7 @@ def v1(run_with_both_engines):
# 3. Use the model "huggyllama/llama-7b".
# 3. Use the model "huggyllama/llama-7b".
MAX_TOKENS
=
[
64
]
MAX_TOKENS
=
[
64
]
BEAM_WIDTHS
=
[
4
]
BEAM_WIDTHS
=
[
4
]
MM_BEAM_WIDTHS
=
[
2
]
MODELS
=
[
"TinyLlama/TinyLlama-1.1B-Chat-v1.0"
]
MODELS
=
[
"TinyLlama/TinyLlama-1.1B-Chat-v1.0"
]
...
@@ -48,15 +52,90 @@ def test_beam_search_single_input(
...
@@ -48,15 +52,90 @@ def test_beam_search_single_input(
for
i
in
range
(
len
(
example_prompts
)):
for
i
in
range
(
len
(
example_prompts
)):
hf_output_ids
,
hf_output_texts
=
hf_outputs
[
i
]
hf_output_ids
,
hf_output_texts
=
hf_outputs
[
i
]
vllm_output_ids
,
vllm_output_texts
=
vllm_outputs
[
i
]
vllm_output_ids
,
vllm_output_texts
=
vllm_outputs
[
i
]
for
i
,
(
hf_text
,
for
j
,
(
hf_text
,
vllm_text
)
in
enumerate
(
zip
(
hf_output_texts
,
vllm_text
)
in
enumerate
(
zip
(
hf_output_texts
,
vllm_output_texts
)):
vllm_output_texts
)):
print
(
f
">>>
{
i
}
-th hf output:"
)
print
(
f
">>>
{
j
}
-th hf output:"
)
print
(
hf_text
)
print
(
hf_text
)
print
(
f
">>>
{
i
}
-th vllm output:"
)
print
(
f
">>>
{
j
}
-th vllm output:"
)
print
(
vllm_text
)
print
(
vllm_text
)
assert
len
(
hf_output_ids
)
==
len
(
vllm_output_ids
)
assert
len
(
hf_output_ids
)
==
len
(
vllm_output_ids
)
for
j
in
range
(
len
(
hf_output_ids
)):
for
j
in
range
(
len
(
hf_output_ids
)):
assert
hf_output_ids
[
j
]
==
vllm_output_ids
[
j
],
(
assert
hf_output_ids
[
j
]
==
vllm_output_ids
[
j
],
(
f
"Test
{
i
}
output
{
j
}
:
\n
HF:
{
hf_output_ids
}
\n
"
f
"Test
{
i
}
output
{
j
}
:
\n
HF:
{
hf_output_ids
}
\n
"
f
"vLLM:
{
vllm_output_ids
}
"
)
f
"vLLM:
{
vllm_output_ids
}
"
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
MAX_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"beam_width"
,
MM_BEAM_WIDTHS
)
def
test_beam_search_passes_multimodal_data
(
hf_runner
,
vllm_runner
,
dtype
:
str
,
max_tokens
:
int
,
beam_width
:
int
,
)
->
None
:
"""Ensure that beam search passes multimodal data through correctly."""
# NOTE - this test is primarily to check that mm data is passed to beams
# correctly. As such, we just need to check one extra modality to make
# sure things pass through properly.
audios
=
[
AudioAsset
(
"mary_had_lamb"
).
audio_and_sample_rate
]
model
=
"Qwen/Qwen2-Audio-7B-Instruct"
audio_seq
=
"<|audio_bos|><|AUDIO|><|audio_eos|>"
prompts
=
[
f
"<|im_start|>user
\n
{
audio_seq
}
Can you transcribe this?<|im_end|>
\n
<|im_start|>assistant
\n
"
#noqa: E501
]
with
hf_runner
(
model
,
dtype
=
dtype
,
auto_cls
=
AutoModelForSeq2SeqLM
)
as
hf_model
:
audio_token_id
=
hf_model
.
config
.
audio_token_index
eos_token_id
=
hf_model
.
tokenizer
.
eos_token_id
# <|im_end|>
hf_outputs
=
hf_model
.
generate_beam_search
(
prompts
,
beam_width
=
beam_width
,
max_tokens
=
max_tokens
,
audios
=
audios
,
)
with
vllm_runner
(
model
,
dtype
=
dtype
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_beam_search
(
prompts
,
beam_width
=
beam_width
,
max_tokens
=
max_tokens
,
audios
=
audios
,
)
seq_with_no_audio_toks
=
lambda
seq
:
[
tok
for
tok
in
seq
if
tok
!=
audio_token_id
]
for
i
in
range
(
len
(
prompts
)):
hf_output_ids
,
hf_output_texts
=
hf_outputs
[
i
]
vllm_output_ids
,
vllm_output_texts
=
vllm_outputs
[
i
]
for
j
,
(
hf_text
,
vllm_text
)
in
enumerate
(
zip
(
hf_output_texts
,
vllm_output_texts
)):
print
(
f
">>>
{
j
}
-th hf output [NOTE: special tokens are filtered]:"
)
print
(
hf_text
)
print
(
f
">>>
{
j
}
-th vllm output:"
)
print
(
vllm_text
)
assert
len
(
hf_output_ids
)
==
len
(
vllm_output_ids
)
for
j
in
range
(
len
(
hf_output_ids
)):
# Compare everything except for the audio tokens; we do this since
# the IDs returned from the transformers helper expands the audio
# token to match features, while the vLLM helper maintains the
# single audio token in the input text
filtered_hf_output_ids
=
seq_with_no_audio_toks
(
hf_output_ids
[
j
])
filtered_vllm_output_ids
=
seq_with_no_audio_toks
(
vllm_output_ids
[
j
])
# HF output IDs may contain the end of sequence
if
len
(
filtered_hf_output_ids
)
==
len
(
filtered_vllm_output_ids
)
+
1
:
assert
filtered_hf_output_ids
[
-
1
]
==
eos_token_id
filtered_hf_output_ids
=
filtered_hf_output_ids
[:
-
1
]
assert
filtered_hf_output_ids
==
filtered_vllm_output_ids
tests/spec_decode/test_scorer.py
View file @
081057de
...
@@ -62,9 +62,8 @@ def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
...
@@ -62,9 +62,8 @@ def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
scorer_worker
=
create_worker
(
Worker
,
model_name
,
block_size
,
scorer_worker
=
create_worker
(
Worker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
)
num_gpu_blocks
,
seed
)
scorer_worker
.
model_runner
.
disable_logprobs
=
True
# accessed by mqa_scorer
scorer_worker
.
model_runner
.
disable_logprobs
=
True
# accessed by mqa_scorer
scorer_worker
.
model_runner
.
model
.
sampler
.
include_gpu_probs_tensor
=
True
scorer_worker
.
model_runner
.
sampler
.
include_gpu_probs_tensor
=
True
scorer_worker
.
model_runner
.
model
.
sampler
.
\
scorer_worker
.
model_runner
.
sampler
.
should_modify_greedy_probs_inplace
=
True
should_modify_greedy_probs_inplace
=
True
vocab_size
=
scorer_worker
.
vocab_size
vocab_size
=
scorer_worker
.
vocab_size
...
...
tests/test_config.py
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
asdict
from
dataclasses
import
MISSING
,
Field
,
asdict
,
dataclass
,
field
import
pytest
import
pytest
from
vllm.config
import
ModelConfig
,
PoolerConfig
from
vllm.config
import
ModelConfig
,
PoolerConfig
,
get_field
from
vllm.model_executor.layers.pooler
import
PoolingType
from
vllm.model_executor.layers.pooler
import
PoolingType
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
def
test_get_field
():
@
dataclass
class
TestConfig
:
a
:
int
b
:
dict
=
field
(
default_factory
=
dict
)
c
:
str
=
"default"
with
pytest
.
raises
(
ValueError
):
get_field
(
TestConfig
,
"a"
)
b
=
get_field
(
TestConfig
,
"b"
)
assert
isinstance
(
b
,
Field
)
assert
b
.
default
is
MISSING
assert
b
.
default_factory
is
dict
c
=
get_field
(
TestConfig
,
"c"
)
assert
isinstance
(
c
,
Field
)
assert
c
.
default
==
"default"
assert
c
.
default_factory
is
MISSING
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"model_id"
,
"expected_runner_type"
,
"expected_task"
),
(
"model_id"
,
"expected_runner_type"
,
"expected_task"
),
[
[
...
...
tests/test_utils.py
View file @
081057de
...
@@ -13,11 +13,11 @@ import torch
...
@@ -13,11 +13,11 @@ import torch
from
vllm_test_utils.monitor
import
monitor
from
vllm_test_utils.monitor
import
monitor
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.utils
import
(
FlexibleArgumentParser
,
MemorySnapshot
,
from
vllm.utils
import
(
CacheInfo
,
FlexibleArgumentParser
,
LRUCache
,
PlaceholderModule
,
StoreBoolean
,
bind_kv_cache
,
MemorySnapshot
,
PlaceholderModule
,
StoreBoolean
,
deprecate_kwargs
,
get_open_port
,
memory_profiling
,
bind_kv_cache
,
deprecate_kwargs
,
get_open_port
,
merge_async_iterators
,
sha256
,
supports_kw
,
memory_profiling
,
merge_async_iterators
,
sha256
,
swap_dict_values
)
supports_kw
,
swap_dict_values
)
from
.utils
import
create_new_process_for_each_test
,
error_on_warning
from
.utils
import
create_new_process_for_each_test
,
error_on_warning
...
@@ -417,6 +417,129 @@ def test_bind_kv_cache_pp():
...
@@ -417,6 +417,129 @@ def test_bind_kv_cache_pp():
assert
ctx
[
'layers.0.self_attn'
].
kv_cache
[
1
]
is
kv_cache
[
1
][
0
]
assert
ctx
[
'layers.0.self_attn'
].
kv_cache
[
1
]
is
kv_cache
[
1
][
0
]
class
TestLRUCache
(
LRUCache
):
def
_on_remove
(
self
,
key
,
value
):
if
not
hasattr
(
self
,
"_remove_counter"
):
self
.
_remove_counter
=
0
self
.
_remove_counter
+=
1
def
test_lru_cache
():
cache
=
TestLRUCache
(
3
)
assert
cache
.
stat
()
==
CacheInfo
(
hits
=
0
,
total
=
0
)
assert
cache
.
stat
(
delta
=
True
)
==
CacheInfo
(
hits
=
0
,
total
=
0
)
cache
.
put
(
1
,
1
)
assert
len
(
cache
)
==
1
cache
.
put
(
1
,
1
)
assert
len
(
cache
)
==
1
cache
.
put
(
2
,
2
)
assert
len
(
cache
)
==
2
cache
.
put
(
3
,
3
)
assert
len
(
cache
)
==
3
assert
set
(
cache
.
cache
)
==
{
1
,
2
,
3
}
cache
.
put
(
4
,
4
)
assert
len
(
cache
)
==
3
assert
set
(
cache
.
cache
)
==
{
2
,
3
,
4
}
assert
cache
.
_remove_counter
==
1
assert
cache
.
get
(
2
)
==
2
assert
cache
.
stat
()
==
CacheInfo
(
hits
=
1
,
total
=
1
)
assert
cache
.
stat
(
delta
=
True
)
==
CacheInfo
(
hits
=
1
,
total
=
1
)
assert
cache
[
2
]
==
2
assert
cache
.
stat
()
==
CacheInfo
(
hits
=
2
,
total
=
2
)
assert
cache
.
stat
(
delta
=
True
)
==
CacheInfo
(
hits
=
1
,
total
=
1
)
cache
.
put
(
5
,
5
)
assert
set
(
cache
.
cache
)
==
{
2
,
4
,
5
}
assert
cache
.
_remove_counter
==
2
assert
cache
.
pop
(
5
)
==
5
assert
len
(
cache
)
==
2
assert
set
(
cache
.
cache
)
==
{
2
,
4
}
assert
cache
.
_remove_counter
==
3
assert
cache
.
get
(
-
1
)
is
None
assert
cache
.
stat
()
==
CacheInfo
(
hits
=
2
,
total
=
3
)
assert
cache
.
stat
(
delta
=
True
)
==
CacheInfo
(
hits
=
0
,
total
=
1
)
cache
.
pop
(
10
)
assert
len
(
cache
)
==
2
assert
set
(
cache
.
cache
)
==
{
2
,
4
}
assert
cache
.
_remove_counter
==
3
cache
.
get
(
10
)
assert
len
(
cache
)
==
2
assert
set
(
cache
.
cache
)
==
{
2
,
4
}
assert
cache
.
_remove_counter
==
3
cache
.
put
(
6
,
6
)
assert
len
(
cache
)
==
3
assert
set
(
cache
.
cache
)
==
{
2
,
4
,
6
}
assert
2
in
cache
assert
4
in
cache
assert
6
in
cache
cache
.
remove_oldest
()
assert
len
(
cache
)
==
2
assert
set
(
cache
.
cache
)
==
{
2
,
6
}
assert
cache
.
_remove_counter
==
4
cache
.
clear
()
assert
len
(
cache
)
==
0
assert
cache
.
_remove_counter
==
6
assert
cache
.
stat
()
==
CacheInfo
(
hits
=
0
,
total
=
0
)
assert
cache
.
stat
(
delta
=
True
)
==
CacheInfo
(
hits
=
0
,
total
=
0
)
cache
.
_remove_counter
=
0
cache
[
1
]
=
1
assert
len
(
cache
)
==
1
cache
[
1
]
=
1
assert
len
(
cache
)
==
1
cache
[
2
]
=
2
assert
len
(
cache
)
==
2
cache
[
3
]
=
3
assert
len
(
cache
)
==
3
assert
set
(
cache
.
cache
)
==
{
1
,
2
,
3
}
cache
[
4
]
=
4
assert
len
(
cache
)
==
3
assert
set
(
cache
.
cache
)
==
{
2
,
3
,
4
}
assert
cache
.
_remove_counter
==
1
assert
cache
[
2
]
==
2
cache
[
5
]
=
5
assert
set
(
cache
.
cache
)
==
{
2
,
4
,
5
}
assert
cache
.
_remove_counter
==
2
del
cache
[
5
]
assert
len
(
cache
)
==
2
assert
set
(
cache
.
cache
)
==
{
2
,
4
}
assert
cache
.
_remove_counter
==
3
cache
.
pop
(
10
)
assert
len
(
cache
)
==
2
assert
set
(
cache
.
cache
)
==
{
2
,
4
}
assert
cache
.
_remove_counter
==
3
cache
[
6
]
=
6
assert
len
(
cache
)
==
3
assert
set
(
cache
.
cache
)
==
{
2
,
4
,
6
}
assert
2
in
cache
assert
4
in
cache
assert
6
in
cache
def
test_placeholder_module_error_handling
():
def
test_placeholder_module_error_handling
():
placeholder
=
PlaceholderModule
(
"placeholder_1234"
)
placeholder
=
PlaceholderModule
(
"placeholder_1234"
)
...
...
tests/tokenization/test_cached_tokenizer.py
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
pickle
from
copy
import
deepcopy
from
copy
import
deepcopy
import
pytest
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
from
vllm.transformers_utils.tokenizer
import
get_cached_tokenizer
from
vllm.transformers_utils.tokenizer
import
(
AnyTokenizer
,
get_cached_tokenizer
)
def
test_cached_tokenizer
():
@
pytest
.
mark
.
parametrize
(
"model_id"
,
[
"gpt2"
,
"THUDM/chatglm3-6b"
])
reference_tokenizer
=
AutoTokenizer
.
from_pretrained
(
"gpt2"
)
def
test_cached_tokenizer
(
model_id
:
str
):
reference_tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
,
trust_remote_code
=
True
)
reference_tokenizer
.
add_special_tokens
({
"cls_token"
:
"<CLS>"
})
reference_tokenizer
.
add_special_tokens
({
"cls_token"
:
"<CLS>"
})
reference_tokenizer
.
add_special_tokens
(
reference_tokenizer
.
add_special_tokens
(
{
"additional_special_tokens"
:
[
"<SEP>"
]})
{
"additional_special_tokens"
:
[
"<SEP>"
]})
cached_tokenizer
=
get_cached_tokenizer
(
deepcopy
(
reference_tokenizer
))
cached_tokenizer
=
get_cached_tokenizer
(
deepcopy
(
reference_tokenizer
))
_check_consistency
(
cached_tokenizer
,
reference_tokenizer
)
pickled_tokenizer
=
pickle
.
dumps
(
cached_tokenizer
)
unpickled_tokenizer
=
pickle
.
loads
(
pickled_tokenizer
)
_check_consistency
(
unpickled_tokenizer
,
reference_tokenizer
)
def
_check_consistency
(
target
:
AnyTokenizer
,
expected
:
AnyTokenizer
):
assert
isinstance
(
target
,
type
(
expected
))
# Cached attributes
assert
target
.
all_special_ids
==
expected
.
all_special_ids
assert
target
.
all_special_tokens
==
expected
.
all_special_tokens
assert
(
target
.
all_special_tokens_extended
==
expected
.
all_special_tokens_extended
)
assert
target
.
get_vocab
()
==
expected
.
get_vocab
()
assert
len
(
target
)
==
len
(
expected
)
# Other attributes
assert
getattr
(
target
,
"padding_side"
,
None
)
==
getattr
(
expected
,
"padding_side"
,
None
)
assert
reference_tokenizer
.
encode
(
"prompt"
)
==
cached_tokenizer
.
encode
(
assert
target
.
encode
(
"prompt"
)
==
expected
.
encode
(
"prompt"
)
"prompt"
)
assert
set
(
reference_tokenizer
.
all_special_ids
)
==
set
(
cached_tokenizer
.
all_special_ids
)
assert
set
(
reference_tokenizer
.
all_special_tokens
)
==
set
(
cached_tokenizer
.
all_special_tokens
)
assert
set
(
reference_tokenizer
.
all_special_tokens_extended
)
==
set
(
cached_tokenizer
.
all_special_tokens_extended
)
tests/tokenization/test_detokenize.py
View file @
081057de
...
@@ -4,14 +4,22 @@ from collections.abc import Generator
...
@@ -4,14 +4,22 @@ from collections.abc import Generator
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
import
pytest
import
pytest
from
transformers
import
AutoTokenizer
from
transformers
import
(
AutoTokenizer
,
PreTrainedTokenizer
,
PreTrainedTokenizerFast
)
from
vllm.inputs
import
token_inputs
from
vllm.inputs
import
token_inputs
from
vllm.sequence
import
Logprob
,
SamplingParams
,
Sequence
,
SequenceGroup
from
vllm.sequence
import
Logprob
,
SamplingParams
,
Sequence
,
SequenceGroup
from
vllm.transformers_utils.detokenizer
import
(
Detokenizer
,
from
vllm.transformers_utils.detokenizer
import
Detokenizer
detokenize_incrementally
)
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
from
vllm.transformers_utils.tokenizer_group
import
get_tokenizer_group
from
vllm.transformers_utils.tokenizers.mistral
import
MistralTokenizer
from
vllm.transformers_utils.tokenizers.mistral
import
MistralTokenizer
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine.detokenizer
import
(
FastIncrementalDetokenizer
,
IncrementalDetokenizer
,
SlowIncrementalDetokenizer
)
SPECIAL_TOKS_TRUTH
=
[
"Some text with adjacent special tokens <|padding|><|padding|><fim_prefix><fim_middle><fim_suffix>other text<fim_pad>"
,
# noqa
]
TRUTH
=
[
TRUTH
=
[
"Hello here, this is a simple test"
,
"Hello here, this is a simple test"
,
...
@@ -22,7 +30,8 @@ TRUTH = [
...
@@ -22,7 +30,8 @@ TRUTH = [
# incomplete UTF-8 characters
# incomplete UTF-8 characters
# see https://github.com/vllm-project/vllm/pull/9625
# see https://github.com/vllm-project/vllm/pull/9625
"ပုံပြင်လေးပြောပြပါ်"
,
"ပုံပြင်လေးပြောပြပါ်"
,
]
]
+
SPECIAL_TOKS_TRUTH
TOKENIZERS
=
[
TOKENIZERS
=
[
"facebook/opt-125m"
,
"facebook/opt-125m"
,
"gpt2"
,
"gpt2"
,
...
@@ -38,26 +47,37 @@ TOKENIZERS = [
...
@@ -38,26 +47,37 @@ TOKENIZERS = [
]
]
def
_run_incremental_decode
(
tokenizer
,
all_input_ids
,
def
_run_incremental_decode
(
tokenizer
,
skip_special_tokens
:
bool
,
starting_index
:
int
):
all_input_ids
,
decoded_text
=
""
skip_special_tokens
:
bool
,
offset
=
0
starting_index
:
int
,
token_offset
=
0
spaces_between_special_tokens
:
bool
=
True
,
prev_tokens
=
None
fast
:
Optional
[
bool
]
=
None
):
for
i
in
range
(
starting_index
,
len
(
all_input_ids
)):
new_tokens
,
text
,
offset
,
token_offset
=
detokenize_incrementally
(
prompt_token_ids
=
all_input_ids
[:
starting_index
]
tokenizer
,
all_input_ids
[:
i
+
1
],
params
=
SamplingParams
(
prev_tokens
,
skip_special_tokens
=
skip_special_tokens
,
offset
,
spaces_between_special_tokens
=
spaces_between_special_tokens
,
token_offset
,
)
skip_special_tokens
=
skip_special_tokens
)
request
=
EngineCoreRequest
(
""
,
prompt_token_ids
,
None
,
None
,
None
,
params
,
decoded_text
+=
text
None
,
0.0
,
None
)
if
prev_tokens
is
None
:
prev_tokens
=
new_tokens
if
fast
is
None
:
else
:
detokenizer
=
IncrementalDetokenizer
.
from_new_request
(
prev_tokens
+=
new_tokens
tokenizer
,
request
)
return
decoded_text
elif
fast
:
detokenizer
=
FastIncrementalDetokenizer
(
tokenizer
,
request
)
else
:
detokenizer
=
SlowIncrementalDetokenizer
(
tokenizer
,
request
)
output_text
=
""
for
i
,
token_id
in
enumerate
(
all_input_ids
[
starting_index
:]):
detokenizer
.
update
([
token_id
],
False
)
finished
=
i
==
len
(
all_input_ids
)
-
1
output_text
+=
detokenizer
.
get_next_output_text
(
finished
,
delta
=
True
)
return
output_text
,
detokenizer
.
output_token_ids
@
pytest
.
fixture
@
pytest
.
fixture
...
@@ -85,11 +105,13 @@ def test_mistral_edge_case(tokenizer, truth):
...
@@ -85,11 +105,13 @@ def test_mistral_edge_case(tokenizer, truth):
starting_index
=
0
starting_index
=
0
all_input_ids
=
tokenizer
(
truth
,
add_special_tokens
=
False
).
input_ids
all_input_ids
=
tokenizer
(
truth
,
add_special_tokens
=
False
).
input_ids
decoded_text
=
_run_incremental_decode
(
tokenizer
,
decoded_text
,
out_ids
=
_run_incremental_decode
(
all_input_ids
,
tokenizer
,
skip_special_tokens
=
True
,
all_input_ids
,
starting_index
=
starting_index
)
skip_special_tokens
=
True
,
starting_index
=
starting_index
)
assert
decoded_text
==
truth
assert
decoded_text
==
truth
assert
out_ids
==
all_input_ids
[
starting_index
:]
@
pytest
.
fixture
@
pytest
.
fixture
...
@@ -106,45 +128,91 @@ def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
...
@@ -106,45 +128,91 @@ def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
@
pytest
.
mark
.
parametrize
(
"with_prompt"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"with_prompt"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"tokenizer_name"
,
TOKENIZERS
)
@
pytest
.
mark
.
parametrize
(
"tokenizer_name"
,
TOKENIZERS
)
@
pytest
.
mark
.
parametrize
(
"skip_special_tokens"
,
(
True
,
False
),
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"skip_special_tokens"
,
(
True
,
False
),
indirect
=
True
)
def
test_decode_streaming
(
tokenizer
,
truth
,
with_prompt
,
skip_special_tokens
):
@
pytest
.
mark
.
parametrize
(
"spaces_between_special_tokens"
,
(
True
,
False
))
@
pytest
.
mark
.
parametrize
(
"fast"
,
(
True
,
False
))
def
test_decode_streaming
(
tokenizer
,
truth
,
with_prompt
,
skip_special_tokens
,
spaces_between_special_tokens
,
fast
):
if
fast
and
not
isinstance
(
tokenizer
,
PreTrainedTokenizerFast
):
pytest
.
skip
()
if
skip_special_tokens
and
not
spaces_between_special_tokens
:
pytest
.
skip
()
if
not
fast
and
isinstance
(
tokenizer
,
PreTrainedTokenizerFast
):
# Fix up inconsistency in fast/slow tokenizer behaviour.
tokenizer
.
add_special_tokens
({
"additional_special_tokens"
:
[
at
for
at
in
tokenizer
.
_tokenizer
.
get_added_tokens_decoder
().
values
()
if
at
.
special
]
})
extra_decode_args
=
{}
if
not
isinstance
(
tokenizer
,
PreTrainedTokenizer
)
\
else
{
"spaces_between_special_tokens"
:
spaces_between_special_tokens
}
truth_tokens
=
tokenizer
(
truth
,
add_special_tokens
=
False
).
input_ids
if
tokenizer
.
bos_token_id
is
not
None
:
truth_tokens
.
insert
(
0
,
tokenizer
.
bos_token_id
)
truth_tokens
.
append
(
tokenizer
.
eos_token_id
)
new_truth
=
tokenizer
.
decode
(
truth_tokens
,
skip_special_tokens
=
skip_special_tokens
,
**
extra_decode_args
)
if
with_prompt
:
if
with_prompt
:
truth_tokens
=
tokenizer
(
truth
,
add_special_tokens
=
False
).
input_ids
num_prompt_tokens
=
len
(
prompt_input_ids
=
truth_tokens
[:
len
(
truth
)
//
2
]
tokenizer
(
truth
[:
len
(
truth
)
//
2
],
generated_input_ids
=
truth_tokens
[
len
(
truth
)
//
2
:]
add_special_tokens
=
False
).
input_ids
)
if
tokenizer
.
bos_token_id
is
not
None
:
num_prompt_tokens
+=
1
prompt_input_ids
=
truth_tokens
[:
num_prompt_tokens
]
generated_input_ids
=
truth_tokens
[
num_prompt_tokens
:]
all_input_ids
=
prompt_input_ids
+
generated_input_ids
all_input_ids
=
prompt_input_ids
+
generated_input_ids
starting_index
=
len
(
prompt_input_ids
)
starting_index
=
len
(
prompt_input_ids
)
prompt
=
tokenizer
.
decode
(
prompt_input_ids
,
prompt
=
tokenizer
.
decode
(
prompt_input_ids
,
skip_special_tokens
=
skip_special_tokens
)
skip_special_tokens
=
skip_special_tokens
,
generated
=
truth
[
len
(
prompt
):]
**
extra_decode_args
)
generated
=
new_truth
[
len
(
prompt
):]
else
:
else
:
generated
=
truth
generated
=
new_
truth
starting_index
=
0
starting_index
=
0
all_input_ids
=
tokenizer
(
truth
,
add_special_tokens
=
False
).
input_ids
all_input_ids
=
truth_tokens
if
skip_special_tokens
:
if
tokenizer
.
bos_token_id
is
not
None
:
all_input_ids
=
[
tokenizer
.
bos_token_id
]
+
all_input_ids
starting_index
+=
1
all_input_ids
=
all_input_ids
+
[
tokenizer
.
eos_token_id
]
decoded_text
=
_run_incremental_decode
(
decoded_text
,
out_ids
=
_run_incremental_decode
(
tokenizer
,
tokenizer
,
all_input_ids
,
all_input_ids
,
skip_special_tokens
=
skip_special_tokens
,
skip_special_tokens
=
skip_special_tokens
,
starting_index
=
starting_index
)
starting_index
=
starting_index
,
spaces_between_special_tokens
=
spaces_between_special_tokens
,
fast
=
fast
)
assert
decoded_text
==
generated
assert
decoded_text
==
generated
assert
out_ids
==
all_input_ids
[
starting_index
:]
@
pytest
.
mark
.
parametrize
(
"tokenizer_name"
,
TOKENIZERS
)
@
pytest
.
mark
.
parametrize
(
"fast"
,
(
True
,
False
))
def
test_oov_decode
(
tokenizer
,
fast
):
if
fast
and
not
isinstance
(
tokenizer
,
PreTrainedTokenizerFast
):
pytest
.
skip
()
decoded_text
=
_run_incremental_decode
(
decoded_text
,
out_ids
=
_run_incremental_decode
(
tokenizer
,
[
len
(
tokenizer
)],
tokenizer
,
[
len
(
tokenizer
)],
skip_special_tokens
=
skip_special_tokens
,
skip_special_tokens
=
True
,
starting_index
=
starting_index
)
starting_index
=
0
,
spaces_between_special_tokens
=
True
,
fast
=
fast
)
assert
decoded_text
==
''
assert
decoded_text
==
''
assert
out_ids
==
[
len
(
tokenizer
)]
@
pytest
.
fixture
@
pytest
.
fixture
def
detokenizer
(
tokenizer_name
:
str
)
->
Detokenizer
:
def
detokenizer
(
tokenizer_name
:
str
)
->
Detokenizer
:
init_kwargs
=
dict
(
tokenizer_group
=
TokenizerGroup
(
tokenizer_id
=
tokenizer_name
,
tokenizer_id
=
tokenizer_name
,
enable_lora
=
False
,
enable_lora
=
False
,
max_num_seqs
=
100
,
max_num_seqs
=
100
,
...
@@ -154,26 +222,20 @@ def detokenizer(tokenizer_name: str) -> Detokenizer:
...
@@ -154,26 +222,20 @@ def detokenizer(tokenizer_name: str) -> Detokenizer:
revision
=
None
,
revision
=
None
,
)
)
tokenizer_group
=
get_tokenizer_group
(
None
,
**
init_kwargs
,
)
return
Detokenizer
(
tokenizer_group
)
return
Detokenizer
(
tokenizer_group
)
@
pytest
.
fixture
(
name
=
"complete_sequence_token_ids"
)
@
pytest
.
fixture
(
name
=
"complete_sequence_token_ids"
)
def
create_complete_sequence_token_ids
(
complete_sequence
:
str
,
def
create_complete_sequence_token_ids
(
complete_sequence
:
str
,
tokenizer
)
->
list
[
int
]:
tokenizer
)
->
list
[
int
]:
complete_sequence_token_ids
=
tokenizer
(
complete_sequence
).
input_ids
return
tokenizer
(
complete_sequence
,
add_special_tokens
=
False
).
input_ids
return
complete_sequence_token_ids
def
create_sequence
(
prompt_token_ids
=
None
):
def
create_sequence
(
prompt_token_ids
=
None
):
prompt_token_ids
=
prompt_token_ids
or
[
1
]
prompt_token_ids
=
prompt_token_ids
or
[]
return
Sequence
(
return
Sequence
(
seq_id
=
0
,
seq_id
=
0
,
inputs
=
token_inputs
(
prompt_token_ids
,
prompt
=
"<s>"
),
inputs
=
token_inputs
(
prompt_token_ids
),
block_size
=
16
,
block_size
=
16
,
)
)
...
@@ -224,7 +286,7 @@ def test_decode_sequence_logprobs(complete_sequence: str,
...
@@ -224,7 +286,7 @@ def test_decode_sequence_logprobs(complete_sequence: str,
assert
sequential_result
==
""
.
join
(
sequential_logprobs_text_chosen_token
)
assert
sequential_result
==
""
.
join
(
sequential_logprobs_text_chosen_token
)
assert
sequential_result
!=
""
.
join
(
sequential_logprobs_text_other_token
)
assert
sequential_result
!=
""
.
join
(
sequential_logprobs_text_other_token
)
if
skip_special_tokens
:
if
not
skip_special_tokens
:
# Text for logprobs for the chosen token should be the same as the
# Text for logprobs for the chosen token should be the same as the
# generated text. Note that this will only be true if we skip
# generated text. Note that this will only be true if we skip
# special tokens.
# special tokens.
...
@@ -233,10 +295,23 @@ def test_decode_sequence_logprobs(complete_sequence: str,
...
@@ -233,10 +295,23 @@ def test_decode_sequence_logprobs(complete_sequence: str,
@
pytest
.
mark
.
parametrize
(
"complete_sequence"
,
TRUTH
)
@
pytest
.
mark
.
parametrize
(
"complete_sequence"
,
TRUTH
)
@
pytest
.
mark
.
parametrize
(
"tokenizer_name"
,
TOKENIZERS
)
@
pytest
.
mark
.
parametrize
(
"tokenizer_name"
,
TOKENIZERS
)
def
test_decode_prompt_logprobs
(
complete_sequence_token_ids
:
list
[
int
],
def
test_decode_prompt_logprobs
(
complete_sequence
:
str
,
complete_sequence_token_ids
:
list
[
int
],
detokenizer
:
Detokenizer
):
detokenizer
:
Detokenizer
):
# We want to use skip_special_tokens=False here but Mistral tokenizers
# don't support that.
if
complete_sequence
not
in
SPECIAL_TOKS_TRUTH
:
skip_special_tokens
=
True
elif
not
isinstance
(
detokenizer
.
tokenizer_group
.
get_lora_tokenizer
(
None
),
MistralTokenizer
):
skip_special_tokens
=
False
else
:
pytest
.
skip
(
"MistralTokenizers don't support "
"skip_special_tokens=False"
)
return
"""Verify Detokenizer decodes prompt logprobs correctly."""
"""Verify Detokenizer decodes prompt logprobs correctly."""
sampling_params
=
SamplingParams
(
skip_special_tokens
=
True
,
sampling_params
=
SamplingParams
(
skip_special_tokens
=
skip_special_tokens
,
prompt_logprobs
=
1
)
prompt_logprobs
=
1
)
# Run sequentially.
# Run sequentially.
...
@@ -256,8 +331,10 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int],
...
@@ -256,8 +331,10 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int],
# decoded_prompt_logprobs doesn't contain the first token.
# decoded_prompt_logprobs doesn't contain the first token.
token_ids
=
complete_sequence_token_ids
token_ids
=
complete_sequence_token_ids
tokenizer
=
detokenizer
.
get_tokenizer_for_seq
(
seq
)
tokenizer
=
detokenizer
.
get_tokenizer_for_seq
(
seq
)
text_full
=
tokenizer
.
decode
(
token_ids
,
skip_special_tokens
=
True
)
text_full
=
tokenizer
.
decode
(
token_ids
,
text_first
=
tokenizer
.
decode
(
token_ids
[
0
],
skip_special_tokens
=
True
)
skip_special_tokens
=
skip_special_tokens
)
text_first
=
tokenizer
.
decode
(
token_ids
[
0
],
skip_special_tokens
=
skip_special_tokens
)
text
=
text_full
[
len
(
text_first
):]
text
=
text_full
[
len
(
text_first
):]
# Text for logprobs for the chosen token should be the same as the
# Text for logprobs for the chosen token should be the same as the
...
...
tests/tokenization/test_tokenizer_group.py
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
os
import
sys
from
typing
import
Optional
from
unittest.mock
import
patch
import
pytest
import
pytest
from
transformers
import
AutoTokenizer
,
PreTrainedTokenizerBase
from
transformers
import
AutoTokenizer
,
PreTrainedTokenizerBase
from
vllm.transformers_utils.tokenizer_group
import
(
TokenizerGroup
,
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
get_tokenizer_group
)
from
vllm.transformers_utils.tokenizer_group.ray_tokenizer_group
import
(
RayTokenizerGroupPool
)
from
..conftest
import
get_tokenizer_pool_config
class
CustomTokenizerGroup
(
TokenizerGroup
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
_i
=
0
def
encode
(
self
,
*
args
,
**
kwargs
):
self
.
_i
+=
1
return
super
().
encode
(
*
args
,
**
kwargs
)
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"tokenizer_group_type"
,
async
def
test_tokenizer_group
():
[
None
,
"ray"
,
CustomTokenizerGroup
])
async
def
test_tokenizer_group
(
tokenizer_group_type
):
reference_tokenizer
=
AutoTokenizer
.
from_pretrained
(
"gpt2"
)
reference_tokenizer
=
AutoTokenizer
.
from_pretrained
(
"gpt2"
)
tokenizer_group
=
get_tokenizer_group
(
tokenizer_group
=
TokenizerGroup
(
get_tokenizer_pool_config
(
tokenizer_group_type
),
tokenizer_id
=
"gpt2"
,
tokenizer_id
=
"gpt2"
,
enable_lora
=
False
,
enable_lora
=
False
,
max_num_seqs
=
1
,
max_num_seqs
=
1
,
...
@@ -49,159 +24,3 @@ async def test_tokenizer_group(tokenizer_group_type):
...
@@ -49,159 +24,3 @@ async def test_tokenizer_group(tokenizer_group_type):
PreTrainedTokenizerBase
)
PreTrainedTokenizerBase
)
assert
tokenizer_group
.
get_lora_tokenizer
(
assert
tokenizer_group
.
get_lora_tokenizer
(
None
)
==
await
tokenizer_group
.
get_lora_tokenizer_async
(
None
)
None
)
==
await
tokenizer_group
.
get_lora_tokenizer_async
(
None
)
if
tokenizer_group_type
is
CustomTokenizerGroup
:
assert
tokenizer_group
.
_i
>
0
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"tokenizer_group_type"
,
[
"ray"
])
async
def
test_tokenizer_group_pool
(
tokenizer_group_type
):
reference_tokenizer
=
AutoTokenizer
.
from_pretrained
(
"gpt2"
)
tokenizer_group_pool
=
get_tokenizer_group
(
get_tokenizer_pool_config
(
tokenizer_group_type
),
tokenizer_id
=
"gpt2"
,
enable_lora
=
False
,
max_num_seqs
=
1
,
max_input_length
=
None
,
)
# Send multiple requests to the tokenizer group pool
# (more than the pool size)
# and check that all requests are processed correctly.
num_requests
=
tokenizer_group_pool
.
pool_size
*
5
requests
=
[
tokenizer_group_pool
.
encode_async
(
prompt
=
f
"prompt
{
i
}
"
,
lora_request
=
None
)
for
i
in
range
(
num_requests
)
]
results
=
await
asyncio
.
gather
(
*
requests
)
expected_results
=
[
reference_tokenizer
.
encode
(
f
"prompt
{
i
}
"
)
for
i
in
range
(
num_requests
)
]
assert
results
==
expected_results
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"tokenizer_group_type"
,
[
"ray"
])
async
def
test_tokenizer_group_ray_pool_env_var_propagation
(
tokenizer_group_type
):
"""Test that env vars from caller process are propagated to
tokenizer Ray actors."""
env_var
=
"MY_ENV_VAR"
class
EnvVarCheckerTokenizerGroup
(
TokenizerGroup
):
def
ping
(
self
):
assert
os
.
environ
.
get
(
env_var
)
==
"1"
return
super
().
ping
()
class
EnvVarCheckerRayTokenizerGroupPool
(
RayTokenizerGroupPool
):
_worker_cls
=
EnvVarCheckerTokenizerGroup
tokenizer_pool_config
=
get_tokenizer_pool_config
(
tokenizer_group_type
)
tokenizer_pool
=
EnvVarCheckerRayTokenizerGroupPool
.
from_config
(
tokenizer_pool_config
,
tokenizer_id
=
"gpt2"
,
enable_lora
=
False
,
max_num_seqs
=
1
,
max_input_length
=
None
)
with
pytest
.
raises
(
AssertionError
):
tokenizer_pool
.
ping
()
with
patch
.
dict
(
os
.
environ
,
{
env_var
:
"1"
}):
tokenizer_pool_config
=
get_tokenizer_pool_config
(
tokenizer_group_type
)
tokenizer_pool
=
EnvVarCheckerRayTokenizerGroupPool
.
from_config
(
tokenizer_pool_config
,
tokenizer_id
=
"gpt2"
,
enable_lora
=
False
,
max_num_seqs
=
1
,
max_input_length
=
None
)
tokenizer_pool
.
ping
()
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"tokenizer_group_type"
,
[
"ray"
])
async
def
test_tokenizer_group_ray_pool_fault_tolerance
(
tokenizer_group_type
):
"""Test that Ray tokenizer pool group can recover from failures and
if that's not possible, mark itself as unhealthy."""
class
FailingTokenizerGroup
(
TokenizerGroup
):
def
__init__
(
self
,
*
args
,
fail_at
:
Optional
[
list
[
int
]]
=
None
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
i
=
0
self
.
fail_at
=
fail_at
or
[]
def
encode
(
self
,
*
args
,
**
kwargs
):
self
.
i
+=
1
if
self
.
i
in
self
.
fail_at
:
sys
.
exit
(
1
)
return
super
().
encode
(
*
args
,
**
kwargs
)
class
FailingRayTokenizerGroupPool
(
RayTokenizerGroupPool
):
_worker_cls
=
FailingTokenizerGroup
# Fail at first iteration
fail_at
=
[
1
]
tokenizer_pool_config
=
get_tokenizer_pool_config
(
tokenizer_group_type
)
tokenizer_group_pool
=
FailingRayTokenizerGroupPool
.
from_config
(
tokenizer_pool_config
,
tokenizer_id
=
"gpt2"
,
enable_lora
=
False
,
max_num_seqs
=
1
,
max_input_length
=
None
,
fail_at
=
fail_at
)
tokenizer_actors
=
tokenizer_group_pool
.
tokenizer_actors
.
copy
()
# Modify fail at to not fail at all (will be re-read when actor is
# re-initialized).
fail_at
[
0
]
=
1000
# We should recover successfully.
await
tokenizer_group_pool
.
encode_async
(
prompt
=
"prompt"
,
lora_request
=
None
)
await
tokenizer_group_pool
.
encode_async
(
prompt
=
"prompt"
,
lora_request
=
None
)
# Check that we have a new actor
assert
len
(
tokenizer_group_pool
.
tokenizer_actors
)
==
len
(
tokenizer_actors
)
assert
tokenizer_group_pool
.
tokenizer_actors
!=
tokenizer_actors
# Fail at first iteration
fail_at
=
[
1
]
tokenizer_group_pool
=
FailingRayTokenizerGroupPool
.
from_config
(
tokenizer_pool_config
,
tokenizer_id
=
"gpt2"
,
enable_lora
=
False
,
max_num_seqs
=
1
,
max_input_length
=
None
,
fail_at
=
fail_at
)
# We should fail after re-initialization.
with
pytest
.
raises
(
RuntimeError
):
await
tokenizer_group_pool
.
encode_async
(
prompt
=
"prompt"
,
lora_request
=
None
)
# check_health should raise the same thing
with
pytest
.
raises
(
RuntimeError
):
tokenizer_group_pool
.
check_health
()
# Ensure that non-ActorDiedErrors are still propagated correctly and do not
# cause a re-initialization.
fail_at
=
[]
tokenizer_group_pool
=
FailingRayTokenizerGroupPool
.
from_config
(
tokenizer_pool_config
,
tokenizer_id
=
"gpt2"
,
enable_lora
=
False
,
max_num_seqs
=
1
,
max_input_length
=
2
,
fail_at
=
fail_at
)
tokenizer_actors
=
tokenizer_group_pool
.
tokenizer_actors
.
copy
()
# Prompt too long error
with
pytest
.
raises
(
ValueError
):
await
tokenizer_group_pool
.
encode_async
(
prompt
=
"prompt"
*
100
,
lora_request
=
None
)
await
tokenizer_group_pool
.
encode_async
(
prompt
=
"prompt"
,
lora_request
=
None
)
# Actors should stay the same.
assert
tokenizer_group_pool
.
tokenizer_actors
==
tokenizer_actors
tests/tool_use/utils.py
View file @
081057de
...
@@ -98,6 +98,20 @@ CONFIGS: dict[str, ServerConfig] = {
...
@@ -98,6 +98,20 @@ CONFIGS: dict[str, ServerConfig] = {
"extended"
:
"extended"
:
True
True
},
},
"llama4_json"
:
{
"model"
:
"meta-llama/Llama-4-Scout-17B-16E-Instruct"
,
"arguments"
:
[
"--enforce-eager"
,
"--no-enable-prefix-caching"
,
"-tp"
,
"4"
,
"--distributed-executor-backend"
,
"mp"
,
"--tool-call-parser"
,
"llama4_json"
,
"--chat-template"
,
str
(
VLLM_PATH
/
"examples/tool_chat_template_llama4_json.jinja"
)
],
"supports_parallel"
:
True
,
"extended"
:
True
},
"mistral"
:
{
"mistral"
:
{
"model"
:
"model"
:
"mistralai/Mistral-7B-Instruct-v0.3"
,
"mistralai/Mistral-7B-Instruct-v0.3"
,
...
...
Prev
1
…
12
13
14
15
16
17
18
19
20
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment