Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b40cf640
Unverified
Commit
b40cf640
authored
Nov 15, 2024
by
Cyrus Leung
Committed by
GitHub
Nov 14, 2024
Browse files
[Model] Support Qwen2 embeddings and use tags to select model tests (#10184)
parent
2885ba0e
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
252 additions
and
178 deletions
+252
-178
.buildkite/run-cpu-test-ppc64le.sh
.buildkite/run-cpu-test-ppc64le.sh
+3
-3
.buildkite/run-cpu-test.sh
.buildkite/run-cpu-test.sh
+3
-3
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+23
-25
docs/source/models/supported_models.rst
docs/source/models/supported_models.rst
+9
-4
tests/models/decoder_only/language/test_jamba.py
tests/models/decoder_only/language/test_jamba.py
+4
-14
tests/models/decoder_only/language/test_mamba.py
tests/models/decoder_only/language/test_mamba.py
+4
-14
tests/models/decoder_only/language/test_models.py
tests/models/decoder_only/language/test_models.py
+43
-28
tests/models/embedding/language/test_cls_models.py
tests/models/embedding/language/test_cls_models.py
+11
-19
tests/models/embedding/language/test_embedding.py
tests/models/embedding/language/test_embedding.py
+22
-20
tests/models/embedding/vision_language/test_llava_next.py
tests/models/embedding/vision_language/test_llava_next.py
+2
-0
tests/models/embedding/vision_language/test_phi3v.py
tests/models/embedding/vision_language/test_phi3v.py
+2
-0
tests/models/encoder_decoder/language/test_bart.py
tests/models/encoder_decoder/language/test_bart.py
+8
-3
tests/models/encoder_decoder/vision_language/test_mllama.py
tests/models/encoder_decoder/vision_language/test_mllama.py
+3
-0
tests/models/registry.py
tests/models/registry.py
+4
-0
tests/models/test_registry.py
tests/models/test_registry.py
+2
-2
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+99
-13
vllm/model_executor/models/qwen2_cls.py
vllm/model_executor/models/qwen2_cls.py
+2
-13
vllm/model_executor/models/qwen2_rm.py
vllm/model_executor/models/qwen2_rm.py
+2
-14
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+6
-3
No files found.
.buildkite/run-cpu-test-ppc64le.sh
View file @
b40cf640
...
...
@@ -27,9 +27,9 @@ function cpu_tests() {
decord einops librosa peft Pillow sentence-transformers soundfile
\
transformers_stream_generator matplotlib datamodel_code_generator
pip install torchvision --index-url https://download.pytorch.org/whl/cpu
pytest -v -s tests/models/
embedding/language
pytest -v -s tests/models/e
ncoder_decoder/language
pytest -v -s tests/models/
d
ecoder_
only
/language
/test
_model
s.py
pytest -v -s tests/models/
decoder_only/language -m cpu_model
pytest -v -s tests/models/e
mbedding/language -m cpu_model
pytest -v -s tests/models/e
n
coder_
decoder
/language
-m cpu
_model
pytest -v -s tests/models/decoder_only/audio_language -m cpu_model
pytest -v -s tests/models/decoder_only/vision_language -m cpu_model"
...
...
.buildkite/run-cpu-test.sh
View file @
b40cf640
...
...
@@ -38,9 +38,9 @@ function cpu_tests() {
decord einops librosa peft Pillow sentence-transformers soundfile
\
transformers_stream_generator matplotlib datamodel_code_generator
pip install torchvision --index-url https://download.pytorch.org/whl/cpu
pytest -v -s tests/models/
embedding/language
pytest -v -s tests/models/e
ncoder_decoder/language
pytest -v -s tests/models/
d
ecoder_
only
/language
/test
_model
s.py
pytest -v -s tests/models/
decoder_only/language -m cpu_model
pytest -v -s tests/models/e
mbedding/language -m cpu_model
pytest -v -s tests/models/e
n
coder_
decoder
/language
-m cpu
_model
pytest -v -s tests/models/decoder_only/audio_language -m cpu_model
pytest -v -s tests/models/decoder_only/vision_language -m cpu_model"
...
...
.buildkite/test-pipeline.yaml
View file @
b40cf640
...
...
@@ -323,62 +323,60 @@ steps:
-
pytest -v -s models/test_registry.py
-
pytest -v -s models/test_initialization.py
-
label
:
Decoder-only
Language Models Test (Standard)
#
18
min
-
label
:
Language Models Test (Standard)
#
42
min
#mirror_hardwares: [amd]
source_file_dependencies
:
-
vllm/
-
tests/models/decoder_only/language
-
tests/models/embedding/language
-
tests/models/encoder_decoder/language
commands
:
-
pytest -v -s models/decoder_only/language -m core_model
-
pytest -v -s models/decoder_only/language -m quant_model
-
pytest -v -s models/decoder_only/language -m 'core_model or quant_model'
-
pytest -v -s models/embedding/language -m core_model
-
pytest -v -s models/embedding/vision_language -m core_model
-
label
:
Decoder-only
Language Models Test (Extended)
#
46
min
-
label
:
Language Models Test (Extended)
#
50
min
nightly
:
true
source_file_dependencies
:
-
vllm/
-
tests/models/decoder_only/language
-
tests/models/embedding/language
-
tests/models/encoder_decoder/language
commands
:
-
pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
-
pytest -v -s models/embedding/language -m 'not core_model'
-
pytest -v -s models/embedding/vision_language -m 'not core_model'
-
label
:
Decoder-only
Multi-Modal Models Test (Standard)
# 2
2
min
-
label
:
Multi-Modal Models Test (Standard)
# 2
6
min
#mirror_hardwares: [amd]
source_file_dependencies
:
-
vllm/
-
tests/models/decoder_only/audio_language
-
tests/models/decoder_only/vision_language
-
tests/models/embedding/vision_language
-
tests/models/encoder_decoder/vision_language
commands
:
-
pytest -v -s models/decoder_only/audio_language -m core_model
-
pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m core_model
# No tests under this group for now
# - pytest -v -s models/decoder_only/audio_language -m quant_model
-
pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m quant_model
-
pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model'
-
pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model'
-
pytest -v -s models/encoder_decoder/language -m core_model
-
pytest -v -s models/encoder_decoder/vision_language -m core_model
-
label
:
Decoder-only
Multi-Modal Models Test (Extended)
# 1h1
0
m
-
label
:
Multi-Modal Models Test (Extended)
# 1h1
5
m
nightly
:
true
source_file_dependencies
:
-
vllm/
-
tests/models/decoder_only/audio_language
-
tests/models/decoder_only/vision_language
-
tests/models/embedding/vision_language
-
tests/models/encoder_decoder/vision_language
commands
:
-
pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model'
# HACK - run phi3v tests separately to sidestep this transformers bug
# https://github.com/huggingface/transformers/issues/34307
-
pytest -v -s models/decoder_only/vision_language/test_phi3v.py
-
pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model'
-
label
:
Other Models Test
# 20min
#mirror_hardwares: [amd]
source_file_dependencies
:
-
vllm/
-
tests/models/embedding/language
-
tests/models/embedding/vision_language
-
tests/models/encoder_decoder/language
-
tests/models/encoder_decoder/vision_language
commands
:
-
pytest -v -s models/embedding/language
-
pytest -v -s models/embedding/vision_language
-
pytest -v -s models/encoder_decoder/language
-
pytest -v -s models/encoder_decoder/vision_language
-
pytest -v -s models/encoder_decoder/language -m 'not core_model'
-
pytest -v -s models/encoder_decoder/vision_language -m 'not core_model'
# This test is used only in PR development phase to test individual models and should never run on main
-
label
:
Custom Models Test
...
...
docs/source/models/supported_models.rst
View file @
b40cf640
...
...
@@ -330,11 +330,16 @@ Text Embedding
- :code:`BAAI/bge-multilingual-gemma2`, etc.
-
- ✅︎
* - :code:`MistralModel`
-
Mistral
-based
* -
:code:`LlamaModel`, :code:`LlamaForCausalLM`,
:code:`MistralModel`
, etc.
-
Llama
-based
- :code:`intfloat/e5-mistral-7b-instruct`, etc.
- ✅︎
- ✅︎
* - :code:`Qwen2Model`, :code:`Qwen2ForCausalLM`
- Qwen2-based
- :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-1.5B-instruct`, etc.
- ✅︎
- ✅︎
.. important::
Some model architectures support both generation and embedding tasks.
...
...
@@ -355,7 +360,7 @@ Reward Modeling
* - :code:`Qwen2ForRewardModel`
- Qwen2-based
- :code:`Qwen/Qwen2.5-Math-RM-72B`, etc.
-
-
✅︎
- ✅︎
.. note::
...
...
@@ -376,7 +381,7 @@ Classification
* - :code:`Qwen2ForSequenceClassification`
- Qwen2-based
- :code:`jason9693/Qwen2.5-1.5B-apeach`, etc.
-
-
✅︎
- ✅︎
.. note::
...
...
tests/models/decoder_only/language/test_jamba.py
View file @
b40cf640
...
...
@@ -33,6 +33,10 @@ def test_models(
with
vllm_runner
(
model
,
dtype
=
dtype
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print
(
vllm_model
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
)
for
i
in
range
(
len
(
example_prompts
)):
hf_output_ids
,
hf_output_str
=
hf_outputs
[
i
]
...
...
@@ -293,17 +297,3 @@ def test_jamba_distributed_produces_identical_generation(
name_0
=
"vllm_tp_1"
,
name_1
=
"vllm_tp_2"
,
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
def
test_model_print
(
vllm_runner
,
model
:
str
,
dtype
:
str
,
)
->
None
:
with
vllm_runner
(
model
,
dtype
=
dtype
)
as
vllm_model
:
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print
(
vllm_model
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
)
tests/models/decoder_only/language/test_mamba.py
View file @
b40cf640
...
...
@@ -51,6 +51,10 @@ def test_models(
with
vllm_runner
(
model
,
dtype
=
dtype
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print
(
vllm_model
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
)
for
i
in
range
(
len
(
example_prompts
)):
hf_output_ids
,
hf_output_str
=
hf_outputs
[
i
]
...
...
@@ -279,17 +283,3 @@ def test_state_cleanup(
except
ValueError
:
pytest
.
fail
(
"Mamba inner state wasn't cleaned up between states, "
"could be related to finished_requests_ids"
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
def
test_model_print
(
vllm_runner
,
model
:
str
,
dtype
:
str
,
)
->
None
:
with
vllm_runner
(
model
,
dtype
=
dtype
)
as
vllm_model
:
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print
(
vllm_model
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
)
tests/models/decoder_only/language/test_models.py
View file @
b40cf640
...
...
@@ -4,37 +4,52 @@ Run `pytest tests/models/test_models.py`.
"""
import
pytest
from
vllm.platforms
import
current_platform
from
...utils
import
check_logprobs_close
MODELS
=
[
"facebook/opt-125m"
,
# opt
"openai-community/gpt2"
,
# gpt2
# "Milos/slovak-gpt-j-405M", # gptj
# "bigcode/tiny_starcoder_py", # gpt_bigcode
# "EleutherAI/pythia-70m", # gpt_neox
"bigscience/bloom-560m"
,
# bloom - testing alibi slopes
"microsoft/phi-2"
,
# phi
# "stabilityai/stablelm-3b-4e1t", # stablelm
# "bigcode/starcoder2-3b", # starcoder2
"google/gemma-1.1-2b-it"
,
# gemma
"Qwen/Qwen2.5-0.5B-Instruct"
,
# qwen2
"meta-llama/Llama-3.2-1B-Instruct"
,
# llama
]
if
not
current_platform
.
is_cpu
():
MODELS
+=
[
# fused_moe which not supported on CPU
"openbmb/MiniCPM3-4B"
,
]
target_dtype
=
"half"
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
target_dtype
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
pytest
.
param
(
"bigscience/bloom-560m"
,
# bloom - testing alibi slopes
marks
=
[
pytest
.
mark
.
core_model
,
pytest
.
mark
.
cpu_model
],
),
pytest
.
param
(
"openai-community/gpt2"
,
# gpt2
marks
=
[
pytest
.
mark
.
core_model
,
pytest
.
mark
.
cpu_model
],
),
pytest
.
param
(
"Milos/slovak-gpt-j-405M"
),
# gptj
pytest
.
param
(
"bigcode/tiny_starcoder_py"
),
# gpt_bigcode
pytest
.
param
(
"EleutherAI/pythia-70m"
),
# gpt_neox
pytest
.
param
(
"google/gemma-1.1-2b-it"
,
# gemma
marks
=
[
pytest
.
mark
.
core_model
,
pytest
.
mark
.
cpu_model
],
),
pytest
.
param
(
"meta-llama/Llama-3.2-1B-Instruct"
,
# llama
marks
=
[
pytest
.
mark
.
core_model
,
pytest
.
mark
.
cpu_model
],
),
pytest
.
param
(
"openbmb/MiniCPM3-4B"
,
# fused_moe not supported on CPU
marks
=
[
pytest
.
mark
.
core_model
],
),
pytest
.
param
(
"facebook/opt-125m"
,
# opt
marks
=
[
pytest
.
mark
.
core_model
,
pytest
.
mark
.
cpu_model
],
),
pytest
.
param
(
"microsoft/phi-2"
,
# phi
marks
=
[
pytest
.
mark
.
core_model
],
),
pytest
.
param
(
"Qwen/Qwen2.5-0.5B-Instruct"
,
# qwen2
marks
=
[
pytest
.
mark
.
core_model
],
),
pytest
.
param
(
"stabilityai/stablelm-3b-4e1t"
),
# stablelm
pytest
.
param
(
"bigcode/starcoder2-3b"
),
# starcoder2
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_models
(
...
...
tests/models/embedding/language/test_cls_models.py
View file @
b40cf640
...
...
@@ -9,10 +9,14 @@ import pytest
import
torch
from
transformers
import
AutoModelForSequenceClassification
CLASSIFICATION_MODELS
=
[
"jason9693/Qwen2.5-1.5B-apeach"
]
@
pytest
.
mark
.
parametrize
(
"model"
,
CLASSIFICATION_MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
pytest
.
param
(
"jason9693/Qwen2.5-1.5B-apeach"
,
marks
=
[
pytest
.
mark
.
core_model
,
pytest
.
mark
.
cpu_model
]),
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
def
test_classification_models
(
hf_runner
,
...
...
@@ -23,31 +27,19 @@ def test_classification_models(
)
->
None
:
with
vllm_runner
(
model
,
dtype
=
dtype
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
classify
(
example_prompts
)
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print
(
vllm_model
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
)
with
hf_runner
(
model
,
dtype
=
dtype
,
auto_cls
=
AutoModelForSequenceClassification
)
as
hf_model
:
hf_outputs
=
hf_model
.
classify
(
example_prompts
)
print
(
hf_outputs
,
vllm_outputs
)
# check logits difference
for
hf_output
,
vllm_output
in
zip
(
hf_outputs
,
vllm_outputs
):
hf_output
=
torch
.
tensor
(
hf_output
)
vllm_output
=
torch
.
tensor
(
vllm_output
)
assert
torch
.
allclose
(
hf_output
,
vllm_output
,
1e-3
)
@
pytest
.
mark
.
parametrize
(
"model"
,
CLASSIFICATION_MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
def
test_classification_model_print
(
vllm_runner
,
model
:
str
,
dtype
:
str
,
)
->
None
:
with
vllm_runner
(
model
,
dtype
=
dtype
)
as
vllm_model
:
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print
(
vllm_model
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
)
tests/models/embedding/language/test_embedding.py
View file @
b40cf640
...
...
@@ -4,25 +4,25 @@ Run `pytest tests/models/embedding/language/test_embedding.py`.
"""
import
pytest
from
vllm.utils
import
current_platform
from
..utils
import
check_embeddings_close
# Model, Guard
MODELS
=
[
"intfloat/e5-mistral-7b-instruct"
,
"BAAI/bge-base-en-v1.5"
,
"BAAI/bge-multilingual-gemma2"
,
"intfloat/multilingual-e5-large"
,
]
ENCODER_ONLY
=
[
"BAAI/bge-base-en-v1.5"
,
"intfloat/multilingual-e5-large"
,
]
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
# [Encoder-only]
pytest
.
param
(
"BAAI/bge-base-en-v1.5"
,
marks
=
[
pytest
.
mark
.
core_model
,
pytest
.
mark
.
cpu_model
]),
pytest
.
param
(
"intfloat/multilingual-e5-large"
),
# [Encoder-decoder]
pytest
.
param
(
"intfloat/e5-mistral-7b-instruct"
,
marks
=
[
pytest
.
mark
.
core_model
,
pytest
.
mark
.
cpu_model
]),
pytest
.
param
(
"BAAI/bge-multilingual-gemma2"
,
marks
=
[
pytest
.
mark
.
core_model
]),
pytest
.
param
(
"ssmits/Qwen2-7B-Instruct-embed-base"
),
pytest
.
param
(
"Alibaba-NLP/gte-Qwen2-1.5B-instruct"
),
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
def
test_models
(
hf_runner
,
...
...
@@ -31,9 +31,6 @@ def test_models(
model
,
dtype
:
str
,
)
->
None
:
if
model
not
in
ENCODER_ONLY
and
current_platform
.
is_cpu
():
pytest
.
skip
(
"Skip large embedding models test on CPU."
)
# The example_prompts has ending "\n", for example:
# "Write a short story about a robot that dreams for the first time.\n"
# sentence_transformers will strip the input texts, see:
...
...
@@ -46,8 +43,13 @@ def test_models(
is_sentence_transformer
=
True
)
as
hf_model
:
hf_outputs
=
hf_model
.
encode
(
example_prompts
)
with
vllm_runner
(
model
,
dtype
=
dtype
,
max_model_len
=
None
)
as
vllm_model
:
with
vllm_runner
(
model
,
task
=
"embedding"
,
dtype
=
dtype
,
max_model_len
=
None
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
encode
(
example_prompts
)
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print
(
vllm_model
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
)
check_embeddings_close
(
embeddings_0_lst
=
hf_outputs
,
...
...
tests/models/embedding/vision_language/test_llava_next.py
View file @
b40cf640
...
...
@@ -88,6 +88,7 @@ def _run_test(
@
pytest
.
mark
.
skipif
(
transformers
.
__version__
.
startswith
(
"4.46"
),
reason
=
"Model broken with changes in transformers 4.46"
)
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
def
test_models_text
(
...
...
@@ -112,6 +113,7 @@ def test_models_text(
@
large_gpu_test
(
min_gb
=
48
)
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
def
test_models_image
(
...
...
tests/models/embedding/vision_language/test_phi3v.py
View file @
b40cf640
...
...
@@ -74,6 +74,7 @@ def _run_test(
)
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
def
test_models_text
(
...
...
@@ -98,6 +99,7 @@ def test_models_text(
@
large_gpu_test
(
min_gb
=
48
)
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
def
test_models_image
(
...
...
tests/models/encoder_decoder/language/test_bart.py
View file @
b40cf640
...
...
@@ -14,8 +14,6 @@ from ....conftest import (DecoderPromptType, ExplicitEncoderDecoderPrompt,
from
....utils
import
multi_gpu_test
from
...utils
import
check_logprobs_close
MODELS
=
[
"facebook/bart-base"
,
"facebook/bart-large-cnn"
]
def
vllm_to_hf_output
(
vllm_output
:
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]],
...
...
@@ -170,7 +168,14 @@ def run_test(
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
pytest
.
param
(
"facebook/bart-base"
,
marks
=
[
pytest
.
mark
.
core_model
,
pytest
.
mark
.
cpu_model
]),
pytest
.
param
(
"facebook/bart-large-cnn"
),
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
,
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
...
...
tests/models/encoder_decoder/vision_language/test_mllama.py
View file @
b40cf640
...
...
@@ -233,6 +233,7 @@ def clear_cache():
@
large_gpu_test
(
min_gb
=
48
)
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
"sizes"
,
...
...
@@ -278,6 +279,7 @@ def test_models_single_leading_image(hf_runner, vllm_runner, image_assets,
@
large_gpu_test
(
min_gb
=
48
)
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
...
...
@@ -326,6 +328,7 @@ def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
@
large_gpu_test
(
min_gb
=
48
)
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
...
...
tests/models/registry.py
View file @
b40cf640
...
...
@@ -129,9 +129,13 @@ _EMBEDDING_EXAMPLE_MODELS = {
# [Text-only]
"BertModel"
:
_HfExamplesInfo
(
"BAAI/bge-base-en-v1.5"
),
"Gemma2Model"
:
_HfExamplesInfo
(
"BAAI/bge-multilingual-gemma2"
),
"LlamaModel"
:
_HfExamplesInfo
(
"llama"
,
is_available_online
=
False
),
"MistralModel"
:
_HfExamplesInfo
(
"intfloat/e5-mistral-7b-instruct"
),
"Qwen2Model"
:
_HfExamplesInfo
(
"ssmits/Qwen2-7B-Instruct-embed-base"
),
"Qwen2ForRewardModel"
:
_HfExamplesInfo
(
"Qwen/Qwen2.5-Math-RM-72B"
),
"Qwen2ForSequenceClassification"
:
_HfExamplesInfo
(
"jason9693/Qwen2.5-1.5B-apeach"
),
# noqa: E501
"RobertaModel"
:
_HfExamplesInfo
(
"sentence-transformers/stsb-roberta-base-v2"
),
# noqa: E501
"XLMRobertaModel"
:
_HfExamplesInfo
(
"intfloat/multilingual-e5-large"
),
# [Multimodal]
"LlavaNextForConditionalGeneration"
:
_HfExamplesInfo
(
"royokong/e5-v"
),
"Phi3VForCausalLM"
:
_HfExamplesInfo
(
"TIGER-Lab/VLM2Vec-Full"
,
...
...
tests/models/test_registry.py
View file @
b40cf640
...
...
@@ -77,8 +77,8 @@ def test_registry_is_pp(model_arch, is_pp, init_cuda):
def
test_hf_registry_coverage
():
untested_archs
=
(
HF_EXAMPLE_MODELS
.
get_supported_archs
()
-
set
(
ModelRegistry
.
get_supported_archs
())
)
untested_archs
=
(
ModelRegistry
.
get_supported_archs
()
-
HF_EXAMPLE_MODELS
.
get_supported_archs
())
assert
not
untested_archs
,
(
"Please add the following architectures to "
...
...
vllm/model_executor/models/qwen2.py
View file @
b40cf640
...
...
@@ -37,6 +37,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
...
...
@@ -44,8 +45,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
is_pp_missing_parameter
,
...
...
@@ -247,6 +249,18 @@ class Qwen2Model(nn.Module):
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
# TODO (@robertgshaw2): see if this can be moved out
if
(
cache_config
.
sliding_window
is
not
None
and
hasattr
(
config
,
"max_window_layers"
)):
raise
ValueError
(
"Sliding window for some but all layers is not "
"supported. This model uses sliding window "
"but `max_window_layers` = {} is less than "
"`num_hidden_layers` = {}. Please open an issue "
"to discuss this feature."
.
format
(
config
.
max_window_layers
,
config
.
num_hidden_layers
,
))
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
...
...
@@ -405,20 +419,9 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
# TODO (@robertgshaw2): see if this can be moved out
if
(
cache_config
.
sliding_window
is
not
None
and
hasattr
(
config
,
"max_window_layers"
)):
raise
ValueError
(
"Sliding window for some but all layers is not "
"supported. This model uses sliding window "
"but `max_window_layers` = {} is less than "
"`num_hidden_layers` = {}. Please open an issue "
"to discuss this feature."
.
format
(
config
.
max_window_layers
,
config
.
num_hidden_layers
,
))
pooler_config
=
vllm_config
.
model_config
.
pooler_config
self
.
config
=
config
self
.
lora_config
=
lora_config
...
...
@@ -438,6 +441,15 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
get_sampler
()
# The same model class supports both language generation and embedding
# because the architecture name is the same
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
,
softmax
=
False
)
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
...
...
@@ -475,6 +487,13 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
loader
=
AutoWeightsLoader
(
self
,
...
...
@@ -482,3 +501,70 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if
self
.
config
.
tie_word_embeddings
else
None
),
)
loader
.
load_weights
(
weights
)
class
Qwen2EmbeddingModel
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
# LoRA specific attributes
supported_lora_modules
=
[
"qkv_proj"
,
"o_proj"
,
"gate_up_proj"
,
"down_proj"
,
]
embedding_modules
=
{}
embedding_padding_modules
=
[]
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
self
.
model
=
Qwen2Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
PoolingType
.
MEAN
,
normalize
=
True
,
softmax
=
False
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
return
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
loader
=
AutoWeightsLoader
(
self
,
ignore_unexpected_prefixes
=
[
"lm_head."
])
loader
.
load_weights
(
weights
)
vllm/model_executor/models/qwen2_cls.py
View file @
b40cf640
...
...
@@ -17,10 +17,11 @@ from vllm.model_executor.models.qwen2 import Qwen2Model
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
AutoWeightsLoader
,
maybe_prefix
class
Qwen2ForSequenceClassification
(
nn
.
Module
):
class
Qwen2ForSequenceClassification
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
...
...
@@ -46,21 +47,9 @@ class Qwen2ForSequenceClassification(nn.Module):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
# TODO (@robertgshaw2): see if this can be moved out
if
(
cache_config
.
sliding_window
is
not
None
and
hasattr
(
config
,
"max_window_layers"
)):
raise
ValueError
(
"Sliding window for some but all layers is not "
"supported. This model uses sliding window "
"but `max_window_layers` = {} is less than "
"`num_hidden_layers` = {}. Please open an issue "
"to discuss this feature."
.
format
(
config
.
max_window_layers
,
config
.
num_hidden_layers
,
))
self
.
config
=
config
self
.
lora_config
=
lora_config
...
...
vllm/model_executor/models/qwen2_rm.py
View file @
b40cf640
...
...
@@ -16,7 +16,7 @@ from vllm.model_executor.layers.pooler import Pooler, PoolingType
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
.interfaces
import
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.qwen2
import
Qwen2Model
from
.utils
import
AutoWeightsLoader
,
maybe_prefix
...
...
@@ -32,7 +32,7 @@ class ReLU(nn.Module):
return
self
.
activation
(
input
)
class
Qwen2ForRewardModel
(
nn
.
Module
,
SupportsPP
):
class
Qwen2ForRewardModel
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
...
...
@@ -58,21 +58,9 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
# TODO (@robertgshaw2): see if this can be moved out
if
(
cache_config
.
sliding_window
is
not
None
and
hasattr
(
config
,
"max_window_layers"
)):
raise
ValueError
(
"Sliding window for some but all layers is not "
"supported. This model uses sliding window "
"but `max_window_layers` = {} is less than "
"`num_hidden_layers` = {}. Please open an issue "
"to discuss this feature."
.
format
(
config
.
max_window_layers
,
config
.
num_hidden_layers
,
))
self
.
config
=
config
self
.
lora_config
=
lora_config
...
...
vllm/model_executor/models/registry.py
View file @
b40cf640
...
...
@@ -11,7 +11,8 @@ import tempfile
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
,
field
from
functools
import
lru_cache
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
TypeVar
,
Union
from
typing
import
(
AbstractSet
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
TypeVar
,
Union
)
import
cloudpickle
import
torch.nn
as
nn
...
...
@@ -110,6 +111,8 @@ _EMBEDDING_MODELS = {
},
"MistralModel"
:
(
"llama"
,
"LlamaEmbeddingModel"
),
"Phi3ForCausalLM"
:
(
"phi3"
,
"Phi3ForCausalLM"
),
"Qwen2Model"
:
(
"qwen2"
,
"Qwen2EmbeddingModel"
),
"Qwen2ForCausalLM"
:
(
"qwen2"
,
"Qwen2ForCausalLM"
),
"Qwen2ForRewardModel"
:
(
"qwen2_rm"
,
"Qwen2ForRewardModel"
),
"Qwen2ForSequenceClassification"
:
(
"qwen2_cls"
,
"Qwen2ForSequenceClassification"
),
# noqa: E501
# [Multimodal]
...
...
@@ -301,8 +304,8 @@ class _ModelRegistry:
# Keyed by model_arch
models
:
Dict
[
str
,
_BaseRegisteredModel
]
=
field
(
default_factory
=
dict
)
def
get_supported_archs
(
self
)
->
Lis
t
[
str
]:
return
list
(
self
.
models
.
keys
()
)
def
get_supported_archs
(
self
)
->
AbstractSe
t
[
str
]:
return
self
.
models
.
keys
()
def
register_model
(
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment