Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e1a3f5e8
Unverified
Commit
e1a3f5e8
authored
Sep 29, 2024
by
Cyrus Leung
Committed by
GitHub
Sep 28, 2024
Browse files
[CI/Build] Update models tests & examples (#8874)
Co-authored-by:
Roger Wang
<
ywang@roblox.com
>
parent
19d02ff9
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
239 additions
and
184 deletions
+239
-184
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+32
-19
examples/offline_inference_vision_language.py
examples/offline_inference_vision_language.py
+19
-9
examples/offline_inference_vision_language_multi_image.py
examples/offline_inference_vision_language_multi_image.py
+10
-3
tests/conftest.py
tests/conftest.py
+44
-40
tests/models/decoder_only/vision_language/test_llava_onevision.py
...dels/decoder_only/vision_language/test_llava_onevision.py
+15
-14
tests/models/decoder_only/vision_language/test_minicpmv.py
tests/models/decoder_only/vision_language/test_minicpmv.py
+1
-1
tests/models/decoder_only/vision_language/test_phi3v.py
tests/models/decoder_only/vision_language/test_phi3v.py
+1
-1
tests/models/decoder_only/vision_language/test_qwen.py
tests/models/decoder_only/vision_language/test_qwen.py
+1
-1
tests/models/encoder_decoder/vision_language/test_broadcast.py
.../models/encoder_decoder/vision_language/test_broadcast.py
+35
-0
tests/models/encoder_decoder/vision_language/test_mllama.py
tests/models/encoder_decoder/vision_language/test_mllama.py
+69
-84
tests/models/utils.py
tests/models/utils.py
+8
-1
vllm/inputs/registry.py
vllm/inputs/registry.py
+2
-10
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+2
-1
No files found.
.buildkite/test-pipeline.yaml
View file @
e1a3f5e8
...
...
@@ -9,6 +9,7 @@
# label(str): the name of the test. emoji allowed.
# fast_check(bool): whether to run this on each commit on fastcheck pipeline.
# fast_check_only(bool): run this test on fastcheck pipeline only
# optional(bool): never run this test by default (i.e. need to unblock manually)
# command(str): the single command to run for tests. incompatible with commands.
# commands(list): the list of commands to run for test. incompatbile with command.
# mirror_hardwares(list): the list of hardwares to run the test on as well. currently only supports [amd]
...
...
@@ -39,7 +40,7 @@ steps:
# Check API reference (if it fails, you may have missing mock imports)
-
grep \"sig sig-object py\" build/html/dev/sampling_params.html
-
label
:
Async Engine, Inputs, Utils, Worker Test
#
15
min
-
label
:
Async Engine, Inputs, Utils, Worker Test
#
24
min
fast_check
:
true
source_file_dependencies
:
-
vllm/
...
...
@@ -81,7 +82,7 @@ steps:
commands
:
-
pytest -v -s core
-
label
:
Entrypoints Test
#
2
0min
-
label
:
Entrypoints Test
#
4
0min
working_dir
:
"
/vllm-workspace/tests"
fast_check
:
true
mirror_hardwares
:
[
amd
]
...
...
@@ -151,7 +152,7 @@ steps:
# OOM in the CI unless we run this separately
-
pytest -v -s tokenization
-
label
:
Examples Test
# 1
2
min
-
label
:
Examples Test
# 1
5
min
working_dir
:
"
/vllm-workspace/examples"
#mirror_hardwares: [amd]
source_file_dependencies
:
...
...
@@ -169,7 +170,7 @@ steps:
-
python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
-
python3 offline_inference_encoder_decoder.py
-
label
:
Prefix Caching Test
#
7
min
-
label
:
Prefix Caching Test
#
9
min
#mirror_hardwares: [amd]
source_file_dependencies
:
-
vllm/
...
...
@@ -177,7 +178,7 @@ steps:
commands
:
-
pytest -v -s prefix_caching
-
label
:
Samplers Test
#
18
min
-
label
:
Samplers Test
#
36
min
source_file_dependencies
:
-
vllm/model_executor/layers
-
vllm/sampling_metadata.py
...
...
@@ -193,7 +194,7 @@ steps:
-
tests/test_logits_processor
command
:
pytest -v -s test_logits_processor.py
-
label
:
Speculative decoding tests
#
22
min
-
label
:
Speculative decoding tests
#
30
min
source_file_dependencies
:
-
vllm/spec_decode
-
tests/spec_decode
...
...
@@ -203,7 +204,7 @@ steps:
-
pytest -v -s spec_decode/e2e/test_multistep_correctness.py
-
pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py
-
label
:
LoRA Test %N
#
30
min each
-
label
:
LoRA Test %N
#
15
min each
mirror_hardwares
:
[
amd
]
source_file_dependencies
:
-
vllm/lora
...
...
@@ -211,7 +212,7 @@ steps:
command
:
pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py
parallelism
:
4
-
label
:
"
PyTorch
Fullgraph
Smoke
Test"
-
label
:
"
PyTorch
Fullgraph
Smoke
Test"
# 9min
fast_check
:
true
source_file_dependencies
:
-
vllm/
...
...
@@ -219,14 +220,14 @@ steps:
commands
:
-
pytest -v -s compile/test_full_graph_smoke.py
-
label
:
"
PyTorch
Fullgraph
Test"
-
label
:
"
PyTorch
Fullgraph
Test"
# 18min
source_file_dependencies
:
-
vllm/
-
tests/compile
commands
:
-
pytest -v -s compile/test_full_graph.py
-
label
:
Kernels Test %N
#
30min
each
-
label
:
Kernels Test %N
#
1h
each
mirror_hardwares
:
[
amd
]
source_file_dependencies
:
-
csrc/
...
...
@@ -256,7 +257,7 @@ steps:
-
pip install aiohttp
-
bash run-benchmarks.sh
-
label
:
Quantization Test
#
15
min
-
label
:
Quantization Test
#
33
min
source_file_dependencies
:
-
csrc/
-
vllm/model_executor/layers/quantization
...
...
@@ -300,7 +301,7 @@ steps:
-
pytest -v -s models/test_oot_registration.py
# it needs a clean process
-
pytest -v -s models/*.py --ignore=models/test_oot_registration.py
-
label
:
Decoder-only Language Models Test
# 1h3min
-
label
:
Decoder-only Language Models Test
# 1h3
6
min
#mirror_hardwares: [amd]
source_file_dependencies
:
-
vllm/
...
...
@@ -308,7 +309,7 @@ steps:
commands
:
-
pytest -v -s models/decoder_only/language
-
label
:
Decoder-only Multi-Modal Models Test
#
56
min
-
label
:
Decoder-only Multi-Modal Models Test
#
1h31
min
#mirror_hardwares: [amd]
source_file_dependencies
:
-
vllm/
...
...
@@ -318,15 +319,25 @@ steps:
-
pytest -v -s models/decoder_only/audio_language
-
pytest -v -s models/decoder_only/vision_language
-
label
:
Other Models Test
#
5
min
-
label
:
Other Models Test
#
6
min
#mirror_hardwares: [amd]
source_file_dependencies
:
-
vllm/
-
tests/models/embedding/language
-
tests/models/encoder_decoder/language
-
tests/models/encoder_decoder/vision_language
commands
:
-
pytest -v -s models/embedding/language
-
pytest -v -s models/encoder_decoder/language
-
pytest -v -s models/encoder_decoder/vision_language
-
label
:
Custom Models Test
#mirror_hardwares: [amd]
optional
:
true
commands
:
# PR authors can temporarily add commands below to test individual models
# e.g. pytest -v -s models/encoder_decoder/vision_language/test_mllama.py
# *To avoid merge conflicts, remember to REMOVE (not just comment out) them before merging the PR*
##### 1 GPU test #####
##### multi gpus test #####
...
...
@@ -359,7 +370,7 @@ steps:
-
# the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
-
VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep -q 'Same node test passed'
-
label
:
Distributed Tests (2 GPUs)
#
28
min
-
label
:
Distributed Tests (2 GPUs)
#
40
min
#mirror_hardwares: [amd]
working_dir
:
"
/vllm-workspace/tests"
num_gpus
:
2
...
...
@@ -376,14 +387,16 @@ steps:
-
VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
-
TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus
# Avoid importing model tests that cause CUDA reinitialization error
-
pytest models/encoder_decoder/language/test_bart.py models/decoder_only/vision_language/test_broadcast.py -v -s -m distributed_2_gpus
-
pytest models/encoder_decoder/language/test_bart.py -v -s -m distributed_2_gpus
-
pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m distributed_2_gpus
-
pytest models/decoder_only/vision_language/test_broadcast.py -v -s -m distributed_2_gpus
-
pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
-
pip install -e ./plugins/vllm_add_dummy_model
-
pytest -v -s distributed/test_distributed_oot.py
-
CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
-
CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
-
label
:
Multi-step Tests (4 GPUs)
#
21
min
-
label
:
Multi-step Tests (4 GPUs)
#
36
min
working_dir
:
"
/vllm-workspace/tests"
num_gpus
:
4
source_file_dependencies
:
...
...
@@ -401,7 +414,7 @@ steps:
-
pytest -v -s multi_step/test_correctness_async_llm.py
-
pytest -v -s multi_step/test_correctness_llm.py
-
label
:
Pipeline Parallelism Test
#
23
min
-
label
:
Pipeline Parallelism Test
#
45
min
working_dir
:
"
/vllm-workspace/tests"
num_gpus
:
4
source_file_dependencies
:
...
...
@@ -427,7 +440,7 @@ steps:
-
export VLLM_WORKER_MULTIPROC_METHOD=spawn
-
pytest -v -s -x lora/test_long_context.py
-
label
:
Weight Loading Multiple GPU Test
-
label
:
Weight Loading Multiple GPU Test
# 33min
working_dir
:
"
/vllm-workspace/tests"
num_gpus
:
2
source_file_dependencies
:
...
...
examples/offline_inference_vision_language.py
View file @
e1a3f5e8
...
...
@@ -12,6 +12,10 @@ from vllm.assets.image import ImageAsset
from
vllm.assets.video
import
VideoAsset
from
vllm.utils
import
FlexibleArgumentParser
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.
# LLaVA-1.5
def
run_llava
(
question
,
modality
):
...
...
@@ -19,7 +23,7 @@ def run_llava(question, modality):
prompt
=
f
"USER: <image>
\n
{
question
}
\n
ASSISTANT:"
llm
=
LLM
(
model
=
"llava-hf/llava-1.5-7b-hf"
)
llm
=
LLM
(
model
=
"llava-hf/llava-1.5-7b-hf"
,
max_model_len
=
4096
)
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
...
...
@@ -57,7 +61,7 @@ def run_llava_onevision(question, modality):
<|im_start|>assistant
\n
"
llm
=
LLM
(
model
=
"llava-hf/llava-onevision-qwen2-7b-ov-hf"
,
max_model_len
=
32768
)
max_model_len
=
16384
)
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
...
...
@@ -67,7 +71,7 @@ def run_fuyu(question, modality):
assert
modality
==
"image"
prompt
=
f
"
{
question
}
\n
"
llm
=
LLM
(
model
=
"adept/fuyu-8b"
)
llm
=
LLM
(
model
=
"adept/fuyu-8b"
,
max_model_len
=
2048
,
max_num_seqs
=
2
)
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
...
...
@@ -99,7 +103,8 @@ def run_phi3v(question, modality):
llm
=
LLM
(
model
=
"microsoft/Phi-3-vision-128k-instruct"
,
trust_remote_code
=
True
,
max_num_seqs
=
5
,
max_model_len
=
4096
,
max_num_seqs
=
2
,
mm_processor_kwargs
=
{
"num_crops"
:
16
},
)
stop_token_ids
=
None
...
...
@@ -122,7 +127,7 @@ def run_chameleon(question, modality):
assert
modality
==
"image"
prompt
=
f
"
{
question
}
<image>"
llm
=
LLM
(
model
=
"facebook/chameleon-7b"
)
llm
=
LLM
(
model
=
"facebook/chameleon-7b"
,
max_model_len
=
4096
)
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
...
...
@@ -145,6 +150,8 @@ def run_minicpmv(question, modality):
trust_remote_code
=
True
)
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
4096
,
max_num_seqs
=
2
,
trust_remote_code
=
True
,
)
# NOTE The stop_token_ids are different for various versions of MiniCPM-V
...
...
@@ -177,7 +184,7 @@ def run_internvl(question, modality):
llm
=
LLM
(
model
=
model_name
,
trust_remote_code
=
True
,
max_
num_seqs
=
5
,
max_
model_len
=
4096
,
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
...
...
@@ -215,7 +222,8 @@ def run_qwen_vl(question, modality):
llm
=
LLM
(
model
=
"Qwen/Qwen-VL"
,
trust_remote_code
=
True
,
max_num_seqs
=
5
,
max_model_len
=
1024
,
max_num_seqs
=
2
,
)
prompt
=
f
"
{
question
}
Picture 1: <img></img>
\n
"
...
...
@@ -229,8 +237,10 @@ def run_qwen2_vl(question, modality):
model_name
=
"Qwen/Qwen2-VL-7B-Instruct"
# Tested on L40
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
8192
,
max_num_seqs
=
5
,
)
...
...
@@ -252,10 +262,10 @@ def run_mllama(question, modality):
# max_model_len (131072) for this model may cause OOM.
# You may lower either to run this example on lower-end GPUs.
# The configuration below has been confirmed to launch on a
# single H100 GPU.
# The configuration below has been confirmed to launch on a single L40 GPU.
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
4096
,
max_num_seqs
=
16
,
enforce_eager
=
True
,
)
...
...
examples/offline_inference_vision_language_multi_image.py
View file @
e1a3f5e8
...
...
@@ -28,12 +28,18 @@ class ModelRequestData(NamedTuple):
chat_template
:
Optional
[
str
]
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.
def
load_qwenvl_chat
(
question
:
str
,
image_urls
:
List
[
str
])
->
ModelRequestData
:
model_name
=
"Qwen/Qwen-VL-Chat"
llm
=
LLM
(
model
=
model_name
,
trust_remote_code
=
True
,
max_num_seqs
=
5
,
max_model_len
=
1024
,
max_num_seqs
=
2
,
limit_mm_per_prompt
=
{
"image"
:
len
(
image_urls
)},
)
placeholders
=
""
.
join
(
f
"Picture
{
i
}
: <img></img>
\n
"
...
...
@@ -83,6 +89,7 @@ def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData:
model
=
"microsoft/Phi-3.5-vision-instruct"
,
trust_remote_code
=
True
,
max_model_len
=
4096
,
max_num_seqs
=
2
,
limit_mm_per_prompt
=
{
"image"
:
len
(
image_urls
)},
mm_processor_kwargs
=
{
"num_crops"
:
4
},
)
...
...
@@ -106,7 +113,6 @@ def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData:
llm
=
LLM
(
model
=
model_name
,
trust_remote_code
=
True
,
max_num_seqs
=
5
,
max_model_len
=
4096
,
limit_mm_per_prompt
=
{
"image"
:
len
(
image_urls
)},
)
...
...
@@ -148,10 +154,11 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
model_name
=
"Qwen/Qwen2-VL-7B-Instruct"
# Tested on L40
llm
=
LLM
(
model
=
model_name
,
max_num_seqs
=
5
,
max_model_len
=
32768
if
process_vision_info
is
None
else
4096
,
max_num_seqs
=
5
,
limit_mm_per_prompt
=
{
"image"
:
len
(
image_urls
)},
)
...
...
tests/conftest.py
View file @
e1a3f5e8
...
...
@@ -246,17 +246,14 @@ _T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature)
class
HfRunner
:
def
wrap_device
(
self
,
input
:
_T
)
->
_T
:
if
not
is_cpu
():
# Check if the input is already on the GPU
if
hasattr
(
input
,
'device'
)
and
input
.
device
.
type
==
"cuda"
:
return
input
# Already on GPU, no need to move
return
input
.
to
(
"cuda"
)
else
:
# Check if the input is already on the CPU
if
hasattr
(
input
,
'device'
)
and
input
.
device
.
type
==
"cpu"
:
return
input
# Already on CPU, no need to move
return
input
.
to
(
"cpu"
)
def
wrap_device
(
self
,
input
:
_T
,
device
:
Optional
[
str
]
=
None
)
->
_T
:
if
device
is
None
:
return
self
.
wrap_device
(
input
,
"cpu"
if
is_cpu
()
else
"cuda"
)
if
hasattr
(
input
,
"device"
)
and
input
.
device
.
type
==
device
:
return
input
return
input
.
to
(
device
)
def
__init__
(
self
,
...
...
@@ -333,7 +330,7 @@ class HfRunner:
inputs
=
self
.
postprocess_inputs
(
inputs
)
output_ids
=
self
.
model
.
generate
(
**
self
.
wrap_device
(
inputs
),
**
self
.
wrap_device
(
inputs
,
device
=
self
.
model
.
device
.
type
),
use_cache
=
True
,
**
kwargs
,
)
...
...
@@ -406,7 +403,7 @@ class HfRunner:
inputs
=
self
.
postprocess_inputs
(
inputs
)
output
=
self
.
model
.
generate
(
**
self
.
wrap_device
(
inputs
),
**
self
.
wrap_device
(
inputs
,
device
=
self
.
model
.
device
.
type
),
use_cache
=
True
,
do_sample
=
False
,
max_new_tokens
=
max_tokens
,
...
...
@@ -414,40 +411,39 @@ class HfRunner:
return_dict_in_generate
=
True
,
**
kwargs
,
)
seq_logprobs
:
List
[
torch
.
Tensor
]
=
[]
for
hidden_states
in
output
.
hidden_states
:
last_hidden_states
=
hidden_states
[
-
1
][
0
]
logits
=
torch
.
matmul
(
last_hidden_states
,
self
.
model
.
get_output_embeddings
().
weight
.
t
(),
)
if
self
.
model
.
get_output_embeddings
().
bias
is
not
None
:
logits
+=
self
.
model
.
get_output_embeddings
(
).
bias
.
unsqueeze
(
0
)
logprobs
=
F
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
seq_logprobs
.
append
(
logprobs
)
seq_logprobs
=
self
.
_hidden_states_to_seq_logprobs
(
output
.
hidden_states
)
all_logprobs
.
append
(
seq_logprobs
)
return
all_logprobs
def
_hidden_states_to_logprobs
(
def
_hidden_states_to_
seq_
logprobs
(
self
,
hidden_states
,
num_logprobs
,
)
->
Tuple
[
List
[
Dict
[
int
,
float
]],
int
]:
hidden_states
:
Tuple
[
Tuple
[
torch
.
Tensor
,
...],
...],
)
->
List
[
torch
.
Tensor
]:
output_embeddings
=
self
.
model
.
get_output_embeddings
()
seq_logprobs
:
List
[
torch
.
Tensor
]
=
[]
output_len
=
len
(
hidden_states
)
for
_
,
hidden_state
in
enumerate
(
hidden_states
):
last_hidden_states
=
hidden_state
[
-
1
][
0
]
logits
=
torch
.
matmul
(
last_hidden_states
,
self
.
model
.
get_
output_embeddings
()
.
weight
.
t
(),
last_hidden_states
.
to
(
output_embeddings
.
weight
.
device
)
,
output_embeddings
.
weight
.
t
(),
)
if
getattr
(
self
.
model
.
get_output_embeddings
(),
"bias"
,
None
)
is
not
None
:
logits
+=
self
.
model
.
get_output_embeddings
().
bias
.
unsqueeze
(
0
)
if
getattr
(
output_embeddings
,
"bias"
,
None
)
is
not
None
:
logits
+=
output_embeddings
.
bias
.
unsqueeze
(
0
)
logprobs
=
F
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
seq_logprobs
.
append
(
logprobs
)
return
seq_logprobs
def
_hidden_states_to_logprobs
(
self
,
hidden_states
:
Tuple
[
Tuple
[
torch
.
Tensor
,
...],
...],
num_logprobs
:
int
,
)
->
Tuple
[
List
[
Dict
[
int
,
float
]],
int
]:
seq_logprobs
=
self
.
_hidden_states_to_seq_logprobs
(
hidden_states
)
output_len
=
len
(
hidden_states
)
# convert to dict
seq_logprobs_lst
:
List
[
Dict
[
int
,
float
]]
=
[]
for
tok_idx
,
tok_logprobs
in
enumerate
(
seq_logprobs
):
...
...
@@ -500,7 +496,7 @@ class HfRunner:
inputs
=
self
.
postprocess_inputs
(
inputs
)
output
=
self
.
model
.
generate
(
**
self
.
wrap_device
(
inputs
),
**
self
.
wrap_device
(
inputs
,
device
=
self
.
model
.
device
.
type
),
use_cache
=
True
,
do_sample
=
False
,
max_new_tokens
=
max_tokens
,
...
...
@@ -543,12 +539,20 @@ class HfRunner:
for
(
encoder_prompt
,
decoder_prompt
)
in
to_enc_dec_tuple_list
(
encoder_decoder_prompts
):
encoder_input_ids
=
self
.
wrap_device
(
self
.
tokenizer
(
encoder_prompt
,
return_tensors
=
"pt"
).
input_ids
)
decoder_input_ids
=
(
None
if
decoder_prompt
is
None
else
self
.
wrap_device
(
self
.
tokenizer
(
encoder_prompt
,
return_tensors
=
"pt"
).
input_ids
,
device
=
self
.
model
.
device
.
type
,
)
if
decoder_prompt
is
None
:
decoder_input_ids
=
None
else
:
decoder_input_ids
=
self
.
wrap_device
(
self
.
tokenizer
(
decoder_prompt
,
return_tensors
=
"pt"
).
input_ids
))
return_tensors
=
"pt"
).
input_ids
,
device
=
self
.
model
.
device
.
type
,
)
output
=
self
.
model
.
generate
(
encoder_input_ids
,
...
...
tests/models/decoder_only/vision_language/test_llava_onevision.py
View file @
e1a3f5e8
...
...
@@ -16,8 +16,7 @@ from ...utils import check_logprobs_close
# Video test
HF_VIDEO_PROMPTS
=
VIDEO_ASSETS
.
prompts
({
"sample_demo_1"
:
"<|im_start|>user <video>
\n
why is this video funny?
\
<|im_end|><|im_start|>assistant
\n
"
"<|im_start|>user
\n
<video>
\n
why is this video funny?<|im_end|>
\n
<|im_start|>assistant
\n
"
# noqa: E501
})
models
=
[
"llava-hf/llava-onevision-qwen2-7b-ov-hf"
]
...
...
@@ -165,6 +164,9 @@ def run_video_test(
)
@
pytest
.
mark
.
skip
(
reason
=
"Model is too big, test passed on L40 locally but will OOM on CI machine."
)
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
"size_factors"
,
...
...
@@ -208,6 +210,9 @@ def test_models(hf_runner, vllm_runner, video_assets, model, size_factors,
)
@
pytest
.
mark
.
skip
(
reason
=
"Model is too big, test passed on L40 locally but will OOM on CI machine."
)
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
"sizes"
,
...
...
@@ -254,9 +259,8 @@ def run_image_test(
# max_model_len should be greater than image_feature_size
with
vllm_runner
(
model
,
dtype
=
dtype
,
max_num_seqs
=
1
,
max_model_len
=
16384
,
gpu_memory_utilization
=
0.98
,
max_num_seqs
=
2
,
tensor_parallel_size
=
tensor_parallel_size
,
distributed_executor_backend
=
distributed_executor_backend
,
enforce_eager
=
True
,
...
...
@@ -302,8 +306,9 @@ def run_image_test(
)
# FIXME: Swap to a smaller model for this architecture
@
pytest
.
mark
.
skip
(
reason
=
"Model OOMing on CI"
)
@
pytest
.
mark
.
skip
(
reason
=
"Model is too big, test passed on L40 locally but will OOM on CI machine."
)
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
...
...
@@ -316,14 +321,10 @@ def test_models_multiple_image_inputs(hf_runner, vllm_runner, image_assets,
inputs
=
[(
[
"<|im_start|>user <image><image>
\n
Describe 2 images.
\
<|im_end|><|im_start|>assistant
\n
"
,
"<|im_start|>user <image><image>
\n
Describe 2 images.
\
<|im_end|><|im_start|>assistant
\n
"
,
"<|im_start|>user <image><image><image><image>
\n
Describe 4 images.
\
<|im_end|><|im_start|>assistant
\n
"
,
"<|im_start|>user <image>
\n
What is the season?
\
<|im_end|><|im_start|>assistant
\n
"
,
"<|im_start|>user
\n
<image><image>
\n
Describe 2 images.<|im_end|>
\n
<|im_start|>assistant
\n
"
,
# noqa: E501
"<|im_start|>user
\n
<image><image>
\n
Describe 2 images.<|im_end|>
\n
<|im_start|>assistant
\n
"
,
# noqa: E501
"<|im_start|>user
\n
<image><image><image><image>
\n
Describe 4 images.<|im_end|>
\n
<|im_start|>assistant
\n
"
,
# noqa: E501
"<|im_start|>user
\n
<image>
\n
What is the season?<|im_end|>
\n
<|im_start|>assistant
\n
"
,
# noqa: E501
],
[
[
stop_sign
,
cherry_blossom
],
...
...
tests/models/decoder_only/vision_language/test_minicpmv.py
View file @
e1a3f5e8
...
...
@@ -79,7 +79,7 @@ def run_test(
# max_model_len should be greater than image_feature_size
with
vllm_runner
(
model
,
max_model_len
=
4096
,
max_num_seqs
=
1
,
max_num_seqs
=
2
,
dtype
=
dtype
,
limit_mm_per_prompt
=
{
"image"
:
mm_limit
},
tensor_parallel_size
=
tensor_parallel_size
,
...
...
tests/models/decoder_only/vision_language/test_phi3v.py
View file @
e1a3f5e8
...
...
@@ -90,7 +90,7 @@ def run_test(
# max_model_len should be greater than image_feature_size
with
vllm_runner
(
model
,
max_model_len
=
4096
,
max_num_seqs
=
1
,
max_num_seqs
=
2
,
dtype
=
dtype
,
limit_mm_per_prompt
=
{
"image"
:
mm_limit
},
tensor_parallel_size
=
tensor_parallel_size
,
...
...
tests/models/decoder_only/vision_language/test_qwen.py
View file @
e1a3f5e8
...
...
@@ -221,7 +221,7 @@ def run_test(
# Qwen encodes each image into a fixed content size of 256
with
vllm_runner
(
model
,
max_model_len
=
1024
,
max_num_seqs
=
1
,
max_num_seqs
=
2
,
dtype
=
dtype
,
limit_mm_per_prompt
=
{
"image"
:
mm_limit
},
tensor_parallel_size
=
tensor_parallel_size
,
...
...
tests/models/encoder_decoder/vision_language/test_broadcast.py
0 → 100644
View file @
e1a3f5e8
import
pytest
from
....utils
import
multi_gpu_test
@
multi_gpu_test
(
num_gpus
=
2
)
@
pytest
.
mark
.
parametrize
(
"distributed_executor_backend"
,
[
"ray"
,
"mp"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"meta-llama/Llama-3.2-11B-Vision-Instruct"
,
])
def
test_models
(
hf_runner
,
vllm_runner
,
image_assets
,
distributed_executor_backend
,
model
)
->
None
:
dtype
=
"half"
max_tokens
=
5
num_logprobs
=
5
tensor_parallel_size
=
2
if
model
.
startswith
(
"meta-llama/Llama-3.2-11B-Vision-Instruct"
):
from
.test_mllama
import
models
,
run_test
else
:
raise
NotImplementedError
(
f
"Unsupported model:
{
model
}
"
)
run_test
(
hf_runner
,
vllm_runner
,
image_assets
,
model
=
models
[
0
],
size_factors
=
[
0.25
,
0.5
,
1.0
],
dtype
=
dtype
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
tensor_parallel_size
,
distributed_executor_backend
=
distributed_executor_backend
,
)
tests/models/encoder_decoder/vision_language/test_mllama.py
View file @
e1a3f5e8
...
...
@@ -9,7 +9,6 @@ from vllm.sequence import SampleLogprobs
from
....conftest
import
(
IMAGE_ASSETS
,
HfRunner
,
PromptImageInput
,
VllmRunner
,
_ImageAssets
)
from
....utils
import
multi_gpu_test
from
...utils
import
check_logprobs_close
_LIMIT_IMAGE_PER_PROMPT
=
1
...
...
@@ -47,14 +46,46 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
if
token_id
!=
image_token_id
or
output_ids
[
idx
-
1
]
!=
image_token_id
]
assert
output_str
[
0
]
==
" "
hf_output_str
=
output_str
[
1
:]
hf_output_str
=
output_str
if
hf_output_ids
[
-
1
]
==
eos_token_id
:
hf_output_str
=
hf_output_str
+
tokenizer
.
decode
(
eos_token_id
)
return
hf_output_ids
,
hf_output_str
,
out_logprobs
def
_get_inputs
(
image_assets
:
_ImageAssets
,
*
,
size_factors
:
Optional
[
List
[
float
]]
=
None
,
sizes
:
Optional
[
List
[
Tuple
[
int
,
int
]]]
=
None
,
)
->
List
[
Tuple
[
List
[
str
],
PromptImageInput
]]:
images
=
[
asset
.
pil_image
for
asset
in
image_assets
]
if
size_factors
is
not
None
:
inputs_per_image
=
[(
[
prompt
for
_
in
size_factors
],
[
rescale_image_size
(
image
,
factor
)
for
factor
in
size_factors
],
)
for
image
,
prompt
in
zip
(
images
,
HF_IMAGE_PROMPTS
)]
elif
sizes
is
not
None
:
inputs_per_image
=
[(
[
prompt
if
size
is
not
None
else
text_only_prompts
[
0
]
for
size
in
sizes
],
[
image
.
resize
(
size
)
if
size
is
not
None
else
None
for
size
in
sizes
],
)
for
image
,
prompt
in
zip
(
images
,
HF_IMAGE_PROMPTS
)]
if
len
(
sizes
)
==
0
:
inputs_per_image
.
append
(
(
text_only_prompts
,
[
None
]
*
len
(
text_only_prompts
)))
else
:
raise
ValueError
(
"You must provide either `size_factors` or `sizes`"
)
return
inputs_per_image
@
overload
def
run_test
(
hf_runner
:
Type
[
HfRunner
],
...
...
@@ -103,39 +134,17 @@ def run_test(
tensor_parallel_size
:
int
,
distributed_executor_backend
:
Optional
[
str
]
=
None
,
):
images
=
[
asset
.
pil_image
for
asset
in
image_assets
]
if
size_factors
is
not
None
:
inputs_per_image
=
[(
[
prompt
for
_
in
size_factors
],
[
rescale_image_size
(
image
,
factor
)
for
factor
in
size_factors
],
)
for
image
,
prompt
in
zip
(
images
,
HF_IMAGE_PROMPTS
)]
elif
sizes
is
not
None
:
inputs_per_image
=
[(
[
prompt
if
size
is
not
None
else
text_only_prompts
[
0
]
for
size
in
sizes
],
[
image
.
resize
(
size
)
if
size
is
not
None
else
None
for
size
in
sizes
],
)
for
image
,
prompt
in
zip
(
images
,
HF_IMAGE_PROMPTS
)]
if
len
(
sizes
)
==
0
:
inputs_per_image
.
append
(
(
text_only_prompts
,
[
None
]
*
len
(
text_only_prompts
)))
else
:
raise
ValueError
(
"You must provide either `size_factors` or `sizes`"
)
_run_test
(
hf_runner
,
_run_test
(
hf_runner
,
vllm_runner
,
inputs_per_image
,
_get_inputs
(
image_assets
,
size_factors
=
size_factors
,
sizes
=
sizes
)
,
model
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
tensor_parallel_size
,
distributed_executor_backend
=
distributed_executor_backend
)
distributed_executor_backend
=
distributed_executor_backend
,
)
def
_run_test
(
...
...
@@ -167,8 +176,8 @@ def _run_test(
# max_model_len should be greater than image_feature_size
with
vllm_runner
(
model
,
dtype
=
dtype
,
max_num_seqs
=
16
,
max_model_len
=
4096
,
max_num_seqs
=
2
,
tensor_parallel_size
=
tensor_parallel_size
,
distributed_executor_backend
=
distributed_executor_backend
,
enforce_eager
=
True
,
...
...
@@ -185,7 +194,6 @@ def _run_test(
def
process
(
hf_inputs
:
BatchEncoding
):
return
hf_inputs
from
transformers
import
AutoConfig
from
transformers.models.mllama
import
MllamaConfig
as
MllamaConfigHf
# use transformer's MllamaConfig for hf_runner
...
...
@@ -193,6 +201,7 @@ def _run_test(
AutoConfig
.
register
(
"mllama"
,
MllamaConfigHf
,
exist_ok
=
True
)
with
hf_runner
(
model
,
dtype
=
dtype
,
model_kwargs
=
{
"device_map"
:
"auto"
},
postprocess_inputs
=
process
,
auto_cls
=
AutoModelForVision2Seq
)
as
hf_model
:
hf_outputs_per_image
=
[
...
...
@@ -218,10 +227,7 @@ def _run_test(
)
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
"sizes"
,
[
SIZES
=
[
# Text only
[],
# Single-size
...
...
@@ -236,40 +242,19 @@ def _run_test(
(
1024
,
1024
),
(
512
,
1536
),
(
512
,
2028
),
None
],
# mllama has 8 possible aspect ratios, carefully set the sizes
# to cover all of them
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_models
(
hf_runner
,
vllm_runner
,
image_assets
,
model
,
sizes
,
dtype
,
max_tokens
,
num_logprobs
)
->
None
:
run_test
(
hf_runner
,
vllm_runner
,
image_assets
,
model
,
sizes
=
sizes
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
1
,
)
]
@
multi_gpu_test
(
num_gpus
=
2
)
@
pytest
.
mark
.
skip
(
reason
=
"Model is too big, test passed on L40 locally but will OOM on CI machine."
)
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
"sizes"
,
[
[(
512
,
512
),
(
1024
,
512
),
(
1536
,
512
),
(
2048
,
512
),
(
512
,
1024
),
(
1024
,
1024
),
(
512
,
1536
),
(
512
,
2028
),
None
],
],
)
@
pytest
.
mark
.
parametrize
(
"sizes"
,
SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_models
_distributed
(
hf_runner
,
vllm_runner
,
image_assets
,
model
,
sizes
,
dtype
,
max_tokens
,
num_logprobs
)
->
None
:
def
test_models
(
hf_runner
,
vllm_runner
,
image_assets
,
model
,
sizes
,
dtype
,
max_tokens
,
num_logprobs
)
->
None
:
run_test
(
hf_runner
,
vllm_runner
,
...
...
@@ -279,5 +264,5 @@ def test_models_distributed(hf_runner, vllm_runner, image_assets, model, sizes,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
2
,
tensor_parallel_size
=
1
,
)
tests/models/utils.py
View file @
e1a3f5e8
import
warnings
from
typing
import
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
import
torch
from
vllm.config
import
ModelConfig
from
vllm.inputs
import
InputContext
from
vllm.sequence
import
Logprob
,
PromptLogprobs
,
SampleLogprobs
from
vllm.utils
import
is_cpu
TokensText
=
Tuple
[
List
[
int
],
str
]
...
...
@@ -247,6 +250,7 @@ def check_logprobs_close(
def
build_model_context
(
model_name
:
str
,
tokenizer_name
:
Optional
[
str
]
=
None
,
trust_remote_code
:
bool
=
False
,
dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
None
,
mm_processor_kwargs
:
Optional
[
Dict
]
=
None
,
limit_mm_per_prompt
:
Optional
[
Dict
]
=
None
):
"""Creates an InputContext for a given model.
...
...
@@ -264,12 +268,15 @@ def build_model_context(model_name: str,
"""
if
tokenizer_name
is
None
:
tokenizer_name
=
model_name
if
dtype
is
None
:
dtype
=
"bfloat16"
if
is_cpu
()
else
"half"
model_config
=
ModelConfig
(
model_name
,
tokenizer_name
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
trust_remote_code
,
dtype
=
"float32"
,
dtype
=
dtype
,
seed
=
0
,
mm_processor_kwargs
=
mm_processor_kwargs
,
limit_mm_per_prompt
=
limit_mm_per_prompt
,
...
...
vllm/inputs/registry.py
View file @
e1a3f5e8
...
...
@@ -185,16 +185,8 @@ class InputRegistry:
return
wrapper
def
_get_dummy_encoder_data_factory
(
self
,
model_cls
:
Type
[
nn
.
Module
]):
if
model_cls
in
self
.
_dummy_encoder_factories_by_model_type
:
dummy_factory
=
self
.
_dummy_encoder_factories_by_model_type
[
model_cls
]
else
:
logger
.
warning
(
"No dummy encoder data factory registered to %s. "
"Using the dummy data factory for the model instead."
,
model_cls
)
dummy_factory
=
self
.
_get_dummy_data_factory
(
model_cls
)
return
dummy_factory
return
self
.
_dummy_encoder_factories_by_model_type
\
.
get
(
model_cls
,
self
.
_default_dummy_data_factory
)
def
dummy_data_for_profiling
(
self
,
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
e1a3f5e8
...
...
@@ -159,7 +159,8 @@ def apply_fp8_linear(
# Making sure the dummy tensor is on the same device as the weight
global
TORCH_DEVICE_IDENTITY
if
TORCH_DEVICE_IDENTITY
.
device
!=
weight
.
device
:
if
(
TORCH_DEVICE_IDENTITY
is
not
None
and
TORCH_DEVICE_IDENTITY
.
device
!=
weight
.
device
):
TORCH_DEVICE_IDENTITY
=
TORCH_DEVICE_IDENTITY
.
to
(
weight
.
device
)
# GEMM
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment