Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
006693ed
Commit
006693ed
authored
Dec 01, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.11.2' into v0.11.2-ori
parents
4b51e6f1
275de341
Changes
544
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2198 additions
and
1020 deletions
+2198
-1020
tests/basic_correctness/test_basic_correctness.py
tests/basic_correctness/test_basic_correctness.py
+69
-78
tests/basic_correctness/test_cpu_offload.py
tests/basic_correctness/test_cpu_offload.py
+3
-2
tests/basic_correctness/test_cumem.py
tests/basic_correctness/test_cumem.py
+104
-70
tests/benchmarks/test_latency_cli.py
tests/benchmarks/test_latency_cli.py
+12
-2
tests/benchmarks/test_random_dataset.py
tests/benchmarks/test_random_dataset.py
+195
-55
tests/benchmarks/test_random_multimodal_dataset_video.py
tests/benchmarks/test_random_multimodal_dataset_video.py
+398
-0
tests/benchmarks/test_serve_cli.py
tests/benchmarks/test_serve_cli.py
+2
-3
tests/benchmarks/test_throughput_cli.py
tests/benchmarks/test_throughput_cli.py
+12
-2
tests/ci_envs.py
tests/ci_envs.py
+11
-4
tests/compile/backend.py
tests/compile/backend.py
+29
-15
tests/compile/piecewise/test_full_cudagraph.py
tests/compile/piecewise/test_full_cudagraph.py
+65
-126
tests/compile/piecewise/test_multiple_graphs.py
tests/compile/piecewise/test_multiple_graphs.py
+165
-102
tests/compile/piecewise/test_simple.py
tests/compile/piecewise/test_simple.py
+74
-62
tests/compile/piecewise/test_toy_llama.py
tests/compile/piecewise/test_toy_llama.py
+227
-176
tests/compile/silly_attention.py
tests/compile/silly_attention.py
+7
-6
tests/compile/test_aot_compile.py
tests/compile/test_aot_compile.py
+139
-0
tests/compile/test_async_tp.py
tests/compile/test_async_tp.py
+167
-110
tests/compile/test_basic_correctness.py
tests/compile/test_basic_correctness.py
+47
-37
tests/compile/test_config.py
tests/compile/test_config.py
+307
-63
tests/compile/test_decorator.py
tests/compile/test_decorator.py
+165
-107
No files found.
Too many changes to show.
To preserve performance only
544 of 544+
files are displayed.
Plain diff
Email patch
tests/basic_correctness/test_basic_correctness.py
View file @
006693ed
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
Run `pytest tests/basic_correctness/test_basic_correctness.py`.
Run `pytest tests/basic_correctness/test_basic_correctness.py`.
"""
"""
import
os
import
os
import
weakref
import
weakref
from
unittest.mock
import
Mock
from
unittest.mock
import
Mock
...
@@ -12,14 +13,14 @@ import pytest
...
@@ -12,14 +13,14 @@ import pytest
import
torch
import
torch
from
vllm
import
LLM
from
vllm
import
LLM
from
vllm.v1.engine.llm_engine
import
LLMEngine
as
LLMEngineV1
from
vllm.v1.engine.llm_engine
import
LLMEngine
from
..conftest
import
HfRunner
,
VllmRunner
from
..conftest
import
HfRunner
,
VllmRunner
from
..models.utils
import
check_outputs_equal
from
..models.utils
import
check_outputs_equal
from
..utils
import
multi_gpu_test
from
..utils
import
multi_gpu_test
MODELS
=
[
MODELS
=
[
"
google/gemma-2-2b-it
"
,
"
hmellor/tiny-random-Gemma2ForCausalLM
"
,
"meta-llama/Llama-3.2-1B-Instruct"
,
"meta-llama/Llama-3.2-1B-Instruct"
,
]
]
...
@@ -28,7 +29,7 @@ TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4")
...
@@ -28,7 +29,7 @@ TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4")
def
test_vllm_gc_ed
():
def
test_vllm_gc_ed
():
"""Verify vllm instance is GC'ed when it is deleted"""
"""Verify vllm instance is GC'ed when it is deleted"""
llm
=
LLM
(
"
distilbert/distilgpt2
"
)
llm
=
LLM
(
"
hmellor/tiny-random-LlamaForCausalLM
"
)
weak_llm
=
weakref
.
ref
(
llm
)
weak_llm
=
weakref
.
ref
(
llm
)
del
llm
del
llm
# If there's any circular reference to vllm, this fails
# If there's any circular reference to vllm, this fails
...
@@ -37,16 +38,21 @@ def test_vllm_gc_ed():
...
@@ -37,16 +38,21 @@ def test_vllm_gc_ed():
def
_fix_prompt_embed_outputs
(
def
_fix_prompt_embed_outputs
(
vllm_outputs
:
list
[
tuple
[
list
[
int
],
str
]],
hf_model
:
HfRunner
,
vllm_outputs
:
list
[
tuple
[
list
[
int
],
str
]],
example_prompts
:
list
[
str
])
->
list
[
tuple
[
list
[
int
],
str
]]:
hf_model
:
HfRunner
,
example_prompts
:
list
[
str
],
)
->
list
[
tuple
[
list
[
int
],
str
]]:
fixed_vllm_outputs
=
[]
fixed_vllm_outputs
=
[]
for
vllm_output
,
hf_input
,
prompt
in
zip
(
for
vllm_output
,
hf_input
,
prompt
in
zip
(
vllm_outputs
,
hf_model
.
get_inputs
(
example_prompts
),
vllm_outputs
,
hf_model
.
get_inputs
(
example_prompts
),
example_prompts
example_prompts
):
):
hf_input_ids
=
hf_input
[
"input_ids"
].
tolist
()[
0
]
hf_input_ids
=
hf_input
[
"input_ids"
].
tolist
()[
0
]
fixed_vllm_outputs
.
append
(
fixed_vllm_outputs
.
append
(
(
hf_input_ids
+
vllm_output
[
0
][
len
(
hf_input_ids
):],
(
prompt
+
vllm_output
[
1
]))
hf_input_ids
+
vllm_output
[
0
][
len
(
hf_input_ids
)
:],
prompt
+
vllm_output
[
1
],
)
)
return
fixed_vllm_outputs
return
fixed_vllm_outputs
...
@@ -69,8 +75,7 @@ def test_models(
...
@@ -69,8 +75,7 @@ def test_models(
enable_prompt_embeds
:
bool
,
enable_prompt_embeds
:
bool
,
)
->
None
:
)
->
None
:
if
backend
==
"XFORMERS"
and
model
==
"google/gemma-2-2b-it"
:
if
backend
==
"XFORMERS"
and
model
==
"google/gemma-2-2b-it"
:
pytest
.
skip
(
pytest
.
skip
(
f
"
{
backend
}
does not support gemma2 with full context length."
)
f
"
{
backend
}
does not support gemma2 with full context length."
)
with
monkeypatch
.
context
()
as
m
:
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
...
@@ -78,34 +83,35 @@ def test_models(
...
@@ -78,34 +83,35 @@ def test_models(
# 5042 tokens for gemma2
# 5042 tokens for gemma2
# gemma2 has alternating sliding window size of 4096
# gemma2 has alternating sliding window size of 4096
# we need a prompt with more than 4096 tokens to test the sliding window
# we need a prompt with more than 4096 tokens to test the sliding window
prompt
=
"The following numbers of the sequence "
+
", "
.
join
(
prompt
=
(
str
(
i
)
for
i
in
range
(
1024
))
+
" are:"
"The following numbers of the sequence "
+
", "
.
join
(
str
(
i
)
for
i
in
range
(
1024
))
+
" are:"
)
example_prompts
=
[
prompt
]
example_prompts
=
[
prompt
]
with
hf_runner
(
model
)
as
hf_model
:
with
hf_runner
(
model
)
as
hf_model
:
hf_outputs
=
hf_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
hf_outputs
=
hf_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
if
enable_prompt_embeds
:
if
enable_prompt_embeds
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
prompt_embeds
=
hf_model
.
get_prompt_embeddings
(
prompt_embeds
=
hf_model
.
get_prompt_embeddings
(
example_prompts
)
example_prompts
)
with
VllmRunner
(
with
VllmRunner
(
model
,
model
,
max_model_len
=
8192
,
max_model_len
=
8192
,
enforce_eager
=
enforce_eager
,
enforce_eager
=
enforce_eager
,
enable_prompt_embeds
=
enable_prompt_embeds
,
enable_prompt_embeds
=
enable_prompt_embeds
,
gpu_memory_utilization
=
0.7
,
gpu_memory_utilization
=
0.7
,
async_scheduling
=
async_scheduling
,
async_scheduling
=
async_scheduling
,
distributed_executor_backend
=
model_executor
,
distributed_executor_backend
=
model_executor
,
)
as
vllm_model
:
)
as
vllm_model
:
if
enable_prompt_embeds
:
if
enable_prompt_embeds
:
vllm_outputs
=
vllm_model
.
generate_greedy
(
vllm_outputs
=
vllm_model
.
generate_greedy
(
prompt_embeds
,
max_tokens
)
prompt_embeds
,
max_tokens
)
vllm_outputs
=
_fix_prompt_embed_outputs
(
vllm_outputs
=
_fix_prompt_embed_outputs
(
vllm_outputs
,
hf_model
,
example_prompts
)
vllm_outputs
,
hf_model
,
example_prompts
)
else
:
else
:
vllm_outputs
=
vllm_model
.
generate_greedy
(
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
example_prompts
,
max_tokens
)
check_outputs_equal
(
check_outputs_equal
(
outputs_0_lst
=
hf_outputs
,
outputs_0_lst
=
hf_outputs
,
...
@@ -117,21 +123,18 @@ def test_models(
...
@@ -117,21 +123,18 @@ def test_models(
@
multi_gpu_test
(
num_gpus
=
2
)
@
multi_gpu_test
(
num_gpus
=
2
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"model, distributed_executor_backend, attention_backend, "
"model, distributed_executor_backend, attention_backend, test_suite, extra_env"
,
"test_suite, extra_env"
,
[
[
(
"distilbert/distilgpt2"
,
"ray"
,
""
,
"L4"
,
{}),
(
"facebook/opt-125m"
,
"ray"
,
""
,
"L4"
,
{}),
(
"distilbert/distilgpt2"
,
"mp"
,
""
,
"L4"
,
{}),
(
"facebook/opt-125m"
,
"mp"
,
""
,
"L4"
,
{}),
(
"distilbert/distilgpt2"
,
"ray"
,
""
,
"L4"
,
{
(
"facebook/opt-125m"
,
"ray"
,
""
,
"L4"
,
{
"VLLM_SLEEP_WHEN_IDLE"
:
"1"
}),
"VLLM_SLEEP_WHEN_IDLE"
:
"1"
(
"facebook/opt-125m"
,
"mp"
,
""
,
"L4"
,
{
"VLLM_SLEEP_WHEN_IDLE"
:
"1"
}),
}),
(
"distilbert/distilgpt2"
,
"mp"
,
""
,
"L4"
,
{
"VLLM_SLEEP_WHEN_IDLE"
:
"1"
}),
(
"meta-llama/Llama-3.2-1B-Instruct"
,
"ray"
,
""
,
"L4"
,
{}),
(
"meta-llama/Llama-3.2-1B-Instruct"
,
"ray"
,
""
,
"L4"
,
{}),
(
"meta-llama/Llama-3.2-1B-Instruct"
,
"mp"
,
""
,
"L4"
,
{}),
(
"meta-llama/Llama-3.2-1B-Instruct"
,
"mp"
,
""
,
"L4"
,
{}),
(
"distilbert/distilgpt2"
,
"ray"
,
""
,
"A100"
,
{}),
(
"facebook/opt-125m"
,
"ray"
,
""
,
"A100"
,
{}),
(
"distilbert/distilgpt2"
,
"mp"
,
""
,
"A100"
,
{}),
(
"facebook/opt-125m"
,
"mp"
,
""
,
"A100"
,
{}),
])
],
)
@
pytest
.
mark
.
parametrize
(
"enable_prompt_embeds"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"enable_prompt_embeds"
,
[
True
,
False
])
def
test_models_distributed
(
def
test_models_distributed
(
monkeypatch
:
pytest
.
MonkeyPatch
,
monkeypatch
:
pytest
.
MonkeyPatch
,
...
@@ -149,13 +152,14 @@ def test_models_distributed(
...
@@ -149,13 +152,14 @@ def test_models_distributed(
pytest
.
skip
(
f
"Skip test for
{
test_suite
}
"
)
pytest
.
skip
(
f
"Skip test for
{
test_suite
}
"
)
with
monkeypatch
.
context
()
as
monkeypatch_context
:
with
monkeypatch
.
context
()
as
monkeypatch_context
:
if
model
==
"meta-llama/Llama-3.2-1B-Instruct"
and
distributed_executor_backend
==
"ray"
and
attention_backend
==
""
and
test_suite
==
"L4"
:
# noqa
if
(
if
enable_prompt_embeds
:
model
==
"meta-llama/Llama-3.2-1B-Instruct"
pytest
.
skip
(
and
distributed_executor_backend
==
"ray"
"enable_prompt_embeds does not work with ray compiled dag."
and
attention_backend
==
""
)
and
test_suite
==
"L4"
monkeypatch_context
.
setenv
(
"VLLM_USE_RAY_SPMD_WORKER"
,
"1"
)
and
enable_prompt_embeds
monkeypatch_context
.
setenv
(
"VLLM_USE_RAY_COMPILED_DAG"
,
"1"
)
):
# noqa
pytest
.
skip
(
"enable_prompt_embeds does not work with ray compiled dag."
)
if
attention_backend
:
if
attention_backend
:
monkeypatch_context
.
setenv
(
monkeypatch_context
.
setenv
(
...
@@ -175,30 +179,26 @@ def test_models_distributed(
...
@@ -175,30 +179,26 @@ def test_models_distributed(
# will hurt multiprocessing backend with fork method
# will hurt multiprocessing backend with fork method
# (the default method).
# (the default method).
with
vllm_runner
(
with
vllm_runner
(
model
,
model
,
dtype
=
dtype
,
dtype
=
dtype
,
tensor_parallel_size
=
2
,
tensor_parallel_size
=
2
,
distributed_executor_backend
=
distributed_executor_backend
,
distributed_executor_backend
=
distributed_executor_backend
,
enable_prompt_embeds
=
enable_prompt_embeds
,
enable_prompt_embeds
=
enable_prompt_embeds
,
gpu_memory_utilization
=
0.7
,
gpu_memory_utilization
=
0.7
,
)
as
vllm_model
:
)
as
vllm_model
:
if
enable_prompt_embeds
:
if
enable_prompt_embeds
:
with
hf_runner
(
model
,
dtype
=
dtype
)
as
hf_model
:
with
hf_runner
(
model
,
dtype
=
dtype
)
as
hf_model
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
prompt_embeds
=
hf_model
.
get_prompt_embeddings
(
prompt_embeds
=
hf_model
.
get_prompt_embeddings
(
example_prompts
)
example_prompts
)
vllm_outputs
=
vllm_model
.
generate_greedy
(
prompt_embeds
,
max_tokens
)
vllm_outputs
=
vllm_model
.
generate_greedy
(
prompt_embeds
,
max_tokens
)
vllm_outputs
=
_fix_prompt_embed_outputs
(
vllm_outputs
=
_fix_prompt_embed_outputs
(
vllm_outputs
,
hf_model
,
example_prompts
)
vllm_outputs
,
hf_model
,
example_prompts
hf_outputs
=
hf_model
.
generate_greedy
(
)
example_prompts
,
max_tokens
)
hf_outputs
=
hf_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
else
:
else
:
vllm_outputs
=
vllm_model
.
generate_greedy
(
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
example_prompts
,
max_tokens
)
with
hf_runner
(
model
,
dtype
=
dtype
)
as
hf_model
:
with
hf_runner
(
model
,
dtype
=
dtype
)
as
hf_model
:
hf_outputs
=
hf_model
.
generate_greedy
(
hf_outputs
=
hf_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
example_prompts
,
max_tokens
)
check_outputs_equal
(
check_outputs_equal
(
outputs_0_lst
=
hf_outputs
,
outputs_0_lst
=
hf_outputs
,
...
@@ -209,27 +209,18 @@ def test_models_distributed(
...
@@ -209,27 +209,18 @@ def test_models_distributed(
def
test_failed_model_execution
(
vllm_runner
,
monkeypatch
)
->
None
:
def
test_failed_model_execution
(
vllm_runner
,
monkeypatch
)
->
None
:
from
vllm.envs
import
VLLM_USE_V1
if
not
VLLM_USE_V1
:
pytest
.
skip
(
"Skipping V0 test, dump input not supported"
)
# Needed to mock an error in the same process
# Needed to mock an error in the same process
monkeypatch
.
setenv
(
'
VLLM_ENABLE_V1_MULTIPROCESSING
'
,
'0'
)
monkeypatch
.
setenv
(
"
VLLM_ENABLE_V1_MULTIPROCESSING
"
,
"0"
)
with
vllm_runner
(
'
facebook/opt-125m
'
,
enforce_eager
=
True
)
as
vllm_model
:
with
vllm_runner
(
"
facebook/opt-125m
"
,
enforce_eager
=
True
)
as
vllm_model
:
if
isinstance
(
vllm_model
.
llm
.
llm_engine
,
LLMEngine
V1
):
if
isinstance
(
vllm_model
.
llm
.
llm_engine
,
LLMEngine
):
v1_test_failed_model_execution
(
vllm_model
)
v1_test_failed_model_execution
(
vllm_model
)
def
v1_test_failed_model_execution
(
vllm_model
):
def
v1_test_failed_model_execution
(
vllm_model
):
engine
=
vllm_model
.
llm
.
llm_engine
engine
=
vllm_model
.
llm
.
llm_engine
mocked_execute_model
=
Mock
(
mocked_execute_model
=
Mock
(
side_effect
=
RuntimeError
(
"Mocked Critical Error"
))
side_effect
=
RuntimeError
(
"Mocked Critical Error"
))
engine
.
engine_core
.
engine_core
.
model_executor
.
execute_model
=
mocked_execute_model
engine
.
engine_core
.
engine_core
.
model_executor
.
execute_model
=
\
mocked_execute_model
with
pytest
.
raises
(
RuntimeError
)
as
exc_info
:
with
pytest
.
raises
(
RuntimeError
)
as
exc_info
:
prompts
=
[
prompts
=
[
...
...
tests/basic_correctness/test_cpu_offload.py
View file @
006693ed
...
@@ -5,5 +5,6 @@ from ..utils import compare_two_settings
...
@@ -5,5 +5,6 @@ from ..utils import compare_two_settings
def
test_cpu_offload
():
def
test_cpu_offload
():
compare_two_settings
(
"meta-llama/Llama-3.2-1B-Instruct"
,
[],
compare_two_settings
(
[
"--cpu-offload-gb"
,
"1"
])
"hmellor/tiny-random-LlamaForCausalLM"
,
[],
[
"--cpu-offload-gb"
,
"1"
]
)
tests/basic_correctness/test_cumem.py
View file @
006693ed
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
asyncio
import
pytest
import
pytest
import
torch
import
torch
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
AsyncEngineArgs
,
AsyncLLMEngine
,
SamplingParams
from
vllm.device_allocator.cumem
import
CuMemAllocator
from
vllm.device_allocator.cumem
import
CuMemAllocator
from
vllm.utils
import
GiB_bytes
from
vllm.platforms
import
current_platform
from
vllm.utils.mem_constants
import
GiB_bytes
from
..utils
import
create_new_process_for_each_test
from
..utils
import
create_new_process_for_each_test
@
create_new_process_for_each_test
()
@
create_new_process_for_each_test
(
"fork"
if
not
current_platform
.
is_rocm
()
else
"spawn"
)
def
test_python_error
():
def
test_python_error
():
"""
"""
Test if Python error occurs when there's low-level
Test if Python error occurs when there's low-level
...
@@ -23,13 +26,13 @@ def test_python_error():
...
@@ -23,13 +26,13 @@ def test_python_error():
tensors
=
[]
tensors
=
[]
with
allocator
.
use_memory_pool
():
with
allocator
.
use_memory_pool
():
# allocate 70% of the total memory
# allocate 70% of the total memory
x
=
torch
.
empty
(
alloc_bytes
,
dtype
=
torch
.
uint8
,
device
=
'
cuda
'
)
x
=
torch
.
empty
(
alloc_bytes
,
dtype
=
torch
.
uint8
,
device
=
"
cuda
"
)
tensors
.
append
(
x
)
tensors
.
append
(
x
)
# release the memory
# release the memory
allocator
.
sleep
()
allocator
.
sleep
()
# allocate more memory than the total memory
# allocate more memory than the total memory
y
=
torch
.
empty
(
alloc_bytes
,
dtype
=
torch
.
uint8
,
device
=
'
cuda
'
)
y
=
torch
.
empty
(
alloc_bytes
,
dtype
=
torch
.
uint8
,
device
=
"
cuda
"
)
tensors
.
append
(
y
)
tensors
.
append
(
y
)
with
pytest
.
raises
(
RuntimeError
):
with
pytest
.
raises
(
RuntimeError
):
# when the allocator is woken up, it should raise an error
# when the allocator is woken up, it should raise an error
...
@@ -37,21 +40,21 @@ def test_python_error():
...
@@ -37,21 +40,21 @@ def test_python_error():
allocator
.
wake_up
()
allocator
.
wake_up
()
@
create_new_process_for_each_test
()
@
create_new_process_for_each_test
(
"fork"
if
not
current_platform
.
is_rocm
()
else
"spawn"
)
def
test_basic_cumem
():
def
test_basic_cumem
():
# some tensors from default memory pool
# some tensors from default memory pool
shape
=
(
1024
,
1024
)
shape
=
(
1024
,
1024
)
x
=
torch
.
empty
(
shape
,
device
=
'
cuda
'
)
x
=
torch
.
empty
(
shape
,
device
=
"
cuda
"
)
x
.
zero_
()
x
.
zero_
()
# some tensors from custom memory pool
# some tensors from custom memory pool
allocator
=
CuMemAllocator
.
get_instance
()
allocator
=
CuMemAllocator
.
get_instance
()
with
allocator
.
use_memory_pool
():
with
allocator
.
use_memory_pool
():
# custom memory pool
# custom memory pool
y
=
torch
.
empty
(
shape
,
device
=
'
cuda
'
)
y
=
torch
.
empty
(
shape
,
device
=
"
cuda
"
)
y
.
zero_
()
y
.
zero_
()
y
+=
1
y
+=
1
z
=
torch
.
empty
(
shape
,
device
=
'
cuda
'
)
z
=
torch
.
empty
(
shape
,
device
=
"
cuda
"
)
z
.
zero_
()
z
.
zero_
()
z
+=
2
z
+=
2
...
@@ -70,20 +73,20 @@ def test_basic_cumem():
...
@@ -70,20 +73,20 @@ def test_basic_cumem():
assert
torch
.
allclose
(
output
,
torch
.
ones_like
(
output
)
*
3
)
assert
torch
.
allclose
(
output
,
torch
.
ones_like
(
output
)
*
3
)
@
create_new_process_for_each_test
()
@
create_new_process_for_each_test
(
"fork"
if
not
current_platform
.
is_rocm
()
else
"spawn"
)
def
test_cumem_with_cudagraph
():
def
test_cumem_with_cudagraph
():
allocator
=
CuMemAllocator
.
get_instance
()
allocator
=
CuMemAllocator
.
get_instance
()
with
allocator
.
use_memory_pool
():
with
allocator
.
use_memory_pool
():
weight
=
torch
.
eye
(
1024
,
device
=
'
cuda
'
)
weight
=
torch
.
eye
(
1024
,
device
=
"
cuda
"
)
with
allocator
.
use_memory_pool
(
tag
=
"discard"
):
with
allocator
.
use_memory_pool
(
tag
=
"discard"
):
cache
=
torch
.
empty
(
1024
,
1024
,
device
=
'
cuda
'
)
cache
=
torch
.
empty
(
1024
,
1024
,
device
=
"
cuda
"
)
def
model
(
x
):
def
model
(
x
):
out
=
x
@
weight
out
=
x
@
weight
cache
[:
out
.
size
(
0
)].
copy_
(
out
)
cache
[:
out
.
size
(
0
)].
copy_
(
out
)
return
out
+
1
return
out
+
1
x
=
torch
.
empty
(
128
,
1024
,
device
=
'
cuda
'
)
x
=
torch
.
empty
(
128
,
1024
,
device
=
"
cuda
"
)
# warmup
# warmup
model
(
x
)
model
(
x
)
...
@@ -109,80 +112,72 @@ def test_cumem_with_cudagraph():
...
@@ -109,80 +112,72 @@ def test_cumem_with_cudagraph():
model_graph
.
replay
()
model_graph
.
replay
()
# cache content is as expected
# cache content is as expected
assert
torch
.
allclose
(
x
,
cache
[:
x
.
size
(
0
)])
assert
torch
.
allclose
(
x
,
cache
[:
x
.
size
(
0
)])
# output content is as expected
# output content is as expected
assert
torch
.
allclose
(
y
,
x
+
1
)
assert
torch
.
allclose
(
y
,
x
+
1
)
@
create_new_process_for_each_test
()
@
create_new_process_for_each_test
(
"fork"
if
not
current_platform
.
is_rocm
()
else
"spawn"
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"model
, use_v1
"
,
"model"
,
[
[
# sleep mode with safetensors
# sleep mode with safetensors
(
"meta-llama/Llama-3.2-1B"
,
True
)
,
"hmellor/tiny-random-LlamaForCausalLM"
,
# sleep mode with pytorch checkpoint
# sleep mode with pytorch checkpoint
(
"facebook/opt-125m"
,
True
),
"facebook/opt-125m"
,
])
],
def
test_end_to_end
(
monkeypatch
:
pytest
.
MonkeyPatch
,
model
:
str
,
use_v1
:
bool
):
)
with
monkeypatch
.
context
()
as
m
:
def
test_end_to_end
(
model
:
str
):
assert
use_v1
free
,
total
=
torch
.
cuda
.
mem_get_info
()
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
used_bytes_baseline
=
total
-
free
# in case other process is running
free
,
total
=
torch
.
cuda
.
mem_get_info
()
llm
=
LLM
(
model
,
enable_sleep_mode
=
True
)
used_bytes_baseline
=
total
-
free
# in case other process is running
prompt
=
"How are you?"
llm
=
LLM
(
model
,
enable_sleep_mode
=
True
)
sampling_params
=
SamplingParams
(
temperature
=
0
,
max_tokens
=
10
)
prompt
=
"How are you?"
output
=
llm
.
generate
(
prompt
,
sampling_params
)
sampling_params
=
SamplingParams
(
temperature
=
0
,
max_tokens
=
10
)
output
=
llm
.
generate
(
prompt
,
sampling_params
)
# the benefit of `llm.sleep(level=2)` is mainly CPU memory usage,
# which is difficult to measure in the test. therefore, we only
# test sleep level 1 here.
llm
.
sleep
(
level
=
1
)
free_gpu_bytes_after_sleep
,
total
=
torch
.
cuda
.
mem_get_info
()
used_bytes
=
total
-
free_gpu_bytes_after_sleep
-
used_bytes_baseline
# now the memory usage is mostly cudagraph memory pool,
# and it should be less than the model weights (1B model, 2GiB weights)
# NOTE: In V1, the memory buffer for logits (max_num_reqs x vocab_size)
# is captured but cannot be releasesd from PyTorch due to a known bug,
# therefore high memory usage after `llm.sleep` is called is expected.
# FIXME(youkaichao & ywang96): Fix memory buffer issue with sleep mode
# in V1.
if
use_v1
:
assert
used_bytes
<
7
*
GiB_bytes
else
:
assert
used_bytes
<
2
*
GiB_bytes
llm
.
wake_up
()
output2
=
llm
.
generate
(
prompt
,
sampling_params
)
# cmp output
assert
output
[
0
].
outputs
[
0
].
text
==
output2
[
0
].
outputs
[
0
].
text
llm
.
sleep
(
level
=
1
)
# the benefit of `llm.sleep(level=2)` is mainly CPU memory usage,
llm
.
wake_up
(
tags
=
[
"weights"
])
# which is difficult to measure in the test. therefore, we only
# test sleep level 1 here.
llm
.
sleep
(
level
=
1
)
free_gpu_bytes_wake_up_w
,
total
=
torch
.
cuda
.
mem_get_info
()
free_gpu_bytes_after_sleep
,
total
=
torch
.
cuda
.
mem_get_info
()
used_bytes
=
total
-
free_gpu_bytes_wake_up_w
-
used_bytes_baseline
used_bytes
=
total
-
free_gpu_bytes_after_sleep
-
used_bytes_baseline
# now the memory usage is mostly cudagraph memory pool,
# and it should be less than the model weights (1B model, 2GiB weights)
# should just reallocate memory for weights (1B model, ~2GiB weights)
# NOTE: In V1, the memory buffer for logits (max_num_reqs x vocab_size)
if
use_v1
:
# is captured but cannot be releasesd from PyTorch due to a known bug,
assert
used_bytes
<
10
*
GiB_bytes
# therefore high memory usage after `llm.sleep` is called is expected.
else
:
# FIXME(youkaichao & ywang96): Fix memory buffer issue with sleep mode
assert
used_bytes
<
6
*
GiB_bytes
# in V1.
assert
used_bytes
<
7
*
GiB_bytes
# now allocate kv cache memory
llm
.
wake_up
()
llm
.
wake_up
(
tags
=
[
"kv_cache"
])
output2
=
llm
.
generate
(
prompt
,
sampling_params
)
output3
=
llm
.
generate
(
prompt
,
sampling_params
)
# cmp output
assert
output
[
0
].
outputs
[
0
].
text
==
output2
[
0
].
outputs
[
0
].
text
# cmp output
llm
.
sleep
(
level
=
1
)
assert
output
[
0
].
outputs
[
0
].
text
==
output3
[
0
].
outputs
[
0
].
text
llm
.
wake_up
(
tags
=
[
"weights"
])
free_gpu_bytes_wake_up_w
,
total
=
torch
.
cuda
.
mem_get_info
()
used_bytes
=
total
-
free_gpu_bytes_wake_up_w
-
used_bytes_baseline
# should just reallocate memory for weights (1B model, ~2GiB weights)
assert
used_bytes
<
10
*
GiB_bytes
# now allocate kv cache memory
llm
.
wake_up
(
tags
=
[
"kv_cache"
])
output3
=
llm
.
generate
(
prompt
,
sampling_params
)
# cmp output
assert
output
[
0
].
outputs
[
0
].
text
==
output3
[
0
].
outputs
[
0
].
text
@
create_new_process_for_each_test
()
@
create_new_process_for_each_test
()
def
test_deep_sleep
():
def
test_deep_sleep
():
model
=
"
Qwen/Qwen3-0.6B
"
model
=
"
hmellor/tiny-random-LlamaForCausalLM
"
free
,
total
=
torch
.
cuda
.
mem_get_info
()
free
,
total
=
torch
.
cuda
.
mem_get_info
()
used_bytes_baseline
=
total
-
free
# in case other process is running
used_bytes_baseline
=
total
-
free
# in case other process is running
llm
=
LLM
(
model
,
enable_sleep_mode
=
True
)
llm
=
LLM
(
model
,
enable_sleep_mode
=
True
)
...
@@ -209,3 +204,42 @@ def test_deep_sleep():
...
@@ -209,3 +204,42 @@ def test_deep_sleep():
# cmp output
# cmp output
assert
output
[
0
].
outputs
[
0
].
text
==
output2
[
0
].
outputs
[
0
].
text
assert
output
[
0
].
outputs
[
0
].
text
==
output2
[
0
].
outputs
[
0
].
text
@
create_new_process_for_each_test
()
def
test_deep_sleep_async
():
async
def
test
():
model
=
"hmellor/tiny-random-LlamaForCausalLM"
free
,
total
=
torch
.
cuda
.
mem_get_info
()
used_bytes_baseline
=
total
-
free
# in case other process is running
engine_args
=
AsyncEngineArgs
(
model
=
model
,
enable_sleep_mode
=
True
,
)
llm
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
)
prompt
=
"How are you?"
sampling_params
=
SamplingParams
(
temperature
=
0
,
max_tokens
=
10
)
outputs
=
llm
.
generate
(
prompt
,
sampling_params
,
request_id
=
"test_request_id1"
)
async
for
output
in
outputs
:
pass
# Put the engine to deep sleep
await
llm
.
sleep
(
level
=
2
)
await
llm
.
wake_up
(
tags
=
[
"weights"
])
await
llm
.
collective_rpc
(
"reload_weights"
)
free_gpu_bytes_wake_up_w
,
total
=
torch
.
cuda
.
mem_get_info
()
used_bytes
=
total
-
free_gpu_bytes_wake_up_w
-
used_bytes_baseline
assert
used_bytes
<
4
*
GiB_bytes
# now allocate kv cache and cuda graph memory
await
llm
.
wake_up
(
tags
=
[
"kv_cache"
])
outputs2
=
llm
.
generate
(
prompt
,
sampling_params
,
request_id
=
"test_request_id2"
)
async
for
output2
in
outputs2
:
pass
# cmp output
assert
output
.
outputs
[
0
].
text
==
output2
.
outputs
[
0
].
text
asyncio
.
run
(
test
())
tests/benchmarks/test_latency_cli.py
View file @
006693ed
...
@@ -10,8 +10,18 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
...
@@ -10,8 +10,18 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
@
pytest
.
mark
.
benchmark
@
pytest
.
mark
.
benchmark
def
test_bench_latency
():
def
test_bench_latency
():
command
=
[
command
=
[
"vllm"
,
"bench"
,
"latency"
,
"--model"
,
MODEL_NAME
,
"--input-len"
,
"32"
,
"vllm"
,
"--output-len"
,
"1"
,
"--enforce-eager"
,
"--load-format"
,
"dummy"
"bench"
,
"latency"
,
"--model"
,
MODEL_NAME
,
"--input-len"
,
"32"
,
"--output-len"
,
"1"
,
"--enforce-eager"
,
"--load-format"
,
"dummy"
,
]
]
result
=
subprocess
.
run
(
command
,
capture_output
=
True
,
text
=
True
)
result
=
subprocess
.
run
(
command
,
capture_output
=
True
,
text
=
True
)
print
(
result
.
stdout
)
print
(
result
.
stdout
)
...
...
tests/benchmarks/test_random_dataset.py
View file @
006693ed
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
random
import
random
from
typing
import
Any
,
NamedTuple
,
Optional
,
cast
from
typing
import
Any
,
NamedTuple
,
cast
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
from
transformers
import
AutoTokenizer
,
PreTrainedTokenizerBase
from
transformers
import
AutoTokenizer
,
PreTrainedTokenizerBase
from
vllm.benchmarks.datasets
import
(
RandomDataset
,
RandomMultiModalDataset
,
from
vllm.benchmarks.datasets
import
(
SampleRequest
)
RandomDataset
,
RandomMultiModalDataset
,
SampleRequest
,
)
@
pytest
.
fixture
(
scope
=
"session"
)
@
pytest
.
fixture
(
scope
=
"session"
)
...
@@ -27,11 +30,9 @@ class Params(NamedTuple):
...
@@ -27,11 +30,9 @@ class Params(NamedTuple):
@
pytest
.
fixture
(
scope
=
"session"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
random_dataset_params
()
->
Params
:
def
random_dataset_params
()
->
Params
:
return
Params
(
num_requests
=
16
,
return
Params
(
prefix_len
=
7
,
num_requests
=
16
,
prefix_len
=
7
,
range_ratio
=
0.3
,
input_len
=
50
,
output_len
=
20
range_ratio
=
0.3
,
)
input_len
=
50
,
output_len
=
20
)
def
_fingerprint_sample
(
req
:
SampleRequest
)
->
tuple
[
str
,
int
,
int
]:
def
_fingerprint_sample
(
req
:
SampleRequest
)
->
tuple
[
str
,
int
,
int
]:
...
@@ -39,13 +40,15 @@ def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]:
...
@@ -39,13 +40,15 @@ def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]:
return
(
req
.
prompt
,
req
.
prompt_len
,
req
.
expected_output_len
)
return
(
req
.
prompt
,
req
.
prompt_len
,
req
.
expected_output_len
)
def
_collect_samples
(
dataset
:
RandomDataset
,
def
_collect_samples
(
tokenizer
:
PreTrainedTokenizerBase
,
dataset
:
RandomDataset
,
num_requests
:
int
=
16
,
tokenizer
:
PreTrainedTokenizerBase
,
prefix_len
:
int
=
7
,
num_requests
:
int
=
16
,
range_ratio
:
float
=
0.3
,
prefix_len
:
int
=
7
,
input_len
:
int
=
50
,
range_ratio
:
float
=
0.3
,
output_len
:
int
=
20
)
->
list
[
tuple
[
str
,
int
,
int
]]:
input_len
:
int
=
50
,
output_len
:
int
=
20
,
)
->
list
[
tuple
[
str
,
int
,
int
]]:
samples
=
dataset
.
sample
(
samples
=
dataset
.
sample
(
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
num_requests
=
num_requests
,
num_requests
=
num_requests
,
...
@@ -59,8 +62,8 @@ def _collect_samples(dataset: RandomDataset,
...
@@ -59,8 +62,8 @@ def _collect_samples(dataset: RandomDataset,
@
pytest
.
mark
.
benchmark
@
pytest
.
mark
.
benchmark
def
test_random_dataset_same_seed
(
def
test_random_dataset_same_seed
(
hf_tokenizer
:
PreTrainedTokenizerBase
,
hf_tokenizer
:
PreTrainedTokenizerBase
,
random_dataset_params
:
Params
random_dataset_params
:
Params
)
->
None
:
)
->
None
:
"""Same seed should yield identical outputs, even if global RNGs change.
"""Same seed should yield identical outputs, even if global RNGs change.
This guards against accidental reliance on Python's random or np.random
This guards against accidental reliance on Python's random or np.random
...
@@ -70,13 +73,15 @@ def test_random_dataset_same_seed(
...
@@ -70,13 +73,15 @@ def test_random_dataset_same_seed(
common_seed
=
123
common_seed
=
123
dataset_a
=
RandomDataset
(
random_seed
=
common_seed
)
dataset_a
=
RandomDataset
(
random_seed
=
common_seed
)
dataset_b
=
RandomDataset
(
random_seed
=
common_seed
)
dataset_b
=
RandomDataset
(
random_seed
=
common_seed
)
a
=
_collect_samples
(
dataset_a
,
a
=
_collect_samples
(
hf_tokenizer
,
dataset_a
,
num_requests
=
p
.
num_requests
,
hf_tokenizer
,
prefix_len
=
p
.
prefix_len
,
num_requests
=
p
.
num_requests
,
range_ratio
=
p
.
range_ratio
,
prefix_len
=
p
.
prefix_len
,
input_len
=
p
.
input_len
,
range_ratio
=
p
.
range_ratio
,
output_len
=
p
.
output_len
)
input_len
=
p
.
input_len
,
output_len
=
p
.
output_len
,
)
# Perturb global RNG state to ensure isolation
# Perturb global RNG state to ensure isolation
random
.
seed
(
999
)
random
.
seed
(
999
)
...
@@ -84,43 +89,50 @@ def test_random_dataset_same_seed(
...
@@ -84,43 +89,50 @@ def test_random_dataset_same_seed(
np
.
random
.
seed
(
888
)
np
.
random
.
seed
(
888
)
_
=
[
np
.
random
.
random
()
for
_
in
range
(
100
)]
_
=
[
np
.
random
.
random
()
for
_
in
range
(
100
)]
b
=
_collect_samples
(
dataset_b
,
b
=
_collect_samples
(
hf_tokenizer
,
dataset_b
,
num_requests
=
p
.
num_requests
,
hf_tokenizer
,
prefix_len
=
p
.
prefix_len
,
num_requests
=
p
.
num_requests
,
range_ratio
=
p
.
range_ratio
,
prefix_len
=
p
.
prefix_len
,
input_len
=
p
.
input_len
,
range_ratio
=
p
.
range_ratio
,
output_len
=
p
.
output_len
)
input_len
=
p
.
input_len
,
output_len
=
p
.
output_len
,
)
assert
a
==
b
assert
a
==
b
@
pytest
.
mark
.
benchmark
@
pytest
.
mark
.
benchmark
def
test_random_dataset_different_seeds
(
def
test_random_dataset_different_seeds
(
hf_tokenizer
:
PreTrainedTokenizerBase
,
hf_tokenizer
:
PreTrainedTokenizerBase
,
random_dataset_params
:
Params
random_dataset_params
:
Params
)
->
None
:
)
->
None
:
"""Different seeds should change outputs with overwhelming likelihood."""
"""Different seeds should change outputs with overwhelming likelihood."""
p
=
random_dataset_params
p
=
random_dataset_params
seed_a
=
0
seed_a
=
0
dataset_a
=
RandomDataset
(
random_seed
=
seed_a
)
dataset_a
=
RandomDataset
(
random_seed
=
seed_a
)
a
=
_collect_samples
(
dataset_a
,
a
=
_collect_samples
(
hf_tokenizer
,
dataset_a
,
num_requests
=
p
.
num_requests
,
hf_tokenizer
,
prefix_len
=
p
.
prefix_len
,
num_requests
=
p
.
num_requests
,
range_ratio
=
p
.
range_ratio
,
prefix_len
=
p
.
prefix_len
,
input_len
=
p
.
input_len
,
range_ratio
=
p
.
range_ratio
,
output_len
=
p
.
output_len
)
input_len
=
p
.
input_len
,
output_len
=
p
.
output_len
,
)
seed_b
=
999
seed_b
=
999
dataset_b
=
RandomDataset
(
random_seed
=
seed_b
)
dataset_b
=
RandomDataset
(
random_seed
=
seed_b
)
# Perturb global RNG with same seed as dataset_a to ensure isolation
# Perturb global RNG with same seed as dataset_a to ensure isolation
random
.
seed
(
seed_a
)
random
.
seed
(
seed_a
)
np
.
random
.
seed
(
seed_a
)
np
.
random
.
seed
(
seed_a
)
b
=
_collect_samples
(
dataset_b
,
b
=
_collect_samples
(
hf_tokenizer
,
dataset_b
,
num_requests
=
p
.
num_requests
,
hf_tokenizer
,
prefix_len
=
p
.
prefix_len
,
num_requests
=
p
.
num_requests
,
range_ratio
=
p
.
range_ratio
,
prefix_len
=
p
.
prefix_len
,
input_len
=
p
.
input_len
,
range_ratio
=
p
.
range_ratio
,
output_len
=
p
.
output_len
)
input_len
=
p
.
input_len
,
output_len
=
p
.
output_len
,
)
assert
a
!=
b
assert
a
!=
b
...
@@ -128,6 +140,7 @@ def test_random_dataset_different_seeds(
...
@@ -128,6 +140,7 @@ def test_random_dataset_different_seeds(
# RandomMultiModalDataset tests
# RandomMultiModalDataset tests
# -----------------------------
# -----------------------------
def
_mm_fingerprint_sample
(
def
_mm_fingerprint_sample
(
req
:
SampleRequest
,
req
:
SampleRequest
,
)
->
tuple
[
str
,
int
,
int
,
int
,
list
[
str
]]:
)
->
tuple
[
str
,
int
,
int
,
int
,
list
[
str
]]:
...
@@ -152,8 +165,13 @@ def _mm_fingerprint_sample(
...
@@ -152,8 +165,13 @@ def _mm_fingerprint_sample(
item_prefixes
.
append
(
f
"video:
{
url
[:
22
]
}
"
)
item_prefixes
.
append
(
f
"video:
{
url
[:
22
]
}
"
)
else
:
else
:
item_prefixes
.
append
(
"unknown:"
)
item_prefixes
.
append
(
"unknown:"
)
return
(
req
.
prompt
,
req
.
prompt_len
,
req
.
expected_output_len
,
len
(
items
),
return
(
item_prefixes
)
req
.
prompt
,
req
.
prompt_len
,
req
.
expected_output_len
,
len
(
items
),
item_prefixes
,
)
def
_collect_mm_samples
(
def
_collect_mm_samples
(
...
@@ -167,8 +185,8 @@ def _collect_mm_samples(
...
@@ -167,8 +185,8 @@ def _collect_mm_samples(
output_len
:
int
=
5
,
output_len
:
int
=
5
,
base_items_per_request
:
int
=
2
,
base_items_per_request
:
int
=
2
,
num_mm_items_range_ratio
:
float
=
0.0
,
num_mm_items_range_ratio
:
float
=
0.0
,
limit_mm_per_prompt
:
Optional
[
dict
[
str
,
int
]
]
=
None
,
limit_mm_per_prompt
:
dict
[
str
,
int
]
|
None
=
None
,
bucket_config
:
Optional
[
dict
[
tuple
[
int
,
int
,
int
],
float
]
]
=
None
,
bucket_config
:
dict
[
tuple
[
int
,
int
,
int
],
float
]
|
None
=
None
,
enable_multimodal_chat
:
bool
=
False
,
enable_multimodal_chat
:
bool
=
False
,
)
->
list
[
SampleRequest
]:
)
->
list
[
SampleRequest
]:
if
limit_mm_per_prompt
is
None
:
if
limit_mm_per_prompt
is
None
:
...
@@ -214,6 +232,7 @@ def test_random_mm_different_seeds(
...
@@ -214,6 +232,7 @@ def test_random_mm_different_seeds(
fb
=
[
_mm_fingerprint_sample
(
s
)
for
s
in
b
]
fb
=
[
_mm_fingerprint_sample
(
s
)
for
s
in
b
]
assert
fa
!=
fb
assert
fa
!=
fb
@
pytest
.
mark
.
benchmark
@
pytest
.
mark
.
benchmark
def
test_random_mm_respects_limits
(
def
test_random_mm_respects_limits
(
hf_tokenizer
:
PreTrainedTokenizerBase
,
hf_tokenizer
:
PreTrainedTokenizerBase
,
...
@@ -271,9 +290,9 @@ def test_random_mm_zero_items(hf_tokenizer: PreTrainedTokenizerBase) -> None:
...
@@ -271,9 +290,9 @@ def test_random_mm_zero_items(hf_tokenizer: PreTrainedTokenizerBase) -> None:
for
s
in
samples
:
for
s
in
samples
:
assert
s
.
multi_modal_data
==
[]
assert
s
.
multi_modal_data
==
[]
@
pytest
.
mark
.
benchmark
@
pytest
.
mark
.
benchmark
def
test_random_mm_num_items_per_prompt
(
def
test_random_mm_num_items_per_prompt
(
hf_tokenizer
:
PreTrainedTokenizerBase
)
->
None
:
hf_tokenizer
:
PreTrainedTokenizerBase
)
->
None
:
ds
=
RandomMultiModalDataset
(
random_seed
=
0
)
ds
=
RandomMultiModalDataset
(
random_seed
=
0
)
# Fixed number of images per prompt
# Fixed number of images per prompt
# set num_mm_items_range_ratio to 0.0
# set num_mm_items_range_ratio to 0.0
...
@@ -300,7 +319,6 @@ def test_random_mm_num_items_per_prompt(
...
@@ -300,7 +319,6 @@ def test_random_mm_num_items_per_prompt(
def
test_random_mm_bucket_config_not_mutated
(
def
test_random_mm_bucket_config_not_mutated
(
hf_tokenizer
:
PreTrainedTokenizerBase
,
hf_tokenizer
:
PreTrainedTokenizerBase
,
)
->
None
:
)
->
None
:
ds
=
RandomMultiModalDataset
(
random_seed
=
0
)
ds
=
RandomMultiModalDataset
(
random_seed
=
0
)
# This bucket config is not normalized to sum to 1
# This bucket config is not normalized to sum to 1
# and has more buckets than requested images
# and has more buckets than requested images
...
@@ -321,7 +339,6 @@ def test_random_mm_bucket_config_not_mutated(
...
@@ -321,7 +339,6 @@ def test_random_mm_bucket_config_not_mutated(
# Ensure the original dict content is unchanged
# Ensure the original dict content is unchanged
assert
original
==
snapshot
assert
original
==
snapshot
# Vary number of mm items per prompt
# Vary number of mm items per prompt
# set num_mm_items_range_ratio to 0.5
# set num_mm_items_range_ratio to 0.5
samples_varying_items
=
_collect_mm_samples
(
samples_varying_items
=
_collect_mm_samples
(
...
@@ -342,3 +359,126 @@ def test_random_mm_bucket_config_not_mutated(
...
@@ -342,3 +359,126 @@ def test_random_mm_bucket_config_not_mutated(
assert
len
(
mm_data
)
>=
1
assert
len
(
mm_data
)
>=
1
for
it
in
mm_data
:
for
it
in
mm_data
:
assert
it
.
get
(
"type"
)
==
"image_url"
assert
it
.
get
(
"type"
)
==
"image_url"
@
pytest
.
mark
.
benchmark
def
test_random_mm_video_sampling
(
hf_tokenizer
:
PreTrainedTokenizerBase
)
->
None
:
"""Test video sampling functionality in RandomMultiModalDataset."""
ds
=
RandomMultiModalDataset
(
random_seed
=
42
)
# Test with video bucket configuration
bucket_config
=
{
(
64
,
64
,
1
):
0.3
,
# Images
(
64
,
64
,
8
):
0.7
,
# Videos
}
limit_mm_per_prompt
=
{
"image"
:
2
,
"video"
:
2
}
samples
=
_collect_mm_samples
(
ds
,
hf_tokenizer
,
num_requests
=
5
,
base_items_per_request
=
1
,
num_mm_items_range_ratio
=
0.0
,
limit_mm_per_prompt
=
limit_mm_per_prompt
,
bucket_config
=
bucket_config
,
)
assert
len
(
samples
)
==
5
# Check that we have both images and videos
video_count
=
0
image_count
=
0
for
s
in
samples
:
mm_data
=
cast
(
list
[
dict
[
str
,
Any
]],
s
.
multi_modal_data
)
assert
len
(
mm_data
)
==
1
item
=
mm_data
[
0
]
if
item
.
get
(
"type"
)
==
"video_url"
:
video_count
+=
1
# Verify video URL format
url
=
item
.
get
(
"video_url"
,
{}).
get
(
"url"
,
""
)
assert
url
.
startswith
(
"data:video/mp4;base64,"
)
elif
item
.
get
(
"type"
)
==
"image_url"
:
image_count
+=
1
# Verify image URL format
url
=
item
.
get
(
"image_url"
,
{}).
get
(
"url"
,
""
)
assert
url
.
startswith
(
"data:image/jpeg;base64,"
)
# Should have some videos due to 0.7 probability
assert
video_count
>
0
assert
image_count
>
0
@
pytest
.
mark
.
benchmark
def
test_random_mm_video_only_sampling
(
hf_tokenizer
:
PreTrainedTokenizerBase
)
->
None
:
"""Test sampling with only video buckets."""
ds
=
RandomMultiModalDataset
(
random_seed
=
42
)
bucket_config
=
{
(
64
,
64
,
8
):
1.0
,
# Only videos
}
limit_mm_per_prompt
=
{
"image"
:
0
,
"video"
:
1
}
samples
=
_collect_mm_samples
(
ds
,
hf_tokenizer
,
num_requests
=
3
,
base_items_per_request
=
1
,
num_mm_items_range_ratio
=
0.0
,
limit_mm_per_prompt
=
limit_mm_per_prompt
,
bucket_config
=
bucket_config
,
)
assert
len
(
samples
)
==
3
for
s
in
samples
:
mm_data
=
cast
(
list
[
dict
[
str
,
Any
]],
s
.
multi_modal_data
)
assert
len
(
mm_data
)
==
1
item
=
mm_data
[
0
]
assert
item
.
get
(
"type"
)
==
"video_url"
url
=
item
.
get
(
"video_url"
,
{}).
get
(
"url"
,
""
)
assert
url
.
startswith
(
"data:video/mp4;base64,"
)
@
pytest
.
mark
.
benchmark
def
test_random_mm_video_deterministic_sampling
(
hf_tokenizer
:
PreTrainedTokenizerBase
,
)
->
None
:
"""Test that video sampling is deterministic with same seed."""
seed
=
123
ds_a
=
RandomMultiModalDataset
(
random_seed
=
seed
)
ds_b
=
RandomMultiModalDataset
(
random_seed
=
seed
)
bucket_config
=
{
(
64
,
64
,
8
):
1.0
,
# Only videos
}
limit_mm_per_prompt
=
{
"image"
:
0
,
"video"
:
1
}
a
=
_collect_mm_samples
(
ds_a
,
hf_tokenizer
,
num_requests
=
3
,
base_items_per_request
=
1
,
num_mm_items_range_ratio
=
0.0
,
limit_mm_per_prompt
=
limit_mm_per_prompt
,
bucket_config
=
bucket_config
,
)
b
=
_collect_mm_samples
(
ds_b
,
hf_tokenizer
,
num_requests
=
3
,
base_items_per_request
=
1
,
num_mm_items_range_ratio
=
0.0
,
limit_mm_per_prompt
=
limit_mm_per_prompt
,
bucket_config
=
bucket_config
,
)
fa
=
[
_mm_fingerprint_sample
(
s
)
for
s
in
a
]
fb
=
[
_mm_fingerprint_sample
(
s
)
for
s
in
b
]
assert
fa
==
fb
tests/benchmarks/test_random_multimodal_dataset_video.py
0 → 100644
View file @
006693ed
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
base64
import
os
from
tempfile
import
NamedTemporaryFile
from
typing
import
Any
,
cast
import
cv2
import
pytest
from
transformers
import
AutoTokenizer
,
PreTrainedTokenizerBase
from
vllm.benchmarks.datasets
import
RandomMultiModalDataset
,
SampleRequest
@
pytest
.
fixture
(
scope
=
"session"
)
def
hf_tokenizer
()
->
PreTrainedTokenizerBase
:
"""Use a small, commonly available tokenizer."""
return
AutoTokenizer
.
from_pretrained
(
"gpt2"
)
@
pytest
.
fixture
def
video_dataset
()
->
RandomMultiModalDataset
:
"""Create a RandomMultiModalDataset instance for testing."""
return
RandomMultiModalDataset
(
random_seed
=
42
)
@
pytest
.
mark
.
benchmark
def
test_generate_synthetic_video_different_seeds
():
"""Test that different seeds produce different videos."""
dataset1
=
RandomMultiModalDataset
(
random_seed
=
123
)
dataset2
=
RandomMultiModalDataset
(
random_seed
=
456
)
width
,
height
,
num_frames
=
64
,
48
,
8
video1
=
dataset1
.
generate_synthetic_video
(
width
,
height
,
num_frames
)
video2
=
dataset2
.
generate_synthetic_video
(
width
,
height
,
num_frames
)
# Videos should be different due to different seeds
assert
video1
[
"bytes"
]
!=
video2
[
"bytes"
]
@
pytest
.
mark
.
benchmark
def
test_map_config_to_modality
(
video_dataset
:
RandomMultiModalDataset
):
"""Test modality mapping for different configurations."""
# Test image configuration (num_frames = 1)
assert
video_dataset
.
map_config_to_modality
((
256
,
256
,
1
))
==
"image"
assert
video_dataset
.
map_config_to_modality
((
720
,
1280
,
1
))
==
"image"
# Test video configurations (num_frames > 1)
assert
video_dataset
.
map_config_to_modality
((
256
,
256
,
8
))
==
"video"
assert
video_dataset
.
map_config_to_modality
((
720
,
1280
,
16
))
==
"video"
assert
video_dataset
.
map_config_to_modality
((
64
,
64
,
32
))
==
"video"
# Test invalid configurations
with
pytest
.
raises
(
ValueError
,
match
=
"Invalid multimodal item configuration"
):
video_dataset
.
map_config_to_modality
((
256
,
256
,
0
))
with
pytest
.
raises
(
ValueError
,
match
=
"Invalid multimodal item configuration"
):
video_dataset
.
map_config_to_modality
((
256
,
256
,
-
1
))
@
pytest
.
mark
.
benchmark
def
test_generate_mm_item_video
(
video_dataset
:
RandomMultiModalDataset
):
"""Test generating multimodal items for video configurations."""
# Test video item generation
video_config
=
(
64
,
48
,
8
)
# height, width, num_frames
result
=
video_dataset
.
generate_mm_item
(
video_config
)
# Check the result structure matches OpenAI API format
assert
isinstance
(
result
,
dict
)
assert
result
[
"type"
]
==
"video_url"
assert
"video_url"
in
result
assert
"url"
in
result
[
"video_url"
]
# Check that the URL is a data URL with base64 encoded video
url
=
result
[
"video_url"
][
"url"
]
assert
url
.
startswith
(
"data:video/mp4;base64,"
)
# Decode and verify the video content
base64_data
=
url
.
split
(
","
)[
1
]
video_bytes
=
base64
.
b64decode
(
base64_data
)
assert
len
(
video_bytes
)
>
0
# Verify the video can be decoded
with
NamedTemporaryFile
(
suffix
=
".mp4"
,
delete
=
False
)
as
temp_file
:
temp_path
=
temp_file
.
name
temp_file
.
write
(
video_bytes
)
try
:
cap
=
cv2
.
VideoCapture
(
temp_path
)
assert
cap
.
isOpened
()
frame_count
=
int
(
cap
.
get
(
cv2
.
CAP_PROP_FRAME_COUNT
))
frame_width
=
int
(
cap
.
get
(
cv2
.
CAP_PROP_FRAME_WIDTH
))
frame_height
=
int
(
cap
.
get
(
cv2
.
CAP_PROP_FRAME_HEIGHT
))
assert
frame_count
==
8
assert
frame_width
==
48
assert
frame_height
==
64
cap
.
release
()
finally
:
if
os
.
path
.
exists
(
temp_path
):
os
.
unlink
(
temp_path
)
@
pytest
.
mark
.
benchmark
def
test_generate_mm_item_image
(
video_dataset
:
RandomMultiModalDataset
):
"""Test generating multimodal items for image configurations."""
# Test image item generation
image_config
=
(
64
,
48
,
1
)
# height, width, num_frames=1
result
=
video_dataset
.
generate_mm_item
(
image_config
)
# Check the result structure matches OpenAI API format
assert
isinstance
(
result
,
dict
)
assert
result
[
"type"
]
==
"image_url"
assert
"image_url"
in
result
assert
"url"
in
result
[
"image_url"
]
# Check that the URL is a data URL with base64 encoded image
url
=
result
[
"image_url"
][
"url"
]
assert
url
.
startswith
(
"data:image/jpeg;base64,"
)
@
pytest
.
mark
.
benchmark
def
test_generate_mm_item_invalid_config
(
video_dataset
:
RandomMultiModalDataset
):
"""Test error handling for invalid configurations."""
with
pytest
.
raises
(
ValueError
,
match
=
"Invalid multimodal item configuration"
):
video_dataset
.
generate_mm_item
((
256
,
256
,
0
))
@
pytest
.
mark
.
benchmark
def
test_sample_with_video_buckets
(
video_dataset
:
RandomMultiModalDataset
,
hf_tokenizer
:
PreTrainedTokenizerBase
):
"""Test sampling with video bucket configurations."""
# Configure bucket with video probability > 0
bucket_config
=
{
(
64
,
64
,
1
):
0.3
,
# Images
(
64
,
64
,
8
):
0.7
,
# Videos
}
limit_mm_per_prompt
=
{
"image"
:
5
,
"video"
:
3
}
samples
=
video_dataset
.
sample
(
tokenizer
=
hf_tokenizer
,
num_requests
=
5
,
base_items_per_request
=
2
,
num_mm_items_range_ratio
=
0.0
,
limit_mm_per_prompt
=
limit_mm_per_prompt
,
bucket_config
=
bucket_config
,
input_len
=
20
,
output_len
=
5
,
)
assert
len
(
samples
)
==
5
# Check that samples contain both images and videos
video_count
=
0
image_count
=
0
for
sample
in
samples
:
assert
isinstance
(
sample
,
SampleRequest
)
assert
sample
.
multi_modal_data
is
not
None
assert
isinstance
(
sample
.
multi_modal_data
,
list
)
mm_data
=
cast
(
list
[
dict
[
str
,
Any
]],
sample
.
multi_modal_data
)
assert
len
(
mm_data
)
==
2
# base_items_per_request
for
item
in
mm_data
:
if
item
[
"type"
]
==
"video_url"
:
video_count
+=
1
# Verify video URL format
url
=
item
[
"video_url"
][
"url"
]
assert
url
.
startswith
(
"data:video/mp4;base64,"
)
elif
item
[
"type"
]
==
"image_url"
:
image_count
+=
1
# Verify image URL format
url
=
item
[
"image_url"
][
"url"
]
assert
url
.
startswith
(
"data:image/jpeg;base64,"
)
# Should have some videos due to 0.7 probability
assert
video_count
>
0
assert
image_count
>
0
@
pytest
.
mark
.
benchmark
def
test_sample_video_only_buckets
(
video_dataset
:
RandomMultiModalDataset
,
hf_tokenizer
:
PreTrainedTokenizerBase
):
"""Test sampling with only video buckets."""
bucket_config
=
{
(
64
,
64
,
8
):
1.0
,
# Only videos
}
limit_mm_per_prompt
=
{
"image"
:
0
,
"video"
:
2
}
samples
=
video_dataset
.
sample
(
tokenizer
=
hf_tokenizer
,
num_requests
=
3
,
base_items_per_request
=
1
,
num_mm_items_range_ratio
=
0.0
,
limit_mm_per_prompt
=
limit_mm_per_prompt
,
bucket_config
=
bucket_config
,
input_len
=
20
,
output_len
=
5
,
)
assert
len
(
samples
)
==
3
for
sample
in
samples
:
assert
isinstance
(
sample
,
SampleRequest
)
assert
sample
.
multi_modal_data
is
not
None
assert
isinstance
(
sample
.
multi_modal_data
,
list
)
mm_data
=
cast
(
list
[
dict
[
str
,
Any
]],
sample
.
multi_modal_data
)
assert
len
(
mm_data
)
==
1
item
=
mm_data
[
0
]
assert
item
[
"type"
]
==
"video_url"
url
=
item
[
"video_url"
][
"url"
]
assert
url
.
startswith
(
"data:video/mp4;base64,"
)
@
pytest
.
mark
.
benchmark
def
test_sample_respects_video_limits
(
video_dataset
:
RandomMultiModalDataset
,
hf_tokenizer
:
PreTrainedTokenizerBase
):
"""Test that sampling respects video limits per prompt."""
bucket_config
=
{
(
64
,
64
,
8
):
1.0
,
# Only videos
}
# Set very low video limit
limit_mm_per_prompt
=
{
"image"
:
0
,
"video"
:
1
}
samples
=
video_dataset
.
sample
(
tokenizer
=
hf_tokenizer
,
num_requests
=
3
,
base_items_per_request
=
1
,
num_mm_items_range_ratio
=
0.0
,
limit_mm_per_prompt
=
limit_mm_per_prompt
,
bucket_config
=
bucket_config
,
input_len
=
20
,
output_len
=
5
,
)
assert
len
(
samples
)
==
3
for
sample
in
samples
:
mm_data
=
cast
(
list
[
dict
[
str
,
Any
]],
sample
.
multi_modal_data
)
assert
len
(
mm_data
)
<=
1
# Should respect video limit
@
pytest
.
mark
.
benchmark
def
test_sample_mixed_buckets_with_zero_probability
(
video_dataset
:
RandomMultiModalDataset
,
hf_tokenizer
:
PreTrainedTokenizerBase
):
"""Test sampling with mixed buckets including zero probability entries."""
bucket_config
=
{
(
64
,
64
,
1
):
0.5
,
# Images
(
64
,
64
,
8
):
0.5
,
# Videos
(
128
,
128
,
16
):
0.0
,
# Zero probability videos (should be ignored)
}
limit_mm_per_prompt
=
{
"image"
:
2
,
"video"
:
2
}
samples
=
video_dataset
.
sample
(
tokenizer
=
hf_tokenizer
,
num_requests
=
4
,
base_items_per_request
=
2
,
num_mm_items_range_ratio
=
0.0
,
limit_mm_per_prompt
=
limit_mm_per_prompt
,
bucket_config
=
bucket_config
,
input_len
=
20
,
output_len
=
5
,
)
assert
len
(
samples
)
==
4
# Should only see 64x64 videos, not 128x128 videos
for
sample
in
samples
:
mm_data
=
cast
(
list
[
dict
[
str
,
Any
]],
sample
.
multi_modal_data
)
for
item
in
mm_data
:
if
item
[
"type"
]
==
"video_url"
:
# Decode video to verify dimensions
url
=
item
[
"video_url"
][
"url"
]
base64_data
=
url
.
split
(
","
)[
1
]
video_bytes
=
base64
.
b64decode
(
base64_data
)
with
NamedTemporaryFile
(
suffix
=
".mp4"
,
delete
=
False
)
as
temp_file
:
# noqa
temp_path
=
temp_file
.
name
temp_file
.
write
(
video_bytes
)
try
:
cap
=
cv2
.
VideoCapture
(
temp_path
)
frame_width
=
int
(
cap
.
get
(
cv2
.
CAP_PROP_FRAME_WIDTH
))
frame_height
=
int
(
cap
.
get
(
cv2
.
CAP_PROP_FRAME_HEIGHT
))
cap
.
release
()
# Should be 64x64, not 128x128
assert
frame_width
==
64
assert
frame_height
==
64
finally
:
if
os
.
path
.
exists
(
temp_path
):
os
.
unlink
(
temp_path
)
@
pytest
.
mark
.
benchmark
def
test_sample_deterministic_with_videos
(
hf_tokenizer
:
PreTrainedTokenizerBase
):
"""Test that sampling with videos is deterministic with same seed."""
dataset1
=
RandomMultiModalDataset
(
random_seed
=
123
)
dataset2
=
RandomMultiModalDataset
(
random_seed
=
123
)
bucket_config
=
{
(
64
,
64
,
1
):
0.3
,
# Images
(
64
,
64
,
8
):
0.7
,
# Videos
}
limit_mm_per_prompt
=
{
"image"
:
2
,
"video"
:
2
}
samples1
=
dataset1
.
sample
(
tokenizer
=
hf_tokenizer
,
num_requests
=
3
,
base_items_per_request
=
1
,
num_mm_items_range_ratio
=
0.0
,
limit_mm_per_prompt
=
limit_mm_per_prompt
,
bucket_config
=
bucket_config
,
input_len
=
20
,
output_len
=
5
,
)
samples2
=
dataset2
.
sample
(
tokenizer
=
hf_tokenizer
,
num_requests
=
3
,
base_items_per_request
=
1
,
num_mm_items_range_ratio
=
0.0
,
limit_mm_per_prompt
=
limit_mm_per_prompt
,
bucket_config
=
bucket_config
,
input_len
=
20
,
output_len
=
5
,
)
assert
len
(
samples1
)
==
len
(
samples2
)
# Compare multimodal data
for
s1
,
s2
in
zip
(
samples1
,
samples2
):
assert
s1
.
multi_modal_data
==
s2
.
multi_modal_data
@
pytest
.
mark
.
benchmark
def
test_sample_different_seeds_produce_different_videos
(
hf_tokenizer
:
PreTrainedTokenizerBase
,
):
"""Test that different seeds produce different video content."""
dataset1
=
RandomMultiModalDataset
(
random_seed
=
123
)
dataset2
=
RandomMultiModalDataset
(
random_seed
=
456
)
bucket_config
=
{
(
64
,
64
,
8
):
1.0
,
# Only videos
}
limit_mm_per_prompt
=
{
"image"
:
0
,
"video"
:
1
}
samples1
=
dataset1
.
sample
(
tokenizer
=
hf_tokenizer
,
num_requests
=
2
,
base_items_per_request
=
1
,
num_mm_items_range_ratio
=
0.0
,
limit_mm_per_prompt
=
limit_mm_per_prompt
,
bucket_config
=
bucket_config
,
input_len
=
20
,
output_len
=
5
,
)
samples2
=
dataset2
.
sample
(
tokenizer
=
hf_tokenizer
,
num_requests
=
2
,
base_items_per_request
=
1
,
num_mm_items_range_ratio
=
0.0
,
limit_mm_per_prompt
=
limit_mm_per_prompt
,
bucket_config
=
bucket_config
,
input_len
=
20
,
output_len
=
5
,
)
# Video content should be different
for
s1
,
s2
in
zip
(
samples1
,
samples2
):
mm_data1
=
cast
(
list
[
dict
[
str
,
Any
]],
s1
.
multi_modal_data
)
mm_data2
=
cast
(
list
[
dict
[
str
,
Any
]],
s2
.
multi_modal_data
)
assert
len
(
mm_data1
)
==
len
(
mm_data2
)
==
1
url1
=
mm_data1
[
0
][
"video_url"
][
"url"
]
url2
=
mm_data2
[
0
][
"video_url"
][
"url"
]
assert
url1
!=
url2
# Different video content
tests/benchmarks/test_serve_cli.py
View file @
006693ed
...
@@ -11,9 +11,7 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
...
@@ -11,9 +11,7 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
@
pytest
.
fixture
(
scope
=
"module"
)
@
pytest
.
fixture
(
scope
=
"module"
)
def
server
():
def
server
():
args
=
[
args
=
[
"--max-model-len"
,
"1024"
,
"--enforce-eager"
,
"--load-format"
,
"dummy"
]
"--max-model-len"
,
"1024"
,
"--enforce-eager"
,
"--load-format"
,
"dummy"
]
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
)
as
remote_server
:
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
)
as
remote_server
:
yield
remote_server
yield
remote_server
...
@@ -46,6 +44,7 @@ def test_bench_serve(server):
...
@@ -46,6 +44,7 @@ def test_bench_serve(server):
assert
result
.
returncode
==
0
,
f
"Benchmark failed:
{
result
.
stderr
}
"
assert
result
.
returncode
==
0
,
f
"Benchmark failed:
{
result
.
stderr
}
"
@
pytest
.
mark
.
benchmark
@
pytest
.
mark
.
benchmark
def
test_bench_serve_chat
(
server
):
def
test_bench_serve_chat
(
server
):
command
=
[
command
=
[
...
...
tests/benchmarks/test_throughput_cli.py
View file @
006693ed
...
@@ -10,8 +10,18 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
...
@@ -10,8 +10,18 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
@
pytest
.
mark
.
benchmark
@
pytest
.
mark
.
benchmark
def
test_bench_throughput
():
def
test_bench_throughput
():
command
=
[
command
=
[
"vllm"
,
"bench"
,
"throughput"
,
"--model"
,
MODEL_NAME
,
"--input-len"
,
"vllm"
,
"32"
,
"--output-len"
,
"1"
,
"--enforce-eager"
,
"--load-format"
,
"dummy"
"bench"
,
"throughput"
,
"--model"
,
MODEL_NAME
,
"--input-len"
,
"32"
,
"--output-len"
,
"1"
,
"--enforce-eager"
,
"--load-format"
,
"dummy"
,
]
]
result
=
subprocess
.
run
(
command
,
capture_output
=
True
,
text
=
True
)
result
=
subprocess
.
run
(
command
,
capture_output
=
True
,
text
=
True
)
print
(
result
.
stdout
)
print
(
result
.
stdout
)
...
...
tests/ci_envs.py
View file @
006693ed
...
@@ -5,13 +5,16 @@ These envs only work for a small part of the tests, fix what you need!
...
@@ -5,13 +5,16 @@ These envs only work for a small part of the tests, fix what you need!
"""
"""
import
os
import
os
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Optional
from
collections.abc
import
Callable
from
typing
import
TYPE_CHECKING
,
Any
from
vllm.envs
import
maybe_convert_bool
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
VLLM_CI_NO_SKIP
:
bool
=
False
VLLM_CI_NO_SKIP
:
bool
=
False
VLLM_CI_DTYPE
:
Optional
[
str
]
=
None
VLLM_CI_DTYPE
:
str
|
None
=
None
VLLM_CI_HEAD_DTYPE
:
Optional
[
str
]
=
None
VLLM_CI_HEAD_DTYPE
:
str
|
None
=
None
VLLM_CI_HF_DTYPE
:
Optional
[
str
]
=
None
VLLM_CI_HF_DTYPE
:
str
|
None
=
None
environment_variables
:
dict
[
str
,
Callable
[[],
Any
]]
=
{
environment_variables
:
dict
[
str
,
Callable
[[],
Any
]]
=
{
# A model family has many models with the same architecture.
# A model family has many models with the same architecture.
...
@@ -24,6 +27,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -24,6 +27,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_CI_HEAD_DTYPE"
:
lambda
:
os
.
getenv
(
"VLLM_CI_HEAD_DTYPE"
,
None
),
"VLLM_CI_HEAD_DTYPE"
:
lambda
:
os
.
getenv
(
"VLLM_CI_HEAD_DTYPE"
,
None
),
# Allow changing the head dtype used by transformers in tests
# Allow changing the head dtype used by transformers in tests
"VLLM_CI_HF_DTYPE"
:
lambda
:
os
.
getenv
(
"VLLM_CI_HF_DTYPE"
,
None
),
"VLLM_CI_HF_DTYPE"
:
lambda
:
os
.
getenv
(
"VLLM_CI_HF_DTYPE"
,
None
),
# Allow control over whether tests use enforce_eager
"VLLM_CI_ENFORCE_EAGER"
:
lambda
:
maybe_convert_bool
(
os
.
getenv
(
"VLLM_CI_ENFORCE_EAGER"
,
None
)
),
}
}
...
...
tests/compile/backend.py
View file @
006693ed
...
@@ -2,18 +2,23 @@
...
@@ -2,18 +2,23 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
weakref
import
weakref
from
collections.abc
import
Sequence
from
collections.abc
import
Callable
,
Sequence
from
contextlib
import
nullcontext
from
copy
import
deepcopy
from
copy
import
deepcopy
from
typing
import
Callable
,
Union
import
depyf
from
torch
import
fx
from
torch
import
fx
from
torch._ops
import
OpOverload
from
torch._ops
import
OpOverload
from
torch.fx._utils
import
lazy_format_graph_code
from
vllm.compilation.fx_utils
import
find_op_nodes
from
vllm.compilation.fx_utils
import
find_op_nodes
from
vllm.compilation.inductor_pass
import
InductorPass
from
vllm.compilation.inductor_pass
import
InductorPass
from
vllm.compilation.pass_manager
import
with_pattern_match_debug
from
vllm.compilation.pass_manager
import
with_pattern_match_debug
from
vllm.compilation.vllm_inductor_pass
import
VllmInductorPass
from
vllm.compilation.vllm_inductor_pass
import
VllmInductorPass
from
vllm.config
import
VllmConfig
,
get_current_vllm_config
from
vllm.config
import
VllmConfig
,
get_current_vllm_config
from
vllm.logger
import
init_logger
logger
=
init_logger
(
"vllm.tests.compile.backend"
)
class
LazyInitPass
(
InductorPass
):
class
LazyInitPass
(
InductorPass
):
...
@@ -23,8 +28,7 @@ class LazyInitPass(InductorPass):
...
@@ -23,8 +28,7 @@ class LazyInitPass(InductorPass):
and then immediately invoke it.
and then immediately invoke it.
"""
"""
def
__init__
(
self
,
pass_cls
:
type
[
VllmInductorPass
],
def
__init__
(
self
,
pass_cls
:
type
[
VllmInductorPass
],
vllm_config
:
VllmConfig
):
vllm_config
:
VllmConfig
):
self
.
pass_cls
=
pass_cls
self
.
pass_cls
=
pass_cls
self
.
vllm_config
=
weakref
.
proxy
(
vllm_config
)
# avoid cycle
self
.
vllm_config
=
weakref
.
proxy
(
vllm_config
)
# avoid cycle
...
@@ -45,24 +49,34 @@ class TestBackend:
...
@@ -45,24 +49,34 @@ class TestBackend:
Inductor config is default-initialized from VllmConfig.CompilationConfig.
Inductor config is default-initialized from VllmConfig.CompilationConfig.
"""
"""
def
__init__
(
self
,
*
passes
:
Union
[
InductorPass
,
Callable
[[
fx
.
Graph
],
def
__init__
(
self
,
*
passes
:
InductorPass
|
Callable
[[
fx
.
Graph
],
None
]):
None
]]):
self
.
custom_passes
=
list
(
passes
)
self
.
custom_passes
=
list
(
passes
)
compile_config
=
get_current_vllm_config
().
compilation_config
vllm_config
=
get_current_vllm_config
()
self
.
inductor_config
=
compile_config
.
inductor_compile_config
compile_config
=
vllm_config
.
compilation_config
self
.
inductor_config
[
'force_disable_caches'
]
=
True
# Deepcopy to allow multiple TestBackend instances to use the same VllmConfig
self
.
inductor_config
[
'post_grad_custom_post_pass'
]
=
self
.
post_pass
self
.
inductor_config
=
deepcopy
(
compile_config
.
inductor_compile_config
)
self
.
inductor_config
[
"force_disable_caches"
]
=
True
self
.
inductor_config
[
"post_grad_custom_post_pass"
]
=
self
.
post_pass
if
debug_dump_path
:
=
vllm_config
.
compile_debug_dump_path
():
logger
.
debug
(
"Dumping depyf output to %s"
,
debug_dump_path
)
self
.
debug_ctx
=
depyf
.
prepare_debug
(
debug_dump_path
.
as_posix
())
else
:
self
.
debug_ctx
=
nullcontext
()
def
__call__
(
self
,
graph
:
fx
.
GraphModule
,
example_inputs
):
def
__call__
(
self
,
graph
:
fx
.
GraphModule
,
example_inputs
):
self
.
graph_pre_compile
=
deepcopy
(
graph
)
self
.
graph_pre_compile
=
deepcopy
(
graph
)
from
torch._inductor.compile_fx
import
compile_fx
from
torch._inductor.compile_fx
import
compile_fx
return
compile_fx
(
graph
,
example_inputs
,
with
self
.
debug_ctx
:
config_patches
=
self
.
inductor_config
)
return
compile_fx
(
graph
,
example_inputs
,
config_patches
=
self
.
inductor_config
)
@
with_pattern_match_debug
@
with_pattern_match_debug
def
post_pass
(
self
,
graph
:
fx
.
Graph
):
def
post_pass
(
self
,
graph
:
fx
.
Graph
):
self
.
graph_pre_pass
=
deepcopy
(
graph
)
self
.
graph_pre_pass
=
deepcopy
(
graph
)
lazy_format_graph_code
(
"graph_pre_pass"
,
graph
.
owning_module
)
VllmInductorPass
.
dump_prefix
=
0
VllmInductorPass
.
dump_prefix
=
0
for
pass_
in
self
.
custom_passes
:
for
pass_
in
self
.
custom_passes
:
...
@@ -72,6 +86,7 @@ class TestBackend:
...
@@ -72,6 +86,7 @@ class TestBackend:
VllmInductorPass
.
dump_prefix
=
None
VllmInductorPass
.
dump_prefix
=
None
self
.
graph_post_pass
=
deepcopy
(
graph
)
self
.
graph_post_pass
=
deepcopy
(
graph
)
lazy_format_graph_code
(
"graph_post_pass"
,
graph
.
owning_module
)
# assign by reference, will reflect the final state of the graph
# assign by reference, will reflect the final state of the graph
self
.
final_graph
=
graph
self
.
final_graph
=
graph
...
@@ -82,8 +97,7 @@ class TestBackend:
...
@@ -82,8 +97,7 @@ class TestBackend:
assert
num_pre
>
0
,
f
"Op
{
op
.
name
()
}
not found in pre-pass graph"
assert
num_pre
>
0
,
f
"Op
{
op
.
name
()
}
not found in pre-pass graph"
assert
num_pre
>
num_post
,
f
"All nodes remain for op
{
op
.
name
()
}
"
assert
num_pre
>
num_post
,
f
"All nodes remain for op
{
op
.
name
()
}
"
if
fully_replaced
:
if
fully_replaced
:
assert
num_post
==
0
,
\
assert
num_post
==
0
,
f
"Unexpected op
{
op
.
name
()
}
in post-pass graph"
f
"Unexpected op
{
op
.
name
()
}
in post-pass graph"
def
check_after_ops
(
self
,
ops
:
Sequence
[
OpOverload
]):
def
check_after_ops
(
self
,
ops
:
Sequence
[
OpOverload
]):
for
op
in
ops
:
for
op
in
ops
:
...
...
tests/compile/piecewise/test_full_cudagraph.py
View file @
006693ed
...
@@ -3,15 +3,15 @@
...
@@ -3,15 +3,15 @@
import
contextlib
import
contextlib
import
os
import
os
import
weakref
import
weakref
from
dataclasses
import
dataclass
from
typing
import
Optional
import
pytest
import
pytest
from
tests.utils
import
wait_for_gpu_memory_to_clear
from
tests.utils
import
wait_for_gpu_memory_to_clear
from
tests.v1.attention.utils
import
full_cg_backend_configs
as
backend_configs
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.config
import
CompilationConfig
from
vllm.config
import
CompilationConfig
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
is_torch_equal_or_newer
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
...
@@ -33,121 +33,44 @@ def temporary_environ(env_vars):
...
@@ -33,121 +33,44 @@ def temporary_environ(env_vars):
os
.
environ
[
k
]
=
v
os
.
environ
[
k
]
=
v
@
dataclass
model_backends_full_cudagraph
=
[]
class
BackendConfig
:
name
:
str
env_vars
:
dict
comp_config
:
dict
specific_gpu_arch
:
Optional
[
tuple
]
=
None
# Define all backend configurations of full cudagraph to be tested
backend_configs
=
{
# FA3 on Hopper
"FA3"
:
BackendConfig
(
name
=
"FA3"
,
env_vars
=
{
"VLLM_FLASH_ATTN_VERSION"
:
"3"
,
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH"
:
"16"
,
},
comp_config
=
{
"cudagraph_mode"
:
"FULL"
,
},
specific_gpu_arch
=
(
9
,
0
)),
# FlashMLA on Hopper
"FlashMLA"
:
BackendConfig
(
name
=
"FlashMLA"
,
env_vars
=
{
"VLLM_ATTENTION_BACKEND"
:
"FLASHMLA"
,
},
comp_config
=
{
"cudagraph_mode"
:
"FULL_AND_PIECEWISE"
,
},
specific_gpu_arch
=
(
9
,
0
)),
# FlashAttention MLA on Hopper
"FlashAttentionMLA"
:
BackendConfig
(
name
=
"FlashAttentionMLA"
,
env_vars
=
{
"VLLM_ATTENTION_BACKEND"
:
"FLASH_ATTN_MLA"
,
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH"
:
"16"
,
},
comp_config
=
{
"cudagraph_mode"
:
"FULL_DECODE_ONLY"
,
},
specific_gpu_arch
=
(
9
,
0
)),
# Cutlass MLA on Blackwell
"CutlassMLA"
:
BackendConfig
(
name
=
"CutlassMLA"
,
env_vars
=
{
"VLLM_USE_V1"
:
"1"
,
"VLLM_ATTENTION_BACKEND"
:
"CUTLASS_MLA"
,
"FORCE_NUM_KV_SPLITS"
:
"1"
,
# TODO: remove this when hang issue is fixed
},
comp_config
=
{
"cudagraph_mode"
:
"FULL_AND_PIECEWISE"
,
"cudagraph_capture_sizes"
:
[
16
,
32
,
64
,
128
,
256
,
512
],
},
specific_gpu_arch
=
(
10
,
0
)),
# FA2
"FA2"
:
BackendConfig
(
name
=
"FA2"
,
env_vars
=
{
"VLLM_FLASH_ATTN_VERSION"
:
"2"
,
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH"
:
"16"
,
},
comp_config
=
{
"cudagraph_mode"
:
"FULL"
,
}),
# Triton Attention
"TritonAttn"
:
BackendConfig
(
name
=
"TritonAttn"
,
env_vars
=
{
"VLLM_ATTENTION_BACKEND"
:
"TRITON_ATTN"
},
comp_config
=
{
"cudagraph_mode"
:
"FULL"
,
}),
# FlashInfer
"FlashInfer"
:
BackendConfig
(
name
=
"FlashInfer"
,
env_vars
=
{
"VLLM_ATTENTION_BACKEND"
:
"FLASHINFER"
},
comp_config
=
{
"cudagraph_mode"
:
"FULL_AND_PIECEWISE"
,
}),
}
test_params_full_cudagraph
=
[]
# deepseek-ai/DeepSeek-V2-Lite with MLA
# deepseek-ai/DeepSeek-V2-Lite with MLA
MLA_backends
=
[
"FlashMLA"
,
"FlashAttentionMLA"
,
"CutlassMLA"
]
MLA_backends
=
[
"FlashMLA"
,
"FlashAttentionMLA"
,
"CutlassMLA"
]
for
mla_backend
in
MLA_backends
:
for
mla_backend
in
MLA_backends
:
test_param
s_full_cudagraph
.
append
(
model_backend
s_full_cudagraph
.
append
(
pytest
.
param
(
(
"deepseek-ai/DeepSeek-V2-Lite"
,
backend_configs
[
mla_backend
])
(
"deepseek-ai/DeepSeek-V2-Lite"
,
backend_configs
[
mla_backend
]))
)
)
# Qwen/Qwen2-1.5B-Instruct with other backends
# Qwen/Qwen2-1.5B-Instruct with other backends
other_backend_configs
=
[
other_backend_configs
=
[
backend_configs
[
c
]
for
c
in
backend_configs
if
c
not
in
MLA_backends
backend_configs
[
c
]
for
c
in
backend_configs
if
c
not
in
MLA_backends
]
]
for
backend_config
in
other_backend_configs
:
for
backend_config
in
other_backend_configs
:
test_params_full_cudagraph
.
append
(
model_backends_full_cudagraph
.
append
((
"Qwen/Qwen2-1.5B-Instruct"
,
backend_config
))
pytest
.
param
((
"Qwen/Qwen2-1.5B-Instruct"
,
backend_config
)))
@
pytest
.
fixture
(
scope
=
"class"
)
@
pytest
.
fixture
(
scope
=
"class"
)
def
llm_pair
(
request
):
def
llm_pair
(
request
):
model
,
backend_config
=
request
.
param
model
,
backend_config
,
use_inductor_graph_partition
=
request
.
param
backend_config
.
comp_config
[
"use_inductor_graph_partition"
]
=
(
use_inductor_graph_partition
)
if
use_inductor_graph_partition
and
not
is_torch_equal_or_newer
(
"2.9.0.dev"
):
pytest
.
skip
(
"Inductor graph partition only supported in torch>=2.9"
)
# Dynamically skip test if GPU capability is not met
# Dynamically skip test if GPU capability is not met
if
backend_config
.
specific_gpu_arch
and
backend_config
.
specific_gpu_arch
\
if
(
!=
current_platform
.
get_device_capability
():
backend_config
.
specific_gpu_arch
and
backend_config
.
specific_gpu_arch
!=
current_platform
.
get_device_capability
()
):
if
backend_config
.
specific_gpu_arch
==
(
9
,
0
):
if
backend_config
.
specific_gpu_arch
==
(
9
,
0
):
pytest
.
skip
(
"Only Hopper GPUs support FA3 and FlashMLA"
)
pytest
.
skip
(
"Only Hopper GPUs support FA3 and FlashMLA"
)
elif
backend_config
.
specific_gpu_arch
==
(
10
,
0
):
elif
backend_config
.
specific_gpu_arch
==
(
10
,
0
):
pytest
.
skip
(
"Only Blackwell GPUs support Cutlass MLA"
)
pytest
.
skip
(
"Only Blackwell GPUs support Cutlass MLA"
)
env_vars
=
{
env_vars
=
{
"VLLM_USE_V1"
:
"1"
,
# Force native sampler to avoid potential nondeterminism in FlashInfer
# Force native sampler to avoid potential nondeterminism in FlashInfer
# when per-request generators are not used in V1.
# when per-request generators are not used in V1.
"VLLM_USE_FLASHINFER_SAMPLER"
:
"0"
,
"VLLM_USE_FLASHINFER_SAMPLER"
:
"0"
,
...
@@ -160,8 +83,7 @@ def llm_pair(request):
...
@@ -160,8 +83,7 @@ def llm_pair(request):
trust_remote_code
=
True
,
trust_remote_code
=
True
,
max_model_len
=
1024
,
max_model_len
=
1024
,
max_num_seqs
=
128
,
max_num_seqs
=
128
,
compilation_config
=
\
compilation_config
=
CompilationConfig
(
**
backend_config
.
comp_config
),
CompilationConfig
(
**
backend_config
.
comp_config
),
generation_config
=
"vllm"
,
generation_config
=
"vllm"
,
seed
=
42
,
seed
=
42
,
)
)
...
@@ -187,7 +109,15 @@ def llm_pair(request):
...
@@ -187,7 +109,15 @@ def llm_pair(request):
)
)
@
pytest
.
mark
.
parametrize
(
"llm_pair"
,
test_params_full_cudagraph
,
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"llm_pair"
,
[
pytest
.
param
((
model
,
backend_config
,
use_inductor_graph_partition
))
for
model
,
backend_config
in
model_backends_full_cudagraph
for
use_inductor_graph_partition
in
[
True
,
False
]
],
indirect
=
True
,
)
class
TestFullCUDAGraph
:
class
TestFullCUDAGraph
:
"""
"""
Use a class such that an llm pair is constructed once for all
Use a class such that an llm pair is constructed once for all
...
@@ -197,20 +127,22 @@ class TestFullCUDAGraph:
...
@@ -197,20 +127,22 @@ class TestFullCUDAGraph:
meaning there would be multiple LLM instances hogging memory simultaneously.
meaning there would be multiple LLM instances hogging memory simultaneously.
"""
"""
@
pytest
.
mark
.
parametrize
((
"batch_size"
,
"max_tokens"
),
[
@
pytest
.
mark
.
parametrize
(
(
1
,
10
),
(
"batch_size"
,
"max_tokens"
),
(
7
,
10
),
[
(
16
,
10
),
(
1
,
10
),
(
25
,
10
),
(
7
,
10
),
(
32
,
10
),
(
16
,
10
),
(
45
,
10
),
(
25
,
10
),
(
64
,
10
),
(
32
,
10
),
(
123
,
10
),
(
45
,
10
),
(
8
,
5
),
(
64
,
10
),
(
8
,
30
),
(
123
,
10
),
])
(
8
,
5
),
def
test_full_cudagraph
(
self
,
batch_size
,
max_tokens
,
(
8
,
30
),
llm_pair
:
tuple
[
LLM
,
LLM
]):
],
)
def
test_full_cudagraph
(
self
,
batch_size
,
max_tokens
,
llm_pair
:
tuple
[
LLM
,
LLM
]):
"""
"""
Test various batch sizes and max_tokens to ensure that the
Test various batch sizes and max_tokens to ensure that the
full cudagraph compilation works for padded cases too.
full cudagraph compilation works for padded cases too.
...
@@ -221,26 +153,33 @@ class TestFullCUDAGraph:
...
@@ -221,26 +153,33 @@ class TestFullCUDAGraph:
prompts
=
[
"the quick brown fox"
]
*
batch_size
prompts
=
[
"the quick brown fox"
]
*
batch_size
# Use purely greedy decoding to avoid top-p truncation sensitivity
# Use purely greedy decoding to avoid top-p truncation sensitivity
# that can amplify tiny numeric differences across runtimes.
# that can amplify tiny numeric differences across runtimes.
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
sampling_params
=
SamplingParams
(
max_tokens
=
max_tokens
,
temperature
=
0.0
,
max_tokens
=
max_tokens
,
top_p
=
1.0
top_p
=
1.0
)
)
piecewise_responses
=
piecewise_llm
.
generate
(
prompts
,
sampling_params
)
piecewise_responses
=
piecewise_llm
.
generate
(
prompts
,
sampling_params
)
full_responses
=
full_cudagraph_llm
.
generate
(
prompts
,
sampling_params
)
full_responses
=
full_cudagraph_llm
.
generate
(
prompts
,
sampling_params
)
# Check that all responses are the same
# Check that all responses are the same
for
piecewise_res
,
full_res
in
zip
(
piecewise_responses
,
for
piecewise_res
,
full_res
in
zip
(
piecewise_responses
,
full_responses
):
full_responses
):
assert
(
assert
piecewise_res
.
outputs
[
0
].
text
.
lower
()
==
\
piecewise_res
.
outputs
[
0
].
text
.
lower
()
full_res
.
outputs
[
0
].
text
.
lower
()
==
full_res
.
outputs
[
0
].
text
.
lower
()
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"Skip if not cuda"
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"Skip if not cuda"
)
def
test_full_cudagraph_with_invalid_backend
():
def
test_full_cudagraph_with_invalid_backend
():
with
temporary_environ
({
with
(
"VLLM_USE_V1"
:
"1"
,
temporary_environ
(
"VLLM_ATTENTION_BACKEND"
:
"FLEX_ATTENTION"
{
# Flex_Attention is not supported with full cuda graph
"VLLM_ATTENTION_BACKEND"
:
"FLEX_ATTENTION"
,
}),
pytest
.
raises
(
RuntimeError
):
# Flex_Attention is not supported with full cuda graph
LLM
(
model
=
"Qwen/Qwen2-1.5B-Instruct"
,
}
compilation_config
=
CompilationConfig
(
cudagraph_mode
=
"FULL"
))
),
pytest
.
raises
(
RuntimeError
),
):
LLM
(
model
=
"Qwen/Qwen2-1.5B-Instruct"
,
compilation_config
=
CompilationConfig
(
cudagraph_mode
=
"FULL"
),
)
tests/compile/piecewise/test_multiple_graphs.py
View file @
006693ed
...
@@ -5,16 +5,24 @@ Test (piecewise) compilation with a simple model where multiple submodules
...
@@ -5,16 +5,24 @@ Test (piecewise) compilation with a simple model where multiple submodules
are compiled and graph captured separately.
are compiled and graph captured separately.
"""
"""
import
pytest
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
vllm.compilation.backends
import
set_model_tag
from
vllm.compilation.backends
import
set_model_tag
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.decorators
import
(
ignore_torch_compile
,
from
vllm.compilation.decorators
import
ignore_torch_compile
,
support_torch_compile
support_torch_compile
)
from
vllm.config
import
(
from
vllm.config
import
(
CompilationConfig
,
CompilationLevel
,
CUDAGraphMode
,
CompilationConfig
,
VllmConfig
,
set_current_vllm_config
)
CompilationMode
,
CUDAGraphMode
,
VllmConfig
,
set_current_vllm_config
,
)
from
vllm.forward_context
import
BatchDescriptor
,
set_forward_context
from
vllm.forward_context
import
BatchDescriptor
,
set_forward_context
from
vllm.utils.torch_utils
import
is_torch_equal_or_newer
from
...utils
import
create_new_process_for_each_test
# This import automatically registers `torch.ops.silly.attention`
# This import automatically registers `torch.ops.silly.attention`
from
..
import
silly_attention
# noqa: F401
from
..
import
silly_attention
# noqa: F401
...
@@ -27,12 +35,7 @@ RANDOM_SEED = 0
...
@@ -27,12 +35,7 @@ RANDOM_SEED = 0
@
support_torch_compile
@
support_torch_compile
class
ParentModel
(
nn
.
Module
):
class
ParentModel
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
**
kwargs
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
''
,
**
kwargs
)
->
None
:
super
().
__init__
()
super
().
__init__
()
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -40,7 +43,6 @@ class ParentModel(nn.Module):
...
@@ -40,7 +43,6 @@ class ParentModel(nn.Module):
class
Attention
(
nn
.
Module
):
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
mlp_size
:
int
,
hidden_size
:
int
)
->
None
:
def
__init__
(
self
,
mlp_size
:
int
,
hidden_size
:
int
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
pre_attn
=
nn
.
Linear
(
mlp_size
,
hidden_size
,
bias
=
False
)
self
.
pre_attn
=
nn
.
Linear
(
mlp_size
,
hidden_size
,
bias
=
False
)
...
@@ -51,17 +53,21 @@ class Attention(nn.Module):
...
@@ -51,17 +53,21 @@ class Attention(nn.Module):
nn
.
init
.
xavier_normal_
(
nn
.
init
.
xavier_normal_
(
self
.
pre_attn
.
weight
.
data
,
self
.
pre_attn
.
weight
.
data
,
generator
=
torch
.
Generator
().
manual_seed
(
RANDOM_SEED
),
generator
=
torch
.
Generator
().
manual_seed
(
RANDOM_SEED
),
gain
=
0.001
)
gain
=
0.001
,
)
nn
.
init
.
xavier_normal_
(
nn
.
init
.
xavier_normal_
(
self
.
post_attn
.
weight
.
data
,
self
.
post_attn
.
weight
.
data
,
generator
=
torch
.
Generator
().
manual_seed
(
RANDOM_SEED
),
generator
=
torch
.
Generator
().
manual_seed
(
RANDOM_SEED
),
gain
=
0.001
)
gain
=
0.001
,
)
def
rms_norm_ref
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
rms_norm_ref
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x_f32
=
x
.
float
()
x_f32
=
x
.
float
()
return
(
x_f32
*
torch
.
rsqrt
(
return
(
torch
.
mean
(
x_f32
.
square
(),
dim
=-
1
,
keepdim
=
True
)
+
1e-6
)
*
x_f32
self
.
rms_norm_weight
).
to
(
x
.
dtype
)
*
torch
.
rsqrt
(
torch
.
mean
(
x_f32
.
square
(),
dim
=-
1
,
keepdim
=
True
)
+
1e-6
)
*
self
.
rms_norm_weight
).
to
(
x
.
dtype
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
self
.
pre_attn
(
x
)
x
=
self
.
pre_attn
(
x
)
...
@@ -76,14 +82,15 @@ class Attention(nn.Module):
...
@@ -76,14 +82,15 @@ class Attention(nn.Module):
@
support_torch_compile
@
support_torch_compile
class
CompiledAttention
(
nn
.
Module
):
class
CompiledAttention
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
*
,
*
,
mlp_size
:
int
,
mlp_size
:
int
,
hidden_size
:
int
,
hidden_size
:
int
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
''
,
prefix
:
str
=
""
,
**
kwargs
)
->
None
:
**
kwargs
,
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
attn
=
Attention
(
mlp_size
,
hidden_size
)
self
.
attn
=
Attention
(
mlp_size
,
hidden_size
)
...
@@ -93,21 +100,21 @@ class CompiledAttention(nn.Module):
...
@@ -93,21 +100,21 @@ class CompiledAttention(nn.Module):
@
support_torch_compile
@
support_torch_compile
class
CompiledAttentionTwo
(
CompiledAttention
):
class
CompiledAttentionTwo
(
CompiledAttention
):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
attn
(
x
)
+
x
return
self
.
attn
(
x
)
+
x
@
ignore_torch_compile
@
ignore_torch_compile
class
SimpleModelWithTwoGraphs
(
ParentModel
):
class
SimpleModelWithTwoGraphs
(
ParentModel
):
def
__init__
(
def
__init__
(
self
,
self
,
*
,
*
,
mlp_size
:
int
,
mlp_size
:
int
,
hidden_size
:
int
,
hidden_size
:
int
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
''
,
prefix
:
str
=
""
,
**
kwargs
)
->
None
:
**
kwargs
,
)
->
None
:
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
# Test will fail without set_model_tag here with error:
# Test will fail without set_model_tag here with error:
# "ValueError: too many values to unpack (expected 3)"
# "ValueError: too many values to unpack (expected 3)"
...
@@ -142,118 +149,174 @@ class SimpleModelWithTwoGraphs(ParentModel):
...
@@ -142,118 +149,174 @@ class SimpleModelWithTwoGraphs(ParentModel):
@
torch
.
inference_mode
@
torch
.
inference_mode
def
run_model
(
vllm_config
:
VllmConfig
,
model
:
nn
.
Module
,
inputs
:
torch
.
Tensor
,
def
run_model
(
cudagraph_runtime_mode
:
CUDAGraphMode
):
vllm_config
:
VllmConfig
,
model
:
nn
.
Module
,
inputs
:
torch
.
Tensor
,
cudagraph_runtime_mode
:
CUDAGraphMode
,
):
with
set_forward_context
({},
vllm_config
=
vllm_config
):
with
set_forward_context
({},
vllm_config
=
vllm_config
):
# warmup for the model with cudagraph_mode NONE
# warmup for the model with cudagraph_mode NONE
model
(
inputs
)
model
(
inputs
)
# simulate cudagraphs capturing
# simulate cudagraphs capturing
with
set_forward_context
({},
with
set_forward_context
(
vllm_config
=
vllm_config
,
{},
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
vllm_config
=
vllm_config
,
batch_descriptor
=
BatchDescriptor
(
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
num_tokens
=
2
,
)):
batch_descriptor
=
BatchDescriptor
(
num_tokens
=
2
,
),
):
model
(
inputs
[:
2
])
model
(
inputs
[:
2
])
with
set_forward_context
({},
with
set_forward_context
(
vllm_config
=
vllm_config
,
{},
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
vllm_config
=
vllm_config
,
batch_descriptor
=
BatchDescriptor
(
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
num_tokens
=
1
,
)):
batch_descriptor
=
BatchDescriptor
(
num_tokens
=
1
,
),
):
model
(
inputs
[:
1
])
model
(
inputs
[:
1
])
# simulate cudagraphs replay
# simulate cudagraphs replay
with
set_forward_context
({},
with
set_forward_context
(
vllm_config
=
vllm_config
,
{},
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
vllm_config
=
vllm_config
,
batch_descriptor
=
BatchDescriptor
(
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
num_tokens
=
2
,
)):
batch_descriptor
=
BatchDescriptor
(
num_tokens
=
2
,
),
):
output
=
model
(
inputs
[:
2
])
output
=
model
(
inputs
[:
2
])
output
=
output
.
cpu
()
output
=
output
.
cpu
()
return
output
.
cpu
()
return
output
.
cpu
()
def
test_multi_graph_piecewise_compile_outputs_equal
():
@
pytest
.
mark
.
parametrize
(
"use_inductor_graph_partition"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"use_bytecode_hook"
,
[
True
,
False
])
@
create_new_process_for_each_test
(
"spawn"
)
def
test_multi_graph_piecewise_compile
(
use_inductor_graph_partition
:
bool
,
use_bytecode_hook
:
bool
,
monkeypatch
):
# Set the environment variable for this test
monkeypatch
.
setenv
(
"VLLM_USE_BYTECODE_HOOK"
,
"1"
if
use_bytecode_hook
else
"0"
)
if
use_inductor_graph_partition
and
not
is_torch_equal_or_newer
(
"2.9.0.dev"
):
pytest
.
skip
(
"inductor graph partition is only available in PyTorch 2.9+"
)
outputs
=
[]
outputs
=
[]
# piecewise compile
# vllmcompile compile
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
vllm_config
=
VllmConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
compilation_config
=
CompilationConfig
(
use_cudagraph
=
True
,
mode
=
CompilationMode
.
VLLM_COMPILE
,
splitting_ops
=
[
"silly.attention"
],
cudagraph_mode
=
CUDAGraphMode
.
PIECEWISE
,
cudagraph_capture_sizes
=
[
1
,
2
],
splitting_ops
=
[
"silly::attention"
],
))
cudagraph_capture_sizes
=
[
1
,
2
],
use_inductor_graph_partition
=
use_inductor_graph_partition
,
)
)
cudagraph_runtime_mode
=
CUDAGraphMode
.
PIECEWISE
cudagraph_runtime_mode
=
CUDAGraphMode
.
PIECEWISE
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
model
=
SimpleModelWithTwoGraphs
(
mlp_size
=
MLP_SIZE
,
model
=
(
hidden_size
=
HIDDEN_SIZE
,
SimpleModelWithTwoGraphs
(
vllm_config
=
vllm_config
,
mlp_size
=
MLP_SIZE
,
prefix
=
''
).
eval
().
cuda
()
hidden_size
=
HIDDEN_SIZE
,
vllm_config
=
vllm_config
,
prefix
=
""
,
)
.
eval
()
.
cuda
()
)
# Pre-allocate memory for CUDAGraph which expects
# Pre-allocate memory for CUDAGraph which expects
# static tensor addresses
# static tensor addresses
inputs
=
torch
.
randn
(
BATCH_SIZE
,
MLP_SIZE
).
cuda
()
inputs
=
torch
.
randn
(
BATCH_SIZE
,
MLP_SIZE
).
cuda
()
if
use_inductor_graph_partition
:
# Splitting happens at Inductor lowering level,
# total piecewise fx graphs is equal to total graphs
num_piecewise_fx
=
2
num_piecewise_capturable_fx
=
2
else
:
# attn_one, attn_two each has 3 piecewise graphs
# (pre attn, post attn, silly_attention) each
num_piecewise_fx
=
6
# attn_one, attn_two has pre attn and post attn each, total=4
num_piecewise_capturable_fx
=
4
with
compilation_counter
.
expect
(
with
compilation_counter
.
expect
(
num_graphs_seen
=
2
,
# two graphs for the model
num_graphs_seen
=
2
,
# two graphs for the model
num_piecewise_graphs_seen
=
6
,
num_piecewise_graphs_seen
=
num_piecewise_fx
,
# attn_one, attn_two each has 3 piecewise graphs
num_piecewise_capturable_graphs_seen
=
num_piecewise_capturable_fx
,
# (pre attn, post attn, silly_attention) each
num_backend_compilations
=
num_piecewise_capturable_fx
,
num_piecewise_capturable_graphs_seen
=
4
,
num_cudagraph_captured
=
8
,
# num_cudagraph_sizes * num_partitions
# attn_one, attn_two has pre attn and post attn each, total=4
num_backend_compilations
=
4
,
# num_piecewise_capturable_graphs_seen
num_cudagraph_captured
=
8
,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
):
outputs
.
append
(
outputs
.
append
(
run_model
(
vllm_config
,
model
,
inputs
,
cudagraph_runtime_mode
))
run_model
(
vllm_config
,
model
,
inputs
,
cudagraph_runtime_mode
))
# no compile or cudagraph
# no compile or cudagraph
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
vllm_config
=
VllmConfig
(
level
=
CompilationLevel
.
NO_COMPILATION
,
))
compilation_config
=
CompilationConfig
(
mode
=
CompilationMode
.
NONE
,
)
)
cudagraph_runtime_mode
=
CUDAGraphMode
.
NONE
cudagraph_runtime_mode
=
CUDAGraphMode
.
NONE
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
model
=
SimpleModelWithTwoGraphs
(
mlp_size
=
MLP_SIZE
,
model
=
(
hidden_size
=
HIDDEN_SIZE
,
SimpleModelWithTwoGraphs
(
vllm_config
=
vllm_config
,
mlp_size
=
MLP_SIZE
,
prefix
=
''
).
eval
().
cuda
()
hidden_size
=
HIDDEN_SIZE
,
vllm_config
=
vllm_config
,
prefix
=
""
,
)
.
eval
()
.
cuda
()
)
with
compilation_counter
.
expect
(
with
compilation_counter
.
expect
(
num_graphs_seen
=
0
,
num_graphs_seen
=
0
,
num_piecewise_graphs_seen
=
0
,
num_piecewise_graphs_seen
=
0
,
num_piecewise_capturable_graphs_seen
=
0
,
num_piecewise_capturable_graphs_seen
=
0
,
num_backend_compilations
=
0
,
num_backend_compilations
=
0
,
num_cudagraph_captured
=
0
,
num_cudagraph_captured
=
0
,
):
):
outputs
.
append
(
outputs
.
append
(
run_model
(
vllm_config
,
model
,
inputs
,
cudagraph_runtime_mode
))
run_model
(
vllm_config
,
model
,
inputs
,
cudagraph_runtime_mode
))
# piecewise compile without CUDA graph
# piecewise compile without CUDA graph
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
vllm_config
=
VllmConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
compilation_config
=
CompilationConfig
(
use_cudagraph
=
False
,
mode
=
CompilationMode
.
VLLM_COMPILE
,
splitting_ops
=
[
"silly.attention"
],
cudagraph_mode
=
CUDAGraphMode
.
NONE
,
))
splitting_ops
=
[
"silly::attention"
],
use_inductor_graph_partition
=
use_inductor_graph_partition
,
)
)
cudagraph_runtime_mode
=
CUDAGraphMode
.
PIECEWISE
cudagraph_runtime_mode
=
CUDAGraphMode
.
PIECEWISE
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
model
=
SimpleModelWithTwoGraphs
(
mlp_size
=
MLP_SIZE
,
model
=
(
hidden_size
=
HIDDEN_SIZE
,
SimpleModelWithTwoGraphs
(
vllm_config
=
vllm_config
,
mlp_size
=
MLP_SIZE
,
prefix
=
''
).
eval
().
cuda
()
hidden_size
=
HIDDEN_SIZE
,
vllm_config
=
vllm_config
,
prefix
=
""
,
)
.
eval
()
.
cuda
()
)
with
compilation_counter
.
expect
(
with
compilation_counter
.
expect
(
num_graphs_seen
=
2
,
num_graphs_seen
=
2
,
num_piecewise_graphs_seen
=
6
,
num_piecewise_graphs_seen
=
num_piecewise_fx
,
num_piecewise_capturable_graphs_seen
=
4
,
num_piecewise_capturable_graphs_seen
=
num_piecewise_capturable_fx
,
num_backend_compilations
=
4
,
num_backend_compilations
=
num_piecewise_capturable_fx
,
num_cudagraph_captured
=
0
,
# no cudagraph captured
num_cudagraph_captured
=
0
,
# no cudagraph captured
):
):
outputs
.
append
(
outputs
.
append
(
run_model
(
vllm_config
,
model
,
inputs
,
cudagraph_runtime_mode
))
run_model
(
vllm_config
,
model
,
inputs
,
cudagraph_runtime_mode
))
# Generally don't expect outputs with and without inductor
# Generally don't expect outputs with and without inductor
# to be bitwise equivalent
# to be bitwise equivalent
...
...
tests/compile/piecewise/test_simple.py
View file @
006693ed
...
@@ -11,11 +11,17 @@ from torch import nn
...
@@ -11,11 +11,17 @@ from torch import nn
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
(
CompilationConfig
,
CompilationLevel
,
CUDAGraphMode
,
from
vllm.config
import
(
VllmConfig
,
set_current_vllm_config
)
CompilationConfig
,
from
vllm.envs
import
VLLM_USE_V1
CompilationMode
,
CUDAGraphMode
,
VllmConfig
,
set_current_vllm_config
,
)
from
vllm.forward_context
import
BatchDescriptor
,
set_forward_context
from
vllm.forward_context
import
BatchDescriptor
,
set_forward_context
from
vllm.utils
import
is_torch_equal_or_newer
from
vllm.utils.torch_utils
import
is_torch_equal_or_newer
from
...utils
import
create_new_process_for_each_test
# This import automatically registers `torch.ops.silly.attention`
# This import automatically registers `torch.ops.silly.attention`
from
..silly_attention
import
get_global_counter
,
reset_global_counter
from
..silly_attention
import
get_global_counter
,
reset_global_counter
...
@@ -23,12 +29,7 @@ from ..silly_attention import get_global_counter, reset_global_counter
...
@@ -23,12 +29,7 @@ from ..silly_attention import get_global_counter, reset_global_counter
@
support_torch_compile
@
support_torch_compile
class
SillyModel
(
nn
.
Module
):
class
SillyModel
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
**
kwargs
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
''
,
**
kwargs
)
->
None
:
super
().
__init__
()
super
().
__init__
()
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -60,53 +61,64 @@ def _run_simple_model(
...
@@ -60,53 +61,64 @@ def _run_simple_model(
expected_num_backend_compilations
,
expected_num_backend_compilations
,
expected_num_cudagraph_captured
,
expected_num_cudagraph_captured
,
):
):
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
vllm_config
=
VllmConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
compilation_config
=
CompilationConfig
(
use_cudagraph
=
True
,
mode
=
CompilationMode
.
VLLM_COMPILE
,
use_inductor
=
use_inductor
,
use_inductor
=
use_inductor
,
splitting_ops
=
splitting_ops
,
splitting_ops
=
splitting_ops
,
use_inductor_graph_partition
=
use_inductor_graph_partition
,
use_inductor_graph_partition
=
use_inductor_graph_partition
,
cudagraph_copy_inputs
=
True
,
cudagraph_copy_inputs
=
True
,
cudagraph_capture_sizes
=
[
1
,
2
],
cudagraph_capture_sizes
=
[
1
,
2
],
))
)
)
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
model
=
SillyModel
(
vllm_config
=
vllm_config
,
prefix
=
''
)
model
=
SillyModel
(
vllm_config
=
vllm_config
,
prefix
=
""
)
inputs
=
torch
.
randn
(
100
).
cuda
()
inputs
=
torch
.
randn
(
100
).
cuda
()
with
compilation_counter
.
expect
(
with
(
compilation_counter
.
expect
(
num_graphs_seen
=
1
,
# one graph for the model
num_graphs_seen
=
1
,
# one graph for the model
num_piecewise_graphs_seen
=
expected_num_piecewise_graphs_seen
,
num_piecewise_graphs_seen
=
expected_num_piecewise_graphs_seen
,
num_piecewise_capturable_graphs_seen
=
num_piecewise_capturable_graphs_seen
=
expected_num_piecewise_capturable_graphs_seen
,
expected_num_piecewise_capturable_graphs_seen
,
num_backend_compilations
=
expected_num_backend_compilations
,
num_backend_compilations
=
expected_num_backend_compilations
,
num_cudagraph_captured
=
expected_num_cudagraph_captured
,
num_cudagraph_captured
=
expected_num_cudagraph_captured
,
),
set_forward_context
(
None
,
),
vllm_config
=
vllm_config
):
# background context
set_forward_context
(
None
,
vllm_config
=
vllm_config
),
):
# background context
# warm up with background context
# warm up with background context
model
(
inputs
)
model
(
inputs
)
# capturing/replaying should under context of cudagraph dispatching
# capturing/replaying should under context of cudagraph dispatching
with
set_forward_context
(
with
set_forward_context
(
None
,
None
,
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
cudagraph_runtime_mode
=
CUDAGraphMode
.
PIECEWISE
,
cudagraph_runtime_mode
=
CUDAGraphMode
.
PIECEWISE
,
batch_descriptor
=
BatchDescriptor
(
num_tokens
=
2
,
)):
batch_descriptor
=
BatchDescriptor
(
num_tokens
=
2
,
),
):
model
(
torch
.
randn
(
2
).
cuda
())
model
(
torch
.
randn
(
2
).
cuda
())
with
set_forward_context
(
with
set_forward_context
(
None
,
None
,
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
cudagraph_runtime_mode
=
CUDAGraphMode
.
PIECEWISE
,
cudagraph_runtime_mode
=
CUDAGraphMode
.
PIECEWISE
,
batch_descriptor
=
BatchDescriptor
(
num_tokens
=
1
,
)):
batch_descriptor
=
BatchDescriptor
(
num_tokens
=
1
,
),
):
model
(
torch
.
randn
(
1
).
cuda
())
model
(
torch
.
randn
(
1
).
cuda
())
input
=
torch
.
zeros
(
2
).
cuda
()
input
=
torch
.
zeros
(
2
).
cuda
()
reset_global_counter
()
reset_global_counter
()
with
set_forward_context
(
with
set_forward_context
(
None
,
None
,
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
cudagraph_runtime_mode
=
CUDAGraphMode
.
PIECEWISE
,
cudagraph_runtime_mode
=
CUDAGraphMode
.
PIECEWISE
,
batch_descriptor
=
BatchDescriptor
(
num_tokens
=
2
,
)):
batch_descriptor
=
BatchDescriptor
(
num_tokens
=
2
,
),
):
output
=
model
(
input
)
output
=
model
(
input
)
assert
get_global_counter
()
==
2
assert
get_global_counter
()
==
2
assert
torch
.
allclose
(
output
.
cpu
(),
torch
.
tensor
([
19.0
,
19.0
]))
assert
torch
.
allclose
(
output
.
cpu
(),
torch
.
tensor
([
19.0
,
19.0
]))
...
@@ -114,42 +126,42 @@ def _run_simple_model(
...
@@ -114,42 +126,42 @@ def _run_simple_model(
@
pytest
.
mark
.
parametrize
(
"use_inductor"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_inductor"
,
[
True
,
False
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
@
create_new_process_for_each_test
(
"spawn"
)
def
test_simple_piecewise_compile
(
use_inductor
):
def
test_simple_piecewise_compile
(
use_inductor
):
assert
VLLM_USE_V1
_run_simple_model
(
_run_simple_model
(
splitting_ops
=
[
"silly
.
attention"
],
splitting_ops
=
[
"silly
::
attention"
],
use_inductor_graph_partition
=
False
,
use_inductor_graph_partition
=
False
,
use_inductor
=
use_inductor
,
use_inductor
=
use_inductor
,
expected_num_piecewise_graphs_seen
=
5
,
# 2 * num_layers + 1
# 2 * num_layers + 1
expected_num_piecewise_capturable_graphs_seen
=
3
,
# 1 + num_layers
expected_num_piecewise_graphs_seen
=
5
,
expected_num_backend_compilations
=
# 1 + num_layers
3
,
# num_piecewise_capturable_graphs_seen
expected_num_piecewise_capturable_graphs_seen
=
3
,
expected_num_cudagraph_captured
=
# num_piecewise_capturable_graphs_seen
6
,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
expected_num_backend_compilations
=
3
,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
expected_num_cudagraph_captured
=
6
,
)
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
@
pytest
.
mark
.
parametrize
(
"splitting_ops"
,
[[
"silly.attention"
],
[]])
def
test_simple_inductor_graph_partition
(
monkeypatch
):
def
test_simple_inductor_graph_partition
(
splitting_ops
):
assert
VLLM_USE_V1
if
not
is_torch_equal_or_newer
(
"2.9.0.dev"
):
if
not
is_torch_equal_or_newer
(
"2.9.0.dev"
):
pytest
.
skip
(
"inductor graph partition is only available "
pytest
.
skip
(
"inductor graph partition is only available in PyTorch 2.9+"
)
"in PyTorch 2.9+"
)
# disable compile cache so that we run separately for different splitting_ops
# and get the expected number of cudagraphs captured.
monkeypatch
.
setenv
(
"VLLM_DISABLE_COMPILE_CACHE"
,
"1"
)
_run_simple_model
(
_run_simple_model
(
# inductor graph partition automatically resets splitting_ops
splitting_ops
=
[
"silly::attention"
],
# to be an empty list
splitting_ops
=
splitting_ops
,
use_inductor_graph_partition
=
True
,
use_inductor_graph_partition
=
True
,
use_inductor
=
True
,
use_inductor
=
True
,
expected_num_piecewise_graphs_seen
=
# Since not splitting at fx graph level
1
,
# since not splitting at fx graph level
expected_num_piecewise_graphs_seen
=
1
,
expected_num_piecewise_capturable_graphs_seen
=
# Since not splitting at fx graph level
1
,
# since not splitting at fx graph level
expected_num_piecewise_capturable_graphs_seen
=
1
,
expected_num_backend_compilations
=
# Since not splitting at fx graph level
1
,
# since not splitting at fx graph level
expected_num_backend_compilations
=
1
,
expected_num_cudagraph_captured
=
# Inductor graph partition still captures 6 graph, same as fx graph partition
6
,
# inductor graph partition still captures 6
expected_num_cudagraph_captured
=
6
,
# graph, same as fx graph partition.
)
)
tests/compile/piecewise/test_toy_llama.py
View file @
006693ed
...
@@ -8,8 +8,10 @@ This is a tractable model, the weights and computation are specially designed
...
@@ -8,8 +8,10 @@ This is a tractable model, the weights and computation are specially designed
if the config `tractable_init` is set to True. Otherwise, the weights are
if the config `tractable_init` is set to True. Otherwise, the weights are
initialized randomly with a fixed seed.
initialized randomly with a fixed seed.
"""
"""
from
copy
import
deepcopy
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
,
Optional
from
typing
import
Any
import
pytest
import
pytest
import
torch
import
torch
...
@@ -17,9 +19,17 @@ from torch import nn
...
@@ -17,9 +19,17 @@ from torch import nn
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
(
CompilationConfig
,
CompilationLevel
,
CUDAGraphMode
,
from
vllm.config
import
(
VllmConfig
,
set_current_vllm_config
)
CompilationConfig
,
CompilationMode
,
CUDAGraphMode
,
VllmConfig
,
set_current_vllm_config
,
)
from
vllm.forward_context
import
BatchDescriptor
,
set_forward_context
from
vllm.forward_context
import
BatchDescriptor
,
set_forward_context
from
vllm.utils.torch_utils
import
is_torch_equal_or_newer
from
...utils
import
create_new_process_for_each_test
# This import automatically registers `torch.ops.silly.attention`
# This import automatically registers `torch.ops.silly.attention`
from
..
import
silly_attention
# noqa: F401
from
..
import
silly_attention
# noqa: F401
...
@@ -43,15 +53,14 @@ class LlamaConfig:
...
@@ -43,15 +53,14 @@ class LlamaConfig:
factors
.
append
((
k
,
v
))
factors
.
append
((
k
,
v
))
factors
.
sort
()
factors
.
sort
()
import
hashlib
import
hashlib
return
hashlib
.
md5
(
str
(
factors
).
encode
(),
usedforsecurity
=
False
).
hexdigest
()
return
hashlib
.
md5
(
str
(
factors
).
encode
(),
usedforsecurity
=
False
).
hexdigest
()
def
__post_init__
(
self
):
def
__post_init__
(
self
):
assert
self
.
mlp_size
>=
self
.
hidden_size
assert
self
.
mlp_size
>=
self
.
hidden_size
class
LlamaMLP
(
nn
.
Module
):
class
LlamaMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
:
LlamaConfig
)
->
None
:
def
__init__
(
self
,
config
:
LlamaConfig
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
gate_up_projection
=
nn
.
Linear
(
self
.
gate_up_projection
=
nn
.
Linear
(
...
@@ -66,31 +75,31 @@ class LlamaMLP(nn.Module):
...
@@ -66,31 +75,31 @@ class LlamaMLP(nn.Module):
)
)
if
config
.
tractable_init
:
if
config
.
tractable_init
:
nn
.
init
.
eye_
(
self
.
gate_up_projection
.
weight
.
data
[:
config
.
mlp_size
])
nn
.
init
.
eye_
(
self
.
gate_up_projection
.
weight
.
data
[:
config
.
mlp_size
])
nn
.
init
.
eye_
(
self
.
gate_up_projection
.
weight
.
data
[
config
.
mlp_size
:])
nn
.
init
.
eye_
(
self
.
gate_up_projection
.
weight
.
data
[
config
.
mlp_size
:])
nn
.
init
.
eye_
(
self
.
down_projection
.
weight
.
data
)
nn
.
init
.
eye_
(
self
.
down_projection
.
weight
.
data
)
else
:
else
:
nn
.
init
.
xavier_normal_
(
self
.
gate_up_projection
.
weight
.
data
,
nn
.
init
.
xavier_normal_
(
generator
=
torch
.
Generator
().
manual_seed
(
self
.
gate_up_projection
.
weight
.
data
,
config
.
random_seed
),
generator
=
torch
.
Generator
().
manual_seed
(
config
.
random_seed
),
gain
=
0.001
)
gain
=
0.001
,
nn
.
init
.
xavier_normal_
(
self
.
down_projection
.
weight
.
data
,
)
generator
=
torch
.
Generator
().
manual_seed
(
nn
.
init
.
xavier_normal_
(
config
.
random_seed
),
self
.
down_projection
.
weight
.
data
,
gain
=
0.001
)
generator
=
torch
.
Generator
().
manual_seed
(
config
.
random_seed
),
gain
=
0.001
,
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
# for tractable_init and positive input, this is
# for tractable_init and positive input, this is
# essentially an elementwise-square
# essentially an elementwise-square
x
=
self
.
gate_up_projection
(
x
)
x
=
self
.
gate_up_projection
(
x
)
x
=
x
[:,
:
x
.
size
(
1
)
//
2
]
*
torch
.
nn
.
functional
.
relu
(
x
=
x
[:,
:
x
.
size
(
1
)
//
2
]
*
torch
.
nn
.
functional
.
relu
(
x
[:,
x
.
size
(
1
)
//
2
:])
x
[:,
x
.
size
(
1
)
//
2
:])
x
=
self
.
down_projection
(
x
)
x
=
self
.
down_projection
(
x
)
return
x
return
x
class
LlamaAttention
(
nn
.
Module
):
class
LlamaAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
LlamaConfig
)
->
None
:
def
__init__
(
self
,
config
:
LlamaConfig
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
qkv_projection
=
nn
.
Linear
(
self
.
qkv_projection
=
nn
.
Linear
(
...
@@ -106,21 +115,25 @@ class LlamaAttention(nn.Module):
...
@@ -106,21 +115,25 @@ class LlamaAttention(nn.Module):
)
)
if
config
.
tractable_init
:
if
config
.
tractable_init
:
nn
.
init
.
eye_
(
self
.
qkv_projection
.
weight
.
data
[:
config
.
hidden_size
])
nn
.
init
.
eye_
(
self
.
qkv_projection
.
weight
.
data
[:
config
.
hidden_size
])
nn
.
init
.
eye_
(
self
.
qkv_projection
.
weight
.
data
[
config
.
hidden_size
:
2
*
nn
.
init
.
eye_
(
config
.
hidden_size
])
self
.
qkv_projection
.
weight
.
data
[
nn
.
init
.
eye_
(
self
.
qkv_projection
.
weight
.
data
[
2
*
config
.
hidden_size
:
2
*
config
.
hidden_size
config
.
hidden_size
:])
]
)
nn
.
init
.
eye_
(
self
.
qkv_projection
.
weight
.
data
[
2
*
config
.
hidden_size
:])
nn
.
init
.
eye_
(
self
.
output_projection
.
weight
.
data
)
nn
.
init
.
eye_
(
self
.
output_projection
.
weight
.
data
)
else
:
else
:
nn
.
init
.
xavier_normal_
(
self
.
qkv_projection
.
weight
.
data
,
nn
.
init
.
xavier_normal_
(
generator
=
torch
.
Generator
().
manual_seed
(
self
.
qkv_projection
.
weight
.
data
,
config
.
random_seed
),
generator
=
torch
.
Generator
().
manual_seed
(
config
.
random_seed
),
gain
=
0.001
)
gain
=
0.001
,
nn
.
init
.
xavier_normal_
(
self
.
output_projection
.
weight
.
data
,
)
generator
=
torch
.
Generator
().
manual_seed
(
nn
.
init
.
xavier_normal_
(
config
.
random_seed
),
self
.
output_projection
.
weight
.
data
,
gain
=
0.001
)
generator
=
torch
.
Generator
().
manual_seed
(
config
.
random_seed
),
gain
=
0.001
,
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -144,7 +157,6 @@ class LlamaAttention(nn.Module):
...
@@ -144,7 +157,6 @@ class LlamaAttention(nn.Module):
class
LlamaDecoderLayer
(
nn
.
Module
):
class
LlamaDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
LlamaConfig
)
->
None
:
def
__init__
(
self
,
config
:
LlamaConfig
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
self_attention
=
LlamaAttention
(
config
)
self
.
self_attention
=
LlamaAttention
(
config
)
...
@@ -154,7 +166,7 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -154,7 +166,7 @@ class LlamaDecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
,
residual
:
torch
.
Tensor
|
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
For tractable computation:
For tractable computation:
...
@@ -164,7 +176,7 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -164,7 +176,7 @@ class LlamaDecoderLayer(nn.Module):
- if residual is not None, the outputs are:
- if residual is not None, the outputs are:
- residual = (hidden_states + residual + 1) * 3 + positions * 2 + hidden_states + residual = (hidden_states + residual) * 4 + positions * 2 + 3
- residual = (hidden_states + residual + 1) * 3 + positions * 2 + hidden_states + residual = (hidden_states + residual) * 4 + positions * 2 + 3
- hidden_states = (residual + 1) ** 2
- hidden_states = (residual + 1) ** 2
"""
# noqa
"""
# noqa
if
residual
is
None
:
if
residual
is
None
:
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
hidden_states
+
1
hidden_states
=
hidden_states
+
1
...
@@ -173,8 +185,9 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -173,8 +185,9 @@ class LlamaDecoderLayer(nn.Module):
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
hidden_states
+
1
hidden_states
=
hidden_states
+
1
hidden_states
=
self
.
self_attention
(
positions
=
positions
,
hidden_states
=
self
.
self_attention
(
hidden_states
=
hidden_states
)
positions
=
positions
,
hidden_states
=
hidden_states
)
hidden_states
=
hidden_states
+
residual
hidden_states
=
hidden_states
+
residual
residual
=
hidden_states
residual
=
hidden_states
...
@@ -186,27 +199,29 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -186,27 +199,29 @@ class LlamaDecoderLayer(nn.Module):
@
support_torch_compile
@
support_torch_compile
class
LlamaModel
(
nn
.
Module
):
class
LlamaModel
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
*
,
*
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
config
:
LlamaConfig
,
config
:
LlamaConfig
,
prefix
:
str
=
''
,
prefix
:
str
=
""
,
**
kwargs
)
->
None
:
**
kwargs
,
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
embedding_tokens
=
nn
.
Embedding
(
self
.
embedding_tokens
=
nn
.
Embedding
(
num_embeddings
=
config
.
vocab_size
,
num_embeddings
=
config
.
vocab_size
,
embedding_dim
=
config
.
hidden_size
,
embedding_dim
=
config
.
hidden_size
,
)
)
self
.
layers
=
nn
.
ModuleList
(
self
.
layers
=
nn
.
ModuleList
(
[
LlamaDecoderLayer
(
config
)
for
_
in
range
(
config
.
num_layers
)])
[
LlamaDecoderLayer
(
config
)
for
_
in
range
(
config
.
num_layers
)]
)
# this is the initial value of the hidden states
# this is the initial value of the hidden states
self
.
embedding_tokens
.
weight
.
data
.
fill_
(
config
.
init_value
)
self
.
embedding_tokens
.
weight
.
data
.
fill_
(
config
.
init_value
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
Optional
[
torch
.
Tensor
]
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embedding_tokens
(
input_ids
)
hidden_states
=
self
.
embedding_tokens
(
input_ids
)
...
@@ -216,168 +231,195 @@ class LlamaModel(nn.Module):
...
@@ -216,168 +231,195 @@ class LlamaModel(nn.Module):
return
hidden_states
return
hidden_states
def
tractable_computation
(
input_ids
:
torch
.
Tensor
,
def
tractable_computation
(
positions
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
config
:
LlamaConfig
,
positions
:
torch
.
Tensor
,
init_value
:
float
=
1.0
)
->
torch
.
Tensor
:
config
:
LlamaConfig
,
hidden_states
=
torch
.
ones
(
input_ids
.
size
(
0
),
init_value
:
float
=
1.0
,
config
.
hidden_size
,
)
->
torch
.
Tensor
:
device
=
input_ids
.
device
,
hidden_states
=
(
dtype
=
input_ids
.
dtype
)
*
init_value
torch
.
ones
(
input_ids
.
size
(
0
),
config
.
hidden_size
,
device
=
input_ids
.
device
,
dtype
=
input_ids
.
dtype
,
)
*
init_value
)
# first layer
# first layer
residual
=
hidden_states
*
4
+
positions
.
unsqueeze
(
1
)
*
2
+
3
residual
=
hidden_states
*
4
+
positions
.
unsqueeze
(
1
)
*
2
+
3
hidden_states
=
(
residual
+
1
)
**
2
hidden_states
=
(
residual
+
1
)
**
2
# following layers
# following layers
for
_
in
range
(
config
.
num_layers
-
1
):
for
_
in
range
(
config
.
num_layers
-
1
):
hidden_states
=
hidden_states
+
residual
hidden_states
=
hidden_states
+
residual
residual
=
hidden_states
*
4
+
positions
.
unsqueeze
(
1
)
*
2
+
3
residual
=
hidden_states
*
4
+
positions
.
unsqueeze
(
1
)
*
2
+
3
hidden_states
=
(
residual
+
1
)
**
2
hidden_states
=
(
residual
+
1
)
**
2
return
hidden_states
return
hidden_states
@
torch
.
inference_mode
@
torch
.
inference_mode
def
run_model
(
llama_config
,
def
run_model
(
llama_config
,
compile_config
:
CompilationConfig
)
->
torch
.
Tensor
:
use_compile
:
bool
,
# Start with a fresh copy to make sure there's no cache dir sharing
use_inductor
:
bool
,
compile_config
=
deepcopy
(
compile_config
)
split_attn
:
bool
=
False
)
->
torch
.
Tensor
:
cudagraph_runtime_mode
=
compile_config
.
cudagraph_mode
if
use_compile
:
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
compilation_config
=
compile_config
,
additional_config
=
llama_config
level
=
CompilationLevel
.
PIECEWISE
,
)
use_cudagraph
=
True
,
use_inductor
=
use_inductor
,
cudagraph_capture_sizes
=
[
1
,
2
],
)
if
split_attn
:
compilation_config
.
splitting_ops
=
[
"silly.attention"
]
cudagraph_runtime_mode
=
CUDAGraphMode
.
PIECEWISE
else
:
compilation_config
=
CompilationConfig
(
level
=
CompilationLevel
.
NO_COMPILATION
,
)
cudagraph_runtime_mode
=
CUDAGraphMode
.
NONE
vllm_config
=
VllmConfig
(
compilation_config
=
compilation_config
,
additional_config
=
llama_config
)
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
model
=
LlamaModel
(
config
=
llama_config
,
model
=
(
vllm_config
=
vllm_config
,
LlamaModel
(
config
=
llama_config
,
vllm_config
=
vllm_config
,
prefix
=
""
)
prefix
=
""
).
eval
().
cuda
()
.
eval
()
.
cuda
()
)
with
set_forward_context
({},
with
set_forward_context
({},
vllm_config
=
vllm_config
):
# background context
vllm_config
=
vllm_config
):
# background context
B
=
16
# max batch size
B
=
16
# max batch size
input_ids
=
torch
.
randint
(
0
,
llama_config
.
vocab_size
,
(
B
,
)).
cuda
()
input_ids
=
torch
.
randint
(
0
,
llama_config
.
vocab_size
,
(
B
,)).
cuda
()
positions
=
torch
.
arange
(
B
).
cuda
()
positions
=
torch
.
arange
(
B
).
cuda
()
# warmup for the model with cudagraph_mode NONE
# warmup for the model with cudagraph_mode NONE
model
(
input_ids
,
positions
)
model
(
input_ids
,
positions
)
# simulate cudagraphs capturing
# simulate cudagraphs capturing
with
set_forward_context
({},
with
set_forward_context
(
vllm_config
=
vllm_config
,
{},
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
vllm_config
=
vllm_config
,
batch_descriptor
=
BatchDescriptor
(
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
num_tokens
=
2
,
)):
batch_descriptor
=
BatchDescriptor
(
num_tokens
=
2
,
),
):
model
(
input_ids
[:
2
],
positions
[:
2
])
model
(
input_ids
[:
2
],
positions
[:
2
])
with
set_forward_context
({},
with
set_forward_context
(
vllm_config
=
vllm_config
,
{},
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
vllm_config
=
vllm_config
,
batch_descriptor
=
BatchDescriptor
(
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
num_tokens
=
1
,
)):
batch_descriptor
=
BatchDescriptor
(
num_tokens
=
1
,
),
):
model
(
input_ids
[:
1
],
positions
[:
1
])
model
(
input_ids
[:
1
],
positions
[:
1
])
input_ids
[:
2
].
zero_
()
input_ids
[:
2
].
zero_
()
# simulate cudagraphs replay
# simulate cudagraphs replay
with
set_forward_context
({},
with
set_forward_context
(
vllm_config
=
vllm_config
,
{},
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
vllm_config
=
vllm_config
,
batch_descriptor
=
BatchDescriptor
(
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
num_tokens
=
2
,
)):
batch_descriptor
=
BatchDescriptor
(
num_tokens
=
2
,
),
):
output
=
model
(
input_ids
[:
2
],
positions
[:
2
])
output
=
model
(
input_ids
[:
2
],
positions
[:
2
])
output
=
output
.
cpu
()
output
=
output
.
cpu
()
if
llama_config
.
tractable_init
:
if
llama_config
.
tractable_init
:
expected_output
=
tractable_computation
(
input_ids
[:
2
],
expected_output
=
tractable_computation
(
positions
[:
2
],
input_ids
[:
2
],
positions
[:
2
],
llama_config
llama_config
).
cpu
()
).
cpu
()
assert
torch
.
allclose
(
output
,
expected_output
)
assert
torch
.
allclose
(
output
,
expected_output
)
else
:
else
:
return
output
.
cpu
()
return
output
.
cpu
()
@
pytest
.
mark
.
parametrize
(
"use_inductor"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
def
test_toy_llama
(
use_inductor
:
bool
):
"backend, use_inductor_graph_partition"
,
[
(
"eager"
,
False
),
# No inductor
(
"inductor"
,
False
),
# Inductor, Dynamo partition
(
"inductor"
,
True
),
# Inductor, Inductor partition
],
)
@
create_new_process_for_each_test
(
"spawn"
)
def
test_toy_llama
(
backend
:
str
,
use_inductor_graph_partition
:
bool
,
monkeypatch
,
tmp_path
):
# We disable the vLLM compile cache into a new tmp dir for 1 reason:
# 1. To make sure we can properly track the number of Inductor compilations.
monkeypatch
.
setenv
(
"VLLM_DISABLE_COMPILE_CACHE"
,
"1"
)
if
use_inductor_graph_partition
and
not
is_torch_equal_or_newer
(
"2.9.0.dev"
):
pytest
.
skip
(
"Inductor graph partition only supported in torch>=2.9"
)
# compare output with and without piecewise compilation
# compare output with and without piecewise compilation
llama_config
=
LlamaConfig
(
hidden_size
=
128
,
llama_config
=
LlamaConfig
(
mlp_size
=
256
,
hidden_size
=
128
,
mlp_size
=
256
,
vocab_size
=
128
,
num_layers
=
12
vocab_size
=
128
,
)
num_layers
=
12
)
tractable_config
=
LlamaConfig
(
hidden_size
=
128
,
mlp_size
=
256
,
vocab_size
=
128
,
num_layers
=
2
,
tractable_init
=
True
)
compile_config_no_compile
=
CompilationConfig
(
mode
=
CompilationMode
.
NONE
,
cudagraph_mode
=
CUDAGraphMode
.
NONE
,
backend
=
"eager"
,
)
compile_config_no_split
=
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
use_inductor_graph_partition
=
use_inductor_graph_partition
,
cudagraph_mode
=
CUDAGraphMode
.
PIECEWISE
,
backend
=
backend
,
cudagraph_capture_sizes
=
[
1
,
2
],
)
tractable_config
=
LlamaConfig
(
hidden_size
=
128
,
compile_config_split
=
deepcopy
(
compile_config_no_split
)
mlp_size
=
256
,
compile_config_split
.
splitting_ops
=
[
"silly::attention"
]
vocab_size
=
128
,
num_layers
=
2
,
tractable_init
=
True
)
outputs
=
[]
outputs
=
[]
with
compilation_counter
.
expect
(
with
compilation_counter
.
expect
(
num_graphs_seen
=
0
,
num_graphs_seen
=
0
,
num_piecewise_graphs_seen
=
0
,
num_piecewise_graphs_seen
=
0
,
num_piecewise_capturable_graphs_seen
=
0
,
num_piecewise_capturable_graphs_seen
=
0
,
num_backend_compilations
=
0
,
num_backend_compilations
=
0
,
num_cudagraph_captured
=
0
,
num_cudagraph_captured
=
0
,
):
):
outputs
.
append
(
outputs
.
append
(
run_model
(
llama_config
,
compile_config_no_compile
))
run_model
(
llama_config
,
use_inductor
=
False
,
use_compile
=
False
))
run_model
(
tractable_config
,
use_inductor
=
False
,
use_compile
=
False
)
if
use_inductor
:
run_model
(
tractable_config
,
compile_config_no_compile
)
if
backend
==
"inductor"
:
kwargs
=
{
"num_inductor_compiles"
:
1
,
"num_eager_compiles"
:
0
}
kwargs
=
{
"num_inductor_compiles"
:
1
,
"num_eager_compiles"
:
0
}
else
:
else
:
kwargs
=
{
"num_eager_compiles"
:
1
,
"num_inductor_compiles"
:
0
}
kwargs
=
{
"num_eager_compiles"
:
1
,
"num_inductor_compiles"
:
0
}
with
compilation_counter
.
expect
(
with
compilation_counter
.
expect
(
num_graphs_seen
=
1
,
# one graph for the model
num_graphs_seen
=
1
,
# one graph for the model
num_piecewise_graphs_seen
=
1
,
num_piecewise_graphs_seen
=
1
,
num_piecewise_capturable_graphs_seen
=
1
,
num_piecewise_capturable_graphs_seen
=
1
,
num_backend_compilations
=
1
,
# num_piecewise_capturable_graphs_seen
num_backend_compilations
=
1
,
# num_piecewise_capturable_graphs_seen
num_cudagraph_captured
=
num_cudagraph_captured
=
2
,
2
,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
**
kwargs
,
**
kwargs
,
):
):
outputs
.
append
(
outputs
.
append
(
run_model
(
llama_config
,
compile_config_no_split
))
run_model
(
llama_config
,
use_inductor
=
use_inductor
,
run_model
(
tractable_config
,
compile_config_no_split
)
use_compile
=
True
))
run_model
(
tractable_config
,
use_inductor
=
use_inductor
,
use_compile
=
True
)
if
use_inductor_graph_partition
:
num_piecewise_fx
=
1
num_piecewise_capturable_fx
=
1
else
:
num_piecewise_fx
=
2
*
llama_config
.
num_layers
+
1
num_piecewise_capturable_fx
=
1
+
llama_config
.
num_layers
with
compilation_counter
.
expect
(
with
compilation_counter
.
expect
(
num_graphs_seen
=
1
,
# one graph for the model
num_graphs_seen
=
1
,
# one graph for the model
num_piecewise_graphs_seen
=
2
*
llama_config
.
num_layers
+
num_piecewise_graphs_seen
=
num_piecewise_fx
,
1
,
# 2 * num_layers + 1
num_piecewise_capturable_graphs_seen
=
num_piecewise_capturable_fx
,
num_piecewise_capturable_graphs_seen
=
1
+
num_backend_compilations
=
num_piecewise_capturable_fx
,
llama_config
.
num_layers
,
# 1 + num_layers
# num_cudagraph_sizes * num_partitions
num_backend_compilations
=
1
+
num_cudagraph_captured
=
2
*
(
1
+
llama_config
.
num_layers
),
llama_config
.
num_layers
,
# num_piecewise_capturable_graphs_seen
num_cudagraph_captured
=
2
*
(
1
+
llama_config
.
num_layers
),
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
):
outputs
.
append
(
outputs
.
append
(
run_model
(
llama_config
,
compile_config_split
))
run_model
(
llama_config
,
run_model
(
tractable_config
,
compile_config_split
)
use_inductor
=
use_inductor
,
use_compile
=
True
,
split_attn
=
True
))
run_model
(
tractable_config
,
use_inductor
=
use_inductor
,
use_compile
=
True
,
split_attn
=
True
)
for
i
in
range
(
1
,
len
(
outputs
)):
for
i
in
range
(
1
,
len
(
outputs
)):
assert
torch
.
allclose
(
outputs
[
0
],
outputs
[
i
])
assert
torch
.
allclose
(
outputs
[
0
],
outputs
[
i
])
...
@@ -388,17 +430,15 @@ def benchmark():
...
@@ -388,17 +430,15 @@ def benchmark():
from
triton.testing
import
do_bench
from
triton.testing
import
do_bench
# similar to llama 3.1-8B
# similar to llama 3.1-8B
llama_config
=
LlamaConfig
(
hidden_size
=
4096
,
llama_config
=
LlamaConfig
(
mlp_size
=
14336
,
hidden_size
=
4096
,
mlp_size
=
14336
,
vocab_size
=
128
*
1024
,
num_layers
=
32
vocab_size
=
128
*
1024
,
)
num_layers
=
32
)
# a tiny model to measure the overhead
# a tiny model to measure the overhead
# of piecewise cudagraph
# of piecewise cudagraph
llama_config
=
LlamaConfig
(
hidden_size
=
40
,
llama_config
=
LlamaConfig
(
mlp_size
=
80
,
hidden_size
=
40
,
mlp_size
=
80
,
vocab_size
=
128
,
num_layers
=
2
vocab_size
=
128
,
)
num_layers
=
2
)
cudagraph_sizes
=
[
1
,
2
,
4
]
+
[
i
*
8
for
i
in
range
(
1
,
33
)]
cudagraph_sizes
=
[
1
,
2
,
4
]
+
[
i
*
8
for
i
in
range
(
1
,
33
)]
...
@@ -411,25 +451,27 @@ def benchmark():
...
@@ -411,25 +451,27 @@ def benchmark():
for
piecewise
in
[
False
,
True
]:
for
piecewise
in
[
False
,
True
]:
if
piecewise
:
if
piecewise
:
compilation_config
=
CompilationConfig
(
compilation_config
=
CompilationConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
mode
=
CompilationMode
.
VLLM_COMPILE
,
use_cudagraph
=
True
,
splitting_ops
=
[
"silly::attention"
],
splitting_ops
=
[
"silly.attention"
],
cudagraph_capture_sizes
=
cudagraph_sizes
,
cudagraph_capture_sizes
=
cudagraph_sizes
,
)
)
else
:
else
:
compilation_config
=
CompilationConfig
(
compilation_config
=
CompilationConfig
(
level
=
Compilation
Level
.
PIECEWIS
E
,
mode
=
Compilation
Mode
.
VLLM_COMPIL
E
,
cudagraph_capture_sizes
=
cudagraph_sizes
,
cudagraph_capture_sizes
=
cudagraph_sizes
,
)
)
vllm_config
=
VllmConfig
(
compilation_config
=
compilation_config
)
vllm_config
=
VllmConfig
(
compilation_config
=
compilation_config
)
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
model
=
LlamaModel
(
config
=
llama_config
,
model
=
(
vllm_config
=
vllm_config
,
LlamaModel
(
config
=
llama_config
,
vllm_config
=
vllm_config
,
prefix
=
""
)
prefix
=
""
).
eval
().
cuda
().
to
(
torch
.
bfloat16
)
.
eval
()
.
cuda
()
.
to
(
torch
.
bfloat16
)
)
B
=
256
# max batch size
B
=
256
# max batch size
input_ids
=
torch
.
randint
(
0
,
llama_config
.
vocab_size
,
(
B
,
)).
cuda
()
input_ids
=
torch
.
randint
(
0
,
llama_config
.
vocab_size
,
(
B
,)).
cuda
()
positions
=
torch
.
arange
(
B
).
cuda
().
to
(
torch
.
bfloat16
)
positions
=
torch
.
arange
(
B
).
cuda
().
to
(
torch
.
bfloat16
)
graphs
=
{}
graphs
=
{}
...
@@ -451,22 +493,31 @@ def benchmark():
...
@@ -451,22 +493,31 @@ def benchmark():
# and use it later, because it will look up the name `b` in the
# and use it later, because it will look up the name `b` in the
# enclosing scope, and the value of `b` will always be 256.
# enclosing scope, and the value of `b` will always be 256.
# it is fine here, because we only use the lambda function once.
# it is fine here, because we only use the lambda function once.
runtime
=
do_bench
(
lambda
:
graphs
[
b
][
0
]
# noqa
runtime
=
do_bench
(
(
input_ids
[:
b
],
positions
[:
b
]))
# noqa
lambda
:
graphs
[
b
][
0
](
# noqa
input_ids
[:
b
],
# noqa
positions
[:
b
],
# noqa
)
)
piecewise_cudagraph_time
[
b
]
=
runtime
piecewise_cudagraph_time
[
b
]
=
runtime
else
:
else
:
runtime
=
do_bench
(
lambda
:
graphs
[
b
][
0
].
replay
())
# noqa
runtime
=
do_bench
(
lambda
:
graphs
[
b
][
0
].
replay
())
# noqa
eager_runtime
=
do_bench
(
eager_runtime
=
do_bench
(
lambda
:
model
(
input_ids
[:
b
],
positions
[:
b
]))
# noqa
lambda
:
model
(
input_ids
[:
b
],
positions
[:
b
]))
# noqa
full_cudagraph_time
[
b
]
=
runtime
full_cudagraph_time
[
b
]
=
runtime
eager_time
[
b
]
=
eager_runtime
eager_time
[
b
]
=
eager_runtime
# print in tabular format
# print in tabular format
print
(
"batch size
\t
eager mode
\t
full cudagraph
\t
piecewise cudagraph"
)
print
(
"batch size
\t
eager mode
\t
full cudagraph
\t
piecewise cudagraph"
)
for
b
in
cudagraph_sizes
:
for
b
in
cudagraph_sizes
:
print
(
f
"
{
b
}
\t
{
eager_time
[
b
]:.
3
f
}
\t
{
full_cudagraph_time
[
b
]:.
3
f
}
"
print
(
f
"
\t
{
piecewise_cudagraph_time
[
b
]:.
3
f
}
"
)
f
"
{
b
}
\t
{
eager_time
[
b
]:.
3
f
}
\t
{
full_cudagraph_time
[
b
]:.
3
f
}
"
f
"
\t
{
piecewise_cudagraph_time
[
b
]:.
3
f
}
"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
benchmark
()
# Protect against subprocess reimport when using spawn_new_process_for_each_test
import
os
if
os
.
environ
.
get
(
"RUNNING_IN_SUBPROCESS"
)
!=
"1"
:
benchmark
()
tests/compile/silly_attention.py
View file @
006693ed
...
@@ -8,7 +8,7 @@ Centralizes custom operation definitions to avoid duplicate registrations.
...
@@ -8,7 +8,7 @@ Centralizes custom operation definitions to avoid duplicate registrations.
import
torch
import
torch
from
torch.library
import
Library
from
torch.library
import
Library
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
.torch_utils
import
direct_register_custom_op
# Shared library for all compilation test operations
# Shared library for all compilation test operations
# Using "silly" namespace to match existing test expectations
# Using "silly" namespace to match existing test expectations
...
@@ -31,8 +31,9 @@ def reset_global_counter():
...
@@ -31,8 +31,9 @@ def reset_global_counter():
_global_counter
=
0
_global_counter
=
0
def
silly_attention
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
def
silly_attention
(
out
:
torch
.
Tensor
)
->
None
:
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
"""
"""
Unified attention implementation that depends on
Unified attention implementation that depends on
all inputs and affects the output.
all inputs and affects the output.
...
@@ -47,8 +48,9 @@ def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
...
@@ -47,8 +48,9 @@ def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out
.
copy_
(
q
+
k
+
v
)
out
.
copy_
(
q
+
k
+
v
)
def
silly_attention_fake
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
def
silly_attention_fake
(
out
:
torch
.
Tensor
)
->
None
:
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
"""Fake implementation for testing"""
"""Fake implementation for testing"""
return
return
...
@@ -60,5 +62,4 @@ direct_register_custom_op(
...
@@ -60,5 +62,4 @@ direct_register_custom_op(
mutates_args
=
[
"out"
],
mutates_args
=
[
"out"
],
fake_impl
=
silly_attention_fake
,
fake_impl
=
silly_attention_fake
,
target_lib
=
silly_lib
,
target_lib
=
silly_lib
,
tags
=
(
torch
.
_C
.
Tag
.
cudagraph_unsafe
,
),
)
)
tests/compile/test_aot_compile.py
0 → 100644
View file @
006693ed
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
tempfile
from
contextlib
import
contextmanager
import
pytest
import
torch
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
(
CompilationConfig
,
CompilationMode
,
VllmConfig
,
set_current_vllm_config
,
)
from
vllm.forward_context
import
set_forward_context
from
vllm.utils.torch_utils
import
is_torch_equal_or_newer
def
reference_fn
(
x
:
torch
.
Tensor
):
assert
x
.
shape
[
0
]
<=
42
assert
x
.
shape
[
0
]
%
2
==
0
for
_
in
range
(
3000
):
x
=
x
+
x
.
shape
[
0
]
return
x
@
support_torch_compile
class
CompiledMod
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
()
def
forward
(
self
,
x
:
torch
.
Tensor
):
return
reference_fn
(
x
)
def
make_vllm_config
()
->
VllmConfig
:
return
VllmConfig
(
compilation_config
=
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
)
)
@
contextmanager
def
use_vllm_config
(
vllm_config
:
VllmConfig
):
with
set_forward_context
({},
vllm_config
),
set_current_vllm_config
(
vllm_config
):
yield
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0.dev"
),
reason
=
"requires torch 2.10"
)
def
test_no_dynamo_cache_entry
(
monkeypatch
:
pytest
.
MonkeyPatch
):
with
monkeypatch
.
context
()
as
m
:
vllm_config
=
make_vllm_config
()
args
=
(
torch
.
randn
(
10
,
10
),)
expected
=
reference_fn
(
*
args
)
with
use_vllm_config
(
vllm_config
):
m
.
setenv
(
"VLLM_USE_AOT_COMPILE"
,
"0"
)
with
(
pytest
.
raises
(
RuntimeError
,
match
=
"Detected recompile"
),
torch
.
compiler
.
set_stance
(
"fail_on_recompile"
),
):
CompiledMod
(
vllm_config
=
vllm_config
)(
*
args
)
m
.
setenv
(
"VLLM_USE_AOT_COMPILE"
,
"1"
)
torch
.
_dynamo
.
reset
()
with
torch
.
compiler
.
set_stance
(
"fail_on_recompile"
):
actual
=
CompiledMod
(
vllm_config
=
vllm_config
)(
*
args
)
assert
torch
.
allclose
(
actual
,
expected
)
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0.dev"
),
reason
=
"requires torch 2.10"
)
def
test_force_aot_load
(
monkeypatch
:
pytest
.
MonkeyPatch
):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
,
monkeypatch
.
context
()
as
m
:
args
=
(
torch
.
randn
(
10
,
10
),)
m
.
setenv
(
"VLLM_USE_AOT_COMPILE"
,
"1"
)
m
.
setenv
(
"VLLM_FORCE_AOT_LOAD"
,
"1"
)
m
.
setenv
(
"VLLM_CACHE_ROOT"
,
tmpdirname
)
vllm_config
=
make_vllm_config
()
with
use_vllm_config
(
vllm_config
),
pytest
.
raises
(
FileNotFoundError
):
CompiledMod
(
vllm_config
=
vllm_config
)(
*
args
)
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0.dev"
),
reason
=
"requires torch 2.10"
)
def
test_save_and_load
(
monkeypatch
:
pytest
.
MonkeyPatch
):
with
monkeypatch
.
context
()
as
m
:
args
=
(
torch
.
randn
(
10
,
10
),)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
m
.
setenv
(
"VLLM_CACHE_ROOT"
,
tmpdirname
)
m
.
setenv
(
"VLLM_USE_AOT_COMPILE"
,
"1"
)
vllm_config
=
make_vllm_config
()
with
use_vllm_config
(
vllm_config
):
expected
=
CompiledMod
(
vllm_config
=
vllm_config
)(
*
args
)
m
.
setenv
(
"VLLM_FORCE_AOT_LOAD"
,
"1"
)
vllm_config
=
make_vllm_config
()
with
use_vllm_config
(
vllm_config
):
ret
=
CompiledMod
(
vllm_config
=
vllm_config
)(
*
args
)
assert
torch
.
allclose
(
ret
,
expected
)
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0.dev"
),
reason
=
"requires torch 2.10"
)
def
test_shape_env
(
monkeypatch
:
pytest
.
MonkeyPatch
):
"""
Test that the shape environment is correctly serialized and preserved
when loading from cache.
"""
with
monkeypatch
.
context
()
as
m
:
args
=
(
torch
.
randn
(
10
,
10
),)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
m
.
setenv
(
"VLLM_CACHE_ROOT"
,
tmpdirname
)
m
.
setenv
(
"VLLM_USE_AOT_COMPILE"
,
"1"
)
vllm_config
=
make_vllm_config
()
with
use_vllm_config
(
vllm_config
):
compiled_mod
=
CompiledMod
(
vllm_config
=
vllm_config
)
compiled_mod
(
*
args
)
artifacts
=
compiled_mod
.
aot_compiled_fn
.
_artifacts
guards_string
=
artifacts
.
compiled_fn
.
shape_env
.
format_guards
()
assert
guards_string
==
" - s77 <= 42
\n
- Eq(Mod(s77, 2), 0)"
m
.
setenv
(
"VLLM_FORCE_AOT_LOAD"
,
"1"
)
vllm_config
=
make_vllm_config
()
with
use_vllm_config
(
vllm_config
):
compiled_mod
=
CompiledMod
(
vllm_config
=
vllm_config
)
compiled_mod
(
*
args
)
artifacts
=
compiled_mod
.
aot_compiled_fn
.
_artifacts
guards_string
=
artifacts
.
compiled_fn
.
shape_env
.
format_guards
()
assert
guards_string
==
" - s77 <= 42
\n
- Eq(Mod(s77, 2), 0)"
tests/compile/test_async_tp.py
View file @
006693ed
...
@@ -8,18 +8,31 @@ import torch
...
@@ -8,18 +8,31 @@ import torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.compilation.collective_fusion
import
AsyncTPPass
from
vllm.compilation.collective_fusion
import
AsyncTPPass
from
vllm.config
import
(
CompilationConfig
,
DeviceConfig
,
ModelConfig
,
from
vllm.config
import
(
PassConfig
,
VllmConfig
)
CompilationConfig
,
from
vllm.distributed
import
(
tensor_model_parallel_all_gather
,
CompilationMode
,
tensor_model_parallel_reduce_scatter
)
DeviceConfig
,
from
vllm.distributed.parallel_state
import
(
init_distributed_environment
,
ModelConfig
,
initialize_model_parallel
)
PassConfig
,
VllmConfig
,
)
from
vllm.distributed
import
(
tensor_model_parallel_all_gather
,
tensor_model_parallel_reduce_scatter
,
)
from
vllm.distributed.parallel_state
import
(
init_distributed_environment
,
initialize_model_parallel
,
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
update_environment_variables
from
vllm.utils
.system_utils
import
update_environment_variables
from
..models.registry
import
HF_EXAMPLE_MODELS
from
..models.registry
import
HF_EXAMPLE_MODELS
from
..utils
import
(
compare_two_settings
,
create_new_process_for_each_test
,
from
..utils
import
(
multi_gpu_test
)
compare_two_settings
,
create_new_process_for_each_test
,
multi_gpu_test
,
)
from
.backend
import
TestBackend
from
.backend
import
TestBackend
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
...
@@ -33,21 +46,20 @@ prompts = [
...
@@ -33,21 +46,20 @@ prompts = [
class
TestMMRSModel
(
torch
.
nn
.
Module
):
class
TestMMRSModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
=
16
,
dtype
=
torch
.
float16
):
def
__init__
(
self
,
hidden_size
=
16
,
dtype
=
torch
.
float16
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
gate_proj
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
self
.
gate_proj
=
torch
.
nn
.
Parameter
(
(
self
.
hidden_size
*
2
,
hidden_size
)),
torch
.
empty
(
(
self
.
hidden_size
*
2
,
hidden_size
)),
requires_grad
=
False
requires_grad
=
False
)
)
# Initialize weights
# Initialize weights
torch
.
nn
.
init
.
normal_
(
self
.
gate_proj
,
std
=
0.02
)
torch
.
nn
.
init
.
normal_
(
self
.
gate_proj
,
std
=
0.02
)
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
"""
"""
Forward pass implementing the mm + reduce scatter in the FX graph
Forward pass implementing the mm + reduce scatter in the FX graph
"""
"""
# Reshape input
# Reshape input
view
=
hidden_states
.
reshape
(
-
1
,
self
.
hidden_size
)
view
=
hidden_states
.
reshape
(
-
1
,
self
.
hidden_size
)
...
@@ -66,14 +78,13 @@ class TestMMRSModel(torch.nn.Module):
...
@@ -66,14 +78,13 @@ class TestMMRSModel(torch.nn.Module):
class
TestAGMMModel
(
torch
.
nn
.
Module
):
class
TestAGMMModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
=
16
,
dtype
=
torch
.
float16
):
def
__init__
(
self
,
hidden_size
=
16
,
dtype
=
torch
.
float16
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
self
.
weight
=
torch
.
nn
.
Parameter
(
(
hidden_size
,
hidden_size
)),
torch
.
empty
(
(
hidden_size
,
hidden_size
)),
requires_grad
=
False
requires_grad
=
False
)
)
# Initialize weights
# Initialize weights
torch
.
nn
.
init
.
normal_
(
self
.
weight
,
std
=
0.02
)
torch
.
nn
.
init
.
normal_
(
self
.
weight
,
std
=
0.02
)
...
@@ -96,32 +107,35 @@ class TestAGMMModel(torch.nn.Module):
...
@@ -96,32 +107,35 @@ class TestAGMMModel(torch.nn.Module):
class
_BaseScaledMMModel
(
torch
.
nn
.
Module
):
class
_BaseScaledMMModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
=
16
,
dtype
=
torch
.
float16
):
def
__init__
(
self
,
hidden_size
=
16
,
dtype
=
torch
.
float16
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
weight
=
torch
.
empty
([
hidden_size
,
hidden_size
],
dtype
=
FP8_DTYPE
)
\
self
.
weight
=
(
.
contiguous
().
transpose
(
0
,
1
)
torch
.
empty
([
hidden_size
,
hidden_size
],
dtype
=
FP8_DTYPE
)
.
contiguous
()
.
transpose
(
0
,
1
)
)
# Initialize scale_b for _scaled_mm.
# Initialize scale_b for _scaled_mm.
self
.
scale_b
=
torch
.
ones
(
1
,
self
.
hidden_size
,
dtype
=
torch
.
float32
)
self
.
scale_b
=
torch
.
ones
(
1
,
self
.
hidden_size
,
dtype
=
torch
.
float32
)
class
TestScaledMMRSModel
(
_BaseScaledMMModel
):
class
TestScaledMMRSModel
(
_BaseScaledMMModel
):
def
forward
(
self
,
input
:
torch
.
Tensor
):
def
forward
(
self
,
input
:
torch
.
Tensor
):
"""
"""
Forward pass implementing the scaled_mm + reduce scatter in the FX graph
Forward pass implementing the scaled_mm + reduce scatter in the FX graph
"""
"""
fp8_input
=
input
.
to
(
FP8_DTYPE
)
fp8_input
=
input
.
to
(
FP8_DTYPE
)
scale_a
=
torch
.
ones
(
input
.
shape
[
0
],
1
,
dtype
=
torch
.
float32
)
scale_a
=
torch
.
ones
(
input
.
shape
[
0
],
1
,
dtype
=
torch
.
float32
)
scaled_mm
=
torch
.
_scaled_mm
(
fp8_input
,
scaled_mm
=
torch
.
_scaled_mm
(
self
.
weight
,
fp8_input
,
scale_a
=
scale_a
,
self
.
weight
,
scale_b
=
self
.
scale_b
,
scale_a
=
scale_a
,
out_dtype
=
self
.
dtype
)
scale_b
=
self
.
scale_b
,
out_dtype
=
self
.
dtype
,
)
reduce_scatter
=
tensor_model_parallel_reduce_scatter
(
scaled_mm
,
dim
=
0
)
reduce_scatter
=
tensor_model_parallel_reduce_scatter
(
scaled_mm
,
dim
=
0
)
return
reduce_scatter
return
reduce_scatter
...
@@ -129,11 +143,10 @@ class TestScaledMMRSModel(_BaseScaledMMModel):
...
@@ -129,11 +143,10 @@ class TestScaledMMRSModel(_BaseScaledMMModel):
return
[
torch
.
ops
.
vllm
.
reduce_scatter
.
default
]
return
[
torch
.
ops
.
vllm
.
reduce_scatter
.
default
]
def
ops_in_model_after
(
self
):
def
ops_in_model_after
(
self
):
return
[
torch
.
ops
.
symm_mem
.
fused_scaled_matmul_reduce_scatter
.
default
]
return
[
torch
.
ops
.
vllm
.
patched_
fused_scaled_matmul_reduce_scatter
.
default
]
class
TestAGScaledMMModel
(
_BaseScaledMMModel
):
class
TestAGScaledMMModel
(
_BaseScaledMMModel
):
def
forward
(
self
,
input
:
torch
.
Tensor
):
def
forward
(
self
,
input
:
torch
.
Tensor
):
"""
"""
Forward pass implementing the all gather + scaled_mm in the FX graph
Forward pass implementing the all gather + scaled_mm in the FX graph
...
@@ -143,11 +156,13 @@ class TestAGScaledMMModel(_BaseScaledMMModel):
...
@@ -143,11 +156,13 @@ class TestAGScaledMMModel(_BaseScaledMMModel):
all_gather
=
tensor_model_parallel_all_gather
(
fp8_input
,
dim
=
0
)
all_gather
=
tensor_model_parallel_all_gather
(
fp8_input
,
dim
=
0
)
scale_a
=
torch
.
ones
(
all_gather
.
shape
[
0
],
1
,
dtype
=
torch
.
float32
)
scale_a
=
torch
.
ones
(
all_gather
.
shape
[
0
],
1
,
dtype
=
torch
.
float32
)
scaled_mm
=
torch
.
_scaled_mm
(
all_gather
,
scaled_mm
=
torch
.
_scaled_mm
(
self
.
weight
,
all_gather
,
scale_a
=
scale_a
,
self
.
weight
,
scale_b
=
self
.
scale_b
,
scale_a
=
scale_a
,
out_dtype
=
self
.
dtype
)
scale_b
=
self
.
scale_b
,
out_dtype
=
self
.
dtype
,
)
return
scaled_mm
return
scaled_mm
def
ops_in_model_before
(
self
):
def
ops_in_model_before
(
self
):
...
@@ -158,20 +173,22 @@ class TestAGScaledMMModel(_BaseScaledMMModel):
...
@@ -158,20 +173,22 @@ class TestAGScaledMMModel(_BaseScaledMMModel):
class
TestCutlassScaledMMRSModel
(
_BaseScaledMMModel
):
class
TestCutlassScaledMMRSModel
(
_BaseScaledMMModel
):
def
forward
(
self
,
input
:
torch
.
Tensor
):
def
forward
(
self
,
input
:
torch
.
Tensor
):
"""
"""
Forward pass implementing the cutlass_scaled_mm + reduce scatter
Forward pass implementing the cutlass_scaled_mm + reduce scatter
in the FX graph
in the FX graph
"""
"""
fp8_input
=
input
.
to
(
FP8_DTYPE
)
fp8_input
=
input
.
to
(
FP8_DTYPE
)
scale_a
=
torch
.
ones
(
input
.
shape
[
0
],
1
,
dtype
=
torch
.
float32
)
scale_a
=
torch
.
ones
(
input
.
shape
[
0
],
1
,
dtype
=
torch
.
float32
)
mm_out
=
torch
.
empty
((
fp8_input
.
shape
[
0
],
self
.
weight
.
shape
[
1
]),
mm_out
=
torch
.
empty
(
dtype
=
self
.
dtype
,
(
fp8_input
.
shape
[
0
],
self
.
weight
.
shape
[
1
]),
device
=
input
.
device
)
dtype
=
self
.
dtype
,
torch
.
ops
.
_C
.
cutlass_scaled_mm
(
mm_out
,
fp8_input
,
self
.
weight
,
scale_a
,
device
=
input
.
device
,
self
.
scale_b
,
None
)
)
torch
.
ops
.
_C
.
cutlass_scaled_mm
(
mm_out
,
fp8_input
,
self
.
weight
,
scale_a
,
self
.
scale_b
,
None
)
reduce_scatter
=
tensor_model_parallel_reduce_scatter
(
mm_out
,
dim
=
0
)
reduce_scatter
=
tensor_model_parallel_reduce_scatter
(
mm_out
,
dim
=
0
)
return
reduce_scatter
return
reduce_scatter
...
@@ -179,14 +196,13 @@ class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
...
@@ -179,14 +196,13 @@ class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
return
[
torch
.
ops
.
vllm
.
reduce_scatter
.
default
]
return
[
torch
.
ops
.
vllm
.
reduce_scatter
.
default
]
def
ops_in_model_after
(
self
):
def
ops_in_model_after
(
self
):
return
[
torch
.
ops
.
symm_mem
.
fused_scaled_matmul_reduce_scatter
.
default
]
return
[
torch
.
ops
.
vllm
.
patched_
fused_scaled_matmul_reduce_scatter
.
default
]
class
TestAGCutlassScaledMMModel
(
_BaseScaledMMModel
):
class
TestAGCutlassScaledMMModel
(
_BaseScaledMMModel
):
def
forward
(
self
,
input
:
torch
.
Tensor
):
def
forward
(
self
,
input
:
torch
.
Tensor
):
"""
"""
Forward pass implementing the all gather + cutlass_scaled_mm
Forward pass implementing the all gather + cutlass_scaled_mm
in the FX graph
in the FX graph
"""
"""
# Reshape input
# Reshape input
...
@@ -195,11 +211,14 @@ class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
...
@@ -195,11 +211,14 @@ class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
scale_a
=
torch
.
ones
(
all_gather
.
shape
[
0
],
1
,
dtype
=
torch
.
float32
)
scale_a
=
torch
.
ones
(
all_gather
.
shape
[
0
],
1
,
dtype
=
torch
.
float32
)
mm_out
=
torch
.
empty
((
all_gather
.
shape
[
0
],
self
.
weight
.
shape
[
1
]),
mm_out
=
torch
.
empty
(
dtype
=
self
.
dtype
,
(
all_gather
.
shape
[
0
],
self
.
weight
.
shape
[
1
]),
device
=
all_gather
.
device
)
dtype
=
self
.
dtype
,
torch
.
ops
.
_C
.
cutlass_scaled_mm
(
mm_out
,
all_gather
,
self
.
weight
,
device
=
all_gather
.
device
,
scale_a
,
self
.
scale_b
,
None
)
)
torch
.
ops
.
_C
.
cutlass_scaled_mm
(
mm_out
,
all_gather
,
self
.
weight
,
scale_a
,
self
.
scale_b
,
None
)
return
mm_out
return
mm_out
def
ops_in_model_before
(
self
):
def
ops_in_model_before
(
self
):
...
@@ -210,23 +229,43 @@ class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
...
@@ -210,23 +229,43 @@ class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
@
multi_gpu_test
(
num_gpus
=
2
)
@
multi_gpu_test
(
num_gpus
=
2
)
@
pytest
.
mark
.
parametrize
(
"test_model"
,
[
@
pytest
.
mark
.
parametrize
(
TestMMRSModel
,
TestAGMMModel
,
TestScaledMMRSModel
,
TestAGScaledMMModel
,
"test_model"
,
TestCutlassScaledMMRSModel
,
TestAGCutlassScaledMMModel
[
])
TestMMRSModel
,
TestAGMMModel
,
TestScaledMMRSModel
,
TestAGScaledMMModel
,
TestCutlassScaledMMRSModel
,
TestAGCutlassScaledMMModel
,
],
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
skipif
(
envs
.
VLLM_TARGET_DEVICE
not
in
[
"cuda"
],
@
pytest
.
mark
.
parametrize
(
"dynamic"
,
[
True
,
False
])
reason
=
"Only test on CUDA"
)
@
pytest
.
mark
.
skipif
(
envs
.
VLLM_TARGET_DEVICE
not
in
[
"cuda"
],
reason
=
"Only test on CUDA"
)
def
test_async_tp_pass_replace
(
test_model
:
str
,
batch_size
:
int
,
seq_len
:
int
,
def
test_async_tp_pass_replace
(
hidden_size
:
int
,
dtype
:
torch
.
dtype
):
test_model
:
str
,
if
test_model
in
(
TestScaledMMRSModel
,
TestAGScaledMMModel
,
batch_size
:
int
,
TestCutlassScaledMMRSModel
,
seq_len
:
int
,
TestAGCutlassScaledMMModel
)
and
dtype
==
torch
.
float16
:
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
dynamic
:
bool
,
):
if
(
test_model
in
(
TestScaledMMRSModel
,
TestAGScaledMMModel
,
TestCutlassScaledMMRSModel
,
TestAGCutlassScaledMMModel
,
)
and
dtype
==
torch
.
float16
):
pytest
.
skip
(
pytest
.
skip
(
"Only bf16 high precision output types are supported for "
\
"Only bf16 high precision output types are supported for "
"per-token (row-wise) scaling"
"per-token (row-wise) scaling"
)
)
...
@@ -235,19 +274,33 @@ def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int,
...
@@ -235,19 +274,33 @@ def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int,
def
run_torch_spawn
(
fn
,
nprocs
):
def
run_torch_spawn
(
fn
,
nprocs
):
# need to use torch.mp.spawn otherwise will have problems with
# need to use torch.mp.spawn otherwise will have problems with
# torch.distributed and cuda
# torch.distributed and cuda
torch
.
multiprocessing
.
spawn
(
fn
,
torch
.
multiprocessing
.
spawn
(
args
=
(
num_processes
,
test_model
,
fn
,
batch_size
,
seq_len
,
hidden_size
,
args
=
(
dtype
),
num_processes
,
nprocs
=
nprocs
)
test_model
,
batch_size
,
seq_len
,
hidden_size
,
dtype
,
dynamic
,
),
nprocs
=
nprocs
,
)
run_torch_spawn
(
async_tp_pass_on_test_model
,
num_processes
)
run_torch_spawn
(
async_tp_pass_on_test_model
,
num_processes
)
def
async_tp_pass_on_test_model
(
local_rank
:
int
,
world_size
:
int
,
def
async_tp_pass_on_test_model
(
test_model_cls
:
torch
.
nn
.
Module
,
local_rank
:
int
,
batch_size
:
int
,
seq_len
:
int
,
world_size
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
):
test_model_cls
:
torch
.
nn
.
Module
,
batch_size
:
int
,
seq_len
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
dynamic
:
bool
,
):
current_platform
.
seed_everything
(
0
)
current_platform
.
seed_everything
(
0
)
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
...
@@ -255,13 +308,15 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
...
@@ -255,13 +308,15 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
torch
.
set_default_dtype
(
dtype
)
torch
.
set_default_dtype
(
dtype
)
update_environment_variables
({
update_environment_variables
(
'RANK'
:
str
(
local_rank
),
{
'LOCAL_RANK'
:
str
(
local_rank
),
"RANK"
:
str
(
local_rank
),
'WORLD_SIZE'
:
str
(
world_size
),
"LOCAL_RANK"
:
str
(
local_rank
),
'MASTER_ADDR'
:
'localhost'
,
"WORLD_SIZE"
:
str
(
world_size
),
'MASTER_PORT'
:
'12345'
,
"MASTER_ADDR"
:
"localhost"
,
})
"MASTER_PORT"
:
"12345"
,
}
)
# initialize distributed
# initialize distributed
init_distributed_environment
()
init_distributed_environment
()
...
@@ -269,27 +324,40 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
...
@@ -269,27 +324,40 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
# configure vllm config for SequenceParallelismPass
# configure vllm config for SequenceParallelismPass
vllm_config
=
VllmConfig
()
vllm_config
=
VllmConfig
()
vllm_config
.
compilation_config
=
CompilationConfig
(
pass_config
=
PassConfig
(
vllm_config
.
compilation_config
=
CompilationConfig
(
enable_async_tp
=
True
,
),
)
pass_config
=
PassConfig
(
enable_async_tp
=
True
,
),
)
vllm_config
.
device_config
=
DeviceConfig
(
device
=
torch
.
device
(
"cuda"
))
vllm_config
.
device_config
=
DeviceConfig
(
device
=
torch
.
device
(
"cuda"
))
# this is a fake model name to construct the model config
# this is a fake model name to construct the model config
# in the vllm_config, it's not really used.
# in the vllm_config, it's not really used.
model_name
=
"nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
model_name
=
"RedHatAI/Llama-3.2-1B-Instruct-FP8"
vllm_config
.
model_config
=
ModelConfig
(
model
=
model_name
,
vllm_config
.
model_config
=
ModelConfig
(
trust_remote_code
=
True
,
model
=
model_name
,
trust_remote_code
=
True
,
dtype
=
dtype
,
seed
=
42
dtype
=
dtype
,
)
seed
=
42
)
async_tp_pass
=
AsyncTPPass
(
vllm_config
)
async_tp_pass
=
AsyncTPPass
(
vllm_config
)
backend
=
TestBackend
(
async_tp_pass
)
backend
=
TestBackend
(
async_tp_pass
)
model
=
test_model_cls
(
hidden_size
,
assert
(
dtype
)
# Pass dtype to model constructor
async_tp_pass
.
compilation_config
.
splitting_ops
==
vllm_config
.
compilation_config
.
splitting_ops
)
assert
(
async_tp_pass
.
compilation_config
.
use_inductor_graph_partition
==
vllm_config
.
compilation_config
.
use_inductor_graph_partition
)
model
=
test_model_cls
(
hidden_size
,
dtype
)
# Pass dtype to model constructor
hidden_states
=
torch
.
randn
((
batch_size
*
seq_len
,
hidden_size
),
hidden_states
=
torch
.
randn
(
dtype
=
dtype
,
(
batch_size
*
seq_len
,
hidden_size
),
dtype
=
dtype
,
requires_grad
=
False
requires_grad
=
False
)
)
if
dynamic
:
torch
.
_dynamo
.
mark_dynamic
(
hidden_states
,
0
)
compiled_model
=
torch
.
compile
(
model
,
backend
=
backend
)
compiled_model
=
torch
.
compile
(
model
,
backend
=
backend
)
compiled_model
(
hidden_states
)
compiled_model
(
hidden_states
)
...
@@ -306,10 +374,10 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
...
@@ -306,10 +374,10 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
@
create_new_process_for_each_test
()
@
create_new_process_for_each_test
()
@
pytest
.
mark
.
parametrize
(
"model_id"
,
[
@
pytest
.
mark
.
parametrize
(
"m
eta-llama/Llama-3.2-1B-Instruct
"
,
"m
odel_id
"
,
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"
[
"meta-llama/Llama-3.2-1B-Instruct"
,
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"
],
]
)
)
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"async_tp_enabled"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"async_tp_enabled"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"distributed_backend"
,
[
"mp"
])
@
pytest
.
mark
.
parametrize
(
"distributed_backend"
,
[
"mp"
])
...
@@ -342,16 +410,10 @@ def test_async_tp_pass_correctness(
...
@@ -342,16 +410,10 @@ def test_async_tp_pass_correctness(
common_args
.
append
(
"--enforce-eager"
)
common_args
.
append
(
"--enforce-eager"
)
compilation_config
=
{
compilation_config
=
{
'level'
:
3
,
"mode"
:
CompilationMode
.
VLLM_COMPILE
,
'compile_sizes'
:
[
2
,
4
,
8
],
"compile_sizes"
:
[
2
,
4
,
8
],
'splitting_ops'
:
[],
"splitting_ops"
:
[],
'pass_config'
:
{
"pass_config"
:
{
"enable_async_tp"
:
async_tp_enabled
},
'enable_async_tp'
:
async_tp_enabled
},
}
async_tp_env
=
tp_env
=
{
"VLLM_USE_V1"
:
"1"
,
}
}
async_tp_args
=
[
async_tp_args
=
[
...
@@ -372,9 +434,4 @@ def test_async_tp_pass_correctness(
...
@@ -372,9 +434,4 @@ def test_async_tp_pass_correctness(
"mp"
,
"mp"
,
]
]
compare_two_settings
(
model_id
,
compare_two_settings
(
model_id
,
async_tp_args
,
tp_args
,
method
=
"generate"
)
async_tp_args
,
tp_args
,
async_tp_env
,
tp_env
,
method
=
"generate"
)
tests/compile/test_basic_correctness.py
View file @
006693ed
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
import
dataclasses
import
dataclasses
import
pytest
import
pytest
from
vllm.config
import
Compilation
Level
from
vllm.config
import
Compilation
Mode
from
vllm.utils
import
cuda_device_count_stateless
from
vllm.utils
.torch_utils
import
cuda_device_count_stateless
from
..utils
import
compare_all_settings
from
..utils
import
compare_all_settings
...
@@ -23,7 +21,7 @@ class TestSetting:
...
@@ -23,7 +21,7 @@ class TestSetting:
# we cannot afford testing the full Cartesian product
# we cannot afford testing the full Cartesian product
# of all models and all
level
s
# of all models and all
mode
s
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"test_setting"
,
"test_setting"
,
[
[
...
@@ -79,14 +77,15 @@ class TestSetting:
...
@@ -79,14 +77,15 @@ class TestSetting:
method
=
"encode"
,
method
=
"encode"
,
),
),
# vision language model
# vision language model
TestSetting
(
# See https://github.com/vllm-project/vllm/issues/26716.
model
=
"microsoft/Phi-3.5-vision-instruct"
,
# TestSetting(
model_args
=
[
"--trust-remote-code"
,
"--max-model-len"
,
"2048"
],
# model="microsoft/Phi-3.5-vision-instruct",
pp_size
=
2
,
# model_args=["--trust-remote-code", "--max-model-len", "2048"],
tp_size
=
1
,
# pp_size=2,
attn_backend
=
"FLASH_ATTN"
,
# tp_size=1,
method
=
"generate_with_image"
,
# attn_backend="FLASH_ATTN",
),
# method="generate_with_image",
# ),
],
],
)
)
def
test_compile_correctness
(
def
test_compile_correctness
(
...
@@ -103,43 +102,54 @@ def test_compile_correctness(
...
@@ -103,43 +102,54 @@ def test_compile_correctness(
attn_backend
=
test_setting
.
attn_backend
attn_backend
=
test_setting
.
attn_backend
method
=
test_setting
.
method
method
=
test_setting
.
method
if
cuda_device_count_stateless
()
<
pp_size
*
tp_size
:
if
cuda_device_count_stateless
()
<
pp_size
*
tp_size
:
pytest
.
skip
(
f
"Need at least
{
pp_size
}
*
{
tp_size
}
CUDA gpus but got "
pytest
.
skip
(
f
"
{
cuda_device_count_stateless
()
}
"
)
f
"Need at least
{
pp_size
}
*
{
tp_size
}
CUDA gpus but got "
f
"
{
cuda_device_count_stateless
()
}
"
)
with
monkeypatch
.
context
()
as
m
:
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
attn_backend
)
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
attn_backend
)
final_args
=
[
final_args
=
[
"--enforce-eager"
,
*
model_args
,
"-pp"
,
*
model_args
,
str
(
pp_size
),
"-tp"
,
"-pp"
,
str
(
tp_size
)
str
(
pp_size
),
"-tp"
,
str
(
tp_size
),
"-O.cudagraph_mode=none"
,
]
]
all_args
:
list
[
list
[
str
]]
=
[]
all_args
:
list
[
list
[
str
]]
=
[]
all_envs
:
list
[
dict
[
str
,
str
]
|
None
]
=
[]
all_envs
:
list
[
dict
[
str
,
str
]
|
None
]
=
[]
for
level
in
[
for
comp_mode
in
[
CompilationLevel
.
NO_COMPILATION
,
CompilationMode
.
STOCK_TORCH_COMPILE
,
CompilationLevel
.
PIECEWISE
,
CompilationMode
.
DYNAMO_TRACE_ONCE
,
CompilationMode
.
VLLM_COMPILE
,
]:
]:
all_args
.
append
(
final_args
+
[
f
"-O
{
level
}
"
])
for
mode
in
[
CompilationMode
.
NONE
,
comp_mode
]:
all_envs
.
append
({})
all_args
.
append
(
final_args
+
[
f
"-O.mode=
{
mode
.
name
}
"
,
"-O.backend=inductor"
]
)
# inductor will change the output, so we only compare if the output
# inductor will change the output, so we only compare if the output
# is close, not exactly the same.
# is close, not exactly the same.
compare_all_settings
(
compare_all_settings
(
model
,
model
,
all_args
,
all_args
,
all_envs
,
all_envs
,
method
=
method
if
method
!=
"generate"
else
"generate_close"
)
method
=
method
if
method
!=
"generate"
else
"generate_close"
,
all_envs
.
clear
()
)
all_args
.
clear
()
all_envs
.
clear
()
all_args
.
clear
()
for
level
in
[
for
mode
in
[
CompilationLevel
.
NO_COMPILATION
,
CompilationMode
.
NONE
,
CompilationLevel
.
DYNAMO_AS_IS
,
CompilationMode
.
STOCK_TORCH_COMPILE
,
CompilationLevel
.
DYNAMO_ONCE
,
CompilationMode
.
DYNAMO_TRACE_ONCE
,
CompilationMode
.
VLLM_COMPILE
,
]:
]:
all_args
.
append
(
final_args
+
[
f
"-O
{
level
}
"
])
all_args
.
append
(
final_args
+
[
f
"-O.mode=
{
mode
.
name
}
"
,
"-O.backend=eager"
])
all_envs
.
append
({})
all_envs
.
append
({})
all_envs
.
append
({})
compare_all_settings
(
model
,
all_args
*
3
,
all_envs
,
method
=
method
)
compare_all_settings
(
model
,
all_args
*
3
,
all_envs
,
method
=
method
)
tests/compile/test_config.py
View file @
006693ed
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
copy
from
contextlib
import
nullcontext
from
unittest.mock
import
patch
import
pytest
import
pytest
from
pydantic
import
ValidationError
import
vllm
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.counter
import
compilation_counter
from
vllm.config
import
CompilationConfig
,
VllmConfig
from
vllm.compilation.fix_functionalization
import
FixFunctionalizationPass
from
vllm.utils
import
_is_torch_equal_or_newer
from
vllm.config
import
CompilationConfig
,
CUDAGraphMode
,
VllmConfig
from
vllm.config.compilation
import
CompilationMode
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
_is_torch_equal_or_newer
# This import automatically registers `torch.ops.silly.attention`
from
.
import
silly_attention
# noqa: F401
def
test_version
():
assert
_is_torch_equal_or_newer
(
'2.8.0.dev20250624+cu128'
,
'2.8.0.dev'
)
assert
_is_torch_equal_or_newer
(
'2.8.0a0+gitc82a174'
,
'2.8.0.dev'
)
assert
_is_torch_equal_or_newer
(
'2.8.0'
,
'2.8.0.dev'
)
assert
_is_torch_equal_or_newer
(
'2.8.1'
,
'2.8.0.dev'
)
assert
not
_is_torch_equal_or_newer
(
'2.7.1'
,
'2.8.0.dev'
)
def
test_version
():
# Test the version comparison logic using the private function
assert
_is_torch_equal_or_newer
(
"2.8.0.dev20250624+cu128"
,
"2.8.0.dev"
)
assert
_is_torch_equal_or_newer
(
"2.8.0a0+gitc82a174"
,
"2.8.0.dev"
)
assert
_is_torch_equal_or_newer
(
"2.8.0"
,
"2.8.0.dev"
)
assert
_is_torch_equal_or_newer
(
"2.8.1"
,
"2.8.0.dev"
)
assert
not
_is_torch_equal_or_newer
(
"2.7.1"
,
"2.8.0.dev"
)
def
test_use_cudagraphs_dynamic
(
monkeypatch
):
assert
vllm
.
envs
.
VLLM_USE_V1
vllm_config
=
VllmConfig
()
assert
vllm_config
.
compilation_config
.
use_cudagraph
monkeypatch
.
setenv
(
'VLLM_USE_V1'
,
'0'
)
def
test_copy_pass
():
vllm_config
=
VllmConfig
()
vllm_config
=
VllmConfig
()
assert
not
vllm_config
.
compilation_config
.
use_cudagraph
inductor_pass
=
FixFunctionalizationPass
(
vllm_config
)
copied_inductor_pass
=
copy
.
deepcopy
(
inductor_pass
)
assert
(
copied_inductor_pass
.
compilation_config
.
use_inductor_graph_partition
==
vllm_config
.
compilation_config
.
use_inductor_graph_partition
)
assert
(
copied_inductor_pass
.
compilation_config
.
splitting_ops
==
vllm_config
.
compilation_config
.
splitting_ops
)
def
test_custom_op
():
def
test_custom_op
():
...
@@ -41,63 +57,80 @@ def test_custom_op():
...
@@ -41,63 +57,80 @@ def test_custom_op():
# may be influenced by other tests.
# may be influenced by other tests.
@
pytest
.
mark
.
parametrize
(
"val"
,
[
"1"
])
@
pytest
.
mark
.
parametrize
(
"val"
,
[
"1"
])
def
test_VLLM_DISABLE_COMPILE_CACHE
(
vllm_runner
,
monkeypatch
,
val
):
def
test_VLLM_DISABLE_COMPILE_CACHE
(
vllm_runner
,
monkeypatch
,
val
):
assert
vllm
.
envs
.
VLLM_USE_V1
# Disable multiprocessing so that the counter is in the same process
# Disable multiprocessing so that the counter is in the same process
monkeypatch
.
setenv
(
'
VLLM_ENABLE_V1_MULTIPROCESSING
'
,
'0'
)
monkeypatch
.
setenv
(
"
VLLM_ENABLE_V1_MULTIPROCESSING
"
,
"0"
)
monkeypatch
.
setenv
(
'
VLLM_DISABLE_COMPILE_CACHE
'
,
val
)
monkeypatch
.
setenv
(
"
VLLM_DISABLE_COMPILE_CACHE
"
,
val
)
compilation_config
=
{
compilation_config
=
{
"
use_
cudagraph
"
:
False
,
# speed things up a bit
"cudagraph
_mode"
:
CUDAGraphMode
.
NONE
,
# speed things up a bit
}
}
with
(
with
(
compilation_counter
.
expect
(
num_cache_entries_updated
=
0
,
compilation_counter
.
expect
(
num_compiled_artifacts_saved
=
0
),
num_cache_entries_updated
=
0
,
num_compiled_artifacts_saved
=
0
# loading the model causes compilation (if enabled) to happen
),
vllm_runner
(
'facebook/opt-125m'
,
# loading the model causes compilation (if enabled) to happen
compilation_config
=
compilation_config
,
vllm_runner
(
gpu_memory_utilization
=
0.4
)
as
_
):
"facebook/opt-125m"
,
compilation_config
=
compilation_config
,
gpu_memory_utilization
=
0.4
,
)
as
_
,
):
pass
pass
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@
pytest
.
mark
.
forked
@
pytest
.
mark
.
forked
@
pytest
.
mark
.
parametrize
(
"enabled"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
def
test_use_cudagraphs
(
vllm_runner
,
monkeypatch
,
enabled
):
"cudagraph_mode,num_cudagraph_captured"
,
assert
vllm
.
envs
.
VLLM_USE_V1
[
(
CUDAGraphMode
.
NONE
,
0
),
(
CUDAGraphMode
.
FULL_DECODE_ONLY
,
1
),
(
CUDAGraphMode
.
PIECEWISE
,
13
),
(
CUDAGraphMode
.
FULL_AND_PIECEWISE
,
14
),
],
)
def
test_use_cudagraphs
(
vllm_runner
,
monkeypatch
,
cudagraph_mode
,
num_cudagraph_captured
):
# Disable multiprocessing so that the counter is in the same process
# Disable multiprocessing so that the counter is in the same process
monkeypatch
.
setenv
(
'
VLLM_ENABLE_V1_MULTIPROCESSING
'
,
'0'
)
monkeypatch
.
setenv
(
"
VLLM_ENABLE_V1_MULTIPROCESSING
"
,
"0"
)
compilation_config
=
{
compilation_config
=
{
"cudagraph_capture_sizes"
:
[
100
],
"cudagraph_capture_sizes"
:
[
100
],
"
use_
cudagraph
"
:
enabled
,
"cudagraph
_mode"
:
cudagraph_mode
,
}
}
num_gpu_runner_capture_triggers
=
1
if
cudagraph_mode
!=
CUDAGraphMode
.
NONE
else
0
with
(
with
(
compilation_counter
.
expect
(
compilation_counter
.
expect
(
num_graphs_seen
=
1
,
num_graphs_seen
=
1
,
num_gpu_runner_capture_triggers
=
1
if
enabled
else
0
,
num_gpu_runner_capture_triggers
=
num_gpu_runner_capture_triggers
,
num_cudagraph_captured
=
13
if
enabled
else
0
,
num_cudagraph_captured
=
num_cudagraph_captured
,
),
),
# loading the model causes compilation (if enabled) to happen
# loading the model causes compilation (if enabled) to happen
vllm_runner
(
'facebook/opt-125m'
,
vllm_runner
(
compilation_config
=
compilation_config
,
"facebook/opt-125m"
,
gpu_memory_utilization
=
0.4
)
as
_
):
compilation_config
=
compilation_config
,
gpu_memory_utilization
=
0.4
,
)
as
_
,
):
pass
pass
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@
pytest
.
mark
.
forked
@
pytest
.
mark
.
forked
def
test_
dynamo_as_is
(
vllm_runner
,
monkeypatch
):
def
test_
stock_torch_compile
(
vllm_runner
,
monkeypatch
):
# Disable multiprocessing so that the counter is in the same process
# Disable multiprocessing so that the counter is in the same process
monkeypatch
.
setenv
(
'
VLLM_ENABLE_V1_MULTIPROCESSING
'
,
'0'
)
monkeypatch
.
setenv
(
"
VLLM_ENABLE_V1_MULTIPROCESSING
"
,
"0"
)
with
(
with
(
compilation_counter
.
expect
(
dynamo_as_is_count
=
1
),
compilation_counter
.
expect
(
stock_torch_compile_count
=
1
),
# loading the model causes compilation (if enabled) to happen
# loading the model causes compilation (if enabled) to happen
vllm_runner
(
'facebook/opt-125m'
,
vllm_runner
(
compilation_config
=
{
"level"
:
1
},
"facebook/opt-125m"
,
gpu_memory_utilization
=
0.4
)
as
_
):
compilation_config
=
{
"mode"
:
CompilationMode
.
STOCK_TORCH_COMPILE
},
gpu_memory_utilization
=
0.4
,
)
as
_
,
):
pass
pass
...
@@ -105,15 +138,16 @@ def test_dynamo_as_is(vllm_runner, monkeypatch):
...
@@ -105,15 +138,16 @@ def test_dynamo_as_is(vllm_runner, monkeypatch):
@
pytest
.
mark
.
forked
@
pytest
.
mark
.
forked
def
test_no_compilation
(
vllm_runner
,
monkeypatch
):
def
test_no_compilation
(
vllm_runner
,
monkeypatch
):
# Disable multiprocessing so that the counter is in the same process
# Disable multiprocessing so that the counter is in the same process
monkeypatch
.
setenv
(
'VLLM_ENABLE_V1_MULTIPROCESSING'
,
'0'
)
monkeypatch
.
setenv
(
"VLLM_ENABLE_V1_MULTIPROCESSING"
,
"0"
)
with
(
with
(
compilation_counter
.
expect
(
num_graphs_seen
=
0
,
compilation_counter
.
expect
(
num_graphs_seen
=
0
,
stock_torch_compile_count
=
0
),
dynamo_as_is_count
=
0
),
# loading the model causes compilation (if enabled) to happen
# loading the model causes compilation (if enabled) to happen
vllm_runner
(
vllm_runner
(
'facebook/opt-125m'
,
"facebook/opt-125m"
,
compilation_config
=
{
"level"
:
0
},
compilation_config
=
{
"mode"
:
CompilationMode
.
NONE
},
gpu_memory_utilization
=
0.4
)
as
_
):
gpu_memory_utilization
=
0.4
,
)
as
_
,
):
pass
pass
...
@@ -121,13 +155,223 @@ def test_no_compilation(vllm_runner, monkeypatch):
...
@@ -121,13 +155,223 @@ def test_no_compilation(vllm_runner, monkeypatch):
@
pytest
.
mark
.
forked
@
pytest
.
mark
.
forked
def
test_enforce_eager
(
vllm_runner
,
monkeypatch
):
def
test_enforce_eager
(
vllm_runner
,
monkeypatch
):
# Disable multiprocessing so that the counter is in the same process
# Disable multiprocessing so that the counter is in the same process
monkeypatch
.
setenv
(
'
VLLM_ENABLE_V1_MULTIPROCESSING
'
,
'0'
)
monkeypatch
.
setenv
(
"
VLLM_ENABLE_V1_MULTIPROCESSING
"
,
"0"
)
with
(
with
(
compilation_counter
.
expect
(
num_graphs_seen
=
0
,
compilation_counter
.
expect
(
num_graphs_seen
=
0
,
stock_torch_compile_count
=
0
),
dynamo_as_is_count
=
0
),
# loading the model causes compilation (if enabled) to happen
# loading the model causes compilation (if enabled) to happen
vllm_runner
(
vllm_runner
(
'
facebook/opt-125m
'
,
"
facebook/opt-125m
"
,
enforce_eager
=
True
,
gpu_memory_utilization
=
0.4
enforce_eager
=
True
,
)
as
_
,
gpu_memory_utilization
=
0.4
)
as
_
):
):
pass
pass
def
test_splitting_ops_dynamic
():
# Default config
config
=
VllmConfig
()
# Default V1 config leaves cudagraph mode unset; splitting ops are only
# populated when the engine decides to use piecewise compilation.
assert
config
.
compilation_config
.
cudagraph_mode
==
CUDAGraphMode
.
NONE
assert
not
config
.
compilation_config
.
splitting_ops_contain_attention
()
# When use_inductor_graph_partition=True
config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
use_inductor_graph_partition
=
True
,
splitting_ops
=
[
"vllm::unified_attention"
],
)
)
# with inductor partition we use splitting_ops directly for
# partition rules
assert
config
.
compilation_config
.
splitting_ops
==
[
"vllm::unified_attention"
]
# When attn_fusion pass enabled.
config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
pass_config
=
{
"enable_attn_fusion"
:
True
,
"enable_noop"
:
True
},
custom_ops
=
[
"+quant_fp8"
],
cudagraph_mode
=
CUDAGraphMode
.
PIECEWISE
,
)
)
assert
config
.
compilation_config
.
splitting_ops
==
[]
# cudagraph mode also fall back to FULL
assert
config
.
compilation_config
.
cudagraph_mode
==
CUDAGraphMode
.
FULL
# splitting_ops can not contain attention ops when attn_fusion
# pass enabled.
with
pytest
.
raises
(
ValidationError
):
config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
pass_config
=
{
"enable_attn_fusion"
:
True
,
"enable_noop"
:
True
},
custom_ops
=
[
"+quant_fp8"
],
cudagraph_mode
=
CUDAGraphMode
.
PIECEWISE
,
# work around for accessing all attntion ops
splitting_ops
=
CompilationConfig
().
_attention_ops
,
)
)
# When both use_inductor_graph_partition and attn_fusion pass enabled.
config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
use_inductor_graph_partition
=
True
,
pass_config
=
{
"enable_attn_fusion"
:
True
,
"enable_noop"
:
True
},
custom_ops
=
[
"+quant_fp8"
],
cudagraph_mode
=
CUDAGraphMode
.
PIECEWISE
,
)
)
# With inductor graph partition, attn_fusion and splitting_ops
# work together. Default splitting_ops include attention ops.
assert
config
.
compilation_config
.
splitting_ops_contain_attention
()
# enable_attn_fusion is directly supported under
# use_inductor_graph_partition=True, and cudagraph_mode
# is unchanged.
assert
config
.
compilation_config
.
cudagraph_mode
==
CUDAGraphMode
.
PIECEWISE
def
test_should_split
():
import
torch
from
vllm.compilation.partition_rules
import
should_split
graph
=
torch
.
fx
.
Graph
()
node
=
torch
.
fx
.
Node
(
graph
=
graph
,
name
=
"dummy_node"
,
op
=
"call_function"
,
target
=
torch
.
ops
.
aten
.
add
.
default
,
args
=
(),
kwargs
=
{},
)
# supports OpOverloadPacket
splitting_ops
=
[
"aten::add"
]
assert
should_split
(
node
,
splitting_ops
)
# supports OpOverload
splitting_ops
=
[
"aten::add.default"
]
assert
should_split
(
node
,
splitting_ops
)
# supports OpOverload
splitting_ops
=
[
"aten::add.Tensor"
]
assert
not
should_split
(
node
,
splitting_ops
)
q
,
k
,
v
,
out
=
[
torch
.
randn
(
1
)]
*
4
# supports custom ops as OpOverloadPacket
node
=
torch
.
fx
.
Node
(
graph
=
graph
,
name
=
"dummy_node"
,
op
=
"call_function"
,
target
=
torch
.
ops
.
silly
.
attention
,
args
=
(
q
,
k
,
v
,
out
),
kwargs
=
{},
)
splitting_ops
=
[
"silly::attention"
]
assert
should_split
(
node
,
splitting_ops
)
# supports custom ops as OpOverload
node
=
torch
.
fx
.
Node
(
graph
=
graph
,
name
=
"dummy_node"
,
op
=
"call_function"
,
target
=
torch
.
ops
.
silly
.
attention
.
default
,
args
=
(
q
,
k
,
v
,
out
),
kwargs
=
{},
)
splitting_ops
=
[
"silly::attention"
]
assert
should_split
(
node
,
splitting_ops
)
splitting_ops
=
[
"silly::attention.default"
]
assert
should_split
(
node
,
splitting_ops
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
support_static_graph_mode
(),
reason
=
"Skip if not cudagraph mode supported"
,
)
@
pytest
.
mark
.
parametrize
(
(
"cudagraph_capture_sizes"
,
"max_cudagraph_capture_size"
,
"tp_size"
,
"enable_sequence_parallelism"
,
"max_num_batched_tokens"
,
"cudagraph_mode"
,
"expected_max_size"
,
),
[
(
None
,
None
,
1
,
False
,
2048
,
CUDAGraphMode
.
FULL_AND_PIECEWISE
,
256
),
([
1
,
2
,
4
],
4
,
1
,
False
,
2048
,
CUDAGraphMode
.
FULL_AND_PIECEWISE
,
4
),
(
[
1
,
2
,
4
],
8
,
1
,
False
,
2048
,
CUDAGraphMode
.
FULL_AND_PIECEWISE
,
ValidationError
,
),
([
1
,
256
],
None
,
1
,
False
,
2048
,
CUDAGraphMode
.
FULL_AND_PIECEWISE
,
256
),
([],
None
,
1
,
False
,
2048
,
CUDAGraphMode
.
NONE
,
0
),
(
None
,
0
,
1
,
False
,
2048
,
CUDAGraphMode
.
NONE
,
0
),
# truncated to nearest multiple of 8 or 16
(
None
,
257
,
1
,
False
,
2048
,
CUDAGraphMode
.
FULL_AND_PIECEWISE
,
256
),
# max from list
([
1
,
2
,
4
,
15
],
None
,
1
,
False
,
2048
,
CUDAGraphMode
.
FULL_AND_PIECEWISE
,
15
),
# filtered out 15 due to SP
([
1
,
2
,
4
,
15
],
None
,
2
,
True
,
2048
,
CUDAGraphMode
.
FULL_AND_PIECEWISE
,
4
),
# limited by the max_tokens
([
1
,
2
,
4
,
15
],
None
,
1
,
False
,
8
,
CUDAGraphMode
.
FULL_AND_PIECEWISE
,
4
),
# the list should contain at least 1 element when use cudagraph
([],
None
,
1
,
False
,
2048
,
CUDAGraphMode
.
FULL_AND_PIECEWISE
,
ValidationError
),
# the max capturing size should be >= 1 when use cudagraph
(
None
,
0
,
1
,
False
,
2048
,
CUDAGraphMode
.
FULL_AND_PIECEWISE
,
ValidationError
),
],
)
def
test_cudagraph_sizes_post_init
(
cudagraph_capture_sizes
,
max_cudagraph_capture_size
,
tp_size
,
enable_sequence_parallelism
,
max_num_batched_tokens
,
cudagraph_mode
,
expected_max_size
,
):
ctx
=
nullcontext
()
if
expected_max_size
==
ValidationError
:
ctx
=
pytest
.
raises
(
expected_max_size
)
with
(
ctx
,
patch
(
"vllm.config.parallel.cuda_device_count_stateless"
,
return_value
=
tp_size
),
):
compilation_config
=
CompilationConfig
(
cudagraph_capture_sizes
=
cudagraph_capture_sizes
,
max_cudagraph_capture_size
=
max_cudagraph_capture_size
,
pass_config
=
{
"enable_sequence_parallelism"
:
enable_sequence_parallelism
,
"enable_fusion"
:
True
,
"enable_noop"
:
True
,
},
cudagraph_mode
=
cudagraph_mode
,
)
engine_args
=
EngineArgs
(
model
=
"facebook/opt-125m"
,
tensor_parallel_size
=
tp_size
,
max_num_seqs
=
min
(
max_num_batched_tokens
,
128
),
max_num_batched_tokens
=
max_num_batched_tokens
,
compilation_config
=
compilation_config
,
)
vllm_config
=
engine_args
.
create_engine_config
()
assert
(
vllm_config
.
compilation_config
.
max_cudagraph_capture_size
==
expected_max_size
)
tests/compile/test_decorator.py
View file @
006693ed
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.decorators
import
(
ignore_torch_compile
,
from
vllm.compilation.decorators
import
ignore_torch_compile
,
support_torch_compile
support_torch_compile
)
from
vllm.config
import
(
from
vllm.config
import
(
CacheConfig
,
CompilationConfig
,
CompilationLevel
,
CacheConfig
,
CUDAGraphMode
,
VllmConfig
,
set_current_vllm_config
)
CompilationConfig
,
CompilationMode
,
CUDAGraphMode
,
VllmConfig
,
set_current_vllm_config
,
)
from
vllm.forward_context
import
BatchDescriptor
,
set_forward_context
from
vllm.forward_context
import
BatchDescriptor
,
set_forward_context
from
vllm.utils.torch_utils
import
is_torch_equal_or_newer
# This import automatically registers `torch.ops.silly.attention`
# This import automatically registers `torch.ops.silly.attention`
from
.
import
silly_attention
# noqa: F401
from
.
import
silly_attention
# noqa: F401
...
@@ -18,56 +25,86 @@ MLP_SIZE = 128
...
@@ -18,56 +25,86 @@ MLP_SIZE = 128
@
torch
.
inference_mode
@
torch
.
inference_mode
def
run_model
(
vllm_config
:
VllmConfig
,
model
:
nn
.
Module
,
def
run_model
(
cudagraph_runtime_mode
:
CUDAGraphMode
):
vllm_config
:
VllmConfig
,
model
:
nn
.
Module
,
cudagraph_runtime_mode
:
CUDAGraphMode
):
with
set_forward_context
({},
vllm_config
=
vllm_config
):
with
set_forward_context
({},
vllm_config
=
vllm_config
):
# warmup for the model with cudagraph_mode NONE
# warmup for the model with cudagraph_mode NONE
model
(
torch
.
randn
(
BATCH_SIZE
,
MLP_SIZE
).
cuda
())
model
(
torch
.
randn
(
BATCH_SIZE
,
MLP_SIZE
).
cuda
())
# simulate cudagraphs capturing
# simulate cudagraphs capturing
with
set_forward_context
({},
with
set_forward_context
(
vllm_config
=
vllm_config
,
{},
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
vllm_config
=
vllm_config
,
batch_descriptor
=
BatchDescriptor
(
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
num_tokens
=
2
,
)):
batch_descriptor
=
BatchDescriptor
(
num_tokens
=
2
,
),
):
model
(
torch
.
randn
(
2
,
MLP_SIZE
).
cuda
())
model
(
torch
.
randn
(
2
,
MLP_SIZE
).
cuda
())
with
set_forward_context
({},
with
set_forward_context
(
vllm_config
=
vllm_config
,
{},
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
vllm_config
=
vllm_config
,
batch_descriptor
=
BatchDescriptor
(
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
num_tokens
=
1
,
)):
batch_descriptor
=
BatchDescriptor
(
num_tokens
=
1
,
),
):
model
(
torch
.
randn
(
1
,
MLP_SIZE
).
cuda
())
model
(
torch
.
randn
(
1
,
MLP_SIZE
).
cuda
())
# simulate cudagraphs replay
# simulate cudagraphs replay
with
set_forward_context
({},
with
set_forward_context
(
vllm_config
=
vllm_config
,
{},
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
vllm_config
=
vllm_config
,
batch_descriptor
=
BatchDescriptor
(
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
num_tokens
=
2
,
)):
batch_descriptor
=
BatchDescriptor
(
num_tokens
=
2
,
),
):
output
=
model
(
torch
.
randn
(
2
,
MLP_SIZE
).
cuda
())
output
=
model
(
torch
.
randn
(
2
,
MLP_SIZE
).
cuda
())
output
=
output
.
cpu
()
output
=
output
.
cpu
()
return
output
.
cpu
()
return
output
.
cpu
()
def
test_ignore_torch_compile_decorator
():
@
pytest
.
mark
.
parametrize
(
"use_inductor_graph_partition"
,
[
True
,
False
])
def
test_ignore_torch_compile_decorator
(
use_inductor_graph_partition
,
monkeypatch
):
# disable compile cache so that we can count the number of compilations
# appropriately
monkeypatch
.
setenv
(
"VLLM_DISABLE_COMPILE_CACHE"
,
"1"
)
if
use_inductor_graph_partition
and
not
is_torch_equal_or_newer
(
"2.9.0.dev"
):
pytest
.
skip
(
"inductor graph partition is only available in PyTorch 2.9+"
)
# piecewise
# piecewise
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
vllm_config
=
VllmConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
compilation_config
=
CompilationConfig
(
use_cudagraph
=
True
,
mode
=
CompilationMode
.
VLLM_COMPILE
,
splitting_ops
=
[
"silly.attention"
],
splitting_ops
=
[
"silly::attention"
],
cudagraph_capture_sizes
=
[
1
,
2
],
cudagraph_capture_sizes
=
[
1
,
2
],
))
use_inductor_graph_partition
=
use_inductor_graph_partition
,
)
)
cudagraph_runtime_mode
=
CUDAGraphMode
.
PIECEWISE
cudagraph_runtime_mode
=
CUDAGraphMode
.
PIECEWISE
expected_num_graphs_seen
=
1
expected_num_cudagraph_captured
=
(
4
# num_cudagraph_sizes * num cudagraphs to capture
)
if
use_inductor_graph_partition
:
expected_num_piecewise_graphs_seen
=
1
expected_num_piecewise_capturable_graphs_seen
=
1
expected_num_backend_compilations
=
1
else
:
expected_num_piecewise_graphs_seen
=
3
expected_num_piecewise_capturable_graphs_seen
=
2
expected_num_backend_compilations
=
2
@
support_torch_compile
@
support_torch_compile
class
A
(
nn
.
Module
):
class
A
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
**
kwargs
*
,
)
->
None
:
vllm_config
:
VllmConfig
,
prefix
:
str
=
''
,
**
kwargs
)
->
None
:
super
().
__init__
()
super
().
__init__
()
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -79,66 +116,58 @@ def test_ignore_torch_compile_decorator():
...
@@ -79,66 +116,58 @@ def test_ignore_torch_compile_decorator():
return
x
return
x
@
ignore_torch_compile
@
ignore_torch_compile
class
B
(
A
):
class
B
(
A
):
...
...
@
support_torch_compile
@
support_torch_compile
class
C
(
B
):
class
C
(
B
):
...
...
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
mod_A
=
A
(
vllm_config
=
vllm_config
,
prefix
=
''
).
eval
().
cuda
()
mod_A
=
A
(
vllm_config
=
vllm_config
,
prefix
=
""
).
eval
().
cuda
()
# A has support_torch_compile
# A has support_torch_compile
with
compilation_counter
.
expect
(
with
compilation_counter
.
expect
(
num_graphs_seen
=
1
,
num_graphs_seen
=
expected_num_graphs_seen
,
num_piecewise_graphs_seen
=
3
,
num_piecewise_graphs_seen
=
expected_num_piecewise_graphs_seen
,
num_piecewise_capturable_graphs_seen
=
2
,
num_piecewise_capturable_graphs_seen
=
expected_num_piecewise_capturable_graphs_seen
,
num_backend_compilations
=
2
,
num_backend_compilations
=
expected_num_backend_compilations
,
num_cudagraph_captured
=
4
,
num_cudagraph_captured
=
expected_num_cudagraph_captured
,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
):
run_model
(
vllm_config
,
mod_A
,
cudagraph_runtime_mode
)
run_model
(
vllm_config
,
mod_A
,
cudagraph_runtime_mode
)
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
mod_B
=
B
(
vllm_config
=
vllm_config
,
prefix
=
''
).
eval
().
cuda
()
mod_B
=
B
(
vllm_config
=
vllm_config
,
prefix
=
""
).
eval
().
cuda
()
# B's ignore_torch_compile should override A's support_torch_compile
# B's ignore_torch_compile should override A's support_torch_compile
with
compilation_counter
.
expect
(
with
compilation_counter
.
expect
(
num_graphs_seen
=
0
,
num_graphs_seen
=
0
,
num_piecewise_graphs_seen
=
0
,
num_piecewise_graphs_seen
=
0
,
num_piecewise_capturable_graphs_seen
=
0
,
num_piecewise_capturable_graphs_seen
=
0
,
num_backend_compilations
=
0
,
num_backend_compilations
=
0
,
num_cudagraph_captured
=
0
,
num_cudagraph_captured
=
0
,
):
):
run_model
(
vllm_config
,
mod_B
,
cudagraph_runtime_mode
)
run_model
(
vllm_config
,
mod_B
,
cudagraph_runtime_mode
)
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
mod_C
=
C
(
vllm_config
=
vllm_config
,
prefix
=
''
).
eval
().
cuda
()
mod_C
=
C
(
vllm_config
=
vllm_config
,
prefix
=
""
).
eval
().
cuda
()
# C's support_torch_compile should override B's ignore_torch_compile
# C's support_torch_compile should override B's ignore_torch_compile
with
compilation_counter
.
expect
(
with
compilation_counter
.
expect
(
num_graphs_seen
=
1
,
num_graphs_seen
=
expected_num_graphs_seen
,
num_piecewise_graphs_seen
=
3
,
num_piecewise_graphs_seen
=
expected_num_piecewise_graphs_seen
,
num_piecewise_capturable_graphs_seen
=
2
,
num_piecewise_capturable_graphs_seen
=
expected_num_piecewise_capturable_graphs_seen
,
num_backend_compilations
=
2
,
num_backend_compilations
=
expected_num_backend_compilations
,
num_cudagraph_captured
=
4
,
num_cudagraph_captured
=
expected_num_cudagraph_captured
,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
):
run_model
(
vllm_config
,
mod_C
,
cudagraph_runtime_mode
)
run_model
(
vllm_config
,
mod_C
,
cudagraph_runtime_mode
)
# Only enable torch.compile if
# Only enable torch.compile if
# vllm_config.cache_config.kv_sharing_fast_prefill=True
# vllm_config.cache_config.kv_sharing_fast_prefill=True
@
support_torch_compile
(
enable_if
=
lambda
vllm_config
:
vllm_config
.
cache_config
.
@
support_torch_compile
(
kv_sharing_fast_prefill
)
enable_if
=
lambda
vllm_config
:
vllm_config
.
cache_config
.
kv_sharing_fast_prefill
)
class
B
(
nn
.
Module
):
class
B
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
**
kwargs
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
''
,
**
kwargs
)
->
None
:
super
().
__init__
()
super
().
__init__
()
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -152,15 +181,11 @@ class B(nn.Module):
...
@@ -152,15 +181,11 @@ class B(nn.Module):
# Only enable torch.compile if
# Only enable torch.compile if
# vllm_config.cache_config.kv_sharing_fast_prefill=False
# vllm_config.cache_config.kv_sharing_fast_prefill=False
@
support_torch_compile
(
enable_if
=
lambda
vllm_config
:
not
vllm_config
.
@
support_torch_compile
(
cache_config
.
kv_sharing_fast_prefill
)
enable_if
=
lambda
vllm_config
:
not
vllm_config
.
cache_config
.
kv_sharing_fast_prefill
)
class
A
(
nn
.
Module
):
class
A
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
**
kwargs
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
''
,
**
kwargs
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
mod1
=
B
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
**
kwargs
)
self
.
mod1
=
B
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
**
kwargs
)
self
.
mod2
=
B
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
**
kwargs
)
self
.
mod2
=
B
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
**
kwargs
)
...
@@ -174,55 +199,88 @@ class A(nn.Module):
...
@@ -174,55 +199,88 @@ class A(nn.Module):
return
x
return
x
def
test_conditional_compile_enable_if
():
@
pytest
.
mark
.
parametrize
(
"use_inductor_graph_partition"
,
[
True
,
False
])
vllm_config
=
VllmConfig
(
cache_config
=
CacheConfig
(
def
test_conditional_compile_enable_if
(
use_inductor_graph_partition
,
monkeypatch
):
kv_sharing_fast_prefill
=
True
,
),
# disable compile cache so that we can count the number of compilations
compilation_config
=
CompilationConfig
(
# appropriately
level
=
CompilationLevel
.
PIECEWISE
,
monkeypatch
.
setenv
(
"VLLM_DISABLE_COMPILE_CACHE"
,
"1"
)
use_cudagraph
=
True
,
splitting_ops
=
[
"silly.attention"
],
if
use_inductor_graph_partition
and
not
is_torch_equal_or_newer
(
"2.9.0.dev"
):
cudagraph_capture_sizes
=
[
1
,
2
],
pytest
.
skip
(
"inductor graph partition is only available in PyTorch 2.9+"
)
))
vllm_config
=
VllmConfig
(
cache_config
=
CacheConfig
(
kv_sharing_fast_prefill
=
True
,
),
compilation_config
=
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
splitting_ops
=
[
"silly::attention"
],
cudagraph_capture_sizes
=
[
1
,
2
],
use_inductor_graph_partition
=
use_inductor_graph_partition
,
),
)
cudagraph_runtime_mode
=
CUDAGraphMode
.
PIECEWISE
cudagraph_runtime_mode
=
CUDAGraphMode
.
PIECEWISE
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
mod_A
=
A
(
vllm_config
=
vllm_config
,
prefix
=
''
).
eval
().
cuda
()
mod_A
=
A
(
vllm_config
=
vllm_config
,
prefix
=
""
).
eval
().
cuda
()
if
use_inductor_graph_partition
:
expected_num_piecewise_graphs_seen
=
2
expected_num_piecewise_capturable_graphs_seen
=
2
expected_num_backend_compilations
=
2
else
:
expected_num_piecewise_graphs_seen
=
6
expected_num_piecewise_capturable_graphs_seen
=
4
expected_num_backend_compilations
=
4
# A has support_torch_compile but enable_if fn returns False
# A has support_torch_compile but enable_if fn returns False
# enalbe_if will be True for B, so we expect mod1 and mod2
# enalbe_if will be True for B, so we expect mod1 and mod2
# to be compiled
# to be compiled
with
compilation_counter
.
expect
(
with
compilation_counter
.
expect
(
num_graphs_seen
=
2
,
num_graphs_seen
=
2
,
num_piecewise_graphs_seen
=
6
,
num_piecewise_graphs_seen
=
expected_num_piecewise_graphs_seen
,
# 3 piecewise graphs per instance of B()
# 3 piecewise graphs per instance of B()
num_piecewise_capturable_graphs_seen
=
4
,
num_piecewise_capturable_graphs_seen
=
expected_num_piecewise_capturable_graphs_seen
,
num_backend_compilations
=
4
,
num_backend_compilations
=
expected_num_backend_compilations
,
num_cudagraph_captured
=
8
,
num_cudagraph_captured
=
8
,
# num_cudagraph_sizes * num
_piecewise_captur
able
_
graphs
_seen
# num_cudagraph_sizes * num
cudagraph
able
graphs
to capture
):
):
run_model
(
vllm_config
,
mod_A
,
cudagraph_runtime_mode
)
run_model
(
vllm_config
,
mod_A
,
cudagraph_runtime_mode
)
# Set kv_sharing_fast_prefill=False
# Set kv_sharing_fast_prefill=False
# which will cause A to be compiled and B to not be compiled
# which will cause A to be compiled and B to not be compiled
vllm_config
=
VllmConfig
(
cache_config
=
CacheConfig
(
vllm_config
=
VllmConfig
(
kv_sharing_fast_prefill
=
False
,
),
cache_config
=
CacheConfig
(
compilation_config
=
CompilationConfig
(
kv_sharing_fast_prefill
=
False
,
level
=
CompilationLevel
.
PIECEWISE
,
),
use_cudagraph
=
True
,
compilation_config
=
CompilationConfig
(
splitting_ops
=
[
"silly.attention"
],
mode
=
CompilationMode
.
VLLM_COMPILE
,
cudagraph_capture_sizes
=
[
1
,
2
],
splitting_ops
=
[
"silly::attention"
],
))
cudagraph_capture_sizes
=
[
1
,
2
],
use_inductor_graph_partition
=
use_inductor_graph_partition
,
),
)
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
mod_A
=
A
(
vllm_config
=
vllm_config
,
prefix
=
''
).
eval
().
cuda
()
mod_A
=
A
(
vllm_config
=
vllm_config
,
prefix
=
""
).
eval
().
cuda
()
if
use_inductor_graph_partition
:
expected_num_piecewise_graphs_seen
=
1
expected_num_piecewise_capturable_graphs_seen
=
1
expected_num_backend_compilations
=
1
else
:
# 3 attn ops and 4 non-attn ops
expected_num_piecewise_graphs_seen
=
7
expected_num_piecewise_capturable_graphs_seen
=
4
expected_num_backend_compilations
=
4
with
compilation_counter
.
expect
(
with
compilation_counter
.
expect
(
num_graphs_seen
=
1
,
num_graphs_seen
=
1
,
num_piecewise_graphs_seen
=
7
,
num_piecewise_graphs_seen
=
expected_num_piecewise_graphs_seen
,
# 3 attn ops and 4 non-attn ops
# 3 attn ops and 4 non-attn ops
num_piecewise_capturable_graphs_seen
=
4
,
num_piecewise_capturable_graphs_seen
=
expected_num_piecewise_capturable_graphs_seen
,
num_backend_compilations
=
4
,
num_backend_compilations
=
expected_num_backend_compilations
,
num_cudagraph_captured
=
8
,
num_cudagraph_captured
=
8
,
# num_cudagraph_sizes * num
_piecewise_captur
able
_
graphs
_seen
# num_cudagraph_sizes * num
cudagraph
able
graphs
to capture
):
):
run_model
(
vllm_config
,
mod_A
,
cudagraph_runtime_mode
)
run_model
(
vllm_config
,
mod_A
,
cudagraph_runtime_mode
)
Prev
1
…
23
24
25
26
27
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment