Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
71ea614d
Unverified
Commit
71ea614d
authored
May 23, 2025
by
cascade
Committed by
GitHub
May 23, 2025
Browse files
[Feature]Add async tensor parallelism using compilation pass (#17882)
Signed-off-by:
cascade812
<
cascade812@outlook.com
>
parent
4c611348
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
472 additions
and
56 deletions
+472
-56
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+1
-0
tests/compile/backend.py
tests/compile/backend.py
+18
-0
tests/compile/test_async_tp.py
tests/compile/test_async_tp.py
+248
-0
tests/compile/test_fusion.py
tests/compile/test_fusion.py
+17
-19
tests/compile/test_sequence_parallelism.py
tests/compile/test_sequence_parallelism.py
+18
-29
vllm/compilation/collective_fusion.py
vllm/compilation/collective_fusion.py
+126
-0
vllm/compilation/pass_manager.py
vllm/compilation/pass_manager.py
+3
-0
vllm/compilation/sequence_parallelism.py
vllm/compilation/sequence_parallelism.py
+5
-4
vllm/compilation/vllm_inductor_pass.py
vllm/compilation/vllm_inductor_pass.py
+2
-1
vllm/config.py
vllm/config.py
+10
-1
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+24
-2
No files found.
.buildkite/test-pipeline.yaml
View file @
71ea614d
...
@@ -316,6 +316,7 @@ steps:
...
@@ -316,6 +316,7 @@ steps:
-
pytest -v -s compile/test_fusion.py
-
pytest -v -s compile/test_fusion.py
-
pytest -v -s compile/test_silu_mul_quant_fusion.py
-
pytest -v -s compile/test_silu_mul_quant_fusion.py
-
pytest -v -s compile/test_sequence_parallelism.py
-
pytest -v -s compile/test_sequence_parallelism.py
-
pytest -v -s compile/test_async_tp.py
-
label
:
PyTorch Fullgraph Smoke Test
# 9min
-
label
:
PyTorch Fullgraph Smoke Test
# 9min
mirror_hardwares
:
[
amdexperimental
,
amdproduction
]
mirror_hardwares
:
[
amdexperimental
,
amdproduction
]
...
...
tests/compile/backend.py
View file @
71ea614d
...
@@ -5,6 +5,8 @@ from typing import Callable, Union
...
@@ -5,6 +5,8 @@ from typing import Callable, Union
from
torch
import
fx
from
torch
import
fx
from
vllm.compilation.fx_utils
import
(
find_specified_fn
,
find_specified_fn_maybe
)
from
vllm.compilation.inductor_pass
import
InductorPass
from
vllm.compilation.inductor_pass
import
InductorPass
from
vllm.config
import
get_current_vllm_config
from
vllm.config
import
get_current_vllm_config
...
@@ -44,3 +46,19 @@ class TestBackend:
...
@@ -44,3 +46,19 @@ class TestBackend:
self
.
graph_post_pass
=
deepcopy
(
graph
)
self
.
graph_post_pass
=
deepcopy
(
graph
)
# assign by reference, will reflect the final state of the graph
# assign by reference, will reflect the final state of the graph
self
.
final_graph
=
graph
self
.
final_graph
=
graph
def
check_before_ops
(
self
,
ops
,
find_fn
=
find_specified_fn
,
\
find_fn_maybe
=
find_specified_fn_maybe
,
\
ops_fully_replaced
=
True
):
for
op
in
ops
:
find_fn
(
self
.
graph_pre_pass
.
nodes
,
op
)
if
ops_fully_replaced
:
assert
find_fn_maybe
(
self
.
graph_post_pass
.
nodes
,
op
)
is
None
def
check_after_ops
(
self
,
ops
,
find_fn
=
find_specified_fn
,
\
find_fn_maybe
=
find_specified_fn_maybe
):
for
op
in
ops
:
find_fn
(
self
.
graph_post_pass
.
nodes
,
op
)
assert
find_fn_maybe
(
self
.
graph_pre_pass
.
nodes
,
op
)
is
None
tests/compile/test_async_tp.py
0 → 100644
View file @
71ea614d
# SPDX-License-Identifier: Apache-2.0
import
json
import
pytest
import
torch
import
vllm.envs
as
envs
from
vllm.compilation.collective_fusion
import
AsyncTPPass
from
vllm.config
import
(
CompilationConfig
,
DeviceConfig
,
ModelConfig
,
PassConfig
,
VllmConfig
)
from
vllm.distributed
import
(
tensor_model_parallel_all_gather
,
tensor_model_parallel_reduce_scatter
)
from
vllm.distributed.parallel_state
import
(
init_distributed_environment
,
initialize_model_parallel
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
update_environment_variables
from
..models.registry
import
HF_EXAMPLE_MODELS
from
..utils
import
(
compare_two_settings
,
create_new_process_for_each_test
,
multi_gpu_test
)
from
.backend
import
TestBackend
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
class
TestMMRSModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
=
16
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
gate_proj
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
(
self
.
hidden_size
*
2
,
hidden_size
)),
requires_grad
=
False
)
# Initialize weights
torch
.
nn
.
init
.
normal_
(
self
.
gate_proj
,
std
=
0.02
)
def
forward
(
self
,
hidden_states
):
"""
Forward pass implementing the mm + reduce scatter in the FX graph
"""
# Reshape input
view
=
hidden_states
.
reshape
(
-
1
,
self
.
hidden_size
)
# matrix multiplication
permute
=
self
.
gate_proj
.
permute
(
1
,
0
)
mm
=
torch
.
mm
(
view
,
permute
)
reduce_scatter
=
tensor_model_parallel_reduce_scatter
(
mm
,
dim
=
0
)
return
reduce_scatter
def
ops_in_model_before
(
self
):
return
[
torch
.
ops
.
vllm
.
reduce_scatter
.
default
]
def
ops_in_model_after
(
self
):
return
[
torch
.
ops
.
symm_mem
.
fused_matmul_reduce_scatter
.
default
]
class
TestAGMMModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
=
16
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
(
hidden_size
,
hidden_size
)),
requires_grad
=
False
)
# Initialize weights
torch
.
nn
.
init
.
normal_
(
self
.
weight
,
std
=
0.02
)
def
forward
(
self
,
hidden_states
):
"""
Forward pass implementing the mm + all gather in the FX graph
"""
# Reshape input
view
=
hidden_states
.
reshape
(
-
1
,
self
.
hidden_size
)
all_gather
=
tensor_model_parallel_all_gather
(
view
,
dim
=
0
)
permute
=
self
.
weight
.
permute
(
1
,
0
)
mm
=
torch
.
mm
(
all_gather
,
permute
)
return
mm
def
ops_in_model_before
(
self
):
return
[
torch
.
ops
.
vllm
.
all_gather
.
default
]
def
ops_in_model_after
(
self
):
return
[
torch
.
ops
.
symm_mem
.
fused_all_gather_matmul
.
default
]
@
multi_gpu_test
(
num_gpus
=
2
)
@
pytest
.
mark
.
parametrize
(
"test_model"
,
[
TestMMRSModel
,
TestAGMMModel
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
skipif
(
envs
.
VLLM_TARGET_DEVICE
not
in
[
"cuda"
],
reason
=
"Only test on CUDA"
)
def
test_async_tp_pass_replace
(
test_model
:
str
,
batch_size
:
int
,
seq_len
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
):
num_processes
=
2
def
run_torch_spawn
(
fn
,
nprocs
):
# need to use torch.mp.spawn otherwise will have problems with
# torch.distributed and cuda
torch
.
multiprocessing
.
spawn
(
fn
,
args
=
(
num_processes
,
test_model
,
batch_size
,
seq_len
,
hidden_size
,
dtype
),
nprocs
=
nprocs
)
run_torch_spawn
(
async_tp_pass_on_test_model
,
num_processes
)
def
async_tp_pass_on_test_model
(
local_rank
:
int
,
world_size
:
int
,
test_model_cls
:
torch
.
nn
.
Module
,
batch_size
:
int
,
seq_len
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
):
current_platform
.
seed_everything
(
0
)
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
torch
.
cuda
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
torch
.
set_default_dtype
(
dtype
)
update_environment_variables
({
'RANK'
:
str
(
local_rank
),
'LOCAL_RANK'
:
str
(
local_rank
),
'WORLD_SIZE'
:
str
(
world_size
),
'MASTER_ADDR'
:
'localhost'
,
'MASTER_PORT'
:
'12345'
,
})
# initialize distributed
init_distributed_environment
()
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
# configure vllm config for SequenceParallelismPass
vllm_config
=
VllmConfig
()
vllm_config
.
compilation_config
=
CompilationConfig
(
pass_config
=
PassConfig
(
enable_async_tp
=
True
,
),
)
vllm_config
.
device_config
=
DeviceConfig
(
device
=
torch
.
device
(
"cuda"
))
# this is a fake model name to construct the model config
# in the vllm_config, it's not really used.
model_name
=
"nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
vllm_config
.
model_config
=
ModelConfig
(
model
=
model_name
,
task
=
"auto"
,
tokenizer
=
model_name
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
True
,
dtype
=
dtype
,
seed
=
42
)
async_tp_pass
=
AsyncTPPass
(
vllm_config
)
backend
=
TestBackend
(
async_tp_pass
)
model
=
test_model_cls
(
hidden_size
)
hidden_states
=
torch
.
randn
((
batch_size
*
seq_len
,
hidden_size
),
dtype
=
dtype
,
requires_grad
=
False
)
compiled_model
=
torch
.
compile
(
model
,
backend
=
backend
)
compiled_model
(
hidden_states
)
# In pre-nodes, all gather or reduce scatter should exist,
# fused_matmul_reduce_scatter or fused_all_gather_matmul should not
backend
.
check_before_ops
(
model
.
ops_in_model_before
(),
ops_fully_replaced
=
False
)
# In post-nodes, fused_matmul_reduce_scatter or \
# fused_all_gather_matmul should exist
backend
.
check_after_ops
(
model
.
ops_in_model_after
())
@
create_new_process_for_each_test
()
@
pytest
.
mark
.
parametrize
(
"model_id"
,
[
"meta-llama/Llama-3.2-1B-Instruct"
])
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"async_tp_enabled"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"distributed_backend"
,
[
"mp"
])
@
pytest
.
mark
.
parametrize
(
"eager_mode"
,
[
False
,
True
])
def
test_async_tp_pass_correctness
(
model_id
:
str
,
tp_size
:
int
,
async_tp_enabled
:
bool
,
distributed_backend
:
str
,
eager_mode
:
bool
,
num_gpus_available
:
int
,
):
model_info
=
HF_EXAMPLE_MODELS
.
find_hf_info
(
model_id
)
model_info
.
check_transformers_version
(
on_fail
=
"skip"
)
model_info
.
check_available_online
(
on_fail
=
"skip"
)
pp_size
=
1
if
num_gpus_available
<
tp_size
:
pytest
.
skip
(
f
"Need at least
{
tp_size
}
x
{
pp_size
}
GPUs"
)
common_args
=
[
"--dtype"
,
"bfloat16"
,
"--max-model-len"
,
"2048"
,
"--max-num-seqs"
,
"8"
,
]
if
eager_mode
:
common_args
.
append
(
"--enforce-eager"
)
compilation_config
=
{
'level'
:
3
,
'compile_sizes'
:
[
2
,
4
,
8
],
'splitting_ops'
:
[],
'pass_config'
:
{
'enable_async_tp'
:
async_tp_enabled
},
}
async_tp_env
=
tp_env
=
{
"VLLM_USE_V1"
:
"1"
,
}
aysnc_tp_args
=
[
*
common_args
,
"--tensor-parallel-size"
,
str
(
tp_size
),
"--distributed-executor-backend"
,
distributed_backend
,
"--compilation_config"
,
json
.
dumps
(
compilation_config
),
]
tp_args
=
[
*
common_args
,
"--tensor-parallel-size"
,
str
(
tp_size
),
"--distributed-executor-backend"
,
"mp"
,
]
compare_two_settings
(
model_id
,
aysnc_tp_args
,
tp_args
,
async_tp_env
,
tp_env
,
method
=
"generate"
)
tests/compile/test_fusion.py
View file @
71ea614d
...
@@ -29,6 +29,10 @@ class TestModel(torch.nn.Module):
...
@@ -29,6 +29,10 @@ class TestModel(torch.nn.Module):
self
.
cutlass_fp8_enabled
=
cutlass_fp8_enabled
self
.
cutlass_fp8_enabled
=
cutlass_fp8_enabled
self
.
norm
=
[
RMSNorm
(
hidden_size
,
eps
)
for
_
in
range
(
3
)]
self
.
norm
=
[
RMSNorm
(
hidden_size
,
eps
)
for
_
in
range
(
3
)]
self
.
wscale
=
[
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
for
_
in
range
(
2
)]
self
.
wscale
=
[
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
for
_
in
range
(
2
)]
self
.
key
=
QuantKey
(
dtype
=
FP8_DTYPE
,
static
=
static
,
per_tensor
=
static
,
symmetric
=
True
)
if
static
:
if
static
:
self
.
scale
=
[
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
for
_
in
range
(
2
)]
self
.
scale
=
[
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
for
_
in
range
(
2
)]
else
:
else
:
...
@@ -59,6 +63,15 @@ class TestModel(torch.nn.Module):
...
@@ -59,6 +63,15 @@ class TestModel(torch.nn.Module):
y3
,
resid
=
self
.
norm
[
2
](
x3
,
resid
)
# use resid here
y3
,
resid
=
self
.
norm
[
2
](
x3
,
resid
)
# use resid here
return
y3
return
y3
def
ops_in_model_before
(
self
):
return
[
QUANT_OPS
[
self
.
key
]]
def
ops_in_model_after
(
self
):
return
[
FUSED_OPS
[
FusedRMSQuantKey
(
self
.
key
,
False
)],
FUSED_OPS
[
FusedRMSQuantKey
(
self
.
key
,
True
)]
]
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
64
,
3392
,
4096
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
64
,
3392
,
4096
])
...
@@ -107,25 +120,10 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
...
@@ -107,25 +120,10 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
torch
.
testing
.
assert_close
(
result
,
result2
,
atol
=
ATOL
,
rtol
=
RTOL
)
torch
.
testing
.
assert_close
(
result
,
result2
,
atol
=
ATOL
,
rtol
=
RTOL
)
# Check substitution worked
pre_nodes
=
backend
.
graph_pre_pass
.
nodes
post_nodes
=
backend
.
graph_post_pass
.
nodes
# static is per-tensor, dynamic is per-token
key
=
QuantKey
(
dtype
=
FP8_DTYPE
,
static
=
static
,
per_tensor
=
static
,
symmetric
=
True
)
rms_quant
=
FUSED_OPS
[
FusedRMSQuantKey
(
key
,
False
)]
add_rms_quant
=
FUSED_OPS
[
FusedRMSQuantKey
(
key
,
True
)]
fp8_quant
=
QUANT_OPS
[
key
]
# In pre-nodes, fp8 quant should be there and fused kernels should not
# In pre-nodes, fp8 quant should be there and fused kernels should not
assert
find_auto_fn_maybe
(
pre_nodes
,
rms_quant
)
is
None
backend
.
check_before_ops
(
model
.
ops_in_model_before
(),
find_auto_fn
,
assert
find_auto_fn_maybe
(
pre_nodes
,
add_rms_quant
)
is
None
find_auto_fn_maybe
)
find_auto_fn
(
pre_nodes
,
fp8_quant
)
# In post-nodes, fused kernels should be there and fp8 quant should not
# In post-nodes, fused kernels should be there and fp8 quant should not
find_auto_fn
(
post_nodes
,
rms_quant
)
backend
.
check_after_ops
(
model
.
ops_in_model_after
(),
find_auto_fn
,
find_auto_fn
(
post_nodes
,
add_rms_quant
)
find_auto_fn_maybe
)
assert
find_auto_fn_maybe
(
post_nodes
,
fp8_quant
)
is
None
tests/compile/test_sequence_parallelism.py
View file @
71ea614d
...
@@ -5,9 +5,7 @@ import torch
...
@@ -5,9 +5,7 @@ import torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.compilation.fix_functionalization
import
FixFunctionalizationPass
from
vllm.compilation.fix_functionalization
import
FixFunctionalizationPass
from
vllm.compilation.fx_utils
import
(
find_auto_fn
,
find_auto_fn_maybe
,
from
vllm.compilation.fx_utils
import
find_auto_fn
,
find_auto_fn_maybe
,
is_func
find_specified_fn
,
find_specified_fn_maybe
,
is_func
)
from
vllm.compilation.sequence_parallelism
import
SequenceParallelismPass
from
vllm.compilation.sequence_parallelism
import
SequenceParallelismPass
from
vllm.config
import
(
CompilationConfig
,
DeviceConfig
,
ModelConfig
,
from
vllm.config
import
(
CompilationConfig
,
DeviceConfig
,
ModelConfig
,
PassConfig
,
VllmConfig
)
PassConfig
,
VllmConfig
)
...
@@ -21,17 +19,6 @@ from vllm.utils import update_environment_variables
...
@@ -21,17 +19,6 @@ from vllm.utils import update_environment_variables
from
..utils
import
multi_gpu_test
from
..utils
import
multi_gpu_test
from
.backend
import
TestBackend
from
.backend
import
TestBackend
OPS_IN_MODEL_BEFORE
=
[
torch
.
ops
.
vllm
.
all_reduce
.
default
,
]
OPS_IN_MODEL_AFTER
=
[
torch
.
ops
.
vllm
.
reduce_scatter
.
default
,
torch
.
ops
.
vllm
.
all_gather
.
default
,
]
OPS_IN_MODEL
=
[
torch
.
ops
.
_C
.
fused_add_rms_norm
.
default
]
prompts
=
[
prompts
=
[
"Hello, my name is"
,
"Hello, my name is"
,
"The president of the United States is"
,
"The president of the United States is"
,
...
@@ -78,6 +65,18 @@ class TestModel(torch.nn.Module):
...
@@ -78,6 +65,18 @@ class TestModel(torch.nn.Module):
return
norm_output
,
residual_output
return
norm_output
,
residual_output
def
ops_in_model_before
(
self
):
return
[
torch
.
ops
.
vllm
.
all_reduce
.
default
]
def
ops_in_model_after
(
self
):
return
[
torch
.
ops
.
vllm
.
reduce_scatter
.
default
,
torch
.
ops
.
vllm
.
all_gather
.
default
]
def
ops_in_model
(
self
):
return
[
torch
.
ops
.
_C
.
fused_add_rms_norm
.
default
]
@
multi_gpu_test
(
num_gpus
=
2
)
@
multi_gpu_test
(
num_gpus
=
2
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
...
@@ -156,26 +155,16 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
...
@@ -156,26 +155,16 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
compiled_model_func
=
torch
.
compile
(
model
,
backend
=
backend_func
)
compiled_model_func
=
torch
.
compile
(
model
,
backend
=
backend_func
)
compiled_model_func
(
hidden_states
,
residual
)
compiled_model_func
(
hidden_states
,
residual
)
# Check substitution worked
pre_nodes
=
backend_no_func
.
graph_pre_pass
.
nodes
post_nodes
=
backend_no_func
.
graph_post_pass
.
nodes
# In pre-nodes, all reduce should be there,
# In pre-nodes, all reduce should be there,
# reduce scatter and all gather should not
# reduce scatter and all gather should not
for
op
in
OPS_IN_MODEL_BEFORE
:
backend_no_func
.
check_before_ops
(
model
.
ops_in_model_before
())
find_specified_fn
(
pre_nodes
,
op
)
for
op
in
OPS_IN_MODEL_AFTER
:
assert
find_specified_fn_maybe
(
pre_nodes
,
op
)
is
None
# In post-nodes, reduce scatter and all gather should be there,
# In post-nodes, reduce scatter and all gather should be there,
# all reduce should not
# all reduce should not
for
op
in
OPS_IN_MODEL_AFTER
:
backend_no_func
.
check_after_ops
(
model
.
ops_in_model_after
())
find_specified_fn
(
post_nodes
,
op
)
for
op
in
OPS_IN_MODEL_BEFORE
:
assert
find_specified_fn_maybe
(
post_nodes
,
op
)
is
None
# check if the functionalization pass is applied
# check if the functionalization pass is applied
for
op
in
OPS_IN_MODEL
:
for
op
in
model
.
ops_in_model
()
:
find_auto_fn
(
backend_no_func
.
graph_post_pass
.
nodes
,
op
)
find_auto_fn
(
backend_no_func
.
graph_post_pass
.
nodes
,
op
)
assert
find_auto_fn_maybe
(
backend_func
.
graph_post_pass
.
nodes
,
assert
find_auto_fn_maybe
(
backend_func
.
graph_post_pass
.
nodes
,
op
)
is
None
# noqa: E501
op
)
is
None
# noqa: E501
...
@@ -183,7 +172,7 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
...
@@ -183,7 +172,7 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
# make sure the ops were all de-functionalized
# make sure the ops were all de-functionalized
found
=
dict
()
found
=
dict
()
for
node
in
backend_func
.
graph_post_pass
.
nodes
:
for
node
in
backend_func
.
graph_post_pass
.
nodes
:
for
op
in
OPS_IN_MODEL
:
for
op
in
model
.
ops_in_model
()
:
if
is_func
(
node
,
op
):
if
is_func
(
node
,
op
):
found
[
op
]
=
True
found
[
op
]
=
True
assert
all
(
found
[
op
]
for
op
in
OPS_IN_MODEL
)
assert
all
(
found
[
op
]
for
op
in
model
.
ops_in_model
()
)
vllm/compilation/collective_fusion.py
0 → 100644
View file @
71ea614d
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
import
torch
import
torch._inductor.pattern_matcher
as
pm
import
torch.fx
as
fx
from
torch._inductor.pattern_matcher
import
PatternMatcherPass
from
torch.distributed._symmetric_memory
import
enable_symm_mem_for_group
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_tp_group
from
vllm.distributed.parallel_state
import
(
get_tensor_model_parallel_world_size
)
from
vllm.logger
import
init_logger
from
.vllm_inductor_pass
import
VllmInductorPass
logger
=
init_logger
(
__name__
)
class
BasePattern
:
def
__init__
(
self
,
dtype
:
torch
.
dtype
,
device
:
str
):
self
.
dtype
=
dtype
self
.
device
=
device
self
.
tp
=
get_tp_group
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
class
GEMMReduceScatterPattern
(
BasePattern
):
def
get_inputs
(
self
):
mul
=
torch
.
empty
([
16
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
mm_weight
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
return
[
mul
,
mm_weight
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
pattern
(
mul
:
torch
.
Tensor
,
mm_weight
:
torch
.
Tensor
):
mm
=
torch
.
ops
.
aten
.
mm
.
default
(
mul
,
mm_weight
)
reduce_scatter
=
torch
.
ops
.
vllm
.
reduce_scatter
.
default
(
mm
,
dim
=
0
,
world_size
=
self
.
tp_size
,
group_name
=
self
.
tp
.
unique_name
)
return
reduce_scatter
def
replacement
(
mul
:
torch
.
Tensor
,
mm_weight
:
torch
.
Tensor
):
gemm_rs
=
torch
.
ops
.
symm_mem
.
fused_matmul_reduce_scatter
(
mul
,
mm_weight
,
"avg"
,
scatter_dim
=
0
,
group_name
=
self
.
tp
.
device_group
.
group_name
,
)
return
gemm_rs
pm
.
register_replacement
(
pattern
,
replacement
,
self
.
get_inputs
(),
pm
.
fwd_only
,
pm_pass
)
class
AllGatherGEMMPattern
(
BasePattern
):
def
get_inputs
(
self
):
x
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
weight
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
return
[
x
,
weight
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
pattern
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
all_gather
=
torch
.
ops
.
vllm
.
all_gather
.
default
(
x
,
dim
=
0
,
world_size
=
self
.
tp_size
,
group_name
=
self
.
tp
.
unique_name
)
return
torch
.
ops
.
aten
.
mm
.
default
(
all_gather
,
weight
)
def
replacement
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
ag_output
,
mm_outputs
=
torch
.
ops
.
symm_mem
.
fused_all_gather_matmul
(
x
,
[
weight
],
gather_dim
=
0
,
group_name
=
self
.
tp
.
device_group
.
group_name
,
)
return
mm_outputs
pm
.
register_replacement
(
pattern
,
replacement
,
self
.
get_inputs
(),
pm
.
fwd_only
,
pm_pass
)
class
AsyncTPPass
(
VllmInductorPass
):
def
__init__
(
self
,
config
:
VllmConfig
):
super
().
__init__
(
config
)
# Enable symmetric memory for the TP process group
enable_symm_mem_for_group
(
get_tp_group
().
device_group
.
group_name
)
self
.
patterns
:
PatternMatcherPass
=
PatternMatcherPass
(
pass_name
=
"async_tp_pass"
)
GEMMReduceScatterPattern
(
self
.
model_dtype
,
self
.
device
).
register
(
self
.
patterns
)
AllGatherGEMMPattern
(
self
.
model_dtype
,
self
.
device
).
register
(
self
.
patterns
)
def
is_applicable_for_shape
(
self
,
shape
:
Optional
[
int
])
->
bool
:
# only do replace for specific shapes
tp_size
=
get_tensor_model_parallel_world_size
()
return
shape
is
not
None
and
shape
%
tp_size
==
0
def
__call__
(
self
,
graph
:
fx
.
Graph
):
self
.
begin
()
self
.
dump_graph
(
graph
,
"before_async_tp_pass"
)
count
=
self
.
patterns
.
apply
(
graph
)
logger
.
debug
(
"Replaced %s patterns"
,
count
)
self
.
dump_graph
(
graph
,
"after_async_tp_pass"
)
self
.
end_and_log
()
vllm/compilation/pass_manager.py
View file @
71ea614d
...
@@ -6,6 +6,7 @@ from vllm.config import VllmConfig
...
@@ -6,6 +6,7 @@ from vllm.config import VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
.activation_quant_fusion
import
ActivationQuantFusionPass
from
.activation_quant_fusion
import
ActivationQuantFusionPass
from
.collective_fusion
import
AsyncTPPass
from
.fix_functionalization
import
FixFunctionalizationPass
from
.fix_functionalization
import
FixFunctionalizationPass
from
.fusion
import
FusionPass
from
.fusion
import
FusionPass
from
.inductor_pass
import
CustomGraphPass
,
InductorPass
,
get_pass_context
from
.inductor_pass
import
CustomGraphPass
,
InductorPass
,
get_pass_context
...
@@ -54,6 +55,8 @@ class PostGradPassManager(CustomGraphPass):
...
@@ -54,6 +55,8 @@ class PostGradPassManager(CustomGraphPass):
if
self
.
pass_config
.
enable_sequence_parallelism
:
if
self
.
pass_config
.
enable_sequence_parallelism
:
self
.
passes
+=
[
SequenceParallelismPass
(
config
)]
self
.
passes
+=
[
SequenceParallelismPass
(
config
)]
if
self
.
pass_config
.
enable_async_tp
:
self
.
passes
+=
[
AsyncTPPass
(
config
)]
self
.
fix_functionalization
=
FixFunctionalizationPass
(
config
)
self
.
fix_functionalization
=
FixFunctionalizationPass
(
config
)
...
...
vllm/compilation/sequence_parallelism.py
View file @
71ea614d
...
@@ -243,24 +243,25 @@ class SequenceParallelismPass(VllmInductorPass):
...
@@ -243,24 +243,25 @@ class SequenceParallelismPass(VllmInductorPass):
pass_name
=
"sequence_parallelism_pass"
)
pass_name
=
"sequence_parallelism_pass"
)
for
epsilon
in
[
1e-5
,
1e-6
]:
for
epsilon
in
[
1e-5
,
1e-6
]:
EmbeddingAllReduceRMSNormPattern
(
EmbeddingAllReduceRMSNormPattern
(
epsilon
,
self
.
dtype
,
self
.
device
).
register
(
self
.
patterns
)
epsilon
,
self
.
model_
dtype
,
self
.
device
).
register
(
self
.
patterns
)
MiddleAllReduceRMSNormPattern
(
epsilon
,
self
.
dtype
,
MiddleAllReduceRMSNormPattern
(
epsilon
,
self
.
model_
dtype
,
self
.
device
).
register
(
self
.
patterns
)
self
.
device
).
register
(
self
.
patterns
)
LastAllReduceRMSNormPattern
(
epsilon
,
self
.
dtype
,
LastAllReduceRMSNormPattern
(
epsilon
,
self
.
model_
dtype
,
self
.
device
).
register
(
self
.
patterns
)
self
.
device
).
register
(
self
.
patterns
)
# WARNING: This is a hack to clear the pattern matcher cache
# WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon.
# and allow multiple values of epsilon.
torch
.
_inductor
.
pattern_matcher
.
_seen_patterns
.
clear
()
torch
.
_inductor
.
pattern_matcher
.
_seen_patterns
.
clear
()
def
is_applicable_for_shape
(
self
,
shape
:
Optional
[
int
])
->
bool
:
def
is_applicable_for_shape
(
self
,
shape
:
Optional
[
int
])
->
bool
:
# only do replace for specific shapes
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
return
shape
is
not
None
and
shape
%
tp_size
==
0
return
shape
is
not
None
and
shape
%
tp_size
==
0
def
__call__
(
self
,
graph
:
fx
.
Graph
):
def
__call__
(
self
,
graph
:
fx
.
Graph
):
self
.
begin
()
self
.
dump_graph
(
graph
,
"before_sequence_parallelism_pass"
)
self
.
dump_graph
(
graph
,
"before_sequence_parallelism_pass"
)
count
=
self
.
patterns
.
apply
(
graph
)
count
=
self
.
patterns
.
apply
(
graph
)
logger
.
debug
(
"Replaced %s patterns"
,
count
)
logger
.
debug
(
"Replaced %s patterns"
,
count
)
self
.
dump_graph
(
graph
,
"after_sequence_parallelism_pass"
)
self
.
dump_graph
(
graph
,
"after_sequence_parallelism_pass"
)
self
.
end_and_log
()
vllm/compilation/vllm_inductor_pass.py
View file @
71ea614d
...
@@ -26,7 +26,8 @@ class VllmInductorPass(InductorPass):
...
@@ -26,7 +26,8 @@ class VllmInductorPass(InductorPass):
def
__init__
(
self
,
config
:
VllmConfig
):
def
__init__
(
self
,
config
:
VllmConfig
):
self
.
pass_config
=
config
.
compilation_config
.
pass_config
self
.
pass_config
=
config
.
compilation_config
.
pass_config
self
.
dtype
=
config
.
model_config
.
dtype
if
config
.
model_config
else
None
self
.
model_dtype
=
config
.
model_config
.
dtype
if
config
.
model_config
\
else
None
self
.
device
=
config
.
device_config
.
device
if
config
.
device_config
\
self
.
device
=
config
.
device_config
.
device
if
config
.
device_config
\
else
None
else
None
self
.
pass_name
=
self
.
__class__
.
__name__
self
.
pass_name
=
self
.
__class__
.
__name__
...
...
vllm/config.py
View file @
71ea614d
...
@@ -3652,6 +3652,8 @@ class PassConfig:
...
@@ -3652,6 +3652,8 @@ class PassConfig:
"""Whether to enable the custom no-op elimination pass."""
"""Whether to enable the custom no-op elimination pass."""
enable_sequence_parallelism
:
bool
=
False
enable_sequence_parallelism
:
bool
=
False
"""Whether to enable sequence parallelism."""
"""Whether to enable sequence parallelism."""
enable_async_tp
:
bool
=
False
"""Whether to enable async TP."""
def
uuid
(
self
):
def
uuid
(
self
):
"""
"""
...
@@ -3661,7 +3663,8 @@ class PassConfig:
...
@@ -3661,7 +3663,8 @@ class PassConfig:
compilation.
compilation.
"""
"""
include
=
{
include
=
{
"enable_fusion"
,
"enable_noop"
,
"enable_sequence_parallelism"
"enable_fusion"
,
"enable_noop"
,
"enable_sequence_parallelism"
,
"enable_async_tp"
}
}
dict_
=
{
k
:
v
for
k
,
v
in
asdict
(
self
).
items
()
if
k
in
include
}
dict_
=
{
k
:
v
for
k
,
v
in
asdict
(
self
).
items
()
if
k
in
include
}
return
InductorPass
.
hash_dict
(
dict_
)
return
InductorPass
.
hash_dict
(
dict_
)
...
@@ -4274,6 +4277,12 @@ class VllmConfig:
...
@@ -4274,6 +4277,12 @@ class VllmConfig:
if
self
.
compilation_config
is
None
:
if
self
.
compilation_config
is
None
:
self
.
compilation_config
=
CompilationConfig
()
self
.
compilation_config
=
CompilationConfig
()
# async tp is built on top of sequence parallelism
# and requires it to be enabled.
if
self
.
compilation_config
.
pass_config
.
enable_async_tp
:
self
.
compilation_config
.
pass_config
.
enable_sequence_parallelism
=
\
True
if
self
.
compilation_config
.
pass_config
.
enable_sequence_parallelism
:
if
self
.
compilation_config
.
pass_config
.
enable_sequence_parallelism
:
self
.
compilation_config
.
custom_ops
.
append
(
"+rms_norm"
)
self
.
compilation_config
.
custom_ops
.
append
(
"+rms_norm"
)
if
envs
.
VLLM_USE_V1
and
self
.
model_config
is
not
None
and
\
if
envs
.
VLLM_USE_V1
and
self
.
model_config
is
not
None
and
\
...
...
vllm/distributed/parallel_state.py
View file @
71ea614d
...
@@ -120,7 +120,7 @@ def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int,
...
@@ -120,7 +120,7 @@ def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int,
group
=
_groups
[
group_name
]()
group
=
_groups
[
group_name
]()
if
group
is
None
:
if
group
is
None
:
raise
ValueError
(
f
"Group
{
group_name
}
is destroyed."
)
raise
ValueError
(
f
"Group
{
group_name
}
is destroyed."
)
return
group
.
reduce_scatter
(
tensor
,
dim
)
return
group
.
_
reduce_scatter
_out_place
(
tensor
,
dim
)
def
reduce_scatter_fake
(
tensor
:
torch
.
Tensor
,
dim
:
int
,
world_size
:
int
,
def
reduce_scatter_fake
(
tensor
:
torch
.
Tensor
,
dim
:
int
,
world_size
:
int
,
...
@@ -136,7 +136,7 @@ def all_gather(tensor: torch.Tensor, dim: int, world_size: int,
...
@@ -136,7 +136,7 @@ def all_gather(tensor: torch.Tensor, dim: int, world_size: int,
group
=
_groups
[
group_name
]()
group
=
_groups
[
group_name
]()
if
group
is
None
:
if
group
is
None
:
raise
ValueError
(
f
"Group
{
group_name
}
is destroyed."
)
raise
ValueError
(
f
"Group
{
group_name
}
is destroyed."
)
return
group
.
all_gather
(
tensor
,
dim
)
return
group
.
_
all_gather
_out_place
(
tensor
,
dim
)
def
all_gather_fake
(
tensor
:
torch
.
Tensor
,
dim
:
int
,
world_size
:
int
,
def
all_gather_fake
(
tensor
:
torch
.
Tensor
,
dim
:
int
,
world_size
:
int
,
...
@@ -161,6 +161,7 @@ if supports_custom_op():
...
@@ -161,6 +161,7 @@ if supports_custom_op():
op_func
=
reduce_scatter
,
op_func
=
reduce_scatter
,
mutates_args
=
[],
mutates_args
=
[],
fake_impl
=
reduce_scatter_fake
,
fake_impl
=
reduce_scatter_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
)
direct_register_custom_op
(
direct_register_custom_op
(
...
@@ -168,6 +169,7 @@ if supports_custom_op():
...
@@ -168,6 +169,7 @@ if supports_custom_op():
op_func
=
all_gather
,
op_func
=
all_gather
,
mutates_args
=
[],
mutates_args
=
[],
fake_impl
=
all_gather_fake
,
fake_impl
=
all_gather_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
)
...
@@ -367,6 +369,16 @@ class GroupCoordinator:
...
@@ -367,6 +369,16 @@ class GroupCoordinator:
assert
-
input_
.
dim
()
<=
dim
<
input_
.
dim
(),
(
assert
-
input_
.
dim
()
<=
dim
<
input_
.
dim
(),
(
f
"Invalid dim (
{
dim
}
) for input tensor with shape
{
input_
.
size
()
}
"
)
f
"Invalid dim (
{
dim
}
) for input tensor with shape
{
input_
.
size
()
}
"
)
if
self
.
use_custom_op_call
:
return
torch
.
ops
.
vllm
.
all_gather
(
input_
,
dim
,
world_size
,
group_name
=
self
.
unique_name
)
else
:
return
self
.
_all_gather_out_place
(
input_
,
dim
)
def
_all_gather_out_place
(
self
,
input_
:
torch
.
Tensor
,
dim
:
int
)
->
torch
.
Tensor
:
return
self
.
device_communicator
.
all_gather
(
input_
,
dim
)
return
self
.
device_communicator
.
all_gather
(
input_
,
dim
)
def
reduce_scatter
(
self
,
def
reduce_scatter
(
self
,
...
@@ -379,6 +391,16 @@ class GroupCoordinator:
...
@@ -379,6 +391,16 @@ class GroupCoordinator:
assert
-
input_
.
dim
()
<=
dim
<
input_
.
dim
(),
(
assert
-
input_
.
dim
()
<=
dim
<
input_
.
dim
(),
(
f
"Invalid dim (
{
dim
}
) for input tensor with shape
{
input_
.
size
()
}
"
)
f
"Invalid dim (
{
dim
}
) for input tensor with shape
{
input_
.
size
()
}
"
)
if
self
.
use_custom_op_call
:
return
torch
.
ops
.
vllm
.
reduce_scatter
(
input_
,
dim
,
world_size
,
group_name
=
self
.
unique_name
)
else
:
return
self
.
_reduce_scatter_out_place
(
input_
,
dim
)
def
_reduce_scatter_out_place
(
self
,
input_
:
torch
.
Tensor
,
dim
:
int
)
->
torch
.
Tensor
:
return
self
.
device_communicator
.
reduce_scatter
(
input_
,
dim
)
return
self
.
device_communicator
.
reduce_scatter
(
input_
,
dim
)
def
gather
(
self
,
def
gather
(
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment