Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
006693ed
Commit
006693ed
authored
Dec 01, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.11.2' into v0.11.2-ori
parents
4b51e6f1
275de341
Changes
544
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
742 additions
and
381 deletions
+742
-381
tests/compile/test_full_graph.py
tests/compile/test_full_graph.py
+151
-103
tests/compile/test_functionalization.py
tests/compile/test_functionalization.py
+241
-93
tests/compile/test_fusion.py
tests/compile/test_fusion.py
+122
-53
tests/compile/test_fusion_all_reduce.py
tests/compile/test_fusion_all_reduce.py
+228
-132
No files found.
Too many changes to show.
To preserve performance only
544 of 544+
files are displayed.
Plain diff
Email patch
tests/compile/test_full_graph.py
View file @
006693ed
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
import
logging
import
tempfile
from
typing
import
Any
,
Optional
,
Union
from
pathlib
import
Path
from
typing
import
Any
import
pytest
import
torch
from
tests.quantization.utils
import
is_quant_method_supported
from
tests.v1.attention.utils
import
_Backend
from
vllm
import
LLM
,
SamplingParams
from
vllm.attention.selector
import
global_force_attn_backend_context_manager
from
vllm.config
import
(
CompilationConfig
,
CompilationLevel
,
CUDAGraphMode
,
PassConfig
)
from
vllm.attention.backends.registry
import
AttentionBackendEnum
from
vllm.config
import
CompilationConfig
,
CompilationMode
,
CUDAGraphMode
,
PassConfig
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_torch_equal_or_newer
from
vllm.utils
.torch_utils
import
is_torch_equal_or_newer
from
..utils
import
create_new_process_for_each_test
def
models_list
(
*
,
all
:
bool
=
True
,
keywords
:
Optional
[
list
[
str
]
]
=
None
):
def
models_list
(
*
,
all
:
bool
=
True
,
keywords
:
list
[
str
]
|
None
=
None
):
TEST_MODELS
:
list
[
tuple
[
str
,
dict
[
str
,
Any
]]]
=
[
(
"facebook/opt-125m"
,
{}),
(
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
,
{
"dtype"
:
torch
.
float16
,
}),
(
"neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic"
,
{
"dtype"
:
torch
.
float16
,
}),
(
"neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8"
,
{}),
(
"neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic"
,
{
"dtype"
:
torch
.
float16
},
),
(
"meta-llama/Llama-3.2-1B-Instruct"
,
{}),
]
if
all
:
TEST_MODELS
.
extend
(
[
(
"neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8"
,
{}),
(
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
,
{
"dtype"
:
torch
.
float16
},
),
]
)
# TODO: figure out why this fails.
if
False
and
is_quant_method_supported
(
"gguf"
):
# noqa: SIM223
TEST_MODELS
.
append
(
(
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
,
{
"quantization"
:
"gguf"
})
)
TEST_MODELS
.
append
(
(
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
,
{
"quantization"
:
"gguf"
})
)
if
is_quant_method_supported
(
"gptq"
):
TEST_MODELS
.
append
(
(
"TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ"
,
{
"quantization"
:
"gptq"
})
)
TEST_MODELS
.
append
(
(
"TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ"
,
{
"quantization"
:
"gptq"
})
)
if
is_quant_method_supported
(
"gptq_marlin"
):
TEST_MODELS
.
append
((
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ"
,
{
"quantization"
:
"gptq_marlin"
}))
TEST_MODELS
.
append
(
(
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ"
,
{
"quantization"
:
"gptq_marlin"
},
)
)
if
is_quant_method_supported
(
"gptq_marlin_24"
):
TEST_MODELS
.
append
((
"alexm-nm/tinyllama-24-marlin24-4bit-g128"
,
{
"quantization"
:
"gptq_marlin_24"
}))
TEST_MODELS
.
append
(
(
"alexm-nm/tinyllama-24-marlin24-4bit-g128"
,
{
"quantization"
:
"gptq_marlin_24"
},
)
)
if
not
current_platform
.
is_rocm
()
and
is_quant_method_supported
(
"awq"
):
TEST_MODELS
.
append
(
(
"TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ"
,
{
"quantization"
:
"AWQ"
})
)
TEST_MODELS
.
append
(
(
"TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ"
,
{
"quantization"
:
"AWQ"
})
)
if
keywords
is
None
:
return
TEST_MODELS
...
...
@@ -72,110 +80,145 @@ def models_list(*, all: bool = True, keywords: Optional[list[str]] = None):
@
pytest
.
mark
.
parametrize
(
"
optimiz
ation_
level
"
,
[
Compilation
Level
.
DYNAMO_ONCE
,
Compilation
Level
.
PIECEWIS
E
],
"
compil
ation_
mode
"
,
[
Compilation
Mode
.
DYNAMO_
TRACE_
ONCE
,
Compilation
Mode
.
VLLM_COMPIL
E
],
)
@
pytest
.
mark
.
parametrize
(
"model
_info
"
,
models_list
(
all
=
True
))
@
pytest
.
mark
.
parametrize
(
"model
, model_kwargs
"
,
models_list
(
all
=
True
))
@
create_new_process_for_each_test
()
def
test_full_graph
(
monkeypatch
:
pytest
.
MonkeyPatch
,
model_info
:
tuple
[
str
,
dict
[
str
,
Any
]],
optimization_level
:
int
,
model
:
str
,
model_kwargs
:
dict
[
str
,
Any
],
compilation_mode
:
int
,
):
model
,
model_kwargs
=
model_info
if
(
"w8a8"
in
model
or
"w8w8"
in
model
and
current_platform
.
has_device_capability
((
10
,
0
))
):
# int8 removed on Blackwell:
pytest
.
skip
(
"int8 support removed on Blackwell"
)
with
monkeypatch
.
context
():
print
(
f
"MODEL=
{
model
}
"
)
run_model
(
optimiz
ation_
level
,
model
,
model_kwargs
)
run_model
(
compil
ation_
mode
,
model
,
**
model_kwargs
)
# TODO(luka) add other supported compilation config scenarios here
@
pytest
.
mark
.
parametrize
(
"compilation_config, model
_info
"
,
"compilation_config, model
, model_kwargs
"
,
[
# additional compile sizes, only some of the models
(
CompilationConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
compile_sizes
=
[
1
,
2
]),
model
)
for
model
in
models_list
(
all
=
False
)
]
+
[
(
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
compile_sizes
=
[
1
,
2
]),
*
model_info
,
)
for
model_info
in
models_list
(
all
=
False
)
]
+
[
# RMSNorm + quant fusion, only 8-bit quant models
(
CompilationConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
custom_ops
=
[
"+rms_norm"
],
pass_config
=
PassConfig
(
enable_fusion
=
True
,
enable_noop
=
True
)),
model
)
for
model
in
models_list
(
keywords
=
[
"FP8-dynamic"
,
"quantized.w8a8"
])
]
+
[
(
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
custom_ops
=
[
"+rms_norm"
],
pass_config
=
PassConfig
(
enable_fusion
=
True
,
enable_noop
=
True
),
),
*
model_info
,
)
for
model_info
in
models_list
(
keywords
=
[
"FP8-dynamic"
,
"quantized.w8a8"
])
]
+
[
# Test depyf integration works
(
CompilationConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
debug_dump_path
=
tempfile
.
gettempdir
()),
(
"facebook/opt-125m"
,
{})),
]
+
[
(
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
debug_dump_path
=
Path
(
tempfile
.
gettempdir
()),
),
"facebook/opt-125m"
,
{},
),
]
+
[
# graph inductor partition
(
CompilationConfig
(
level
=
Compilation
Level
.
PIECEWIS
E
,
mode
=
Compilation
Mode
.
VLLM_COMPIL
E
,
# inductor graph partition uses
# torch._C.Tag.cudagraph_unsafe to specify splitting ops
use_inductor_graph_partition
=
True
,
cudagraph_mode
=
CUDAGraphMode
.
PIECEWISE
,
compile_sizes
=
[
1
,
2
]),
model
)
for
model
in
models_list
(
all
=
False
)
compile_sizes
=
[
1
,
2
],
),
*
model_info
,
)
for
model_info
in
models_list
(
all
=
False
)
if
is_torch_equal_or_newer
(
"2.9.0.dev"
)
])
],
)
# only test some of the models
@
create_new_process_for_each_test
()
def
test_custom_compile_config
(
compilation_config
:
CompilationConfig
,
model_info
:
tuple
[
str
,
dict
[
str
,
Any
]],
model
:
str
,
model_kwargs
:
dict
[
str
,
Any
],
):
if
(
compilation_config
.
use_inductor_graph_partition
and
not
is_torch_equal_or_newer
(
"2.9.0.dev"
)):
pytest
.
skip
(
"inductor graph partition is only available "
"in PyTorch 2.9+"
)
if
(
"w8a8"
in
model
or
"w8w8"
in
model
and
current_platform
.
has_device_capability
((
10
,
0
))
):
# int8 removed on Blackwell:
pytest
.
skip
(
"int8 support removed on Blackwell"
)
if
compilation_config
.
use_inductor_graph_partition
and
not
is_torch_equal_or_newer
(
"2.9.0.dev"
):
pytest
.
skip
(
"inductor graph partition is only available in PyTorch 2.9+"
)
model
,
model_kwargs
=
model_info
print
(
f
"MODEL=
{
model
}
"
)
run_model
(
compilation_config
,
model
,
model_kwargs
)
run_model
(
compilation_config
,
model
,
**
model_kwargs
)
def
test_inductor_graph_partition_attn_fusion
(
caplog_vllm
):
if
not
is_torch_equal_or_newer
(
"2.9.0.dev"
):
pytest
.
skip
(
"inductor graph partition is only available "
"in PyTorch 2.9+"
)
@
pytest
.
mark
.
parametrize
(
"compilation_mode"
,
[
CompilationMode
.
NONE
,
CompilationMode
.
VLLM_COMPILE
],
)
@
pytest
.
mark
.
parametrize
(
"model, backend"
,
[
(
"Qwen/Qwen2-0.5B"
,
None
),
# Standard attention model
(
"deepseek-ai/DeepSeek-V2-Lite"
,
AttentionBackendEnum
.
FLASHINFER_MLA
,
),
# MLA (Multi-head Latent Attention) model
],
)
def
test_fp8_kv_scale_compile
(
monkeypatch
:
pytest
.
MonkeyPatch
,
compilation_mode
:
int
,
model
:
str
,
backend
:
AttentionBackendEnum
|
None
,
):
if
backend
:
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
.
name
)
model
=
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"
compilation_config
=
CompilationConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
use_inductor_graph_partition
=
True
,
cudagraph_mode
=
CUDAGraphMode
.
PIECEWISE
,
custom_ops
=
[
"+quant_fp8"
],
pass_config
=
PassConfig
(
enable_attn_fusion
=
True
,
enable_noop
=
True
),
)
model_kwargs
=
{
"kv_cache_dtype"
:
"fp8"
,
"max_model_len"
:
1024
,
"quantization"
:
"fp8"
,
"kv_cache_dtype"
:
"fp8_e4m3"
,
"calculate_kv_scales"
:
True
,
"max_model_len"
:
512
,
}
with
caplog_vllm
.
at_level
(
logging
.
DEBUG
),
global_force_attn_backend_context_manager
(
_Backend
.
FLASHINFER
):
run_model
(
compilation_config
,
model
,
model_kwargs
)
try
:
assert
(
"Fused quantization onto 48 attention nodes"
in
caplog_vllm
.
text
),
caplog_vllm
.
text
except
AssertionError
:
# Note: this message is only triggered when the compilation goes
# through the custom pass. Due to multiple layers of cache on
# PyTorch side, the compilation of a graph may be cached such
# that custom pass directly goes through cache. In this case,
# we go through this branch and assert that the pass is not
# triggered.
assert
"Fused quantization"
not
in
caplog_vllm
.
text
def
run_model
(
compile_config
:
Union
[
int
,
CompilationConfig
],
model
:
str
,
model_kwargs
:
dict
[
str
,
Any
]):
run_model
(
compilation_mode
,
model
,
**
model_kwargs
)
def
run_model
(
compile_config
:
int
|
CompilationConfig
,
model
:
str
,
**
model_kwargs
):
compilation_config
=
(
compile_config
if
isinstance
(
compile_config
,
CompilationConfig
)
else
CompilationConfig
(
mode
=
compile_config
)
)
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
...
...
@@ -183,12 +226,17 @@ def run_model(compile_config: Union[int, CompilationConfig], model: str,
"The future of AI is"
,
]
sampling_params
=
SamplingParams
(
temperature
=
0
)
# Allow override from model_kwargs
model_kwargs
=
{
"tensor_parallel_size"
:
1
,
**
model_kwargs
}
model_kwargs
=
{
"disable_custom_all_reduce"
:
True
,
**
model_kwargs
}
# No cudagraphs by default
if
compilation_config
.
cudagraph_mode
is
None
:
compilation_config
.
cudagraph_mode
=
CUDAGraphMode
.
NONE
llm
=
LLM
(
model
=
model
,
enforce_eager
=
True
,
tensor_parallel_size
=
1
,
disable_custom_all_reduce
=
True
,
compilation_config
=
compile_config
,
compilation_config
=
compilation_config
,
**
model_kwargs
,
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
...
...
tests/compile/test_functionalization.py
View file @
006693ed
...
...
@@ -5,114 +5,262 @@ import pytest
import
torch
import
vllm.envs
as
envs
from
vllm
import
LLM
,
SamplingParams
from
vllm.compilation.activation_quant_fusion
import
ActivationQuantFusionPass
from
vllm.compilation.fix_functionalization
import
FixFunctionalizationPass
from
vllm.compilation.fusion
import
FUSED_OPS
,
RMSNormQuantFusionPass
from
vllm.compilation.fusion
import
RMSNormQuantFusionPass
from
vllm.compilation.fx_utils
import
find_auto_fn
,
find_auto_fn_maybe
,
is_func
from
vllm.compilation.noop_elimination
import
NoOpEliminationPass
from
vllm.compilation.post_cleanup
import
PostCleanupPass
from
vllm.config
import
CompilationConfig
,
PassConfig
,
VllmConfig
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
kFp8DynamicTokenSym
,
kFp8StaticTensorSym
)
from
vllm.config
import
(
CompilationConfig
,
ModelConfig
,
PassConfig
,
VllmConfig
,
set_current_vllm_config
,
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
Fp8LinearOp
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.platforms
import
current_platform
from
.backend
import
TestBackend
OPS_IN_MODEL
=
[
torch
.
ops
.
_C
.
rotary_embedding
.
default
,
torch
.
ops
.
_C
.
fused_add_rms_norm
.
default
,
]
TEST_FP8
=
current_platform
.
supports_fp8
()
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
class
TestSiluMul
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
=
128
):
super
().
__init__
()
self
.
silu_and_mul
=
SiluAndMul
()
self
.
wscale
=
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
self
.
scale
=
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
if
TEST_FP8
:
self
.
w
=
torch
.
rand
(
hidden_size
,
hidden_size
).
to
(
dtype
=
FP8_DTYPE
).
t
()
self
.
fp8_linear
=
Fp8LinearOp
(
act_quant_static
=
True
,
act_quant_group_shape
=
GroupShape
.
PER_TENSOR
,
)
def
forward
(
self
,
x
):
y
=
self
.
silu_and_mul
(
x
)
if
TEST_FP8
:
x2
=
self
.
fp8_linear
.
apply
(
y
,
self
.
w
,
self
.
wscale
,
input_scale
=
self
.
wscale
)
return
x2
else
:
return
y
def
example_inputs
(
self
,
num_tokens
=
32
,
hidden_size
=
128
):
return
(
torch
.
rand
(
num_tokens
,
hidden_size
*
2
),)
def
ops_in_model
(
self
,
do_fusion
):
if
TEST_FP8
and
do_fusion
:
return
[
torch
.
ops
.
_C
.
silu_and_mul_quant
.
default
]
else
:
return
[
torch
.
ops
.
_C
.
silu_and_mul
.
default
]
def
ops_not_in_model
(
self
):
return
[]
class
TestFusedAddRMSNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
=
16
,
intermediate_size
=
32
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
gate_proj
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
intermediate_size
,
hidden_size
))
)
self
.
norm
=
RMSNorm
(
intermediate_size
,
1e-05
)
self
.
norm
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
intermediate_size
))
torch
.
nn
.
init
.
normal_
(
self
.
gate_proj
,
std
=
0.02
)
if
TEST_FP8
:
self
.
fp8_linear
=
Fp8LinearOp
(
act_quant_static
=
True
)
self
.
scale
=
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
self
.
w
=
torch
.
rand
(
hidden_size
,
intermediate_size
).
to
(
dtype
=
FP8_DTYPE
).
t
()
self
.
wscale
=
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
def
forward
(
self
,
hidden_states
,
residual
):
# Reshape input
view
=
hidden_states
.
reshape
(
-
1
,
self
.
hidden_size
)
# matrix multiplication
permute
=
self
.
gate_proj
.
permute
(
1
,
0
)
mm
=
torch
.
mm
(
view
,
permute
)
# layer normalization
norm_output
,
residual_output
=
self
.
norm
(
mm
,
residual
)
if
TEST_FP8
:
# scaled_mm with static input quantization
fp8_linear_result
=
self
.
fp8_linear
.
apply
(
norm_output
,
self
.
w
,
self
.
wscale
,
input_scale
=
self
.
scale
.
to
(
norm_output
.
device
),
)
return
fp8_linear_result
,
residual_output
else
:
return
norm_output
,
residual_output
def
example_inputs
(
self
,
batch_size
=
8
,
hidden_size
=
16
,
seq_len
=
16
):
hidden_states
=
torch
.
randn
((
batch_size
*
seq_len
,
hidden_size
))
residual
=
torch
.
randn
((
batch_size
*
seq_len
,
hidden_size
))
return
(
hidden_states
,
residual
)
def
ops_in_model
(
self
,
do_fusion
):
if
TEST_FP8
and
do_fusion
:
return
[
torch
.
ops
.
_C
.
fused_add_rms_norm_static_fp8_quant
.
default
]
else
:
return
[
torch
.
ops
.
_C
.
fused_add_rms_norm
.
default
]
def
ops_not_in_model
(
self
):
return
[]
RMS_OP
=
torch
.
ops
.
_C
.
rms_norm
.
default
RMS_QUANT_OPS
=
{
"static_fp8"
:
[
torch
.
ops
.
_C
.
rms_norm_static_fp8_quant
.
default
,
torch
.
ops
.
_C
.
fused_add_rms_norm_static_fp8_quant
.
default
],
}
class
TestRotaryEmbedding
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
head_dim
=
64
,
rotary_dim
=
None
,
max_position
=
2048
,
base
=
10000
):
super
().
__init__
()
self
.
head_dim
=
head_dim
self
.
rotary_dim
=
rotary_dim
or
head_dim
SILU_MUL_OP
=
torch
.
ops
.
_C
.
silu_and_mul
.
default
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
rotary_dim
,
max_position
=
max_position
,
base
=
base
,
)
SILU_MUL_QUANT_OP
=
torch
.
ops
.
_C
.
silu_and_mul_quant
.
default
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
def
forward
(
self
,
positions
,
q
,
k
):
q_rotated
,
k_rotated
=
self
.
rotary_emb
(
positions
,
q
,
k
)
return
q_rotated
,
k_rotated
def
example_inputs
(
self
,
num_tokens
=
32
,
head_dim
=
64
):
positions
=
torch
.
arange
(
num_tokens
,
dtype
=
torch
.
long
)
q
=
torch
.
randn
(
num_tokens
,
head_dim
)
k
=
torch
.
randn
(
num_tokens
,
head_dim
)
return
(
positions
,
q
,
k
)
def
ops_in_model
(
self
,
do_fusion
):
return
[
torch
.
ops
.
_C
.
rotary_embedding
.
default
]
def
ops_not_in_model
(
self
):
return
[]
class
TestRotaryEmbeddingSliceScatter
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
head_dim
=
64
,
num_heads
=
4
,
max_position
=
2048
,
base
=
10000
):
super
().
__init__
()
self
.
head_dim
=
head_dim
self
.
num_heads
=
num_heads
self
.
hidden_size
=
head_dim
*
num_heads
self
.
qkv_proj
=
torch
.
nn
.
Linear
(
self
.
hidden_size
,
self
.
hidden_size
*
3
,
bias
=
False
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position
,
base
=
base
,
)
def
forward
(
self
,
positions
,
hidden_states
):
# Simulate the pattern: mm -> split_with_sizes -> rotary_embedding
# -> slice_scatter -> split_with_sizes
qkv
=
self
.
qkv_proj
(
hidden_states
)
split_sizes
=
[
self
.
hidden_size
,
self
.
hidden_size
,
self
.
hidden_size
]
q
,
k
,
v
=
torch
.
split
(
qkv
,
split_sizes
,
dim
=-
1
)
q_rotated
,
k_rotated
=
self
.
rotary_emb
(
positions
,
q
,
k
)
qkv_updated
=
torch
.
cat
([
q_rotated
,
k_rotated
,
v
],
dim
=-
1
)
return
qkv_updated
def
example_inputs
(
self
,
num_tokens
=
32
,
head_dim
=
64
,
num_heads
=
4
):
hidden_size
=
head_dim
*
num_heads
positions
=
torch
.
arange
(
num_tokens
,
dtype
=
torch
.
long
)
hidden_states
=
torch
.
randn
(
num_tokens
,
hidden_size
)
return
(
positions
,
hidden_states
)
def
ops_in_model
(
self
,
do_fusion
):
return
[
torch
.
ops
.
_C
.
rotary_embedding
.
default
]
def
ops_not_in_model
(
self
):
return
[
torch
.
ops
.
aten
.
slice_scatter
.
default
]
MODELS
=
[
TestSiluMul
,
TestFusedAddRMSNorm
,
TestRotaryEmbedding
,
TestRotaryEmbeddingSliceScatter
,
]
@
pytest
.
mark
.
parametrize
(
"model, quant_key"
,
[(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
,
kFp8StaticTensorSym
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8_DYNAMIC-e2e"
,
kFp8DynamicTokenSym
)])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"model_class"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"do_fusion"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
envs
.
VLLM_TARGET_DEVICE
!=
"cuda"
,
reason
=
"Only test on CUDA"
)
def
test_fix_functionalization
(
model
:
str
,
quant_key
:
QuantKey
,
do_fusion
:
bool
):
@
pytest
.
mark
.
skipif
(
envs
.
VLLM_TARGET_DEVICE
!=
"cuda"
,
reason
=
"Only test on CUDA"
)
def
test_fix_functionalization
(
model_class
:
torch
.
nn
.
Module
,
do_fusion
:
bool
,
dtype
:
torch
.
dtype
):
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_dtype
(
dtype
)
vllm_config
=
VllmConfig
(
model_config
=
ModelConfig
(
dtype
=
dtype
),
compilation_config
=
CompilationConfig
(
custom_ops
=
[
"all"
],
pass_config
=
PassConfig
(
enable_fusion
=
do_fusion
,
enable_noop
=
True
),
),
)
with
set_current_vllm_config
(
vllm_config
):
assert
RMSNorm
.
enabled
()
noop_pass
=
NoOpEliminationPass
(
vllm_config
)
fusion_pass
=
RMSNormQuantFusionPass
(
vllm_config
)
cleanup_pass
=
PostCleanupPass
(
vllm_config
)
act_quant_fusion_pass
=
ActivationQuantFusionPass
(
vllm_config
)
passes
=
(
[
noop_pass
,
fusion_pass
,
act_quant_fusion_pass
,
cleanup_pass
]
if
do_fusion
else
[
noop_pass
,
cleanup_pass
]
)
func_pass
=
FixFunctionalizationPass
(
vllm_config
)
backend_func
=
TestBackend
(
*
passes
,
func_pass
)
backend_no_func
=
TestBackend
(
*
passes
)
model
=
model_class
()
torch
.
compile
(
model
,
backend
=
backend_func
)(
*
model
.
example_inputs
())
torch
.
compile
(
model
,
backend
=
backend_no_func
)(
*
model
.
example_inputs
())
# check if the functionalization pass is applied
for
op
in
model
.
ops_in_model
(
do_fusion
):
find_auto_fn
(
backend_no_func
.
graph_post_pass
.
nodes
,
op
)
assert
find_auto_fn_maybe
(
backend_func
.
graph_post_pass
.
nodes
,
op
)
is
None
vllm_config
=
VllmConfig
()
vllm_config
.
compilation_config
=
CompilationConfig
(
pass_config
=
PassConfig
(
enable_fusion
=
do_fusion
,
enable_noop
=
True
))
noop_pass
=
NoOpEliminationPass
(
vllm_config
)
fusion_pass
=
RMSNormQuantFusionPass
(
vllm_config
)
cleanup_pass
=
PostCleanupPass
(
vllm_config
)
act_quant_fusion_pass
=
ActivationQuantFusionPass
(
vllm_config
)
passes
=
[
noop_pass
,
fusion_pass
,
act_quant_fusion_pass
,
cleanup_pass
]
if
do_fusion
else
[
noop_pass
,
cleanup_pass
]
func_pass
=
FixFunctionalizationPass
(
vllm_config
)
backend_func
=
TestBackend
(
*
passes
,
func_pass
)
backend_no_func
=
TestBackend
(
*
passes
)
# instantiate a full engine and manually compile the model 2x
# (with and without FixFunctionalizationPass)
llm
=
LLM
(
model
=
model
,
enforce_eager
=
True
)
model_runner
=
llm
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
orig_model
=
model_runner
.
model
# TODO mark inputs dynamic? (currently torch.compile is triggered 4x)
# Can only do that by using the decorator but then we'd have to instantiate
# 2 LLM instances.
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
top_p
=
1.0
)
model_runner
.
model
=
torch
.
compile
(
orig_model
,
fullgraph
=
True
,
backend
=
backend_func
)
gen_func
=
llm
.
generate
(
prompts
,
sampling_params
)
model_runner
.
model
=
torch
.
compile
(
orig_model
,
fullgraph
=
True
,
backend
=
backend_no_func
)
gen_no_func
=
llm
.
generate
(
prompts
,
sampling_params
)
for
output_func
,
output_no_func
in
zip
(
gen_func
,
gen_no_func
):
assert
output_func
.
outputs
[
0
].
text
==
output_no_func
.
outputs
[
0
].
text
# OPS_IN_MODEL always appear. RMS_OP is fused away if we run fusion,
# and replaced by fused quantized ops in RMS_QUANT_OPS.
rms_ops
=
[
FUSED_OPS
[(
quant_key
,
True
)],
FUSED_OPS
[(
quant_key
,
False
)]
]
if
do_fusion
else
[
RMS_OP
]
silu_mul_ops
=
[
SILU_MUL_QUANT_OP
]
if
do_fusion
and
\
quant_key
==
kFp8StaticTensorSym
else
[
SILU_MUL_OP
]
ops
=
OPS_IN_MODEL
+
rms_ops
+
silu_mul_ops
for
op
in
ops
:
find_auto_fn
(
backend_no_func
.
graph_post_pass
.
nodes
,
op
)
assert
find_auto_fn_maybe
(
backend_func
.
graph_post_pass
.
nodes
,
op
)
is
None
# noqa: E501
# make sure the ops were all de-functionalized
found
=
dict
()
for
node
in
backend_func
.
graph_post_pass
.
nodes
:
for
op
in
ops
:
if
is_func
(
node
,
op
):
found
[
op
]
=
True
assert
all
(
found
[
op
]
for
op
in
ops
)
# make sure the ops were all de-functionalized
found
=
dict
()
for
node
in
backend_func
.
graph_post_pass
.
nodes
:
for
op
in
model
.
ops_in_model
(
do_fusion
):
if
is_func
(
node
,
op
):
found
[
op
]
=
True
for
op
in
model
.
ops_not_in_model
():
if
is_func
(
node
,
op
):
found
[
op
]
=
True
assert
all
(
found
[
op
]
for
op
in
model
.
ops_in_model
(
do_fusion
))
assert
all
(
not
found
.
get
(
op
)
for
op
in
model
.
ops_not_in_model
())
tests/compile/test_fusion.py
View file @
006693ed
...
...
@@ -5,17 +5,29 @@ import pytest
import
torch
import
vllm.plugins
from
vllm.compilation.fusion
import
(
FUSED_OPS
,
QUANT_OPS
,
FusedRMSQuantKey
,
RMSNormQuantFusionPass
)
from
vllm.compilation.fusion
import
FUSED_OPS
,
FusedRMSQuantKey
,
RMSNormQuantFusionPass
from
vllm.compilation.fx_utils
import
find_op_nodes
from
vllm.compilation.matcher_utils
import
QUANT_OPS
from
vllm.compilation.noop_elimination
import
NoOpEliminationPass
from
vllm.compilation.post_cleanup
import
PostCleanupPass
from
vllm.config
import
(
CompilationConfig
,
CompilationLevel
,
PassConfig
,
VllmConfig
)
from
vllm.config
import
(
CompilationConfig
,
CompilationMode
,
ModelConfig
,
PassConfig
,
VllmConfig
,
)
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
QuantKey
,
ScaleDesc
)
GroupShape
,
QuantKey
,
ScaleDesc
,
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
Fp8LinearOp
,
cutlass_fp8_supported
,
maybe_create_device_identity
)
Fp8LinearOp
,
cutlass_fp8_supported
,
maybe_create_device_identity
,
)
from
vllm.platforms
import
current_platform
from
..utils
import
override_cutlass_fp8_supported
...
...
@@ -23,25 +35,34 @@ from .backend import TestBackend
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
RMS_OP
=
torch
.
ops
.
_C
.
rms_norm
.
default
RMS_ADD_OP
=
torch
.
ops
.
_C
.
fused_add_rms_norm
.
default
class
TestModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
eps
:
float
,
static
:
bool
,
cuda_force_torch
:
bool
,
*
args
,
**
kwargs
):
class
TestModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
eps
:
float
,
static
:
bool
,
cuda_force_torch
:
bool
,
*
args
,
**
kwargs
,
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
cuda_force_torch
=
cuda_force_torch
self
.
norm
=
[
RMSNorm
(
hidden_size
,
eps
)
for
_
in
range
(
3
)]
self
.
wscale
=
[
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
for
_
in
range
(
2
)]
self
.
norm
=
[
RMSNorm
(
hidden_size
,
eps
)
for
_
in
range
(
4
)]
self
.
wscale
=
[
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
for
_
in
range
(
3
)]
group_shape
=
GroupShape
.
PER_TENSOR
if
static
else
GroupShape
.
PER_TOKEN
quant_scale
=
ScaleDesc
(
torch
.
float32
,
static
,
group_shape
)
self
.
key
=
QuantKey
(
dtype
=
FP8_DTYPE
,
scale
=
quant_scale
,
symmetric
=
True
)
self
.
quant_
key
=
QuantKey
(
dtype
=
FP8_DTYPE
,
scale
=
quant_scale
,
symmetric
=
True
)
if
static
:
self
.
scale
=
[
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
for
_
in
range
(
2
)]
self
.
scale
=
[
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
for
_
in
range
(
3
)]
else
:
self
.
scale
=
[
None
for
_
in
range
(
2
)]
self
.
scale
=
[
None
for
_
in
range
(
3
)]
self
.
w
=
[
torch
.
rand
(
hidden_size
,
hidden_size
).
to
(
dtype
=
FP8_DTYPE
).
t
()
for
_
in
range
(
2
)
for
_
in
range
(
3
)
]
with
override_cutlass_fp8_supported
(
not
cuda_force_torch
):
...
...
@@ -50,57 +71,97 @@ class TestModel(torch.nn.Module):
act_quant_group_shape
=
group_shape
,
)
self
.
enable_rms_norm_custom_op
=
self
.
norm
[
0
].
enabled
()
self
.
enable_quant_fp8_custom_op
=
self
.
fp8_linear
.
quant_fp8
.
enabled
()
def
forward
(
self
,
x
):
resid
=
torch
.
sqrt
(
x
)
# avoid having graph input be an arg to a pattern directly
x
=
resid
=
torch
.
relu
(
x
)
y
=
self
.
norm
[
0
](
x
)
x2
=
self
.
fp8_linear
.
apply
(
y
,
self
.
w
[
0
],
self
.
wscale
[
0
],
input_scale
=
self
.
scale
[
0
])
x2
=
self
.
fp8_linear
.
apply
(
y
,
self
.
w
[
0
],
self
.
wscale
[
0
],
input_scale
=
self
.
scale
[
0
]
)
# make sure resid is used for replacement to work
y2
,
resid
=
self
.
norm
[
1
](
x2
,
resid
)
x3
=
self
.
fp8_linear
.
apply
(
y2
,
self
.
w
[
1
],
self
.
wscale
[
1
],
input_scale
=
self
.
scale
[
1
])
x3
=
self
.
fp8_linear
.
apply
(
y2
,
self
.
w
[
1
],
self
.
wscale
[
1
],
input_scale
=
self
.
scale
[
1
]
)
y3
,
resid
=
self
.
norm
[
2
](
x3
,
resid
)
# use resid here
return
y3
def
ops_in_model_before
(
self
):
return
[
QUANT_OPS
[
self
.
key
]]
x4
=
self
.
fp8_linear
.
apply
(
y3
,
self
.
w
[
2
],
self
.
wscale
[
2
],
input_scale
=
self
.
scale
[
2
]
)
y4
,
resid
=
self
.
norm
[
3
](
x4
,
resid
)
# use resid here
return
y4
def
ops_in_model_after
(
self
):
return
[
FUSED_OPS
[
FusedRMSQuantKey
(
self
.
key
,
Fals
e
)],
FUSED_OPS
[
FusedRMSQuantKey
(
self
.
key
,
Tru
e
)]
FUSED_OPS
[
FusedRMSQuantKey
(
self
.
quant_
key
,
Tru
e
)],
FUSED_OPS
[
FusedRMSQuantKey
(
self
.
quant_
key
,
Fals
e
)]
,
]
def
ops_in_model_before
(
self
):
return
(
[
QUANT_OPS
[
self
.
quant_key
]]
if
self
.
enable_quant_fp8_custom_op
else
[
torch
.
ops
.
aten
.
reciprocal
]
)
def
ops_in_model_before_partial
(
self
):
return
(
[
RMS_OP
,
RMS_ADD_OP
]
if
self
.
enable_rms_norm_custom_op
else
[
torch
.
ops
.
aten
.
rsqrt
]
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
257
])
@
pytest
.
mark
.
parametrize
(
"eps"
,
[
1e-5
,
1e-6
])
@
pytest
.
mark
.
parametrize
(
"static"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"enable_rms_norm_custom_op"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"enable_quant_fp8_custom_op"
,
[
True
,
False
])
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@
pytest
.
mark
.
parametrize
(
"cuda_force_torch"
,
[
True
,
False
]
if
cutlass_fp8_supported
()
else
[
True
])
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda_alike
(),
reason
=
"Only test on CUDA and ROCm"
)
def
test_fusion_rmsnorm_quant
(
dtype
,
hidden_size
,
num_tokens
,
eps
,
static
,
cuda_force_torch
):
@
pytest
.
mark
.
parametrize
(
"cuda_force_torch"
,
[
True
,
False
]
if
cutlass_fp8_supported
()
else
[
True
]
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda_alike
(),
reason
=
"Only test on CUDA and ROCm"
)
def
test_fusion_rmsnorm_quant
(
dtype
,
hidden_size
,
num_tokens
,
eps
,
static
,
enable_rms_norm_custom_op
,
enable_quant_fp8_custom_op
,
cuda_force_torch
,
):
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_dtype
(
dtype
)
torch
.
manual_seed
(
1
)
maybe_create_device_identity
()
# needed for certain non-cutlass fp8 paths
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
custom_ops
=
[
"+rms_norm"
,
"+quant_fp8"
],
pass_config
=
PassConfig
(
enable_fusion
=
True
,
enable_noop
=
True
),
))
custom_ops
=
[]
if
enable_rms_norm_custom_op
:
custom_ops
.
append
(
"+rms_norm"
)
if
enable_quant_fp8_custom_op
:
custom_ops
.
append
(
"+quant_fp8"
)
vllm_config
=
VllmConfig
(
model_config
=
ModelConfig
(
dtype
=
dtype
),
compilation_config
=
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
custom_ops
=
custom_ops
,
pass_config
=
PassConfig
(
enable_fusion
=
True
,
enable_noop
=
True
),
),
)
with
vllm
.
config
.
set_current_vllm_config
(
vllm_config
):
# Reshape pass is needed for the fusion pass to work
noop_pass
=
NoOpEliminationPass
(
vllm_config
)
...
...
@@ -108,31 +169,39 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
cleanup_pass
=
PostCleanupPass
(
vllm_config
)
backend
=
TestBackend
(
noop_pass
,
fusion_pass
,
cleanup_pass
)
backend2
=
TestBackend
(
noop_pass
,
cleanup_pass
)
model
=
TestModel
(
hidden_size
,
eps
,
static
,
cuda_force_torch
)
# First dimension dynamic
x
=
torch
.
rand
(
num_tokens
,
hidden_size
)
torch
.
_dynamo
.
mark_dynamic
(
x
,
0
)
result
=
model
(
x
)
model_fused
=
torch
.
compile
(
model
,
backend
=
backend
)
result_fused
=
model_fused
(
x
)
model
2
=
torch
.
compile
(
model
,
backend
=
backend
)
result
2
=
model2
(
x
)
model
_unfused
=
torch
.
compile
(
model
,
backend
=
backend
2
)
result
_unfused
=
model_unfused
(
x
)
# Higher tol for dynamic, even higher for bfloat16
if
static
:
ATOL
,
RTOL
=
(
1e-3
,
1e-3
)
elif
dtype
==
torch
.
float16
:
if
dtype
==
torch
.
float16
:
ATOL
,
RTOL
=
(
2e-3
,
2e-3
)
else
:
ATOL
,
RTOL
=
(
1e-2
,
1e-2
)
torch
.
testing
.
assert_close
(
result
,
result
2
,
atol
=
ATOL
,
rtol
=
RTOL
)
torch
.
testing
.
assert_close
(
result
_fused
,
result
_unfused
,
atol
=
ATOL
,
rtol
=
RTOL
)
assert
fusion_pass
.
matched_count
==
2
# In pre-nodes, fp8 quant should be there and fused kernels should not
assert
fusion_pass
.
matched_count
==
3
backend
.
check_before_ops
(
model
.
ops_in_model_before
())
# In post-nodes, fused kernels should be there and fp8 quant should not
backend
.
check_before_ops
(
model
.
ops_in_model_before_partial
(),
fully_replaced
=
False
)
backend
.
check_after_ops
(
model
.
ops_in_model_after
())
# If RMSNorm custom op is disabled (native/torch impl used),
# there's a risk that the fused add doesn't get included in the
# replacement and only the rms part gets fused with quant.
# Hence, we check only 2 add nodes are left (final fused rmsnorm add).
if
not
enable_rms_norm_custom_op
:
n_add_nodes
=
lambda
g
:
sum
(
1
for
_
in
find_op_nodes
(
torch
.
ops
.
aten
.
add
,
g
))
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
assert
n_add_nodes
(
backend
.
graph_pre_pass
)
==
7
assert
n_add_nodes
(
backend
.
graph_post_pass
)
==
2
tests/compile/test_fusion_all_reduce.py
View file @
006693ed
...
...
@@ -6,59 +6,66 @@ import pytest
import
torch
import
vllm.envs
as
envs
from
vllm._custom_ops
import
cutlass_scaled_fp4_mm
,
scaled_fp4_quant
from
vllm.compilation.collective_fusion
import
AllReduceFusionPass
from
vllm.compilation.fix_functionalization
import
FixFunctionalizationPass
from
vllm.compilation.noop_elimination
import
NoOpEliminationPass
from
vllm.compilation.post_cleanup
import
PostCleanupPass
from
vllm.config
import
(
CompilationConfig
,
CompilationLevel
,
DeviceConfig
,
ModelConfig
,
PassConfig
,
VllmConfig
)
from
vllm.config
import
(
CompilationConfig
,
CompilationMode
,
DeviceConfig
,
ModelConfig
,
PassConfig
,
VllmConfig
,
set_current_vllm_config
,
)
from
vllm.distributed
import
tensor_model_parallel_all_reduce
from
vllm.distributed.parallel_state
import
(
init_distributed_environment
,
initialize_model_parallel
)
from
vllm.distributed.parallel_state
import
(
init_distributed_environment
,
initialize_model_parallel
,
)
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
GroupShape
,
QuantFP8
)
Fp8LinearOp
,
GroupShape
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
update_environment_variables
from
vllm.utils
.system_utils
import
update_environment_variables
from
..utils
import
has_module_attribute
,
multi_gpu_test
from
.backend
import
TestBackend
class
TestAllReduceRMSNormModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
=
16
,
token_num
=
16
,
eps
=
1e-6
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
eps
=
eps
self
.
norm
=
RMSNorm
(
hidden_size
,
eps
)
self
.
norm
=
[
RMSNorm
(
hidden_size
,
eps
)
for
i
in
range
(
4
)]
self
.
w
=
[
torch
.
rand
(
hidden_size
,
hidden_size
)
for
_
in
range
(
3
)]
def
forward
(
self
,
hidden_states
,
residual
):
view
=
hidden_states
.
reshape
(
-
1
,
self
.
hidden_size
)
all_reduce
=
tensor_model_parallel_all_reduce
(
view
)
norm
=
self
.
norm
(
all_reduce
)
return
norm
def
forward
(
self
,
x
):
# avoid having graph input be an arg to a pattern directly
z
=
torch
.
relu
(
x
)
x
=
resid
=
tensor_model_parallel_
all_reduce
(
z
)
y
=
self
.
norm
[
0
](
x
)
def
ops_in_model_before
(
self
):
return
[
torch
.
ops
.
vllm
.
all_reduce
.
default
]
z2
=
torch
.
mm
(
y
,
self
.
w
[
0
])
x2
=
tensor_model_parallel_
all_reduce
(
z2
)
def
ops_in_model_after
(
self
):
return
[
torch
.
ops
.
vllm
.
flashinfer_trtllm_fused_allreduce_norm
.
default
]
y2
,
resid
=
self
.
norm
[
1
](
x2
,
resid
)
z3
=
torch
.
mm
(
y2
,
self
.
w
[
1
])
x3
=
tensor_model_parallel_all_reduce
(
z3
)
class
TestAllReduceFusedAddRMSNormModel
(
torch
.
nn
.
Module
):
y3
,
resid
=
self
.
norm
[
2
](
x3
,
resid
)
def
__init__
(
self
,
hidden_size
=
16
,
token_num
=
16
,
eps
=
1e-6
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
eps
=
eps
self
.
norm
=
RMSNorm
(
hidden_size
,
eps
)
z4
=
torch
.
mm
(
y3
,
self
.
w
[
2
])
x4
=
tensor_model_parallel_all_reduce
(
z4
)
def
forward
(
self
,
hidden_states
,
residual
):
view
=
hidden_states
.
reshape
(
-
1
,
self
.
hidden_size
)
all_reduce
=
tensor_model_parallel_all_reduce
(
view
)
norm
,
_
=
self
.
norm
(
all_reduce
,
residual
)
return
norm
y4
,
resid
=
self
.
norm
[
3
](
x4
,
resid
)
return
y4
def
ops_in_model_before
(
self
):
return
[
torch
.
ops
.
vllm
.
all_reduce
.
default
]
...
...
@@ -67,27 +74,53 @@ class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
return
[
torch
.
ops
.
vllm
.
flashinfer_trtllm_fused_allreduce_norm
.
default
]
class
TestAllReduceFusedAddRMSNormStaticQuantFP8Model
(
torch
.
nn
.
Module
):
class
TestAllReduceRMSNormStaticQuantFP8Model
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
=
16
,
token_num
=
16
,
eps
=
1e-6
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
eps
=
eps
self
.
norm
=
RMSNorm
(
hidden_size
,
eps
)
self
.
quant_fp8
=
QuantFP8
(
static
=
True
,
group_shape
=
GroupShape
.
PER_TENSOR
)
self
.
scale
=
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
self
.
output
=
torch
.
empty
((
token_num
,
hidden_size
),
dtype
=
torch
.
float32
)
def
forward
(
self
,
hidden_states
,
residual
):
view
=
hidden_states
.
reshape
(
-
1
,
self
.
hidden_size
)
all_reduce
=
tensor_model_parallel_all_reduce
(
view
)
norm_output
,
residual_output
=
self
.
norm
(
all_reduce
,
residual
)
torch
.
ops
.
_C
.
static_scaled_fp8_quant
(
self
.
output
,
norm_output
.
contiguous
(),
self
.
scale
)
return
self
.
output
,
residual_output
self
.
norm
=
[
RMSNorm
(
hidden_size
,
eps
)
for
i
in
range
(
4
)]
self
.
wscale
=
[
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
for
_
in
range
(
3
)]
self
.
w
=
[
torch
.
rand
(
hidden_size
,
hidden_size
)
.
to
(
dtype
=
current_platform
.
fp8_dtype
())
.
t
()
for
_
in
range
(
3
)
]
self
.
fp8_linear
=
Fp8LinearOp
(
act_quant_static
=
True
,
act_quant_group_shape
=
GroupShape
.
PER_TENSOR
,
)
self
.
scale
=
[
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
for
_
in
range
(
3
)]
def
forward
(
self
,
hidden_states
):
# avoid having graph input be an arg to a pattern directly
z
=
torch
.
relu
(
hidden_states
)
x
=
resid
=
tensor_model_parallel_all_reduce
(
z
)
y
=
self
.
norm
[
0
](
x
)
z2
=
self
.
fp8_linear
.
apply
(
y
,
self
.
w
[
0
],
self
.
wscale
[
0
],
input_scale
=
self
.
scale
[
0
]
)
x2
=
tensor_model_parallel_all_reduce
(
z2
)
y2
,
resid
=
self
.
norm
[
1
](
x2
,
resid
)
z3
=
self
.
fp8_linear
.
apply
(
y2
,
self
.
w
[
1
],
self
.
wscale
[
1
],
input_scale
=
self
.
scale
[
1
]
)
x3
=
tensor_model_parallel_all_reduce
(
z3
)
y3
,
resid
=
self
.
norm
[
2
](
x3
,
resid
)
# use resid here
z4
=
self
.
fp8_linear
.
apply
(
y3
,
self
.
w
[
2
],
self
.
wscale
[
2
],
input_scale
=
self
.
scale
[
2
]
)
x4
=
tensor_model_parallel_all_reduce
(
z4
)
y4
,
resid
=
self
.
norm
[
3
](
x4
,
resid
)
# use resid here
return
y4
def
ops_in_model_after
(
self
):
return
[
torch
.
ops
.
vllm
.
flashinfer_trtllm_fused_allreduce_norm
.
default
]
...
...
@@ -96,35 +129,58 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
return
[
torch
.
ops
.
vllm
.
all_reduce
.
default
,
torch
.
ops
.
_C
.
static_scaled_fp8_quant
.
default
if
self
.
fp8_linear
.
quant_fp8
.
enabled
()
else
torch
.
ops
.
aten
.
reciprocal
.
default
,
]
class
TestAllReduceFusedAddRMSNormStaticQuantFP4Model
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
=
16
,
token_num
=
16
,
eps
=
1e-6
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
eps
=
eps
self
.
norm
=
RMSNorm
(
hidden_size
,
eps
)
self
.
scale
=
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
self
.
output
=
torch
.
empty
((
token_num
,
hidden_size
),
dtype
=
torch
.
float32
)
round_up
=
lambda
x
,
y
:
(
x
+
y
-
1
)
//
y
*
y
rounded_m
=
round_up
(
token_num
,
128
)
scale_n
=
hidden_size
//
16
rounded_n
=
round_up
(
scale_n
,
4
)
self
.
output_scale
=
torch
.
empty
((
rounded_m
,
rounded_n
//
4
),
dtype
=
torch
.
int32
)
def
forward
(
self
,
hidden_states
,
residual
):
view
=
hidden_states
.
reshape
(
-
1
,
self
.
hidden_size
)
all_reduce
=
tensor_model_parallel_all_reduce
(
view
)
norm_output
,
residual_output
=
self
.
norm
(
all_reduce
,
residual
)
norm_output
=
norm_output
.
reshape
(
-
1
,
norm_output
.
shape
[
-
1
])
torch
.
ops
.
_C
.
scaled_fp4_quant
(
self
.
output
,
norm_output
,
self
.
output_scale
,
self
.
scale
)
return
self
.
output
,
residual_output
,
self
.
output_scale
self
.
norm
=
[
RMSNorm
(
hidden_size
,
eps
)
for
i
in
range
(
4
)]
self
.
w
=
[
torch
.
rand
(
hidden_size
,
hidden_size
)
for
_
in
range
(
3
)]
self
.
agscale
=
[
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
for
_
in
range
(
3
)]
wgscale
=
[
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
for
_
in
range
(
3
)]
self
.
alpha
=
[
1
/
(
w
*
a
)
for
w
,
a
in
zip
(
wgscale
,
self
.
agscale
)]
wq_gen
,
wscale_gen
=
zip
(
*
(
scaled_fp4_quant
(
w
,
wg
)
for
w
,
wg
in
zip
(
self
.
w
,
wgscale
))
)
self
.
wq
,
self
.
wscale
=
list
(
wq_gen
),
list
(
wscale_gen
)
print
(
f
"
{
self
.
wq
=
}
,
{
self
.
wscale
=
}
"
)
def
forward
(
self
,
hidden_states
):
# avoid having graph input be an arg to a pattern directly
z
=
torch
.
relu
(
hidden_states
)
x
=
resid
=
tensor_model_parallel_all_reduce
(
z
)
y
=
self
.
norm
[
0
](
x
)
yq
,
y_scale
=
scaled_fp4_quant
(
y
,
self
.
agscale
[
0
])
z2
=
cutlass_scaled_fp4_mm
(
yq
,
self
.
wq
[
0
],
y_scale
,
self
.
wscale
[
0
],
self
.
alpha
[
0
],
out_dtype
=
y
.
dtype
)
x2
=
tensor_model_parallel_all_reduce
(
z2
)
y2
,
resid
=
self
.
norm
[
1
](
x2
,
resid
)
yq2
,
y_scale2
=
scaled_fp4_quant
(
y2
,
self
.
agscale
[
1
])
z3
=
cutlass_scaled_fp4_mm
(
yq2
,
self
.
wq
[
1
],
y_scale2
,
self
.
wscale
[
1
],
self
.
alpha
[
1
],
out_dtype
=
y2
.
dtype
)
x3
=
tensor_model_parallel_all_reduce
(
z3
)
y3
,
resid
=
self
.
norm
[
2
](
x3
,
resid
)
# use resid here
yq3
,
y_scale3
=
scaled_fp4_quant
(
y3
,
self
.
agscale
[
2
])
z4
=
cutlass_scaled_fp4_mm
(
yq3
,
self
.
wq
[
2
],
y_scale3
,
self
.
wscale
[
2
],
self
.
alpha
[
2
],
out_dtype
=
y3
.
dtype
)
x4
=
tensor_model_parallel_all_reduce
(
z4
)
y4
,
resid
=
self
.
norm
[
3
](
x4
,
resid
)
# use resid here
return
y4
def
ops_in_model_after
(
self
):
return
[
torch
.
ops
.
vllm
.
flashinfer_trtllm_fused_allreduce_norm
.
default
]
...
...
@@ -132,54 +188,81 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
def
ops_in_model_before
(
self
):
return
[
torch
.
ops
.
vllm
.
all_reduce
.
default
,
torch
.
ops
.
_C
.
scaled_fp4_quant
.
default
torch
.
ops
.
_C
.
scaled_fp4_quant
.
default
,
]
@
multi_gpu_test
(
num_gpus
=
2
)
@
pytest
.
mark
.
parametrize
(
"test_model"
,
"test_model
, enable_quant_fp8_custom_op
"
,
[
TestAllReduceRMSNormModel
,
TestAllReduce
FusedAddRMSNormModel
,
TestAllReduce
FusedAdd
RMSNormStaticQuantFP8Model
,
# TODO: Enable with torch==2.8.0
# TestAllReduceFusedAddRMSNormStaticQuantFP4Model
,
]
)
(
TestAllReduceRMSNormModel
,
False
),
(
TestAllReduce
RMSNormStaticQuantFP8Model
,
True
)
,
(
TestAllReduceRMSNormStaticQuantFP8Model
,
False
),
(
TestAllReduceFusedAddRMSNormStaticQuantFP4Model
,
False
),
]
,
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
1
6
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
6
4
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
skipif
(
envs
.
VLLM_TARGET_DEVICE
not
in
[
"cuda"
],
reason
=
"Only test on CUDA"
)
@
pytest
.
mark
.
parametrize
(
"enable_rms_norm_custom_op"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
envs
.
VLLM_TARGET_DEVICE
not
in
[
"cuda"
],
reason
=
"Only test on CUDA"
)
@
pytest
.
mark
.
skipif
(
not
find_spec
(
"flashinfer"
)
or
not
has_module_attribute
(
"flashinfer.comm"
,
"trtllm_allreduce_fusion"
),
reason
=
"flashinfer is not found or flashinfer "
"is not compiled with trtllm_allreduce_fusion"
)
def
test_all_reduce_fusion_pass_replace
(
test_model
:
torch
.
nn
.
Module
,
batch_size
:
int
,
seq_len
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
):
"is not compiled with trtllm_allreduce_fusion"
,
)
def
test_all_reduce_fusion_pass_replace
(
test_model
:
torch
.
nn
.
Module
,
batch_size
:
int
,
seq_len
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
enable_rms_norm_custom_op
,
enable_quant_fp8_custom_op
,
):
num_processes
=
2
if
(
test_model
==
TestAllReduceFusedAddRMSNormStaticQuantFP4Model
and
not
current_platform
.
has_device_capability
(
100
)):
pytest
.
skip
(
"Skip as nvfp4 is only supported on "
"devices with compute capability 10.0 (Blackwell)"
)
if
(
test_model
==
TestAllReduceFusedAddRMSNormStaticQuantFP4Model
and
not
current_platform
.
has_device_capability
(
100
)
):
pytest
.
skip
(
"Skip as nvfp4 is only supported on "
"devices with compute capability 10.0 (Blackwell)"
)
def
run_torch_spawn
(
fn
,
nprocs
):
torch
.
multiprocessing
.
spawn
(
fn
,
args
=
(
num_processes
,
test_model
,
batch_size
,
seq_len
,
hidden_size
,
dtype
),
nprocs
=
nprocs
)
torch
.
multiprocessing
.
spawn
(
fn
,
args
=
(
num_processes
,
test_model
,
batch_size
,
seq_len
,
hidden_size
,
dtype
,
enable_rms_norm_custom_op
,
enable_quant_fp8_custom_op
,
),
nprocs
=
nprocs
,
)
run_torch_spawn
(
all_reduce_fusion_pass_on_test_model
,
num_processes
)
def
all_reduce_fusion_pass_on_test_model
(
local_rank
:
int
,
world_size
:
int
,
test_model_cls
:
torch
.
nn
.
Module
,
batch_size
:
int
,
seq_len
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
):
def
all_reduce_fusion_pass_on_test_model
(
local_rank
:
int
,
world_size
:
int
,
test_model_cls
:
torch
.
nn
.
Module
,
batch_size
:
int
,
seq_len
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
enable_rms_norm_custom_op
,
enable_quant_fp8_custom_op
,
):
current_platform
.
seed_everything
(
0
)
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
...
...
@@ -187,50 +270,63 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
torch
.
set_default_device
(
device
)
torch
.
set_default_dtype
(
dtype
)
update_environment_variables
({
'RANK'
:
str
(
local_rank
),
'LOCAL_RANK'
:
str
(
local_rank
),
'WORLD_SIZE'
:
str
(
world_size
),
'MASTER_ADDR'
:
'localhost'
,
'MASTER_PORT'
:
'12345'
,
})
update_environment_variables
(
{
"RANK"
:
str
(
local_rank
),
"LOCAL_RANK"
:
str
(
local_rank
),
"WORLD_SIZE"
:
str
(
world_size
),
"MASTER_ADDR"
:
"localhost"
,
"MASTER_PORT"
:
"12345"
,
}
)
init_distributed_environment
()
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
custom_ops
=
[
"+rms_norm"
,
"+quant_fp8"
]))
custom_ops
=
[]
if
enable_rms_norm_custom_op
:
custom_ops
.
append
(
"+rms_norm"
)
if
enable_quant_fp8_custom_op
:
custom_ops
.
append
(
"+quant_fp8"
)
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
custom_ops
=
custom_ops
)
)
vllm_config
.
compilation_config
.
pass_config
=
PassConfig
(
enable_fi_allreduce_fusion
=
True
,
enable_noop
=
True
)
enable_fi_allreduce_fusion
=
True
,
enable_noop
=
True
)
vllm_config
.
device_config
=
DeviceConfig
(
device
=
torch
.
device
(
"cuda"
))
vllm_config
.
parallel_config
.
rank
=
local_rank
# Setup rank for debug path
# this is a fake model name to construct the model config
# in the vllm_config, it's not really used.
model_name
=
"nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
vllm_config
.
model_config
=
ModelConfig
(
model
=
model_name
,
trust_remote_code
=
True
,
dtype
=
dtype
,
seed
=
42
)
all_reduce_fusion_pass
=
AllReduceFusionPass
(
vllm_config
)
noop_pass
=
NoOpEliminationPass
(
vllm_config
)
func_pass
=
FixFunctionalizationPass
(
vllm_config
)
cleanup_pass
=
PostCleanupPass
(
vllm_config
)
backend
=
TestBackend
(
all_reduce_fusion_pass
,
noop_pass
,
func_pass
,
cleanup_pass
)
token_num
=
batch_size
*
seq_len
model
=
test_model_cls
(
hidden_size
,
token_num
)
hidden_states
=
torch
.
randn
((
token_num
,
hidden_size
),
requires_grad
=
False
)
residual
=
torch
.
randn
((
token_num
,
hidden_size
),
requires_grad
=
False
)
compiled_model
=
torch
.
compile
(
model
,
backend
=
backend
)
compiled_model
(
hidden_states
,
residual
)
assert
all_reduce_fusion_pass
.
matched_count
==
1
backend
.
check_before_ops
(
model
.
ops_in_model_before
(),
fully_replaced
=
False
)
backend
.
check_after_ops
(
model
.
ops_in_model_after
())
del
all_reduce_fusion_pass
model_name
=
"RedHatAI/Llama-3.2-1B-Instruct-FP8"
vllm_config
.
model_config
=
ModelConfig
(
model
=
model_name
,
trust_remote_code
=
True
,
dtype
=
dtype
,
seed
=
42
)
with
set_current_vllm_config
(
vllm_config
):
all_reduce_fusion_pass
=
AllReduceFusionPass
(
vllm_config
)
noop_pass
=
NoOpEliminationPass
(
vllm_config
)
func_pass
=
FixFunctionalizationPass
(
vllm_config
)
cleanup_pass
=
PostCleanupPass
(
vllm_config
)
backend
=
TestBackend
(
noop_pass
,
all_reduce_fusion_pass
,
func_pass
,
cleanup_pass
)
token_num
=
batch_size
*
seq_len
model
=
test_model_cls
(
hidden_size
,
token_num
)
hidden_states
=
torch
.
randn
((
token_num
,
hidden_size
),
requires_grad
=
False
)
compiled_model
=
torch
.
compile
(
model
,
backend
=
backend
)
compiled_model
(
hidden_states
)
assert
all_reduce_fusion_pass
.
matched_count
==
4
,
(
f
"
{
all_reduce_fusion_pass
.
matched_count
=
}
"
)
backend
.
check_before_ops
(
model
.
ops_in_model_before
(),
fully_replaced
=
False
)
backend
.
check_after_ops
(
model
.
ops_in_model_after
())
del
all_reduce_fusion_pass
Prev
1
…
24
25
26
27
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment