Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
51c2e1fc
Unverified
Commit
51c2e1fc
authored
Nov 10, 2024
by
Cyrus Leung
Committed by
GitHub
Nov 09, 2024
Browse files
[CI/Build] Split up models tests (#10069)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
b09895a6
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
111 additions
and
125 deletions
+111
-125
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+14
-10
pyproject.toml
pyproject.toml
+1
-0
tests/models/decoder_only/language/test_aqlm.py
tests/models/decoder_only/language/test_aqlm.py
+1
-0
tests/models/decoder_only/language/test_fp8.py
tests/models/decoder_only/language/test_fp8.py
+1
-0
tests/models/decoder_only/language/test_gguf.py
tests/models/decoder_only/language/test_gguf.py
+16
-19
tests/models/decoder_only/language/test_gptq_marlin.py
tests/models/decoder_only/language/test_gptq_marlin.py
+1
-0
tests/models/decoder_only/language/test_gptq_marlin_24.py
tests/models/decoder_only/language/test_gptq_marlin_24.py
+1
-0
tests/models/decoder_only/language/test_granite.py
tests/models/decoder_only/language/test_granite.py
+2
-1
tests/models/decoder_only/language/test_granitemoe.py
tests/models/decoder_only/language/test_granitemoe.py
+0
-39
tests/models/decoder_only/language/test_modelopt.py
tests/models/decoder_only/language/test_modelopt.py
+1
-0
tests/models/decoder_only/language/test_models.py
tests/models/decoder_only/language/test_models.py
+1
-3
tests/models/decoder_only/vision_language/mm_processor_kwargs/test_llava_next.py
...ly/vision_language/mm_processor_kwargs/test_llava_next.py
+3
-1
tests/models/decoder_only/vision_language/mm_processor_kwargs/test_phi3v.py
...er_only/vision_language/mm_processor_kwargs/test_phi3v.py
+2
-1
tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen2_vl.py
...only/vision_language/mm_processor_kwargs/test_qwen2_vl.py
+11
-4
tests/models/decoder_only/vision_language/test_awq.py
tests/models/decoder_only/vision_language/test_awq.py
+11
-8
tests/models/decoder_only/vision_language/test_intern_vit.py
tests/models/decoder_only/vision_language/test_intern_vit.py
+9
-10
tests/models/decoder_only/vision_language/test_models.py
tests/models/decoder_only/vision_language/test_models.py
+23
-23
vllm/config.py
vllm/config.py
+8
-1
vllm/model_executor/models/fuyu.py
vllm/model_executor/models/fuyu.py
+2
-4
vllm/model_executor/models/internlm2_ve.py
vllm/model_executor/models/internlm2_ve.py
+3
-1
No files found.
.buildkite/test-pipeline.yaml
View file @
51c2e1fc
...
@@ -305,7 +305,7 @@ steps:
...
@@ -305,7 +305,7 @@ steps:
##### models test #####
##### models test #####
-
label
:
Basic Models Test
#
3
min
-
label
:
Basic Models Test
#
10
min
source_file_dependencies
:
source_file_dependencies
:
-
vllm/
-
vllm/
-
tests/models
-
tests/models
...
@@ -314,23 +314,24 @@ steps:
...
@@ -314,23 +314,24 @@ steps:
-
pytest -v -s models/test_oot_registration.py
# it needs a clean process
-
pytest -v -s models/test_oot_registration.py
# it needs a clean process
-
pytest -v -s models/*.py --ignore=models/test_oot_registration.py
-
pytest -v -s models/*.py --ignore=models/test_oot_registration.py
-
label
:
Decoder-only Language Models Test (Standard)
#
35
min
-
label
:
Decoder-only Language Models Test (Standard)
#
18
min
#mirror_hardwares: [amd]
#mirror_hardwares: [amd]
source_file_dependencies
:
source_file_dependencies
:
-
vllm/
-
vllm/
-
tests/models/decoder_only/language
-
tests/models/decoder_only/language
commands
:
commands
:
-
pytest -v -s models/decoder_only/language/test_models.py
-
pytest -v -s models/decoder_only/language -m core_model
-
pytest -v -s models/decoder_only/language -m quant_model
-
label
:
Decoder-only Language Models Test (Extended)
#
1h20
min
-
label
:
Decoder-only Language Models Test (Extended)
#
46
min
nightly
:
true
nightly
:
true
source_file_dependencies
:
source_file_dependencies
:
-
vllm/
-
vllm/
-
tests/models/decoder_only/language
-
tests/models/decoder_only/language
commands
:
commands
:
-
pytest -v -s models/decoder_only/language -
-ign
ore
=
model
s/decoder_only/language/tes
t_model
s.py
-
pytest -v -s models/decoder_only/language -
m 'not c
ore
_
model
and not quan
t_model
'
-
label
:
Decoder-only Multi-Modal Models Test (Standard)
# 2
6
min
-
label
:
Decoder-only Multi-Modal Models Test (Standard)
# 2
2
min
#mirror_hardwares: [amd]
#mirror_hardwares: [amd]
source_file_dependencies
:
source_file_dependencies
:
-
vllm/
-
vllm/
...
@@ -339,21 +340,24 @@ steps:
...
@@ -339,21 +340,24 @@ steps:
commands
:
commands
:
-
pytest -v -s models/decoder_only/audio_language -m core_model
-
pytest -v -s models/decoder_only/audio_language -m core_model
-
pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m core_model
-
pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m core_model
# No tests under this group for now
# - pytest -v -s models/decoder_only/audio_language -m quant_model
-
pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m quant_model
-
label
:
Decoder-only Multi-Modal Models Test (Extended)
-
label
:
Decoder-only Multi-Modal Models Test (Extended)
# 1h10m
nightly
:
true
nightly
:
true
source_file_dependencies
:
source_file_dependencies
:
-
vllm/
-
vllm/
-
tests/models/decoder_only/audio_language
-
tests/models/decoder_only/audio_language
-
tests/models/decoder_only/vision_language
-
tests/models/decoder_only/vision_language
commands
:
commands
:
-
pytest -v -s models/decoder_only/audio_language -m 'not core_model'
-
pytest -v -s models/decoder_only/audio_language -m 'not core_model
and not quant_model
'
# HACK - run phi3v tests separately to sidestep this transformers bug
# HACK - run phi3v tests separately to sidestep this transformers bug
# https://github.com/huggingface/transformers/issues/34307
# https://github.com/huggingface/transformers/issues/34307
-
pytest -v -s models/decoder_only/vision_language/test_phi3v.py
-
pytest -v -s models/decoder_only/vision_language/test_phi3v.py
-
pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model'
-
pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model
and not quant_model
'
-
label
:
Other Models Test
#
6
min
-
label
:
Other Models Test
#
20
min
#mirror_hardwares: [amd]
#mirror_hardwares: [amd]
source_file_dependencies
:
source_file_dependencies
:
-
vllm/
-
vllm/
...
...
pyproject.toml
View file @
51c2e1fc
...
@@ -95,6 +95,7 @@ markers = [
...
@@ -95,6 +95,7 @@ markers = [
"skip_global_cleanup"
,
"skip_global_cleanup"
,
"core_model: enable this model test in each PR instead of only nightly"
,
"core_model: enable this model test in each PR instead of only nightly"
,
"cpu_model: enable this model test in CPU tests"
,
"cpu_model: enable this model test in CPU tests"
,
"quant_model: run this model test under Quantized category"
,
"distributed_2_gpus: run this test only in distributed tests for 2 GPUs"
,
"distributed_2_gpus: run this test only in distributed tests for 2 GPUs"
,
"skip_v1: do not run this test with v1"
,
"skip_v1: do not run this test with v1"
,
]
]
tests/models/decoder_only/language/test_aqlm.py
View file @
51c2e1fc
...
@@ -38,6 +38,7 @@ ground_truth_generations = [
...
@@ -38,6 +38,7 @@ ground_truth_generations = [
]
]
@
pytest
.
mark
.
quant_model
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"aqlm"
),
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"aqlm"
),
reason
=
"AQLM is not supported on this GPU type."
)
reason
=
"AQLM is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"
])
...
...
tests/models/decoder_only/language/test_fp8.py
View file @
51c2e1fc
...
@@ -15,6 +15,7 @@ from ...utils import check_logprobs_close
...
@@ -15,6 +15,7 @@ from ...utils import check_logprobs_close
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"true"
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"true"
@
pytest
.
mark
.
quant_model
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
),
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
),
reason
=
"fp8 is not supported on this GPU type."
)
reason
=
"fp8 is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
...
tests/models/decoder_only/language/test_gguf.py
View file @
51c2e1fc
...
@@ -17,26 +17,21 @@ os.environ["TOKENIZERS_PARALLELISM"] = "true"
...
@@ -17,26 +17,21 @@ os.environ["TOKENIZERS_PARALLELISM"] = "true"
MAX_MODEL_LEN
=
1024
MAX_MODEL_LEN
=
1024
# FIXME: Move this to confest
MODELS
=
[
(
"meta-llama/Llama-3.2-1B-Instruct"
,
hf_hub_download
(
"bartowski/Llama-3.2-1B-Instruct-GGUF"
,
filename
=
"Llama-3.2-1B-Instruct-Q4_K_M.gguf"
)),
(
"meta-llama/Llama-3.2-1B-Instruct"
,
hf_hub_download
(
"bartowski/Llama-3.2-1B-Instruct-GGUF"
,
filename
=
"Llama-3.2-1B-Instruct-IQ4_XS.gguf"
)),
(
"Qwen/Qwen2-1.5B-Instruct"
,
hf_hub_download
(
"Qwen/Qwen2-1.5B-Instruct-GGUF"
,
filename
=
"qwen2-1_5b-instruct-q4_k_m.gguf"
)),
(
"Qwen/Qwen2-1.5B-Instruct"
,
hf_hub_download
(
"legraphista/Qwen2-1.5B-Instruct-IMat-GGUF"
,
filename
=
"Qwen2-1.5B-Instruct.IQ4_XS.gguf"
)),
]
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"gguf"
),
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"gguf"
),
reason
=
"gguf is not supported on this GPU type."
)
reason
=
"gguf is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
((
"original_model"
,
"gguf_id"
,
"gguf_path"
),
[
(
"meta-llama/Llama-3.2-1B-Instruct"
,
"bartowski/Llama-3.2-1B-Instruct-GGUF"
,
"Llama-3.2-1B-Instruct-Q4_K_M.gguf"
),
(
"meta-llama/Llama-3.2-1B-Instruct"
,
"bartowski/Llama-3.2-1B-Instruct-GGUF"
,
"Llama-3.2-1B-Instruct-IQ4_XS.gguf"
),
(
"Qwen/Qwen2-1.5B-Instruct"
,
"Qwen/Qwen2-1.5B-Instruct-GGUF"
,
"qwen2-1_5b-instruct-q4_k_m.gguf"
),
(
"Qwen/Qwen2-1.5B-Instruct"
,
"legraphista/Qwen2-1.5B-Instruct-IMat-GGUF"
,
"Qwen2-1.5B-Instruct.IQ4_XS.gguf"
),
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
...
@@ -45,7 +40,9 @@ def test_models(
...
@@ -45,7 +40,9 @@ def test_models(
num_gpus_available
,
num_gpus_available
,
vllm_runner
,
vllm_runner
,
example_prompts
,
example_prompts
,
model
,
original_model
,
gguf_id
,
gguf_path
,
dtype
:
str
,
dtype
:
str
,
max_tokens
:
int
,
max_tokens
:
int
,
num_logprobs
:
int
,
num_logprobs
:
int
,
...
@@ -54,7 +51,7 @@ def test_models(
...
@@ -54,7 +51,7 @@ def test_models(
if
num_gpus_available
<
tp_size
:
if
num_gpus_available
<
tp_size
:
pytest
.
skip
(
f
"Not enough GPUs for tensor parallelism
{
tp_size
}
"
)
pytest
.
skip
(
f
"Not enough GPUs for tensor parallelism
{
tp_size
}
"
)
original_model
,
gguf_model
=
model
gguf_model
=
hf_hub_download
(
gguf_id
,
filename
=
gguf_path
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
original_model
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
original_model
)
messages
=
[[{
messages
=
[[{
...
...
tests/models/decoder_only/language/test_gptq_marlin.py
View file @
51c2e1fc
...
@@ -33,6 +33,7 @@ MODELS = [
...
@@ -33,6 +33,7 @@ MODELS = [
]
]
@
pytest
.
mark
.
quant_model
@
pytest
.
mark
.
flaky
(
reruns
=
3
)
@
pytest
.
mark
.
flaky
(
reruns
=
3
)
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"gptq_marlin"
),
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"gptq_marlin"
),
reason
=
"gptq_marlin is not supported on this GPU type."
)
reason
=
"gptq_marlin is not supported on this GPU type."
)
...
...
tests/models/decoder_only/language/test_gptq_marlin_24.py
View file @
51c2e1fc
...
@@ -38,6 +38,7 @@ model_pairs = [
...
@@ -38,6 +38,7 @@ model_pairs = [
]
]
@
pytest
.
mark
.
quant_model
@
pytest
.
mark
.
flaky
(
reruns
=
2
)
@
pytest
.
mark
.
flaky
(
reruns
=
2
)
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"gptq_marlin_24"
),
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"gptq_marlin_24"
),
reason
=
"Marlin24 is not supported on this GPU type."
)
reason
=
"Marlin24 is not supported on this GPU type."
)
...
...
tests/models/decoder_only/language/test_granite.py
View file @
51c2e1fc
...
@@ -7,7 +7,9 @@ import pytest
...
@@ -7,7 +7,9 @@ import pytest
from
...utils
import
check_logprobs_close
from
...utils
import
check_logprobs_close
MODELS
=
[
MODELS
=
[
# TODO(sang): Sliding window should be tested separately.
"ibm/PowerLM-3b"
,
"ibm/PowerLM-3b"
,
"ibm/PowerMoE-3b"
,
]
]
...
@@ -24,7 +26,6 @@ def test_models(
...
@@ -24,7 +26,6 @@ def test_models(
max_tokens
:
int
,
max_tokens
:
int
,
num_logprobs
:
int
,
num_logprobs
:
int
,
)
->
None
:
)
->
None
:
# TODO(sang): Sliding window should be tested separately.
with
hf_runner
(
model
,
dtype
=
dtype
)
as
hf_model
:
with
hf_runner
(
model
,
dtype
=
dtype
)
as
hf_model
:
hf_outputs
=
hf_model
.
generate_greedy_logprobs_limit
(
hf_outputs
=
hf_model
.
generate_greedy_logprobs_limit
(
example_prompts
,
max_tokens
,
num_logprobs
)
example_prompts
,
max_tokens
,
num_logprobs
)
...
...
tests/models/decoder_only/language/test_granitemoe.py
deleted
100644 → 0
View file @
b09895a6
"""Compare the outputs of HF and vLLM for Granite models using greedy sampling.
Run `pytest tests/models/test_granite.py`.
"""
import
pytest
from
...utils
import
check_logprobs_close
MODELS
=
[
"ibm/PowerMoE-3b"
,
]
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_models
(
hf_runner
,
vllm_runner
,
example_prompts
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
,
)
->
None
:
with
hf_runner
(
model
,
dtype
=
dtype
)
as
hf_model
:
hf_outputs
=
hf_model
.
generate_greedy_logprobs_limit
(
example_prompts
,
max_tokens
,
num_logprobs
)
with
vllm_runner
(
model
,
dtype
=
dtype
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
check_logprobs_close
(
outputs_0_lst
=
hf_outputs
,
outputs_1_lst
=
vllm_outputs
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
)
tests/models/decoder_only/language/test_modelopt.py
View file @
51c2e1fc
...
@@ -39,6 +39,7 @@ EXPECTED_STRS_MAP = {
...
@@ -39,6 +39,7 @@ EXPECTED_STRS_MAP = {
@
pytest
.
mark
.
skip
(
@
pytest
.
mark
.
skip
(
reason
=
reason
=
"Prevent unstable test based on golden strings from breaking the build."
)
"Prevent unstable test based on golden strings from breaking the build."
)
@
pytest
.
mark
.
quant_model
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
),
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
),
reason
=
"fp8 is not supported on this GPU type."
)
reason
=
"fp8 is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"model_name"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model_name"
,
MODELS
)
...
...
tests/models/decoder_only/language/test_models.py
View file @
51c2e1fc
"""Compare the outputs of HF and vLLM when using greedy sampling.
"""Compare the outputs of HF and vLLM when using greedy sampling.
This test only tests small models. Big models such as 7B should be tested from
test_big_models.py because it could use a larger instance to run tests.
Run `pytest tests/models/test_models.py`.
Run `pytest tests/models/test_models.py`.
"""
"""
import
pytest
import
pytest
...
@@ -35,6 +32,7 @@ if not current_platform.is_cpu():
...
@@ -35,6 +32,7 @@ if not current_platform.is_cpu():
target_dtype
=
"half"
target_dtype
=
"half"
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
target_dtype
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
target_dtype
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
32
])
...
...
tests/models/decoder_only/vision_language/mm_processor_kwargs/test_llava_next.py
View file @
51c2e1fc
...
@@ -56,11 +56,13 @@ def test_dummy_data_for_llava_next_feature_size(dummy_data_for_llava_next,
...
@@ -56,11 +56,13 @@ def test_dummy_data_for_llava_next_feature_size(dummy_data_for_llava_next,
ctx
.
model_config
.
hf_config
.
image_grid_pinpoints
=
gridpoints
ctx
.
model_config
.
hf_config
.
image_grid_pinpoints
=
gridpoints
seq_len
=
5000
# bigger than the max feature size for any image
seq_len
=
5000
# bigger than the max feature size for any image
seq_data
,
mm_data
=
dummy_data_for_llava_next
(
du
mm
y
_data
=
dummy_data_for_llava_next
(
ctx
,
ctx
,
seq_len
=
seq_len
,
seq_len
=
seq_len
,
mm_counts
=
{
"image"
:
1
},
mm_counts
=
{
"image"
:
1
},
)
)
seq_data
=
dummy_data
.
seq_data
mm_data
=
dummy_data
.
multi_modal_data
# The dummy data dims should match the gridpoint with the biggest feat size
# The dummy data dims should match the gridpoint with the biggest feat size
assert
mm_data
[
"image"
].
height
==
expected_size
[
0
]
assert
mm_data
[
"image"
].
height
==
expected_size
[
0
]
...
...
tests/models/decoder_only/vision_language/mm_processor_kwargs/test_phi3v.py
View file @
51c2e1fc
...
@@ -131,12 +131,13 @@ def test_dummy_data_override(dummy_data_for_phi3v, model: str, num_crops: int,
...
@@ -131,12 +131,13 @@ def test_dummy_data_override(dummy_data_for_phi3v, model: str, num_crops: int,
mm_processor_kwargs
=
None
,
mm_processor_kwargs
=
None
,
)
)
sequence
_data
,
_
,
=
dummy_data_for_phi3v
(
dummy
_data
=
dummy_data_for_phi3v
(
ctx
=
ctx
,
ctx
=
ctx
,
seq_len
=
8192
,
# Should be bigger than num_imgs * toks_per_img
seq_len
=
8192
,
# Should be bigger than num_imgs * toks_per_img
mm_counts
=
{
"image"
:
num_imgs
},
mm_counts
=
{
"image"
:
num_imgs
},
num_crops
=
num_crops
,
num_crops
=
num_crops
,
)
)
sequence_data
=
dummy_data
.
seq_data
# Ensure we have the right number of placeholders per num_crops size
# Ensure we have the right number of placeholders per num_crops size
img_tok_count
=
sequence_data
.
get_token_ids
().
count
(
_IMAGE_TOKEN_ID
)
img_tok_count
=
sequence_data
.
get_token_ids
().
count
(
_IMAGE_TOKEN_ID
)
assert
img_tok_count
==
toks_per_img
*
num_imgs
assert
img_tok_count
==
toks_per_img
*
num_imgs
...
...
tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen2_vl.py
View file @
51c2e1fc
...
@@ -86,10 +86,17 @@ def test_qwen2_vl_dummy_data(dummy_data_for_qwen2_vl,
...
@@ -86,10 +86,17 @@ def test_qwen2_vl_dummy_data(dummy_data_for_qwen2_vl,
# NOTE: video value is required, but isn't actually used
# NOTE: video value is required, but isn't actually used
# when making the dummy data except for error handling currently
# when making the dummy data except for error handling currently
seq_data
,
mm_data
=
dummy_data_for_qwen2_vl
(
qwen2_vl_context
,
seq_len
,
{
dummy_data
=
dummy_data_for_qwen2_vl
(
"image"
:
1
,
ctx
=
qwen2_vl_context
,
"video"
:
0
seq_len
=
seq_len
,
},
**
mm_processor_kwargs
)
mm_counts
=
{
"image"
:
1
,
"video"
:
0
},
**
mm_processor_kwargs
,
)
seq_data
=
dummy_data
.
seq_data
mm_data
=
dummy_data
.
multi_modal_data
# Ensure we have the right number of placeholders for min/max pixel values
# Ensure we have the right number of placeholders for min/max pixel values
assert
seq_data
.
get_token_ids
().
count
(
image_token_id
)
==
token_count
assert
seq_data
.
get_token_ids
().
count
(
image_token_id
)
==
token_count
...
...
tests/models/decoder_only/vision_language/test_
internvl
.py
→
tests/models/decoder_only/vision_language/test_
awq
.py
View file @
51c2e1fc
from
typing
import
List
,
Optional
,
Tuple
,
Type
from
typing
import
List
,
Optional
,
Type
import
pytest
import
pytest
import
torch
import
torch
...
@@ -19,7 +19,8 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
...
@@ -19,7 +19,8 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
def
run_awq_test
(
def
run_awq_test
(
vllm_runner
:
Type
[
VllmRunner
],
vllm_runner
:
Type
[
VllmRunner
],
image_assets
:
_ImageAssets
,
image_assets
:
_ImageAssets
,
models
:
Tuple
[
str
,
str
],
source_model
:
str
,
quant_model
:
str
,
*
,
*
,
size_factors
:
List
[
float
],
size_factors
:
List
[
float
],
dtype
:
str
,
dtype
:
str
,
...
@@ -28,8 +29,6 @@ def run_awq_test(
...
@@ -28,8 +29,6 @@ def run_awq_test(
tensor_parallel_size
:
int
,
tensor_parallel_size
:
int
,
distributed_executor_backend
:
Optional
[
str
]
=
None
,
distributed_executor_backend
:
Optional
[
str
]
=
None
,
):
):
source_model
,
quant_model
=
models
images
=
[
asset
.
pil_image
for
asset
in
image_assets
]
images
=
[
asset
.
pil_image
for
asset
in
image_assets
]
inputs_per_image
=
[(
inputs_per_image
=
[(
...
@@ -84,8 +83,11 @@ def run_awq_test(
...
@@ -84,8 +83,11 @@ def run_awq_test(
)
)
@
pytest
.
mark
.
quant_model
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"models"
,
[(
"OpenGVLab/InternVL2-2B"
,
"OpenGVLab/InternVL2-2B-AWQ"
)])
(
"source_model"
,
"quant_model"
),
[(
"OpenGVLab/InternVL2-2B"
,
"OpenGVLab/InternVL2-2B-AWQ"
)],
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"size_factors"
,
"size_factors"
,
[
[
...
@@ -103,12 +105,13 @@ def run_awq_test(
...
@@ -103,12 +105,13 @@ def run_awq_test(
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_awq_models
(
vllm_runner
,
image_assets
,
model
s
,
size_factors
,
def
test_awq_models
(
vllm_runner
,
image_assets
,
source_
model
,
quant_model
,
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
)
->
None
:
size_factors
,
dtype
,
max_tokens
,
num_logprobs
)
->
None
:
run_awq_test
(
run_awq_test
(
vllm_runner
,
vllm_runner
,
image_assets
,
image_assets
,
models
,
source_model
,
quant_model
,
size_factors
=
size_factors
,
size_factors
=
size_factors
,
dtype
=
dtype
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
max_tokens
=
max_tokens
,
...
...
tests/models/decoder_only/vision_language/test_intern_vit.py
View file @
51c2e1fc
...
@@ -11,21 +11,17 @@ from ....conftest import _ImageAssets
...
@@ -11,21 +11,17 @@ from ....conftest import _ImageAssets
# we use snapshot_download to prevent conflicts between
# we use snapshot_download to prevent conflicts between
# dynamic_module and trust_remote_code for hf_runner
# dynamic_module and trust_remote_code for hf_runner
DOWNLOAD_PATTERN
=
[
"*.json"
,
"*.py"
,
"*.safetensors"
,
"*.txt"
,
"*.model"
]
DOWNLOAD_PATTERN
=
[
"*.json"
,
"*.py"
,
"*.safetensors"
,
"*.txt"
,
"*.model"
]
models
=
[
snapshot_download
(
"OpenGVLab/InternViT-300M-448px"
,
allow_patterns
=
DOWNLOAD_PATTERN
),
snapshot_download
(
"OpenGVLab/InternViT-6B-448px-V1-5"
,
allow_patterns
=
DOWNLOAD_PATTERN
),
]
def
run_intern_vit_test
(
def
run_intern_vit_test
(
image_assets
:
_ImageAssets
,
image_assets
:
_ImageAssets
,
model
:
str
,
model
_id
:
str
,
*
,
*
,
dtype
:
str
,
dtype
:
str
,
distributed_executor_backend
:
Optional
[
str
]
=
None
,
distributed_executor_backend
:
Optional
[
str
]
=
None
,
):
):
model
=
snapshot_download
(
model_id
,
allow_patterns
=
DOWNLOAD_PATTERN
)
img_processor
=
CLIPImageProcessor
.
from_pretrained
(
model
)
img_processor
=
CLIPImageProcessor
.
from_pretrained
(
model
)
images
=
[
asset
.
pil_image
for
asset
in
image_assets
]
images
=
[
asset
.
pil_image
for
asset
in
image_assets
]
pixel_values
=
[
pixel_values
=
[
...
@@ -67,12 +63,15 @@ def run_intern_vit_test(
...
@@ -67,12 +63,15 @@ def run_intern_vit_test(
assert
cos_similar
(
vllm_output
,
hf_output
).
mean
()
>
0.99
assert
cos_similar
(
vllm_output
,
hf_output
).
mean
()
>
0.99
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
"model_id"
,
[
"OpenGVLab/InternViT-300M-448px"
,
"OpenGVLab/InternViT-6B-448px-V1-5"
,
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
half
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
half
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_models
(
dist_init
,
image_assets
,
model
,
dtype
:
str
)
->
None
:
def
test_models
(
dist_init
,
image_assets
,
model
_id
,
dtype
:
str
)
->
None
:
run_intern_vit_test
(
run_intern_vit_test
(
image_assets
,
image_assets
,
model
,
model
_id
,
dtype
=
dtype
,
dtype
=
dtype
,
)
)
tests/models/decoder_only/vision_language/test_models.py
View file @
51c2e1fc
...
@@ -130,8 +130,8 @@ VLM_TEST_SETTINGS = {
...
@@ -130,8 +130,8 @@ VLM_TEST_SETTINGS = {
max_num_seqs
=
2
,
max_num_seqs
=
2
,
auto_cls
=
AutoModelForVision2Seq
,
auto_cls
=
AutoModelForVision2Seq
,
vllm_output_post_proc
=
model_utils
.
qwen2_vllm_to_hf_output
,
vllm_output_post_proc
=
model_utils
.
qwen2_vllm_to_hf_output
,
marks
=
[
pytest
.
mark
.
core_model
,
pytest
.
mark
.
cpu_model
],
image_size_factors
=
[(),
(
0.25
,),
(
0.25
,
0.25
,
0.25
),
(
0.25
,
0.2
,
0.15
)],
image_size_factors
=
[(),
(
0.25
,),
(
0.25
,
0.25
,
0.25
),
(
0.25
,
0.2
,
0.15
)],
marks
=
[
pytest
.
mark
.
core_model
,
pytest
.
mark
.
cpu_model
],
),
),
#### Extended model tests
#### Extended model tests
"blip2"
:
VLMTestInfo
(
"blip2"
:
VLMTestInfo
(
...
@@ -159,9 +159,9 @@ VLM_TEST_SETTINGS = {
...
@@ -159,9 +159,9 @@ VLM_TEST_SETTINGS = {
dtype
=
"bfloat16"
,
dtype
=
"bfloat16"
,
marks
=
[
marks
=
[
pytest
.
mark
.
skipif
(
pytest
.
mark
.
skipif
(
transformers
.
__version__
.
startswith
(
"4.46"
)
,
transformers
.
__version__
<
"4.46
.2
"
,
reason
=
"Model broken in HF, see huggingface/transformers#34379"
reason
=
"Model broken in HF, see huggingface/transformers#34379"
)
)
,
]
]
),
),
"fuyu"
:
VLMTestInfo
(
"fuyu"
:
VLMTestInfo
(
...
@@ -185,8 +185,8 @@ VLM_TEST_SETTINGS = {
...
@@ -185,8 +185,8 @@ VLM_TEST_SETTINGS = {
max_num_seqs
=
2
,
max_num_seqs
=
2
,
dtype
=
"bfloat16"
,
dtype
=
"bfloat16"
,
get_stop_token_ids
=
lambda
tok
:
[
151329
,
151336
,
151338
],
get_stop_token_ids
=
lambda
tok
:
[
151329
,
151336
,
151338
],
marks
=
[
large_gpu_mark
(
min_gb
=
48
)],
patch_hf_runner
=
model_utils
.
glm_patch_hf_runner
,
patch_hf_runner
=
model_utils
.
glm_patch_hf_runner
,
marks
=
[
large_gpu_mark
(
min_gb
=
48
)],
),
),
"h2ovl"
:
VLMTestInfo
(
"h2ovl"
:
VLMTestInfo
(
models
=
[
models
=
[
...
@@ -205,6 +205,22 @@ VLM_TEST_SETTINGS = {
...
@@ -205,6 +205,22 @@ VLM_TEST_SETTINGS = {
use_tokenizer_eos
=
True
,
use_tokenizer_eos
=
True
,
patch_hf_runner
=
model_utils
.
h2ovl_patch_hf_runner
,
patch_hf_runner
=
model_utils
.
h2ovl_patch_hf_runner
,
),
),
"idefics3"
:
VLMTestInfo
(
models
=
[
"HuggingFaceM4/Idefics3-8B-Llama3"
],
test_type
=
(
VLMTestType
.
IMAGE
,
VLMTestType
.
MULTI_IMAGE
),
prompt_formatter
=
lambda
img_prompt
:
f
"<|begin_of_text|>User:
{
img_prompt
}
<end_of_utterance>
\n
Assistant:"
,
# noqa: E501
img_idx_to_prompt
=
lambda
idx
:
"<image>"
,
max_model_len
=
8192
,
max_num_seqs
=
2
,
auto_cls
=
AutoModelForVision2Seq
,
marks
=
[
pytest
.
mark
.
skipif
(
transformers
.
__version__
<
"4.46.0"
,
reason
=
"Model introduced in HF >= 4.46.0"
),
large_gpu_mark
(
min_gb
=
48
),
],
),
"intern_vl"
:
VLMTestInfo
(
"intern_vl"
:
VLMTestInfo
(
models
=
[
models
=
[
"OpenGVLab/InternVL2-1B"
,
"OpenGVLab/InternVL2-1B"
,
...
@@ -263,7 +279,6 @@ VLM_TEST_SETTINGS = {
...
@@ -263,7 +279,6 @@ VLM_TEST_SETTINGS = {
runner_mm_key
=
"videos"
,
runner_mm_key
=
"videos"
,
)],
)],
),
),
# FIXME
"llava_next_video"
:
VLMTestInfo
(
"llava_next_video"
:
VLMTestInfo
(
models
=
[
"llava-hf/LLaVA-NeXT-Video-7B-hf"
],
models
=
[
"llava-hf/LLaVA-NeXT-Video-7B-hf"
],
test_type
=
VLMTestType
.
VIDEO
,
test_type
=
VLMTestType
.
VIDEO
,
...
@@ -275,7 +290,7 @@ VLM_TEST_SETTINGS = {
...
@@ -275,7 +290,7 @@ VLM_TEST_SETTINGS = {
image_sizes
=
[((
1669
,
2560
),
(
2560
,
1669
),
(
183
,
488
),
(
488
,
183
))],
image_sizes
=
[((
1669
,
2560
),
(
2560
,
1669
),
(
183
,
488
),
(
488
,
183
))],
marks
=
[
marks
=
[
pytest
.
mark
.
skipif
(
pytest
.
mark
.
skipif
(
transformers
.
__version__
.
startswith
(
"4.46"
)
,
transformers
.
__version__
<
"4.46
.2
"
,
reason
=
"Model broken with changes in transformers 4.46"
reason
=
"Model broken with changes in transformers 4.46"
)
)
],
],
...
@@ -316,6 +331,7 @@ VLM_TEST_SETTINGS = {
...
@@ -316,6 +331,7 @@ VLM_TEST_SETTINGS = {
max_model_len
=
8192
,
max_model_len
=
8192
,
max_num_seqs
=
2
,
max_num_seqs
=
2
,
auto_cls
=
AutoModelForVision2Seq
,
auto_cls
=
AutoModelForVision2Seq
,
marks
=
[
large_gpu_mark
(
min_gb
=
48
)],
),
),
"qwen"
:
VLMTestInfo
(
"qwen"
:
VLMTestInfo
(
models
=
[
"Qwen/Qwen-VL"
],
models
=
[
"Qwen/Qwen-VL"
],
...
@@ -327,22 +343,6 @@ VLM_TEST_SETTINGS = {
...
@@ -327,22 +343,6 @@ VLM_TEST_SETTINGS = {
vllm_output_post_proc
=
model_utils
.
qwen_vllm_to_hf_output
,
vllm_output_post_proc
=
model_utils
.
qwen_vllm_to_hf_output
,
prompt_path_encoder
=
model_utils
.
qwen_prompt_path_encoder
,
prompt_path_encoder
=
model_utils
.
qwen_prompt_path_encoder
,
),
),
"idefics3"
:
VLMTestInfo
(
models
=
[
"HuggingFaceM4/Idefics3-8B-Llama3"
],
test_type
=
(
VLMTestType
.
IMAGE
,
VLMTestType
.
MULTI_IMAGE
),
prompt_formatter
=
lambda
img_prompt
:
f
"<|begin_of_text|>User:
{
img_prompt
}
<end_of_utterance>
\n
Assistant:"
,
# noqa: E501
img_idx_to_prompt
=
lambda
idx
:
"<image>"
,
max_model_len
=
8192
,
max_num_seqs
=
2
,
auto_cls
=
AutoModelForVision2Seq
,
marks
=
[
pytest
.
mark
.
skipif
(
transformers
.
__version__
<
"4.46.0"
,
reason
=
"Model introduced in HF >= 4.46.0"
),
large_gpu_mark
(
min_gb
=
48
),
],
),
### Tensor parallel / multi-gpu broadcast tests
### Tensor parallel / multi-gpu broadcast tests
"broadcast-chameleon"
:
VLMTestInfo
(
"broadcast-chameleon"
:
VLMTestInfo
(
models
=
[
"facebook/chameleon-7b"
],
models
=
[
"facebook/chameleon-7b"
],
...
@@ -362,7 +362,7 @@ VLM_TEST_SETTINGS = {
...
@@ -362,7 +362,7 @@ VLM_TEST_SETTINGS = {
reason
=
"Need at least 2 GPUs to run the test."
,
reason
=
"Need at least 2 GPUs to run the test."
,
),
),
pytest
.
mark
.
skipif
(
pytest
.
mark
.
skipif
(
transformers
.
__version__
.
startswith
(
"4.46"
)
,
transformers
.
__version__
<
"4.46
.2
"
,
reason
=
"Model broken in HF, see huggingface/transformers#34379"
reason
=
"Model broken in HF, see huggingface/transformers#34379"
)
)
],
],
...
...
vllm/config.py
View file @
51c2e1fc
import
copy
import
enum
import
enum
import
json
import
json
import
warnings
import
warnings
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
,
replace
from
typing
import
(
TYPE_CHECKING
,
Any
,
ClassVar
,
Dict
,
Final
,
List
,
Literal
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
ClassVar
,
Dict
,
Final
,
List
,
Literal
,
Mapping
,
Optional
,
Set
,
Tuple
,
Type
,
Union
)
Mapping
,
Optional
,
Set
,
Tuple
,
Type
,
Union
)
...
@@ -2078,6 +2079,12 @@ class VllmConfig:
...
@@ -2078,6 +2079,12 @@ class VllmConfig:
return
quant_config
return
quant_config
return
None
return
None
def
with_hf_config
(
self
,
hf_config
:
PretrainedConfig
)
->
"VllmConfig"
:
model_config
=
copy
.
deepcopy
(
self
.
model_config
)
model_config
.
hf_config
=
hf_config
return
replace
(
self
,
model_config
=
model_config
)
def
__post_init__
(
self
):
def
__post_init__
(
self
):
"""Verify configs are valid & consistent with each other.
"""Verify configs are valid & consistent with each other.
"""
"""
...
...
vllm/model_executor/models/fuyu.py
View file @
51c2e1fc
...
@@ -229,7 +229,6 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -229,7 +229,6 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
)
->
None
:
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
)
->
None
:
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
self
.
config
=
config
self
.
config
=
config
...
@@ -246,9 +245,8 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -246,9 +245,8 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
gather_output
=
True
,
gather_output
=
True
,
)
)
self
.
language_model
=
PersimmonForCausalLM
(
config
.
text_config
,
self
.
language_model
=
PersimmonForCausalLM
(
cache_config
=
cache_config
,
vllm_config
.
with_hf_config
(
config
.
text_config
))
quant_config
=
quant_config
)
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
self
.
language_model
.
make_empty_intermediate_tensors
)
...
...
vllm/model_executor/models/internlm2_ve.py
View file @
51c2e1fc
...
@@ -164,10 +164,12 @@ class InternLM2VEForCausalLM(InternLM2ForCausalLM):
...
@@ -164,10 +164,12 @@ class InternLM2VEForCausalLM(InternLM2ForCausalLM):
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
(
vllm_config
,
prefix
=
prefix
)
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
super
().
__init__
(
config
,
cache_config
,
quant_config
)
self
.
model
=
InternLM2VEModel
(
config
,
self
.
model
=
InternLM2VEModel
(
config
,
cache_config
,
cache_config
,
quant_config
,
quant_config
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment