Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
14ce5242
Unverified
Commit
14ce5242
authored
Jan 15, 2026
by
Lucas Wilkinson
Committed by
GitHub
Jan 16, 2026
Browse files
[CI] Breakup h200 tests (#30499)
Signed-off-by:
Lucas Wilkinson
<
lwilkins@redhat.com
>
parent
4ae77dfd
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
390 additions
and
296 deletions
+390
-296
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+53
-4
tests/compile/distributed/test_fusions_e2e.py
tests/compile/distributed/test_fusions_e2e.py
+15
-283
tests/compile/fusion_test_utils.py
tests/compile/fusion_test_utils.py
+208
-0
tests/compile/test_fusion_attn.py
tests/compile/test_fusion_attn.py
+113
-8
vllm/env_override.py
vllm/env_override.py
+1
-1
No files found.
.buildkite/test-pipeline.yaml
View file @
14ce5242
...
...
@@ -1047,6 +1047,47 @@ steps:
# Run all e2e fusion tests
-
pytest -v -s tests/compile/distributed/test_fusions_e2e.py
-
label
:
Hopper Fusion E2E Tests (H100)
# 10min
timeout_in_minutes
:
70
working_dir
:
"
/vllm-workspace/"
gpu
:
h100
optional
:
true
source_file_dependencies
:
-
csrc/quantization/fp4/
-
vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
-
vllm/v1/attention/backends/flashinfer.py
-
vllm/compilation/
# can affect pattern matching
-
vllm/model_executor/layers/layernorm.py
-
vllm/model_executor/layers/activation.py
-
vllm/model_executor/layers/quantization/input_quant_fp8.py
-
tests/compile/test_fusion_attn.py
commands
:
-
export VLLM_TEST_CLEAN_GPU_MEMORY=1
-
pytest -v -s tests/compile/test_fusion_attn.py
-
label
:
Hopper Fusion Distributed E2E Tests (2xH100)
# 70min
timeout_in_minutes
:
70
working_dir
:
"
/vllm-workspace/"
gpu
:
h100
optional
:
true
num_gpus
:
2
source_file_dependencies
:
-
csrc/quantization/fp4/
-
vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
-
vllm/v1/attention/backends/flashinfer.py
-
vllm/compilation/
# can affect pattern matching
-
vllm/model_executor/layers/layernorm.py
-
vllm/model_executor/layers/activation.py
-
vllm/model_executor/layers/quantization/input_quant_fp8.py
-
tests/compile/distributed/test_fusions_e2e.py
commands
:
-
export VLLM_TEST_CLEAN_GPU_MEMORY=1
# Run all e2e fusion tests
-
pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'
-
pytest -v -s tests/compile/distributed/test_fusion_all_reduce.py
-
label
:
Blackwell GPT-OSS Eval
timeout_in_minutes
:
60
working_dir
:
"
/vllm-workspace/"
...
...
@@ -1346,6 +1387,18 @@ steps:
-
export VLLM_USE_DEEP_GEMM=0
# We found Triton is faster than DeepGEMM for H100
-
pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large-hopper.txt --tp-size=4
-
label
:
Sequence Parallel Tests (H100)
# 60 min
timeout_in_minutes
:
60
working_dir
:
"
/vllm-workspace/"
gpu
:
h100
optional
:
true
num_gpus
:
2
commands
:
-
export VLLM_TEST_CLEAN_GPU_MEMORY=1
# Run sequence parallel tests
-
pytest -v -s tests/distributed/test_sequence_parallel.py
-
pytest -v -s tests/compile/distributed/test_sequence_parallelism.py
##### H200 test #####
-
label
:
Distributed Tests (H200)
# optional
gpu
:
h200
...
...
@@ -1354,10 +1407,6 @@ steps:
num_gpus
:
2
commands
:
-
VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_async_tp.py
-
pytest -v -s tests/compile/distributed/test_sequence_parallelism.py
-
pytest -v -s tests/compile/distributed/test_fusion_all_reduce.py
-
"
VLLM_TEST_CLEAN_GPU_MEMORY=1
pytest
-v
-s
tests/compile/distributed/test_fusions_e2e.py
-k
'not
Llama-4'"
-
VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/distributed/test_sequence_parallel.py
-
pytest -v -s tests/distributed/test_context_parallel.py
-
CUDA_VISIBLE_DEVICES=1,2 VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model=Qwen/Qwen1.5-MoE-A2.7B -tp=1 -dp=2 --max-model-len=2048 --all2all-backend=deepep_high_throughput
-
pytest -v -s tests/v1/distributed/test_dbo.py
...
...
tests/compile/distributed/test_fusions_e2e.py
View file @
14ce5242
...
...
@@ -3,16 +3,26 @@
from
__future__
import
annotations
import
itertools
import
logging
from
collections.abc
import
Iterable
from
typing
import
Any
,
NamedTuple
from
typing
import
Any
import
pytest
import
regex
as
re
from
tests.compile.fusion_test_utils
import
(
CUSTOM_OPS_FP8
,
CUSTOM_OPS_QUANT_RMS_NORM
,
CUSTOM_OPS_RMS_NORM
,
MODELS
,
MODELS_FP4
,
MODELS_FP8
,
MODELS_GROUP_FP8
,
Matches
,
custom_ops_product
,
is_blackwell
,
run_model
,
)
from
tests.v1.attention.utils
import
AttentionBackendEnum
from
vllm
import
LLM
,
SamplingParams
from
vllm.config
import
CompilationConfig
,
CompilationMode
,
CUDAGraphMode
,
PassConfig
from
vllm.platforms
import
current_platform
from
vllm.utils.flashinfer
import
has_flashinfer
...
...
@@ -20,228 +30,6 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer
from
...utils
import
flat_product
,
multi_gpu_test
is_blackwell
=
lambda
:
current_platform
.
is_device_capability_family
(
100
)
"""Are we running on Blackwell, a lot of tests depend on it"""
class
Matches
(
NamedTuple
):
attention_fusion
:
int
=
0
allreduce_fusion
:
int
=
0
rms_quant_norm_fusion
:
int
=
0
sequence_parallel
:
int
=
0
async_tp
:
int
=
0
class
ModelBackendTestCase
(
NamedTuple
):
model_name
:
str
model_kwargs
:
dict
[
str
,
Any
]
backend
:
AttentionBackendEnum
matches
:
Matches
MODELS_FP8
:
list
[
ModelBackendTestCase
]
=
[]
MODELS_FP4
:
list
[
ModelBackendTestCase
]
=
[]
MODELS_GROUP_FP8
:
list
[
ModelBackendTestCase
]
=
[]
MODELS
:
list
[
ModelBackendTestCase
]
=
[]
# tp-only
if
current_platform
.
is_cuda
():
MODELS_FP8
=
[
ModelBackendTestCase
(
# Use smaller model for L40s in CI
model_name
=
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"
,
model_kwargs
=
dict
(
max_model_len
=
1024
,
kv_cache_dtype
=
"fp8"
),
backend
=
AttentionBackendEnum
.
TRITON_ATTN
,
matches
=
Matches
(
attention_fusion
=
32
,
allreduce_fusion
=
65
,
sequence_parallel
=
65
,
async_tp
=
128
,
),
),
ModelBackendTestCase
(
model_name
=
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"
,
model_kwargs
=
dict
(
max_model_len
=
1024
,
kv_cache_dtype
=
"fp8"
),
# TODO FlashInfer attn broken on Hopper with kvcache=fp8:
# https://github.com/vllm-project/vllm/issues/28568
backend
=
AttentionBackendEnum
.
FLASHINFER
if
is_blackwell
()
else
AttentionBackendEnum
.
TRITON_ATTN
,
matches
=
Matches
(
attention_fusion
=
48
,
allreduce_fusion
=
96
,
sequence_parallel
=
96
,
async_tp
=
95
,
# mlp is moe, no fusion there
),
),
]
MODELS_FP4
=
[
ModelBackendTestCase
(
model_name
=
"nvidia/Llama-3.1-8B-Instruct-FP4"
,
model_kwargs
=
dict
(
max_model_len
=
1024
,
kv_cache_dtype
=
"fp8"
),
backend
=
AttentionBackendEnum
.
FLASHINFER
,
matches
=
Matches
(
attention_fusion
=
32
,
allreduce_fusion
=
65
,
sequence_parallel
=
65
,
async_tp
=
128
,
),
),
]
# TP only
MODELS
=
[
ModelBackendTestCase
(
model_name
=
"meta-llama/Llama-3.1-8B-Instruct"
,
model_kwargs
=
dict
(
max_model_len
=
1024
),
backend
=
AttentionBackendEnum
.
TRITON_ATTN
,
matches
=
Matches
(
attention_fusion
=
0
,
allreduce_fusion
=
65
,
sequence_parallel
=
65
,
async_tp
=
128
,
),
),
ModelBackendTestCase
(
model_name
=
"Qwen/Qwen3-30B-A3B"
,
model_kwargs
=
dict
(
max_model_len
=
1024
),
backend
=
AttentionBackendEnum
.
TRITON_ATTN
,
matches
=
Matches
(
attention_fusion
=
0
,
allreduce_fusion
=
97
,
sequence_parallel
=
97
,
async_tp
=
96
,
# MLP is MoE, half the fusions of dense
),
),
]
elif
current_platform
.
is_rocm
():
MODELS_FP8
=
[
ModelBackendTestCase
(
model_name
=
"amd/Llama-3.1-8B-Instruct-FP8-KV"
,
model_kwargs
=
dict
(
max_model_len
=
1024
),
backend
=
AttentionBackendEnum
.
TRITON_ATTN
,
matches
=
Matches
(
attention_fusion
=
32
),
),
ModelBackendTestCase
(
model_name
=
"amd/Llama-3.1-8B-Instruct-FP8-KV"
,
model_kwargs
=
dict
(
max_model_len
=
1024
),
backend
=
AttentionBackendEnum
.
ROCM_ATTN
,
matches
=
Matches
(
attention_fusion
=
32
),
),
ModelBackendTestCase
(
model_name
=
"amd/Llama-3.1-8B-Instruct-FP8-KV"
,
model_kwargs
=
dict
(
max_model_len
=
1024
),
backend
=
AttentionBackendEnum
.
ROCM_AITER_UNIFIED_ATTN
,
matches
=
Matches
(
attention_fusion
=
32
),
),
]
CUSTOM_OPS_FP8
=
[
"-quant_fp8"
,
"+quant_fp8"
]
def
has_cuda_graph_wrapper_metadata
()
->
bool
:
from
importlib
import
import_module
try
:
module
=
import_module
(
"torch._inductor.utils"
)
module
.
CUDAGraphWrapperMetadata
# noqa B018
except
AttributeError
:
return
False
return
True
@
pytest
.
mark
.
parametrize
(
"model_name, model_kwargs, backend, matches, custom_ops"
,
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
list
(
flat_product
(
MODELS_FP8
,
CUSTOM_OPS_FP8
))
# quant_fp4 only has the custom impl
+
list
(
flat_product
(
MODELS_FP4
,
[
""
])),
)
@
pytest
.
mark
.
parametrize
(
"inductor_graph_partition"
,
[
pytest
.
param
(
True
,
marks
=
pytest
.
mark
.
skipif
(
not
has_cuda_graph_wrapper_metadata
(),
reason
=
"This test requires"
"torch._inductor.utils.CUDAGraphWrapperMetadata to run"
,
),
),
False
,
],
)
def
test_attn_quant
(
model_name
:
str
,
model_kwargs
:
dict
[
str
,
Any
],
backend
:
AttentionBackendEnum
,
matches
:
Matches
,
custom_ops
:
str
,
inductor_graph_partition
:
bool
,
caplog_mp_spawn
,
monkeypatch
,
):
if
backend
==
AttentionBackendEnum
.
FLASHINFER
and
(
not
is_blackwell
()
or
not
has_flashinfer
()
):
pytest
.
skip
(
"FlashInfer attn fusion requires Blackwell and flashinfer"
)
if
inductor_graph_partition
and
not
is_torch_equal_or_newer
(
"2.9.0.dev"
):
pytest
.
skip
(
"Inductor graph partition requires torch>=2.9"
)
custom_ops_list
=
custom_ops
.
split
(
","
)
if
custom_ops
else
[]
if
inductor_graph_partition
:
mode
=
CUDAGraphMode
.
FULL_AND_PIECEWISE
splitting_ops
:
list
[
str
]
|
None
=
None
else
:
# FIXME: Llama-4-Scout-17B-16E-Instruct-FP8 + FlashInfer + Blackwell end at
# CUDAGraphMode.NONE here because it derives an attention backend that
# does not support full cudagraphs
mode
=
CUDAGraphMode
.
FULL_DECODE_ONLY
splitting_ops
=
[]
# Disable, compile cache to make sure custom passes run.
# Otherwise, we can't verify fusion happened through the logs.
monkeypatch
.
setenv
(
"VLLM_DISABLE_COMPILE_CACHE"
,
"1"
)
# To capture subprocess logs, we need to know whether spawn or fork is used.
# Force spawn as it is more general.
monkeypatch
.
setenv
(
"VLLM_WORKER_MULTIPROC_METHOD"
,
"spawn"
)
model_kwargs
[
"attention_config"
]
=
{
"backend"
:
backend
.
name
}
compilation_config
=
CompilationConfig
(
# Testing properties
custom_ops
=
custom_ops_list
,
use_inductor_graph_partition
=
inductor_graph_partition
,
cudagraph_mode
=
mode
,
splitting_ops
=
splitting_ops
,
# Common
mode
=
CompilationMode
.
VLLM_COMPILE
,
pass_config
=
PassConfig
(
fuse_attn_quant
=
True
,
eliminate_noops
=
True
),
# Inductor caches custom passes by default as well via uuid
inductor_compile_config
=
{
"force_disable_caches"
:
True
},
)
with
caplog_mp_spawn
(
logging
.
DEBUG
)
as
log_holder
:
run_model
(
compilation_config
,
model_name
,
**
model_kwargs
)
log_matches
=
re
.
findall
(
r
"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes"
,
log_holder
.
text
,
)
assert
len
(
log_matches
)
==
1
,
log_holder
.
text
assert
int
(
log_matches
[
0
])
==
matches
.
attention_fusion
CUSTOM_OPS_RMS_NORM
=
[
"-rms_norm"
,
"+rms_norm"
]
def
custom_ops_product
(
*
custom_ops_lists
:
list
[
str
])
->
Iterable
[
str
]:
for
op_list
in
itertools
.
product
(
*
custom_ops_lists
):
yield
","
.
join
(
op_list
)
@
multi_gpu_test
(
num_gpus
=
2
)
@
pytest
.
mark
.
parametrize
(
...
...
@@ -421,7 +209,7 @@ def test_tp2_attn_quant_async_tp(
custom_ops
=
custom_ops_list
,
splitting_ops
=
splitting_ops
,
# Common
level
=
CompilationMode
.
VLLM_COMPILE
,
mode
=
CompilationMode
.
VLLM_COMPILE
,
pass_config
=
PassConfig
(
fuse_attn_quant
=
True
,
eliminate_noops
=
True
,
...
...
@@ -464,62 +252,6 @@ def test_tp2_attn_quant_async_tp(
assert
int
(
log_matches
[
1
])
==
matches
.
async_tp
def
run_model
(
compile_config
:
int
|
CompilationConfig
,
model
:
str
,
**
model_kwargs
):
compilation_config
=
(
compile_config
if
isinstance
(
compile_config
,
CompilationConfig
)
else
CompilationConfig
(
mode
=
compile_config
)
)
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
sampling_params
=
SamplingParams
(
temperature
=
0
)
# Allow override from model_kwargs
model_kwargs
=
{
"tensor_parallel_size"
:
1
,
**
model_kwargs
}
model_kwargs
=
{
"disable_custom_all_reduce"
:
True
,
**
model_kwargs
}
# No cudagraphs by default
if
compilation_config
.
cudagraph_mode
is
None
:
compilation_config
.
cudagraph_mode
=
CUDAGraphMode
.
NONE
llm
=
LLM
(
model
=
model
,
compilation_config
=
compilation_config
,
**
model_kwargs
,
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
# Print the outputs.
for
output
in
outputs
:
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
# Get the compile ranges split points after vllm config post init
# in order to compute compile ranges correctly
compilation_config
.
compile_ranges_split_points
=
(
llm
.
llm_engine
.
vllm_config
.
compilation_config
.
compile_ranges_split_points
)
if
current_platform
.
is_cuda
():
MODELS_GROUP_FP8
=
[
ModelBackendTestCase
(
model_name
=
"Qwen/Qwen3-30B-A3B-FP8"
,
model_kwargs
=
dict
(
max_model_len
=
1024
,
kv_cache_dtype
=
"fp8"
),
backend
=
AttentionBackendEnum
.
TRITON_ATTN
,
matches
=
Matches
(
rms_quant_norm_fusion
=
48
,
),
),
]
CUSTOM_OPS_QUANT_RMS_NORM
=
[
"+quant_fp8,+rms_norm"
]
@
pytest
.
mark
.
parametrize
(
"model_name, model_kwargs, backend, matches, custom_ops"
,
# Test rms norm+group quant_fp8 fusion
...
...
tests/compile/fusion_test_utils.py
0 → 100644
View file @
14ce5242
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Shared utilities for fusion tests (e.g. test_fusion_attn.py)."""
from
__future__
import
annotations
import
itertools
from
collections.abc
import
Iterable
from
typing
import
Any
,
NamedTuple
from
tests.v1.attention.utils
import
AttentionBackendEnum
from
vllm
import
LLM
,
SamplingParams
from
vllm.config
import
CompilationConfig
,
CUDAGraphMode
from
vllm.platforms
import
current_platform
is_blackwell
=
lambda
:
current_platform
.
is_device_capability_family
(
100
)
"""Are we running on Blackwell, a lot of tests depend on it"""
def
has_cuda_graph_wrapper_metadata
()
->
bool
:
from
importlib
import
import_module
try
:
module
=
import_module
(
"torch._inductor.utils"
)
module
.
CUDAGraphWrapperMetadata
# noqa B018
except
AttributeError
:
return
False
return
True
class
Matches
(
NamedTuple
):
attention_fusion
:
int
=
0
allreduce_fusion
:
int
=
0
sequence_parallel
:
int
=
0
async_tp
:
int
=
0
rms_quant_norm_fusion
:
int
=
0
class
ModelBackendTestCase
(
NamedTuple
):
model_name
:
str
model_kwargs
:
dict
[
str
,
Any
]
backend
:
AttentionBackendEnum
matches
:
Matches
# E2E model test cases
MODELS_FP8
:
list
[
ModelBackendTestCase
]
=
[]
MODELS_FP4
:
list
[
ModelBackendTestCase
]
=
[]
MODELS
:
list
[
ModelBackendTestCase
]
=
[]
# tp-only (unquantized)
MODELS_GROUP_FP8
:
list
[
ModelBackendTestCase
]
=
[]
if
current_platform
.
is_cuda
():
MODELS_FP8
=
[
ModelBackendTestCase
(
# Use smaller model for L40s in CI
model_name
=
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"
,
model_kwargs
=
dict
(
max_model_len
=
1024
,
kv_cache_dtype
=
"fp8"
),
backend
=
AttentionBackendEnum
.
TRITON_ATTN
,
matches
=
Matches
(
attention_fusion
=
32
,
allreduce_fusion
=
65
,
sequence_parallel
=
65
,
async_tp
=
128
,
),
),
ModelBackendTestCase
(
model_name
=
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"
,
model_kwargs
=
dict
(
max_model_len
=
1024
,
kv_cache_dtype
=
"fp8"
),
# TODO FlashInfer attn broken on Hopper with kvcache=fp8:
# https://github.com/vllm-project/vllm/issues/28568
backend
=
AttentionBackendEnum
.
FLASHINFER
if
is_blackwell
()
else
AttentionBackendEnum
.
TRITON_ATTN
,
matches
=
Matches
(
attention_fusion
=
48
,
allreduce_fusion
=
96
,
sequence_parallel
=
96
,
async_tp
=
95
,
# mlp is moe, no fusion there
),
),
]
MODELS_FP4
=
[
ModelBackendTestCase
(
model_name
=
"nvidia/Llama-3.1-8B-Instruct-FP4"
,
model_kwargs
=
dict
(
max_model_len
=
1024
,
kv_cache_dtype
=
"fp8"
),
backend
=
AttentionBackendEnum
.
FLASHINFER
,
matches
=
Matches
(
attention_fusion
=
32
,
allreduce_fusion
=
65
,
sequence_parallel
=
65
,
async_tp
=
128
,
),
),
]
# TP only (unquantized models)
MODELS
=
[
ModelBackendTestCase
(
model_name
=
"meta-llama/Llama-3.1-8B-Instruct"
,
model_kwargs
=
dict
(
max_model_len
=
1024
),
backend
=
AttentionBackendEnum
.
TRITON_ATTN
,
matches
=
Matches
(
attention_fusion
=
0
,
allreduce_fusion
=
65
,
sequence_parallel
=
65
,
async_tp
=
128
,
),
),
ModelBackendTestCase
(
model_name
=
"Qwen/Qwen3-30B-A3B"
,
model_kwargs
=
dict
(
max_model_len
=
1024
),
backend
=
AttentionBackendEnum
.
TRITON_ATTN
,
matches
=
Matches
(
attention_fusion
=
0
,
allreduce_fusion
=
97
,
sequence_parallel
=
97
,
async_tp
=
96
,
# MLP is MoE, half the fusions of dense
),
),
]
MODELS_GROUP_FP8
=
[
ModelBackendTestCase
(
model_name
=
"Qwen/Qwen3-30B-A3B-FP8"
,
model_kwargs
=
dict
(
max_model_len
=
1024
,
kv_cache_dtype
=
"fp8"
),
backend
=
AttentionBackendEnum
.
TRITON_ATTN
,
matches
=
Matches
(
rms_quant_norm_fusion
=
48
,
),
),
]
elif
current_platform
.
is_rocm
():
MODELS_FP8
=
[
ModelBackendTestCase
(
model_name
=
"amd/Llama-3.1-8B-Instruct-FP8-KV"
,
model_kwargs
=
dict
(
max_model_len
=
1024
),
backend
=
AttentionBackendEnum
.
TRITON_ATTN
,
matches
=
Matches
(
attention_fusion
=
32
),
),
ModelBackendTestCase
(
model_name
=
"amd/Llama-3.1-8B-Instruct-FP8-KV"
,
model_kwargs
=
dict
(
max_model_len
=
1024
),
backend
=
AttentionBackendEnum
.
ROCM_ATTN
,
matches
=
Matches
(
attention_fusion
=
32
),
),
ModelBackendTestCase
(
model_name
=
"amd/Llama-3.1-8B-Instruct-FP8-KV"
,
model_kwargs
=
dict
(
max_model_len
=
1024
),
backend
=
AttentionBackendEnum
.
ROCM_AITER_UNIFIED_ATTN
,
matches
=
Matches
(
attention_fusion
=
32
),
),
]
# Custom ops toggle lists for parametrization
CUSTOM_OPS_FP8
=
[
"-quant_fp8"
,
"+quant_fp8"
]
CUSTOM_OPS_RMS_NORM
=
[
"-rms_norm"
,
"+rms_norm"
]
CUSTOM_OPS_QUANT_RMS_NORM
=
[
"+quant_fp8,+rms_norm"
]
def
custom_ops_product
(
*
custom_ops_lists
:
list
[
str
])
->
Iterable
[
str
]:
"""Generate all combinations of custom ops for parametrization."""
for
op_list
in
itertools
.
product
(
*
custom_ops_lists
):
yield
","
.
join
(
op_list
)
def
run_model
(
compile_config
:
int
|
CompilationConfig
,
model
:
str
,
**
model_kwargs
):
"""Run a model with the given compilation config for E2E fusion tests."""
compilation_config
=
(
compile_config
if
isinstance
(
compile_config
,
CompilationConfig
)
else
CompilationConfig
(
mode
=
compile_config
)
)
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
sampling_params
=
SamplingParams
(
temperature
=
0
)
# Allow override from model_kwargs
model_kwargs
=
{
"tensor_parallel_size"
:
1
,
**
model_kwargs
}
model_kwargs
=
{
"disable_custom_all_reduce"
:
True
,
**
model_kwargs
}
# No cudagraphs by default
if
compilation_config
.
cudagraph_mode
is
None
:
compilation_config
.
cudagraph_mode
=
CUDAGraphMode
.
NONE
llm
=
LLM
(
model
=
model
,
compilation_config
=
compilation_config
,
**
model_kwargs
,
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
# Print the outputs.
for
output
in
outputs
:
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
# Get the compile ranges split points after vllm config post init
# in order to compute compile ranges correctly
compilation_config
.
compile_ranges_split_points
=
(
llm
.
llm_engine
.
vllm_config
.
compilation_config
.
compile_ranges_split_points
)
tests/compile/test_fusion_attn.py
View file @
14ce5242
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
copy
import
logging
from
typing
import
Any
import
pytest
import
regex
as
re
import
torch._dynamo
from
tests.compile.backend
import
LazyInitPass
,
TestBackend
from
tests.utils
import
flat_product
from
tests.compile.fusion_test_utils
import
(
CUSTOM_OPS_FP8
,
MODELS_FP4
,
MODELS_FP8
,
Matches
,
has_cuda_graph_wrapper_metadata
,
is_blackwell
,
run_model
,
)
from
tests.utils
import
cuda_device_count_stateless
,
flat_product
from
tests.v1.attention.utils
import
BatchSpec
,
create_common_attn_metadata
from
vllm._custom_ops
import
cutlass_scaled_fp4_mm
,
scaled_fp4_quant
from
vllm.attention.layer
import
Attention
...
...
@@ -20,6 +32,7 @@ from vllm.config import (
CacheConfig
,
CompilationConfig
,
CompilationMode
,
CUDAGraphMode
,
ModelConfig
,
PassConfig
,
SchedulerConfig
,
...
...
@@ -35,6 +48,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
Fp8LinearOp
from
vllm.platforms
import
current_platform
from
vllm.utils.flashinfer
import
has_flashinfer
from
vllm.utils.torch_utils
import
is_torch_equal_or_newer
from
vllm.v1.attention.backend
import
AttentionMetadata
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.kv_cache_interface
import
AttentionSpec
...
...
@@ -241,8 +255,8 @@ class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
)
MODELS_FP8
:
list
[
tuple
[
str
,
type
]]
=
[]
MODELS_FP4
:
list
[
tuple
[
str
,
type
]]
=
[]
PATTERN_TEST_
MODELS_FP8
:
list
[
tuple
[
str
,
type
]]
=
[]
PATTERN_TEST_
MODELS_FP4
:
list
[
tuple
[
str
,
type
]]
=
[]
HEADS
:
list
[
tuple
[
int
,
int
]]
=
[]
SPLIT_ATTENTION
:
list
[
bool
]
=
[]
BACKENDS_FP8
:
list
[
AttentionBackendEnum
]
=
[]
...
...
@@ -250,13 +264,13 @@ BACKENDS_FP4: list[AttentionBackendEnum] = []
if
current_platform
.
is_cuda
():
HEADS
=
[(
64
,
8
),
(
40
,
8
)]
MODELS_FP8
=
[
PATTERN_TEST_
MODELS_FP8
=
[
(
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"
,
TestAttentionFp8StaticQuantPatternModel
,
)
]
MODELS_FP4
=
[
PATTERN_TEST_
MODELS_FP4
=
[
(
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP4"
,
TestAttentionNvfp4QuantPatternModel
,
...
...
@@ -267,7 +281,7 @@ if current_platform.is_cuda():
elif
current_platform
.
is_rocm
():
HEADS
=
[(
32
,
8
),
(
40
,
8
)]
MODELS_FP8
=
[
PATTERN_TEST_
MODELS_FP8
=
[
(
"amd/Llama-3.1-8B-Instruct-FP8-KV"
,
TestAttentionFp8StaticQuantPatternModel
)
]
BACKENDS
=
[
...
...
@@ -286,9 +300,13 @@ elif current_platform.is_rocm():
@
pytest
.
mark
.
parametrize
(
"backend, model_name, model_class, custom_ops"
,
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
list
(
flat_product
(
BACKENDS_FP8
,
MODELS_FP8
,
[
"+quant_fp8"
,
"-quant_fp8"
]))
list
(
flat_product
(
BACKENDS_FP8
,
PATTERN_TEST_MODELS_FP8
,
[
"+quant_fp8"
,
"-quant_fp8"
]
)
)
# quant_fp4 only has the custom impl
+
list
(
flat_product
(
BACKENDS_FP4
,
MODELS_FP4
,
[
""
])),
+
list
(
flat_product
(
BACKENDS_FP4
,
PATTERN_TEST_
MODELS_FP4
,
[
""
])),
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda_alike
(),
reason
=
"Only test ROCm or CUDA"
...
...
@@ -315,6 +333,8 @@ def test_attention_quant_pattern(
not
current_platform
.
is_device_capability
((
10
,
0
))
or
not
has_flashinfer
()
):
pytest
.
skip
(
"FlashInfer attn fusion requires Blackwell and flashinfer"
)
if
"Llama-4-Scout"
in
model_name
and
cuda_device_count_stateless
()
<
2
:
pytest
.
skip
(
"Llama-4-Scout requires at least 2 GPUs"
)
custom_ops_list
=
custom_ops
.
split
(
","
)
if
custom_ops
else
[]
...
...
@@ -483,3 +503,88 @@ def test_attention_quant_pattern(
# Check that results are close
torch
.
testing
.
assert_close
(
result_unfused
,
result_fused
,
atol
=
1e-2
,
rtol
=
1e-2
)
@
pytest
.
mark
.
parametrize
(
"model_name, model_kwargs, backend, matches, custom_ops"
,
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
list
(
flat_product
(
MODELS_FP8
,
CUSTOM_OPS_FP8
))
# quant_fp4 only has the custom impl
+
list
(
flat_product
(
MODELS_FP4
,
[
""
])),
)
@
pytest
.
mark
.
parametrize
(
"inductor_graph_partition"
,
[
pytest
.
param
(
True
,
marks
=
pytest
.
mark
.
skipif
(
not
has_cuda_graph_wrapper_metadata
(),
reason
=
"This test requires"
"torch._inductor.utils.CUDAGraphWrapperMetadata to run"
,
),
),
False
,
],
)
def
test_attn_quant
(
model_name
:
str
,
model_kwargs
:
dict
[
str
,
Any
],
backend
:
AttentionBackendEnum
,
matches
:
Matches
,
custom_ops
:
str
,
inductor_graph_partition
:
bool
,
caplog_mp_spawn
,
monkeypatch
,
):
if
not
current_platform
.
has_device_capability
(
90
):
pytest
.
skip
(
"test_attn_quant requires H100 (SM90) or B200 (SM100) GPU"
)
if
backend
==
AttentionBackendEnum
.
FLASHINFER
and
(
not
is_blackwell
()
or
not
has_flashinfer
()
):
pytest
.
skip
(
"FlashInfer attn fusion requires Blackwell and flashinfer"
)
if
inductor_graph_partition
and
not
is_torch_equal_or_newer
(
"2.9.0.dev"
):
pytest
.
skip
(
"Inductor graph partition requires torch>=2.9"
)
custom_ops_list
=
custom_ops
.
split
(
","
)
if
custom_ops
else
[]
if
inductor_graph_partition
:
mode
=
CUDAGraphMode
.
FULL_AND_PIECEWISE
splitting_ops
:
list
[
str
]
|
None
=
None
else
:
# FIXME: Llama-4-Scout-17B-16E-Instruct-FP8 + FlashInfer + Blackwell end at
# CUDAGraphMode.NONE here because it derives an attention backend that
# does not support full cudagraphs
mode
=
CUDAGraphMode
.
FULL_DECODE_ONLY
splitting_ops
=
[]
# Disable, compile cache to make sure custom passes run.
# Otherwise, we can't verify fusion happened through the logs.
monkeypatch
.
setenv
(
"VLLM_DISABLE_COMPILE_CACHE"
,
"1"
)
# To capture subprocess logs, we need to know whether spawn or fork is used.
# Force spawn as it is more general.
monkeypatch
.
setenv
(
"VLLM_WORKER_MULTIPROC_METHOD"
,
"spawn"
)
model_kwargs
[
"attention_config"
]
=
{
"backend"
:
backend
.
name
}
compilation_config
=
CompilationConfig
(
# Testing properties
custom_ops
=
custom_ops_list
,
use_inductor_graph_partition
=
inductor_graph_partition
,
cudagraph_mode
=
mode
,
splitting_ops
=
splitting_ops
,
# Common
mode
=
CompilationMode
.
VLLM_COMPILE
,
pass_config
=
PassConfig
(
fuse_attn_quant
=
True
,
eliminate_noops
=
True
),
# Inductor caches custom passes by default as well via uuid
inductor_compile_config
=
{
"force_disable_caches"
:
True
},
)
with
caplog_mp_spawn
(
logging
.
DEBUG
)
as
log_holder
:
run_model
(
compilation_config
,
model_name
,
**
model_kwargs
)
log_matches
=
re
.
findall
(
r
"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes"
,
log_holder
.
text
,
)
assert
len
(
log_matches
)
==
1
,
log_holder
.
text
assert
int
(
log_matches
[
0
])
==
matches
.
attention_fusion
vllm/env_override.py
View file @
14ce5242
...
...
@@ -95,7 +95,7 @@ def memory_plan_reuse_patched(self):
# ===================================================
# This change monkeypatches get_graph_partition_signature in pytorch 2.9.0 to
# fix inductor partition + attention-nvfp4 quant fusion, tested in
# `tests/compile/
distributed/
test_fusion
s_e2e
.py::test_attn_quant`.
# `tests/compile/test_fusion
_attn
.py::test_attn_quant`.
# For more context, see https://github.com/pytorch/pytorch/pull/165815.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment