Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
afd0da21
Commit
afd0da21
authored
Feb 03, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.7.1' into v0.7.1-dev
parents
1a11f127
4f4d427a
Changes
587
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1427 additions
and
259 deletions
+1427
-259
tests/models/multimodal/processing/test_phi3v.py
tests/models/multimodal/processing/test_phi3v.py
+55
-0
tests/models/multimodal/processing/test_qwen2_vl.py
tests/models/multimodal/processing/test_qwen2_vl.py
+54
-0
tests/models/registry.py
tests/models/registry.py
+84
-6
tests/models/test_initialization.py
tests/models/test_initialization.py
+5
-6
tests/models/test_registry.py
tests/models/test_registry.py
+3
-0
tests/multi_step/test_correctness_async_llm.py
tests/multi_step/test_correctness_async_llm.py
+3
-3
tests/multi_step/test_correctness_llm.py
tests/multi_step/test_correctness_llm.py
+16
-1
tests/multimodal/test_processing.py
tests/multimodal/test_processing.py
+217
-66
tests/multimodal/test_utils.py
tests/multimodal/test_utils.py
+242
-26
tests/multimodal/utils.py
tests/multimodal/utils.py
+33
-0
tests/neuron/test_prefix_prefill.py
tests/neuron/test_prefix_prefill.py
+456
-0
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py
...ins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py
+6
-4
tests/plugins/vllm_add_dummy_platform/setup.py
tests/plugins/vllm_add_dummy_platform/setup.py
+11
-0
tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/__init__.py
...lm_add_dummy_platform/vllm_add_dummy_platform/__init__.py
+5
-0
tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_attention_backend.py
...atform/vllm_add_dummy_platform/dummy_attention_backend.py
+8
-0
tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py
..._dummy_platform/vllm_add_dummy_platform/dummy_platform.py
+9
-0
tests/plugins_tests/test_platform_plugins.py
tests/plugins_tests/test_platform_plugins.py
+30
-0
tests/quantization/test_compressed_tensors.py
tests/quantization/test_compressed_tensors.py
+141
-107
tests/quantization/test_fp8.py
tests/quantization/test_fp8.py
+29
-23
tests/quantization/test_lm_head.py
tests/quantization/test_lm_head.py
+20
-17
No files found.
Too many changes to show.
To preserve performance only
587 of 587+
files are displayed.
Plain diff
Email patch
tests/models/multimodal/processing/test_phi3v.py
0 → 100644
View file @
afd0da21
"""Tests for phi3v's multimodal preprocessing kwargs."""
import
pytest
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
....conftest
import
_ImageAssets
from
...utils
import
build_model_context
@
pytest
.
mark
.
parametrize
(
"model_id"
,
[
"microsoft/Phi-3.5-vision-instruct"
])
# yapf: disable
@
pytest
.
mark
.
parametrize
(
(
"mm_processor_kwargs"
,
"expected_toks_per_img"
),
[
({
"num_crops"
:
4
},
757
),
({
"num_crops"
:
16
},
1921
),
# the default num_crops of phi-3.5-vision is 4
({},
757
),
])
# yapf: enable
@
pytest
.
mark
.
parametrize
(
"num_imgs"
,
[
1
,
2
])
def
test_processor_override
(
image_assets
:
_ImageAssets
,
model_id
:
str
,
mm_processor_kwargs
:
dict
[
str
,
int
],
expected_toks_per_img
:
int
,
num_imgs
:
int
,
):
"""Ensure input_processor_for_phi3v handles num_crops properly."""
# Avoid initializing CUDA early
from
vllm.model_executor.models.phi3v
import
_IMAGE_TOKEN_ID
ctx
=
build_model_context
(
model_name
=
model_id
,
tokenizer_name
=
model_id
,
trust_remote_code
=
True
,
limit_mm_per_prompt
=
{
"image"
:
num_imgs
},
)
tokenizer
=
cached_get_tokenizer
(
ctx
.
model_config
.
tokenizer
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
ctx
.
model_config
,
tokenizer
=
tokenizer
,
)
# Build the image str / prompt based on the number of images we pass
img_str
=
""
.
join
([
f
"<|image_
{
idx
}
|>
\n
"
for
idx
in
range
(
1
,
num_imgs
+
1
)])
prompt
=
f
"<|user|>
\n
{
img_str
}
<|end|>
\n
<|assistant|>
\n
"
mm_data
=
{
"image"
:
[
image_assets
[
0
].
pil_image
]
*
num_imgs
}
processed_inputs
=
processor
.
apply
(
prompt
,
mm_data
,
mm_processor_kwargs
)
# Ensure we have the right number of placeholders per num_crops size
img_tok_count
=
processed_inputs
[
"prompt_token_ids"
].
count
(
_IMAGE_TOKEN_ID
)
assert
img_tok_count
==
expected_toks_per_img
*
num_imgs
tests/models/multimodal/processing/test_qwen2_vl.py
0 → 100644
View file @
afd0da21
import
pytest
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
....conftest
import
_ImageAssets
from
...utils
import
build_model_context
@
pytest
.
mark
.
parametrize
(
"model_id"
,
[
"Qwen/Qwen2-VL-2B-Instruct"
])
# yapf: disable
@
pytest
.
mark
.
parametrize
(
(
"mm_processor_kwargs"
,
"expected_toks_per_img"
,
"expected_pixels_shape"
),
[
({},
1426
,
(
5704
,
1176
)),
({
"min_pixels"
:
64
**
2
,
"max_pixels"
:
512
**
2
},
330
,
(
1320
,
1176
)),
])
# yapf: enable
@
pytest
.
mark
.
parametrize
(
"num_imgs"
,
[
1
,
2
])
def
test_processor_override
(
image_assets
:
_ImageAssets
,
model_id
:
str
,
mm_processor_kwargs
:
dict
[
str
,
object
],
expected_toks_per_img
:
int
,
expected_pixels_shape
:
tuple
[
int
,
int
],
num_imgs
:
int
,
):
"""Ensure Qwen2VLMultiModalProcessor handles min/max pixels properly."""
ctx
=
build_model_context
(
model_name
=
model_id
,
tokenizer_name
=
model_id
,
mm_processor_kwargs
=
None
,
limit_mm_per_prompt
=
{
"image"
:
num_imgs
},
)
tokenizer
=
cached_get_tokenizer
(
ctx
.
model_config
.
tokenizer
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
ctx
.
model_config
,
tokenizer
=
tokenizer
,
)
# Build the image str / prompt based on the number of images we pass
prompt
=
"<|vision_start|><|image_pad|><|vision_end|>"
*
num_imgs
mm_data
=
{
"image"
:
[
image_assets
[
0
].
pil_image
]
*
num_imgs
}
processed_inputs
=
processor
.
apply
(
prompt
,
mm_data
,
mm_processor_kwargs
)
# Ensure we have the right number of placeholders per num_crops size
hf_processor
=
processor
.
info
.
get_hf_processor
(
**
mm_processor_kwargs
)
image_token_id
=
tokenizer
.
convert_tokens_to_ids
(
hf_processor
.
image_token
)
img_tok_count
=
processed_inputs
[
"prompt_token_ids"
].
count
(
image_token_id
)
pixel_shape
=
processed_inputs
[
"mm_kwargs"
][
"pixel_values"
].
shape
assert
img_tok_count
==
expected_toks_per_img
*
num_imgs
assert
pixel_shape
[
0
]
==
expected_pixels_shape
[
0
]
*
num_imgs
assert
pixel_shape
[
1
]
==
expected_pixels_shape
[
1
]
tests/models/registry.py
View file @
afd0da21
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
AbstractSet
,
Mapping
,
Optional
from
typing
import
AbstractSet
,
Any
,
Literal
,
Mapping
,
Optional
import
pytest
from
packaging.version
import
Version
from
transformers
import
__version__
as
TRANSFORMERS_VERSION
@
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
...
@@ -22,6 +26,11 @@ class _HfExamplesInfo:
...
@@ -22,6 +26,11 @@ class _HfExamplesInfo:
for speculative decoding.
for speculative decoding.
"""
"""
min_transformers_version
:
Optional
[
str
]
=
None
"""
The minimum version of HF Transformers that is required to run this model.
"""
is_available_online
:
bool
=
True
is_available_online
:
bool
=
True
"""
"""
Set this to ``False`` if the name of this architecture no longer exists on
Set this to ``False`` if the name of this architecture no longer exists on
...
@@ -33,6 +42,50 @@ class _HfExamplesInfo:
...
@@ -33,6 +42,50 @@ class _HfExamplesInfo:
trust_remote_code
:
bool
=
False
trust_remote_code
:
bool
=
False
"""The ``trust_remote_code`` level required to load the model."""
"""The ``trust_remote_code`` level required to load the model."""
hf_overrides
:
dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
"""The ``hf_overrides`` required to load the model."""
def
check_transformers_version
(
self
,
*
,
on_fail
:
Literal
[
"error"
,
"skip"
],
)
->
None
:
"""
If the installed transformers version does not meet the requirements,
perform the given action.
"""
if
self
.
min_transformers_version
is
None
:
return
current_version
=
TRANSFORMERS_VERSION
required_version
=
self
.
min_transformers_version
if
Version
(
current_version
)
<
Version
(
required_version
):
msg
=
(
f
"You have `transformers==
{
current_version
}
` installed, but "
f
"`transformers>=
{
required_version
}
` is required to run this "
"model"
)
if
on_fail
==
"error"
:
raise
RuntimeError
(
msg
)
else
:
pytest
.
skip
(
msg
)
def
check_available_online
(
self
,
*
,
on_fail
:
Literal
[
"error"
,
"skip"
],
)
->
None
:
"""
If the model is not available online, perform the given action.
"""
if
not
self
.
is_available_online
:
msg
=
"Model is not available online"
if
on_fail
==
"error"
:
raise
RuntimeError
(
msg
)
else
:
pytest
.
skip
(
msg
)
# yapf: disable
# yapf: disable
_TEXT_GENERATION_EXAMPLE_MODELS
=
{
_TEXT_GENERATION_EXAMPLE_MODELS
=
{
...
@@ -43,8 +96,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
...
@@ -43,8 +96,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"ArcticForCausalLM"
:
_HfExamplesInfo
(
"Snowflake/snowflake-arctic-instruct"
,
"ArcticForCausalLM"
:
_HfExamplesInfo
(
"Snowflake/snowflake-arctic-instruct"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"AriaForConditionalGeneration"
:
_HfExamplesInfo
(
"rhymes-ai/Aria"
,
trust_remote_code
=
True
),
"BaiChuanForCausalLM"
:
_HfExamplesInfo
(
"baichuan-inc/Baichuan-7B"
,
"BaiChuanForCausalLM"
:
_HfExamplesInfo
(
"baichuan-inc/Baichuan-7B"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"BaichuanForCausalLM"
:
_HfExamplesInfo
(
"baichuan-inc/Baichuan2-7B-chat"
,
"BaichuanForCausalLM"
:
_HfExamplesInfo
(
"baichuan-inc/Baichuan2-7B-chat"
,
...
@@ -64,6 +115,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
...
@@ -64,6 +115,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"DeepseekV3ForCausalLM"
:
_HfExamplesInfo
(
"deepseek-ai/DeepSeek-V3"
,
# noqa: E501
"DeepseekV3ForCausalLM"
:
_HfExamplesInfo
(
"deepseek-ai/DeepSeek-V3"
,
# noqa: E501
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"ExaoneForCausalLM"
:
_HfExamplesInfo
(
"LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"
),
# noqa: E501
"ExaoneForCausalLM"
:
_HfExamplesInfo
(
"LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"
),
# noqa: E501
"Fairseq2LlamaForCausalLM"
:
_HfExamplesInfo
(
"mgleize/fairseq2-dummy-Llama-3.2-1B"
),
# noqa: E501
"FalconForCausalLM"
:
_HfExamplesInfo
(
"tiiuae/falcon-7b"
),
"FalconForCausalLM"
:
_HfExamplesInfo
(
"tiiuae/falcon-7b"
),
"GemmaForCausalLM"
:
_HfExamplesInfo
(
"google/gemma-2b"
),
"GemmaForCausalLM"
:
_HfExamplesInfo
(
"google/gemma-2b"
),
"Gemma2ForCausalLM"
:
_HfExamplesInfo
(
"google/gemma-2-9b"
),
"Gemma2ForCausalLM"
:
_HfExamplesInfo
(
"google/gemma-2-9b"
),
...
@@ -80,6 +132,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
...
@@ -80,6 +132,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"InternLM2VEForCausalLM"
:
_HfExamplesInfo
(
"OpenGVLab/Mono-InternVL-2B"
,
"InternLM2VEForCausalLM"
:
_HfExamplesInfo
(
"OpenGVLab/Mono-InternVL-2B"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"InternLM3ForCausalLM"
:
_HfExamplesInfo
(
"internlm/internlm3-8b-instruct"
,
trust_remote_code
=
True
),
"JAISLMHeadModel"
:
_HfExamplesInfo
(
"inceptionai/jais-13b-chat"
),
"JAISLMHeadModel"
:
_HfExamplesInfo
(
"inceptionai/jais-13b-chat"
),
"JambaForCausalLM"
:
_HfExamplesInfo
(
"ai21labs/AI21-Jamba-1.5-Mini"
),
"JambaForCausalLM"
:
_HfExamplesInfo
(
"ai21labs/AI21-Jamba-1.5-Mini"
),
"LlamaForCausalLM"
:
_HfExamplesInfo
(
"meta-llama/Meta-Llama-3-8B"
),
"LlamaForCausalLM"
:
_HfExamplesInfo
(
"meta-llama/Meta-Llama-3-8B"
),
...
@@ -140,11 +194,14 @@ _EMBEDDING_EXAMPLE_MODELS = {
...
@@ -140,11 +194,14 @@ _EMBEDDING_EXAMPLE_MODELS = {
"BertModel"
:
_HfExamplesInfo
(
"BAAI/bge-base-en-v1.5"
),
"BertModel"
:
_HfExamplesInfo
(
"BAAI/bge-base-en-v1.5"
),
"Gemma2Model"
:
_HfExamplesInfo
(
"BAAI/bge-multilingual-gemma2"
),
"Gemma2Model"
:
_HfExamplesInfo
(
"BAAI/bge-multilingual-gemma2"
),
"GritLM"
:
_HfExamplesInfo
(
"parasail-ai/GritLM-7B-vllm"
),
"GritLM"
:
_HfExamplesInfo
(
"parasail-ai/GritLM-7B-vllm"
),
"InternLM2ForRewardModel"
:
_HfExamplesInfo
(
"internlm/internlm2-1_8b-reward"
,
trust_remote_code
=
True
),
"JambaForSequenceClassification"
:
_HfExamplesInfo
(
"ai21labs/Jamba-tiny-reward-dev"
),
# noqa: E501
"JambaForSequenceClassification"
:
_HfExamplesInfo
(
"ai21labs/Jamba-tiny-reward-dev"
),
# noqa: E501
"LlamaModel"
:
_HfExamplesInfo
(
"llama"
,
is_available_online
=
False
),
"LlamaModel"
:
_HfExamplesInfo
(
"llama"
,
is_available_online
=
False
),
"MistralModel"
:
_HfExamplesInfo
(
"intfloat/e5-mistral-7b-instruct"
),
"MistralModel"
:
_HfExamplesInfo
(
"intfloat/e5-mistral-7b-instruct"
),
"Qwen2Model"
:
_HfExamplesInfo
(
"ssmits/Qwen2-7B-Instruct-embed-base"
),
"Qwen2Model"
:
_HfExamplesInfo
(
"ssmits/Qwen2-7B-Instruct-embed-base"
),
"Qwen2ForRewardModel"
:
_HfExamplesInfo
(
"Qwen/Qwen2.5-Math-RM-72B"
),
"Qwen2ForRewardModel"
:
_HfExamplesInfo
(
"Qwen/Qwen2.5-Math-RM-72B"
),
"Qwen2ForProcessRewardModel"
:
_HfExamplesInfo
(
"Qwen/Qwen2.5-Math-PRM-7B"
),
"Qwen2ForSequenceClassification"
:
_HfExamplesInfo
(
"jason9693/Qwen2.5-1.5B-apeach"
),
# noqa: E501
"Qwen2ForSequenceClassification"
:
_HfExamplesInfo
(
"jason9693/Qwen2.5-1.5B-apeach"
),
# noqa: E501
"RobertaModel"
:
_HfExamplesInfo
(
"sentence-transformers/stsb-roberta-base-v2"
),
# noqa: E501
"RobertaModel"
:
_HfExamplesInfo
(
"sentence-transformers/stsb-roberta-base-v2"
),
# noqa: E501
"RobertaForMaskedLM"
:
_HfExamplesInfo
(
"sentence-transformers/all-roberta-large-v1"
),
# noqa: E501
"RobertaForMaskedLM"
:
_HfExamplesInfo
(
"sentence-transformers/all-roberta-large-v1"
),
# noqa: E501
...
@@ -165,6 +222,8 @@ _CROSS_ENCODER_EXAMPLE_MODELS = {
...
@@ -165,6 +222,8 @@ _CROSS_ENCODER_EXAMPLE_MODELS = {
_MULTIMODAL_EXAMPLE_MODELS
=
{
_MULTIMODAL_EXAMPLE_MODELS
=
{
# [Decoder-only]
# [Decoder-only]
"AriaForConditionalGeneration"
:
_HfExamplesInfo
(
"rhymes-ai/Aria"
,
min_transformers_version
=
"4.48"
),
"Blip2ForConditionalGeneration"
:
_HfExamplesInfo
(
"Salesforce/blip2-opt-2.7b"
),
# noqa: E501
"Blip2ForConditionalGeneration"
:
_HfExamplesInfo
(
"Salesforce/blip2-opt-2.7b"
),
# noqa: E501
"ChameleonForConditionalGeneration"
:
_HfExamplesInfo
(
"facebook/chameleon-7b"
),
# noqa: E501
"ChameleonForConditionalGeneration"
:
_HfExamplesInfo
(
"facebook/chameleon-7b"
),
# noqa: E501
"ChatGLMModel"
:
_HfExamplesInfo
(
"THUDM/glm-4v-9b"
,
"ChatGLMModel"
:
_HfExamplesInfo
(
"THUDM/glm-4v-9b"
,
...
@@ -172,6 +231,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
...
@@ -172,6 +231,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"ChatGLMForConditionalGeneration"
:
_HfExamplesInfo
(
"chatglm2-6b"
,
"ChatGLMForConditionalGeneration"
:
_HfExamplesInfo
(
"chatglm2-6b"
,
is_available_online
=
False
),
is_available_online
=
False
),
"DeepseekVLV2ForCausalLM"
:
_HfExamplesInfo
(
"deepseek-ai/deepseek-vl2-tiny"
,
# noqa: E501
hf_overrides
=
{
"architectures"
:
[
"DeepseekVLV2ForCausalLM"
]}),
# noqa: E501
"FuyuForCausalLM"
:
_HfExamplesInfo
(
"adept/fuyu-8b"
),
"FuyuForCausalLM"
:
_HfExamplesInfo
(
"adept/fuyu-8b"
),
"H2OVLChatModel"
:
_HfExamplesInfo
(
"h2oai/h2ovl-mississippi-800m"
),
"H2OVLChatModel"
:
_HfExamplesInfo
(
"h2oai/h2ovl-mississippi-800m"
),
"InternVLChatModel"
:
_HfExamplesInfo
(
"OpenGVLab/InternVL2-1B"
,
"InternVLChatModel"
:
_HfExamplesInfo
(
"OpenGVLab/InternVL2-1B"
,
...
@@ -182,8 +243,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
...
@@ -182,8 +243,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"LlavaNextForConditionalGeneration"
:
_HfExamplesInfo
(
"llava-hf/llava-v1.6-mistral-7b-hf"
),
# noqa: E501
"LlavaNextForConditionalGeneration"
:
_HfExamplesInfo
(
"llava-hf/llava-v1.6-mistral-7b-hf"
),
# noqa: E501
"LlavaNextVideoForConditionalGeneration"
:
_HfExamplesInfo
(
"llava-hf/LLaVA-NeXT-Video-7B-hf"
),
# noqa: E501
"LlavaNextVideoForConditionalGeneration"
:
_HfExamplesInfo
(
"llava-hf/LLaVA-NeXT-Video-7B-hf"
),
# noqa: E501
"LlavaOnevisionForConditionalGeneration"
:
_HfExamplesInfo
(
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf"
),
# noqa: E501
"LlavaOnevisionForConditionalGeneration"
:
_HfExamplesInfo
(
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf"
),
# noqa: E501
"MantisForConditionalGeneration"
:
_HfExamplesInfo
(
"TIGER-Lab/Mantis-8B-siglip-llama3"
),
# noqa: E501
"MantisForConditionalGeneration"
:
_HfExamplesInfo
(
"TIGER-Lab/Mantis-8B-siglip-llama3"
,
# noqa: E501
"MiniCPMV"
:
_HfExamplesInfo
(
"openbmb/MiniCPM-Llama3-V-2_5"
,
hf_overrides
=
{
"architectures"
:
[
"MantisForConditionalGeneration"
]}),
# noqa: E501
"MiniCPMO"
:
_HfExamplesInfo
(
"openbmb/MiniCPM-o-2_6"
,
trust_remote_code
=
True
),
"MiniCPMV"
:
_HfExamplesInfo
(
"openbmb/MiniCPM-V-2_6"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"MolmoForCausalLM"
:
_HfExamplesInfo
(
"allenai/Molmo-7B-D-0924"
,
"MolmoForCausalLM"
:
_HfExamplesInfo
(
"allenai/Molmo-7B-D-0924"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
...
@@ -199,9 +263,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
...
@@ -199,9 +263,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"Qwen2AudioForConditionalGeneration"
:
_HfExamplesInfo
(
"Qwen/Qwen2-Audio-7B-Instruct"
),
# noqa: E501
"Qwen2AudioForConditionalGeneration"
:
_HfExamplesInfo
(
"Qwen/Qwen2-Audio-7B-Instruct"
),
# noqa: E501
"Qwen2VLForConditionalGeneration"
:
_HfExamplesInfo
(
"Qwen/Qwen2-VL-2B-Instruct"
),
# noqa: E501
"Qwen2VLForConditionalGeneration"
:
_HfExamplesInfo
(
"Qwen/Qwen2-VL-2B-Instruct"
),
# noqa: E501
"UltravoxModel"
:
_HfExamplesInfo
(
"fixie-ai/ultravox-v0_3"
),
"UltravoxModel"
:
_HfExamplesInfo
(
"fixie-ai/ultravox-v0_3"
,
trust_remote_code
=
True
),
# [Encoder-decoder]
# [Encoder-decoder]
"MllamaForConditionalGeneration"
:
_HfExamplesInfo
(
"meta-llama/Llama-3.2-11B-Vision-Instruct"
),
# noqa: E501
"MllamaForConditionalGeneration"
:
_HfExamplesInfo
(
"meta-llama/Llama-3.2-11B-Vision-Instruct"
),
# noqa: E501
"WhisperForConditionalGeneration"
:
_HfExamplesInfo
(
"openai/whisper-large-v3"
),
# noqa: E501
}
}
_SPECULATIVE_DECODING_EXAMPLE_MODELS
=
{
_SPECULATIVE_DECODING_EXAMPLE_MODELS
=
{
...
@@ -234,5 +300,17 @@ class HfExampleModels:
...
@@ -234,5 +300,17 @@ class HfExampleModels:
def
get_hf_info
(
self
,
model_arch
:
str
)
->
_HfExamplesInfo
:
def
get_hf_info
(
self
,
model_arch
:
str
)
->
_HfExamplesInfo
:
return
self
.
hf_models
[
model_arch
]
return
self
.
hf_models
[
model_arch
]
def
find_hf_info
(
self
,
model_id
:
str
)
->
_HfExamplesInfo
:
for
info
in
self
.
hf_models
.
values
():
if
info
.
default
==
model_id
:
return
info
# Fallback to extras
for
info
in
self
.
hf_models
.
values
():
if
any
(
extra
==
model_id
for
extra
in
info
.
extras
.
values
()):
return
info
raise
ValueError
(
f
"No example model defined for
{
model_id
}
"
)
HF_EXAMPLE_MODELS
=
HfExampleModels
(
_EXAMPLE_MODELS
)
HF_EXAMPLE_MODELS
=
HfExampleModels
(
_EXAMPLE_MODELS
)
tests/models/test_initialization.py
View file @
afd0da21
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
pytest
import
pytest
import
transformers
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm
import
LLM
from
vllm
import
LLM
...
@@ -12,14 +11,14 @@ from .registry import HF_EXAMPLE_MODELS
...
@@ -12,14 +11,14 @@ from .registry import HF_EXAMPLE_MODELS
@
pytest
.
mark
.
parametrize
(
"model_arch"
,
HF_EXAMPLE_MODELS
.
get_supported_archs
())
@
pytest
.
mark
.
parametrize
(
"model_arch"
,
HF_EXAMPLE_MODELS
.
get_supported_archs
())
def
test_can_initialize
(
model_arch
):
def
test_can_initialize
(
model_arch
):
model_info
=
HF_EXAMPLE_MODELS
.
get_hf_info
(
model_arch
)
model_info
=
HF_EXAMPLE_MODELS
.
get_hf_info
(
model_arch
)
if
(
model_arch
==
"Cohere2ForCausalLM"
model_info
.
check_available_online
(
on_fail
=
"skip"
)
and
transformers
.
__version__
<
"4.48.0"
):
model_info
.
check_transformers_version
(
on_fail
=
"skip"
)
pytest
.
skip
(
reason
=
"Model introduced in HF >= 4.48.0"
)
if
not
model_info
.
is_available_online
:
pytest
.
skip
(
"Model is not available online"
)
# Avoid OOM
# Avoid OOM
def
hf_overrides
(
hf_config
:
PretrainedConfig
)
->
PretrainedConfig
:
def
hf_overrides
(
hf_config
:
PretrainedConfig
)
->
PretrainedConfig
:
if
hf_config
.
model_type
==
"deepseek_vl_v2"
:
hf_config
.
update
({
"architectures"
:
[
"DeepseekVLV2ForCausalLM"
]})
if
hasattr
(
hf_config
,
"text_config"
):
if
hasattr
(
hf_config
,
"text_config"
):
text_config
:
PretrainedConfig
=
hf_config
.
text_config
text_config
:
PretrainedConfig
=
hf_config
.
text_config
else
:
else
:
...
...
tests/models/test_registry.py
View file @
afd0da21
...
@@ -21,6 +21,9 @@ from .registry import HF_EXAMPLE_MODELS
...
@@ -21,6 +21,9 @@ from .registry import HF_EXAMPLE_MODELS
@
pytest
.
mark
.
parametrize
(
"model_arch"
,
ModelRegistry
.
get_supported_archs
())
@
pytest
.
mark
.
parametrize
(
"model_arch"
,
ModelRegistry
.
get_supported_archs
())
def
test_registry_imports
(
model_arch
):
def
test_registry_imports
(
model_arch
):
model_info
=
HF_EXAMPLE_MODELS
.
get_hf_info
(
model_arch
)
model_info
.
check_transformers_version
(
on_fail
=
"skip"
)
# Ensure all model classes can be imported successfully
# Ensure all model classes can be imported successfully
model_cls
,
_
=
ModelRegistry
.
resolve_model_cls
(
model_arch
)
model_cls
,
_
=
ModelRegistry
.
resolve_model_cls
(
model_arch
)
...
...
tests/multi_step/test_correctness_async_llm.py
View file @
afd0da21
...
@@ -17,8 +17,8 @@ NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps
...
@@ -17,8 +17,8 @@ NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps
NUM_PROMPTS
=
[
10
]
NUM_PROMPTS
=
[
10
]
DEFAULT_SERVER_ARGS
:
List
[
str
]
=
[
DEFAULT_SERVER_ARGS
:
List
[
str
]
=
[
"--dis
able-log-requests
"
,
"--dis
tributed-executor-backend
"
,
"
--worker-use-
ray"
,
"ray"
,
"--gpu-memory-utilization"
,
"--gpu-memory-utilization"
,
"0.85"
,
"0.85"
,
"--swap-space"
,
"--swap-space"
,
...
@@ -112,7 +112,7 @@ async def test_multi_step(
...
@@ -112,7 +112,7 @@ async def test_multi_step(
# Spin up client/server & issue completion API requests.
# Spin up client/server & issue completion API requests.
# Default `max_wait_seconds` is 240 but was empirically
# Default `max_wait_seconds` is 240 but was empirically
# was raised
3
x to
72
0 *just for this test* due to
# was raised
5
x to
120
0 *just for this test* due to
# observed timeouts in GHA CI
# observed timeouts in GHA CI
ref_completions
=
await
completions_with_server_args
(
ref_completions
=
await
completions_with_server_args
(
prompts
,
prompts
,
...
...
tests/multi_step/test_correctness_llm.py
View file @
afd0da21
...
@@ -6,6 +6,8 @@ from typing import Optional
...
@@ -6,6 +6,8 @@ from typing import Optional
import
pytest
import
pytest
import
os
import
os
from
tests.kernels.utils
import
override_backend_env_variable
from
..models.utils
import
check_logprobs_close
,
check_outputs_equal
from
..models.utils
import
check_logprobs_close
,
check_outputs_equal
from
..utils
import
models_path_prefix
from
..utils
import
models_path_prefix
...
@@ -21,10 +23,11 @@ NUM_PROMPTS = [10]
...
@@ -21,10 +23,11 @@ NUM_PROMPTS = [10]
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"enable_chunked_prefill"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"enable_chunked_prefill"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"num_scheduler_steps"
,
NUM_SCHEDULER_STEPS
)
@
pytest
.
mark
.
parametrize
(
"num_scheduler_steps"
,
NUM_SCHEDULER_STEPS
)
@
pytest
.
mark
.
parametrize
(
"num_prompts"
,
NUM_PROMPTS
)
@
pytest
.
mark
.
parametrize
(
"num_prompts"
,
NUM_PROMPTS
)
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
None
,
5
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
None
,
5
])
@
pytest
.
mark
.
parametrize
(
"attention_backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
])
def
test_multi_step_llm
(
def
test_multi_step_llm
(
hf_runner
,
hf_runner
,
vllm_runner
,
vllm_runner
,
...
@@ -38,6 +41,8 @@ def test_multi_step_llm(
...
@@ -38,6 +41,8 @@ def test_multi_step_llm(
num_scheduler_steps
:
int
,
num_scheduler_steps
:
int
,
num_prompts
:
int
,
num_prompts
:
int
,
num_logprobs
:
Optional
[
int
],
num_logprobs
:
Optional
[
int
],
attention_backend
:
str
,
monkeypatch
,
)
->
None
:
)
->
None
:
"""Test vLLM engine with multi-step scheduling via sync LLM Engine.
"""Test vLLM engine with multi-step scheduling via sync LLM Engine.
...
@@ -65,6 +70,7 @@ def test_multi_step_llm(
...
@@ -65,6 +70,7 @@ def test_multi_step_llm(
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
completions endpoint; `None` -> 1 logprob returned.
completions endpoint; `None` -> 1 logprob returned.
"""
"""
override_backend_env_variable
(
monkeypatch
,
attention_backend
)
prompts
=
example_prompts
prompts
=
example_prompts
if
len
(
prompts
)
<
num_prompts
:
if
len
(
prompts
)
<
num_prompts
:
...
@@ -116,6 +122,7 @@ def test_multi_step_llm(
...
@@ -116,6 +122,7 @@ def test_multi_step_llm(
@
pytest
.
mark
.
parametrize
(
"num_scheduler_steps"
,
NUM_SCHEDULER_STEPS
)
@
pytest
.
mark
.
parametrize
(
"num_scheduler_steps"
,
NUM_SCHEDULER_STEPS
)
@
pytest
.
mark
.
parametrize
(
"num_prompts"
,
NUM_PROMPTS
)
@
pytest
.
mark
.
parametrize
(
"num_prompts"
,
NUM_PROMPTS
)
@
pytest
.
mark
.
parametrize
(
"num_logprobs,num_prompt_logprobs"
,
[(
5
,
5
)])
@
pytest
.
mark
.
parametrize
(
"num_logprobs,num_prompt_logprobs"
,
[(
5
,
5
)])
@
pytest
.
mark
.
parametrize
(
"attention_backend"
,
[
"FLASH_ATTN"
])
def
test_multi_step_llm_w_prompt_logprobs
(
def
test_multi_step_llm_w_prompt_logprobs
(
vllm_runner
,
vllm_runner
,
example_prompts
,
example_prompts
,
...
@@ -128,6 +135,8 @@ def test_multi_step_llm_w_prompt_logprobs(
...
@@ -128,6 +135,8 @@ def test_multi_step_llm_w_prompt_logprobs(
num_prompts
:
int
,
num_prompts
:
int
,
num_logprobs
:
Optional
[
int
],
num_logprobs
:
Optional
[
int
],
num_prompt_logprobs
:
Optional
[
int
],
num_prompt_logprobs
:
Optional
[
int
],
attention_backend
:
str
,
monkeypatch
,
)
->
None
:
)
->
None
:
"""Test prompt logprobs with multi-step scheduling via sync LLM Engine.
"""Test prompt logprobs with multi-step scheduling via sync LLM Engine.
...
@@ -157,6 +166,7 @@ def test_multi_step_llm_w_prompt_logprobs(
...
@@ -157,6 +166,7 @@ def test_multi_step_llm_w_prompt_logprobs(
note that this argument is not supported by the
note that this argument is not supported by the
OpenAI completions endpoint.
OpenAI completions endpoint.
"""
"""
override_backend_env_variable
(
monkeypatch
,
attention_backend
)
prompts
=
example_prompts
prompts
=
example_prompts
if
len
(
prompts
)
<
num_prompts
:
if
len
(
prompts
)
<
num_prompts
:
...
@@ -207,6 +217,7 @@ def test_multi_step_llm_w_prompt_logprobs(
...
@@ -207,6 +217,7 @@ def test_multi_step_llm_w_prompt_logprobs(
@
pytest
.
mark
.
parametrize
(
"num_scheduler_steps"
,
NUM_SCHEDULER_STEPS
)
@
pytest
.
mark
.
parametrize
(
"num_scheduler_steps"
,
NUM_SCHEDULER_STEPS
)
@
pytest
.
mark
.
parametrize
(
"num_prompts"
,
NUM_PROMPTS
)
@
pytest
.
mark
.
parametrize
(
"num_prompts"
,
NUM_PROMPTS
)
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
None
,
5
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
None
,
5
])
@
pytest
.
mark
.
parametrize
(
"attention_backend"
,
[
"FLASH_ATTN"
])
def
test_multi_step_llm_chunked_prefill_prefix_cache
(
def
test_multi_step_llm_chunked_prefill_prefix_cache
(
vllm_runner
,
vllm_runner
,
example_prompts
,
example_prompts
,
...
@@ -218,6 +229,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache(
...
@@ -218,6 +229,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache(
num_scheduler_steps
:
int
,
num_scheduler_steps
:
int
,
num_prompts
:
int
,
num_prompts
:
int
,
num_logprobs
:
Optional
[
int
],
num_logprobs
:
Optional
[
int
],
attention_backend
:
str
,
monkeypatch
,
)
->
None
:
)
->
None
:
"""Test vLLM engine with multi-step+"single-step chunked prefill"+APC.
"""Test vLLM engine with multi-step+"single-step chunked prefill"+APC.
...
@@ -280,6 +293,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache(
...
@@ -280,6 +293,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache(
#
#
# The Incorrect scheduling behavior - if it occurs - will cause an exception
# The Incorrect scheduling behavior - if it occurs - will cause an exception
# in the model runner resulting from `do_sample=False`.
# in the model runner resulting from `do_sample=False`.
override_backend_env_variable
(
monkeypatch
,
attention_backend
)
assert
len
(
example_prompts
)
>=
2
assert
len
(
example_prompts
)
>=
2
challenge_prompts
=
copy
.
deepcopy
(
example_prompts
)
challenge_prompts
=
copy
.
deepcopy
(
example_prompts
)
challenge_prompts
[
0
]
=
(
'vLLM is a high-throughput and memory-efficient '
challenge_prompts
[
0
]
=
(
'vLLM is a high-throughput and memory-efficient '
...
...
tests/multimodal/test_processing.py
View file @
afd0da21
from
contextlib
import
nullcontext
from
typing
import
cast
from
typing
import
cast
from
unittest.mock
import
MagicMock
import
numpy
as
np
import
pytest
import
pytest
from
vllm.multimodal.processing
import
(
PromptReplacement
,
_PlaceholderInfo
,
from
vllm.config
import
ModelConfig
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.multimodal.processing
import
(
PlaceholderFeaturesInfo
,
PromptReplacement
,
find_mm_placeholders
,
find_text_matches
,
find_token_matches
,
find_text_matches
,
find_token_matches
,
iter_placeholders
,
iter_token_matches
,
iter_token_matches
,
replace_text_matches
,
replace_text_matches
,
replace_token_matches
)
replace_token_matches
)
# yapf: enable
from
vllm.multimodal.profiling
import
MultiModalProfiler
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
full_groupby
from
vllm.utils
import
full_groupby
from
.utils
import
random_image
# yapf: disable
# yapf: disable
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -304,21 +318,27 @@ def test_find_replace_text(
...
@@ -304,21 +318,27 @@ def test_find_replace_text(
# Should not be used since there is nothing to convert to text
# Should not be used since there is nothing to convert to text
mock_tokenizer
=
cast
(
AnyTokenizer
,
object
())
mock_tokenizer
=
cast
(
AnyTokenizer
,
object
())
prompt_repls
=
[
mm_prompt_repls
=
{
PromptReplacement
(
key
,
target
,
repl_by_key
[
key
]).
bind
(
mock_tokenizer
)
key
:
[
PromptReplacement
(
key
,
target
,
repl_by_key
[
key
]).
bind
(
mock_tokenizer
)
]
for
key
,
target
in
target_by_key
.
items
()
for
key
,
target
in
target_by_key
.
items
()
]
}
matches
=
find_text_matches
(
prompt
,
prompt_repls
)
mm_matches
=
{
key
:
find_text_matches
(
prompt
,
prompt_repls
)
for
key
,
prompt_repls
in
mm_prompt_repls
.
items
()
}
result
=
replace_text_matches
(
result
=
replace_text_matches
(
prompt
,
prompt
,
matches
,
mm_
matches
,
{
key
:
mm_count
{
key
:
mm_count
for
key
in
repl_by_key
},
for
key
in
repl_by_key
},
)
)
# Only displayed on error
# Only displayed on error
print
(
"matches:"
,
matches
)
print
(
"
mm_
matches:"
,
mm_
matches
)
print
(
"result:"
,
result
)
print
(
"result:"
,
result
)
# Manually constructed results
# Manually constructed results
...
@@ -370,21 +390,27 @@ def test_find_replace_tokens(
...
@@ -370,21 +390,27 @@ def test_find_replace_tokens(
# Should not be used since there is nothing to convert to tokens
# Should not be used since there is nothing to convert to tokens
mock_tokenizer
=
cast
(
AnyTokenizer
,
object
())
mock_tokenizer
=
cast
(
AnyTokenizer
,
object
())
prompt_repls
=
[
mm_prompt_repls
=
{
PromptReplacement
(
key
,
target
,
repl_by_key
[
key
]).
bind
(
mock_tokenizer
)
key
:
[
PromptReplacement
(
key
,
target
,
repl_by_key
[
key
]).
bind
(
mock_tokenizer
)
]
for
key
,
target
in
target_by_key
.
items
()
for
key
,
target
in
target_by_key
.
items
()
]
}
matches
=
find_token_matches
(
prompt
,
prompt_repls
)
mm_matches
=
{
key
:
find_token_matches
(
prompt
,
prompt_repls
)
for
key
,
prompt_repls
in
mm_prompt_repls
.
items
()
}
result
=
replace_token_matches
(
result
=
replace_token_matches
(
prompt
,
prompt
,
matches
,
mm_
matches
,
{
key
:
mm_count
{
key
:
mm_count
for
key
in
repl_by_key
},
for
key
in
repl_by_key
},
)
)
# Only displayed on error
# Only displayed on error
print
(
"matches:"
,
matches
)
print
(
"
mm_
matches:"
,
mm_
matches
)
print
(
"result:"
,
result
)
print
(
"result:"
,
result
)
# Manually constructed results
# Manually constructed results
...
@@ -399,6 +425,8 @@ def test_find_replace_tokens(
...
@@ -399,6 +425,8 @@ def test_find_replace_tokens(
"pattern_1"
:
[
32000
,
32000
],
"pattern_1"
:
[
32000
,
32000
],
"pattern_2"
:
[],
"pattern_2"
:
[],
"pattern_3"
:
[
1550
,
918
,
1550
],
"pattern_3"
:
[
1550
,
918
,
1550
],
# Test different modalities having the same tokens (32000)
"pattern_4"
:
[
32000
],
},
},
],
],
)
)
...
@@ -407,57 +435,93 @@ def test_find_replace_tokens(
...
@@ -407,57 +435,93 @@ def test_find_replace_tokens(
[
[
(
(
[
1
,
9833
,
28747
,
32000
,
9833
,
28747
,
32000
,
32000
,
918
],
[
1
,
9833
,
28747
,
32000
,
9833
,
28747
,
32000
,
32000
,
918
],
[
{
_PlaceholderInfo
(
"pattern_1"
:
[
modality
=
"pattern_1"
,
PlaceholderFeaturesInfo
(
start_idx
=
6
,
modality
=
"pattern_1"
,
replacement
=
[
32000
,
32000
],
item_idx
=
0
,
),
start_idx
=
6
,
],
tokens
=
[
32000
,
32000
],
),
],
"pattern_4"
:
[
PlaceholderFeaturesInfo
(
modality
=
"pattern_4"
,
item_idx
=
0
,
start_idx
=
3
,
tokens
=
[
32000
],
),
],
}
),
),
(
(
[
1
,
32000
,
32000
,
9833
,
28747
,
32000
,
32000
,
1550
,
918
,
1550
],
[
1
,
32000
,
32000
,
9833
,
28747
,
32000
,
32000
,
1550
,
918
,
1550
],
[
{
_PlaceholderInfo
(
"pattern_1"
:
[
modality
=
"pattern_1"
,
PlaceholderFeaturesInfo
(
start_idx
=
1
,
modality
=
"pattern_1"
,
replacement
=
[
32000
,
32000
],
item_idx
=
0
,
),
start_idx
=
1
,
_PlaceholderInfo
(
tokens
=
[
32000
,
32000
],
modality
=
"pattern_1"
,
),
start_idx
=
5
,
PlaceholderFeaturesInfo
(
replacement
=
[
32000
,
32000
],
modality
=
"pattern_1"
,
),
item_idx
=
1
,
_PlaceholderInfo
(
start_idx
=
5
,
modality
=
"pattern_3"
,
tokens
=
[
32000
,
32000
],
start_idx
=
7
,
),
replacement
=
[
1550
,
918
,
1550
],
],
),
"pattern_3"
:
[
],
PlaceholderFeaturesInfo
(
modality
=
"pattern_3"
,
item_idx
=
0
,
start_idx
=
7
,
tokens
=
[
1550
,
918
,
1550
],
),
],
# No match for pattern_4 as it has lower priority than pattern_1
}
),
),
(
(
[
1
,
32000
,
32000
,
32000
,
32000
,
32000
,
1550
,
918
,
1550
],
[
1
,
32000
,
32000
,
32000
,
32000
,
32000
,
1550
,
918
,
1550
],
[
{
_PlaceholderInfo
(
"pattern_1"
:
[
modality
=
"pattern_1"
,
PlaceholderFeaturesInfo
(
start_idx
=
1
,
modality
=
"pattern_1"
,
replacement
=
[
32000
,
32000
],
item_idx
=
0
,
),
start_idx
=
1
,
_PlaceholderInfo
(
tokens
=
[
32000
,
32000
],
modality
=
"pattern_1"
,
),
start_idx
=
3
,
PlaceholderFeaturesInfo
(
replacement
=
[
32000
,
32000
],
modality
=
"pattern_1"
,
),
item_idx
=
1
,
_PlaceholderInfo
(
start_idx
=
3
,
modality
=
"pattern_3"
,
tokens
=
[
32000
,
32000
],
start_idx
=
6
,
),
replacement
=
[
1550
,
918
,
1550
],
],
),
"pattern_4"
:
[
],
PlaceholderFeaturesInfo
(
modality
=
"pattern_4"
,
item_idx
=
0
,
start_idx
=
5
,
tokens
=
[
32000
],
),
],
"pattern_3"
:
[
PlaceholderFeaturesInfo
(
modality
=
"pattern_3"
,
item_idx
=
0
,
start_idx
=
6
,
tokens
=
[
1550
,
918
,
1550
],
),
],
}
),
),
]
]
)
)
def
test_iter_placeholders
(
# yapf: enable
def
test_find_mm_placeholders
(
repl_by_key
,
repl_by_key
,
prompt
,
prompt
,
expected
,
expected
,
...
@@ -465,21 +529,108 @@ def test_iter_placeholders(
...
@@ -465,21 +529,108 @@ def test_iter_placeholders(
# Should not be used since there is nothing to convert to tokens
# Should not be used since there is nothing to convert to tokens
mock_tokenizer
=
cast
(
AnyTokenizer
,
object
())
mock_tokenizer
=
cast
(
AnyTokenizer
,
object
())
prompt_repls
=
[
mm_
prompt_repls
=
{
PromptReplacement
(
key
,
[],
repl
).
bind
(
mock_tokenizer
)
key
:
[
PromptReplacement
(
key
,
[],
repl
).
bind
(
mock_tokenizer
)
]
for
key
,
repl
in
repl_by_key
.
items
()
for
key
,
repl
in
repl_by_key
.
items
()
]
}
result
=
list
(
result
=
find_mm_placeholders
(
iter_placeholders
(
mm_prompt_repls
,
prompt
_repls
,
prompt
,
prompt
,
# Effectively match all occurrences in the
prompt
# Effectively match all occurrences in the prompt
{
key
:
3
{
key
:
3
for
key
in
repl_by_key
},
for
key
in
repl_by_key
},
)
)
)
# Only displayed on error
# Only displayed on error
print
(
"result:"
,
result
)
print
(
"result:"
,
result
)
# Manually constructed results
# Manually constructed results
assert
result
==
expected
assert
result
==
expected
@
pytest
.
mark
.
parametrize
(
"model_id"
,
[
"llava-hf/llava-v1.6-mistral-7b-hf"
])
@
pytest
.
mark
.
parametrize
(
(
"limit"
,
"num_supported"
,
"is_valid"
),
[(
0
,
0
,
True
),
(
0
,
1
,
True
),
(
1
,
0
,
False
),
(
1
,
1
,
True
),
(
1
,
2
,
True
),
(
2
,
1
,
False
),
(
2
,
2
,
True
)],
)
def
test_limit_mm_per_prompt_dummy
(
model_id
,
limit
,
num_supported
,
is_valid
):
limit_mm_per_prompt
=
{
"image"
:
limit
}
model_config
=
ModelConfig
(
model
=
model_id
,
task
=
"auto"
,
tokenizer
=
model_id
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
seed
=
0
,
dtype
=
"half"
,
revision
=
None
,
limit_mm_per_prompt
=
limit_mm_per_prompt
,
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
model_config
,
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
),
)
profiler
=
MultiModalProfiler
(
processor
)
mock_supported_mm_limits
=
MagicMock
(
return_value
=
{
"image"
:
num_supported
})
processor
.
info
.
get_supported_mm_limits
=
mock_supported_mm_limits
if
is_valid
:
exc_ctx
=
nullcontext
()
else
:
exc_ctx
=
pytest
.
raises
(
ValueError
,
match
=
"this model only supports"
)
with
exc_ctx
:
profiler
.
get_dummy_data
(
model_config
.
max_model_len
)
@
pytest
.
mark
.
parametrize
(
"model_id"
,
[
"llava-hf/llava-v1.6-mistral-7b-hf"
])
@
pytest
.
mark
.
parametrize
(
(
"num_images"
,
"limit"
,
"is_valid"
),
[(
0
,
0
,
True
),
(
0
,
1
,
True
),
(
1
,
0
,
False
),
(
1
,
1
,
True
),
(
1
,
2
,
True
),
(
2
,
1
,
False
),
(
2
,
2
,
True
)],
)
def
test_limit_mm_per_prompt_apply
(
model_id
,
num_images
,
limit
,
is_valid
):
limit_mm_per_prompt
=
{
"image"
:
limit
}
model_config
=
ModelConfig
(
model
=
model_id
,
task
=
"auto"
,
tokenizer
=
model_id
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
seed
=
0
,
dtype
=
"half"
,
revision
=
None
,
limit_mm_per_prompt
=
limit_mm_per_prompt
,
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
model_config
,
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
),
)
rng
=
np
.
random
.
RandomState
(
0
)
image
=
random_image
(
rng
,
min_wh
=
128
,
max_wh
=
256
)
if
num_images
==
0
:
mm_data
=
{}
elif
num_images
==
1
:
mm_data
=
{
"image"
:
image
}
else
:
mm_data
=
{
"image"
:
[
image
]
*
num_images
}
if
is_valid
:
exc_ctx
=
nullcontext
()
else
:
exc_ctx
=
pytest
.
raises
(
ValueError
,
match
=
f
"passed
{
num_images
}
image"
)
with
exc_ctx
:
processor
.
apply
(
"<image>"
*
num_images
,
mm_data
=
mm_data
,
hf_processor_mm_kwargs
=
{},
)
tests/multimodal/test_utils.py
View file @
afd0da21
...
@@ -2,7 +2,7 @@ import base64
...
@@ -2,7 +2,7 @@ import base64
import
mimetypes
import
mimetypes
import
os
import
os
from
tempfile
import
NamedTemporaryFile
,
TemporaryDirectory
from
tempfile
import
NamedTemporaryFile
,
TemporaryDirectory
from
typing
import
Dict
,
Tuple
from
typing
import
TYPE_CHECKING
,
Dict
,
NamedTuple
,
Optional
,
Tuple
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
...
@@ -11,10 +11,16 @@ import os
...
@@ -11,10 +11,16 @@ import os
from
PIL
import
Image
,
ImageChops
from
PIL
import
Image
,
ImageChops
from
transformers
import
AutoConfig
,
AutoTokenizer
from
transformers
import
AutoConfig
,
AutoTokenizer
from
vllm.multimodal.utils
import
(
async_fetch_image
,
fetch_image
,
from
vllm.multimodal.inputs
import
PlaceholderRange
from
vllm.multimodal.utils
import
(
MediaConnector
,
merge_and_sort_multimodal_metadata
,
repeat_and_pad_placeholder_tokens
)
repeat_and_pad_placeholder_tokens
)
from
..utils
import
models_path_prefix
,
urls_port
from
..utils
import
models_path_prefix
,
urls_port
if
TYPE_CHECKING
:
from
vllm.multimodal.hasher
import
MultiModalHashDict
from
vllm.multimodal.inputs
import
MultiModalPlaceholderDict
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS
=
[
TEST_IMAGE_URLS
=
[
f
"http://localhost:
{
urls_port
}
/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
,
f
"http://localhost:
{
urls_port
}
/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
,
...
@@ -26,7 +32,12 @@ TEST_IMAGE_URLS = [
...
@@ -26,7 +32,12 @@ TEST_IMAGE_URLS = [
@
pytest
.
fixture
(
scope
=
"module"
)
@
pytest
.
fixture
(
scope
=
"module"
)
def
url_images
()
->
Dict
[
str
,
Image
.
Image
]:
def
url_images
()
->
Dict
[
str
,
Image
.
Image
]:
return
{
image_url
:
fetch_image
(
image_url
)
for
image_url
in
TEST_IMAGE_URLS
}
connector
=
MediaConnector
()
return
{
image_url
:
connector
.
fetch_image
(
image_url
)
for
image_url
in
TEST_IMAGE_URLS
}
def
get_supported_suffixes
()
->
Tuple
[
str
,
...]:
def
get_supported_suffixes
()
->
Tuple
[
str
,
...]:
...
@@ -46,8 +57,10 @@ def _image_equals(a: Image.Image, b: Image.Image) -> bool:
...
@@ -46,8 +57,10 @@ def _image_equals(a: Image.Image, b: Image.Image) -> bool:
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"image_url"
,
TEST_IMAGE_URLS
)
@
pytest
.
mark
.
parametrize
(
"image_url"
,
TEST_IMAGE_URLS
)
async
def
test_fetch_image_http
(
image_url
:
str
):
async
def
test_fetch_image_http
(
image_url
:
str
):
image_sync
=
fetch_image
(
image_url
)
connector
=
MediaConnector
()
image_async
=
await
async_fetch_image
(
image_url
)
image_sync
=
connector
.
fetch_image
(
image_url
)
image_async
=
await
connector
.
fetch_image_async
(
image_url
)
assert
_image_equals
(
image_sync
,
image_async
)
assert
_image_equals
(
image_sync
,
image_async
)
...
@@ -56,6 +69,7 @@ async def test_fetch_image_http(image_url: str):
...
@@ -56,6 +69,7 @@ async def test_fetch_image_http(image_url: str):
@
pytest
.
mark
.
parametrize
(
"suffix"
,
get_supported_suffixes
())
@
pytest
.
mark
.
parametrize
(
"suffix"
,
get_supported_suffixes
())
async
def
test_fetch_image_base64
(
url_images
:
Dict
[
str
,
Image
.
Image
],
async
def
test_fetch_image_base64
(
url_images
:
Dict
[
str
,
Image
.
Image
],
image_url
:
str
,
suffix
:
str
):
image_url
:
str
,
suffix
:
str
):
connector
=
MediaConnector
()
url_image
=
url_images
[
image_url
]
url_image
=
url_images
[
image_url
]
try
:
try
:
...
@@ -78,48 +92,49 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
...
@@ -78,48 +92,49 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
base64_image
=
base64
.
b64encode
(
f
.
read
()).
decode
(
"utf-8"
)
base64_image
=
base64
.
b64encode
(
f
.
read
()).
decode
(
"utf-8"
)
data_url
=
f
"data:
{
mime_type
}
;base64,
{
base64_image
}
"
data_url
=
f
"data:
{
mime_type
}
;base64,
{
base64_image
}
"
data_image_sync
=
fetch_image
(
data_url
)
data_image_sync
=
connector
.
fetch_image
(
data_url
)
if
_image_equals
(
url_image
,
Image
.
open
(
f
)):
if
_image_equals
(
url_image
,
Image
.
open
(
f
)):
assert
_image_equals
(
url_image
,
data_image_sync
)
assert
_image_equals
(
url_image
,
data_image_sync
)
else
:
else
:
pass
# Lossy format; only check that image can be opened
pass
# Lossy format; only check that image can be opened
data_image_async
=
await
async_
fetch_image
(
data_url
)
data_image_async
=
await
connector
.
fetch_image
_async
(
data_url
)
assert
_image_equals
(
data_image_sync
,
data_image_async
)
assert
_image_equals
(
data_image_sync
,
data_image_async
)
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"image_url"
,
TEST_IMAGE_URLS
)
@
pytest
.
mark
.
parametrize
(
"image_url"
,
TEST_IMAGE_URLS
)
async
def
test_fetch_image_local_files
(
image_url
:
str
):
async
def
test_fetch_image_local_files
(
image_url
:
str
):
connector
=
MediaConnector
()
with
TemporaryDirectory
()
as
temp_dir
:
with
TemporaryDirectory
()
as
temp_dir
:
origin_image
=
fetch_image
(
image_url
)
local_connector
=
MediaConnector
(
allowed_local_media_path
=
temp_dir
)
origin_image
=
connector
.
fetch_image
(
image_url
)
origin_image
.
save
(
os
.
path
.
join
(
temp_dir
,
os
.
path
.
basename
(
image_url
)),
origin_image
.
save
(
os
.
path
.
join
(
temp_dir
,
os
.
path
.
basename
(
image_url
)),
quality
=
100
,
quality
=
100
,
icc_profile
=
origin_image
.
info
.
get
(
'icc_profile'
))
icc_profile
=
origin_image
.
info
.
get
(
'icc_profile'
))
image_async
=
await
async_fetch_image
(
image_async
=
await
local_connector
.
fetch_image_async
(
f
"file://
{
temp_dir
}
/
{
os
.
path
.
basename
(
image_url
)
}
"
,
f
"file://
{
temp_dir
}
/
{
os
.
path
.
basename
(
image_url
)
}
"
)
allowed_local_media_path
=
temp_dir
)
image_sync
=
local_connector
.
fetch_image
(
f
"file://
{
temp_dir
}
/
{
os
.
path
.
basename
(
image_url
)
}
"
)
image_sync
=
fetch_image
(
f
"file://
{
temp_dir
}
/
{
os
.
path
.
basename
(
image_url
)
}
"
,
allowed_local_media_path
=
temp_dir
)
# Check that the images are equal
# Check that the images are equal
assert
not
ImageChops
.
difference
(
image_sync
,
image_async
).
getbbox
()
assert
not
ImageChops
.
difference
(
image_sync
,
image_async
).
getbbox
()
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
,
match
=
"must be a subpath"
):
await
async_fetch_image
(
await
local_connector
.
fetch_image_async
(
f
"file://
{
temp_dir
}
/../
{
os
.
path
.
basename
(
image_url
)
}
"
,
f
"file://
{
temp_dir
}
/../
{
os
.
path
.
basename
(
image_url
)
}
"
)
allowed_local_media_path
=
temp_dir
)
with
pytest
.
raises
(
RuntimeError
,
match
=
"Cannot load local files"
):
with
pytest
.
raises
(
ValueError
):
await
connector
.
fetch_image_async
(
await
async_fetch_image
(
f
"file://
{
temp_dir
}
/../
{
os
.
path
.
basename
(
image_url
)
}
"
)
f
"file://
{
temp_dir
}
/../
{
os
.
path
.
basename
(
image_url
)
}
"
)
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
,
match
=
"must be a subpath"
):
fetch_image
(
f
"file://
{
temp_dir
}
/../
{
os
.
path
.
basename
(
image_url
)
}
"
,
local_connector
.
fetch_image
(
allowed_local_media_path
=
temp_dir
)
f
"file://
{
temp_dir
}
/../
{
os
.
path
.
basename
(
image_url
)
}
"
)
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
RuntimeError
,
match
=
"Cannot load local files"
):
fetch_image
(
f
"file://
{
temp_dir
}
/../
{
os
.
path
.
basename
(
image_url
)
}
"
)
connector
.
fetch_image
(
f
"file://
{
temp_dir
}
/../
{
os
.
path
.
basename
(
image_url
)
}
"
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
os
.
path
.
join
(
models_path_prefix
,
"llava-hf/llava-v1.6-mistral-7b-hf"
)])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
os
.
path
.
join
(
models_path_prefix
,
"llava-hf/llava-v1.6-mistral-7b-hf"
)])
...
@@ -185,3 +200,204 @@ def test_repeat_and_pad_placeholder_tokens(model):
...
@@ -185,3 +200,204 @@ def test_repeat_and_pad_placeholder_tokens(model):
assert
new_prompt
==
expected_prompt
assert
new_prompt
==
expected_prompt
assert
new_token_ids
==
expected_token_ids
assert
new_token_ids
==
expected_token_ids
assert
ranges
==
expected_ranges
assert
ranges
==
expected_ranges
# Used for the next two tests related to `merge_and_sort_multimodal_metadata`.
class
TestCase
(
NamedTuple
):
mm_positions
:
"MultiModalPlaceholderDict"
mm_hashes
:
Optional
[
"MultiModalHashDict"
]
expected_modalities
:
list
[
str
]
expected_ranges
:
list
[
PlaceholderRange
]
expected_hashes
:
Optional
[
list
[
str
]]
def
test_merge_and_sort_multimodal_metadata
():
test_cases
=
[
# Single modality should return result as is but flattened
TestCase
(
mm_positions
=
{
"image"
:
[
PlaceholderRange
(
offset
=
0
,
length
=
2
),
PlaceholderRange
(
offset
=
3
,
length
=
2
),
]
},
mm_hashes
=
{
"image"
:
[
"hash1"
,
"hash2"
]},
expected_modalities
=
[
"image"
],
expected_ranges
=
[
PlaceholderRange
(
offset
=
0
,
length
=
2
),
PlaceholderRange
(
offset
=
3
,
length
=
2
),
],
expected_hashes
=
[
"hash1"
,
"hash2"
],
),
# Single modality without hashes return None for mm hash.
TestCase
(
mm_positions
=
{
"image"
:
[
PlaceholderRange
(
offset
=
0
,
length
=
2
),
PlaceholderRange
(
offset
=
2
,
length
=
2
),
]
},
mm_hashes
=
None
,
expected_modalities
=
[
"image"
],
expected_ranges
=
[
PlaceholderRange
(
offset
=
0
,
length
=
2
),
PlaceholderRange
(
offset
=
2
,
length
=
2
),
],
expected_hashes
=
None
,
),
# Multiple modalities with hashes should return sorted modalities
# and flattened ranges and hashes.
TestCase
(
mm_positions
=
{
"image"
:
[
PlaceholderRange
(
offset
=
7
,
length
=
4
),
PlaceholderRange
(
offset
=
11
,
length
=
5
),
],
"audio"
:
[
PlaceholderRange
(
offset
=
0
,
length
=
2
),
PlaceholderRange
(
offset
=
2
,
length
=
3
),
]
},
mm_hashes
=
{
"image"
:
[
"image_hash1"
,
"image_hash2"
],
"audio"
:
[
"audio_hash1"
,
"audio_hash2"
],
},
expected_modalities
=
[
"audio"
,
"image"
],
expected_ranges
=
[
PlaceholderRange
(
offset
=
0
,
length
=
2
),
PlaceholderRange
(
offset
=
2
,
length
=
3
),
PlaceholderRange
(
offset
=
7
,
length
=
4
),
PlaceholderRange
(
offset
=
11
,
length
=
5
),
],
expected_hashes
=
[
"audio_hash1"
,
"audio_hash2"
,
"image_hash1"
,
"image_hash2"
],
),
# Multiple modalities without hashes should return sorted modalities
# and flattened ranges and None.
TestCase
(
mm_positions
=
{
"image"
:
[
PlaceholderRange
(
offset
=
7
,
length
=
4
),
PlaceholderRange
(
offset
=
11
,
length
=
5
),
],
"audio"
:
[
PlaceholderRange
(
offset
=
0
,
length
=
2
),
PlaceholderRange
(
offset
=
2
,
length
=
3
),
]
},
mm_hashes
=
None
,
expected_modalities
=
[
"audio"
,
"image"
],
expected_ranges
=
[
PlaceholderRange
(
offset
=
0
,
length
=
2
),
PlaceholderRange
(
offset
=
2
,
length
=
3
),
PlaceholderRange
(
offset
=
7
,
length
=
4
),
PlaceholderRange
(
offset
=
11
,
length
=
5
),
],
expected_hashes
=
None
,
),
# Three modalities
TestCase
(
mm_positions
=
{
"image"
:
[
PlaceholderRange
(
offset
=
15
,
length
=
7
),
PlaceholderRange
(
offset
=
22
,
length
=
8
),
],
"audio"
:
[
PlaceholderRange
(
offset
=
0
,
length
=
2
),
],
"video"
:
[
PlaceholderRange
(
offset
=
3
,
length
=
4
),
PlaceholderRange
(
offset
=
7
,
length
=
5
),
PlaceholderRange
(
offset
=
12
,
length
=
6
),
]
},
mm_hashes
=
{
"image"
:
[
"image_hash1"
,
"image_hash2"
],
"audio"
:
[
"audio_hash1"
],
"video"
:
[
"video_hash1"
,
"video_hash2"
,
"video_hash3"
]
},
expected_modalities
=
[
"audio"
,
"video"
,
"image"
],
expected_ranges
=
[
PlaceholderRange
(
offset
=
0
,
length
=
2
),
PlaceholderRange
(
offset
=
3
,
length
=
4
),
PlaceholderRange
(
offset
=
7
,
length
=
5
),
PlaceholderRange
(
offset
=
12
,
length
=
6
),
PlaceholderRange
(
offset
=
15
,
length
=
7
),
PlaceholderRange
(
offset
=
22
,
length
=
8
),
],
expected_hashes
=
[
"audio_hash1"
,
"video_hash1"
,
"video_hash2"
,
"video_hash3"
,
"image_hash1"
,
"image_hash2"
],
),
]
for
(
mm_positions
,
mm_hashes
,
expected_modalities
,
expected_ranges
,
expected_hashes
)
in
test_cases
:
modalities
,
ranges
,
hashes
=
merge_and_sort_multimodal_metadata
(
mm_positions
,
mm_hashes
)
assert
modalities
==
expected_modalities
assert
ranges
==
expected_ranges
assert
hashes
==
expected_hashes
def
test_merge_and_sort_multimodal_metadata_with_interleaving
():
test_cases
=
[
# <image> <audio> <image> <audio>
TestCase
(
mm_positions
=
{
"image"
:
[
PlaceholderRange
(
offset
=
0
,
length
=
4
),
PlaceholderRange
(
offset
=
8
,
length
=
2
),
],
"audio"
:
[
PlaceholderRange
(
offset
=
5
,
length
=
2
),
PlaceholderRange
(
offset
=
11
,
length
=
4
),
]
},
mm_hashes
=
{
"image"
:
[
"image_hash1"
,
"image_hash2"
],
"audio"
:
[
"audio_hash1"
,
"audio_hash2"
],
},
expected_modalities
=
[],
expected_ranges
=
[],
expected_hashes
=
None
,
),
# <image> <image> <video> <audio> <image>
TestCase
(
mm_positions
=
{
"image"
:
[
PlaceholderRange
(
offset
=
0
,
length
=
2
),
PlaceholderRange
(
offset
=
2
,
length
=
3
),
PlaceholderRange
(
offset
=
20
,
length
=
4
),
],
"audio"
:
[
PlaceholderRange
(
offset
=
5
,
length
=
2
),
],
"video"
:
[
PlaceholderRange
(
offset
=
8
,
length
=
5
),
]
},
mm_hashes
=
None
,
expected_modalities
=
[],
expected_ranges
=
[],
expected_hashes
=
None
,
),
]
for
case
in
test_cases
:
with
pytest
.
raises
(
ValueError
)
as
ex_info
:
merge_and_sort_multimodal_metadata
(
case
.
mm_positions
,
case
.
mm_hashes
)
assert
"Interleaved mixed-modality"
in
str
(
ex_info
.
value
)
tests/multimodal/utils.py
0 → 100644
View file @
afd0da21
import
numpy
as
np
from
PIL
import
Image
def
random_image
(
rng
:
np
.
random
.
RandomState
,
min_wh
:
int
,
max_wh
:
int
):
w
,
h
=
rng
.
randint
(
min_wh
,
max_wh
,
size
=
(
2
,
))
arr
=
rng
.
randint
(
0
,
255
,
size
=
(
w
,
h
,
3
),
dtype
=
np
.
uint8
)
return
Image
.
fromarray
(
arr
)
def
random_video
(
rng
:
np
.
random
.
RandomState
,
min_frames
:
int
,
max_frames
:
int
,
min_wh
:
int
,
max_wh
:
int
,
):
# Temporary workaround for https://github.com/huggingface/transformers/issues/35412
num_frames
=
rng
.
randint
(
min_frames
,
max_frames
)
num_frames
=
(
num_frames
//
2
)
*
2
w
,
h
=
rng
.
randint
(
min_wh
,
max_wh
,
size
=
(
2
,
))
return
rng
.
randint
(
0
,
255
,
size
=
(
num_frames
,
w
,
h
,
3
),
dtype
=
np
.
uint8
)
def
random_audio
(
rng
:
np
.
random
.
RandomState
,
min_len
:
int
,
max_len
:
int
,
sr
:
int
,
):
audio_len
=
rng
.
randint
(
min_len
,
max_len
)
return
rng
.
rand
(
audio_len
),
sr
tests/neuron/test_prefix_prefill.py
0 → 100644
View file @
afd0da21
import
random
from
typing
import
Optional
import
pytest
import
torch
import
torch.nn.functional
as
F
class
BlockDiagonalCausalFromBottomRightMask
:
@
staticmethod
def
_from_seqlens
(
query_lens
,
seq_lens
,
block_size
=
None
):
from
torch
import
logical_and
,
logical_or
contexted
=
block_size
is
None
context_lens
=
torch
.
tensor
(
seq_lens
)
-
torch
.
tensor
(
query_lens
)
n_queries
=
sum
(
query_lens
)
num_seqs
=
len
(
query_lens
)
if
contexted
:
key_lens_blockaligned
=
seq_lens
else
:
n_blocks_per_seq
=
(
context_lens
+
block_size
-
1
)
//
block_size
offset_per_seq
=
n_blocks_per_seq
*
block_size
key_lens_blockaligned
=
offset_per_seq
[:
num_seqs
].
tolist
()
n_keys
=
sum
(
key_lens_blockaligned
)
a
=
(
torch
.
arange
(
n_queries
).
reshape
(
n_queries
,
1
).
expand
(
n_queries
,
n_keys
))
b
=
torch
.
arange
(
n_keys
).
reshape
(
1
,
n_keys
).
expand
(
n_queries
,
n_keys
)
q_cumsum
=
torch
.
tensor
([
0
]
+
query_lens
).
cumsum
(
dim
=
0
)
k_cumsum
=
torch
.
tensor
([
0
]
+
key_lens_blockaligned
).
cumsum
(
dim
=
0
)
prior_mask
=
torch
.
zeros
(
n_queries
,
n_keys
)
new_masks
:
list
[
torch
.
Tensor
]
=
[]
for
seq_id
in
range
(
num_seqs
):
ri
=
q_cumsum
[
seq_id
]
ci
=
k_cumsum
[
seq_id
]
nr
=
query_lens
[
seq_id
]
if
contexted
:
nc
=
seq_lens
[
seq_id
]
a_offset
=
ci
+
nc
-
ri
-
nr
new_mask
=
(
a
+
a_offset
)
>=
b
else
:
nc
=
context_lens
[
seq_id
]
a_offset
=
ci
+
nc
-
1
new_mask
=
a_offset
>=
b
left_mask
=
b
>=
ci
top_mask
=
a
>=
ri
bottom_mask
=
a
<
(
ri
+
nr
)
new_mask
=
logical_and
(
logical_and
(
logical_and
(
new_mask
,
left_mask
),
top_mask
),
bottom_mask
,
)
prior_mask
=
logical_or
(
prior_mask
,
new_mask
)
new_masks
=
new_masks
+
[
new_mask
]
return
prior_mask
@
staticmethod
def
from_seqlens
(
query_lens
,
seq_lens
,
block_size
=
None
):
contexted
=
block_size
is
None
if
contexted
:
prior_mask
=
BlockDiagonalCausalFromBottomRightMask
.
_from_seqlens
(
query_lens
,
seq_lens
)
active_mask
=
None
else
:
prior_mask
=
BlockDiagonalCausalFromBottomRightMask
.
_from_seqlens
(
query_lens
,
seq_lens
,
block_size
)
active_mask
=
BlockDiagonalCausalFromBottomRightMask
.
_from_seqlens
(
query_lens
,
query_lens
)
return
prior_mask
,
active_mask
def
ref_softmax
(
x
:
torch
.
Tensor
,
dim
:
int
,
mixed_precision
=
False
,
return_max_reduce
=
False
):
max_value
=
torch
.
amax
(
x
,
dim
=
dim
,
keepdims
=
True
)
exp
=
torch
.
exp
(
x
-
max_value
)
if
mixed_precision
:
sum_value
=
torch
.
sum
(
exp
.
astype
(
torch
.
float32
),
dim
=
dim
,
keepdims
=
True
).
astype
(
x
.
dtype
)
else
:
sum_value
=
torch
.
sum
(
exp
,
dim
=
dim
,
keepdims
=
True
)
if
return_max_reduce
:
return
exp
/
sum_value
,
max_value
,
torch
.
reciprocal
(
sum_value
)
return
exp
/
sum_value
def
ref_masked_attention
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
scale
:
float
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
return_max_reduce
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
scaled_qk
=
scale
*
torch
.
einsum
(
"qhd,khd->hqk"
,
query
,
key
).
float
()
if
attn_mask
is
not
None
:
masked_score
=
scaled_qk
+
attn_mask
.
float
()
if
return_max_reduce
:
norm_score
,
cached_max
,
cached_sum_reciprocal
=
ref_softmax
(
masked_score
,
dim
=-
1
,
return_max_reduce
=
True
)
else
:
norm_score
=
ref_softmax
(
masked_score
,
dim
=-
1
)
out
=
torch
.
einsum
(
"hqk,khd->qhd"
,
norm_score
,
value
)
if
return_max_reduce
:
return
(
out
,
cached_max
,
cached_sum_reciprocal
,
norm_score
,
masked_score
,
scaled_qk
,
)
else
:
return
out
def
ref_context_attention
(
query
,
key
,
value
,
query_lens
,
seq_lens
,
head_size
,
num_kv_heads
,
num_heads
,
num_queries_per_kv
,
return_max_reduce
=
False
,
):
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
if
num_queries_per_kv
>
1
:
# Handle MQA and GQA
key
=
torch
.
repeat_interleave
(
key
,
num_queries_per_kv
,
dim
=
1
)
value
=
torch
.
repeat_interleave
(
value
,
num_queries_per_kv
,
dim
=
1
)
attn_mask
,
_
=
BlockDiagonalCausalFromBottomRightMask
.
from_seqlens
(
query_lens
,
seq_lens
)
# convert binary mask to -inf values
attn_mask
=
torch
.
logical_not
(
attn_mask
)
attn_mask
=
attn_mask
.
float
()
*
-
30000
output
,
cached_max
,
cached_sum_reciprocal
,
lse
,
masked_score
,
scaled_qk
=
(
ref_masked_attention
(
query
,
key
,
value
,
scale
,
attn_mask
,
return_max_reduce
=
return_max_reduce
,
))
output
=
output
.
unsqueeze
(
1
)
if
return_max_reduce
:
return
(
output
,
cached_max
,
cached_sum_reciprocal
,
lse
,
masked_score
,
scaled_qk
,
)
else
:
return
output
@
pytest
.
mark
.
parametrize
(
"num_heads,num_queries_per_kv,head_size,mixed_precision"
,
[
(
4
,
2
,
8
,
False
),
(
4
,
2
,
8
,
True
),
(
32
,
8
,
64
,
True
),
],
)
@
torch
.
inference_mode
()
def
test_contexted_kv_attention
(
num_heads
:
int
,
num_queries_per_kv
:
int
,
head_size
:
int
,
mixed_precision
:
bool
,
)
->
None
:
import
os
import
torch_xla.core.xla_model
as
xm
from
vllm.attention.ops.nki_flash_attn
import
flash_attn_varlen_nkifunc
device
=
xm
.
xla_device
()
os
.
environ
[
"NEURON_CC_FLAGS"
]
=
(
" --model-type=transformer -O1 "
" --internal-hlo2tensorizer-options='--verify-hlo' "
)
random
.
seed
(
0
)
torch
.
manual_seed
(
0
)
torch
.
set_printoptions
(
sci_mode
=
False
)
min_ctx_len
=
2
max_ctx_len
=
64
min_query_len
=
2
max_query_len
=
64
prefill_batch_size
=
2
decode_batch_size
=
6
batch_size
=
prefill_batch_size
+
decode_batch_size
block_size
=
32
max_model_len
=
(
max_query_len
+
max_ctx_len
)
*
4
max_block_per_request
=
max_model_len
//
block_size
dtype
=
torch
.
float32
cache_size
=
(
batch_size
*
max_block_per_request
)
+
2
ctx_lens
=
[
random
.
randint
(
min_ctx_len
,
max_ctx_len
)
for
_
in
range
(
prefill_batch_size
)
]
+
[
random
.
randint
(
min_ctx_len
,
max_ctx_len
)
for
_
in
range
(
decode_batch_size
)
]
query_lens
=
[
random
.
randint
(
min_query_len
,
max_query_len
)
for
_
in
range
(
prefill_batch_size
)
]
+
[
1
for
_
in
range
(
decode_batch_size
)]
seq_lens
=
[
a
+
b
for
a
,
b
in
zip
(
query_lens
,
ctx_lens
)]
num_kv_heads
=
num_heads
//
num_queries_per_kv
num_tokens
=
sum
(
query_lens
)
query
=
torch
.
empty
(
num_tokens
,
num_heads
,
head_size
,
dtype
=
dtype
)
query
.
uniform_
(
-
1
,
1
)
torch
.
empty
(
num_tokens
,
num_heads
,
head_size
,
dtype
=
dtype
)
kv
=
torch
.
empty
(
sum
(
seq_lens
),
2
,
num_kv_heads
,
head_size
,
dtype
=
dtype
)
kv
.
uniform_
(
-
1
,
1
)
key
,
value
=
kv
.
unbind
(
dim
=
1
)
k_cache
=
torch
.
zeros
(
cache_size
,
block_size
,
num_kv_heads
,
head_size
,
dtype
=
dtype
)
v_cache
=
torch
.
zeros
(
cache_size
,
block_size
,
num_kv_heads
,
head_size
,
dtype
=
dtype
)
k
=
torch
.
zeros
(
sum
(
query_lens
),
num_kv_heads
,
head_size
,
dtype
=
dtype
)
v
=
torch
.
zeros
(
sum
(
query_lens
),
num_kv_heads
,
head_size
,
dtype
=
dtype
)
values
=
torch
.
arange
(
0
,
cache_size
,
dtype
=
torch
.
long
)
values
=
values
[
torch
.
randperm
(
cache_size
)]
block_table
=
values
[:
batch_size
*
max_block_per_request
].
view
(
batch_size
,
max_block_per_request
)
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
long
)
b_ctx_len
=
torch
.
tensor
(
ctx_lens
,
dtype
=
torch
.
long
)
b_start_loc
=
torch
.
cumsum
(
torch
.
tensor
([
0
]
+
query_lens
[:
-
1
],
dtype
=
torch
.
long
),
dim
=
0
)
# copy kv to cache
b_seq_start_loc
=
torch
.
cumsum
(
torch
.
tensor
([
0
]
+
seq_lens
[:
-
1
],
dtype
=
torch
.
long
),
dim
=
0
)
for
i
in
range
(
batch_size
):
for
j
in
range
(
query_lens
[
i
]):
k
[
b_start_loc
[
i
]
+
j
].
copy_
(
key
[
b_seq_start_loc
[
i
]
+
b_ctx_len
[
i
]
+
j
])
v
[
b_start_loc
[
i
]
+
j
].
copy_
(
value
[
b_seq_start_loc
[
i
]
+
b_ctx_len
[
i
]
+
j
])
cur_ctx
=
0
block_id
=
0
while
cur_ctx
<
b_ctx_len
[
i
]:
start_loc
=
b_seq_start_loc
[
i
]
+
cur_ctx
if
cur_ctx
+
block_size
>
b_ctx_len
[
i
]:
end_loc
=
b_seq_start_loc
[
i
]
+
b_ctx_len
[
i
]
else
:
end_loc
=
start_loc
+
block_size
start_slot
=
block_table
[
i
,
block_id
]
*
block_size
end_slot
=
start_slot
+
end_loc
-
start_loc
k_cache
.
view
(
-
1
,
num_kv_heads
,
head_size
)[
start_slot
:
end_slot
].
copy_
(
key
[
start_loc
:
end_loc
])
v_cache
.
view
(
-
1
,
num_kv_heads
,
head_size
)[
start_slot
:
end_slot
].
copy_
(
value
[
start_loc
:
end_loc
])
cur_ctx
+=
block_size
block_id
+=
1
(
output_ref
,
cached_max
,
cached_sum_reciprocal
,
lse
,
masked_score
,
scaled_qk
,
)
=
ref_context_attention
(
query
,
key
,
value
,
query_lens
,
seq_lens
,
head_size
,
num_kv_heads
,
num_heads
,
num_queries_per_kv
,
return_max_reduce
=
True
,
)
# build neuron program
return_debug_tensors
=
False
B_P_SIZE
=
128
LARGE_TILE_SZ
=
2048
max_num_queries
=
(
(
sum
(
query_lens
)
+
block_size
-
1
)
//
block_size
)
*
block_size
def
get_active_block_tables
(
block_tables
,
query_lens
,
seq_lens
,
block_size
,
num_blocks
):
context_lens
=
seq_lens
-
query_lens
blocks_per_seq
=
(
context_lens
+
block_size
-
1
)
//
block_size
num_seqs
=
len
(
seq_lens
)
active_blocks
:
list
[
int
]
=
[]
for
seq_id
in
range
(
num_seqs
):
active_blocks
=
(
active_blocks
+
block_tables
[
seq_id
,
:
blocks_per_seq
[
seq_id
]].
tolist
())
return
F
.
pad
(
torch
.
tensor
(
active_blocks
),
(
0
,
num_blocks
-
len
(
active_blocks
)),
"constant"
,
0
,
)
def
shift_bit_length
(
x
):
return
1
<<
(
x
-
1
).
bit_length
()
# calculate input shapes
max_num_queries_shifted
=
shift_bit_length
(
max_num_queries
)
max_num_queries_factor
=
B_P_SIZE
//
max_num_queries_shifted
max_num_queries_padded
=
max_num_queries_shifted
*
max_num_queries_factor
assert
(
max_num_queries_padded
==
B_P_SIZE
),
"invalid {max_num_queries_padded=}"
head_size_padded
=
B_P_SIZE
context_lens
=
torch
.
tensor
(
seq_lens
)
-
torch
.
tensor
(
query_lens
)
num_active_blocks_shifted
=
shift_bit_length
(
((
context_lens
+
block_size
-
1
)
//
block_size
).
sum
().
item
())
num_active_blocks_factor
=
(
LARGE_TILE_SZ
//
block_size
//
num_active_blocks_shifted
)
num_active_blocks
=
num_active_blocks_shifted
*
num_active_blocks_factor
assert
(
num_active_blocks
*
block_size
)
==
LARGE_TILE_SZ
,
"invalid {num_active_blocks=}"
context_kv_len
=
num_active_blocks
*
block_size
assert
context_kv_len
==
LARGE_TILE_SZ
,
f
"invalid
{
context_kv_len
=
}
"
# pad QKV tensors
pad_dims
=
(
0
,
head_size_padded
-
query
.
shape
[
2
],
0
,
0
,
0
,
max_num_queries_padded
-
query
.
shape
[
0
],
)
query
=
F
.
pad
(
query
,
pad_dims
,
"constant"
,
0
)
k
=
F
.
pad
(
k
,
pad_dims
,
"constant"
,
0
)
v
=
F
.
pad
(
v
,
pad_dims
,
"constant"
,
0
)
k_cache
=
F
.
pad
(
k_cache
,
(
0
,
head_size_padded
-
head_size
),
"constant"
,
0
)
v_cache
=
F
.
pad
(
v_cache
,
(
0
,
head_size_padded
-
head_size
),
"constant"
,
0
)
# permute QKV tensors
# query: (1, n_heads, d, seq_q)
# key: (1, n_kv_heads, d, seq_k)
# value: (1, n_kv_heads, seq_v, d)
query
=
query
.
unsqueeze
(
0
).
permute
(
0
,
2
,
3
,
1
).
contiguous
()
k
=
k
.
unsqueeze
(
0
).
permute
(
0
,
2
,
3
,
1
).
contiguous
()
v
=
v
.
unsqueeze
(
0
).
permute
(
0
,
2
,
1
,
3
).
contiguous
()
# transform block table
active_block_table
=
get_active_block_tables
(
block_table
,
torch
.
tensor
(
query_lens
),
torch
.
tensor
(
seq_lens
),
block_size
,
num_active_blocks
,
)
# Build attention masks
prior_mask
,
active_mask
=
(
BlockDiagonalCausalFromBottomRightMask
.
from_seqlens
(
query_lens
,
seq_lens
,
block_size
=
block_size
))
attn_mask
=
torch
.
concat
(
[
F
.
pad
(
prior_mask
,
(
0
,
context_kv_len
-
prior_mask
.
shape
[
1
],
0
,
B_P_SIZE
-
prior_mask
.
shape
[
0
],
),
"constant"
,
0
,
).
bool
(),
F
.
pad
(
active_mask
,
(
0
,
B_P_SIZE
-
active_mask
.
shape
[
1
],
0
,
B_P_SIZE
-
active_mask
.
shape
[
0
],
),
"constant"
,
0
,
).
bool
(),
],
dim
=
1
,
)
input_args
=
(
query
.
to
(
device
=
device
),
k
.
to
(
device
=
device
),
v
.
to
(
device
=
device
),
k_cache
.
to
(
device
=
device
),
v_cache
.
to
(
device
=
device
),
active_block_table
.
to
(
torch
.
int32
).
to
(
device
=
device
),
attn_mask
.
to
(
device
=
device
),
)
input_kwargs
=
dict
(
n_kv_head
=
num_kv_heads
,
head_size
=
head_size
,
mixed_precision
=
mixed_precision
,
)
if
return_debug_tensors
:
output_nki
,
*
debug_tensors
=
flash_attn_varlen_nkifunc
(
*
input_args
,
**
input_kwargs
)
else
:
output_nki
=
flash_attn_varlen_nkifunc
(
*
input_args
,
**
input_kwargs
)
debug_tensors
=
[]
output_nki
=
torch
.
tensor
(
output_nki
).
cpu
()
debug_tensors
=
[
torch
.
tensor
(
dt
).
cpu
()
for
dt
in
debug_tensors
]
num_actual_tokens
=
sum
(
query_lens
)
print
(
f
"
{
num_actual_tokens
=
}
"
)
# - o: shape (bs, n_heads, seq_q, d) -> (bs, seq_q, n_heads, d)
output_nki
=
output_nki
.
permute
(
0
,
2
,
1
,
3
)[:,
:,
:,
:
head_size
].
cpu
()[
0
,
:
num_actual_tokens
,
:,
:]
output_ref_padded
=
F
.
pad
(
output_ref
,
(
0
,
0
,
0
,
0
,
0
,
0
,
0
,
max_num_queries_padded
-
output_ref
.
shape
[
0
]),
"constant"
,
0
,
)
output_ref
=
output_ref_padded
.
transpose
(
0
,
1
)[
0
,
:
num_actual_tokens
,
:,
:]
torch
.
testing
.
assert_close
(
output_nki
,
output_ref
,
atol
=
1e-2
,
rtol
=
0
)
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py
View file @
afd0da21
...
@@ -2,15 +2,17 @@ from typing import Optional
...
@@ -2,15 +2,17 @@ from typing import Optional
import
torch
import
torch
from
vllm.model_executor.models.llava
import
(
LlavaForConditionalGeneration
,
from
vllm.model_executor.models.llava
import
(
LlavaDummyInputsBuilder
,
LlavaForConditionalGeneration
,
LlavaMultiModalProcessor
,
LlavaMultiModalProcessor
,
get_max_llava_image_tokens
)
LlavaProcessingInfo
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_llava_image_tokens
)
@
MULTIMODAL_REGISTRY
.
register_processor
(
LlavaMultiModalProcessor
,
@
MULTIMODAL_REGISTRY
.
register_processor
(
LlavaMultiModalProcessor
)
info
=
LlavaProcessingInfo
,
dummy_inputs
=
LlavaDummyInputsBuilder
)
class
MyLlava
(
LlavaForConditionalGeneration
):
class
MyLlava
(
LlavaForConditionalGeneration
):
def
compute_logits
(
def
compute_logits
(
...
...
tests/plugins/vllm_add_dummy_platform/setup.py
0 → 100644
View file @
afd0da21
from
setuptools
import
setup
setup
(
name
=
'vllm_add_dummy_platform'
,
version
=
'0.1'
,
packages
=
[
'vllm_add_dummy_platform'
],
entry_points
=
{
'vllm.platform_plugins'
:
[
"dummy_platform_plugin = vllm_add_dummy_platform:dummy_platform_plugin"
# noqa
]
})
tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/__init__.py
0 → 100644
View file @
afd0da21
from
typing
import
Optional
def
dummy_platform_plugin
()
->
Optional
[
str
]:
return
"vllm_add_dummy_platform.dummy_platform.DummyPlatform"
tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_attention_backend.py
0 → 100644
View file @
afd0da21
from
vllm.attention.backends.flash_attn
import
FlashAttentionBackend
class
DummyAttentionBackend
(
FlashAttentionBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"Dummy_Backend"
tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py
0 → 100644
View file @
afd0da21
from
vllm.platforms.cuda
import
CudaPlatform
class
DummyPlatform
(
CudaPlatform
):
device_name
=
"DummyDevice"
def
get_attn_backend_cls
(
self
,
backend_name
,
head_size
,
dtype
,
kv_cache_dtype
,
block_size
,
use_v1
,
use_mla
):
return
"vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend"
# noqa E501
tests/plugins_tests/test_platform_plugins.py
0 → 100644
View file @
afd0da21
import
torch
from
tests.kernels.utils
import
override_backend_env_variable
from
vllm.attention.selector
import
get_attn_backend
from
vllm.utils
import
STR_INVALID_VAL
def
test_platform_plugins
():
# simulate workload by running an example
import
runpy
current_file
=
__file__
import
os
example_file
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
current_file
))),
"examples"
,
"offline_inference/basic.py"
)
runpy
.
run_path
(
example_file
)
# check if the plugin is loaded correctly
from
vllm.platforms
import
_init_trace
,
current_platform
assert
current_platform
.
device_name
==
"DummyDevice"
,
(
f
"Expected DummyDevice, got
{
current_platform
.
device_name
}
, "
"possibly because current_platform is imported before the plugin"
f
" is loaded. The first import:
\n
{
_init_trace
}
"
)
def
test_oot_attention_backend
(
monkeypatch
):
# ignore the backend env variable if it is set
override_backend_env_variable
(
monkeypatch
,
STR_INVALID_VAL
)
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
assert
backend
.
get_name
()
==
"Dummy_Backend"
tests/quantization/test_compressed_tensors.py
View file @
afd0da21
...
@@ -33,50 +33,55 @@ from ..utils import models_path_prefix
...
@@ -33,50 +33,55 @@ from ..utils import models_path_prefix
def
test_compressed_tensors_w8a8_static_setup
(
vllm_runner
,
model_args
):
def
test_compressed_tensors_w8a8_static_setup
(
vllm_runner
,
model_args
):
model_path
,
strategy
,
quant_type
,
shape_0
,
is_symmetric
=
model_args
model_path
,
strategy
,
quant_type
,
shape_0
,
is_symmetric
=
model_args
with
vllm_runner
(
model_path
,
enforce_eager
=
True
)
as
llm
:
with
vllm_runner
(
model_path
,
enforce_eager
=
True
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
layer
=
model
.
model
.
layers
[
0
]
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
o_proj
=
layer
.
self_attn
.
o_proj
qkv_proj
=
layer
.
self_attn
.
qkv_proj
gate_up_proj
=
layer
.
mlp
.
gate_up_proj
o_proj
=
layer
.
self_attn
.
o_proj
down_proj
=
layer
.
mlp
.
down_proj
gate_up_proj
=
layer
.
mlp
.
gate_up_proj
down_proj
=
layer
.
mlp
.
down_proj
# assert zp for symmetric and asymmetric cases
def
zp_valid
(
zp
:
Optional
[
torch
.
Tensor
]):
# assert zp for symmetric and asymmetric cases
if
is_symmetric
:
def
zp_valid
(
zp
:
Optional
[
torch
.
Tensor
]):
return
zp
is
None
if
is_symmetric
:
return
zp
is
None
return
zp
is
not
None
and
zp
.
dtype
is
torch
.
int32
return
zp
is
not
None
and
zp
.
dtype
is
torch
.
int32
assert
zp_valid
(
qkv_proj
.
input_zero_point
)
assert
zp_valid
(
o_proj
.
input_zero_point
)
assert
zp_valid
(
qkv_proj
.
input_zero_point
)
assert
zp_valid
(
gate_up_proj
.
input_zero_point
)
assert
zp_valid
(
o_proj
.
input_zero_point
)
assert
zp_valid
(
down_proj
.
input_zero_point
)
assert
zp_valid
(
gate_up_proj
.
input_zero_point
)
assert
zp_valid
(
down_proj
.
input_zero_point
)
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
o_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
quant_method
,
assert
isinstance
(
gate_up_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
CompressedTensorsLinearMethod
)
assert
isinstance
(
o_proj
.
quant_method
,
assert
isinstance
(
down_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
CompressedTensorsLinearMethod
)
assert
isinstance
(
gate_up_proj
.
quant_method
,
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW8A8Int8
)
CompressedTensorsLinearMethod
)
assert
isinstance
(
down_proj
.
quant_method
,
assert
qkv_proj
.
scheme
.
strategy
==
strategy
CompressedTensorsLinearMethod
)
assert
qkv_proj
.
scheme
.
is_static_input_scheme
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW8A8Int8
)
expected_type
=
torch
.
int8
assert
qkv_proj
.
scheme
.
strategy
==
strategy
assert
qkv_proj
.
weight
.
dtype
is
expected_type
assert
qkv_proj
.
scheme
.
is_static_input_scheme
assert
o_proj
.
weight
.
dtype
is
expected_type
expected_type
=
torch
.
int8
assert
gate_up_proj
.
weight
.
dtype
is
expected_type
assert
qkv_proj
.
weight
.
dtype
is
expected_type
if
qkv_proj
.
scheme
.
strategy
==
"tensor"
:
assert
o_proj
.
weight
.
dtype
is
expected_type
# Make sure it is a channelwise buffer
assert
gate_up_proj
.
weight
.
dtype
is
expected_type
# After running process_weights_after_loading
assert
len
(
qkv_proj
.
weight_scale
.
shape
)
==
2
if
qkv_proj
.
scheme
.
strategy
==
"tensor"
:
assert
qkv_proj
.
weight_scale
.
shape
[
0
]
==
shape_0
# Make sure it is a channelwise buffer
assert
qkv_proj
.
weight_scale
.
shape
[
1
]
==
1
# After running process_weights_after_loading
assert
qkv_proj
.
weight_scale
.
dtype
is
torch
.
float32
assert
len
(
qkv_proj
.
weight_scale
.
shape
)
==
2
assert
qkv_proj
.
input_scale
.
dtype
is
torch
.
float32
assert
qkv_proj
.
weight_scale
.
shape
[
0
]
==
shape_0
assert
qkv_proj
.
weight_scale
.
shape
[
1
]
==
1
assert
qkv_proj
.
weight_scale
.
dtype
is
torch
.
float32
assert
qkv_proj
.
input_scale
.
dtype
is
torch
.
float32
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
([
"Hello my name is"
],
max_tokens
=
20
)
output
=
llm
.
generate_greedy
([
"Hello my name is"
],
max_tokens
=
20
)
assert
output
assert
output
...
@@ -132,16 +137,20 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner):
...
@@ -132,16 +137,20 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner):
def
test_compressed_tensors_w8a8_dynamic_per_token
(
vllm_runner
,
model_args
):
def
test_compressed_tensors_w8a8_dynamic_per_token
(
vllm_runner
,
model_args
):
model_path
,
strategy
=
model_args
model_path
,
strategy
=
model_args
with
vllm_runner
(
model_path
,
dtype
=
torch
.
float16
)
as
llm
:
with
vllm_runner
(
model_path
,
dtype
=
torch
.
float16
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW8A8Int8
)
assert
not
qkv_proj
.
scheme
.
is_static_input_scheme
assert
isinstance
(
qkv_proj
.
quant_method
,
assert
qkv_proj
.
scheme
.
strategy
==
strategy
CompressedTensorsLinearMethod
)
assert
qkv_proj
.
weight
.
dtype
is
torch
.
int8
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW8A8Int8
)
assert
not
qkv_proj
.
scheme
.
is_static_input_scheme
assert
qkv_proj
.
scheme
.
strategy
==
strategy
assert
qkv_proj
.
weight
.
dtype
is
torch
.
int8
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
([
"Hello my name is"
],
max_tokens
=
20
)
output
=
llm
.
generate_greedy
([
"Hello my name is"
],
max_tokens
=
20
)
assert
output
assert
output
...
@@ -157,19 +166,24 @@ def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
...
@@ -157,19 +166,24 @@ def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
def
test_compressed_tensors_wNa16
(
vllm_runner
,
wNa16_args
):
def
test_compressed_tensors_wNa16
(
vllm_runner
,
wNa16_args
):
model
,
strategy
,
group
,
pack_factor
=
wNa16_args
model
,
strategy
,
group
,
pack_factor
=
wNa16_args
with
vllm_runner
(
model
)
as
llm
:
with
vllm_runner
(
model
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
def
check_model
(
model
):
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
layer
=
model
.
model
.
layers
[
0
]
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsWNA16
)
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsWNA16
)
assert
qkv_proj
.
scheme
.
strategy
==
strategy
assert
qkv_proj
.
scheme
.
strategy
==
strategy
assert
qkv_proj
.
scheme
.
group_size
==
(
-
1
if
group
is
None
else
group
)
assert
qkv_proj
.
scheme
.
group_size
==
(
-
1
if
group
is
None
else
group
)
assert
qkv_proj
.
weight_packed
.
dtype
is
torch
.
int32
assert
qkv_proj
.
weight_packed
.
dtype
is
torch
.
int32
assert
qkv_proj
.
weight_scale
.
dtype
is
torch
.
float16
assert
qkv_proj
.
weight_scale
.
dtype
is
torch
.
float16
assert
qkv_proj
.
scheme
.
pack_factor
==
pack_factor
assert
qkv_proj
.
scheme
.
pack_factor
==
pack_factor
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
assert
output
assert
output
...
@@ -180,14 +194,18 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
...
@@ -180,14 +194,18 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
def
test_compressed_tensors_w4a16_marlin24
(
vllm_runner
):
def
test_compressed_tensors_w4a16_marlin24
(
vllm_runner
):
model_path
=
os
.
path
.
join
(
models_path_prefix
,
"nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
)
model_path
=
os
.
path
.
join
(
models_path_prefix
,
"nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
)
with
vllm_runner
(
model_path
)
as
llm
:
with
vllm_runner
(
model_path
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW4A16Sparse24
)
assert
qkv_proj
.
weight_packed
.
dtype
is
torch
.
int32
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
llm
.
apply_model
(
check_model
)
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW4A16Sparse24
)
assert
qkv_proj
.
weight_packed
.
dtype
is
torch
.
int32
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
assert
output
assert
output
...
@@ -198,23 +216,27 @@ def test_compressed_tensors_w4a16_marlin24(vllm_runner):
...
@@ -198,23 +216,27 @@ def test_compressed_tensors_w4a16_marlin24(vllm_runner):
def
test_compressed_tensors_fp8
(
vllm_runner
):
def
test_compressed_tensors_fp8
(
vllm_runner
):
model_path
=
os
.
path
.
join
(
models_path_prefix
,
"nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
)
model_path
=
os
.
path
.
join
(
models_path_prefix
,
"nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
)
with
vllm_runner
(
model_path
)
as
llm
:
with
vllm_runner
(
model_path
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
isinstance
(
qkv_proj
.
scheme
,
(
CompressedTensorsW8A8Fp8
,
CompressedTensorsW8A16Fp8
))
assert
qkv_proj
.
input_scale
.
dtype
is
torch
.
float32
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
(
CompressedTensorsW8A8Fp8
,
CompressedTensorsW8A16Fp8
))
if
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW8A8Fp8
):
assert
qkv_proj
.
input_scale
.
dtype
is
torch
.
float32
assert
len
(
qkv_proj
.
input_scale
.
shape
)
==
0
assert
qkv_proj
.
weight
.
dtype
is
torch
.
float8_e4m3fn
if
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW8A8Fp8
):
assert
qkv_proj
.
weight_scale
.
dtype
is
torch
.
float32
assert
len
(
qkv_proj
.
input_scale
.
shape
)
==
0
assert
len
(
qkv_proj
.
weight_scale
.
shape
)
==
0
assert
qkv_proj
.
weight
.
dtype
is
torch
.
float8_e4m3fn
assert
qkv_proj
.
weight_scale
.
dtype
is
torch
.
float32
assert
len
(
qkv_proj
.
weight_scale
.
shape
)
==
0
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
assert
output
assert
output
...
@@ -259,12 +281,15 @@ def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy):
...
@@ -259,12 +281,15 @@ def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy):
def
test_compressed_tensors_2of4_quant_fp8
(
vllm_runner
,
args_2of4
):
def
test_compressed_tensors_2of4_quant_fp8
(
vllm_runner
,
args_2of4
):
model
,
weight_strategy
,
input_strategy
=
args_2of4
model
,
weight_strategy
,
input_strategy
=
args_2of4
with
vllm_runner
(
model
)
as
llm
:
with
vllm_runner
(
model
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
def
check_model
(
model
):
assert
qkv_proj
.
scheme
.
weights_dtype
==
torch
.
float8_e4m3fn
layer
=
model
.
model
.
layers
[
0
]
_test_2of4_quant_models
(
qkv_proj
,
weight_strategy
,
input_strategy
)
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
qkv_proj
.
scheme
.
weights_dtype
==
torch
.
float8_e4m3fn
_test_2of4_quant_models
(
qkv_proj
,
weight_strategy
,
input_strategy
)
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
print
(
output
)
print
(
output
)
...
@@ -284,40 +309,49 @@ def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
...
@@ -284,40 +309,49 @@ def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
def
test_compressed_tensors_2of4_quant_int8
(
vllm_runner
,
args_2of4
):
def
test_compressed_tensors_2of4_quant_int8
(
vllm_runner
,
args_2of4
):
model
,
weight_strategy
,
input_strategy
=
args_2of4
model
,
weight_strategy
,
input_strategy
=
args_2of4
with
vllm_runner
(
model
)
as
llm
:
with
vllm_runner
(
model
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
def
check_model
(
model
):
assert
qkv_proj
.
scheme
.
weights_dtype
==
torch
.
int8
layer
=
model
.
model
.
layers
[
0
]
_test_2of4_quant_models
(
qkv_proj
,
weight_strategy
,
input_strategy
)
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
qkv_proj
.
scheme
.
weights_dtype
==
torch
.
int8
_test_2of4_quant_models
(
qkv_proj
,
weight_strategy
,
input_strategy
)
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
print
(
output
)
print
(
output
)
assert
output
assert
output
@
pytest
.
mark
.
skipif
(
not
sparse_cutlass_supported
(),
@
pytest
.
mark
.
skip
(
reason
=
"2of4 sparse w16a16 CUTLASS produces bad output."
)
reason
=
"Sparse FP8 is not yet supported on this GPU type."
)
@
pytest
.
mark
.
skipif
(
not
sparse_cutlass_supported
(),
reason
=
"2of4 Sparse is not yet supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"args_2of4"
,
"args_2of4"
,
[(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor"
)])
[(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor"
)])
def
test_compressed_tensors_2of4_sparse
(
vllm_runner
,
args_2of4
):
def
test_compressed_tensors_2of4_sparse
(
vllm_runner
,
args_2of4
):
model
=
args_2of4
model
=
args_2of4
with
vllm_runner
(
model
)
as
llm
:
with
vllm_runner
(
model
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
layer
=
model
.
model
.
layers
[
0
]
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensors24
)
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
qkv_proj
.
scheme
.
weight_quant
is
None
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensors24
)
assert
qkv_proj
.
scheme
.
input_quant
is
None
assert
not
qkv_proj
.
scheme
.
quantized
assert
qkv_proj
.
scheme
.
weight_quant
is
None
assert
qkv_proj
.
quant_method
.
quantization_config
.
sparsity_scheme_map
assert
qkv_proj
.
scheme
.
input_quant
is
None
sparsity_map
=
qkv_proj
.
quant_method
.
quantization_config
.
sparsity_scheme_map
# noqa: E501
assert
not
qkv_proj
.
scheme
.
quantized
assert
sparsity_map
.
get
(
"Linear"
).
format
==
"dense"
assert
qkv_proj
.
quant_method
.
quantization_config
.
sparsity_scheme_map
assert
sparsity_map
.
get
(
"Linear"
).
sparsity_structure
==
"2:4"
sparsity_map
=
qkv_proj
.
quant_method
.
quantization_config
.
sparsity_scheme_map
# noqa: E501
assert
sparsity_map
.
get
(
"Linear"
).
format
==
"dense"
assert
sparsity_map
.
get
(
"Linear"
).
sparsity_structure
==
"2:4"
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
print
(
output
)
print
(
output
)
...
...
tests/quantization/test_fp8.py
View file @
afd0da21
...
@@ -52,13 +52,17 @@ KV_CACHE_MODELS = [
...
@@ -52,13 +52,17 @@ KV_CACHE_MODELS = [
def
test_kv_cache_model_load_and_run
(
vllm_runner
,
model_id
:
str
):
def
test_kv_cache_model_load_and_run
(
vllm_runner
,
model_id
:
str
):
with
vllm_runner
(
model_id
,
kv_cache_dtype
=
"fp8"
)
as
llm
:
with
vllm_runner
(
model_id
,
kv_cache_dtype
=
"fp8"
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
def
check_model
(
model
):
attn
=
model
.
model
.
layers
[
0
].
self_attn
.
attn
attn
=
model
.
model
.
layers
[
0
].
self_attn
.
attn
assert
isinstance
(
attn
.
quant_method
,
Fp8KVCacheMethod
)
# NOTE: it is valid for scales to be 1.0 (default value), but we know
assert
isinstance
(
attn
.
quant_method
,
Fp8KVCacheMethod
)
# these checkpoints have scales < 1.0
assert
0.0
<
attn
.
_k_scale
<
1.0
# NOTE: it is valid for scales to be 1.0 (default value), but
assert
0.0
<
attn
.
_v_scale
<
1.0
# we know these checkpoints have scales < 1.0
assert
0.0
<
attn
.
_k_scale
<
1.0
assert
0.0
<
attn
.
_v_scale
<
1.0
llm
.
apply_model
(
check_model
)
# note: this does not test accuracy, just that we can run through
# note: this does not test accuracy, just that we can run through
# see lm-eval tests for accuracy
# see lm-eval tests for accuracy
...
@@ -80,22 +84,24 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
...
@@ -80,22 +84,24 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
quantization
=
"fp8"
,
quantization
=
"fp8"
,
kv_cache_dtype
=
kv_cache_dtype
)
as
llm
:
kv_cache_dtype
=
kv_cache_dtype
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
def
check_model
(
model
):
fc1
=
model
.
model
.
decoder
.
layers
[
0
].
fc1
fc1
=
model
.
model
.
decoder
.
layers
[
0
].
fc1
assert
isinstance
(
fc1
.
quant_method
,
Fp8LinearMethod
)
assert
isinstance
(
fc1
.
quant_method
,
Fp8LinearMethod
)
if
kv_cache_dtype
==
"fp8"
:
if
kv_cache_dtype
==
"fp8"
:
attn
=
model
.
model
.
decoder
.
layers
[
0
].
self_attn
.
attn
attn
=
model
.
model
.
decoder
.
layers
[
0
].
self_attn
.
attn
assert
isinstance
(
attn
.
quant_method
,
Fp8KVCacheMethod
)
assert
isinstance
(
attn
.
quant_method
,
Fp8KVCacheMethod
)
assert
attn
.
_k_scale
==
1.0
assert
attn
.
_k_scale
==
1.0
assert
attn
.
_v_scale
==
1.0
assert
attn
.
_v_scale
==
1.0
if
current_platform
.
has_device_capability
(
89
)
and
not
force_marlin
:
if
current_platform
.
has_device_capability
(
89
)
and
not
force_marlin
:
# For GPUs with hardware support, we keep weights in fp8
# For GPUs with hardware support, we keep weights in fp8
assert
fc1
.
weight
.
dtype
==
torch
.
float8_e4m3fn
assert
fc1
.
weight
.
dtype
==
torch
.
float8_e4m3fn
else
:
else
:
# For GPUs without hardware support, we pack the fp8 weights
# For GPUs without hardware support, we pack the fp8 weights
# for weight-only quantization using Marlin kernels
# for weight-only quantization using Marlin kernels
assert
fc1
.
weight
.
dtype
==
torch
.
int32
assert
fc1
.
weight
.
dtype
==
torch
.
int32
llm
.
apply_model
(
check_model
)
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
)
or
is_hip
(),
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
)
or
is_hip
(),
...
...
tests/quantization/test_lm_head.py
View file @
afd0da21
...
@@ -30,20 +30,23 @@ def test_lm_head(
...
@@ -30,20 +30,23 @@ def test_lm_head(
model_lm_head_quant
:
Tuple
[
str
,
bool
],
model_lm_head_quant
:
Tuple
[
str
,
bool
],
)
->
None
:
)
->
None
:
model
,
lm_head_quantized
=
model_lm_head_quant
model
,
lm_head_quantized
=
model_lm_head_quant
vllm_model
=
vllm_runner
(
model
,
dtype
=
torch
.
float16
,
max_model_len
=
2048
)
with
vllm_runner
(
model
,
dtype
=
torch
.
float16
,
lm_head_layer
=
(
vllm_model
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
max_model_len
=
2048
)
as
vllm_model
:
model_runner
.
model
.
lm_head
)
def
check_model
(
model
):
if
lm_head_quantized
:
lm_head_layer
=
model
.
lm_head
assert
isinstance
(
lm_head_layer
.
linear_method
,
if
lm_head_quantized
:
(
GPTQLinearMethod
,
GPTQMarlinLinearMethod
,
MarlinLinearMethod
))
assert
isinstance
(
lm_head_layer
.
linear_method
,
else
:
(
GPTQLinearMethod
,
GPTQMarlinLinearMethod
,
assert
isinstance
(
lm_head_layer
.
linear_method
,
MarlinLinearMethod
))
UnquantizedEmbeddingMethod
)
else
:
assert
isinstance
(
lm_head_layer
.
linear_method
,
print
(
UnquantizedEmbeddingMethod
)
vllm_model
.
generate_greedy
(
prompts
=
[
"Hello my name is"
],
max_tokens
=
10
)[
0
][
1
])
vllm_model
.
apply_model
(
check_model
)
del
vllm_model
print
(
vllm_model
.
generate_greedy
(
prompts
=
[
"Hello my name is"
],
max_tokens
=
10
)[
0
][
1
])
Prev
1
…
20
21
22
23
24
25
26
27
28
…
30
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment