Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
afd0da21
Commit
afd0da21
authored
Feb 03, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.7.1' into v0.7.1-dev
parents
1a11f127
4f4d427a
Changes
587
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1427 additions
and
259 deletions
+1427
-259
tests/models/multimodal/processing/test_phi3v.py
tests/models/multimodal/processing/test_phi3v.py
+55
-0
tests/models/multimodal/processing/test_qwen2_vl.py
tests/models/multimodal/processing/test_qwen2_vl.py
+54
-0
tests/models/registry.py
tests/models/registry.py
+84
-6
tests/models/test_initialization.py
tests/models/test_initialization.py
+5
-6
tests/models/test_registry.py
tests/models/test_registry.py
+3
-0
tests/multi_step/test_correctness_async_llm.py
tests/multi_step/test_correctness_async_llm.py
+3
-3
tests/multi_step/test_correctness_llm.py
tests/multi_step/test_correctness_llm.py
+16
-1
tests/multimodal/test_processing.py
tests/multimodal/test_processing.py
+217
-66
tests/multimodal/test_utils.py
tests/multimodal/test_utils.py
+242
-26
tests/multimodal/utils.py
tests/multimodal/utils.py
+33
-0
tests/neuron/test_prefix_prefill.py
tests/neuron/test_prefix_prefill.py
+456
-0
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py
...ins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py
+6
-4
tests/plugins/vllm_add_dummy_platform/setup.py
tests/plugins/vllm_add_dummy_platform/setup.py
+11
-0
tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/__init__.py
...lm_add_dummy_platform/vllm_add_dummy_platform/__init__.py
+5
-0
tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_attention_backend.py
...atform/vllm_add_dummy_platform/dummy_attention_backend.py
+8
-0
tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py
..._dummy_platform/vllm_add_dummy_platform/dummy_platform.py
+9
-0
tests/plugins_tests/test_platform_plugins.py
tests/plugins_tests/test_platform_plugins.py
+30
-0
tests/quantization/test_compressed_tensors.py
tests/quantization/test_compressed_tensors.py
+141
-107
tests/quantization/test_fp8.py
tests/quantization/test_fp8.py
+29
-23
tests/quantization/test_lm_head.py
tests/quantization/test_lm_head.py
+20
-17
No files found.
Too many changes to show.
To preserve performance only
587 of 587+
files are displayed.
Plain diff
Email patch
tests/models/multimodal/processing/test_phi3v.py
0 → 100644
View file @
afd0da21
"""Tests for phi3v's multimodal preprocessing kwargs."""
import
pytest
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
....conftest
import
_ImageAssets
from
...utils
import
build_model_context
@
pytest
.
mark
.
parametrize
(
"model_id"
,
[
"microsoft/Phi-3.5-vision-instruct"
])
# yapf: disable
@
pytest
.
mark
.
parametrize
(
(
"mm_processor_kwargs"
,
"expected_toks_per_img"
),
[
({
"num_crops"
:
4
},
757
),
({
"num_crops"
:
16
},
1921
),
# the default num_crops of phi-3.5-vision is 4
({},
757
),
])
# yapf: enable
@
pytest
.
mark
.
parametrize
(
"num_imgs"
,
[
1
,
2
])
def
test_processor_override
(
image_assets
:
_ImageAssets
,
model_id
:
str
,
mm_processor_kwargs
:
dict
[
str
,
int
],
expected_toks_per_img
:
int
,
num_imgs
:
int
,
):
"""Ensure input_processor_for_phi3v handles num_crops properly."""
# Avoid initializing CUDA early
from
vllm.model_executor.models.phi3v
import
_IMAGE_TOKEN_ID
ctx
=
build_model_context
(
model_name
=
model_id
,
tokenizer_name
=
model_id
,
trust_remote_code
=
True
,
limit_mm_per_prompt
=
{
"image"
:
num_imgs
},
)
tokenizer
=
cached_get_tokenizer
(
ctx
.
model_config
.
tokenizer
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
ctx
.
model_config
,
tokenizer
=
tokenizer
,
)
# Build the image str / prompt based on the number of images we pass
img_str
=
""
.
join
([
f
"<|image_
{
idx
}
|>
\n
"
for
idx
in
range
(
1
,
num_imgs
+
1
)])
prompt
=
f
"<|user|>
\n
{
img_str
}
<|end|>
\n
<|assistant|>
\n
"
mm_data
=
{
"image"
:
[
image_assets
[
0
].
pil_image
]
*
num_imgs
}
processed_inputs
=
processor
.
apply
(
prompt
,
mm_data
,
mm_processor_kwargs
)
# Ensure we have the right number of placeholders per num_crops size
img_tok_count
=
processed_inputs
[
"prompt_token_ids"
].
count
(
_IMAGE_TOKEN_ID
)
assert
img_tok_count
==
expected_toks_per_img
*
num_imgs
tests/models/multimodal/processing/test_qwen2_vl.py
0 → 100644
View file @
afd0da21
import
pytest
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
....conftest
import
_ImageAssets
from
...utils
import
build_model_context
@
pytest
.
mark
.
parametrize
(
"model_id"
,
[
"Qwen/Qwen2-VL-2B-Instruct"
])
# yapf: disable
@
pytest
.
mark
.
parametrize
(
(
"mm_processor_kwargs"
,
"expected_toks_per_img"
,
"expected_pixels_shape"
),
[
({},
1426
,
(
5704
,
1176
)),
({
"min_pixels"
:
64
**
2
,
"max_pixels"
:
512
**
2
},
330
,
(
1320
,
1176
)),
])
# yapf: enable
@
pytest
.
mark
.
parametrize
(
"num_imgs"
,
[
1
,
2
])
def
test_processor_override
(
image_assets
:
_ImageAssets
,
model_id
:
str
,
mm_processor_kwargs
:
dict
[
str
,
object
],
expected_toks_per_img
:
int
,
expected_pixels_shape
:
tuple
[
int
,
int
],
num_imgs
:
int
,
):
"""Ensure Qwen2VLMultiModalProcessor handles min/max pixels properly."""
ctx
=
build_model_context
(
model_name
=
model_id
,
tokenizer_name
=
model_id
,
mm_processor_kwargs
=
None
,
limit_mm_per_prompt
=
{
"image"
:
num_imgs
},
)
tokenizer
=
cached_get_tokenizer
(
ctx
.
model_config
.
tokenizer
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
ctx
.
model_config
,
tokenizer
=
tokenizer
,
)
# Build the image str / prompt based on the number of images we pass
prompt
=
"<|vision_start|><|image_pad|><|vision_end|>"
*
num_imgs
mm_data
=
{
"image"
:
[
image_assets
[
0
].
pil_image
]
*
num_imgs
}
processed_inputs
=
processor
.
apply
(
prompt
,
mm_data
,
mm_processor_kwargs
)
# Ensure we have the right number of placeholders per num_crops size
hf_processor
=
processor
.
info
.
get_hf_processor
(
**
mm_processor_kwargs
)
image_token_id
=
tokenizer
.
convert_tokens_to_ids
(
hf_processor
.
image_token
)
img_tok_count
=
processed_inputs
[
"prompt_token_ids"
].
count
(
image_token_id
)
pixel_shape
=
processed_inputs
[
"mm_kwargs"
][
"pixel_values"
].
shape
assert
img_tok_count
==
expected_toks_per_img
*
num_imgs
assert
pixel_shape
[
0
]
==
expected_pixels_shape
[
0
]
*
num_imgs
assert
pixel_shape
[
1
]
==
expected_pixels_shape
[
1
]
tests/models/registry.py
View file @
afd0da21
from
dataclasses
import
dataclass
,
field
from
typing
import
AbstractSet
,
Mapping
,
Optional
from
typing
import
AbstractSet
,
Any
,
Literal
,
Mapping
,
Optional
import
pytest
from
packaging.version
import
Version
from
transformers
import
__version__
as
TRANSFORMERS_VERSION
@
dataclass
(
frozen
=
True
)
...
...
@@ -22,6 +26,11 @@ class _HfExamplesInfo:
for speculative decoding.
"""
min_transformers_version
:
Optional
[
str
]
=
None
"""
The minimum version of HF Transformers that is required to run this model.
"""
is_available_online
:
bool
=
True
"""
Set this to ``False`` if the name of this architecture no longer exists on
...
...
@@ -33,6 +42,50 @@ class _HfExamplesInfo:
trust_remote_code
:
bool
=
False
"""The ``trust_remote_code`` level required to load the model."""
hf_overrides
:
dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
"""The ``hf_overrides`` required to load the model."""
def
check_transformers_version
(
self
,
*
,
on_fail
:
Literal
[
"error"
,
"skip"
],
)
->
None
:
"""
If the installed transformers version does not meet the requirements,
perform the given action.
"""
if
self
.
min_transformers_version
is
None
:
return
current_version
=
TRANSFORMERS_VERSION
required_version
=
self
.
min_transformers_version
if
Version
(
current_version
)
<
Version
(
required_version
):
msg
=
(
f
"You have `transformers==
{
current_version
}
` installed, but "
f
"`transformers>=
{
required_version
}
` is required to run this "
"model"
)
if
on_fail
==
"error"
:
raise
RuntimeError
(
msg
)
else
:
pytest
.
skip
(
msg
)
def
check_available_online
(
self
,
*
,
on_fail
:
Literal
[
"error"
,
"skip"
],
)
->
None
:
"""
If the model is not available online, perform the given action.
"""
if
not
self
.
is_available_online
:
msg
=
"Model is not available online"
if
on_fail
==
"error"
:
raise
RuntimeError
(
msg
)
else
:
pytest
.
skip
(
msg
)
# yapf: disable
_TEXT_GENERATION_EXAMPLE_MODELS
=
{
...
...
@@ -43,8 +96,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code
=
True
),
"ArcticForCausalLM"
:
_HfExamplesInfo
(
"Snowflake/snowflake-arctic-instruct"
,
trust_remote_code
=
True
),
"AriaForConditionalGeneration"
:
_HfExamplesInfo
(
"rhymes-ai/Aria"
,
trust_remote_code
=
True
),
"BaiChuanForCausalLM"
:
_HfExamplesInfo
(
"baichuan-inc/Baichuan-7B"
,
trust_remote_code
=
True
),
"BaichuanForCausalLM"
:
_HfExamplesInfo
(
"baichuan-inc/Baichuan2-7B-chat"
,
...
...
@@ -64,6 +115,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"DeepseekV3ForCausalLM"
:
_HfExamplesInfo
(
"deepseek-ai/DeepSeek-V3"
,
# noqa: E501
trust_remote_code
=
True
),
"ExaoneForCausalLM"
:
_HfExamplesInfo
(
"LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"
),
# noqa: E501
"Fairseq2LlamaForCausalLM"
:
_HfExamplesInfo
(
"mgleize/fairseq2-dummy-Llama-3.2-1B"
),
# noqa: E501
"FalconForCausalLM"
:
_HfExamplesInfo
(
"tiiuae/falcon-7b"
),
"GemmaForCausalLM"
:
_HfExamplesInfo
(
"google/gemma-2b"
),
"Gemma2ForCausalLM"
:
_HfExamplesInfo
(
"google/gemma-2-9b"
),
...
...
@@ -80,6 +132,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code
=
True
),
"InternLM2VEForCausalLM"
:
_HfExamplesInfo
(
"OpenGVLab/Mono-InternVL-2B"
,
trust_remote_code
=
True
),
"InternLM3ForCausalLM"
:
_HfExamplesInfo
(
"internlm/internlm3-8b-instruct"
,
trust_remote_code
=
True
),
"JAISLMHeadModel"
:
_HfExamplesInfo
(
"inceptionai/jais-13b-chat"
),
"JambaForCausalLM"
:
_HfExamplesInfo
(
"ai21labs/AI21-Jamba-1.5-Mini"
),
"LlamaForCausalLM"
:
_HfExamplesInfo
(
"meta-llama/Meta-Llama-3-8B"
),
...
...
@@ -140,11 +194,14 @@ _EMBEDDING_EXAMPLE_MODELS = {
"BertModel"
:
_HfExamplesInfo
(
"BAAI/bge-base-en-v1.5"
),
"Gemma2Model"
:
_HfExamplesInfo
(
"BAAI/bge-multilingual-gemma2"
),
"GritLM"
:
_HfExamplesInfo
(
"parasail-ai/GritLM-7B-vllm"
),
"InternLM2ForRewardModel"
:
_HfExamplesInfo
(
"internlm/internlm2-1_8b-reward"
,
trust_remote_code
=
True
),
"JambaForSequenceClassification"
:
_HfExamplesInfo
(
"ai21labs/Jamba-tiny-reward-dev"
),
# noqa: E501
"LlamaModel"
:
_HfExamplesInfo
(
"llama"
,
is_available_online
=
False
),
"MistralModel"
:
_HfExamplesInfo
(
"intfloat/e5-mistral-7b-instruct"
),
"Qwen2Model"
:
_HfExamplesInfo
(
"ssmits/Qwen2-7B-Instruct-embed-base"
),
"Qwen2ForRewardModel"
:
_HfExamplesInfo
(
"Qwen/Qwen2.5-Math-RM-72B"
),
"Qwen2ForProcessRewardModel"
:
_HfExamplesInfo
(
"Qwen/Qwen2.5-Math-PRM-7B"
),
"Qwen2ForSequenceClassification"
:
_HfExamplesInfo
(
"jason9693/Qwen2.5-1.5B-apeach"
),
# noqa: E501
"RobertaModel"
:
_HfExamplesInfo
(
"sentence-transformers/stsb-roberta-base-v2"
),
# noqa: E501
"RobertaForMaskedLM"
:
_HfExamplesInfo
(
"sentence-transformers/all-roberta-large-v1"
),
# noqa: E501
...
...
@@ -165,6 +222,8 @@ _CROSS_ENCODER_EXAMPLE_MODELS = {
_MULTIMODAL_EXAMPLE_MODELS
=
{
# [Decoder-only]
"AriaForConditionalGeneration"
:
_HfExamplesInfo
(
"rhymes-ai/Aria"
,
min_transformers_version
=
"4.48"
),
"Blip2ForConditionalGeneration"
:
_HfExamplesInfo
(
"Salesforce/blip2-opt-2.7b"
),
# noqa: E501
"ChameleonForConditionalGeneration"
:
_HfExamplesInfo
(
"facebook/chameleon-7b"
),
# noqa: E501
"ChatGLMModel"
:
_HfExamplesInfo
(
"THUDM/glm-4v-9b"
,
...
...
@@ -172,6 +231,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
trust_remote_code
=
True
),
"ChatGLMForConditionalGeneration"
:
_HfExamplesInfo
(
"chatglm2-6b"
,
is_available_online
=
False
),
"DeepseekVLV2ForCausalLM"
:
_HfExamplesInfo
(
"deepseek-ai/deepseek-vl2-tiny"
,
# noqa: E501
hf_overrides
=
{
"architectures"
:
[
"DeepseekVLV2ForCausalLM"
]}),
# noqa: E501
"FuyuForCausalLM"
:
_HfExamplesInfo
(
"adept/fuyu-8b"
),
"H2OVLChatModel"
:
_HfExamplesInfo
(
"h2oai/h2ovl-mississippi-800m"
),
"InternVLChatModel"
:
_HfExamplesInfo
(
"OpenGVLab/InternVL2-1B"
,
...
...
@@ -182,8 +243,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"LlavaNextForConditionalGeneration"
:
_HfExamplesInfo
(
"llava-hf/llava-v1.6-mistral-7b-hf"
),
# noqa: E501
"LlavaNextVideoForConditionalGeneration"
:
_HfExamplesInfo
(
"llava-hf/LLaVA-NeXT-Video-7B-hf"
),
# noqa: E501
"LlavaOnevisionForConditionalGeneration"
:
_HfExamplesInfo
(
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf"
),
# noqa: E501
"MantisForConditionalGeneration"
:
_HfExamplesInfo
(
"TIGER-Lab/Mantis-8B-siglip-llama3"
),
# noqa: E501
"MiniCPMV"
:
_HfExamplesInfo
(
"openbmb/MiniCPM-Llama3-V-2_5"
,
"MantisForConditionalGeneration"
:
_HfExamplesInfo
(
"TIGER-Lab/Mantis-8B-siglip-llama3"
,
# noqa: E501
hf_overrides
=
{
"architectures"
:
[
"MantisForConditionalGeneration"
]}),
# noqa: E501
"MiniCPMO"
:
_HfExamplesInfo
(
"openbmb/MiniCPM-o-2_6"
,
trust_remote_code
=
True
),
"MiniCPMV"
:
_HfExamplesInfo
(
"openbmb/MiniCPM-V-2_6"
,
trust_remote_code
=
True
),
"MolmoForCausalLM"
:
_HfExamplesInfo
(
"allenai/Molmo-7B-D-0924"
,
trust_remote_code
=
True
),
...
...
@@ -199,9 +263,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
trust_remote_code
=
True
),
"Qwen2AudioForConditionalGeneration"
:
_HfExamplesInfo
(
"Qwen/Qwen2-Audio-7B-Instruct"
),
# noqa: E501
"Qwen2VLForConditionalGeneration"
:
_HfExamplesInfo
(
"Qwen/Qwen2-VL-2B-Instruct"
),
# noqa: E501
"UltravoxModel"
:
_HfExamplesInfo
(
"fixie-ai/ultravox-v0_3"
),
"UltravoxModel"
:
_HfExamplesInfo
(
"fixie-ai/ultravox-v0_3"
,
trust_remote_code
=
True
),
# [Encoder-decoder]
"MllamaForConditionalGeneration"
:
_HfExamplesInfo
(
"meta-llama/Llama-3.2-11B-Vision-Instruct"
),
# noqa: E501
"WhisperForConditionalGeneration"
:
_HfExamplesInfo
(
"openai/whisper-large-v3"
),
# noqa: E501
}
_SPECULATIVE_DECODING_EXAMPLE_MODELS
=
{
...
...
@@ -234,5 +300,17 @@ class HfExampleModels:
def
get_hf_info
(
self
,
model_arch
:
str
)
->
_HfExamplesInfo
:
return
self
.
hf_models
[
model_arch
]
def
find_hf_info
(
self
,
model_id
:
str
)
->
_HfExamplesInfo
:
for
info
in
self
.
hf_models
.
values
():
if
info
.
default
==
model_id
:
return
info
# Fallback to extras
for
info
in
self
.
hf_models
.
values
():
if
any
(
extra
==
model_id
for
extra
in
info
.
extras
.
values
()):
return
info
raise
ValueError
(
f
"No example model defined for
{
model_id
}
"
)
HF_EXAMPLE_MODELS
=
HfExampleModels
(
_EXAMPLE_MODELS
)
tests/models/test_initialization.py
View file @
afd0da21
from
unittest.mock
import
patch
import
pytest
import
transformers
from
transformers
import
PretrainedConfig
from
vllm
import
LLM
...
...
@@ -12,14 +11,14 @@ from .registry import HF_EXAMPLE_MODELS
@
pytest
.
mark
.
parametrize
(
"model_arch"
,
HF_EXAMPLE_MODELS
.
get_supported_archs
())
def
test_can_initialize
(
model_arch
):
model_info
=
HF_EXAMPLE_MODELS
.
get_hf_info
(
model_arch
)
if
(
model_arch
==
"Cohere2ForCausalLM"
and
transformers
.
__version__
<
"4.48.0"
):
pytest
.
skip
(
reason
=
"Model introduced in HF >= 4.48.0"
)
if
not
model_info
.
is_available_online
:
pytest
.
skip
(
"Model is not available online"
)
model_info
.
check_available_online
(
on_fail
=
"skip"
)
model_info
.
check_transformers_version
(
on_fail
=
"skip"
)
# Avoid OOM
def
hf_overrides
(
hf_config
:
PretrainedConfig
)
->
PretrainedConfig
:
if
hf_config
.
model_type
==
"deepseek_vl_v2"
:
hf_config
.
update
({
"architectures"
:
[
"DeepseekVLV2ForCausalLM"
]})
if
hasattr
(
hf_config
,
"text_config"
):
text_config
:
PretrainedConfig
=
hf_config
.
text_config
else
:
...
...
tests/models/test_registry.py
View file @
afd0da21
...
...
@@ -21,6 +21,9 @@ from .registry import HF_EXAMPLE_MODELS
@
pytest
.
mark
.
parametrize
(
"model_arch"
,
ModelRegistry
.
get_supported_archs
())
def
test_registry_imports
(
model_arch
):
model_info
=
HF_EXAMPLE_MODELS
.
get_hf_info
(
model_arch
)
model_info
.
check_transformers_version
(
on_fail
=
"skip"
)
# Ensure all model classes can be imported successfully
model_cls
,
_
=
ModelRegistry
.
resolve_model_cls
(
model_arch
)
...
...
tests/multi_step/test_correctness_async_llm.py
View file @
afd0da21
...
...
@@ -17,8 +17,8 @@ NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps
NUM_PROMPTS
=
[
10
]
DEFAULT_SERVER_ARGS
:
List
[
str
]
=
[
"--dis
able-log-requests
"
,
"
--worker-use-
ray"
,
"--dis
tributed-executor-backend
"
,
"ray"
,
"--gpu-memory-utilization"
,
"0.85"
,
"--swap-space"
,
...
...
@@ -112,7 +112,7 @@ async def test_multi_step(
# Spin up client/server & issue completion API requests.
# Default `max_wait_seconds` is 240 but was empirically
# was raised
3
x to
72
0 *just for this test* due to
# was raised
5
x to
120
0 *just for this test* due to
# observed timeouts in GHA CI
ref_completions
=
await
completions_with_server_args
(
prompts
,
...
...
tests/multi_step/test_correctness_llm.py
View file @
afd0da21
...
...
@@ -6,6 +6,8 @@ from typing import Optional
import
pytest
import
os
from
tests.kernels.utils
import
override_backend_env_variable
from
..models.utils
import
check_logprobs_close
,
check_outputs_equal
from
..utils
import
models_path_prefix
...
...
@@ -21,10 +23,11 @@ NUM_PROMPTS = [10]
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"enable_chunked_prefill"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"num_scheduler_steps"
,
NUM_SCHEDULER_STEPS
)
@
pytest
.
mark
.
parametrize
(
"num_prompts"
,
NUM_PROMPTS
)
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
None
,
5
])
@
pytest
.
mark
.
parametrize
(
"attention_backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
])
def
test_multi_step_llm
(
hf_runner
,
vllm_runner
,
...
...
@@ -38,6 +41,8 @@ def test_multi_step_llm(
num_scheduler_steps
:
int
,
num_prompts
:
int
,
num_logprobs
:
Optional
[
int
],
attention_backend
:
str
,
monkeypatch
,
)
->
None
:
"""Test vLLM engine with multi-step scheduling via sync LLM Engine.
...
...
@@ -65,6 +70,7 @@ def test_multi_step_llm(
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
completions endpoint; `None` -> 1 logprob returned.
"""
override_backend_env_variable
(
monkeypatch
,
attention_backend
)
prompts
=
example_prompts
if
len
(
prompts
)
<
num_prompts
:
...
...
@@ -116,6 +122,7 @@ def test_multi_step_llm(
@
pytest
.
mark
.
parametrize
(
"num_scheduler_steps"
,
NUM_SCHEDULER_STEPS
)
@
pytest
.
mark
.
parametrize
(
"num_prompts"
,
NUM_PROMPTS
)
@
pytest
.
mark
.
parametrize
(
"num_logprobs,num_prompt_logprobs"
,
[(
5
,
5
)])
@
pytest
.
mark
.
parametrize
(
"attention_backend"
,
[
"FLASH_ATTN"
])
def
test_multi_step_llm_w_prompt_logprobs
(
vllm_runner
,
example_prompts
,
...
...
@@ -128,6 +135,8 @@ def test_multi_step_llm_w_prompt_logprobs(
num_prompts
:
int
,
num_logprobs
:
Optional
[
int
],
num_prompt_logprobs
:
Optional
[
int
],
attention_backend
:
str
,
monkeypatch
,
)
->
None
:
"""Test prompt logprobs with multi-step scheduling via sync LLM Engine.
...
...
@@ -157,6 +166,7 @@ def test_multi_step_llm_w_prompt_logprobs(
note that this argument is not supported by the
OpenAI completions endpoint.
"""
override_backend_env_variable
(
monkeypatch
,
attention_backend
)
prompts
=
example_prompts
if
len
(
prompts
)
<
num_prompts
:
...
...
@@ -207,6 +217,7 @@ def test_multi_step_llm_w_prompt_logprobs(
@
pytest
.
mark
.
parametrize
(
"num_scheduler_steps"
,
NUM_SCHEDULER_STEPS
)
@
pytest
.
mark
.
parametrize
(
"num_prompts"
,
NUM_PROMPTS
)
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
None
,
5
])
@
pytest
.
mark
.
parametrize
(
"attention_backend"
,
[
"FLASH_ATTN"
])
def
test_multi_step_llm_chunked_prefill_prefix_cache
(
vllm_runner
,
example_prompts
,
...
...
@@ -218,6 +229,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache(
num_scheduler_steps
:
int
,
num_prompts
:
int
,
num_logprobs
:
Optional
[
int
],
attention_backend
:
str
,
monkeypatch
,
)
->
None
:
"""Test vLLM engine with multi-step+"single-step chunked prefill"+APC.
...
...
@@ -280,6 +293,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache(
#
# The Incorrect scheduling behavior - if it occurs - will cause an exception
# in the model runner resulting from `do_sample=False`.
override_backend_env_variable
(
monkeypatch
,
attention_backend
)
assert
len
(
example_prompts
)
>=
2
challenge_prompts
=
copy
.
deepcopy
(
example_prompts
)
challenge_prompts
[
0
]
=
(
'vLLM is a high-throughput and memory-efficient '
...
...
tests/multimodal/test_processing.py
View file @
afd0da21
from
contextlib
import
nullcontext
from
typing
import
cast
from
unittest.mock
import
MagicMock
import
numpy
as
np
import
pytest
from
vllm.multimodal.processing
import
(
PromptReplacement
,
_PlaceholderInfo
,
from
vllm.config
import
ModelConfig
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.multimodal.processing
import
(
PlaceholderFeaturesInfo
,
PromptReplacement
,
find_mm_placeholders
,
find_text_matches
,
find_token_matches
,
iter_placeholders
,
iter_token_matches
,
iter_token_matches
,
replace_text_matches
,
replace_token_matches
)
# yapf: enable
from
vllm.multimodal.profiling
import
MultiModalProfiler
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
full_groupby
from
.utils
import
random_image
# yapf: disable
@
pytest
.
mark
.
parametrize
(
...
...
@@ -304,21 +318,27 @@ def test_find_replace_text(
# Should not be used since there is nothing to convert to text
mock_tokenizer
=
cast
(
AnyTokenizer
,
object
())
prompt_repls
=
[
PromptReplacement
(
key
,
target
,
repl_by_key
[
key
]).
bind
(
mock_tokenizer
)
mm_prompt_repls
=
{
key
:
[
PromptReplacement
(
key
,
target
,
repl_by_key
[
key
]).
bind
(
mock_tokenizer
)
]
for
key
,
target
in
target_by_key
.
items
()
]
matches
=
find_text_matches
(
prompt
,
prompt_repls
)
}
mm_matches
=
{
key
:
find_text_matches
(
prompt
,
prompt_repls
)
for
key
,
prompt_repls
in
mm_prompt_repls
.
items
()
}
result
=
replace_text_matches
(
prompt
,
matches
,
mm_
matches
,
{
key
:
mm_count
for
key
in
repl_by_key
},
)
# Only displayed on error
print
(
"matches:"
,
matches
)
print
(
"
mm_
matches:"
,
mm_
matches
)
print
(
"result:"
,
result
)
# Manually constructed results
...
...
@@ -370,21 +390,27 @@ def test_find_replace_tokens(
# Should not be used since there is nothing to convert to tokens
mock_tokenizer
=
cast
(
AnyTokenizer
,
object
())
prompt_repls
=
[
PromptReplacement
(
key
,
target
,
repl_by_key
[
key
]).
bind
(
mock_tokenizer
)
mm_prompt_repls
=
{
key
:
[
PromptReplacement
(
key
,
target
,
repl_by_key
[
key
]).
bind
(
mock_tokenizer
)
]
for
key
,
target
in
target_by_key
.
items
()
]
matches
=
find_token_matches
(
prompt
,
prompt_repls
)
}
mm_matches
=
{
key
:
find_token_matches
(
prompt
,
prompt_repls
)
for
key
,
prompt_repls
in
mm_prompt_repls
.
items
()
}
result
=
replace_token_matches
(
prompt
,
matches
,
mm_
matches
,
{
key
:
mm_count
for
key
in
repl_by_key
},
)
# Only displayed on error
print
(
"matches:"
,
matches
)
print
(
"
mm_
matches:"
,
mm_
matches
)
print
(
"result:"
,
result
)
# Manually constructed results
...
...
@@ -399,6 +425,8 @@ def test_find_replace_tokens(
"pattern_1"
:
[
32000
,
32000
],
"pattern_2"
:
[],
"pattern_3"
:
[
1550
,
918
,
1550
],
# Test different modalities having the same tokens (32000)
"pattern_4"
:
[
32000
],
},
],
)
...
...
@@ -407,57 +435,93 @@ def test_find_replace_tokens(
[
(
[
1
,
9833
,
28747
,
32000
,
9833
,
28747
,
32000
,
32000
,
918
],
[
_PlaceholderInfo
(
modality
=
"pattern_1"
,
start_idx
=
6
,
replacement
=
[
32000
,
32000
],
),
],
{
"pattern_1"
:
[
PlaceholderFeaturesInfo
(
modality
=
"pattern_1"
,
item_idx
=
0
,
start_idx
=
6
,
tokens
=
[
32000
,
32000
],
),
],
"pattern_4"
:
[
PlaceholderFeaturesInfo
(
modality
=
"pattern_4"
,
item_idx
=
0
,
start_idx
=
3
,
tokens
=
[
32000
],
),
],
}
),
(
[
1
,
32000
,
32000
,
9833
,
28747
,
32000
,
32000
,
1550
,
918
,
1550
],
[
_PlaceholderInfo
(
modality
=
"pattern_1"
,
start_idx
=
1
,
replacement
=
[
32000
,
32000
],
),
_PlaceholderInfo
(
modality
=
"pattern_1"
,
start_idx
=
5
,
replacement
=
[
32000
,
32000
],
),
_PlaceholderInfo
(
modality
=
"pattern_3"
,
start_idx
=
7
,
replacement
=
[
1550
,
918
,
1550
],
),
],
{
"pattern_1"
:
[
PlaceholderFeaturesInfo
(
modality
=
"pattern_1"
,
item_idx
=
0
,
start_idx
=
1
,
tokens
=
[
32000
,
32000
],
),
PlaceholderFeaturesInfo
(
modality
=
"pattern_1"
,
item_idx
=
1
,
start_idx
=
5
,
tokens
=
[
32000
,
32000
],
),
],
"pattern_3"
:
[
PlaceholderFeaturesInfo
(
modality
=
"pattern_3"
,
item_idx
=
0
,
start_idx
=
7
,
tokens
=
[
1550
,
918
,
1550
],
),
],
# No match for pattern_4 as it has lower priority than pattern_1
}
),
(
[
1
,
32000
,
32000
,
32000
,
32000
,
32000
,
1550
,
918
,
1550
],
[
_PlaceholderInfo
(
modality
=
"pattern_1"
,
start_idx
=
1
,
replacement
=
[
32000
,
32000
],
),
_PlaceholderInfo
(
modality
=
"pattern_1"
,
start_idx
=
3
,
replacement
=
[
32000
,
32000
],
),
_PlaceholderInfo
(
modality
=
"pattern_3"
,
start_idx
=
6
,
replacement
=
[
1550
,
918
,
1550
],
),
],
{
"pattern_1"
:
[
PlaceholderFeaturesInfo
(
modality
=
"pattern_1"
,
item_idx
=
0
,
start_idx
=
1
,
tokens
=
[
32000
,
32000
],
),
PlaceholderFeaturesInfo
(
modality
=
"pattern_1"
,
item_idx
=
1
,
start_idx
=
3
,
tokens
=
[
32000
,
32000
],
),
],
"pattern_4"
:
[
PlaceholderFeaturesInfo
(
modality
=
"pattern_4"
,
item_idx
=
0
,
start_idx
=
5
,
tokens
=
[
32000
],
),
],
"pattern_3"
:
[
PlaceholderFeaturesInfo
(
modality
=
"pattern_3"
,
item_idx
=
0
,
start_idx
=
6
,
tokens
=
[
1550
,
918
,
1550
],
),
],
}
),
]
)
def
test_iter_placeholders
(
# yapf: enable
def
test_find_mm_placeholders
(
repl_by_key
,
prompt
,
expected
,
...
...
@@ -465,21 +529,108 @@ def test_iter_placeholders(
# Should not be used since there is nothing to convert to tokens
mock_tokenizer
=
cast
(
AnyTokenizer
,
object
())
prompt_repls
=
[
PromptReplacement
(
key
,
[],
repl
).
bind
(
mock_tokenizer
)
mm_
prompt_repls
=
{
key
:
[
PromptReplacement
(
key
,
[],
repl
).
bind
(
mock_tokenizer
)
]
for
key
,
repl
in
repl_by_key
.
items
()
]
}
result
=
list
(
iter_placeholders
(
prompt
_repls
,
prompt
,
# Effectively match all occurrences in the prompt
{
key
:
3
for
key
in
repl_by_key
},
)
)
result
=
find_mm_placeholders
(
mm_prompt_repls
,
prompt
,
# Effectively match all occurrences in the
prompt
{
key
:
3
for
key
in
repl_by_key
},
)
# Only displayed on error
print
(
"result:"
,
result
)
# Manually constructed results
assert
result
==
expected
@
pytest
.
mark
.
parametrize
(
"model_id"
,
[
"llava-hf/llava-v1.6-mistral-7b-hf"
])
@
pytest
.
mark
.
parametrize
(
(
"limit"
,
"num_supported"
,
"is_valid"
),
[(
0
,
0
,
True
),
(
0
,
1
,
True
),
(
1
,
0
,
False
),
(
1
,
1
,
True
),
(
1
,
2
,
True
),
(
2
,
1
,
False
),
(
2
,
2
,
True
)],
)
def
test_limit_mm_per_prompt_dummy
(
model_id
,
limit
,
num_supported
,
is_valid
):
limit_mm_per_prompt
=
{
"image"
:
limit
}
model_config
=
ModelConfig
(
model
=
model_id
,
task
=
"auto"
,
tokenizer
=
model_id
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
seed
=
0
,
dtype
=
"half"
,
revision
=
None
,
limit_mm_per_prompt
=
limit_mm_per_prompt
,
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
model_config
,
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
),
)
profiler
=
MultiModalProfiler
(
processor
)
mock_supported_mm_limits
=
MagicMock
(
return_value
=
{
"image"
:
num_supported
})
processor
.
info
.
get_supported_mm_limits
=
mock_supported_mm_limits
if
is_valid
:
exc_ctx
=
nullcontext
()
else
:
exc_ctx
=
pytest
.
raises
(
ValueError
,
match
=
"this model only supports"
)
with
exc_ctx
:
profiler
.
get_dummy_data
(
model_config
.
max_model_len
)
@
pytest
.
mark
.
parametrize
(
"model_id"
,
[
"llava-hf/llava-v1.6-mistral-7b-hf"
])
@
pytest
.
mark
.
parametrize
(
(
"num_images"
,
"limit"
,
"is_valid"
),
[(
0
,
0
,
True
),
(
0
,
1
,
True
),
(
1
,
0
,
False
),
(
1
,
1
,
True
),
(
1
,
2
,
True
),
(
2
,
1
,
False
),
(
2
,
2
,
True
)],
)
def
test_limit_mm_per_prompt_apply
(
model_id
,
num_images
,
limit
,
is_valid
):
limit_mm_per_prompt
=
{
"image"
:
limit
}
model_config
=
ModelConfig
(
model
=
model_id
,
task
=
"auto"
,
tokenizer
=
model_id
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
seed
=
0
,
dtype
=
"half"
,
revision
=
None
,
limit_mm_per_prompt
=
limit_mm_per_prompt
,
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
model_config
,
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
),
)
rng
=
np
.
random
.
RandomState
(
0
)
image
=
random_image
(
rng
,
min_wh
=
128
,
max_wh
=
256
)
if
num_images
==
0
:
mm_data
=
{}
elif
num_images
==
1
:
mm_data
=
{
"image"
:
image
}
else
:
mm_data
=
{
"image"
:
[
image
]
*
num_images
}
if
is_valid
:
exc_ctx
=
nullcontext
()
else
:
exc_ctx
=
pytest
.
raises
(
ValueError
,
match
=
f
"passed
{
num_images
}
image"
)
with
exc_ctx
:
processor
.
apply
(
"<image>"
*
num_images
,
mm_data
=
mm_data
,
hf_processor_mm_kwargs
=
{},
)
tests/multimodal/test_utils.py
View file @
afd0da21
...
...
@@ -2,7 +2,7 @@ import base64
import
mimetypes
import
os
from
tempfile
import
NamedTemporaryFile
,
TemporaryDirectory
from
typing
import
Dict
,
Tuple
from
typing
import
TYPE_CHECKING
,
Dict
,
NamedTuple
,
Optional
,
Tuple
import
numpy
as
np
import
pytest
...
...
@@ -11,10 +11,16 @@ import os
from
PIL
import
Image
,
ImageChops
from
transformers
import
AutoConfig
,
AutoTokenizer
from
vllm.multimodal.utils
import
(
async_fetch_image
,
fetch_image
,
from
vllm.multimodal.inputs
import
PlaceholderRange
from
vllm.multimodal.utils
import
(
MediaConnector
,
merge_and_sort_multimodal_metadata
,
repeat_and_pad_placeholder_tokens
)
from
..utils
import
models_path_prefix
,
urls_port
if
TYPE_CHECKING
:
from
vllm.multimodal.hasher
import
MultiModalHashDict
from
vllm.multimodal.inputs
import
MultiModalPlaceholderDict
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS
=
[
f
"http://localhost:
{
urls_port
}
/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
,
...
...
@@ -26,7 +32,12 @@ TEST_IMAGE_URLS = [
@
pytest
.
fixture
(
scope
=
"module"
)
def
url_images
()
->
Dict
[
str
,
Image
.
Image
]:
return
{
image_url
:
fetch_image
(
image_url
)
for
image_url
in
TEST_IMAGE_URLS
}
connector
=
MediaConnector
()
return
{
image_url
:
connector
.
fetch_image
(
image_url
)
for
image_url
in
TEST_IMAGE_URLS
}
def
get_supported_suffixes
()
->
Tuple
[
str
,
...]:
...
...
@@ -46,8 +57,10 @@ def _image_equals(a: Image.Image, b: Image.Image) -> bool:
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"image_url"
,
TEST_IMAGE_URLS
)
async
def
test_fetch_image_http
(
image_url
:
str
):
image_sync
=
fetch_image
(
image_url
)
image_async
=
await
async_fetch_image
(
image_url
)
connector
=
MediaConnector
()
image_sync
=
connector
.
fetch_image
(
image_url
)
image_async
=
await
connector
.
fetch_image_async
(
image_url
)
assert
_image_equals
(
image_sync
,
image_async
)
...
...
@@ -56,6 +69,7 @@ async def test_fetch_image_http(image_url: str):
@
pytest
.
mark
.
parametrize
(
"suffix"
,
get_supported_suffixes
())
async
def
test_fetch_image_base64
(
url_images
:
Dict
[
str
,
Image
.
Image
],
image_url
:
str
,
suffix
:
str
):
connector
=
MediaConnector
()
url_image
=
url_images
[
image_url
]
try
:
...
...
@@ -78,48 +92,49 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
base64_image
=
base64
.
b64encode
(
f
.
read
()).
decode
(
"utf-8"
)
data_url
=
f
"data:
{
mime_type
}
;base64,
{
base64_image
}
"
data_image_sync
=
fetch_image
(
data_url
)
data_image_sync
=
connector
.
fetch_image
(
data_url
)
if
_image_equals
(
url_image
,
Image
.
open
(
f
)):
assert
_image_equals
(
url_image
,
data_image_sync
)
else
:
pass
# Lossy format; only check that image can be opened
data_image_async
=
await
async_
fetch_image
(
data_url
)
data_image_async
=
await
connector
.
fetch_image
_async
(
data_url
)
assert
_image_equals
(
data_image_sync
,
data_image_async
)
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"image_url"
,
TEST_IMAGE_URLS
)
async
def
test_fetch_image_local_files
(
image_url
:
str
):
connector
=
MediaConnector
()
with
TemporaryDirectory
()
as
temp_dir
:
origin_image
=
fetch_image
(
image_url
)
local_connector
=
MediaConnector
(
allowed_local_media_path
=
temp_dir
)
origin_image
=
connector
.
fetch_image
(
image_url
)
origin_image
.
save
(
os
.
path
.
join
(
temp_dir
,
os
.
path
.
basename
(
image_url
)),
quality
=
100
,
icc_profile
=
origin_image
.
info
.
get
(
'icc_profile'
))
image_async
=
await
async_fetch_image
(
f
"file://
{
temp_dir
}
/
{
os
.
path
.
basename
(
image_url
)
}
"
,
allowed_local_media_path
=
temp_dir
)
image_sync
=
fetch_image
(
f
"file://
{
temp_dir
}
/
{
os
.
path
.
basename
(
image_url
)
}
"
,
allowed_local_media_path
=
temp_dir
)
image_async
=
await
local_connector
.
fetch_image_async
(
f
"file://
{
temp_dir
}
/
{
os
.
path
.
basename
(
image_url
)
}
"
)
image_sync
=
local_connector
.
fetch_image
(
f
"file://
{
temp_dir
}
/
{
os
.
path
.
basename
(
image_url
)
}
"
)
# Check that the images are equal
assert
not
ImageChops
.
difference
(
image_sync
,
image_async
).
getbbox
()
with
pytest
.
raises
(
ValueError
):
await
async_fetch_image
(
f
"file://
{
temp_dir
}
/../
{
os
.
path
.
basename
(
image_url
)
}
"
,
allowed_local_media_path
=
temp_dir
)
with
pytest
.
raises
(
ValueError
):
await
async_fetch_image
(
with
pytest
.
raises
(
ValueError
,
match
=
"must be a subpath"
):
await
local_connector
.
fetch_image_async
(
f
"file://
{
temp_dir
}
/../
{
os
.
path
.
basename
(
image_url
)
}
"
)
with
pytest
.
raises
(
RuntimeError
,
match
=
"Cannot load local files"
):
await
connector
.
fetch_image_async
(
f
"file://
{
temp_dir
}
/../
{
os
.
path
.
basename
(
image_url
)
}
"
)
with
pytest
.
raises
(
ValueError
):
fetch_image
(
f
"file://
{
temp_dir
}
/../
{
os
.
path
.
basename
(
image_url
)
}
"
,
allowed_local_media_path
=
temp_dir
)
with
pytest
.
raises
(
ValueError
):
fetch_image
(
f
"file://
{
temp_dir
}
/../
{
os
.
path
.
basename
(
image_url
)
}
"
)
with
pytest
.
raises
(
ValueError
,
match
=
"must be a subpath"
):
local_connector
.
fetch_image
(
f
"file://
{
temp_dir
}
/../
{
os
.
path
.
basename
(
image_url
)
}
"
)
with
pytest
.
raises
(
RuntimeError
,
match
=
"Cannot load local files"
):
connector
.
fetch_image
(
f
"file://
{
temp_dir
}
/../
{
os
.
path
.
basename
(
image_url
)
}
"
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
os
.
path
.
join
(
models_path_prefix
,
"llava-hf/llava-v1.6-mistral-7b-hf"
)])
...
...
@@ -185,3 +200,204 @@ def test_repeat_and_pad_placeholder_tokens(model):
assert
new_prompt
==
expected_prompt
assert
new_token_ids
==
expected_token_ids
assert
ranges
==
expected_ranges
# Used for the next two tests related to `merge_and_sort_multimodal_metadata`.
class
TestCase
(
NamedTuple
):
mm_positions
:
"MultiModalPlaceholderDict"
mm_hashes
:
Optional
[
"MultiModalHashDict"
]
expected_modalities
:
list
[
str
]
expected_ranges
:
list
[
PlaceholderRange
]
expected_hashes
:
Optional
[
list
[
str
]]
def
test_merge_and_sort_multimodal_metadata
():
test_cases
=
[
# Single modality should return result as is but flattened
TestCase
(
mm_positions
=
{
"image"
:
[
PlaceholderRange
(
offset
=
0
,
length
=
2
),
PlaceholderRange
(
offset
=
3
,
length
=
2
),
]
},
mm_hashes
=
{
"image"
:
[
"hash1"
,
"hash2"
]},
expected_modalities
=
[
"image"
],
expected_ranges
=
[
PlaceholderRange
(
offset
=
0
,
length
=
2
),
PlaceholderRange
(
offset
=
3
,
length
=
2
),
],
expected_hashes
=
[
"hash1"
,
"hash2"
],
),
# Single modality without hashes return None for mm hash.
TestCase
(
mm_positions
=
{
"image"
:
[
PlaceholderRange
(
offset
=
0
,
length
=
2
),
PlaceholderRange
(
offset
=
2
,
length
=
2
),
]
},
mm_hashes
=
None
,
expected_modalities
=
[
"image"
],
expected_ranges
=
[
PlaceholderRange
(
offset
=
0
,
length
=
2
),
PlaceholderRange
(
offset
=
2
,
length
=
2
),
],
expected_hashes
=
None
,
),
# Multiple modalities with hashes should return sorted modalities
# and flattened ranges and hashes.
TestCase
(
mm_positions
=
{
"image"
:
[
PlaceholderRange
(
offset
=
7
,
length
=
4
),
PlaceholderRange
(
offset
=
11
,
length
=
5
),
],
"audio"
:
[
PlaceholderRange
(
offset
=
0
,
length
=
2
),
PlaceholderRange
(
offset
=
2
,
length
=
3
),
]
},
mm_hashes
=
{
"image"
:
[
"image_hash1"
,
"image_hash2"
],
"audio"
:
[
"audio_hash1"
,
"audio_hash2"
],
},
expected_modalities
=
[
"audio"
,
"image"
],
expected_ranges
=
[
PlaceholderRange
(
offset
=
0
,
length
=
2
),
PlaceholderRange
(
offset
=
2
,
length
=
3
),
PlaceholderRange
(
offset
=
7
,
length
=
4
),
PlaceholderRange
(
offset
=
11
,
length
=
5
),
],
expected_hashes
=
[
"audio_hash1"
,
"audio_hash2"
,
"image_hash1"
,
"image_hash2"
],
),
# Multiple modalities without hashes should return sorted modalities
# and flattened ranges and None.
TestCase
(
mm_positions
=
{
"image"
:
[
PlaceholderRange
(
offset
=
7
,
length
=
4
),
PlaceholderRange
(
offset
=
11
,
length
=
5
),
],
"audio"
:
[
PlaceholderRange
(
offset
=
0
,
length
=
2
),
PlaceholderRange
(
offset
=
2
,
length
=
3
),
]
},
mm_hashes
=
None
,
expected_modalities
=
[
"audio"
,
"image"
],
expected_ranges
=
[
PlaceholderRange
(
offset
=
0
,
length
=
2
),
PlaceholderRange
(
offset
=
2
,
length
=
3
),
PlaceholderRange
(
offset
=
7
,
length
=
4
),
PlaceholderRange
(
offset
=
11
,
length
=
5
),
],
expected_hashes
=
None
,
),
# Three modalities
TestCase
(
mm_positions
=
{
"image"
:
[
PlaceholderRange
(
offset
=
15
,
length
=
7
),
PlaceholderRange
(
offset
=
22
,
length
=
8
),
],
"audio"
:
[
PlaceholderRange
(
offset
=
0
,
length
=
2
),
],
"video"
:
[
PlaceholderRange
(
offset
=
3
,
length
=
4
),
PlaceholderRange
(
offset
=
7
,
length
=
5
),
PlaceholderRange
(
offset
=
12
,
length
=
6
),
]
},
mm_hashes
=
{
"image"
:
[
"image_hash1"
,
"image_hash2"
],
"audio"
:
[
"audio_hash1"
],
"video"
:
[
"video_hash1"
,
"video_hash2"
,
"video_hash3"
]
},
expected_modalities
=
[
"audio"
,
"video"
,
"image"
],
expected_ranges
=
[
PlaceholderRange
(
offset
=
0
,
length
=
2
),
PlaceholderRange
(
offset
=
3
,
length
=
4
),
PlaceholderRange
(
offset
=
7
,
length
=
5
),
PlaceholderRange
(
offset
=
12
,
length
=
6
),
PlaceholderRange
(
offset
=
15
,
length
=
7
),
PlaceholderRange
(
offset
=
22
,
length
=
8
),
],
expected_hashes
=
[
"audio_hash1"
,
"video_hash1"
,
"video_hash2"
,
"video_hash3"
,
"image_hash1"
,
"image_hash2"
],
),
]
for
(
mm_positions
,
mm_hashes
,
expected_modalities
,
expected_ranges
,
expected_hashes
)
in
test_cases
:
modalities
,
ranges
,
hashes
=
merge_and_sort_multimodal_metadata
(
mm_positions
,
mm_hashes
)
assert
modalities
==
expected_modalities
assert
ranges
==
expected_ranges
assert
hashes
==
expected_hashes
def
test_merge_and_sort_multimodal_metadata_with_interleaving
():
test_cases
=
[
# <image> <audio> <image> <audio>
TestCase
(
mm_positions
=
{
"image"
:
[
PlaceholderRange
(
offset
=
0
,
length
=
4
),
PlaceholderRange
(
offset
=
8
,
length
=
2
),
],
"audio"
:
[
PlaceholderRange
(
offset
=
5
,
length
=
2
),
PlaceholderRange
(
offset
=
11
,
length
=
4
),
]
},
mm_hashes
=
{
"image"
:
[
"image_hash1"
,
"image_hash2"
],
"audio"
:
[
"audio_hash1"
,
"audio_hash2"
],
},
expected_modalities
=
[],
expected_ranges
=
[],
expected_hashes
=
None
,
),
# <image> <image> <video> <audio> <image>
TestCase
(
mm_positions
=
{
"image"
:
[
PlaceholderRange
(
offset
=
0
,
length
=
2
),
PlaceholderRange
(
offset
=
2
,
length
=
3
),
PlaceholderRange
(
offset
=
20
,
length
=
4
),
],
"audio"
:
[
PlaceholderRange
(
offset
=
5
,
length
=
2
),
],
"video"
:
[
PlaceholderRange
(
offset
=
8
,
length
=
5
),
]
},
mm_hashes
=
None
,
expected_modalities
=
[],
expected_ranges
=
[],
expected_hashes
=
None
,
),
]
for
case
in
test_cases
:
with
pytest
.
raises
(
ValueError
)
as
ex_info
:
merge_and_sort_multimodal_metadata
(
case
.
mm_positions
,
case
.
mm_hashes
)
assert
"Interleaved mixed-modality"
in
str
(
ex_info
.
value
)
tests/multimodal/utils.py
0 → 100644
View file @
afd0da21
import
numpy
as
np
from
PIL
import
Image
def
random_image
(
rng
:
np
.
random
.
RandomState
,
min_wh
:
int
,
max_wh
:
int
):
w
,
h
=
rng
.
randint
(
min_wh
,
max_wh
,
size
=
(
2
,
))
arr
=
rng
.
randint
(
0
,
255
,
size
=
(
w
,
h
,
3
),
dtype
=
np
.
uint8
)
return
Image
.
fromarray
(
arr
)
def
random_video
(
rng
:
np
.
random
.
RandomState
,
min_frames
:
int
,
max_frames
:
int
,
min_wh
:
int
,
max_wh
:
int
,
):
# Temporary workaround for https://github.com/huggingface/transformers/issues/35412
num_frames
=
rng
.
randint
(
min_frames
,
max_frames
)
num_frames
=
(
num_frames
//
2
)
*
2
w
,
h
=
rng
.
randint
(
min_wh
,
max_wh
,
size
=
(
2
,
))
return
rng
.
randint
(
0
,
255
,
size
=
(
num_frames
,
w
,
h
,
3
),
dtype
=
np
.
uint8
)
def
random_audio
(
rng
:
np
.
random
.
RandomState
,
min_len
:
int
,
max_len
:
int
,
sr
:
int
,
):
audio_len
=
rng
.
randint
(
min_len
,
max_len
)
return
rng
.
rand
(
audio_len
),
sr
tests/neuron/test_prefix_prefill.py
0 → 100644
View file @
afd0da21
import
random
from
typing
import
Optional
import
pytest
import
torch
import
torch.nn.functional
as
F
class
BlockDiagonalCausalFromBottomRightMask
:
@
staticmethod
def
_from_seqlens
(
query_lens
,
seq_lens
,
block_size
=
None
):
from
torch
import
logical_and
,
logical_or
contexted
=
block_size
is
None
context_lens
=
torch
.
tensor
(
seq_lens
)
-
torch
.
tensor
(
query_lens
)
n_queries
=
sum
(
query_lens
)
num_seqs
=
len
(
query_lens
)
if
contexted
:
key_lens_blockaligned
=
seq_lens
else
:
n_blocks_per_seq
=
(
context_lens
+
block_size
-
1
)
//
block_size
offset_per_seq
=
n_blocks_per_seq
*
block_size
key_lens_blockaligned
=
offset_per_seq
[:
num_seqs
].
tolist
()
n_keys
=
sum
(
key_lens_blockaligned
)
a
=
(
torch
.
arange
(
n_queries
).
reshape
(
n_queries
,
1
).
expand
(
n_queries
,
n_keys
))
b
=
torch
.
arange
(
n_keys
).
reshape
(
1
,
n_keys
).
expand
(
n_queries
,
n_keys
)
q_cumsum
=
torch
.
tensor
([
0
]
+
query_lens
).
cumsum
(
dim
=
0
)
k_cumsum
=
torch
.
tensor
([
0
]
+
key_lens_blockaligned
).
cumsum
(
dim
=
0
)
prior_mask
=
torch
.
zeros
(
n_queries
,
n_keys
)
new_masks
:
list
[
torch
.
Tensor
]
=
[]
for
seq_id
in
range
(
num_seqs
):
ri
=
q_cumsum
[
seq_id
]
ci
=
k_cumsum
[
seq_id
]
nr
=
query_lens
[
seq_id
]
if
contexted
:
nc
=
seq_lens
[
seq_id
]
a_offset
=
ci
+
nc
-
ri
-
nr
new_mask
=
(
a
+
a_offset
)
>=
b
else
:
nc
=
context_lens
[
seq_id
]
a_offset
=
ci
+
nc
-
1
new_mask
=
a_offset
>=
b
left_mask
=
b
>=
ci
top_mask
=
a
>=
ri
bottom_mask
=
a
<
(
ri
+
nr
)
new_mask
=
logical_and
(
logical_and
(
logical_and
(
new_mask
,
left_mask
),
top_mask
),
bottom_mask
,
)
prior_mask
=
logical_or
(
prior_mask
,
new_mask
)
new_masks
=
new_masks
+
[
new_mask
]
return
prior_mask
@
staticmethod
def
from_seqlens
(
query_lens
,
seq_lens
,
block_size
=
None
):
contexted
=
block_size
is
None
if
contexted
:
prior_mask
=
BlockDiagonalCausalFromBottomRightMask
.
_from_seqlens
(
query_lens
,
seq_lens
)
active_mask
=
None
else
:
prior_mask
=
BlockDiagonalCausalFromBottomRightMask
.
_from_seqlens
(
query_lens
,
seq_lens
,
block_size
)
active_mask
=
BlockDiagonalCausalFromBottomRightMask
.
_from_seqlens
(
query_lens
,
query_lens
)
return
prior_mask
,
active_mask
def
ref_softmax
(
x
:
torch
.
Tensor
,
dim
:
int
,
mixed_precision
=
False
,
return_max_reduce
=
False
):
max_value
=
torch
.
amax
(
x
,
dim
=
dim
,
keepdims
=
True
)
exp
=
torch
.
exp
(
x
-
max_value
)
if
mixed_precision
:
sum_value
=
torch
.
sum
(
exp
.
astype
(
torch
.
float32
),
dim
=
dim
,
keepdims
=
True
).
astype
(
x
.
dtype
)
else
:
sum_value
=
torch
.
sum
(
exp
,
dim
=
dim
,
keepdims
=
True
)
if
return_max_reduce
:
return
exp
/
sum_value
,
max_value
,
torch
.
reciprocal
(
sum_value
)
return
exp
/
sum_value
def
ref_masked_attention
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
scale
:
float
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
return_max_reduce
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
scaled_qk
=
scale
*
torch
.
einsum
(
"qhd,khd->hqk"
,
query
,
key
).
float
()
if
attn_mask
is
not
None
:
masked_score
=
scaled_qk
+
attn_mask
.
float
()
if
return_max_reduce
:
norm_score
,
cached_max
,
cached_sum_reciprocal
=
ref_softmax
(
masked_score
,
dim
=-
1
,
return_max_reduce
=
True
)
else
:
norm_score
=
ref_softmax
(
masked_score
,
dim
=-
1
)
out
=
torch
.
einsum
(
"hqk,khd->qhd"
,
norm_score
,
value
)
if
return_max_reduce
:
return
(
out
,
cached_max
,
cached_sum_reciprocal
,
norm_score
,
masked_score
,
scaled_qk
,
)
else
:
return
out
def
ref_context_attention
(
query
,
key
,
value
,
query_lens
,
seq_lens
,
head_size
,
num_kv_heads
,
num_heads
,
num_queries_per_kv
,
return_max_reduce
=
False
,
):
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
if
num_queries_per_kv
>
1
:
# Handle MQA and GQA
key
=
torch
.
repeat_interleave
(
key
,
num_queries_per_kv
,
dim
=
1
)
value
=
torch
.
repeat_interleave
(
value
,
num_queries_per_kv
,
dim
=
1
)
attn_mask
,
_
=
BlockDiagonalCausalFromBottomRightMask
.
from_seqlens
(
query_lens
,
seq_lens
)
# convert binary mask to -inf values
attn_mask
=
torch
.
logical_not
(
attn_mask
)
attn_mask
=
attn_mask
.
float
()
*
-
30000
output
,
cached_max
,
cached_sum_reciprocal
,
lse
,
masked_score
,
scaled_qk
=
(
ref_masked_attention
(
query
,
key
,
value
,
scale
,
attn_mask
,
return_max_reduce
=
return_max_reduce
,
))
output
=
output
.
unsqueeze
(
1
)
if
return_max_reduce
:
return
(
output
,
cached_max
,
cached_sum_reciprocal
,
lse
,
masked_score
,
scaled_qk
,
)
else
:
return
output
@
pytest
.
mark
.
parametrize
(
"num_heads,num_queries_per_kv,head_size,mixed_precision"
,
[
(
4
,
2
,
8
,
False
),
(
4
,
2
,
8
,
True
),
(
32
,
8
,
64
,
True
),
],
)
@
torch
.
inference_mode
()
def
test_contexted_kv_attention
(
num_heads
:
int
,
num_queries_per_kv
:
int
,
head_size
:
int
,
mixed_precision
:
bool
,
)
->
None
:
import
os
import
torch_xla.core.xla_model
as
xm
from
vllm.attention.ops.nki_flash_attn
import
flash_attn_varlen_nkifunc
device
=
xm
.
xla_device
()
os
.
environ
[
"NEURON_CC_FLAGS"
]
=
(
" --model-type=transformer -O1 "
" --internal-hlo2tensorizer-options='--verify-hlo' "
)
random
.
seed
(
0
)
torch
.
manual_seed
(
0
)
torch
.
set_printoptions
(
sci_mode
=
False
)
min_ctx_len
=
2
max_ctx_len
=
64
min_query_len
=
2
max_query_len
=
64
prefill_batch_size
=
2
decode_batch_size
=
6
batch_size
=
prefill_batch_size
+
decode_batch_size
block_size
=
32
max_model_len
=
(
max_query_len
+
max_ctx_len
)
*
4
max_block_per_request
=
max_model_len
//
block_size
dtype
=
torch
.
float32
cache_size
=
(
batch_size
*
max_block_per_request
)
+
2
ctx_lens
=
[
random
.
randint
(
min_ctx_len
,
max_ctx_len
)
for
_
in
range
(
prefill_batch_size
)
]
+
[
random
.
randint
(
min_ctx_len
,
max_ctx_len
)
for
_
in
range
(
decode_batch_size
)
]
query_lens
=
[
random
.
randint
(
min_query_len
,
max_query_len
)
for
_
in
range
(
prefill_batch_size
)
]
+
[
1
for
_
in
range
(
decode_batch_size
)]
seq_lens
=
[
a
+
b
for
a
,
b
in
zip
(
query_lens
,
ctx_lens
)]
num_kv_heads
=
num_heads
//
num_queries_per_kv
num_tokens
=
sum
(
query_lens
)
query
=
torch
.
empty
(
num_tokens
,
num_heads
,
head_size
,
dtype
=
dtype
)
query
.
uniform_
(
-
1
,
1
)
torch
.
empty
(
num_tokens
,
num_heads
,
head_size
,
dtype
=
dtype
)
kv
=
torch
.
empty
(
sum
(
seq_lens
),
2
,
num_kv_heads
,
head_size
,
dtype
=
dtype
)
kv
.
uniform_
(
-
1
,
1
)
key
,
value
=
kv
.
unbind
(
dim
=
1
)
k_cache
=
torch
.
zeros
(
cache_size
,
block_size
,
num_kv_heads
,
head_size
,
dtype
=
dtype
)
v_cache
=
torch
.
zeros
(
cache_size
,
block_size
,
num_kv_heads
,
head_size
,
dtype
=
dtype
)
k
=
torch
.
zeros
(
sum
(
query_lens
),
num_kv_heads
,
head_size
,
dtype
=
dtype
)
v
=
torch
.
zeros
(
sum
(
query_lens
),
num_kv_heads
,
head_size
,
dtype
=
dtype
)
values
=
torch
.
arange
(
0
,
cache_size
,
dtype
=
torch
.
long
)
values
=
values
[
torch
.
randperm
(
cache_size
)]
block_table
=
values
[:
batch_size
*
max_block_per_request
].
view
(
batch_size
,
max_block_per_request
)
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
long
)
b_ctx_len
=
torch
.
tensor
(
ctx_lens
,
dtype
=
torch
.
long
)
b_start_loc
=
torch
.
cumsum
(
torch
.
tensor
([
0
]
+
query_lens
[:
-
1
],
dtype
=
torch
.
long
),
dim
=
0
)
# copy kv to cache
b_seq_start_loc
=
torch
.
cumsum
(
torch
.
tensor
([
0
]
+
seq_lens
[:
-
1
],
dtype
=
torch
.
long
),
dim
=
0
)
for
i
in
range
(
batch_size
):
for
j
in
range
(
query_lens
[
i
]):
k
[
b_start_loc
[
i
]
+
j
].
copy_
(
key
[
b_seq_start_loc
[
i
]
+
b_ctx_len
[
i
]
+
j
])
v
[
b_start_loc
[
i
]
+
j
].
copy_
(
value
[
b_seq_start_loc
[
i
]
+
b_ctx_len
[
i
]
+
j
])
cur_ctx
=
0
block_id
=
0
while
cur_ctx
<
b_ctx_len
[
i
]:
start_loc
=
b_seq_start_loc
[
i
]
+
cur_ctx
if
cur_ctx
+
block_size
>
b_ctx_len
[
i
]:
end_loc
=
b_seq_start_loc
[
i
]
+
b_ctx_len
[
i
]
else
:
end_loc
=
start_loc
+
block_size
start_slot
=
block_table
[
i
,
block_id
]
*
block_size
end_slot
=
start_slot
+
end_loc
-
start_loc
k_cache
.
view
(
-
1
,
num_kv_heads
,
head_size
)[
start_slot
:
end_slot
].
copy_
(
key
[
start_loc
:
end_loc
])
v_cache
.
view
(
-
1
,
num_kv_heads
,
head_size
)[
start_slot
:
end_slot
].
copy_
(
value
[
start_loc
:
end_loc
])
cur_ctx
+=
block_size
block_id
+=
1
(
output_ref
,
cached_max
,
cached_sum_reciprocal
,
lse
,
masked_score
,
scaled_qk
,
)
=
ref_context_attention
(
query
,
key
,
value
,
query_lens
,
seq_lens
,
head_size
,
num_kv_heads
,
num_heads
,
num_queries_per_kv
,
return_max_reduce
=
True
,
)
# build neuron program
return_debug_tensors
=
False
B_P_SIZE
=
128
LARGE_TILE_SZ
=
2048
max_num_queries
=
(
(
sum
(
query_lens
)
+
block_size
-
1
)
//
block_size
)
*
block_size
def
get_active_block_tables
(
block_tables
,
query_lens
,
seq_lens
,
block_size
,
num_blocks
):
context_lens
=
seq_lens
-
query_lens
blocks_per_seq
=
(
context_lens
+
block_size
-
1
)
//
block_size
num_seqs
=
len
(
seq_lens
)
active_blocks
:
list
[
int
]
=
[]
for
seq_id
in
range
(
num_seqs
):
active_blocks
=
(
active_blocks
+
block_tables
[
seq_id
,
:
blocks_per_seq
[
seq_id
]].
tolist
())
return
F
.
pad
(
torch
.
tensor
(
active_blocks
),
(
0
,
num_blocks
-
len
(
active_blocks
)),
"constant"
,
0
,
)
def
shift_bit_length
(
x
):
return
1
<<
(
x
-
1
).
bit_length
()
# calculate input shapes
max_num_queries_shifted
=
shift_bit_length
(
max_num_queries
)
max_num_queries_factor
=
B_P_SIZE
//
max_num_queries_shifted
max_num_queries_padded
=
max_num_queries_shifted
*
max_num_queries_factor
assert
(
max_num_queries_padded
==
B_P_SIZE
),
"invalid {max_num_queries_padded=}"
head_size_padded
=
B_P_SIZE
context_lens
=
torch
.
tensor
(
seq_lens
)
-
torch
.
tensor
(
query_lens
)
num_active_blocks_shifted
=
shift_bit_length
(
((
context_lens
+
block_size
-
1
)
//
block_size
).
sum
().
item
())
num_active_blocks_factor
=
(
LARGE_TILE_SZ
//
block_size
//
num_active_blocks_shifted
)
num_active_blocks
=
num_active_blocks_shifted
*
num_active_blocks_factor
assert
(
num_active_blocks
*
block_size
)
==
LARGE_TILE_SZ
,
"invalid {num_active_blocks=}"
context_kv_len
=
num_active_blocks
*
block_size
assert
context_kv_len
==
LARGE_TILE_SZ
,
f
"invalid
{
context_kv_len
=
}
"
# pad QKV tensors
pad_dims
=
(
0
,
head_size_padded
-
query
.
shape
[
2
],
0
,
0
,
0
,
max_num_queries_padded
-
query
.
shape
[
0
],
)
query
=
F
.
pad
(
query
,
pad_dims
,
"constant"
,
0
)
k
=
F
.
pad
(
k
,
pad_dims
,
"constant"
,
0
)
v
=
F
.
pad
(
v
,
pad_dims
,
"constant"
,
0
)
k_cache
=
F
.
pad
(
k_cache
,
(
0
,
head_size_padded
-
head_size
),
"constant"
,
0
)
v_cache
=
F
.
pad
(
v_cache
,
(
0
,
head_size_padded
-
head_size
),
"constant"
,
0
)
# permute QKV tensors
# query: (1, n_heads, d, seq_q)
# key: (1, n_kv_heads, d, seq_k)
# value: (1, n_kv_heads, seq_v, d)
query
=
query
.
unsqueeze
(
0
).
permute
(
0
,
2
,
3
,
1
).
contiguous
()
k
=
k
.
unsqueeze
(
0
).
permute
(
0
,
2
,
3
,
1
).
contiguous
()
v
=
v
.
unsqueeze
(
0
).
permute
(
0
,
2
,
1
,
3
).
contiguous
()
# transform block table
active_block_table
=
get_active_block_tables
(
block_table
,
torch
.
tensor
(
query_lens
),
torch
.
tensor
(
seq_lens
),
block_size
,
num_active_blocks
,
)
# Build attention masks
prior_mask
,
active_mask
=
(
BlockDiagonalCausalFromBottomRightMask
.
from_seqlens
(
query_lens
,
seq_lens
,
block_size
=
block_size
))
attn_mask
=
torch
.
concat
(
[
F
.
pad
(
prior_mask
,
(
0
,
context_kv_len
-
prior_mask
.
shape
[
1
],
0
,
B_P_SIZE
-
prior_mask
.
shape
[
0
],
),
"constant"
,
0
,
).
bool
(),
F
.
pad
(
active_mask
,
(
0
,
B_P_SIZE
-
active_mask
.
shape
[
1
],
0
,
B_P_SIZE
-
active_mask
.
shape
[
0
],
),
"constant"
,
0
,
).
bool
(),
],
dim
=
1
,
)
input_args
=
(
query
.
to
(
device
=
device
),
k
.
to
(
device
=
device
),
v
.
to
(
device
=
device
),
k_cache
.
to
(
device
=
device
),
v_cache
.
to
(
device
=
device
),
active_block_table
.
to
(
torch
.
int32
).
to
(
device
=
device
),
attn_mask
.
to
(
device
=
device
),
)
input_kwargs
=
dict
(
n_kv_head
=
num_kv_heads
,
head_size
=
head_size
,
mixed_precision
=
mixed_precision
,
)
if
return_debug_tensors
:
output_nki
,
*
debug_tensors
=
flash_attn_varlen_nkifunc
(
*
input_args
,
**
input_kwargs
)
else
:
output_nki
=
flash_attn_varlen_nkifunc
(
*
input_args
,
**
input_kwargs
)
debug_tensors
=
[]
output_nki
=
torch
.
tensor
(
output_nki
).
cpu
()
debug_tensors
=
[
torch
.
tensor
(
dt
).
cpu
()
for
dt
in
debug_tensors
]
num_actual_tokens
=
sum
(
query_lens
)
print
(
f
"
{
num_actual_tokens
=
}
"
)
# - o: shape (bs, n_heads, seq_q, d) -> (bs, seq_q, n_heads, d)
output_nki
=
output_nki
.
permute
(
0
,
2
,
1
,
3
)[:,
:,
:,
:
head_size
].
cpu
()[
0
,
:
num_actual_tokens
,
:,
:]
output_ref_padded
=
F
.
pad
(
output_ref
,
(
0
,
0
,
0
,
0
,
0
,
0
,
0
,
max_num_queries_padded
-
output_ref
.
shape
[
0
]),
"constant"
,
0
,
)
output_ref
=
output_ref_padded
.
transpose
(
0
,
1
)[
0
,
:
num_actual_tokens
,
:,
:]
torch
.
testing
.
assert_close
(
output_nki
,
output_ref
,
atol
=
1e-2
,
rtol
=
0
)
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py
View file @
afd0da21
...
...
@@ -2,15 +2,17 @@ from typing import Optional
import
torch
from
vllm.model_executor.models.llava
import
(
LlavaForConditionalGeneration
,
from
vllm.model_executor.models.llava
import
(
LlavaDummyInputsBuilder
,
LlavaForConditionalGeneration
,
LlavaMultiModalProcessor
,
get_max_llava_image_tokens
)
LlavaProcessingInfo
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_llava_image_tokens
)
@
MULTIMODAL_REGISTRY
.
register_processor
(
LlavaMultiModalProcessor
)
@
MULTIMODAL_REGISTRY
.
register_processor
(
LlavaMultiModalProcessor
,
info
=
LlavaProcessingInfo
,
dummy_inputs
=
LlavaDummyInputsBuilder
)
class
MyLlava
(
LlavaForConditionalGeneration
):
def
compute_logits
(
...
...
tests/plugins/vllm_add_dummy_platform/setup.py
0 → 100644
View file @
afd0da21
from
setuptools
import
setup
setup
(
name
=
'vllm_add_dummy_platform'
,
version
=
'0.1'
,
packages
=
[
'vllm_add_dummy_platform'
],
entry_points
=
{
'vllm.platform_plugins'
:
[
"dummy_platform_plugin = vllm_add_dummy_platform:dummy_platform_plugin"
# noqa
]
})
tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/__init__.py
0 → 100644
View file @
afd0da21
from
typing
import
Optional
def
dummy_platform_plugin
()
->
Optional
[
str
]:
return
"vllm_add_dummy_platform.dummy_platform.DummyPlatform"
tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_attention_backend.py
0 → 100644
View file @
afd0da21
from
vllm.attention.backends.flash_attn
import
FlashAttentionBackend
class
DummyAttentionBackend
(
FlashAttentionBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"Dummy_Backend"
tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py
0 → 100644
View file @
afd0da21
from
vllm.platforms.cuda
import
CudaPlatform
class
DummyPlatform
(
CudaPlatform
):
device_name
=
"DummyDevice"
def
get_attn_backend_cls
(
self
,
backend_name
,
head_size
,
dtype
,
kv_cache_dtype
,
block_size
,
use_v1
,
use_mla
):
return
"vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend"
# noqa E501
tests/plugins_tests/test_platform_plugins.py
0 → 100644
View file @
afd0da21
import
torch
from
tests.kernels.utils
import
override_backend_env_variable
from
vllm.attention.selector
import
get_attn_backend
from
vllm.utils
import
STR_INVALID_VAL
def
test_platform_plugins
():
# simulate workload by running an example
import
runpy
current_file
=
__file__
import
os
example_file
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
current_file
))),
"examples"
,
"offline_inference/basic.py"
)
runpy
.
run_path
(
example_file
)
# check if the plugin is loaded correctly
from
vllm.platforms
import
_init_trace
,
current_platform
assert
current_platform
.
device_name
==
"DummyDevice"
,
(
f
"Expected DummyDevice, got
{
current_platform
.
device_name
}
, "
"possibly because current_platform is imported before the plugin"
f
" is loaded. The first import:
\n
{
_init_trace
}
"
)
def
test_oot_attention_backend
(
monkeypatch
):
# ignore the backend env variable if it is set
override_backend_env_variable
(
monkeypatch
,
STR_INVALID_VAL
)
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
assert
backend
.
get_name
()
==
"Dummy_Backend"
tests/quantization/test_compressed_tensors.py
View file @
afd0da21
...
...
@@ -33,50 +33,55 @@ from ..utils import models_path_prefix
def
test_compressed_tensors_w8a8_static_setup
(
vllm_runner
,
model_args
):
model_path
,
strategy
,
quant_type
,
shape_0
,
is_symmetric
=
model_args
with
vllm_runner
(
model_path
,
enforce_eager
=
True
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
o_proj
=
layer
.
self_attn
.
o_proj
gate_up_proj
=
layer
.
mlp
.
gate_up_proj
down_proj
=
layer
.
mlp
.
down_proj
# assert zp for symmetric and asymmetric cases
def
zp_valid
(
zp
:
Optional
[
torch
.
Tensor
]):
if
is_symmetric
:
return
zp
is
None
return
zp
is
not
None
and
zp
.
dtype
is
torch
.
int32
assert
zp_valid
(
qkv_proj
.
input_zero_point
)
assert
zp_valid
(
o_proj
.
input_zero_point
)
assert
zp_valid
(
gate_up_proj
.
input_zero_point
)
assert
zp_valid
(
down_proj
.
input_zero_point
)
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
o_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
gate_up_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
down_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW8A8Int8
)
assert
qkv_proj
.
scheme
.
strategy
==
strategy
assert
qkv_proj
.
scheme
.
is_static_input_scheme
expected_type
=
torch
.
int8
assert
qkv_proj
.
weight
.
dtype
is
expected_type
assert
o_proj
.
weight
.
dtype
is
expected_type
assert
gate_up_proj
.
weight
.
dtype
is
expected_type
if
qkv_proj
.
scheme
.
strategy
==
"tensor"
:
# Make sure it is a channelwise buffer
# After running process_weights_after_loading
assert
len
(
qkv_proj
.
weight_scale
.
shape
)
==
2
assert
qkv_proj
.
weight_scale
.
shape
[
0
]
==
shape_0
assert
qkv_proj
.
weight_scale
.
shape
[
1
]
==
1
assert
qkv_proj
.
weight_scale
.
dtype
is
torch
.
float32
assert
qkv_proj
.
input_scale
.
dtype
is
torch
.
float32
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
o_proj
=
layer
.
self_attn
.
o_proj
gate_up_proj
=
layer
.
mlp
.
gate_up_proj
down_proj
=
layer
.
mlp
.
down_proj
# assert zp for symmetric and asymmetric cases
def
zp_valid
(
zp
:
Optional
[
torch
.
Tensor
]):
if
is_symmetric
:
return
zp
is
None
return
zp
is
not
None
and
zp
.
dtype
is
torch
.
int32
assert
zp_valid
(
qkv_proj
.
input_zero_point
)
assert
zp_valid
(
o_proj
.
input_zero_point
)
assert
zp_valid
(
gate_up_proj
.
input_zero_point
)
assert
zp_valid
(
down_proj
.
input_zero_point
)
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
o_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
gate_up_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
down_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW8A8Int8
)
assert
qkv_proj
.
scheme
.
strategy
==
strategy
assert
qkv_proj
.
scheme
.
is_static_input_scheme
expected_type
=
torch
.
int8
assert
qkv_proj
.
weight
.
dtype
is
expected_type
assert
o_proj
.
weight
.
dtype
is
expected_type
assert
gate_up_proj
.
weight
.
dtype
is
expected_type
if
qkv_proj
.
scheme
.
strategy
==
"tensor"
:
# Make sure it is a channelwise buffer
# After running process_weights_after_loading
assert
len
(
qkv_proj
.
weight_scale
.
shape
)
==
2
assert
qkv_proj
.
weight_scale
.
shape
[
0
]
==
shape_0
assert
qkv_proj
.
weight_scale
.
shape
[
1
]
==
1
assert
qkv_proj
.
weight_scale
.
dtype
is
torch
.
float32
assert
qkv_proj
.
input_scale
.
dtype
is
torch
.
float32
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
([
"Hello my name is"
],
max_tokens
=
20
)
assert
output
...
...
@@ -132,16 +137,20 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner):
def
test_compressed_tensors_w8a8_dynamic_per_token
(
vllm_runner
,
model_args
):
model_path
,
strategy
=
model_args
with
vllm_runner
(
model_path
,
dtype
=
torch
.
float16
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW8A8Int8
)
assert
not
qkv_proj
.
scheme
.
is_static_input_scheme
assert
qkv_proj
.
scheme
.
strategy
==
strategy
assert
qkv_proj
.
weight
.
dtype
is
torch
.
int8
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW8A8Int8
)
assert
not
qkv_proj
.
scheme
.
is_static_input_scheme
assert
qkv_proj
.
scheme
.
strategy
==
strategy
assert
qkv_proj
.
weight
.
dtype
is
torch
.
int8
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
([
"Hello my name is"
],
max_tokens
=
20
)
assert
output
...
...
@@ -157,19 +166,24 @@ def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
def
test_compressed_tensors_wNa16
(
vllm_runner
,
wNa16_args
):
model
,
strategy
,
group
,
pack_factor
=
wNa16_args
with
vllm_runner
(
model
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsWNA16
)
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsWNA16
)
assert
qkv_proj
.
scheme
.
strategy
==
strategy
assert
qkv_proj
.
scheme
.
group_size
==
(
-
1
if
group
is
None
else
group
)
assert
qkv_proj
.
scheme
.
strategy
==
strategy
assert
qkv_proj
.
scheme
.
group_size
==
(
-
1
if
group
is
None
else
group
)
assert
qkv_proj
.
weight_packed
.
dtype
is
torch
.
int32
assert
qkv_proj
.
weight_scale
.
dtype
is
torch
.
float16
assert
qkv_proj
.
scheme
.
pack_factor
==
pack_factor
assert
qkv_proj
.
weight_packed
.
dtype
is
torch
.
int32
assert
qkv_proj
.
weight_scale
.
dtype
is
torch
.
float16
assert
qkv_proj
.
scheme
.
pack_factor
==
pack_factor
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
assert
output
...
...
@@ -180,14 +194,18 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
def
test_compressed_tensors_w4a16_marlin24
(
vllm_runner
):
model_path
=
os
.
path
.
join
(
models_path_prefix
,
"nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
)
with
vllm_runner
(
model_path
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW4A16Sparse24
)
assert
qkv_proj
.
weight_packed
.
dtype
is
torch
.
int32
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW4A16Sparse24
)
assert
qkv_proj
.
weight_packed
.
dtype
is
torch
.
int32
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
assert
output
...
...
@@ -198,23 +216,27 @@ def test_compressed_tensors_w4a16_marlin24(vllm_runner):
def
test_compressed_tensors_fp8
(
vllm_runner
):
model_path
=
os
.
path
.
join
(
models_path_prefix
,
"nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
)
with
vllm_runner
(
model_path
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
(
CompressedTensorsW8A8Fp8
,
CompressedTensorsW8A16Fp8
))
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
qkv_proj
.
input_scale
.
dtype
is
torch
.
float32
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
(
CompressedTensorsW8A8Fp8
,
CompressedTensorsW8A16Fp8
))
if
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW8A8Fp8
):
assert
len
(
qkv_proj
.
input_scale
.
shape
)
==
0
assert
qkv_proj
.
weight
.
dtype
is
torch
.
float8_e4m3fn
assert
qkv_proj
.
weight_scale
.
dtype
is
torch
.
float32
assert
len
(
qkv_proj
.
weight_scale
.
shape
)
==
0
assert
qkv_proj
.
input_scale
.
dtype
is
torch
.
float32
if
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW8A8Fp8
):
assert
len
(
qkv_proj
.
input_scale
.
shape
)
==
0
assert
qkv_proj
.
weight
.
dtype
is
torch
.
float8_e4m3fn
assert
qkv_proj
.
weight_scale
.
dtype
is
torch
.
float32
assert
len
(
qkv_proj
.
weight_scale
.
shape
)
==
0
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
assert
output
...
...
@@ -259,12 +281,15 @@ def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy):
def
test_compressed_tensors_2of4_quant_fp8
(
vllm_runner
,
args_2of4
):
model
,
weight_strategy
,
input_strategy
=
args_2of4
with
vllm_runner
(
model
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
qkv_proj
.
scheme
.
weights_dtype
==
torch
.
float8_e4m3fn
_test_2of4_quant_models
(
qkv_proj
,
weight_strategy
,
input_strategy
)
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
qkv_proj
.
scheme
.
weights_dtype
==
torch
.
float8_e4m3fn
_test_2of4_quant_models
(
qkv_proj
,
weight_strategy
,
input_strategy
)
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
print
(
output
)
...
...
@@ -284,40 +309,49 @@ def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
def
test_compressed_tensors_2of4_quant_int8
(
vllm_runner
,
args_2of4
):
model
,
weight_strategy
,
input_strategy
=
args_2of4
with
vllm_runner
(
model
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
qkv_proj
.
scheme
.
weights_dtype
==
torch
.
int8
_test_2of4_quant_models
(
qkv_proj
,
weight_strategy
,
input_strategy
)
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
qkv_proj
.
scheme
.
weights_dtype
==
torch
.
int8
_test_2of4_quant_models
(
qkv_proj
,
weight_strategy
,
input_strategy
)
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
print
(
output
)
assert
output
@
pytest
.
mark
.
skipif
(
not
sparse_cutlass_supported
(),
reason
=
"Sparse FP8 is not yet supported on this GPU type."
)
@
pytest
.
mark
.
skip
(
reason
=
"2of4 sparse w16a16 CUTLASS produces bad output."
)
@
pytest
.
mark
.
skipif
(
not
sparse_cutlass_supported
(),
reason
=
"2of4 Sparse is not yet supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"args_2of4"
,
[(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor"
)])
def
test_compressed_tensors_2of4_sparse
(
vllm_runner
,
args_2of4
):
model
=
args_2of4
with
vllm_runner
(
model
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensors24
)
assert
qkv_proj
.
scheme
.
weight_quant
is
None
assert
qkv_proj
.
scheme
.
input_quant
is
None
assert
not
qkv_proj
.
scheme
.
quantized
assert
qkv_proj
.
quant_method
.
quantization_config
.
sparsity_scheme_map
sparsity_map
=
qkv_proj
.
quant_method
.
quantization_config
.
sparsity_scheme_map
# noqa: E501
assert
sparsity_map
.
get
(
"Linear"
).
format
==
"dense"
assert
sparsity_map
.
get
(
"Linear"
).
sparsity_structure
==
"2:4"
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensors24
)
assert
qkv_proj
.
scheme
.
weight_quant
is
None
assert
qkv_proj
.
scheme
.
input_quant
is
None
assert
not
qkv_proj
.
scheme
.
quantized
assert
qkv_proj
.
quant_method
.
quantization_config
.
sparsity_scheme_map
sparsity_map
=
qkv_proj
.
quant_method
.
quantization_config
.
sparsity_scheme_map
# noqa: E501
assert
sparsity_map
.
get
(
"Linear"
).
format
==
"dense"
assert
sparsity_map
.
get
(
"Linear"
).
sparsity_structure
==
"2:4"
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
print
(
output
)
...
...
tests/quantization/test_fp8.py
View file @
afd0da21
...
...
@@ -52,13 +52,17 @@ KV_CACHE_MODELS = [
def
test_kv_cache_model_load_and_run
(
vllm_runner
,
model_id
:
str
):
with
vllm_runner
(
model_id
,
kv_cache_dtype
=
"fp8"
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
attn
=
model
.
model
.
layers
[
0
].
self_attn
.
attn
assert
isinstance
(
attn
.
quant_method
,
Fp8KVCacheMethod
)
# NOTE: it is valid for scales to be 1.0 (default value), but we know
# these checkpoints have scales < 1.0
assert
0.0
<
attn
.
_k_scale
<
1.0
assert
0.0
<
attn
.
_v_scale
<
1.0
def
check_model
(
model
):
attn
=
model
.
model
.
layers
[
0
].
self_attn
.
attn
assert
isinstance
(
attn
.
quant_method
,
Fp8KVCacheMethod
)
# NOTE: it is valid for scales to be 1.0 (default value), but
# we know these checkpoints have scales < 1.0
assert
0.0
<
attn
.
_k_scale
<
1.0
assert
0.0
<
attn
.
_v_scale
<
1.0
llm
.
apply_model
(
check_model
)
# note: this does not test accuracy, just that we can run through
# see lm-eval tests for accuracy
...
...
@@ -80,22 +84,24 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
quantization
=
"fp8"
,
kv_cache_dtype
=
kv_cache_dtype
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
fc1
=
model
.
model
.
decoder
.
layers
[
0
].
fc1
assert
isinstance
(
fc1
.
quant_method
,
Fp8LinearMethod
)
if
kv_cache_dtype
==
"fp8"
:
attn
=
model
.
model
.
decoder
.
layers
[
0
].
self_attn
.
attn
assert
isinstance
(
attn
.
quant_method
,
Fp8KVCacheMethod
)
assert
attn
.
_k_scale
==
1.0
assert
attn
.
_v_scale
==
1.0
if
current_platform
.
has_device_capability
(
89
)
and
not
force_marlin
:
# For GPUs with hardware support, we keep weights in fp8
assert
fc1
.
weight
.
dtype
==
torch
.
float8_e4m3fn
else
:
# For GPUs without hardware support, we pack the fp8 weights
# for weight-only quantization using Marlin kernels
assert
fc1
.
weight
.
dtype
==
torch
.
int32
def
check_model
(
model
):
fc1
=
model
.
model
.
decoder
.
layers
[
0
].
fc1
assert
isinstance
(
fc1
.
quant_method
,
Fp8LinearMethod
)
if
kv_cache_dtype
==
"fp8"
:
attn
=
model
.
model
.
decoder
.
layers
[
0
].
self_attn
.
attn
assert
isinstance
(
attn
.
quant_method
,
Fp8KVCacheMethod
)
assert
attn
.
_k_scale
==
1.0
assert
attn
.
_v_scale
==
1.0
if
current_platform
.
has_device_capability
(
89
)
and
not
force_marlin
:
# For GPUs with hardware support, we keep weights in fp8
assert
fc1
.
weight
.
dtype
==
torch
.
float8_e4m3fn
else
:
# For GPUs without hardware support, we pack the fp8 weights
# for weight-only quantization using Marlin kernels
assert
fc1
.
weight
.
dtype
==
torch
.
int32
llm
.
apply_model
(
check_model
)
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
)
or
is_hip
(),
...
...
tests/quantization/test_lm_head.py
View file @
afd0da21
...
...
@@ -30,20 +30,23 @@ def test_lm_head(
model_lm_head_quant
:
Tuple
[
str
,
bool
],
)
->
None
:
model
,
lm_head_quantized
=
model_lm_head_quant
vllm_model
=
vllm_runner
(
model
,
dtype
=
torch
.
float16
,
max_model_len
=
2048
)
lm_head_layer
=
(
vllm_model
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
.
lm_head
)
if
lm_head_quantized
:
assert
isinstance
(
lm_head_layer
.
linear_method
,
(
GPTQLinearMethod
,
GPTQMarlinLinearMethod
,
MarlinLinearMethod
))
else
:
assert
isinstance
(
lm_head_layer
.
linear_method
,
UnquantizedEmbeddingMethod
)
print
(
vllm_model
.
generate_greedy
(
prompts
=
[
"Hello my name is"
],
max_tokens
=
10
)[
0
][
1
])
del
vllm_model
with
vllm_runner
(
model
,
dtype
=
torch
.
float16
,
max_model_len
=
2048
)
as
vllm_model
:
def
check_model
(
model
):
lm_head_layer
=
model
.
lm_head
if
lm_head_quantized
:
assert
isinstance
(
lm_head_layer
.
linear_method
,
(
GPTQLinearMethod
,
GPTQMarlinLinearMethod
,
MarlinLinearMethod
))
else
:
assert
isinstance
(
lm_head_layer
.
linear_method
,
UnquantizedEmbeddingMethod
)
vllm_model
.
apply_model
(
check_model
)
print
(
vllm_model
.
generate_greedy
(
prompts
=
[
"Hello my name is"
],
max_tokens
=
10
)[
0
][
1
])
Prev
1
…
20
21
22
23
24
25
26
27
28
…
30
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment