Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e6327c9b
Unverified
Commit
e6327c9b
authored
Jun 23, 2025
by
cascade
Committed by
GitHub
Jun 23, 2025
Browse files
[Feature] Support sequence parallelism for static fp8 quantization (#19181)
Signed-off-by:
cascade812
<
cascade812@outlook.com
>
parent
d0132f02
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
534 additions
and
198 deletions
+534
-198
tests/compile/test_sequence_parallelism.py
tests/compile/test_sequence_parallelism.py
+144
-17
tests/distributed/test_sequence_parallel.py
tests/distributed/test_sequence_parallel.py
+52
-56
tests/models/registry.py
tests/models/registry.py
+2
-1
vllm/compilation/fusion.py
vllm/compilation/fusion.py
+2
-2
vllm/compilation/pass_manager.py
vllm/compilation/pass_manager.py
+4
-4
vllm/compilation/sequence_parallelism.py
vllm/compilation/sequence_parallelism.py
+328
-114
vllm/config.py
vllm/config.py
+2
-4
No files found.
tests/compile/test_sequence_parallelism.py
View file @
e6327c9b
...
@@ -6,7 +6,9 @@ import torch
...
@@ -6,7 +6,9 @@ import torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.compilation.fix_functionalization
import
FixFunctionalizationPass
from
vllm.compilation.fix_functionalization
import
FixFunctionalizationPass
from
vllm.compilation.fusion
import
FusionPass
from
vllm.compilation.fx_utils
import
find_auto_fn
,
find_auto_fn_maybe
,
is_func
from
vllm.compilation.fx_utils
import
find_auto_fn
,
find_auto_fn_maybe
,
is_func
from
vllm.compilation.noop_elimination
import
NoOpEliminationPass
from
vllm.compilation.sequence_parallelism
import
SequenceParallelismPass
from
vllm.compilation.sequence_parallelism
import
SequenceParallelismPass
from
vllm.config
import
(
CompilationConfig
,
DeviceConfig
,
ModelConfig
,
from
vllm.config
import
(
CompilationConfig
,
DeviceConfig
,
ModelConfig
,
PassConfig
,
VllmConfig
)
PassConfig
,
VllmConfig
)
...
@@ -14,12 +16,15 @@ from vllm.distributed import tensor_model_parallel_all_reduce
...
@@ -14,12 +16,15 @@ from vllm.distributed import tensor_model_parallel_all_reduce
from
vllm.distributed.parallel_state
import
(
init_distributed_environment
,
from
vllm.distributed.parallel_state
import
(
init_distributed_environment
,
initialize_model_parallel
)
initialize_model_parallel
)
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
Fp8LinearOp
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
update_environment_variables
from
vllm.utils
import
update_environment_variables
from
..utils
import
multi_gpu_test
from
..utils
import
multi_gpu_test
from
.backend
import
TestBackend
from
.backend
import
TestBackend
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
prompts
=
[
prompts
=
[
"Hello, my name is"
,
"Hello, my name is"
,
"The president of the United States is"
,
"The president of the United States is"
,
...
@@ -30,13 +35,16 @@ prompts = [
...
@@ -30,13 +35,16 @@ prompts = [
class
TestModel
(
torch
.
nn
.
Module
):
class
TestModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
=
16
,
intermediate_size
=
32
):
def
__init__
(
self
,
hidden_size
=
16
,
intermediate_size
=
32
,
vllm_config
:
VllmConfig
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
intermediate_size
=
intermediate_size
self
.
gate_proj
=
torch
.
nn
.
Parameter
(
self
.
gate_proj
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
intermediate_size
,
hidden_size
)))
torch
.
empty
((
intermediate_size
,
hidden_size
)))
self
.
norm
=
RMSNorm
(
hidden
_size
,
1e-05
)
self
.
norm
=
RMSNorm
(
intermediate
_size
,
1e-05
)
# Initialize weights
# Initialize weights
torch
.
nn
.
init
.
normal_
(
self
.
gate_proj
,
std
=
0.02
)
torch
.
nn
.
init
.
normal_
(
self
.
gate_proj
,
std
=
0.02
)
...
@@ -79,32 +87,138 @@ class TestModel(torch.nn.Module):
...
@@ -79,32 +87,138 @@ class TestModel(torch.nn.Module):
return
[
torch
.
ops
.
_C
.
fused_add_rms_norm
.
default
]
return
[
torch
.
ops
.
_C
.
fused_add_rms_norm
.
default
]
class
TestQuantModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
=
16
,
intermediate_size
=
32
,
vllm_config
:
VllmConfig
=
None
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
vllm_config
=
vllm_config
self
.
gate_proj
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
(
intermediate_size
,
hidden_size
)),
requires_grad
=
False
)
self
.
norm
=
RMSNorm
(
intermediate_size
,
1e-05
)
# Initialize weights
torch
.
nn
.
init
.
normal_
(
self
.
gate_proj
,
std
=
0.02
)
self
.
fp8_linear
=
Fp8LinearOp
(
cutlass_fp8_supported
=
True
,
use_per_token_if_dynamic
=
False
)
self
.
scale
=
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
# Create a weight that is compatible with torch._scaled_mm,
# which expects a column-major layout.
self
.
w
=
torch
.
rand
(
hidden_size
,
intermediate_size
).
to
(
dtype
=
FP8_DTYPE
).
t
()
self
.
wscale
=
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
def
forward
(
self
,
hidden_states
,
residual
):
"""
Forward pass implementing the operations in the FX graph
Args:
hidden_states: Input tensor
residual: Residual tensor from previous layer
Returns:
Tuple containing the output tensor
"""
# Reshape input
view
=
hidden_states
.
reshape
(
-
1
,
self
.
hidden_size
)
#matrix multiplication
permute
=
self
.
gate_proj
.
permute
(
1
,
0
)
mm
=
torch
.
mm
(
view
,
permute
)
# Tensor parallel all-reduce
all_reduce
=
tensor_model_parallel_all_reduce
(
mm
)
# layer normalization
norm_output
,
residual_output
=
self
.
norm
(
all_reduce
,
residual
)
# for static input quantization
# self.fp8_linear is initialized with use_per_token_if_dynamic=False
fp8_linear_result
=
self
.
fp8_linear
.
apply
(
norm_output
,
self
.
w
,
self
.
wscale
,
input_scale
=
self
.
scale
.
to
(
norm_output
.
device
))
return
fp8_linear_result
,
residual_output
def
ops_in_model_before
(
self
):
ops_to_remove
=
[
torch
.
ops
.
vllm
.
all_reduce
.
default
]
# Always removed by SP
# The following are only removed if fusion happens
if
self
.
vllm_config
and
self
.
vllm_config
.
compilation_config
\
.
pass_config
.
enable_fusion
:
ops_to_remove
.
extend
([
torch
.
ops
.
_C
.
fused_add_rms_norm
.
default
,
torch
.
ops
.
_C
.
static_scaled_fp8_quant
.
default
,
])
return
ops_to_remove
def
ops_in_model_after
(
self
):
ops_to_add
=
[
torch
.
ops
.
vllm
.
reduce_scatter
.
default
,
torch
.
ops
.
vllm
.
all_gather
.
default
]
# The following is only added if fusion happens
if
self
.
vllm_config
and
self
.
vllm_config
.
compilation_config
\
.
pass_config
.
enable_fusion
:
ops_to_add
.
append
(
torch
.
ops
.
_C
.
fused_add_rms_norm_static_fp8_quant
.
default
)
return
ops_to_add
def
ops_in_model
(
self
):
if
self
.
vllm_config
and
self
.
vllm_config
.
compilation_config
\
.
pass_config
.
enable_fusion
:
# If fusion happens, the fused op is the one
# we check for (de)functionalization
return
[
torch
.
ops
.
_C
.
fused_add_rms_norm_static_fp8_quant
.
default
]
# noqa: E501
else
:
# If no fusion, the original ops are checked
return
[
torch
.
ops
.
_C
.
fused_add_rms_norm
.
default
,
# TODO functionalization pass does not handle this yet
# torch.ops._C.static_scaled_fp8_quant.default,
]
@
multi_gpu_test
(
num_gpus
=
2
)
@
multi_gpu_test
(
num_gpus
=
2
)
@
pytest
.
mark
.
parametrize
(
"test_model_cls"
,
[
TestModel
,
TestQuantModel
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"enable_fusion"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
envs
.
VLLM_TARGET_DEVICE
not
in
[
"cuda"
],
@
pytest
.
mark
.
skipif
(
envs
.
VLLM_TARGET_DEVICE
not
in
[
"cuda"
],
reason
=
"Only test on CUDA"
)
reason
=
"Only test on CUDA"
)
def
test_sequence_parallelism_pass
(
batch_size
:
int
,
seq_len
:
int
,
def
test_sequence_parallelism_pass
(
test_model_cls
:
type
[
torch
.
nn
.
Module
],
hidden_size
:
int
,
dtype
:
torch
.
dtype
):
batch_size
:
int
,
seq_len
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
enable_fusion
:
bool
):
num_processes
=
2
num_processes
=
2
def
run_torch_spawn
(
fn
,
nprocs
):
def
run_torch_spawn
(
fn
,
nprocs
):
# need to use torch.mp.spawn otherwise will have problems with
# need to use torch.mp.spawn otherwise will have problems with
# torch.distributed and cuda
# torch.distributed and cuda
torch
.
multiprocessing
.
spawn
(
fn
,
torch
.
multiprocessing
.
spawn
(
fn
,
args
=
(
num_processes
,
batch_size
,
seq_len
,
args
=
(
num_processes
,
test_model_cls
,
hidden_size
,
dtype
),
batch_size
,
seq_len
,
hidden_size
,
dtype
,
enable_fusion
),
nprocs
=
nprocs
)
nprocs
=
nprocs
)
run_torch_spawn
(
sequence_parallelism_pass_on_test_model
,
num_processes
)
run_torch_spawn
(
sequence_parallelism_pass_on_test_model
,
num_processes
)
def
sequence_parallelism_pass_on_test_model
(
local_rank
:
int
,
world_size
:
int
,
def
sequence_parallelism_pass_on_test_model
(
batch_size
:
int
,
seq_len
:
int
,
local_rank
:
int
,
world_size
:
int
,
hidden_size
:
int
,
test_model_cls
:
type
[
torch
.
nn
.
Module
],
batch_size
:
int
,
seq_len
:
int
,
dtype
:
torch
.
dtype
):
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
enable_fusion
:
bool
):
current_platform
.
seed_everything
(
0
)
current_platform
.
seed_everything
(
0
)
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
...
@@ -127,26 +241,39 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
...
@@ -127,26 +241,39 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
# configure vllm config for SequenceParallelismPass
# configure vllm config for SequenceParallelismPass
vllm_config
=
VllmConfig
()
vllm_config
=
VllmConfig
()
vllm_config
.
compilation_config
=
CompilationConfig
(
pass_config
=
PassConfig
(
vllm_config
.
compilation_config
=
CompilationConfig
(
pass_config
=
PassConfig
(
enable_sequence_parallelism
=
True
))
enable_sequence_parallelism
=
True
,
enable_fusion
=
enable_fusion
,
enable_noop
=
True
))
# NoOp needed for fusion
vllm_config
.
device_config
=
DeviceConfig
(
device
=
torch
.
device
(
"cuda"
))
vllm_config
.
device_config
=
DeviceConfig
(
device
=
torch
.
device
(
"cuda"
))
# this is a fake model name to construct the model config
# this is a fake model name to construct the model config
# in the vllm_config, it's not really used.
# in the vllm_config, it's not really used.
model
=
"nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
model
_name
=
"nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
vllm_config
.
model_config
=
ModelConfig
(
model
=
model
,
vllm_config
.
model_config
=
ModelConfig
(
model
=
model
_name
,
task
=
"auto"
,
task
=
"auto"
,
tokenizer
=
model
,
tokenizer
=
model
_name
,
tokenizer_mode
=
"auto"
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
dtype
=
dtype
,
dtype
=
dtype
,
seed
=
42
)
seed
=
42
)
sequence_parallelism_pass
=
SequenceParallelismPass
(
vllm_config
)
sequence_parallelism_pass
=
SequenceParallelismPass
(
vllm_config
)
backend_no_func
=
TestBackend
(
sequence_parallelism_pass
)
noop_pass
=
NoOpEliminationPass
(
vllm_config
)
func_pass
=
FixFunctionalizationPass
(
vllm_config
)
func_pass
=
FixFunctionalizationPass
(
vllm_config
)
backend_func
=
TestBackend
(
sequence_parallelism_pass
,
func_pass
)
model
=
TestModel
(
hidden_size
,
hidden_size
*
2
)
passes_for_backend
=
[
noop_pass
,
sequence_parallelism_pass
]
if
enable_fusion
:
fusion_pass
=
FusionPass
.
instance
(
vllm_config
)
passes_for_backend
.
append
(
fusion_pass
)
backend_no_func
=
TestBackend
(
*
passes_for_backend
)
backend_func
=
TestBackend
(
*
passes_for_backend
,
func_pass
)
model
=
test_model_cls
(
hidden_size
,
hidden_size
*
2
,
vllm_config
=
vllm_config
)
hidden_states
=
torch
.
randn
((
batch_size
*
seq_len
,
hidden_size
),
hidden_states
=
torch
.
randn
((
batch_size
*
seq_len
,
hidden_size
),
dtype
=
dtype
)
dtype
=
dtype
)
residual
=
torch
.
randn
((
batch_size
*
seq_len
,
hidden_size
),
dtype
=
dtype
)
residual
=
torch
.
randn
((
batch_size
*
seq_len
,
hidden_size
),
dtype
=
dtype
)
...
...
tests/distributed/test_sequence_parallel.py
View file @
e6327c9b
...
@@ -28,7 +28,7 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
...
@@ -28,7 +28,7 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
class
ParallelSetup
(
NamedTuple
):
class
ParallelSetup
(
NamedTuple
):
tp_size
:
int
tp_size
:
int
pp_size
:
int
pp_size
:
int
sp_
enable
d
:
bool
enable
_fusion
:
bool
eager_mode
:
bool
eager_mode
:
bool
chunked_prefill
:
bool
chunked_prefill
:
bool
...
@@ -67,49 +67,18 @@ class SPTestSettings:
...
@@ -67,49 +67,18 @@ class SPTestSettings:
task
:
TaskOption
=
"auto"
,
task
:
TaskOption
=
"auto"
,
load_format
:
Optional
[
str
]
=
None
,
load_format
:
Optional
[
str
]
=
None
,
):
):
parallel_setups
=
[]
for
eager_mode_val
in
[
False
,
True
]:
for
pp_multiplier
in
[
1
,
2
]:
for
chunked_prefill_val
in
[
False
,
True
]:
parallel_setups
.
append
(
ParallelSetup
(
tp_size
=
tp_base
,
pp_size
=
pp_multiplier
*
pp_base
,
enable_fusion
=
False
,
eager_mode
=
eager_mode_val
,
chunked_prefill
=
chunked_prefill_val
))
return
SPTestSettings
(
return
SPTestSettings
(
parallel_setups
=
[
parallel_setups
=
parallel_setups
,
ParallelSetup
(
tp_size
=
tp_base
,
pp_size
=
pp_base
,
sp_enabled
=
True
,
eager_mode
=
False
,
chunked_prefill
=
False
),
ParallelSetup
(
tp_size
=
tp_base
,
pp_size
=
pp_base
,
sp_enabled
=
True
,
eager_mode
=
False
,
chunked_prefill
=
True
),
ParallelSetup
(
tp_size
=
tp_base
,
pp_size
=
pp_base
,
sp_enabled
=
True
,
eager_mode
=
True
,
chunked_prefill
=
False
),
ParallelSetup
(
tp_size
=
tp_base
,
pp_size
=
pp_base
,
sp_enabled
=
True
,
eager_mode
=
True
,
chunked_prefill
=
True
),
ParallelSetup
(
tp_size
=
tp_base
,
pp_size
=
2
*
pp_base
,
sp_enabled
=
True
,
eager_mode
=
False
,
chunked_prefill
=
False
),
ParallelSetup
(
tp_size
=
tp_base
,
pp_size
=
2
*
pp_base
,
sp_enabled
=
True
,
eager_mode
=
False
,
chunked_prefill
=
True
),
ParallelSetup
(
tp_size
=
tp_base
,
pp_size
=
2
*
pp_base
,
sp_enabled
=
True
,
eager_mode
=
True
,
chunked_prefill
=
False
),
ParallelSetup
(
tp_size
=
tp_base
,
pp_size
=
2
*
pp_base
,
sp_enabled
=
True
,
eager_mode
=
True
,
chunked_prefill
=
True
)
],
distributed_backends
=
[
"mp"
,
"ray"
],
distributed_backends
=
[
"mp"
,
"ray"
],
vllm_major_versions
=
[
"1"
,
"1"
],
vllm_major_versions
=
[
"1"
,
"1"
],
task
=
task
,
task
=
task
,
...
@@ -126,19 +95,44 @@ class SPTestSettings:
...
@@ -126,19 +95,44 @@ class SPTestSettings:
multi_node_only
:
bool
=
False
,
multi_node_only
:
bool
=
False
,
load_format
:
Optional
[
str
]
=
None
,
load_format
:
Optional
[
str
]
=
None
,
):
):
parallel_setups
=
[]
for
eager_mode_val
in
[
False
,
True
]:
for
pp_multiplier
in
[
1
,
2
]:
for
chunked_prefill_val
in
[
False
,
True
]:
parallel_setups
.
append
(
ParallelSetup
(
tp_size
=
tp_base
,
pp_size
=
pp_multiplier
*
pp_base
,
enable_fusion
=
False
,
eager_mode
=
eager_mode_val
,
chunked_prefill
=
chunked_prefill_val
))
return
SPTestSettings
(
return
SPTestSettings
(
parallel_setups
=
[
parallel_setups
=
parallel_setups
,
distributed_backends
=
[
"mp"
,
"ray"
],
vllm_major_versions
=
[
"1"
,
"1"
],
task
=
task
,
test_options
=
SPTestOptions
(
multi_node_only
=
multi_node_only
,
load_format
=
load_format
),
)
@
staticmethod
def
fp8_quant
(
*
,
tp_base
:
int
=
2
,
pp_base
:
int
=
1
,
task
:
TaskOption
=
"auto"
,
multi_node_only
:
bool
=
False
,
load_format
:
Optional
[
str
]
=
None
,
):
parallel_setups
=
[]
for
fusion_val
in
[
False
,
True
]:
parallel_setups
.
append
(
ParallelSetup
(
tp_size
=
tp_base
,
ParallelSetup
(
tp_size
=
tp_base
,
pp_size
=
pp_base
,
pp_size
=
pp_base
,
sp_enabled
=
True
,
enable_fusion
=
fusion_val
,
eager_mode
=
False
,
eager_mode
=
True
,
chunked_prefill
=
False
),
chunked_prefill
=
False
))
ParallelSetup
(
tp_size
=
tp_base
,
return
SPTestSettings
(
pp_size
=
2
*
pp_base
,
parallel_setups
=
parallel_setups
,
sp_enabled
=
True
,
eager_mode
=
False
,
chunked_prefill
=
False
),
],
distributed_backends
=
[
"mp"
,
"ray"
],
distributed_backends
=
[
"mp"
,
"ray"
],
vllm_major_versions
=
[
"1"
,
"1"
],
vllm_major_versions
=
[
"1"
,
"1"
],
task
=
task
,
task
=
task
,
...
@@ -171,7 +165,7 @@ def _compare_sp(
...
@@ -171,7 +165,7 @@ def _compare_sp(
(
(
tp_size
,
tp_size
,
pp_size
,
pp_size
,
sp_
enable
d
,
enable
_fusion
,
eager_mode
,
eager_mode
,
chunked_prefill
,
chunked_prefill
,
)
=
parallel_setup
)
=
parallel_setup
...
@@ -240,9 +234,9 @@ def _compare_sp(
...
@@ -240,9 +234,9 @@ def _compare_sp(
'compile_sizes'
:
[
4
,
8
],
'compile_sizes'
:
[
4
,
8
],
'splitting_ops'
:
[],
'splitting_ops'
:
[],
'pass_config'
:
{
'pass_config'
:
{
'enable_sequence_parallelism'
:
sp_enabled
,
'enable_sequence_parallelism'
:
True
,
'enable_fusion'
:
enable_fusion
,
'enable_noop'
:
True
,
'enable_noop'
:
True
,
'enable_fusion'
:
True
,
},
},
}
}
...
@@ -291,12 +285,14 @@ def _compare_sp(
...
@@ -291,12 +285,14 @@ def _compare_sp(
SP_TEXT_GENERATION_MODELS
=
{
SP_TEXT_GENERATION_MODELS
=
{
# [Decoder-only]
# [Decoder-only]
"meta-llama/Llama-3.2-1B-Instruct"
:
SPTestSettings
.
fast
(),
"meta-llama/Llama-3.2-1B-Instruct"
:
SPTestSettings
.
fast
(),
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"
:
SPTestSettings
.
fp8_quant
(),
}
}
SP_TEST_MODELS
=
[
SP_TEST_MODELS
=
[
# TODO support other models
# TODO support other models
# [LANGUAGE GENERATION]
# [LANGUAGE GENERATION]
"meta-llama/Llama-3.2-1B-Instruct"
,
"meta-llama/Llama-3.2-1B-Instruct"
,
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"
]
]
...
...
tests/models/registry.py
View file @
e6327c9b
...
@@ -193,7 +193,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
...
@@ -193,7 +193,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
extras
=
{
"tiny"
:
"ai21labs/Jamba-tiny-dev"
}),
# noqa: E501
extras
=
{
"tiny"
:
"ai21labs/Jamba-tiny-dev"
}),
# noqa: E501
"LlamaForCausalLM"
:
_HfExamplesInfo
(
"meta-llama/Llama-3.2-1B-Instruct"
,
"LlamaForCausalLM"
:
_HfExamplesInfo
(
"meta-llama/Llama-3.2-1B-Instruct"
,
extras
=
{
"guard"
:
"meta-llama/Llama-Guard-3-1B"
,
# noqa: E501
extras
=
{
"guard"
:
"meta-llama/Llama-Guard-3-1B"
,
# noqa: E501
"hermes"
:
"NousResearch/Hermes-3-Llama-3.1-8B"
}),
# noqa: E501
"hermes"
:
"NousResearch/Hermes-3-Llama-3.1-8B"
,
# noqa: E501
"fp8"
:
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"
}),
# noqa: E501
"LLaMAForCausalLM"
:
_HfExamplesInfo
(
"decapoda-research/llama-7b-hf"
,
"LLaMAForCausalLM"
:
_HfExamplesInfo
(
"decapoda-research/llama-7b-hf"
,
is_available_online
=
False
),
is_available_online
=
False
),
"MambaForCausalLM"
:
_HfExamplesInfo
(
"state-spaces/mamba-130m-hf"
),
"MambaForCausalLM"
:
_HfExamplesInfo
(
"state-spaces/mamba-130m-hf"
),
...
...
vllm/compilation/fusion.py
View file @
e6327c9b
...
@@ -345,8 +345,8 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
...
@@ -345,8 +345,8 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
# 0 is always None
# 0 is always None
fused_return_mapping
=
{
1
:
(
quant_node
,
1
),
2
:
(
rms_node
,
2
)}
fused_return_mapping
=
{
1
:
(
quant_node
,
1
),
2
:
(
rms_node
,
2
)}
self
.
insert_fused_node
(
fused_return_mapping
,
self
.
insert_fused_node
(
fused_return_mapping
,
epsilon
=
rms_node
.
kwargs
[
"epsilon"
]
,
**
kwargs
,
**
kwargs
)
epsilon
=
rms_node
.
kwargs
[
"epsilon"
]
)
class
RMSNormDynamicQuantPattern
(
RMSNormQuantPattern
):
class
RMSNormDynamicQuantPattern
(
RMSNormQuantPattern
):
...
...
vllm/compilation/pass_manager.py
View file @
e6327c9b
...
@@ -51,15 +51,15 @@ class PostGradPassManager(CustomGraphPass):
...
@@ -51,15 +51,15 @@ class PostGradPassManager(CustomGraphPass):
if
self
.
pass_config
.
enable_noop
:
if
self
.
pass_config
.
enable_noop
:
self
.
passes
+=
[
NoOpEliminationPass
(
config
)]
self
.
passes
+=
[
NoOpEliminationPass
(
config
)]
if
self
.
pass_config
.
enable_fusion
:
self
.
passes
+=
[
FusionPass
.
instance
(
config
)]
self
.
passes
+=
[
ActivationQuantFusionPass
(
config
)]
if
self
.
pass_config
.
enable_sequence_parallelism
:
if
self
.
pass_config
.
enable_sequence_parallelism
:
self
.
passes
+=
[
SequenceParallelismPass
(
config
)]
self
.
passes
+=
[
SequenceParallelismPass
(
config
)]
if
self
.
pass_config
.
enable_async_tp
:
if
self
.
pass_config
.
enable_async_tp
:
self
.
passes
+=
[
AsyncTPPass
(
config
)]
self
.
passes
+=
[
AsyncTPPass
(
config
)]
if
self
.
pass_config
.
enable_fusion
:
self
.
passes
+=
[
FusionPass
.
instance
(
config
)]
self
.
passes
+=
[
ActivationQuantFusionPass
(
config
)]
if
self
.
pass_config
.
enable_attn_fusion
:
if
self
.
pass_config
.
enable_attn_fusion
:
self
.
passes
+=
[
AttnFusionPass
(
config
)]
self
.
passes
+=
[
AttnFusionPass
(
config
)]
...
...
vllm/compilation/sequence_parallelism.py
View file @
e6327c9b
...
@@ -12,91 +12,142 @@ from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
...
@@ -12,91 +12,142 @@ from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from
vllm.distributed.parallel_state
import
(
from
vllm.distributed.parallel_state
import
(
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
.vllm_inductor_pass
import
VllmInductorPass
from
.vllm_inductor_pass
import
VllmInductorPass
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
class
AllReduceRMSNormPattern
:
class
_RMSNormAndQuantOpHelper
:
"""Base helper for RMSNorm and RMSNorm + Quantization functionalization."""
def
__init__
(
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
):
def
__init__
(
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
,
quant_op
:
Optional
[
torch
.
_ops
.
OpOverload
]
=
None
,
**
kwargs
):
self
.
epsilon
=
epsilon
self
.
epsilon
=
epsilon
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
device
=
device
self
.
device
=
device
self
.
quant_op
=
quant_op
class
EmbeddingAllReduceRMSNormPattern
(
AllReduceRMSNormPattern
):
def
_functional_rmsnorm
(
self
,
result_buffer
,
input_tensor
,
weight_tensor
):
return
torch
.
ops
.
higher_order
.
auto_functionalized
(
torch
.
ops
.
_C
.
rms_norm
.
default
,
result
=
result_buffer
,
input
=
input_tensor
,
weight
=
weight_tensor
,
epsilon
=
self
.
epsilon
)
def
_functional_fused_add_rmsnorm
(
self
,
input_tensor
,
residual_tensor
,
weight_tensor
):
return
torch
.
ops
.
higher_order
.
auto_functionalized
(
torch
.
ops
.
_C
.
fused_add_rms_norm
.
default
,
input
=
input_tensor
,
residual
=
residual_tensor
,
weight
=
weight_tensor
,
epsilon
=
self
.
epsilon
)
def
_functional_rmsnorm_then_quant
(
self
,
rmsnorm_result_buffer
,
quant_result_buffer
,
input_tensor
,
weight_tensor
,
scale_tensor
):
if
self
.
quant_op
is
None
:
raise
RuntimeError
(
"_RMSNormAndQuantOpHelper was not initialized with a quant_op."
)
rmsnorm_out_tuple
=
self
.
_functional_rmsnorm
(
rmsnorm_result_buffer
,
input_tensor
,
weight_tensor
)
quant_out_tuple
=
torch
.
ops
.
higher_order
.
auto_functionalized
(
self
.
quant_op
,
result
=
quant_result_buffer
,
input
=
rmsnorm_out_tuple
[
1
],
scale
=
scale_tensor
)
return
quant_out_tuple
def
_functional_fused_add_rmsnorm_then_quant
(
self
,
quant_result_buffer
,
input_tensor
,
residual_tensor
,
weight_tensor
,
scale_tensor
):
if
self
.
quant_op
is
None
:
raise
RuntimeError
(
"_RMSNormAndQuantOpHelper was not initialized with a quant_op."
)
fused_add_rmsnorm_out_tuple
=
self
.
_functional_fused_add_rmsnorm
(
input_tensor
,
residual_tensor
,
weight_tensor
)
quant_out_tuple
=
torch
.
ops
.
higher_order
.
auto_functionalized
(
self
.
quant_op
,
result
=
quant_result_buffer
,
input
=
fused_add_rmsnorm_out_tuple
[
1
],
scale
=
scale_tensor
)
return
quant_out_tuple
,
fused_add_rmsnorm_out_tuple
[
2
]
class
_SequenceParallelPatternHelper
(
_RMSNormAndQuantOpHelper
):
"""Helper for sequence parallelism patterns."""
def
__init__
(
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
,
quant_op
:
Optional
[
torch
.
_ops
.
OpOverload
]
=
None
,
**
kwargs
):
super
().
__init__
(
epsilon
,
dtype
,
device
,
quant_op
=
quant_op
,
**
kwargs
)
self
.
tp_group
=
get_tp_group
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
def
_all_reduce
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
tensor_model_parallel_all_reduce
(
x
)
def
_reduce_scatter
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
ops
.
vllm
.
reduce_scatter
.
default
(
x
,
dim
=
0
,
world_size
=
self
.
tp_size
,
group_name
=
self
.
tp_group
.
unique_name
)
def
_all_gather
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
ops
.
vllm
.
all_gather
.
default
(
x
,
dim
=
0
,
world_size
=
self
.
tp_size
,
group_name
=
self
.
tp_group
.
unique_name
)
class
FirstAllReduceRMSNormPattern
(
_SequenceParallelPatternHelper
):
def
get_inputs
(
self
):
def
get_inputs
(
self
):
arg2_1
=
torch
.
empty
([
16
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
input
=
torch
.
empty
([
1
,
8
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
mul_6
=
torch
.
tensor
([[
3
,
7
,
1
,
4
,
9
,
2
,
5
,
0
]],
device
=
self
.
device
,
dtype
=
torch
.
long
)
unsqueeze
=
torch
.
rand
([
1
,
8
,
1
],
device
=
self
.
device
,
\
dtype
=
self
.
dtype
)
>
0.5
full_default
=
torch
.
zeros
([
1
,
8
,
4
],
device
=
self
.
device
,
\
dtype
=
self
.
dtype
)
permute
=
torch
.
empty
([
1
,
8
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
permute
=
torch
.
empty
([
1
,
8
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
arg3_1
=
torch
.
empty
([
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
arg3_1
=
torch
.
empty
([
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
return
[
arg2_1
,
mul_6
,
unsqueeze
,
full_defaul
t
,
permute
,
arg3_1
]
return
[
inpu
t
,
permute
,
arg3_1
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
pattern
(
def
pattern
(
arg2_1
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
mul_6
:
torch
.
Tensor
,
unsqueeze
:
torch
.
Tensor
,
full_default
:
torch
.
Tensor
,
permute
:
torch
.
Tensor
,
permute
:
torch
.
Tensor
,
arg3_1
:
torch
.
Tensor
,
arg3_1
:
torch
.
Tensor
,
):
):
embedding
=
torch
.
ops
.
aten
.
embedding
.
default
(
arg2_1
,
mul_6
)
all_reduce
=
self
.
_all_reduce
(
input
)
where
=
torch
.
ops
.
aten
.
where
.
self
(
unsqueeze
,
full_default
,
rmsnorm
=
self
.
_functional_rmsnorm
(
permute
,
all_reduce
,
arg3_1
)
embedding
)
all_reduce
=
tensor_model_parallel_all_reduce
(
where
)
rmsnorm
=
torch
.
ops
.
higher_order
.
auto_functionalized
(
torch
.
ops
.
_C
.
rms_norm
.
default
,
result
=
permute
,
input
=
all_reduce
,
weight
=
arg3_1
,
epsilon
=
self
.
epsilon
,
)
return
rmsnorm
[
1
],
all_reduce
return
rmsnorm
[
1
],
all_reduce
def
replacement
(
def
replacement
(
arg2_1
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
mul_6
:
torch
.
Tensor
,
unsqueeze
:
torch
.
Tensor
,
full_default
:
torch
.
Tensor
,
permute
:
torch
.
Tensor
,
permute
:
torch
.
Tensor
,
arg3_1
:
torch
.
Tensor
,
arg3_1
:
torch
.
Tensor
,
):
):
embedding
=
torch
.
ops
.
aten
.
embedding
.
default
(
arg2_1
,
mul_6
)
reduce_scatter
=
self
.
_reduce_scatter
(
input
)
where
=
torch
.
ops
.
aten
.
where
.
self
(
unsqueeze
,
full_default
,
embedding
)
tp
=
get_tp_group
()
tp_size
=
get_tensor_model_parallel_world_size
()
reduce_scatter
=
torch
.
ops
.
vllm
.
reduce_scatter
.
default
(
where
,
dim
=
0
,
world_size
=
tp_size
,
group_name
=
tp
.
unique_name
)
rmsnorm_result
=
torch
.
empty_like
(
reduce_scatter
)
rmsnorm_result
=
torch
.
empty_like
(
reduce_scatter
)
rmsnorm
=
torch
.
ops
.
higher_order
.
auto_functionalized
(
rmsnorm
=
self
.
_functional_rmsnorm
(
rmsnorm_result
,
reduce_scatter
,
torch
.
ops
.
_C
.
rms_norm
.
default
,
arg3_1
)
result
=
rmsnorm_result
,
input
=
reduce_scatter
,
weight
=
arg3_1
,
epsilon
=
self
.
epsilon
,
)
all_gather
=
torch
.
ops
.
vllm
.
all_gather
.
default
(
all_gather
=
self
.
_all_gather
(
rmsnorm
[
1
])
rmsnorm
[
1
],
dim
=
0
,
world_size
=
tp_size
,
group_name
=
tp
.
unique_name
)
return
all_gather
,
reduce_scatter
return
all_gather
,
reduce_scatter
...
@@ -104,7 +155,7 @@ class EmbeddingAllReduceRMSNormPattern(AllReduceRMSNormPattern):
...
@@ -104,7 +155,7 @@ class EmbeddingAllReduceRMSNormPattern(AllReduceRMSNormPattern):
pm
.
fwd_only
,
pm_pass
)
pm
.
fwd_only
,
pm_pass
)
class
MiddleAllReduceRMSNormPattern
(
AllReduceRMSNorm
Pattern
):
class
MiddleAllReduceRMSNormPattern
(
_SequenceParallel
Pattern
Helper
):
def
get_inputs
(
self
):
def
get_inputs
(
self
):
mm_1
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
mm_1
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
...
@@ -127,16 +178,9 @@ class MiddleAllReduceRMSNormPattern(AllReduceRMSNormPattern):
...
@@ -127,16 +178,9 @@ class MiddleAllReduceRMSNormPattern(AllReduceRMSNormPattern):
mm_1
:
torch
.
Tensor
,
mm_1
:
torch
.
Tensor
,
rms_norm_weights
:
torch
.
Tensor
,
rms_norm_weights
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
all_reduce
=
tensor_model_parallel_all_reduce
(
mm_1
)
all_reduce
=
self
.
_all_reduce
(
mm_1
)
rmsnorm
=
self
.
_functional_fused_add_rmsnorm
(
rmsnorm
=
torch
.
ops
.
higher_order
.
auto_functionalized
(
all_reduce
,
residual
,
rms_norm_weights
)
torch
.
ops
.
_C
.
fused_add_rms_norm
.
default
,
input
=
all_reduce
,
residual
=
residual
,
weight
=
rms_norm_weights
,
epsilon
=
self
.
epsilon
,
)
return
rmsnorm
[
1
],
rmsnorm
[
2
]
return
rmsnorm
[
1
],
rmsnorm
[
2
]
def
replacement
(
def
replacement
(
...
@@ -144,32 +188,17 @@ class MiddleAllReduceRMSNormPattern(AllReduceRMSNormPattern):
...
@@ -144,32 +188,17 @@ class MiddleAllReduceRMSNormPattern(AllReduceRMSNormPattern):
mm_1
:
torch
.
Tensor
,
mm_1
:
torch
.
Tensor
,
rms_norm_weights
:
torch
.
Tensor
,
rms_norm_weights
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
tp
=
get_tp_group
()
reduce_scatter
=
self
.
_reduce_scatter
(
mm_1
)
tp_size
=
get_tensor_model_parallel_world_size
()
rmsnorm
=
self
.
_functional_fused_add_rmsnorm
(
reduce_scatter
=
torch
.
ops
.
vllm
.
reduce_scatter
.
default
(
reduce_scatter
,
residual
,
rms_norm_weights
)
mm_1
,
dim
=
0
,
world_size
=
tp_size
,
group_name
=
tp
.
unique_name
)
all_gather
=
self
.
_all_gather
(
rmsnorm
[
1
])
# TODO is it possible to extract epsilon from somewhere
rmsnorm
=
torch
.
ops
.
higher_order
.
auto_functionalized
(
torch
.
ops
.
_C
.
fused_add_rms_norm
.
default
,
input
=
reduce_scatter
,
residual
=
residual
,
weight
=
rms_norm_weights
,
epsilon
=
self
.
epsilon
,
)
all_gather
=
torch
.
ops
.
vllm
.
all_gather
.
default
(
rmsnorm
[
1
],
dim
=
0
,
world_size
=
tp_size
,
group_name
=
tp
.
unique_name
)
return
all_gather
,
rmsnorm
[
2
]
return
all_gather
,
rmsnorm
[
2
]
pm
.
register_replacement
(
pattern
,
replacement
,
self
.
get_inputs
(),
pm
.
register_replacement
(
pattern
,
replacement
,
self
.
get_inputs
(),
pm
.
fwd_only
,
pm_pass
)
pm
.
fwd_only
,
pm_pass
)
class
LastAllReduceRMSNormPattern
(
AllReduceRMSNorm
Pattern
):
class
LastAllReduceRMSNormPattern
(
_SequenceParallel
Pattern
Helper
):
def
get_inputs
(
self
):
def
get_inputs
(
self
):
mm_1
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
mm_1
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
...
@@ -192,16 +221,9 @@ class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern):
...
@@ -192,16 +221,9 @@ class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern):
mm_1
:
torch
.
Tensor
,
mm_1
:
torch
.
Tensor
,
rms_norm_weights
:
torch
.
Tensor
,
rms_norm_weights
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
all_reduce
=
tensor_model_parallel_all_reduce
(
mm_1
)
all_reduce
=
self
.
_all_reduce
(
mm_1
)
rmsnorm
=
self
.
_functional_fused_add_rmsnorm
(
rmsnorm
=
torch
.
ops
.
higher_order
.
auto_functionalized
(
all_reduce
,
residual
,
rms_norm_weights
)
torch
.
ops
.
_C
.
fused_add_rms_norm
.
default
,
input
=
all_reduce
,
residual
=
residual
,
weight
=
rms_norm_weights
,
epsilon
=
self
.
epsilon
,
)
return
rmsnorm
[
1
]
return
rmsnorm
[
1
]
def
replacement
(
def
replacement
(
...
@@ -209,26 +231,185 @@ class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern):
...
@@ -209,26 +231,185 @@ class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern):
mm_1
:
torch
.
Tensor
,
mm_1
:
torch
.
Tensor
,
rms_norm_weights
:
torch
.
Tensor
,
rms_norm_weights
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
tp
=
get_tp_group
()
reduce_scatter
=
self
.
_reduce_scatter
(
mm_1
)
tp_size
=
get_tensor_model_parallel_world_size
()
rmsnorm
=
self
.
_functional_fused_add_rmsnorm
(
reduce_scatter
=
torch
.
ops
.
vllm
.
reduce_scatter
.
default
(
reduce_scatter
,
residual
,
rms_norm_weights
)
mm_1
,
dim
=
0
,
world_size
=
tp_size
,
group_name
=
tp
.
unique_name
)
normalized
=
self
.
_all_gather
(
rmsnorm
[
1
])
return
normalized
# TODO is it possible to extract epsilon from somewhere
rmsnorm
=
torch
.
ops
.
higher_order
.
auto_functionalized
(
pm
.
register_replacement
(
pattern
,
replacement
,
self
.
get_inputs
(),
torch
.
ops
.
_C
.
fused_add_rms_norm
.
default
,
pm
.
fwd_only
,
pm_pass
)
input
=
reduce_scatter
,
residual
=
residual
,
weight
=
rms_norm_weights
,
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
epsilon
=
self
.
epsilon
,
)
class
FirstAllReduceRMSNormStaticFP8Pattern
(
_SequenceParallelPatternHelper
):
def
__init__
(
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
,
op
:
torch
.
_ops
.
OpOverload
):
super
().
__init__
(
epsilon
,
dtype
,
device
,
quant_op
=
op
)
def
get_inputs
(
self
):
input
=
torch
.
zeros
([
1
,
8
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
rmsnorm_result
=
torch
.
empty
([
1
,
8
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
quant_result
=
torch
.
empty
([
1
,
8
,
4
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
)
weight
=
torch
.
empty
([
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
scale
=
torch
.
tensor
(
1.0
,
device
=
self
.
device
,
dtype
=
torch
.
float32
)
return
[
input
,
rmsnorm_result
,
quant_result
,
weight
,
scale
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
pattern
(
input
:
torch
.
Tensor
,
rmsnorm_result
:
torch
.
Tensor
,
quant_result
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
):
all_reduce
=
self
.
_all_reduce
(
input
)
static_fp8
=
self
.
_functional_rmsnorm_then_quant
(
rmsnorm_result
,
quant_result
,
all_reduce
,
weight
,
scale
)
return
static_fp8
[
1
],
all_reduce
def
replacement
(
input
:
torch
.
Tensor
,
rmsnorm_result
:
torch
.
Tensor
,
quant_result
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
):
reduce_scatter
=
self
.
_reduce_scatter
(
input
)
rmsnorm_result
=
torch
.
empty_like
(
reduce_scatter
,
dtype
=
rmsnorm_result
.
dtype
)
quant_result
=
torch
.
empty_like
(
rmsnorm_result
,
# Output of RMSNorm
dtype
=
quant_result
.
dtype
)
static_fp8
=
self
.
_functional_rmsnorm_then_quant
(
rmsnorm_result
,
quant_result
,
reduce_scatter
,
weight
,
scale
)
all_gather
=
self
.
_all_gather
(
static_fp8
[
1
])
return
all_gather
,
reduce_scatter
pm
.
register_replacement
(
pattern
,
replacement
,
self
.
get_inputs
(),
pm
.
fwd_only
,
pm_pass
)
class
MiddleAllReduceRMSNormStaticFP8Pattern
(
_SequenceParallelPatternHelper
):
def
__init__
(
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
,
op
:
torch
.
_ops
.
OpOverload
):
super
().
__init__
(
epsilon
,
dtype
,
device
,
quant_op
=
op
)
def
get_inputs
(
self
):
mm_1
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
residual
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
rms_norm_weights
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
result
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
)
scale
=
torch
.
empty
([
1
,
1
],
device
=
self
.
device
,
dtype
=
torch
.
float32
)
return
[
result
,
residual
,
mm_1
,
rms_norm_weights
,
scale
,
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
pattern
(
result
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
mm_1
:
torch
.
Tensor
,
rms_norm_weights
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
all_reduce
=
self
.
_all_reduce
(
mm_1
)
static_fp8
,
rmsnorm_residual_out
=
self
.
_functional_fused_add_rmsnorm_then_quant
(
# noqa: E501
result
,
all_reduce
,
residual
,
rms_norm_weights
,
scale
)
return
static_fp8
[
1
],
rmsnorm_residual_out
def
replacement
(
result
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
mm_1
:
torch
.
Tensor
,
rms_norm_weights
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
reduce_scatter
=
self
.
_reduce_scatter
(
mm_1
)
quant_result_buf
=
torch
.
empty_like
(
reduce_scatter
,
dtype
=
result
.
dtype
)
static_fp8
,
rmsnorm_residual_out
=
self
.
_functional_fused_add_rmsnorm_then_quant
(
# noqa: E501
quant_result_buf
,
reduce_scatter
,
residual
,
rms_norm_weights
,
scale
)
all_gather
=
self
.
_all_gather
(
static_fp8
[
1
])
return
all_gather
,
rmsnorm_residual_out
pm
.
register_replacement
(
pattern
,
replacement
,
self
.
get_inputs
(),
pm
.
fwd_only
,
pm_pass
)
class
LastAllReduceRMSNormStaticFP8Pattern
(
_SequenceParallelPatternHelper
):
def
__init__
(
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
,
op
:
torch
.
_ops
.
OpOverload
):
super
().
__init__
(
epsilon
,
dtype
,
device
,
quant_op
=
op
)
def
get_inputs
(
self
):
mm_1
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
residual
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
rms_norm_weights
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
result
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
)
scale
=
torch
.
empty
([
1
,
1
],
device
=
self
.
device
,
dtype
=
torch
.
float32
)
return
[
result
,
residual
,
mm_1
,
rms_norm_weights
,
scale
,
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
normalized
=
torch
.
ops
.
vllm
.
all_gather
.
default
(
def
pattern
(
rmsnorm
[
1
],
result
:
torch
.
Tensor
,
dim
=
0
,
residual
:
torch
.
Tensor
,
world_size
=
tp_size
,
mm_1
:
torch
.
Tensor
,
group_name
=
tp
.
unique_name
)
rms_norm_weights
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
all_reduce
=
self
.
_all_reduce
(
mm_1
)
static_fp8
,
_
=
self
.
_functional_fused_add_rmsnorm_then_quant
(
result
,
all_reduce
,
residual
,
rms_norm_weights
,
scale
)
return
static_fp8
[
1
]
def
replacement
(
result
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
mm_1
:
torch
.
Tensor
,
rms_norm_weights
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
reduce_scatter
=
self
.
_reduce_scatter
(
mm_1
)
quant_result_buf
=
torch
.
empty_like
(
reduce_scatter
,
dtype
=
result
.
dtype
)
static_fp8
,
_
=
self
.
_functional_fused_add_rmsnorm_then_quant
(
quant_result_buf
,
reduce_scatter
,
residual
,
rms_norm_weights
,
scale
)
normalized
=
self
.
_all_gather
(
static_fp8
[
1
])
return
normalized
return
normalized
pm
.
register_replacement
(
pattern
,
replacement
,
self
.
get_inputs
(),
pm
.
register_replacement
(
pattern
,
replacement
,
self
.
get_inputs
(),
...
@@ -236,21 +417,54 @@ class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern):
...
@@ -236,21 +417,54 @@ class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern):
class
SequenceParallelismPass
(
VllmInductorPass
):
class
SequenceParallelismPass
(
VllmInductorPass
):
"""
This pass enables sequence parallelism for models.
It identifies patterns where an AllReduce operation is followed by
an RMSNorm (or RMSNorm and then Quantization) operation.
These patterns are replaced with a ReduceScatter operation, followed by
a local RMSNorm/Quantization, and then an AllGather operation.
The general transformation is:
Input -> AllReduce -> RMSNorm -> Output
becomes
Input -> ReduceScatter -> RMSNorm -> AllGather -> Output
While this pass itself does not directly yield performance improvements,
it lays the groundwork for subsequent fusion passes, such as
GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can
significantly reduce communication overhead and improve overall model
performance.
"""
def
__init__
(
self
,
config
:
VllmConfig
):
def
__init__
(
self
,
config
:
VllmConfig
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
patterns
:
PatternMatcherPass
=
PatternMatcherPass
(
self
.
patterns
:
PatternMatcherPass
=
PatternMatcherPass
(
pass_name
=
"sequence_parallelism_pass"
)
pass_name
=
"sequence_parallelism_pass"
)
for
epsilon
in
[
1e-5
,
1e-6
]:
for
epsilon
in
[
1e-5
,
1e-6
]:
EmbeddingAllReduceRMSNormPattern
(
# RMSNorm + Static FP8 quantization patterns
epsilon
,
self
.
model_dtype
,
self
.
device
).
register
(
self
.
patterns
)
fp8_quant_op
=
torch
.
ops
.
_C
.
static_scaled_fp8_quant
.
default
FirstAllReduceRMSNormStaticFP8Pattern
(
epsilon
,
self
.
model_dtype
,
self
.
device
,
fp8_quant_op
).
register
(
self
.
patterns
)
MiddleAllReduceRMSNormStaticFP8Pattern
(
epsilon
,
self
.
model_dtype
,
self
.
device
,
fp8_quant_op
).
register
(
self
.
patterns
)
LastAllReduceRMSNormStaticFP8Pattern
(
epsilon
,
self
.
model_dtype
,
self
.
device
,
fp8_quant_op
).
register
(
self
.
patterns
)
# Normal RMSNorm patterns
FirstAllReduceRMSNormPattern
(
epsilon
,
self
.
model_dtype
,
self
.
device
).
register
(
self
.
patterns
)
MiddleAllReduceRMSNormPattern
(
epsilon
,
self
.
model_dtype
,
MiddleAllReduceRMSNormPattern
(
epsilon
,
self
.
model_dtype
,
self
.
device
).
register
(
self
.
patterns
)
self
.
device
).
register
(
self
.
patterns
)
LastAllReduceRMSNormPattern
(
epsilon
,
self
.
model_dtype
,
LastAllReduceRMSNormPattern
(
epsilon
,
self
.
model_dtype
,
self
.
device
).
register
(
self
.
patterns
)
self
.
device
).
register
(
self
.
patterns
)
# WARNING: This is a hack to clear the pattern matcher cache
# WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon.
# and allow multiple values of epsilon.
torch
.
_inductor
.
pattern_matcher
.
_seen_patterns
.
clear
()
torch
.
_inductor
.
pattern_matcher
.
_seen_patterns
.
clear
()
...
...
vllm/config.py
View file @
e6327c9b
...
@@ -3802,11 +3802,11 @@ class PassConfig:
...
@@ -3802,11 +3802,11 @@ class PassConfig:
its own stages (before, after, maybe in-between)."""
its own stages (before, after, maybe in-between)."""
dump_graph_dir
:
Path
=
Path
(
"."
)
dump_graph_dir
:
Path
=
Path
(
"."
)
"""Directory to dump the graphs."""
"""Directory to dump the graphs."""
enable_fusion
:
bool
=
True
enable_fusion
:
bool
=
field
(
default_factory
=
lambda
:
not
envs
.
VLLM_USE_V1
)
"""Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass."""
"""Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass."""
enable_attn_fusion
:
bool
=
False
enable_attn_fusion
:
bool
=
False
"""Whether to enable the custom attention+quant fusion pass."""
"""Whether to enable the custom attention+quant fusion pass."""
enable_noop
:
bool
=
True
enable_noop
:
bool
=
field
(
default_factory
=
lambda
:
not
envs
.
VLLM_USE_V1
)
"""Whether to enable the custom no-op elimination pass."""
"""Whether to enable the custom no-op elimination pass."""
enable_sequence_parallelism
:
bool
=
False
enable_sequence_parallelism
:
bool
=
False
"""Whether to enable sequence parallelism."""
"""Whether to enable sequence parallelism."""
...
@@ -4451,8 +4451,6 @@ class VllmConfig:
...
@@ -4451,8 +4451,6 @@ class VllmConfig:
# By default, V1 uses piecewise CUDA graphs. If full_cuda_graph
# By default, V1 uses piecewise CUDA graphs. If full_cuda_graph
# is set to True, full CUDA graphs will be used.
# is set to True, full CUDA graphs will be used.
self
.
compilation_config
.
cudagraph_num_of_warmups
=
1
self
.
compilation_config
.
cudagraph_num_of_warmups
=
1
self
.
compilation_config
.
pass_config
.
enable_fusion
=
False
self
.
compilation_config
.
pass_config
.
enable_noop
=
False
self
.
compilation_config
.
level
=
CompilationLevel
.
PIECEWISE
self
.
compilation_config
.
level
=
CompilationLevel
.
PIECEWISE
self
.
compilation_config
.
set_splitting_ops_for_v1
()
self
.
compilation_config
.
set_splitting_ops_for_v1
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment