Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8d75f22e
Commit
8d75f22e
authored
Dec 13, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.13.0rc1' into v0.13.0rc1-ori
parents
ce888aa4
7d80c73d
Changes
656
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1340 additions
and
217 deletions
+1340
-217
tests/basic_correctness/test_basic_correctness.py
tests/basic_correctness/test_basic_correctness.py
+4
-1
tests/basic_correctness/test_cumem.py
tests/basic_correctness/test_cumem.py
+7
-2
tests/benchmarks/test_param_sweep.py
tests/benchmarks/test_param_sweep.py
+257
-0
tests/benchmarks/test_plot_filters.py
tests/benchmarks/test_plot_filters.py
+171
-0
tests/compile/distributed/test_fusions_e2e.py
tests/compile/distributed/test_fusions_e2e.py
+19
-4
tests/compile/test_aot_compile.py
tests/compile/test_aot_compile.py
+66
-0
tests/compile/test_compile_ranges.py
tests/compile/test_compile_ranges.py
+174
-0
tests/compile/test_config.py
tests/compile/test_config.py
+80
-7
tests/compile/test_dynamic_shapes_compilation.py
tests/compile/test_dynamic_shapes_compilation.py
+132
-7
tests/compile/test_fusion.py
tests/compile/test_fusion.py
+147
-18
tests/compile/test_fusion_attn.py
tests/compile/test_fusion_attn.py
+2
-3
tests/compile/test_pass_manager.py
tests/compile/test_pass_manager.py
+40
-33
tests/compile/test_silu_mul_quant_fusion.py
tests/compile/test_silu_mul_quant_fusion.py
+57
-5
tests/conftest.py
tests/conftest.py
+58
-27
tests/distributed/test_context_parallel.py
tests/distributed/test_context_parallel.py
+98
-93
tests/distributed/test_eplb_algo.py
tests/distributed/test_eplb_algo.py
+18
-14
tests/distributed/test_eplb_spec_decode.py
tests/distributed/test_eplb_spec_decode.py
+7
-0
tests/distributed/test_kvlayout.py
tests/distributed/test_kvlayout.py
+1
-1
tests/distributed/test_pipeline_parallel.py
tests/distributed/test_pipeline_parallel.py
+1
-1
tests/distributed/test_shm_storage.py
tests/distributed/test_shm_storage.py
+1
-1
No files found.
Too many changes to show.
To preserve performance only
656 of 656+
files are displayed.
Plain diff
Email patch
tests/basic_correctness/test_basic_correctness.py
View file @
8d75f22e
...
@@ -13,12 +13,15 @@ import pytest
...
@@ -13,12 +13,15 @@ import pytest
import
torch
import
torch
from
vllm
import
LLM
from
vllm
import
LLM
from
vllm.platforms
import
current_platform
from
vllm.v1.engine.llm_engine
import
LLMEngine
from
vllm.v1.engine.llm_engine
import
LLMEngine
from
..conftest
import
HfRunner
,
VllmRunner
from
..conftest
import
HfRunner
,
VllmRunner
from
..models.utils
import
check_outputs_equal
from
..models.utils
import
check_outputs_equal
from
..utils
import
multi_gpu_test
from
..utils
import
multi_gpu_test
ATTN_BACKEND
=
[
"ROCM_ATTN"
]
if
current_platform
.
is_rocm
()
else
[
"FLASH_ATTN"
]
MODELS
=
[
MODELS
=
[
"hmellor/tiny-random-Gemma2ForCausalLM"
,
"hmellor/tiny-random-Gemma2ForCausalLM"
,
"meta-llama/Llama-3.2-1B-Instruct"
,
"meta-llama/Llama-3.2-1B-Instruct"
,
...
@@ -57,7 +60,7 @@ def _fix_prompt_embed_outputs(
...
@@ -57,7 +60,7 @@ def _fix_prompt_embed_outputs(
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
]
)
@
pytest
.
mark
.
parametrize
(
"backend"
,
ATTN_BACKEND
)
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"async_scheduling"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"async_scheduling"
,
[
True
,
False
])
...
...
tests/basic_correctness/test_cumem.py
View file @
8d75f22e
...
@@ -260,13 +260,18 @@ def test_deep_sleep_fp8_kvcache():
...
@@ -260,13 +260,18 @@ def test_deep_sleep_fp8_kvcache():
llm
.
sleep
(
level
=
2
)
llm
.
sleep
(
level
=
2
)
used_bytes
=
current_platform
.
get_current_memory_usage
()
-
used_bytes_baseline
used_bytes
=
current_platform
.
get_current_memory_usage
()
-
used_bytes_baseline
assert
used_bytes
<
3
*
GiB_bytes
# Rocm uses more memory for CudaGraphs, so we add 2 GiB more for the threshold
rocm_extra_mem_bytes
=
2
*
GiB_bytes
if
current_platform
.
is_rocm
()
else
0
mem_threshold_after_sleep
=
3
*
GiB_bytes
+
rocm_extra_mem_bytes
assert
used_bytes
<
mem_threshold_after_sleep
llm
.
wake_up
(
tags
=
[
"weights"
])
llm
.
wake_up
(
tags
=
[
"weights"
])
llm
.
collective_rpc
(
"reload_weights"
)
llm
.
collective_rpc
(
"reload_weights"
)
used_bytes
=
current_platform
.
get_current_memory_usage
()
-
used_bytes_baseline
used_bytes
=
current_platform
.
get_current_memory_usage
()
-
used_bytes_baseline
assert
used_bytes
<
4
*
GiB_bytes
mem_threshold_after_wake_up
=
4
*
GiB_bytes
+
rocm_extra_mem_bytes
assert
used_bytes
<
mem_threshold_after_wake_up
# now allocate kv cache and cuda graph memory
# now allocate kv cache and cuda graph memory
llm
.
wake_up
(
tags
=
[
"kv_cache"
])
llm
.
wake_up
(
tags
=
[
"kv_cache"
])
...
...
tests/benchmarks/test_param_sweep.py
0 → 100644
View file @
8d75f22e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
json
import
tempfile
from
pathlib
import
Path
import
pytest
from
vllm.benchmarks.sweep.param_sweep
import
ParameterSweep
,
ParameterSweepItem
class
TestParameterSweepItem
:
"""Test ParameterSweepItem functionality."""
@
pytest
.
mark
.
parametrize
(
"input_dict,expected"
,
[
(
{
"compilation_config.use_inductor_graph_partition"
:
False
},
"--compilation-config.use_inductor_graph_partition=false"
,
),
(
{
"compilation_config.use_inductor_graph_partition"
:
True
},
"--compilation-config.use_inductor_graph_partition=true"
,
),
(
{
"compilation_config.use_inductor"
:
False
},
"--compilation-config.use_inductor=false"
,
),
(
{
"compilation_config.use_inductor"
:
True
},
"--compilation-config.use_inductor=true"
,
),
],
)
def
test_nested_boolean_params
(
self
,
input_dict
,
expected
):
"""Test that nested boolean params use =true/false syntax."""
item
=
ParameterSweepItem
.
from_record
(
input_dict
)
cmd
=
item
.
apply_to_cmd
([
"vllm"
,
"serve"
,
"model"
])
assert
expected
in
cmd
@
pytest
.
mark
.
parametrize
(
"input_dict,expected"
,
[
({
"enable_prefix_caching"
:
False
},
"--no-enable-prefix-caching"
),
({
"enable_prefix_caching"
:
True
},
"--enable-prefix-caching"
),
({
"disable_log_stats"
:
False
},
"--no-disable-log-stats"
),
({
"disable_log_stats"
:
True
},
"--disable-log-stats"
),
],
)
def
test_non_nested_boolean_params
(
self
,
input_dict
,
expected
):
"""Test that non-nested boolean params use --no- prefix."""
item
=
ParameterSweepItem
.
from_record
(
input_dict
)
cmd
=
item
.
apply_to_cmd
([
"vllm"
,
"serve"
,
"model"
])
assert
expected
in
cmd
@
pytest
.
mark
.
parametrize
(
"compilation_config"
,
[
{
"cudagraph_mode"
:
"full"
,
"mode"
:
2
,
"use_inductor_graph_partition"
:
True
},
{
"cudagraph_mode"
:
"piecewise"
,
"mode"
:
3
,
"use_inductor_graph_partition"
:
False
,
},
],
)
def
test_nested_dict_value
(
self
,
compilation_config
):
"""Test that nested dict values are serialized as JSON."""
item
=
ParameterSweepItem
.
from_record
(
{
"compilation_config"
:
compilation_config
}
)
cmd
=
item
.
apply_to_cmd
([
"vllm"
,
"serve"
,
"model"
])
assert
"--compilation-config"
in
cmd
# The dict should be JSON serialized
idx
=
cmd
.
index
(
"--compilation-config"
)
assert
json
.
loads
(
cmd
[
idx
+
1
])
==
compilation_config
@
pytest
.
mark
.
parametrize
(
"input_dict,expected_key,expected_value"
,
[
({
"model"
:
"test-model"
},
"--model"
,
"test-model"
),
({
"max_tokens"
:
100
},
"--max-tokens"
,
"100"
),
({
"temperature"
:
0.7
},
"--temperature"
,
"0.7"
),
],
)
def
test_string_and_numeric_values
(
self
,
input_dict
,
expected_key
,
expected_value
):
"""Test that string and numeric values are handled correctly."""
item
=
ParameterSweepItem
.
from_record
(
input_dict
)
cmd
=
item
.
apply_to_cmd
([
"vllm"
,
"serve"
])
assert
expected_key
in
cmd
assert
expected_value
in
cmd
@
pytest
.
mark
.
parametrize
(
"input_dict,expected_key,key_idx_offset"
,
[
({
"max_tokens"
:
200
},
"--max-tokens"
,
1
),
({
"enable_prefix_caching"
:
False
},
"--no-enable-prefix-caching"
,
0
),
],
)
def
test_replace_existing_parameter
(
self
,
input_dict
,
expected_key
,
key_idx_offset
):
"""Test that existing parameters in cmd are replaced."""
item
=
ParameterSweepItem
.
from_record
(
input_dict
)
if
key_idx_offset
==
1
:
# Key-value pair
cmd
=
item
.
apply_to_cmd
([
"vllm"
,
"serve"
,
"--max-tokens"
,
"100"
,
"model"
])
assert
expected_key
in
cmd
idx
=
cmd
.
index
(
expected_key
)
assert
cmd
[
idx
+
1
]
==
"200"
assert
"100"
not
in
cmd
else
:
# Boolean flag
cmd
=
item
.
apply_to_cmd
(
[
"vllm"
,
"serve"
,
"--enable-prefix-caching"
,
"model"
]
)
assert
expected_key
in
cmd
assert
"--enable-prefix-caching"
not
in
cmd
class
TestParameterSweep
:
"""Test ParameterSweep functionality."""
def
test_from_records_list
(
self
):
"""Test creating ParameterSweep from a list of records."""
records
=
[
{
"max_tokens"
:
100
,
"temperature"
:
0.7
},
{
"max_tokens"
:
200
,
"temperature"
:
0.9
},
]
sweep
=
ParameterSweep
.
from_records
(
records
)
assert
len
(
sweep
)
==
2
assert
sweep
[
0
][
"max_tokens"
]
==
100
assert
sweep
[
1
][
"max_tokens"
]
==
200
def
test_read_from_dict
(
self
):
"""Test creating ParameterSweep from a dict format."""
data
=
{
"experiment1"
:
{
"max_tokens"
:
100
,
"temperature"
:
0.7
},
"experiment2"
:
{
"max_tokens"
:
200
,
"temperature"
:
0.9
},
}
sweep
=
ParameterSweep
.
read_from_dict
(
data
)
assert
len
(
sweep
)
==
2
# Check that items have the _benchmark_name field
names
=
{
item
[
"_benchmark_name"
]
for
item
in
sweep
}
assert
names
==
{
"experiment1"
,
"experiment2"
}
# Check that parameters are preserved
for
item
in
sweep
:
if
item
[
"_benchmark_name"
]
==
"experiment1"
:
assert
item
[
"max_tokens"
]
==
100
assert
item
[
"temperature"
]
==
0.7
elif
item
[
"_benchmark_name"
]
==
"experiment2"
:
assert
item
[
"max_tokens"
]
==
200
assert
item
[
"temperature"
]
==
0.9
def
test_read_json_list_format
(
self
):
"""Test reading JSON file with list format."""
records
=
[
{
"max_tokens"
:
100
,
"temperature"
:
0.7
},
{
"max_tokens"
:
200
,
"temperature"
:
0.9
},
]
with
tempfile
.
NamedTemporaryFile
(
mode
=
"w"
,
suffix
=
".json"
,
delete
=
False
)
as
f
:
json
.
dump
(
records
,
f
)
temp_path
=
Path
(
f
.
name
)
try
:
sweep
=
ParameterSweep
.
read_json
(
temp_path
)
assert
len
(
sweep
)
==
2
assert
sweep
[
0
][
"max_tokens"
]
==
100
assert
sweep
[
1
][
"max_tokens"
]
==
200
finally
:
temp_path
.
unlink
()
def
test_read_json_dict_format
(
self
):
"""Test reading JSON file with dict format."""
data
=
{
"experiment1"
:
{
"max_tokens"
:
100
,
"temperature"
:
0.7
},
"experiment2"
:
{
"max_tokens"
:
200
,
"temperature"
:
0.9
},
}
with
tempfile
.
NamedTemporaryFile
(
mode
=
"w"
,
suffix
=
".json"
,
delete
=
False
)
as
f
:
json
.
dump
(
data
,
f
)
temp_path
=
Path
(
f
.
name
)
try
:
sweep
=
ParameterSweep
.
read_json
(
temp_path
)
assert
len
(
sweep
)
==
2
# Check that items have the _benchmark_name field
names
=
{
item
[
"_benchmark_name"
]
for
item
in
sweep
}
assert
names
==
{
"experiment1"
,
"experiment2"
}
finally
:
temp_path
.
unlink
()
def
test_unique_benchmark_names_validation
(
self
):
"""Test that duplicate _benchmark_name values raise an error."""
# Test with duplicate names in list format
records
=
[
{
"_benchmark_name"
:
"exp1"
,
"max_tokens"
:
100
},
{
"_benchmark_name"
:
"exp1"
,
"max_tokens"
:
200
},
]
with
pytest
.
raises
(
ValueError
,
match
=
"Duplicate _benchmark_name values"
):
ParameterSweep
.
from_records
(
records
)
def
test_unique_benchmark_names_multiple_duplicates
(
self
):
"""Test validation with multiple duplicate names."""
records
=
[
{
"_benchmark_name"
:
"exp1"
,
"max_tokens"
:
100
},
{
"_benchmark_name"
:
"exp1"
,
"max_tokens"
:
200
},
{
"_benchmark_name"
:
"exp2"
,
"max_tokens"
:
300
},
{
"_benchmark_name"
:
"exp2"
,
"max_tokens"
:
400
},
]
with
pytest
.
raises
(
ValueError
,
match
=
"Duplicate _benchmark_name values"
):
ParameterSweep
.
from_records
(
records
)
def
test_no_benchmark_names_allowed
(
self
):
"""Test that records without _benchmark_name are allowed."""
records
=
[
{
"max_tokens"
:
100
,
"temperature"
:
0.7
},
{
"max_tokens"
:
200
,
"temperature"
:
0.9
},
]
sweep
=
ParameterSweep
.
from_records
(
records
)
assert
len
(
sweep
)
==
2
def
test_mixed_benchmark_names_allowed
(
self
):
"""Test that mixing records with and without _benchmark_name is allowed."""
records
=
[
{
"_benchmark_name"
:
"exp1"
,
"max_tokens"
:
100
},
{
"max_tokens"
:
200
,
"temperature"
:
0.9
},
]
sweep
=
ParameterSweep
.
from_records
(
records
)
assert
len
(
sweep
)
==
2
class
TestParameterSweepItemKeyNormalization
:
"""Test key normalization in ParameterSweepItem."""
def
test_underscore_to_hyphen_conversion
(
self
):
"""Test that underscores are converted to hyphens in CLI."""
item
=
ParameterSweepItem
.
from_record
({
"max_tokens"
:
100
})
cmd
=
item
.
apply_to_cmd
([
"vllm"
,
"serve"
])
assert
"--max-tokens"
in
cmd
def
test_nested_key_preserves_suffix
(
self
):
"""Test that nested keys preserve the suffix format."""
# The suffix after the dot should preserve underscores
item
=
ParameterSweepItem
.
from_record
(
{
"compilation_config.some_nested_param"
:
"value"
}
)
cmd
=
item
.
apply_to_cmd
([
"vllm"
,
"serve"
])
# The prefix (compilation_config) gets converted to hyphens,
# but the suffix (some_nested_param) is preserved
assert
any
(
"compilation-config.some_nested_param"
in
arg
for
arg
in
cmd
)
tests/benchmarks/test_plot_filters.py
0 → 100644
View file @
8d75f22e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pandas
as
pd
import
pytest
from
vllm.benchmarks.sweep.plot
import
(
PlotEqualTo
,
PlotFilterBase
,
PlotFilters
,
PlotGreaterThan
,
PlotGreaterThanOrEqualTo
,
PlotLessThan
,
PlotLessThanOrEqualTo
,
PlotNotEqualTo
,
)
class
TestPlotFilters
:
"""Test PlotFilter functionality including 'inf' edge case."""
def
setup_method
(
self
):
"""Create sample DataFrames for testing."""
# DataFrame with numeric values
self
.
df_numeric
=
pd
.
DataFrame
(
{
"request_rate"
:
[
1.0
,
5.0
,
10.0
,
50.0
,
100.0
],
"value"
:
[
10
,
20
,
30
,
40
,
50
],
}
)
# DataFrame with float('inf') - note: string "inf" values are coerced
# to float when loading data, so we only test with float('inf')
self
.
df_inf_float
=
pd
.
DataFrame
(
{
"request_rate"
:
[
1.0
,
5.0
,
10.0
,
float
(
"inf"
),
float
(
"inf"
)],
"value"
:
[
10
,
20
,
30
,
40
,
50
],
}
)
@
pytest
.
mark
.
parametrize
(
"target,expected_count"
,
[
(
"5.0"
,
1
),
(
"10.0"
,
1
),
(
"1.0"
,
1
),
],
)
def
test_equal_to_numeric
(
self
,
target
,
expected_count
):
"""Test PlotEqualTo with numeric values."""
filter_obj
=
PlotEqualTo
(
"request_rate"
,
target
)
result
=
filter_obj
.
apply
(
self
.
df_numeric
)
assert
len
(
result
)
==
expected_count
def
test_equal_to_inf_float
(
self
):
"""Test PlotEqualTo with float('inf')."""
filter_obj
=
PlotEqualTo
(
"request_rate"
,
"inf"
)
result
=
filter_obj
.
apply
(
self
.
df_inf_float
)
# Should match both float('inf') entries because float('inf') == float('inf')
assert
len
(
result
)
==
2
@
pytest
.
mark
.
parametrize
(
"target,expected_count"
,
[
(
"5.0"
,
4
),
# All except 5.0
(
"1.0"
,
4
),
# All except 1.0
],
)
def
test_not_equal_to_numeric
(
self
,
target
,
expected_count
):
"""Test PlotNotEqualTo with numeric values."""
filter_obj
=
PlotNotEqualTo
(
"request_rate"
,
target
)
result
=
filter_obj
.
apply
(
self
.
df_numeric
)
assert
len
(
result
)
==
expected_count
def
test_not_equal_to_inf_float
(
self
):
"""Test PlotNotEqualTo with float('inf')."""
filter_obj
=
PlotNotEqualTo
(
"request_rate"
,
"inf"
)
result
=
filter_obj
.
apply
(
self
.
df_inf_float
)
# Should exclude float('inf') entries
assert
len
(
result
)
==
3
@
pytest
.
mark
.
parametrize
(
"target,expected_count"
,
[
(
"10.0"
,
2
),
# 1.0, 5.0
(
"50.0"
,
3
),
# 1.0, 5.0, 10.0
(
"5.0"
,
1
),
# 1.0
],
)
def
test_less_than
(
self
,
target
,
expected_count
):
"""Test PlotLessThan with numeric values."""
filter_obj
=
PlotLessThan
(
"request_rate"
,
target
)
result
=
filter_obj
.
apply
(
self
.
df_numeric
)
assert
len
(
result
)
==
expected_count
@
pytest
.
mark
.
parametrize
(
"target,expected_count"
,
[
(
"10.0"
,
3
),
# 1.0, 5.0, 10.0
(
"5.0"
,
2
),
# 1.0, 5.0
],
)
def
test_less_than_or_equal_to
(
self
,
target
,
expected_count
):
"""Test PlotLessThanOrEqualTo with numeric values."""
filter_obj
=
PlotLessThanOrEqualTo
(
"request_rate"
,
target
)
result
=
filter_obj
.
apply
(
self
.
df_numeric
)
assert
len
(
result
)
==
expected_count
@
pytest
.
mark
.
parametrize
(
"target,expected_count"
,
[
(
"10.0"
,
2
),
# 50.0, 100.0
(
"5.0"
,
3
),
# 10.0, 50.0, 100.0
],
)
def
test_greater_than
(
self
,
target
,
expected_count
):
"""Test PlotGreaterThan with numeric values."""
filter_obj
=
PlotGreaterThan
(
"request_rate"
,
target
)
result
=
filter_obj
.
apply
(
self
.
df_numeric
)
assert
len
(
result
)
==
expected_count
@
pytest
.
mark
.
parametrize
(
"target,expected_count"
,
[
(
"10.0"
,
3
),
# 10.0, 50.0, 100.0
(
"5.0"
,
4
),
# 5.0, 10.0, 50.0, 100.0
],
)
def
test_greater_than_or_equal_to
(
self
,
target
,
expected_count
):
"""Test PlotGreaterThanOrEqualTo with numeric values."""
filter_obj
=
PlotGreaterThanOrEqualTo
(
"request_rate"
,
target
)
result
=
filter_obj
.
apply
(
self
.
df_numeric
)
assert
len
(
result
)
==
expected_count
@
pytest
.
mark
.
parametrize
(
"filter_str,expected_var,expected_target,expected_type"
,
[
(
"request_rate==5.0"
,
"request_rate"
,
"5.0"
,
PlotEqualTo
),
(
"request_rate!=10.0"
,
"request_rate"
,
"10.0"
,
PlotNotEqualTo
),
(
"request_rate<50.0"
,
"request_rate"
,
"50.0"
,
PlotLessThan
),
(
"request_rate<=50.0"
,
"request_rate"
,
"50.0"
,
PlotLessThanOrEqualTo
),
(
"request_rate>10.0"
,
"request_rate"
,
"10.0"
,
PlotGreaterThan
),
(
"request_rate>=10.0"
,
"request_rate"
,
"10.0"
,
PlotGreaterThanOrEqualTo
),
(
"request_rate==inf"
,
"request_rate"
,
"inf"
,
PlotEqualTo
),
(
"request_rate!='inf'"
,
"request_rate"
,
"inf"
,
PlotNotEqualTo
),
],
)
def
test_parse_str
(
self
,
filter_str
,
expected_var
,
expected_target
,
expected_type
):
"""Test parsing filter strings."""
filter_obj
=
PlotFilterBase
.
parse_str
(
filter_str
)
assert
isinstance
(
filter_obj
,
expected_type
)
assert
filter_obj
.
var
==
expected_var
assert
filter_obj
.
target
==
expected_target
def
test_parse_str_inf_edge_case
(
self
):
"""Test parsing 'inf' string in filter."""
filter_obj
=
PlotFilterBase
.
parse_str
(
"request_rate==inf"
)
assert
isinstance
(
filter_obj
,
PlotEqualTo
)
assert
filter_obj
.
var
==
"request_rate"
assert
filter_obj
.
target
==
"inf"
def
test_parse_multiple_filters
(
self
):
"""Test parsing multiple filters."""
filters
=
PlotFilters
.
parse_str
(
"request_rate>5.0,value<=40"
)
assert
len
(
filters
)
==
2
assert
isinstance
(
filters
[
0
],
PlotGreaterThan
)
assert
isinstance
(
filters
[
1
],
PlotLessThanOrEqualTo
)
def
test_parse_empty_filter
(
self
):
"""Test parsing empty filter string."""
filters
=
PlotFilters
.
parse_str
(
""
)
assert
len
(
filters
)
==
0
tests/compile/distributed/test_fusions_e2e.py
View file @
8d75f22e
...
@@ -298,10 +298,14 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
...
@@ -298,10 +298,14 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
r
"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes"
,
r
"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes"
,
log_holder
.
text
,
log_holder
.
text
,
)
)
assert
len
(
log_matches
)
==
2
,
log_holder
.
text
# 2 for each compile range
# (global compile range can be split due to fuse_allreduce_rmsnorm)
num_compile_ranges
=
len
(
compilation_config
.
get_compile_ranges
())
assert
num_compile_ranges
in
[
1
,
2
]
assert
int
(
log_matches
[
0
])
==
matches
.
attention_fusion
assert
len
(
log_matches
)
==
2
*
num_compile_ranges
,
log_holder
.
text
assert
int
(
log_matches
[
1
])
==
matches
.
attention_fusion
assert
all
(
int
(
log_match
)
==
matches
.
attention_fusion
for
log_match
in
log_matches
)
log_matches
=
re
.
findall
(
log_matches
=
re
.
findall
(
r
"collective_fusion.py:\d+] Replaced (\d+) patterns"
,
r
"collective_fusion.py:\d+] Replaced (\d+) patterns"
,
...
@@ -312,6 +316,12 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
...
@@ -312,6 +316,12 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
assert
int
(
log_matches
[
0
])
==
matches
.
allreduce_fusion
assert
int
(
log_matches
[
0
])
==
matches
.
allreduce_fusion
assert
int
(
log_matches
[
1
])
==
matches
.
allreduce_fusion
assert
int
(
log_matches
[
1
])
==
matches
.
allreduce_fusion
log_matches
=
re
.
findall
(
r
"pass_manager.py:\d+] Skipping .*AllReduceFusionPass.* with compile range"
,
log_holder
.
text
,
)
assert
len
(
log_matches
)
==
2
*
(
num_compile_ranges
-
1
),
log_holder
.
text
@
multi_gpu_test
(
num_gpus
=
2
)
@
multi_gpu_test
(
num_gpus
=
2
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -446,7 +456,6 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg
...
@@ -446,7 +456,6 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg
# No cudagraphs by default
# No cudagraphs by default
if
compilation_config
.
cudagraph_mode
is
None
:
if
compilation_config
.
cudagraph_mode
is
None
:
compilation_config
.
cudagraph_mode
=
CUDAGraphMode
.
NONE
compilation_config
.
cudagraph_mode
=
CUDAGraphMode
.
NONE
llm
=
LLM
(
llm
=
LLM
(
model
=
model
,
model
=
model
,
compilation_config
=
compilation_config
,
compilation_config
=
compilation_config
,
...
@@ -459,3 +468,9 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg
...
@@ -459,3 +468,9 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg
prompt
=
output
.
prompt
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
# Get the compile ranges split points after vllm config post init
# in order to compute compile ranges correctly
compilation_config
.
compile_ranges_split_points
=
(
llm
.
llm_engine
.
vllm_config
.
compilation_config
.
compile_ranges_split_points
)
tests/compile/test_aot_compile.py
View file @
8d75f22e
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
functools
import
multiprocessing
import
tempfile
import
tempfile
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
...
@@ -137,3 +139,67 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch):
...
@@ -137,3 +139,67 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch):
artifacts
=
compiled_mod
.
aot_compiled_fn
.
_artifacts
artifacts
=
compiled_mod
.
aot_compiled_fn
.
_artifacts
guards_string
=
artifacts
.
compiled_fn
.
shape_env
.
format_guards
()
guards_string
=
artifacts
.
compiled_fn
.
shape_env
.
format_guards
()
assert
guards_string
==
" - s77 <= 42
\n
- Eq(Mod(s77, 2), 0)"
assert
guards_string
==
" - s77 <= 42
\n
- Eq(Mod(s77, 2), 0)"
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0.dev"
),
reason
=
"requires torch 2.10"
)
@
use_vllm_config
(
make_vllm_config
())
def
test_gpt2_cache_hit
(
monkeypatch
:
pytest
.
MonkeyPatch
):
"""
Test that compiling gpt2 twice results in a cache hit and
capture torch dynamic symbol creations to ensure make_symbol
not called on cache hit.
"""
import
torch.fx.experimental.symbolic_shapes
as
symbolic_shapes_module
from
torch.utils._sympy.symbol
import
make_symbol
from
vllm
import
LLM
create_symbol_counter
=
multiprocessing
.
Value
(
"i"
,
0
)
original_make_symbol
=
make_symbol
@
functools
.
wraps
(
original_make_symbol
)
def
counting_make_symbol
(
prefix
,
idx
,
**
kwargs
):
with
create_symbol_counter
.
get_lock
():
create_symbol_counter
.
value
+=
1
return
original_make_symbol
(
prefix
,
idx
,
**
kwargs
)
symbolic_shapes_module
.
make_symbol
=
counting_make_symbol
try
:
with
monkeypatch
.
context
()
as
m
,
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
m
.
setenv
(
"VLLM_CACHE_ROOT"
,
tmpdirname
)
m
.
setenv
(
"VLLM_USE_AOT_COMPILE"
,
"1"
)
# First compilation - initialize model and generate
llm_model
=
LLM
(
model
=
"gpt2"
,
compilation_config
=
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
),
max_model_len
=
256
,
)
llm_model
.
generate
(
"Hello, my name is"
)
assert
create_symbol_counter
.
value
==
2
create_symbol_counter
.
value
=
0
# Clean up first model
del
llm_model
# Second compilation - should hit cache
m
.
setenv
(
"VLLM_FORCE_AOT_LOAD"
,
"1"
)
llm_model
=
LLM
(
model
=
"gpt2"
,
compilation_config
=
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
),
max_model_len
=
256
,
)
llm_model
.
generate
(
"Hello, my name is"
)
assert
create_symbol_counter
.
value
==
0
finally
:
# Restore original method
symbolic_shapes_module
.
make_symbol
=
original_make_symbol
tests/compile/test_compile_ranges.py
0 → 100644
View file @
8d75f22e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
import
torch
from
torch
import
fx
as
fx
from
torch
import
nn
# This import automatically registers `torch.ops.silly.attention`
import
tests.compile.silly_attention
# noqa
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.inductor_pass
import
(
InductorPass
,
get_pass_context
,
)
from
vllm.config
import
(
VllmConfig
,
set_current_vllm_config
,
)
from
vllm.config.compilation
import
CompilationConfig
,
CompilationMode
from
vllm.config.scheduler
import
SchedulerConfig
from
vllm.config.utils
import
Range
from
vllm.forward_context
import
set_forward_context
BATCH_SIZE
=
64
MLP_SIZE
=
128
@
support_torch_compile
class
TestModel
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
**
kwargs
)
->
None
:
super
().
__init__
()
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
x
+
x
attn_output
=
torch
.
empty_like
(
x
)
torch
.
ops
.
silly
.
attention
(
x
,
x
,
x
,
attn_output
)
x
=
attn_output
x
=
x
*
3
return
x
@
torch
.
inference_mode
def
run_model
(
vllm_config
:
VllmConfig
,
model
:
nn
.
Module
,
batch_sizes
:
list
[
int
]):
with
set_forward_context
({},
vllm_config
=
vllm_config
):
model
(
torch
.
randn
(
BATCH_SIZE
,
MLP_SIZE
))
for
batch_size
in
batch_sizes
:
model
(
torch
.
randn
(
batch_size
,
MLP_SIZE
))
class
PostGradRangeChecker
(
InductorPass
):
def
__init__
(
self
,
ranges
:
list
[
Range
]):
self
.
ranges
=
ranges
self
.
num_calls
=
0
def
__call__
(
self
,
graph
:
fx
.
Graph
):
compile_range
=
get_pass_context
().
compile_range
assert
compile_range
in
self
.
ranges
,
(
f
"Compile range
{
compile_range
}
not in
{
self
.
ranges
}
"
)
self
.
num_calls
+=
1
def
uuid
(
self
)
->
str
:
state
:
dict
[
str
,
Any
]
=
{}
return
InductorPass
.
hash_dict
(
state
)
def
test_compile_ranges
(
use_fresh_inductor_cache
):
post_grad_range_checker
=
PostGradRangeChecker
(
[
Range
(
start
=
1
,
end
=
8
),
Range
(
start
=
16
,
end
=
16
),
Range
(
start
=
9
,
end
=
32
),
Range
(
start
=
64
,
end
=
64
),
Range
(
start
=
33
,
end
=
8192
),
]
)
torch
.
set_default_device
(
"cuda"
)
vllm_config
=
VllmConfig
(
scheduler_config
=
SchedulerConfig
(
max_num_batched_tokens
=
8192
,
max_model_len
=
8192
,
is_encoder_decoder
=
False
,
),
compilation_config
=
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
compile_ranges_split_points
=
[
8
,
32
],
compile_sizes
=
[
16
,
64
,
128
],
inductor_compile_config
=
{
"post_grad_custom_post_pass"
:
post_grad_range_checker
,
},
),
)
with
set_current_vllm_config
(
vllm_config
):
model
=
TestModel
(
vllm_config
=
vllm_config
,
prefix
=
""
).
eval
()
# Number of compilations: 3 for each compile range + 2 compile sizes
batch_sizes
=
[
1
,
4
,
16
,
24
,
48
,
64
,
8192
]
with
compilation_counter
.
expect
(
num_graphs_seen
=
1
,
num_piecewise_graphs_seen
=
1
,
num_backend_compilations
=
5
,
):
run_model
(
vllm_config
,
model
,
batch_sizes
)
assert
post_grad_range_checker
.
num_calls
==
5
def
test_compile_config_get_compile_ranges
():
compilation_config
=
CompilationConfig
(
compile_ranges_split_points
=
[
8
,
32
],
)
VllmConfig
(
scheduler_config
=
SchedulerConfig
(
max_num_batched_tokens
=
8192
,
max_model_len
=
8192
,
is_encoder_decoder
=
False
,
),
compilation_config
=
compilation_config
,
)
assert
compilation_config
.
get_compile_ranges
()
==
[
Range
(
start
=
1
,
end
=
8
),
Range
(
start
=
9
,
end
=
32
),
Range
(
start
=
33
,
end
=
8192
),
]
def
test_inductor_cache_compile_ranges
(
monkeypatch
,
use_fresh_inductor_cache
):
# To force multiple compilations, we disable the compile cache
monkeypatch
.
setenv
(
"VLLM_DISABLE_COMPILE_CACHE"
,
"1"
)
post_grad_range_checker
=
PostGradRangeChecker
(
ranges
=
[
Range
(
start
=
1
,
end
=
8
),
Range
(
start
=
9
,
end
=
8192
),
]
)
scheduler_config
=
SchedulerConfig
(
max_num_batched_tokens
=
8192
,
max_model_len
=
8192
,
is_encoder_decoder
=
False
,
)
torch
.
set_default_device
(
"cuda"
)
def
create_vllm_config
():
return
VllmConfig
(
scheduler_config
=
scheduler_config
,
compilation_config
=
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
compile_ranges_split_points
=
[
8
],
inductor_compile_config
=
{
"post_grad_custom_post_pass"
:
post_grad_range_checker
,
},
),
)
vllm_config_1
=
create_vllm_config
()
with
set_current_vllm_config
(
vllm_config_1
):
model1
=
TestModel
(
vllm_config
=
vllm_config_1
,
prefix
=
""
).
eval
()
batch_sizes
=
[
1
,
16
]
run_model
(
vllm_config_1
,
model1
,
batch_sizes
)
assert
post_grad_range_checker
.
num_calls
==
2
post_grad_range_checker
.
num_calls
=
0
# Create a new vllm config with the new pass context
vllm_config_2
=
create_vllm_config
()
with
set_current_vllm_config
(
vllm_config_2
):
model2
=
TestModel
(
vllm_config
=
vllm_config_2
,
prefix
=
""
).
eval
()
batch_sizes
=
[
4
,
32
]
run_model
(
vllm_config_2
,
model2
,
batch_sizes
)
# Check that cache is used, so the number of calls
# should be 0
assert
post_grad_range_checker
.
num_calls
==
0
tests/compile/test_config.py
View file @
8d75f22e
...
@@ -10,7 +10,7 @@ from pydantic import ValidationError
...
@@ -10,7 +10,7 @@ from pydantic import ValidationError
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.fix_functionalization
import
FixFunctionalizationPass
from
vllm.compilation.fix_functionalization
import
FixFunctionalizationPass
from
vllm.config
import
CompilationConfig
,
CUDAGraphMode
,
VllmConfig
from
vllm.config
import
CompilationConfig
,
CUDAGraphMode
,
ParallelConfig
,
VllmConfig
from
vllm.config.compilation
import
CompilationMode
,
PassConfig
from
vllm.config.compilation
import
CompilationMode
,
PassConfig
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.logger
import
_print_warning_once
from
vllm.logger
import
_print_warning_once
...
@@ -235,6 +235,70 @@ def test_splitting_ops_dynamic():
...
@@ -235,6 +235,70 @@ def test_splitting_ops_dynamic():
assert
config
.
compilation_config
.
cudagraph_mode
==
CUDAGraphMode
.
PIECEWISE
assert
config
.
compilation_config
.
cudagraph_mode
==
CUDAGraphMode
.
PIECEWISE
def
test_moe_splitting_ops_deepep_ht_piecewise
():
# Non-inductor, non-attn-fusion case: DeepEP HT with dp>1
# should add MoE ops to splitting_ops on top of attention ops.
config
=
VllmConfig
(
parallel_config
=
ParallelConfig
(
all2all_backend
=
"deepep_high_throughput"
,
data_parallel_size
=
8
,
),
compilation_config
=
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
),
)
splitting_ops
=
config
.
compilation_config
.
splitting_ops
assert
splitting_ops
is
not
None
assert
"vllm::moe_forward"
in
splitting_ops
assert
"vllm::moe_forward_shared"
in
splitting_ops
def
test_moe_splitting_ops_deepep_ht_inductor_partition
():
# Inductor partition case: user-provided splitting_ops should be
# preserved and MoE ops should be appended for DeepEP HT with dp>1.
config
=
VllmConfig
(
parallel_config
=
ParallelConfig
(
all2all_backend
=
"deepep_high_throughput"
,
data_parallel_size
=
8
,
),
compilation_config
=
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
use_inductor_graph_partition
=
True
,
splitting_ops
=
[
"vllm::unified_attention"
,
"vllm::moe_forward"
,
"vllm::moe_forward_shared"
,
],
),
)
splitting_ops
=
config
.
compilation_config
.
splitting_ops
assert
splitting_ops
==
[
"vllm::unified_attention"
,
"vllm::moe_forward"
,
"vllm::moe_forward_shared"
,
]
def
test_moe_splitting_ops_deepep_ht_attn_fusion_no_inductor
():
# Pure attn-fusion case without inductor partition: even with
# DeepEP HT and dp>1, we should not re-enable piecewise compilation
# or add MoE ops into splitting_ops.
config
=
VllmConfig
(
parallel_config
=
ParallelConfig
(
all2all_backend
=
"deepep_high_throughput"
,
data_parallel_size
=
8
,
),
compilation_config
=
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
pass_config
=
{
"enable_attn_fusion"
:
True
,
"enable_noop"
:
True
},
custom_ops
=
[
"+quant_fp8"
],
cudagraph_mode
=
CUDAGraphMode
.
PIECEWISE
,
),
)
assert
config
.
compilation_config
.
splitting_ops
==
[]
assert
config
.
compilation_config
.
cudagraph_mode
==
CUDAGraphMode
.
FULL
def
test_should_split
():
def
test_should_split
():
import
torch
import
torch
...
@@ -392,39 +456,48 @@ def test_pass_config_deprecation(caplog_vllm):
...
@@ -392,39 +456,48 @@ def test_pass_config_deprecation(caplog_vllm):
assert
"enable_fusion is deprecated"
in
caplog_vllm
.
text
assert
"enable_fusion is deprecated"
in
caplog_vllm
.
text
assert
config
.
fuse_norm_quant
is
True
assert
config
.
fuse_norm_quant
is
True
assert
config
.
fuse_act_quant
is
True
assert
config
.
fuse_act_quant
is
True
assert
config
.
enable_fusion
is
Non
e
assert
config
.
enable_fusion
is
Tru
e
# Test enable_attn_fusion -> fuse_attn_quant
# Test enable_attn_fusion -> fuse_attn_quant
caplog_vllm
.
clear
()
caplog_vllm
.
clear
()
config
=
PassConfig
(
enable_attn_fusion
=
True
)
config
=
PassConfig
(
enable_attn_fusion
=
True
)
assert
"enable_attn_fusion is deprecated"
in
caplog_vllm
.
text
assert
"enable_attn_fusion is deprecated"
in
caplog_vllm
.
text
assert
config
.
fuse_attn_quant
is
True
assert
config
.
fuse_attn_quant
is
True
assert
config
.
enable_attn_fusion
is
Non
e
assert
config
.
enable_attn_fusion
is
Tru
e
# Test enable_noop -> eliminate_noops
# Test enable_noop -> eliminate_noops
caplog_vllm
.
clear
()
caplog_vllm
.
clear
()
config
=
PassConfig
(
enable_noop
=
True
)
config
=
PassConfig
(
enable_noop
=
True
)
assert
"enable_noop is deprecated"
in
caplog_vllm
.
text
assert
"enable_noop is deprecated"
in
caplog_vllm
.
text
assert
config
.
eliminate_noops
is
True
assert
config
.
eliminate_noops
is
True
assert
config
.
enable_noop
is
Non
e
assert
config
.
enable_noop
is
Tru
e
# Test enable_sequence_parallelism -> enable_sp
# Test enable_sequence_parallelism -> enable_sp
caplog_vllm
.
clear
()
caplog_vllm
.
clear
()
config
=
PassConfig
(
enable_sequence_parallelism
=
True
)
config
=
PassConfig
(
enable_sequence_parallelism
=
True
)
assert
"enable_sequence_parallelism is deprecated"
in
caplog_vllm
.
text
assert
"enable_sequence_parallelism is deprecated"
in
caplog_vllm
.
text
assert
config
.
enable_sp
is
True
assert
config
.
enable_sp
is
True
assert
config
.
enable_sequence_parallelism
is
Non
e
assert
config
.
enable_sequence_parallelism
is
Tru
e
# Test enable_async_tp -> fuse_gemm_comms
# Test enable_async_tp -> fuse_gemm_comms
caplog_vllm
.
clear
()
caplog_vllm
.
clear
()
config
=
PassConfig
(
enable_async_tp
=
True
)
config
=
PassConfig
(
enable_async_tp
=
True
)
assert
"enable_async_tp is deprecated"
in
caplog_vllm
.
text
assert
"enable_async_tp is deprecated"
in
caplog_vllm
.
text
assert
config
.
fuse_gemm_comms
is
True
assert
config
.
fuse_gemm_comms
is
True
assert
config
.
enable_async_tp
is
Non
e
assert
config
.
enable_async_tp
is
Tru
e
# Test enable_fi_allreduce_fusion -> fuse_allreduce_rms
# Test enable_fi_allreduce_fusion -> fuse_allreduce_rms
caplog_vllm
.
clear
()
caplog_vllm
.
clear
()
config
=
PassConfig
(
enable_fi_allreduce_fusion
=
True
)
config
=
PassConfig
(
enable_fi_allreduce_fusion
=
True
)
assert
"enable_fi_allreduce_fusion is deprecated"
in
caplog_vllm
.
text
assert
"enable_fi_allreduce_fusion is deprecated"
in
caplog_vllm
.
text
assert
config
.
fuse_allreduce_rms
is
True
assert
config
.
fuse_allreduce_rms
is
True
assert
config
.
enable_fi_allreduce_fusion
is
None
assert
config
.
enable_fi_allreduce_fusion
is
True
# Test hash consistency
config_old
=
PassConfig
(
enable_fusion
=
True
)
config_new
=
PassConfig
(
fuse_norm_quant
=
True
,
fuse_act_quant
=
True
)
assert
config_old
.
compute_hash
()
==
config_new
.
compute_hash
()
config_old
=
PassConfig
(
enable_async_tp
=
True
)
config_new
=
PassConfig
(
fuse_gemm_comms
=
True
)
assert
config_old
.
compute_hash
()
==
config_new
.
compute_hash
()
tests/compile/test_dynamic_shapes_compilation.py
View file @
8d75f22e
...
@@ -2,12 +2,21 @@
...
@@ -2,12 +2,21 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
gc
import
gc
import
tempfile
from
contextlib
import
contextmanager
import
pytest
import
pytest
import
torch
import
torch
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.config.compilation
import
CompilationMode
,
DynamicShapesType
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CompilationConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.config.compilation
import
(
CompilationMode
,
DynamicShapesConfig
,
DynamicShapesType
,
)
from
vllm.forward_context
import
set_forward_context
from
vllm.tokenizers
import
get_tokenizer
from
vllm.tokenizers
import
get_tokenizer
from
vllm.utils.torch_utils
import
is_torch_equal_or_newer
from
vllm.utils.torch_utils
import
is_torch_equal_or_newer
...
@@ -29,18 +38,19 @@ def get_test_models():
...
@@ -29,18 +38,19 @@ def get_test_models():
)
)
@
pytest
.
mark
.
parametrize
(
"use_aot_compile"
,
[
"0"
])
@
pytest
.
mark
.
parametrize
(
"use_aot_compile"
,
[
"0"
])
@
pytest
.
mark
.
parametrize
(
"use_bytecode_hook"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bytecode_hook"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"evaluate_guards"
,
[
False
,
True
])
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0.dev"
),
reason
=
"requires torch 2.10"
not
is_torch_equal_or_newer
(
"2.10.0.dev"
),
reason
=
"requires torch 2.10"
)
)
def
test_dynamic_shapes_compilation
(
def
test_dynamic_shapes_compilation
(
monkeypatch
,
model_name
,
shapes_type
,
use_aot_compile
,
use_bytecode_hook
monkeypatch
,
model_name
,
shapes_type
,
use_aot_compile
,
use_bytecode_hook
,
evaluate_guards
,
):
):
"""Test that all dynamic shapes types compile successfully"""
"""Test that all dynamic shapes types compile successfully"""
print
(
f
"
\n
Testing model:
{
model_name
}
with
{
shapes_type
.
name
}
, "
f
"AOT compile:
{
use_aot_compile
}
, "
f
"Bytecode hook:
{
use_bytecode_hook
}
"
)
if
use_bytecode_hook
and
shapes_type
==
DynamicShapesType
.
UNBACKED
:
if
use_bytecode_hook
and
shapes_type
==
DynamicShapesType
.
UNBACKED
:
pytest
.
skip
(
"UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0"
)
pytest
.
skip
(
"UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0"
)
...
@@ -58,6 +68,7 @@ def test_dynamic_shapes_compilation(
...
@@ -58,6 +68,7 @@ def test_dynamic_shapes_compilation(
"mode"
:
CompilationMode
.
VLLM_COMPILE
,
"mode"
:
CompilationMode
.
VLLM_COMPILE
,
"dynamic_shapes_config"
:
{
"dynamic_shapes_config"
:
{
"type"
:
shapes_type
.
value
,
"type"
:
shapes_type
.
value
,
"evaluate_guards"
:
evaluate_guards
,
},
},
},
},
)
)
...
@@ -86,3 +97,117 @@ def test_dynamic_shapes_compilation(
...
@@ -86,3 +97,117 @@ def test_dynamic_shapes_compilation(
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
"GPU memory cleared"
)
print
(
"GPU memory cleared"
)
@
pytest
.
mark
.
parametrize
(
"use_aot_compile"
,
[
"0"
,
"1"
])
@
pytest
.
mark
.
parametrize
(
"dynamic_shapes_type"
,
[
DynamicShapesType
.
BACKED
,
DynamicShapesType
.
BACKED_SIZE_OBLIVIOUS
,
],
)
@
pytest
.
mark
.
parametrize
(
"evaluate_guards"
,
[
False
,
True
])
def
test_model_specialization_with_evaluate_guards
(
monkeypatch
,
use_aot_compile
,
dynamic_shapes_type
,
evaluate_guards
):
"""Test that evaluate_guards correctly detects shape specialization
violations.
"""
if
(
use_aot_compile
==
"1"
and
dynamic_shapes_type
==
DynamicShapesType
.
BACKED
and
evaluate_guards
):
pytest
.
skip
(
"evaluate_guards for backed does not work with aot_compile =1"
)
@
support_torch_compile
class
ModelWithSizeCheck
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
()
def
forward
(
self
,
x
:
torch
.
Tensor
):
# This will cause specialization - torch.compile will guard on
# sx.shape[0]
if
x
.
shape
[
0
]
>=
10
:
return
x
*
10
else
:
return
x
*
10
@
support_torch_compile
class
ModelWithOneSizeCheck
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
()
def
forward
(
self
,
x
:
torch
.
Tensor
):
# This will cause 0/1 specializations.
if
x
.
shape
[
0
]
==
0
:
return
x
*
10
if
x
.
shape
[
0
]
==
1
:
return
x
*
10
else
:
return
x
*
10
@
contextmanager
def
use_vllm_config
(
vllm_config
:
VllmConfig
):
with
set_forward_context
({},
vllm_config
),
set_current_vllm_config
(
vllm_config
):
yield
monkeypatch
.
setenv
(
"TOKENIZERS_PARALLELISM"
,
"true"
)
monkeypatch
.
setenv
(
"VLLM_USE_AOT_COMPILE"
,
use_aot_compile
)
monkeypatch
.
setenv
(
"VLLM_USE_BYTECODE_HOOK"
,
"0"
)
# Create vllm config with the desired settings
from
vllm.config
import
CompilationMode
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
dynamic_shapes_config
=
DynamicShapesConfig
(
type
=
dynamic_shapes_type
,
evaluate_guards
=
evaluate_guards
,
),
)
)
def
test
(
model_class
,
input1
,
input2
,
is_01_specialization
=
False
):
with
(
torch
.
no_grad
(),
use_vllm_config
(
vllm_config
),
tempfile
.
TemporaryDirectory
()
as
tmpdirname
,
):
monkeypatch
.
setenv
(
"VLLM_CACHE_ROOT"
,
tmpdirname
)
model
=
model_class
(
vllm_config
=
vllm_config
).
cuda
()
model
(
input1
)
if
evaluate_guards
and
(
not
(
is_01_specialization
and
dynamic_shapes_type
==
DynamicShapesType
.
BACKED
)
):
# This should fail because guards were added.
with
pytest
.
raises
(
RuntimeError
)
as
excinfo
:
model
(
input2
)
# Expected failure - guard was violated
error_msg
=
str
(
excinfo
.
value
)
assert
(
"GuardManager check failed"
in
error_msg
or
"Detected recompile when torch.compile stance"
in
error_msg
),
error_msg
else
:
model
(
input2
)
test
(
ModelWithSizeCheck
,
torch
.
randn
(
20
,
10
).
cuda
(),
torch
.
randn
(
5
,
10
).
cuda
())
test
(
ModelWithSizeCheck
,
torch
.
randn
(
5
,
10
).
cuda
(),
torch
.
randn
(
20
,
10
).
cuda
())
test
(
ModelWithOneSizeCheck
,
torch
.
randn
(
20
,
10
).
cuda
(),
torch
.
randn
(
1
,
10
).
cuda
(),
is_01_specialization
=
True
,
)
tests/compile/test_fusion.py
View file @
8d75f22e
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
itertools
import
pytest
import
pytest
import
torch
import
torch
import
vllm.plugins
import
vllm.plugins
from
vllm._aiter_ops
import
IS_AITER_FOUND
,
rocm_aiter_ops
from
vllm.compilation.fusion
import
FUSED_OPS
,
FusedRMSQuantKey
,
RMSNormQuantFusionPass
from
vllm.compilation.fusion
import
FUSED_OPS
,
FusedRMSQuantKey
,
RMSNormQuantFusionPass
from
vllm.compilation.fx_utils
import
find_op_nodes
from
vllm.compilation.fx_utils
import
find_op_nodes
from
vllm.compilation.matcher_utils
import
QUANT_OPS
from
vllm.compilation.matcher_utils
import
QUANT_OPS
...
@@ -18,6 +21,9 @@ from vllm.config import (
...
@@ -18,6 +21,9 @@ from vllm.config import (
VllmConfig
,
VllmConfig
,
)
)
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
W8A8BlockFp8LinearOp
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
GroupShape
,
QuantKey
,
QuantKey
,
...
@@ -25,10 +31,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
...
@@ -25,10 +31,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
Fp8LinearOp
,
Fp8LinearOp
,
cutlass_block_fp8_supported
,
cutlass_fp8_supported
,
cutlass_fp8_supported
,
maybe_create_device_identity
,
maybe_create_device_identity
,
)
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.deep_gemm
import
is_deep_gemm_supported
from
..utils
import
override_cutlass_fp8_supported
from
..utils
import
override_cutlass_fp8_supported
from
.backend
import
TestBackend
from
.backend
import
TestBackend
...
@@ -44,7 +52,7 @@ class TestModel(torch.nn.Module):
...
@@ -44,7 +52,7 @@ class TestModel(torch.nn.Module):
self
,
self
,
hidden_size
:
int
,
hidden_size
:
int
,
eps
:
float
,
eps
:
float
,
static
:
bool
,
group_shape
:
GroupShape
,
cuda_force_torch
:
bool
,
cuda_force_torch
:
bool
,
*
args
,
*
args
,
**
kwargs
,
**
kwargs
,
...
@@ -52,8 +60,17 @@ class TestModel(torch.nn.Module):
...
@@ -52,8 +60,17 @@ class TestModel(torch.nn.Module):
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
cuda_force_torch
=
cuda_force_torch
self
.
cuda_force_torch
=
cuda_force_torch
self
.
norm
=
[
RMSNorm
(
hidden_size
,
eps
)
for
_
in
range
(
4
)]
self
.
norm
=
[
RMSNorm
(
hidden_size
,
eps
)
for
_
in
range
(
4
)]
self
.
wscale
=
[
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
for
_
in
range
(
3
)]
if
group_shape
.
is_per_group
():
group_shape
=
GroupShape
.
PER_TENSOR
if
static
else
GroupShape
.
PER_TOKEN
self
.
wscale
=
[
torch
.
rand
(
(
hidden_size
//
group_shape
[
1
],
hidden_size
//
group_shape
[
1
]),
dtype
=
torch
.
float32
,
)
for
_
in
range
(
3
)
]
else
:
self
.
wscale
=
[
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
for
_
in
range
(
3
)]
static
=
group_shape
==
GroupShape
.
PER_TENSOR
quant_scale
=
ScaleDesc
(
torch
.
float32
,
static
,
group_shape
)
quant_scale
=
ScaleDesc
(
torch
.
float32
,
static
,
group_shape
)
self
.
quant_key
=
QuantKey
(
dtype
=
FP8_DTYPE
,
scale
=
quant_scale
,
symmetric
=
True
)
self
.
quant_key
=
QuantKey
(
dtype
=
FP8_DTYPE
,
scale
=
quant_scale
,
symmetric
=
True
)
if
static
:
if
static
:
...
@@ -61,18 +78,29 @@ class TestModel(torch.nn.Module):
...
@@ -61,18 +78,29 @@ class TestModel(torch.nn.Module):
else
:
else
:
self
.
scale
=
[
None
for
_
in
range
(
3
)]
self
.
scale
=
[
None
for
_
in
range
(
3
)]
self
.
w
=
[
self
.
w
=
[
torch
.
rand
(
hidden_size
,
hidden_size
).
to
(
dtype
=
FP8_DTYPE
).
t
()
torch
.
rand
(
hidden_size
,
hidden_size
).
to
(
dtype
=
FP8_DTYPE
)
for
_
in
range
(
3
)
for
_
in
range
(
3
)
]
]
if
not
group_shape
.
is_per_group
():
self
.
w
=
[
self
.
w
[
0
].
t
()
for
_
in
range
(
3
)]
with
override_cutlass_fp8_supported
(
not
cuda_force_torch
):
if
group_shape
.
is_per_group
(
):
self
.
fp8_linear
=
Fp8LinearOp
(
self
.
fp8_linear
=
W8A8Block
Fp8LinearOp
(
act_quant_static
=
static
,
weight_group_shape
=
GroupShape
(
group_shape
[
1
],
group_shape
[
1
])
,
act_quant_group_shape
=
group_shape
,
act_quant_group_shape
=
group_shape
,
cutlass_block_fp8_supported
=
cutlass_block_fp8_supported
(),
use_aiter_and_is_supported
=
False
,
)
)
self
.
enable_quant_fp8_custom_op
=
self
.
fp8_linear
.
input_quant_op
.
enabled
()
else
:
with
override_cutlass_fp8_supported
(
not
cuda_force_torch
):
self
.
fp8_linear
=
Fp8LinearOp
(
act_quant_static
=
static
,
act_quant_group_shape
=
group_shape
,
)
self
.
enable_quant_fp8_custom_op
=
self
.
fp8_linear
.
quant_fp8
.
enabled
()
self
.
enable_rms_norm_custom_op
=
self
.
norm
[
0
].
enabled
()
self
.
enable_rms_norm_custom_op
=
self
.
norm
[
0
].
enabled
()
self
.
enable_quant_fp8_custom_op
=
self
.
fp8_linear
.
quant_fp8
.
enabled
()
self
.
group_shape
=
group_shape
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
# avoid having graph input be an arg to a pattern directly
# avoid having graph input be an arg to a pattern directly
...
@@ -119,13 +147,87 @@ class TestModel(torch.nn.Module):
...
@@ -119,13 +147,87 @@ class TestModel(torch.nn.Module):
)
)
GROUP_SHAPES
=
[
GroupShape
.
PER_TOKEN
,
GroupShape
.
PER_TENSOR
,
GroupShape
(
1
,
128
),
GroupShape
(
1
,
64
),
]
class
TestRmsnormGroupFp8QuantModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
eps
:
float
,
**
kwargs
):
super
().
__init__
()
self
.
w8a8_block_fp8_linear
=
W8A8BlockFp8LinearOp
(
weight_group_shape
=
GroupShape
(
128
,
128
),
act_quant_group_shape
=
GroupShape
(
1
,
128
),
cutlass_block_fp8_supported
=
False
,
use_aiter_and_is_supported
=
True
,
)
self
.
w
=
[
torch
.
rand
(
hidden_size
,
hidden_size
).
to
(
dtype
=
FP8_DTYPE
).
t
()
for
_
in
range
(
3
)
]
scale_hidden_size
=
(
hidden_size
+
128
-
1
)
//
128
self
.
wscale
=
[
torch
.
rand
((
scale_hidden_size
,
scale_hidden_size
),
dtype
=
torch
.
float32
)
for
_
in
range
(
3
)
]
self
.
norm_weight
=
[
torch
.
ones
(
hidden_size
)
for
_
in
range
(
4
)]
self
.
eps
=
eps
def
forward
(
self
,
x
):
# avoid having graph input be an arg to a pattern directly
x
=
resid
=
torch
.
relu
(
x
)
y
=
rocm_aiter_ops
.
rms_norm
(
x
,
self
.
norm_weight
[
0
],
self
.
eps
)
x2
=
self
.
w8a8_block_fp8_linear
.
apply
(
y
,
self
.
w
[
0
],
self
.
wscale
[
0
])
# make sure resid is used for replacement to work
y2
,
resid
=
rocm_aiter_ops
.
rms_norm2d_with_add
(
x2
,
resid
,
self
.
norm_weight
[
1
],
self
.
eps
)
x3
=
self
.
w8a8_block_fp8_linear
.
apply
(
y2
,
self
.
w
[
1
],
self
.
wscale
[
1
])
y3
,
resid
=
rocm_aiter_ops
.
rms_norm2d_with_add
(
x3
,
resid
,
self
.
norm_weight
[
2
],
self
.
eps
)
x4
=
self
.
w8a8_block_fp8_linear
.
apply
(
y3
,
self
.
w
[
2
],
self
.
wscale
[
2
])
y4
,
resid
=
rocm_aiter_ops
.
rms_norm2d_with_add
(
x4
,
resid
,
self
.
norm_weight
[
3
],
self
.
eps
)
return
y4
def
ops_in_model_before
(
self
):
return
[
torch
.
ops
.
vllm
.
rocm_aiter_rms_norm
,
torch
.
ops
.
vllm
.
rocm_aiter_group_fp8_quant
,
]
def
ops_in_model_before_partial
(
self
):
return
[]
def
ops_in_model_after
(
self
):
return
[
torch
.
ops
.
vllm
.
rocm_aiter_rmsnorm_fp8_group_quant
,
torch
.
ops
.
vllm
.
rocm_aiter_rmsnorm_with_add_fp8_group_quant
,
]
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
6
4
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
25
6
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
257
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
257
])
@
pytest
.
mark
.
parametrize
(
"eps"
,
[
1e-5
,
1e-6
])
@
pytest
.
mark
.
parametrize
(
"eps"
,
[
1e-5
,
1e-6
])
@
pytest
.
mark
.
parametrize
(
"static"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"group_shape"
,
GROUP_SHAPES
)
@
pytest
.
mark
.
parametrize
(
"enable_rms_norm_custom_op"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"enable_quant_fp8_custom_op"
,
[
True
,
False
])
"model_class, enable_rms_norm_custom_op, enable_quant_fp8_custom_op"
,
list
(
itertools
.
product
([
TestModel
],
[
True
,
False
],
[
True
,
False
]))
+
[(
TestRmsnormGroupFp8QuantModel
,
False
,
False
)],
)
# cuda_force_torch used to test torch code path on platforms that
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
# cutlass_fp8_supported() == True.
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -139,16 +241,29 @@ def test_fusion_rmsnorm_quant(
...
@@ -139,16 +241,29 @@ def test_fusion_rmsnorm_quant(
hidden_size
,
hidden_size
,
num_tokens
,
num_tokens
,
eps
,
eps
,
static
,
group_shape
,
model_class
,
enable_rms_norm_custom_op
,
enable_rms_norm_custom_op
,
enable_quant_fp8_custom_op
,
enable_quant_fp8_custom_op
,
cuda_force_torch
,
cuda_force_torch
,
):
):
if
model_class
is
TestRmsnormGroupFp8QuantModel
and
not
IS_AITER_FOUND
:
pytest
.
skip
(
"AITER is not supported on this GPU."
)
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_dtype
(
dtype
)
torch
.
set_default_dtype
(
dtype
)
torch
.
manual_seed
(
1
)
torch
.
manual_seed
(
1
)
maybe_create_device_identity
()
# needed for certain non-cutlass fp8 paths
maybe_create_device_identity
()
# needed for certain non-cutlass fp8 paths
if
not
enable_quant_fp8_custom_op
and
group_shape
.
is_per_group
():
pytest
.
skip
(
"Unsupported unwrapped quant fp8 op for blockwise quantization"
)
# Skip test for 64-bit group shape when running with cutlass or deepgemm
if
group_shape
==
GroupShape
(
1
,
64
)
and
(
cutlass_block_fp8_supported
()
or
is_deep_gemm_supported
()
):
pytest
.
skip
(
"Unsupported group shape 64 for CUTLASS/DeepGemm"
)
custom_ops
=
[]
custom_ops
=
[]
if
enable_rms_norm_custom_op
:
if
enable_rms_norm_custom_op
:
custom_ops
.
append
(
"+rms_norm"
)
custom_ops
.
append
(
"+rms_norm"
)
...
@@ -167,13 +282,24 @@ def test_fusion_rmsnorm_quant(
...
@@ -167,13 +282,24 @@ def test_fusion_rmsnorm_quant(
with
vllm
.
config
.
set_current_vllm_config
(
vllm_config
):
with
vllm
.
config
.
set_current_vllm_config
(
vllm_config
):
# Reshape pass is needed for the fusion pass to work
# Reshape pass is needed for the fusion pass to work
noop_pass
=
NoOpEliminationPass
(
vllm_config
)
noop_pass
=
NoOpEliminationPass
(
vllm_config
)
fusion_pass
=
RMSNormQuantFusionPass
(
vllm_config
)
if
model_class
is
TestRmsnormGroupFp8QuantModel
:
from
vllm.compilation.rocm_aiter_fusion
import
(
RocmAiterRMSNormFp8GroupQuantFusionPass
,
)
fusion_pass
=
RocmAiterRMSNormFp8GroupQuantFusionPass
(
vllm_config
)
else
:
fusion_pass
=
RMSNormQuantFusionPass
(
vllm_config
)
cleanup_pass
=
PostCleanupPass
(
vllm_config
)
cleanup_pass
=
PostCleanupPass
(
vllm_config
)
backend
=
TestBackend
(
noop_pass
,
fusion_pass
,
cleanup_pass
)
backend
=
TestBackend
(
noop_pass
,
fusion_pass
,
cleanup_pass
)
backend2
=
TestBackend
(
noop_pass
,
cleanup_pass
)
backend2
=
TestBackend
(
noop_pass
,
cleanup_pass
)
model
=
TestModel
(
hidden_size
,
eps
,
static
,
cuda_force_torch
)
model
=
model_class
(
hidden_size
=
hidden_size
,
eps
=
eps
,
group_shape
=
group_shape
,
cuda_force_torch
=
cuda_force_torch
,
)
# First dimension dynamic
# First dimension dynamic
x
=
torch
.
rand
(
num_tokens
,
hidden_size
)
x
=
torch
.
rand
(
num_tokens
,
hidden_size
)
torch
.
_dynamo
.
mark_dynamic
(
x
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
x
,
0
)
...
@@ -202,7 +328,10 @@ def test_fusion_rmsnorm_quant(
...
@@ -202,7 +328,10 @@ def test_fusion_rmsnorm_quant(
# there's a risk that the fused add doesn't get included in the
# there's a risk that the fused add doesn't get included in the
# replacement and only the rms part gets fused with quant.
# replacement and only the rms part gets fused with quant.
# Hence, we check only 2 add nodes are left (final fused rmsnorm add).
# Hence, we check only 2 add nodes are left (final fused rmsnorm add).
if
not
enable_rms_norm_custom_op
:
if
(
not
enable_rms_norm_custom_op
and
model_class
is
not
TestRmsnormGroupFp8QuantModel
):
n_add_nodes
=
lambda
g
:
sum
(
1
for
_
in
find_op_nodes
(
torch
.
ops
.
aten
.
add
,
g
))
n_add_nodes
=
lambda
g
:
sum
(
1
for
_
in
find_op_nodes
(
torch
.
ops
.
aten
.
add
,
g
))
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
assert
n_add_nodes
(
backend
.
graph_pre_pass
)
==
7
assert
n_add_nodes
(
backend
.
graph_pre_pass
)
==
7
...
...
tests/compile/test_fusion_attn.py
View file @
8d75f22e
...
@@ -12,13 +12,13 @@ from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
...
@@ -12,13 +12,13 @@ from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.registry
import
AttentionBackendEnum
from
vllm.attention.backends.registry
import
AttentionBackendEnum
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.attention.selector
import
global_force_attn_backend_context_manager
from
vllm.compilation.fusion_attn
import
ATTN_OP
,
AttnFusionPass
from
vllm.compilation.fusion_attn
import
ATTN_OP
,
AttnFusionPass
from
vllm.compilation.fx_utils
import
find_op_nodes
from
vllm.compilation.fx_utils
import
find_op_nodes
from
vllm.compilation.matcher_utils
import
QUANT_OPS
from
vllm.compilation.matcher_utils
import
QUANT_OPS
from
vllm.compilation.noop_elimination
import
NoOpEliminationPass
from
vllm.compilation.noop_elimination
import
NoOpEliminationPass
from
vllm.compilation.post_cleanup
import
PostCleanupPass
from
vllm.compilation.post_cleanup
import
PostCleanupPass
from
vllm.config
import
(
from
vllm.config
import
(
AttentionConfig
,
CacheConfig
,
CacheConfig
,
CompilationConfig
,
CompilationConfig
,
CompilationMode
,
CompilationMode
,
...
@@ -335,6 +335,7 @@ def test_attention_quant_pattern(
...
@@ -335,6 +335,7 @@ def test_attention_quant_pattern(
custom_ops
=
custom_ops_list
,
custom_ops
=
custom_ops_list
,
),
),
cache_config
=
CacheConfig
(
cache_dtype
=
"fp8"
),
cache_config
=
CacheConfig
(
cache_dtype
=
"fp8"
),
attention_config
=
AttentionConfig
(
backend
=
backend
),
)
)
# Create test inputs
# Create test inputs
...
@@ -352,7 +353,6 @@ def test_attention_quant_pattern(
...
@@ -352,7 +353,6 @@ def test_attention_quant_pattern(
with
(
with
(
set_current_vllm_config
(
vllm_config_unfused
),
set_current_vllm_config
(
vllm_config_unfused
),
set_forward_context
(
attn_metadata
=
None
,
vllm_config
=
vllm_config_unfused
),
set_forward_context
(
attn_metadata
=
None
,
vllm_config
=
vllm_config_unfused
),
global_force_attn_backend_context_manager
(
backend
),
):
):
model_unfused
=
model_class
(
model_unfused
=
model_class
(
num_qo_heads
=
num_qo_heads
,
num_qo_heads
=
num_qo_heads
,
...
@@ -378,7 +378,6 @@ def test_attention_quant_pattern(
...
@@ -378,7 +378,6 @@ def test_attention_quant_pattern(
with
(
with
(
set_current_vllm_config
(
vllm_config
),
set_current_vllm_config
(
vllm_config
),
set_forward_context
(
attn_metadata
=
None
,
vllm_config
=
vllm_config
),
set_forward_context
(
attn_metadata
=
None
,
vllm_config
=
vllm_config
),
global_force_attn_backend_context_manager
(
backend
),
):
):
model_fused
=
model_class
(
model_fused
=
model_class
(
num_qo_heads
=
num_qo_heads
,
num_qo_heads
=
num_qo_heads
,
...
...
tests/compile/test_pass_manager.py
View file @
8d75f22e
...
@@ -5,9 +5,14 @@ import copy
...
@@ -5,9 +5,14 @@ import copy
import
pytest
import
pytest
import
torch
import
torch
from
vllm.compilation.inductor_pass
import
CallableInductorPass
,
InductorPass
from
vllm.compilation.inductor_pass
import
(
CallableInductorPass
,
InductorPass
,
pass_context
,
)
from
vllm.compilation.pass_manager
import
PostGradPassManager
from
vllm.compilation.pass_manager
import
PostGradPassManager
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.config.utils
import
Range
# dummy custom pass that doesn't inherit
# dummy custom pass that doesn't inherit
...
@@ -42,35 +47,37 @@ class ProperPass(InductorPass):
...
@@ -42,35 +47,37 @@ class ProperPass(InductorPass):
],
],
)
)
def
test_pass_manager_uuid
(
callable
):
def
test_pass_manager_uuid
(
callable
):
# Some passes need dtype to be set
# Set the pass context as PassManager uuid uses it
config
=
VllmConfig
(
model_config
=
ModelConfig
(
dtype
=
torch
.
bfloat16
))
with
pass_context
(
Range
(
start
=
1
,
end
=
8
)):
# Some passes need dtype to be set
pass_manager
=
PostGradPassManager
()
config
=
VllmConfig
(
model_config
=
ModelConfig
(
dtype
=
torch
.
bfloat16
))
pass_manager
.
configure
(
config
)
pass_manager
=
PostGradPassManager
()
# Check that UUID is different if the same pass is added 2x
pass_manager
.
configure
(
config
)
pass_manager
.
add
(
callable
)
uuid1
=
pass_manager
.
uuid
()
# Check that UUID is different if the same pass is added 2x
pass_manager
.
add
(
callable
)
pass_manager
.
add
(
callable
)
uuid2
=
pass_manager
.
uuid
()
uuid1
=
pass_manager
.
uuid
()
assert
uuid1
!=
uuid2
pass_manager
.
add
(
callable
)
uuid2
=
pass_manager
.
uuid
()
# UUID should be the same as the original one,
assert
uuid1
!=
uuid2
# as we constructed in the same way.
pass_manager2
=
PostGradPassManager
()
# UUID should be the same as the original one,
pass_manager2
.
configure
(
config
)
# as we constructed in the same way.
pass_manager2
.
add
(
callable
)
pass_manager2
=
PostGradPassManager
()
assert
uuid1
==
pass_manager2
.
uuid
()
pass_manager2
.
configure
(
config
)
pass_manager2
.
add
(
callable
)
# UUID should be different due to config change
assert
uuid1
==
pass_manager2
.
uuid
()
config2
=
copy
.
deepcopy
(
config
)
config2
.
compilation_config
.
pass_config
.
fuse_norm_quant
=
(
# UUID should be different due to config change
not
config2
.
compilation_config
.
pass_config
.
fuse_norm_quant
config2
=
copy
.
deepcopy
(
config
)
)
config2
.
compilation_config
.
pass_config
.
fuse_norm_quant
=
(
config2
.
compilation_config
.
pass_config
.
fuse_act_quant
=
(
not
config2
.
compilation_config
.
pass_config
.
fuse_norm_quant
not
config2
.
compilation_config
.
pass_config
.
fuse_act_quant
)
)
config2
.
compilation_config
.
pass_config
.
fuse_act_quant
=
(
pass_manager3
=
PostGradPassManager
()
not
config2
.
compilation_config
.
pass_config
.
fuse_act_quant
pass_manager3
.
configure
(
config2
)
)
pass_manager3
.
add
(
callable
)
pass_manager3
=
PostGradPassManager
()
assert
uuid1
!=
pass_manager3
.
uuid
()
pass_manager3
.
configure
(
config2
)
pass_manager3
.
add
(
callable
)
assert
uuid1
!=
pass_manager3
.
uuid
()
tests/compile/test_silu_mul_quant_fusion.py
View file @
8d75f22e
...
@@ -7,6 +7,7 @@ import torch
...
@@ -7,6 +7,7 @@ import torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
tests.kernels.quantization.nvfp4_utils
import
quant_nvfp4_tensor
from
tests.kernels.quantization.nvfp4_utils
import
quant_nvfp4_tensor
from
vllm._aiter_ops
import
IS_AITER_FOUND
from
vllm._custom_ops
import
cutlass_scaled_fp4_mm
,
scaled_fp4_quant
from
vllm._custom_ops
import
cutlass_scaled_fp4_mm
,
scaled_fp4_quant
from
vllm.compilation.activation_quant_fusion
import
(
from
vllm.compilation.activation_quant_fusion
import
(
FUSED_OPS
,
FUSED_OPS
,
...
@@ -24,6 +25,7 @@ from vllm.config import (
...
@@ -24,6 +25,7 @@ from vllm.config import (
set_current_vllm_config
,
set_current_vllm_config
,
)
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
W8A8BlockFp8LinearOp
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
GroupShape
,
kFp8StaticTensorSym
,
kFp8StaticTensorSym
,
...
@@ -126,6 +128,39 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
...
@@ -126,6 +128,39 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
return
[
FUSED_OPS
[
kNvfp4Quant
]]
return
[
FUSED_OPS
[
kNvfp4Quant
]]
class
TestSiluMulGroupFp8QuantModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
**
kwargs
):
super
().
__init__
()
self
.
silu_and_mul
=
SiluAndMul
()
self
.
w8a8_block_fp8_linear
=
W8A8BlockFp8LinearOp
(
weight_group_shape
=
GroupShape
(
128
,
128
),
act_quant_group_shape
=
GroupShape
(
1
,
128
),
cutlass_block_fp8_supported
=
False
,
use_aiter_and_is_supported
=
True
,
)
self
.
w
=
torch
.
rand
(
hidden_size
,
hidden_size
).
to
(
dtype
=
FP8_DTYPE
).
t
()
scale_hidden_size
=
(
hidden_size
+
128
-
1
)
//
128
self
.
wscale
=
torch
.
rand
(
(
scale_hidden_size
,
scale_hidden_size
),
dtype
=
torch
.
float32
)
self
.
enable_silu_mul_custom_op
=
self
.
silu_and_mul
.
enabled
()
def
forward
(
self
,
x
):
y
=
self
.
silu_and_mul
(
x
)
x2
=
self
.
w8a8_block_fp8_linear
.
apply
(
y
,
self
.
w
,
self
.
wscale
)
return
x2
def
ops_in_model_before
(
self
):
return
[
SILU_MUL_OP
if
self
.
enable_silu_mul_custom_op
else
torch
.
ops
.
aten
.
mul
,
]
def
ops_in_model_after
(
self
):
return
[
torch
.
ops
.
vllm
.
rocm_aiter_act_mul_and_fp8_group_quant
]
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
32
,
64
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
32
,
64
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
128
,
256
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
128
,
256
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
...
@@ -133,7 +168,10 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
...
@@ -133,7 +168,10 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"model_class, enable_quant_fp8_custom_op, cuda_force_torch"
,
"model_class, enable_quant_fp8_custom_op, cuda_force_torch"
,
list
(
itertools
.
product
([
TestSiluMulFp8QuantModel
],
[
True
,
False
],
[
True
,
False
]))
list
(
itertools
.
product
([
TestSiluMulFp8QuantModel
],
[
True
,
False
],
[
True
,
False
]))
+
[(
TestSiluMulNvfp4QuantModel
,
False
,
False
)],
+
[
(
TestSiluMulNvfp4QuantModel
,
False
,
False
),
(
TestSiluMulGroupFp8QuantModel
,
False
,
False
),
],
)
)
# cuda_force_torch used to test torch code path on platforms that
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
# cutlass_fp8_supported() == True.
...
@@ -144,13 +182,19 @@ def test_fusion_silu_and_mul_quant(
...
@@ -144,13 +182,19 @@ def test_fusion_silu_and_mul_quant(
num_tokens
:
int
,
num_tokens
:
int
,
hidden_size
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
model_class
:
type
[
TestSiluMulFp8QuantModel
|
TestSiluMulNvfp4QuantModel
],
model_class
:
type
[
TestSiluMulFp8QuantModel
|
TestSiluMulNvfp4QuantModel
|
TestSiluMulGroupFp8QuantModel
],
enable_silu_mul_custom_op
:
bool
,
enable_silu_mul_custom_op
:
bool
,
enable_quant_fp8_custom_op
:
bool
,
enable_quant_fp8_custom_op
:
bool
,
cuda_force_torch
:
bool
,
cuda_force_torch
:
bool
,
):
):
if
model_class
is
TestSiluMulNvfp4QuantModel
and
not
is_nvfp4_supported
():
if
model_class
is
TestSiluMulNvfp4QuantModel
and
not
is_nvfp4_supported
():
pytest
.
skip
(
"NVFP4 is not supported on this GPU."
)
pytest
.
skip
(
"NVFP4 is not supported on this GPU."
)
if
model_class
is
TestSiluMulGroupFp8QuantModel
and
not
IS_AITER_FOUND
:
pytest
.
skip
(
"AITER is not supported on this GPU."
)
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_dtype
(
dtype
)
torch
.
set_default_dtype
(
dtype
)
...
@@ -173,9 +217,15 @@ def test_fusion_silu_and_mul_quant(
...
@@ -173,9 +217,15 @@ def test_fusion_silu_and_mul_quant(
)
)
with
set_current_vllm_config
(
config
):
with
set_current_vllm_config
(
config
):
fusion_pass
=
ActivationQuantFusionPass
(
config
)
fusion_passes
=
[
ActivationQuantFusionPass
(
config
)]
if
IS_AITER_FOUND
:
from
vllm.compilation.rocm_aiter_fusion
import
(
RocmAiterSiluMulFp8GroupQuantFusionPass
,
)
fusion_passes
+=
[
RocmAiterSiluMulFp8GroupQuantFusionPass
(
config
)]
passes
=
[
NoOpEliminationPass
(
config
),
fusion_pass
,
PostCleanupPass
(
config
)]
passes
=
[
NoOpEliminationPass
(
config
),
*
fusion_pass
es
,
PostCleanupPass
(
config
)]
backend
=
TestBackend
(
*
passes
)
backend
=
TestBackend
(
*
passes
)
model
=
model_class
(
model
=
model_class
(
hidden_size
=
hidden_size
,
cuda_force_torch
=
cuda_force_torch
,
x
=
x
hidden_size
=
hidden_size
,
cuda_force_torch
=
cuda_force_torch
,
x
=
x
...
@@ -194,12 +244,14 @@ def test_fusion_silu_and_mul_quant(
...
@@ -194,12 +244,14 @@ def test_fusion_silu_and_mul_quant(
atol
,
rtol
=
1e-3
,
1e-3
atol
,
rtol
=
1e-3
,
1e-3
elif
model_class
==
TestSiluMulNvfp4QuantModel
:
elif
model_class
==
TestSiluMulNvfp4QuantModel
:
atol
,
rtol
=
1e-1
,
1e-1
atol
,
rtol
=
1e-1
,
1e-1
elif
model_class
==
TestSiluMulGroupFp8QuantModel
:
atol
,
rtol
=
5e-2
,
5e-2
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
result
[
0
].
to
(
dtype
=
dtype
),
result2
[
0
].
to
(
dtype
=
dtype
),
atol
=
atol
,
rtol
=
rtol
result
[
0
].
to
(
dtype
=
dtype
),
result2
[
0
].
to
(
dtype
=
dtype
),
atol
=
atol
,
rtol
=
rtol
)
)
assert
fusion_pass
.
matched_count
==
1
assert
sum
([
p
.
matched_count
for
p
in
fusion_passes
])
==
1
# In pre-nodes, quant op should be present and fused kernels should not
# In pre-nodes, quant op should be present and fused kernels should not
backend
.
check_before_ops
(
model
.
ops_in_model_before
())
backend
.
check_before_ops
(
model
.
ops_in_model_before
())
...
...
tests/conftest.py
View file @
8d75f22e
...
@@ -27,7 +27,7 @@ import threading
...
@@ -27,7 +27,7 @@ import threading
from
collections.abc
import
Generator
from
collections.abc
import
Generator
from
contextlib
import
nullcontext
from
contextlib
import
nullcontext
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Any
,
Callable
,
TypedDict
,
TypeVar
,
cast
from
typing
import
Any
,
Callable
,
TypedDict
,
TypeVar
,
cast
,
TYPE_CHECKING
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
...
@@ -59,6 +59,7 @@ from vllm.distributed import (
...
@@ -59,6 +59,7 @@ from vllm.distributed import (
)
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.logprobs
import
Logprob
from
vllm.logprobs
import
Logprob
from
vllm.multimodal.base
import
MediaWithBytes
from
vllm.multimodal.utils
import
fetch_image
from
vllm.multimodal.utils
import
fetch_image
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
BeamSearchParams
from
vllm.sampling_params
import
BeamSearchParams
...
@@ -66,6 +67,14 @@ from vllm.transformers_utils.utils import maybe_model_redirect
...
@@ -66,6 +67,14 @@ from vllm.transformers_utils.utils import maybe_model_redirect
from
vllm.utils.collection_utils
import
is_list_of
from
vllm.utils.collection_utils
import
is_list_of
from
vllm.utils.torch_utils
import
set_default_torch_num_threads
from
vllm.utils.torch_utils
import
set_default_torch_num_threads
from
torch._inductor.utils
import
fresh_cache
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
from
transformers.generation.utils
import
GenerateOutput
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_TEST_DIR
=
os
.
path
.
dirname
(
__file__
)
_TEST_DIR
=
os
.
path
.
dirname
(
__file__
)
...
@@ -201,10 +210,7 @@ def dynamo_reset():
...
@@ -201,10 +210,7 @@ def dynamo_reset():
@
pytest
.
fixture
@
pytest
.
fixture
def
example_prompts
()
->
list
[
str
]:
def
example_prompts
()
->
list
[
str
]:
prompts
=
[]
return
[
prompt
for
filename
in
_TEST_PROMPTS
for
prompt
in
_read_prompts
(
filename
)]
for
filename
in
_TEST_PROMPTS
:
prompts
+=
_read_prompts
(
filename
)
return
prompts
@
pytest
.
fixture
@
pytest
.
fixture
...
@@ -223,10 +229,7 @@ class DecoderPromptType(Enum):
...
@@ -223,10 +229,7 @@ class DecoderPromptType(Enum):
@
pytest
.
fixture
@
pytest
.
fixture
def
example_long_prompts
()
->
list
[
str
]:
def
example_long_prompts
()
->
list
[
str
]:
prompts
=
[]
return
[
prompt
for
filename
in
_LONG_PROMPTS
for
prompt
in
_read_prompts
(
filename
)]
for
filename
in
_LONG_PROMPTS
:
prompts
+=
_read_prompts
(
filename
)
return
prompts
@
pytest
.
fixture
(
scope
=
"session"
)
@
pytest
.
fixture
(
scope
=
"session"
)
...
@@ -352,10 +355,13 @@ class HfRunner:
...
@@ -352,10 +355,13 @@ class HfRunner:
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
)
)
else
:
else
:
model
=
auto_cls
.
from_pretrained
(
model
=
cast
(
model_name
,
nn
.
Module
,
trust_remote_code
=
trust_remote_code
,
auto_cls
.
from_pretrained
(
**
model_kwargs
,
model_name
,
trust_remote_code
=
trust_remote_code
,
**
model_kwargs
,
),
)
)
# in case some unquantized custom models are not in same dtype
# in case some unquantized custom models are not in same dtype
...
@@ -373,10 +379,12 @@ class HfRunner:
...
@@ -373,10 +379,12 @@ class HfRunner:
self
.
model
=
model
self
.
model
=
model
if
not
skip_tokenizer_init
:
if
not
skip_tokenizer_init
:
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
self
.
tokenizer
:
"PreTrainedTokenizer | PreTrainedTokenizerFast"
=
(
model_name
,
AutoTokenizer
.
from_pretrained
(
dtype
=
dtype
,
model_name
,
trust_remote_code
=
trust_remote_code
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
)
)
)
# don't put this import at the top level
# don't put this import at the top level
...
@@ -397,6 +405,7 @@ class HfRunner:
...
@@ -397,6 +405,7 @@ class HfRunner:
images
:
PromptImageInput
|
None
=
None
,
images
:
PromptImageInput
|
None
=
None
,
videos
:
PromptVideoInput
|
None
=
None
,
videos
:
PromptVideoInput
|
None
=
None
,
audios
:
PromptAudioInput
|
None
=
None
,
audios
:
PromptAudioInput
|
None
=
None
,
tokenization_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
)
->
list
[
BatchFeature
|
BatchEncoding
|
dict
[
str
,
torch
.
Tensor
]]:
)
->
list
[
BatchFeature
|
BatchEncoding
|
dict
[
str
,
torch
.
Tensor
]]:
if
images
is
not
None
:
if
images
is
not
None
:
assert
len
(
prompts
)
==
len
(
images
)
assert
len
(
prompts
)
==
len
(
images
)
...
@@ -410,10 +419,18 @@ class HfRunner:
...
@@ -410,10 +419,18 @@ class HfRunner:
all_inputs
:
list
[
BatchFeature
|
BatchEncoding
|
dict
[
str
,
torch
.
Tensor
]]
=
[]
all_inputs
:
list
[
BatchFeature
|
BatchEncoding
|
dict
[
str
,
torch
.
Tensor
]]
=
[]
for
i
,
prompt
in
enumerate
(
prompts
):
for
i
,
prompt
in
enumerate
(
prompts
):
if
isinstance
(
prompt
,
str
):
if
isinstance
(
prompt
,
str
):
processor_kwargs
:
dict
[
str
,
Any
]
=
{
# Create a copy to avoid modifying the original dict
"text"
:
prompt
,
processor_kwargs
=
(
"return_tensors"
:
"pt"
,
tokenization_kwargs
.
copy
()
}
if
tokenization_kwargs
is
not
None
else
{}
)
processor_kwargs
.
update
(
{
"text"
:
prompt
,
"return_tensors"
:
"pt"
,
}
)
if
images
is
not
None
and
(
image
:
=
images
[
i
])
is
not
None
:
if
images
is
not
None
and
(
image
:
=
images
[
i
])
is
not
None
:
processor_kwargs
[
"images"
]
=
image
processor_kwargs
[
"images"
]
=
image
if
videos
is
not
None
and
(
video
:
=
videos
[
i
])
is
not
None
:
if
videos
is
not
None
and
(
video
:
=
videos
[
i
])
is
not
None
:
...
@@ -494,7 +511,7 @@ class HfRunner:
...
@@ -494,7 +511,7 @@ class HfRunner:
outputs
:
list
[
tuple
[
list
[
list
[
int
]],
list
[
str
]]]
=
[]
outputs
:
list
[
tuple
[
list
[
list
[
int
]],
list
[
str
]]]
=
[]
for
inputs
in
all_inputs
:
for
inputs
in
all_inputs
:
output_ids
=
self
.
model
.
generate
(
output_ids
:
torch
.
Tensor
=
self
.
model
.
generate
(
**
self
.
wrap_device
(
inputs
),
**
self
.
wrap_device
(
inputs
),
use_cache
=
True
,
use_cache
=
True
,
**
kwargs
,
**
kwargs
,
...
@@ -504,8 +521,7 @@ class HfRunner:
...
@@ -504,8 +521,7 @@ class HfRunner:
skip_special_tokens
=
True
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
,
clean_up_tokenization_spaces
=
False
,
)
)
output_ids
=
output_ids
.
cpu
().
tolist
()
outputs
.
append
((
output_ids
.
cpu
().
tolist
(),
output_str
))
outputs
.
append
((
output_ids
,
output_str
))
return
outputs
return
outputs
def
generate_greedy
(
def
generate_greedy
(
...
@@ -573,7 +589,7 @@ class HfRunner:
...
@@ -573,7 +589,7 @@ class HfRunner:
all_logprobs
:
list
[
list
[
torch
.
Tensor
]]
=
[]
all_logprobs
:
list
[
list
[
torch
.
Tensor
]]
=
[]
for
inputs
in
all_inputs
:
for
inputs
in
all_inputs
:
output
=
self
.
model
.
generate
(
output
:
"GenerateOutput"
=
self
.
model
.
generate
(
**
self
.
wrap_device
(
inputs
),
**
self
.
wrap_device
(
inputs
),
use_cache
=
True
,
use_cache
=
True
,
do_sample
=
False
,
do_sample
=
False
,
...
@@ -655,7 +671,7 @@ class HfRunner:
...
@@ -655,7 +671,7 @@ class HfRunner:
all_output_strs
:
list
[
str
]
=
[]
all_output_strs
:
list
[
str
]
=
[]
for
inputs
in
all_inputs
:
for
inputs
in
all_inputs
:
output
=
self
.
model
.
generate
(
output
:
"GenerateOutput"
=
self
.
model
.
generate
(
**
self
.
wrap_device
(
inputs
),
**
self
.
wrap_device
(
inputs
),
use_cache
=
True
,
use_cache
=
True
,
do_sample
=
False
,
do_sample
=
False
,
...
@@ -1389,7 +1405,11 @@ class LocalAssetServer:
...
@@ -1389,7 +1405,11 @@ class LocalAssetServer:
return
f
"
{
self
.
base_url
}
/
{
name
}
"
return
f
"
{
self
.
base_url
}
/
{
name
}
"
def
get_image_asset
(
self
,
name
:
str
)
->
Image
.
Image
:
def
get_image_asset
(
self
,
name
:
str
)
->
Image
.
Image
:
return
fetch_image
(
self
.
url_for
(
name
))
image
=
fetch_image
(
self
.
url_for
(
name
))
# Unwrap MediaWithBytes if present
if
isinstance
(
image
,
MediaWithBytes
):
image
=
image
.
media
return
image
@
pytest
.
fixture
(
scope
=
"session"
)
@
pytest
.
fixture
(
scope
=
"session"
)
...
@@ -1457,3 +1477,14 @@ def clean_gpu_memory_between_tests():
...
@@ -1457,3 +1477,14 @@ def clean_gpu_memory_between_tests():
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
@
pytest
.
fixture
def
use_fresh_inductor_cache
():
"""
Use a fresh inductor cache for the test.
This is useful to ensure that the test is not affected by the
previous test calls.
"""
with
fresh_cache
():
yield
tests/distributed/test_context_parallel.py
View file @
8d75f22e
...
@@ -16,16 +16,35 @@ from typing import Literal, NamedTuple
...
@@ -16,16 +16,35 @@ from typing import Literal, NamedTuple
import
pytest
import
pytest
import
torch
import
torch
from
tests.evals.gsm8k.gsm8k_eval
import
evaluate_gsm8k
from
tests.utils
import
RemoteOpenAIServer
,
create_new_process_for_each_test
from
vllm.config.model
import
RunnerOption
from
vllm.config.model
import
RunnerOption
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
..models.registry
import
HF_EXAMPLE_MODELS
from
..models.registry
import
HF_EXAMPLE_MODELS
from
..utils
import
compare_two_settings
,
create_new_process_for_each_test
logger
=
init_logger
(
"test_context_parallel"
)
logger
=
init_logger
(
"test_context_parallel"
)
VLLM_MULTI_NODE
=
os
.
getenv
(
"VLLM_MULTI_NODE"
,
"0"
)
==
"1"
VLLM_MULTI_NODE
=
os
.
getenv
(
"VLLM_MULTI_NODE"
,
"0"
)
==
"1"
CP_TEST_MODELS
=
[
# TODO support other models
# [LANGUAGE GENERATION]
"deepseek-ai/DeepSeek-V2-Lite-Chat"
,
"Qwen/Qwen2.5-1.5B-Instruct"
,
]
# GSM8K eval configuration
NUM_QUESTIONS
=
256
# Fast eval for CI
NUM_SHOTS
=
5
# Few-shot examples
# tp accuracy with 2% buffer
MIN_ACCURACY
=
{
# .buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml
"deepseek-ai/DeepSeek-V2-Lite-Chat"
:
0.64
,
# .buildkite/lm-eval-harness/configs/Qwen2.5-1.5B-Instruct.yaml
"Qwen/Qwen2.5-1.5B-Instruct"
:
0.52
,
}
class
ParallelSetup
(
NamedTuple
):
class
ParallelSetup
(
NamedTuple
):
tp_size
:
int
tp_size
:
int
...
@@ -38,7 +57,6 @@ class ParallelSetup(NamedTuple):
...
@@ -38,7 +57,6 @@ class ParallelSetup(NamedTuple):
class
CPTestOptions
(
NamedTuple
):
class
CPTestOptions
(
NamedTuple
):
multi_node_only
:
bool
multi_node_only
:
bool
load_format
:
str
|
None
=
None
attn_backend
:
str
|
None
=
None
attn_backend
:
str
|
None
=
None
...
@@ -54,17 +72,20 @@ class CPTestSettings:
...
@@ -54,17 +72,20 @@ class CPTestSettings:
*
,
*
,
tp_base
:
int
=
4
,
tp_base
:
int
=
4
,
pp_base
:
int
=
1
,
pp_base
:
int
=
1
,
dcp_
base
:
int
=
1
,
dcp_
multipliers
:
list
[
float
]
|
None
=
None
,
cp_kv_cache_interleave_size
:
int
=
1
,
cp_kv_cache_interleave_size
:
int
=
1
,
multi_node_only
:
bool
=
False
,
multi_node_only
:
bool
=
False
,
runner
:
RunnerOption
=
"auto"
,
runner
:
RunnerOption
=
"auto"
,
load_format
:
str
|
None
=
None
,
attn_backend
:
str
|
None
=
None
,
attn_backend
:
str
|
None
=
None
,
):
):
parallel_setups
=
[]
parallel_setups
=
[]
if
dcp_multipliers
is
None
:
dcp_multipliers
=
[
0.5
,
]
for
eager_mode_val
in
[
False
]:
for
eager_mode_val
in
[
False
]:
for
pp_multiplier
in
[
1
]:
for
pp_multiplier
in
[
1
]:
for
dcp_multiplier
in
[
0.5
,
1
]
:
for
dcp_multiplier
in
dcp_multipliers
:
for
chunked_prefill_val
in
[
True
]:
for
chunked_prefill_val
in
[
True
]:
parallel_setups
.
append
(
parallel_setups
.
append
(
ParallelSetup
(
ParallelSetup
(
...
@@ -82,7 +103,6 @@ class CPTestSettings:
...
@@ -82,7 +103,6 @@ class CPTestSettings:
runner
=
runner
,
runner
=
runner
,
test_options
=
CPTestOptions
(
test_options
=
CPTestOptions
(
multi_node_only
=
multi_node_only
,
multi_node_only
=
multi_node_only
,
load_format
=
load_format
,
attn_backend
=
attn_backend
,
attn_backend
=
attn_backend
,
),
),
)
)
...
@@ -101,7 +121,27 @@ class CPTestSettings:
...
@@ -101,7 +121,27 @@ class CPTestSettings:
)
)
def
_compare_cp_with_tp
(
CP_TEXT_GENERATION_MODELS
=
{
"deepseek-ai/DeepSeek-V2-Lite-Chat"
:
[
CPTestSettings
.
detailed
(
dcp_multipliers
=
[
1
]),
CPTestSettings
.
detailed
(
dcp_multipliers
=
[
0.5
],
cp_kv_cache_interleave_size
=
64
,
attn_backend
=
"FLASHMLA"
,
),
],
"Qwen/Qwen2.5-1.5B-Instruct"
:
[
CPTestSettings
.
detailed
(
cp_kv_cache_interleave_size
=
16
,
attn_backend
=
"FLASH_ATTN"
),
CPTestSettings
.
detailed
(
cp_kv_cache_interleave_size
=
16
,
attn_backend
=
"FLASHINFER"
),
],
}
def
_test_cp_gsm8k
(
model_id
:
str
,
model_id
:
str
,
parallel_setup
:
ParallelSetup
,
parallel_setup
:
ParallelSetup
,
distributed_backend
:
str
,
distributed_backend
:
str
,
...
@@ -121,7 +161,7 @@ def _compare_cp_with_tp(
...
@@ -121,7 +161,7 @@ def _compare_cp_with_tp(
chunked_prefill
,
chunked_prefill
,
)
=
parallel_setup
)
=
parallel_setup
multi_node_only
,
load_format
,
attn_backend
=
test_options
multi_node_only
,
attn_backend
=
test_options
model_info
=
HF_EXAMPLE_MODELS
.
find_hf_info
(
model_id
)
model_info
=
HF_EXAMPLE_MODELS
.
find_hf_info
(
model_id
)
model_info
.
check_transformers_version
(
on_fail
=
"skip"
)
model_info
.
check_transformers_version
(
on_fail
=
"skip"
)
...
@@ -130,22 +170,7 @@ def _compare_cp_with_tp(
...
@@ -130,22 +170,7 @@ def _compare_cp_with_tp(
tokenizer_mode
=
model_info
.
tokenizer_mode
tokenizer_mode
=
model_info
.
tokenizer_mode
hf_overrides
=
model_info
.
hf_overrides
hf_overrides
=
model_info
.
hf_overrides
if
load_format
==
"dummy"
:
model_info
.
check_available_online
(
on_fail
=
"skip"
)
# Avoid OOM
text_overrides
=
{
"num_hidden_layers"
:
4
,
"hidden_size"
:
512
,
"intermediate_size"
:
800
,
"num_attention_heads"
:
4
,
"num_key_value_heads"
:
1
,
}
if
is_multimodal
:
hf_overrides
.
update
({
"text_config"
:
text_overrides
})
else
:
hf_overrides
.
update
(
text_overrides
)
else
:
model_info
.
check_available_online
(
on_fail
=
"skip"
)
if
num_gpus_available
<
tp_size
*
pp_size
:
if
num_gpus_available
<
tp_size
*
pp_size
:
pytest
.
skip
(
f
"Need at least
{
tp_size
}
x
{
pp_size
}
GPUs"
)
pytest
.
skip
(
f
"Need at least
{
tp_size
}
x
{
pp_size
}
GPUs"
)
...
@@ -157,90 +182,70 @@ def _compare_cp_with_tp(
...
@@ -157,90 +182,70 @@ def _compare_cp_with_tp(
if
multi_node_only
and
not
VLLM_MULTI_NODE
:
if
multi_node_only
and
not
VLLM_MULTI_NODE
:
pytest
.
skip
(
"Not in multi-node setting"
)
pytest
.
skip
(
"Not in multi-node setting"
)
common
_args
=
[
server
_args
=
[
# use half precision for speed and memory savings in CI environment
# use half precision for speed and memory savings in CI environment
"--dtype"
,
"--dtype"
,
"bfloat16"
,
"bfloat16"
,
"--max-model-len"
,
"--max-model-len"
,
"
2048
"
,
"
4096
"
,
"--max-num-seqs"
,
"--max-num-seqs"
,
"
8
"
,
"
64
"
,
]
]
if
chunked_prefill
:
if
chunked_prefill
:
common
_args
.
append
(
"--enable-chunked-prefill"
)
server
_args
.
append
(
"--enable-chunked-prefill"
)
if
eager_mode
:
if
eager_mode
:
common
_args
.
append
(
"--enforce-eager"
)
server
_args
.
append
(
"--enforce-eager"
)
if
runner
!=
"auto"
:
if
runner
!=
"auto"
:
common
_args
.
extend
([
"--runner"
,
runner
])
server
_args
.
extend
([
"--runner"
,
runner
])
if
trust_remote_code
:
if
trust_remote_code
:
common
_args
.
append
(
"--trust-remote-code"
)
server
_args
.
append
(
"--trust-remote-code"
)
if
tokenizer_mode
:
if
tokenizer_mode
:
common_args
.
extend
([
"--tokenizer-mode"
,
tokenizer_mode
])
server_args
.
extend
([
"--tokenizer-mode"
,
tokenizer_mode
])
if
load_format
:
common_args
.
extend
([
"--load-format"
,
load_format
])
if
hf_overrides
:
if
hf_overrides
:
common_args
.
extend
([
"--hf-overrides"
,
json
.
dumps
(
hf_overrides
)])
server_args
.
extend
([
"--hf-overrides"
,
json
.
dumps
(
hf_overrides
)])
if
not
attn_backend
:
server_args
.
extend
(
cp_env
=
tp_env
=
{}
[
else
:
"--tensor-parallel-size"
,
cp_env
=
tp_env
=
{
str
(
tp_size
),
"VLLM_ATTENTION_BACKEND"
:
attn_backend
,
"--pipeline-parallel-size"
,
}
str
(
pp_size
),
"--decode-context-parallel-size"
,
cp_args
=
[
str
(
dcp_size
),
*
common_args
,
"--dcp-kv-cache-interleave-size"
,
"--tensor-parallel-size"
,
str
(
cp_kv_cache_interleave_size
),
str
(
tp_size
),
"--distributed-executor-backend"
,
"--pipeline-parallel-size"
,
distributed_backend
,
str
(
pp_size
),
]
"--decode-context-parallel-size"
,
)
str
(
dcp_size
),
"--dcp-kv-cache-interleave-size"
,
str
(
cp_kv_cache_interleave_size
),
"--distributed-executor-backend"
,
distributed_backend
,
]
tp_args
=
[
server_env
=
{}
*
common_args
,
if
attn_backend
:
"--tensor-parallel-size"
,
server_env
[
"VLLM_ATTENTION_BACKEND"
]
=
attn_backend
str
(
tp_size
),
"--pipeline-parallel-size"
,
str
(
pp_size
),
"--distributed-executor-backend"
,
distributed_backend
,
]
compare_two_settings
(
with
RemoteOpenAIServer
(
model_id
,
model_id
,
cp_args
,
server_args
,
tp_args
,
env_dict
=
server_env
,
cp_env
,
tp_env
,
method
=
method
,
max_wait_seconds
=
720
,
max_wait_seconds
=
720
,
)
)
as
remote_server
:
host
=
f
"http://
{
remote_server
.
host
}
"
port
=
remote_server
.
port
CP_TEXT_GENERATION_MODELS
=
{
"deepseek-ai/DeepSeek-V2-Lite-Chat"
:
[
# Run GSM8K evaluation
CPTestSettings
.
detailed
(),
results
=
evaluate_gsm8k
(
CPTestSettings
.
detailed
(
tp_base
=
2
),
num_questions
=
NUM_QUESTIONS
,
CPTestSettings
.
detailed
(
tp_base
=
2
,
cp_kv_cache_interleave_size
=
64
),
num_shots
=
NUM_SHOTS
,
],
host
=
host
,
"bigcode/gpt_bigcode-santacoder"
:
[
port
=
port
,
CPTestSettings
.
detailed
(),
)
CPTestSettings
.
detailed
(
tp_base
=
2
),
],
}
CP_TEST_MODELS
=
[
# Validate accuracy is reasonable
# TODO support other models
accuracy
=
results
[
"accuracy"
]
# [LANGUAGE GENERATION
]
min_accuracy
=
MIN_ACCURACY
[
model_id
]
"deepseek-ai/DeepSeek-V2-Lite-Chat"
,
assert
accuracy
>=
min_accuracy
,
(
"bigcode/gpt_bigcode-santacoder"
,
f
"TP+DCP accuracy too low:
{
accuracy
:.
3
f
}
<
{
min_accuracy
:.
3
f
}
"
]
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -274,12 +279,12 @@ def test_cp_generation(
...
@@ -274,12 +279,12 @@ def test_cp_generation(
):
):
pytest
.
skip
(
reason
=
"MLA+DCP requires compute capability of 9.0 or higher"
)
pytest
.
skip
(
reason
=
"MLA+DCP requires compute capability of 9.0 or higher"
)
if
(
if
(
model_id
==
"
bigcode/gpt_bigcode-santacoder
"
model_id
==
"
Qwen/Qwen2.5-1.5B-Instruct
"
and
torch
.
cuda
.
get_device_capability
()
!=
(
9
,
0
)
and
torch
.
cuda
.
get_device_capability
()
!=
(
9
,
0
)
):
):
pytest
.
skip
(
reason
=
"GQA+DCP currently requires compute capability of 9.0"
)
pytest
.
skip
(
reason
=
"GQA+DCP currently requires compute capability of 9.0"
)
_
compare_cp_with_tp
(
_
test_cp_gsm8k
(
model_id
,
model_id
,
parallel_setup
,
parallel_setup
,
distributed_backend
,
distributed_backend
,
...
...
tests/distributed/test_eplb_algo.py
View file @
8d75f22e
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
import
pytest
import
pytest
import
torch
import
torch
from
vllm.distributed.eplb.
rebalance_algo
import
rebalance_experts
from
vllm.distributed.eplb.
policy.default
import
DefaultEplbPolicy
def
test_basic_rebalance
():
def
test_basic_rebalance
():
...
@@ -23,7 +23,7 @@ def test_basic_rebalance():
...
@@ -23,7 +23,7 @@ def test_basic_rebalance():
num_nodes
=
2
num_nodes
=
2
num_gpus
=
8
num_gpus
=
8
phy2log
,
log2phy
,
logcnt
=
rebalance_experts
(
phy2log
,
log2phy
,
logcnt
=
DefaultEplbPolicy
.
rebalance_experts
(
weight
,
num_replicas
,
num_groups
,
num_nodes
,
num_gpus
weight
,
num_replicas
,
num_groups
,
num_nodes
,
num_gpus
)
)
...
@@ -77,7 +77,7 @@ def test_single_gpu_case():
...
@@ -77,7 +77,7 @@ def test_single_gpu_case():
num_nodes
=
1
num_nodes
=
1
num_gpus
=
1
num_gpus
=
1
phy2log
,
log2phy
,
logcnt
=
rebalance_experts
(
phy2log
,
log2phy
,
logcnt
=
DefaultEplbPolicy
.
rebalance_experts
(
weight
,
num_replicas
,
num_groups
,
num_nodes
,
num_gpus
weight
,
num_replicas
,
num_groups
,
num_nodes
,
num_gpus
)
)
...
@@ -99,7 +99,7 @@ def test_equal_weights():
...
@@ -99,7 +99,7 @@ def test_equal_weights():
num_nodes
=
2
num_nodes
=
2
num_gpus
=
4
num_gpus
=
4
phy2log
,
log2phy
,
logcnt
=
rebalance_experts
(
phy2log
,
log2phy
,
logcnt
=
DefaultEplbPolicy
.
rebalance_experts
(
weight
,
num_replicas
,
num_groups
,
num_nodes
,
num_gpus
weight
,
num_replicas
,
num_groups
,
num_nodes
,
num_gpus
)
)
...
@@ -122,7 +122,7 @@ def test_extreme_weight_imbalance():
...
@@ -122,7 +122,7 @@ def test_extreme_weight_imbalance():
num_nodes
=
2
num_nodes
=
2
num_gpus
=
4
num_gpus
=
4
phy2log
,
log2phy
,
logcnt
=
rebalance_experts
(
phy2log
,
log2phy
,
logcnt
=
DefaultEplbPolicy
.
rebalance_experts
(
weight
,
num_replicas
,
num_groups
,
num_nodes
,
num_gpus
weight
,
num_replicas
,
num_groups
,
num_nodes
,
num_gpus
)
)
...
@@ -150,7 +150,7 @@ def test_multiple_layers():
...
@@ -150,7 +150,7 @@ def test_multiple_layers():
num_nodes
=
2
num_nodes
=
2
num_gpus
=
4
num_gpus
=
4
phy2log
,
log2phy
,
logcnt
=
rebalance_experts
(
phy2log
,
log2phy
,
logcnt
=
DefaultEplbPolicy
.
rebalance_experts
(
weight
,
num_replicas
,
num_groups
,
num_nodes
,
num_gpus
weight
,
num_replicas
,
num_groups
,
num_nodes
,
num_gpus
)
)
...
@@ -175,14 +175,14 @@ def test_parameter_validation():
...
@@ -175,14 +175,14 @@ def test_parameter_validation():
# Test non-divisible case - this should handle normally without throwing
# Test non-divisible case - this should handle normally without throwing
# errors because the function will fall back to global load balancing
# errors because the function will fall back to global load balancing
# strategy
# strategy
phy2log
,
log2phy
,
logcnt
=
rebalance_experts
(
weight
,
8
,
3
,
2
,
4
)
phy2log
,
log2phy
,
logcnt
=
DefaultEplbPolicy
.
rebalance_experts
(
weight
,
8
,
3
,
2
,
4
)
assert
phy2log
.
shape
==
(
1
,
8
)
assert
phy2log
.
shape
==
(
1
,
8
)
assert
logcnt
.
shape
==
(
1
,
4
)
assert
logcnt
.
shape
==
(
1
,
4
)
# Test cases that will actually cause errors:
# Test cases that will actually cause errors:
# num_physical_experts not divisible by num_gpus
# num_physical_experts not divisible by num_gpus
with
pytest
.
raises
(
AssertionError
):
with
pytest
.
raises
(
AssertionError
):
rebalance_experts
(
weight
,
7
,
2
,
2
,
4
)
# 7 not divisible by 4
DefaultEplbPolicy
.
rebalance_experts
(
weight
,
7
,
2
,
2
,
4
)
# 7 not divisible by 4
def
test_small_scale_hierarchical
():
def
test_small_scale_hierarchical
():
...
@@ -197,7 +197,7 @@ def test_small_scale_hierarchical():
...
@@ -197,7 +197,7 @@ def test_small_scale_hierarchical():
num_nodes
=
2
# 2 nodes
num_nodes
=
2
# 2 nodes
num_gpus
=
4
# 4 GPUs
num_gpus
=
4
# 4 GPUs
phy2log
,
log2phy
,
logcnt
=
rebalance_experts
(
phy2log
,
log2phy
,
logcnt
=
DefaultEplbPolicy
.
rebalance_experts
(
weight
,
num_replicas
,
num_groups
,
num_nodes
,
num_gpus
weight
,
num_replicas
,
num_groups
,
num_nodes
,
num_gpus
)
)
...
@@ -224,7 +224,7 @@ def test_global_load_balance_fallback():
...
@@ -224,7 +224,7 @@ def test_global_load_balance_fallback():
num_nodes
=
2
num_nodes
=
2
num_gpus
=
4
num_gpus
=
4
phy2log
,
log2phy
,
logcnt
=
rebalance_experts
(
phy2log
,
log2phy
,
logcnt
=
DefaultEplbPolicy
.
rebalance_experts
(
weight
,
num_replicas
,
num_groups
,
num_nodes
,
num_gpus
weight
,
num_replicas
,
num_groups
,
num_nodes
,
num_gpus
)
)
...
@@ -246,7 +246,7 @@ def test_device_compatibility(device):
...
@@ -246,7 +246,7 @@ def test_device_compatibility(device):
num_nodes
=
1
num_nodes
=
1
num_gpus
=
2
num_gpus
=
2
phy2log
,
log2phy
,
logcnt
=
rebalance_experts
(
phy2log
,
log2phy
,
logcnt
=
DefaultEplbPolicy
.
rebalance_experts
(
weight
,
num_replicas
,
num_groups
,
num_nodes
,
num_gpus
weight
,
num_replicas
,
num_groups
,
num_nodes
,
num_gpus
)
)
...
@@ -263,7 +263,9 @@ def test_additional_cases():
...
@@ -263,7 +263,9 @@ def test_additional_cases():
weight1
=
torch
.
tensor
(
weight1
=
torch
.
tensor
(
[[
50
,
100
,
75
,
120
,
90
,
60
,
80
,
110
,
40
,
70
,
95
,
85
,
65
,
55
,
45
,
35
]]
[[
50
,
100
,
75
,
120
,
90
,
60
,
80
,
110
,
40
,
70
,
95
,
85
,
65
,
55
,
45
,
35
]]
)
)
phy2log1
,
log2phy1
,
logcnt1
=
rebalance_experts
(
weight1
,
24
,
8
,
4
,
8
)
phy2log1
,
log2phy1
,
logcnt1
=
DefaultEplbPolicy
.
rebalance_experts
(
weight1
,
24
,
8
,
4
,
8
)
assert
phy2log1
.
shape
==
(
1
,
24
)
assert
phy2log1
.
shape
==
(
1
,
24
)
assert
logcnt1
.
shape
==
(
1
,
16
)
assert
logcnt1
.
shape
==
(
1
,
16
)
...
@@ -276,7 +278,9 @@ def test_additional_cases():
...
@@ -276,7 +278,9 @@ def test_additional_cases():
[
12
,
25
,
50
,
100
,
150
,
200
],
# Increasing weights
[
12
,
25
,
50
,
100
,
150
,
200
],
# Increasing weights
]
]
)
)
phy2log2
,
log2phy2
,
logcnt2
=
rebalance_experts
(
weight2
,
10
,
3
,
1
,
2
)
phy2log2
,
log2phy2
,
logcnt2
=
DefaultEplbPolicy
.
rebalance_experts
(
weight2
,
10
,
3
,
1
,
2
)
assert
phy2log2
.
shape
==
(
2
,
10
)
assert
phy2log2
.
shape
==
(
2
,
10
)
assert
logcnt2
.
shape
==
(
2
,
6
)
assert
logcnt2
.
shape
==
(
2
,
6
)
...
@@ -300,7 +304,7 @@ if __name__ == "__main__":
...
@@ -300,7 +304,7 @@ if __name__ == "__main__":
num_nodes
=
2
num_nodes
=
2
num_gpus
=
8
num_gpus
=
8
phy2log
,
log2phy
,
logcnt
=
rebalance_experts
(
phy2log
,
log2phy
,
logcnt
=
DefaultEplbPolicy
.
rebalance_experts
(
weight
,
num_replicas
,
num_groups
,
num_nodes
,
num_gpus
weight
,
num_replicas
,
num_groups
,
num_nodes
,
num_gpus
)
)
print
(
phy2log
)
print
(
phy2log
)
...
...
tests/distributed/test_eplb_spec_decode.py
View file @
8d75f22e
...
@@ -6,6 +6,7 @@ import lm_eval
...
@@ -6,6 +6,7 @@ import lm_eval
import
pytest
import
pytest
from
tests.utils
import
large_gpu_mark
from
tests.utils
import
large_gpu_mark
from
vllm.platforms
import
current_platform
def
get_model_args
(
def
get_model_args
(
...
@@ -45,6 +46,12 @@ def get_model_args(
...
@@ -45,6 +46,12 @@ def get_model_args(
return
model_args
return
model_args
pytestmark
=
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
"EPLB with Spec Decode is a work in progress on ROCm."
,
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"model_setup"
,
"model_setup"
,
[
[
...
...
tests/distributed/test_kvlayout.py
View file @
8d75f22e
...
@@ -61,7 +61,7 @@ def test_get_kv_connector_cache_layout_with_multi_connector():
...
@@ -61,7 +61,7 @@ def test_get_kv_connector_cache_layout_with_multi_connector():
kv_role
=
"kv_both"
,
kv_role
=
"kv_both"
,
kv_connector_extra_config
=
{
kv_connector_extra_config
=
{
"connectors"
:
[
"connectors"
:
[
{
"kv_connector"
:
"
SharedStorag
eConnector"
,
"kv_role"
:
"kv_both"
},
{
"kv_connector"
:
"
Exampl
eConnector"
,
"kv_role"
:
"kv_both"
},
{
"kv_connector"
:
"NixlConnector"
,
"kv_role"
:
"kv_both"
},
{
"kv_connector"
:
"NixlConnector"
,
"kv_role"
:
"kv_both"
},
]
]
},
},
...
...
tests/distributed/test_pipeline_parallel.py
View file @
8d75f22e
...
@@ -109,7 +109,7 @@ TEXT_GENERATION_MODELS = {
...
@@ -109,7 +109,7 @@ TEXT_GENERATION_MODELS = {
"baichuan-inc/Baichuan2-13B-Chat"
:
PPTestSettings
.
fast
(),
"baichuan-inc/Baichuan2-13B-Chat"
:
PPTestSettings
.
fast
(),
"bigscience/bloomz-1b1"
:
PPTestSettings
.
fast
(),
"bigscience/bloomz-1b1"
:
PPTestSettings
.
fast
(),
"zai-org/chatglm3-6b"
:
PPTestSettings
.
fast
(),
"zai-org/chatglm3-6b"
:
PPTestSettings
.
fast
(),
"Cohere
ForAI
/c4ai-command-r-v01"
:
PPTestSettings
.
fast
(
load_format
=
"dummy"
),
"Cohere
Labs
/c4ai-command-r-v01"
:
PPTestSettings
.
fast
(
load_format
=
"dummy"
),
"databricks/dbrx-instruct"
:
PPTestSettings
.
fast
(
load_format
=
"dummy"
),
"databricks/dbrx-instruct"
:
PPTestSettings
.
fast
(
load_format
=
"dummy"
),
"Deci/DeciLM-7B-instruct"
:
PPTestSettings
.
fast
(),
"Deci/DeciLM-7B-instruct"
:
PPTestSettings
.
fast
(),
"deepseek-ai/deepseek-llm-7b-chat"
:
PPTestSettings
.
fast
(),
"deepseek-ai/deepseek-llm-7b-chat"
:
PPTestSettings
.
fast
(),
...
...
tests/distributed/test_shm_storage.py
View file @
8d75f22e
...
@@ -28,7 +28,7 @@ def _dummy_elem(modality: str, key: str, size: int):
...
@@ -28,7 +28,7 @@ def _dummy_elem(modality: str, key: str, size: int):
modality
=
modality
,
modality
=
modality
,
key
=
key
,
key
=
key
,
data
=
torch
.
empty
((
size
,),
dtype
=
torch
.
int8
),
data
=
torch
.
empty
((
size
,),
dtype
=
torch
.
int8
),
field
=
MultiModalSharedField
(
1
),
field
=
MultiModalSharedField
(
batch_size
=
1
),
)
)
...
...
Prev
1
…
6
7
8
9
10
11
12
13
14
…
33
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment