Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
287f527f
Unverified
Commit
287f527f
authored
Jul 30, 2025
by
cascade
Committed by
GitHub
Jul 30, 2025
Browse files
[Feature] Add async tensor parallelism for scaled mm (#20155)
Signed-off-by:
cascade812
<
cascade812@outlook.com
>
parent
f12d9256
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
381 additions
and
8 deletions
+381
-8
tests/compile/test_async_tp.py
tests/compile/test_async_tp.py
+138
-5
vllm/compilation/collective_fusion.py
vllm/compilation/collective_fusion.py
+242
-2
vllm/compilation/sequence_parallelism.py
vllm/compilation/sequence_parallelism.py
+1
-1
No files found.
tests/compile/test_async_tp.py
View file @
287f527f
...
@@ -22,6 +22,8 @@ from ..utils import (compare_two_settings, create_new_process_for_each_test,
...
@@ -22,6 +22,8 @@ from ..utils import (compare_two_settings, create_new_process_for_each_test,
multi_gpu_test
)
multi_gpu_test
)
from
.backend
import
TestBackend
from
.backend
import
TestBackend
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
prompts
=
[
prompts
=
[
"Hello, my name is"
,
"Hello, my name is"
,
"The president of the United States is"
,
"The president of the United States is"
,
...
@@ -32,9 +34,10 @@ prompts = [
...
@@ -32,9 +34,10 @@ prompts = [
class
TestMMRSModel
(
torch
.
nn
.
Module
):
class
TestMMRSModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
=
16
):
def
__init__
(
self
,
hidden_size
=
16
,
dtype
=
torch
.
float16
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
dtype
=
dtype
self
.
gate_proj
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
self
.
gate_proj
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
(
self
.
hidden_size
*
2
,
hidden_size
)),
(
self
.
hidden_size
*
2
,
hidden_size
)),
requires_grad
=
False
)
requires_grad
=
False
)
...
@@ -64,9 +67,10 @@ class TestMMRSModel(torch.nn.Module):
...
@@ -64,9 +67,10 @@ class TestMMRSModel(torch.nn.Module):
class
TestAGMMModel
(
torch
.
nn
.
Module
):
class
TestAGMMModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
=
16
):
def
__init__
(
self
,
hidden_size
=
16
,
dtype
=
torch
.
float16
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
dtype
=
dtype
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
(
hidden_size
,
hidden_size
)),
(
hidden_size
,
hidden_size
)),
requires_grad
=
False
)
requires_grad
=
False
)
...
@@ -91,8 +95,125 @@ class TestAGMMModel(torch.nn.Module):
...
@@ -91,8 +95,125 @@ class TestAGMMModel(torch.nn.Module):
return
[
torch
.
ops
.
symm_mem
.
fused_all_gather_matmul
.
default
]
return
[
torch
.
ops
.
symm_mem
.
fused_all_gather_matmul
.
default
]
class
_BaseScaledMMModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
=
16
,
dtype
=
torch
.
float16
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
dtype
=
dtype
self
.
weight
=
torch
.
empty
([
hidden_size
,
hidden_size
],
dtype
=
FP8_DTYPE
)
\
.
contiguous
().
transpose
(
0
,
1
)
# Initialize scale_b for _scaled_mm.
self
.
scale_b
=
torch
.
ones
(
1
,
self
.
hidden_size
,
dtype
=
torch
.
float32
)
class
TestScaledMMRSModel
(
_BaseScaledMMModel
):
def
forward
(
self
,
input
:
torch
.
Tensor
):
"""
Forward pass implementing the scaled_mm + reduce scatter in the FX graph
"""
fp8_input
=
input
.
to
(
FP8_DTYPE
)
scale_a
=
torch
.
ones
(
input
.
shape
[
0
],
1
,
dtype
=
torch
.
float32
)
scaled_mm
=
torch
.
_scaled_mm
(
fp8_input
,
self
.
weight
,
scale_a
=
scale_a
,
scale_b
=
self
.
scale_b
,
out_dtype
=
self
.
dtype
)
reduce_scatter
=
tensor_model_parallel_reduce_scatter
(
scaled_mm
,
dim
=
0
)
return
reduce_scatter
def
ops_in_model_before
(
self
):
return
[
torch
.
ops
.
vllm
.
reduce_scatter
.
default
]
def
ops_in_model_after
(
self
):
return
[
torch
.
ops
.
symm_mem
.
fused_scaled_matmul_reduce_scatter
.
default
]
class
TestAGScaledMMModel
(
_BaseScaledMMModel
):
def
forward
(
self
,
input
:
torch
.
Tensor
):
"""
Forward pass implementing the all gather + scaled_mm in the FX graph
"""
# Reshape input
fp8_input
=
input
.
to
(
FP8_DTYPE
)
all_gather
=
tensor_model_parallel_all_gather
(
fp8_input
,
dim
=
0
)
scale_a
=
torch
.
ones
(
all_gather
.
shape
[
0
],
1
,
dtype
=
torch
.
float32
)
scaled_mm
=
torch
.
_scaled_mm
(
all_gather
,
self
.
weight
,
scale_a
=
scale_a
,
scale_b
=
self
.
scale_b
,
out_dtype
=
self
.
dtype
)
return
scaled_mm
def
ops_in_model_before
(
self
):
return
[
torch
.
ops
.
vllm
.
all_gather
.
default
]
def
ops_in_model_after
(
self
):
return
[
torch
.
ops
.
symm_mem
.
fused_all_gather_scaled_matmul
.
default
]
class
TestCutlassScaledMMRSModel
(
_BaseScaledMMModel
):
def
forward
(
self
,
input
:
torch
.
Tensor
):
"""
Forward pass implementing the cutlass_scaled_mm + reduce scatter
in the FX graph
"""
fp8_input
=
input
.
to
(
FP8_DTYPE
)
scale_a
=
torch
.
ones
(
input
.
shape
[
0
],
1
,
dtype
=
torch
.
float32
)
mm_out
=
torch
.
empty
((
fp8_input
.
shape
[
0
],
self
.
weight
.
shape
[
1
]),
dtype
=
self
.
dtype
,
device
=
input
.
device
)
torch
.
ops
.
_C
.
cutlass_scaled_mm
(
mm_out
,
fp8_input
,
self
.
weight
,
scale_a
,
self
.
scale_b
,
None
)
reduce_scatter
=
tensor_model_parallel_reduce_scatter
(
mm_out
,
dim
=
0
)
return
reduce_scatter
def
ops_in_model_before
(
self
):
return
[
torch
.
ops
.
vllm
.
reduce_scatter
.
default
]
def
ops_in_model_after
(
self
):
return
[
torch
.
ops
.
symm_mem
.
fused_scaled_matmul_reduce_scatter
.
default
]
class
TestAGCutlassScaledMMModel
(
_BaseScaledMMModel
):
def
forward
(
self
,
input
:
torch
.
Tensor
):
"""
Forward pass implementing the all gather + cutlass_scaled_mm
in the FX graph
"""
# Reshape input
fp8_input
=
input
.
to
(
FP8_DTYPE
)
all_gather
=
tensor_model_parallel_all_gather
(
fp8_input
,
dim
=
0
)
scale_a
=
torch
.
ones
(
all_gather
.
shape
[
0
],
1
,
dtype
=
torch
.
float32
)
mm_out
=
torch
.
empty
((
all_gather
.
shape
[
0
],
self
.
weight
.
shape
[
1
]),
dtype
=
self
.
dtype
,
device
=
all_gather
.
device
)
torch
.
ops
.
_C
.
cutlass_scaled_mm
(
mm_out
,
all_gather
,
self
.
weight
,
scale_a
,
self
.
scale_b
,
None
)
return
mm_out
def
ops_in_model_before
(
self
):
return
[
torch
.
ops
.
vllm
.
all_gather
.
default
]
def
ops_in_model_after
(
self
):
return
[
torch
.
ops
.
symm_mem
.
fused_all_gather_scaled_matmul
.
default
]
@
multi_gpu_test
(
num_gpus
=
2
)
@
multi_gpu_test
(
num_gpus
=
2
)
@
pytest
.
mark
.
parametrize
(
"test_model"
,
[
TestMMRSModel
,
TestAGMMModel
])
@
pytest
.
mark
.
parametrize
(
"test_model"
,
[
TestMMRSModel
,
TestAGMMModel
,
TestScaledMMRSModel
,
TestAGScaledMMModel
,
TestCutlassScaledMMRSModel
,
TestAGCutlassScaledMMModel
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
16
])
...
@@ -101,6 +222,14 @@ class TestAGMMModel(torch.nn.Module):
...
@@ -101,6 +222,14 @@ class TestAGMMModel(torch.nn.Module):
reason
=
"Only test on CUDA"
)
reason
=
"Only test on CUDA"
)
def
test_async_tp_pass_replace
(
test_model
:
str
,
batch_size
:
int
,
seq_len
:
int
,
def
test_async_tp_pass_replace
(
test_model
:
str
,
batch_size
:
int
,
seq_len
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
):
hidden_size
:
int
,
dtype
:
torch
.
dtype
):
if
test_model
in
(
TestScaledMMRSModel
,
TestAGScaledMMModel
,
TestCutlassScaledMMRSModel
,
TestAGCutlassScaledMMModel
)
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"Only bf16 high precision output types are supported for "
\
"per-token (row-wise) scaling"
)
num_processes
=
2
num_processes
=
2
def
run_torch_spawn
(
fn
,
nprocs
):
def
run_torch_spawn
(
fn
,
nprocs
):
...
@@ -155,7 +284,8 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
...
@@ -155,7 +284,8 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
async_tp_pass
=
AsyncTPPass
(
vllm_config
)
async_tp_pass
=
AsyncTPPass
(
vllm_config
)
backend
=
TestBackend
(
async_tp_pass
)
backend
=
TestBackend
(
async_tp_pass
)
model
=
test_model_cls
(
hidden_size
)
model
=
test_model_cls
(
hidden_size
,
dtype
)
# Pass dtype to model constructor
hidden_states
=
torch
.
randn
((
batch_size
*
seq_len
,
hidden_size
),
hidden_states
=
torch
.
randn
((
batch_size
*
seq_len
,
hidden_size
),
dtype
=
dtype
,
dtype
=
dtype
,
...
@@ -174,7 +304,10 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
...
@@ -174,7 +304,10 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
@
create_new_process_for_each_test
()
@
create_new_process_for_each_test
()
@
pytest
.
mark
.
parametrize
(
"model_id"
,
[
"meta-llama/Llama-3.2-1B-Instruct"
])
@
pytest
.
mark
.
parametrize
(
"model_id"
,
[
"meta-llama/Llama-3.2-1B-Instruct"
,
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"
])
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"async_tp_enabled"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"async_tp_enabled"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"distributed_backend"
,
[
"mp"
])
@
pytest
.
mark
.
parametrize
(
"distributed_backend"
,
[
"mp"
])
...
...
vllm/compilation/collective_fusion.py
View file @
287f527f
...
@@ -15,10 +15,13 @@ from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
...
@@ -15,10 +15,13 @@ from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from
vllm.distributed.parallel_state
import
(
from
vllm.distributed.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
from
.vllm_inductor_pass
import
VllmInductorPass
from
.vllm_inductor_pass
import
VllmInductorPass
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
if
find_spec
(
"flashinfer"
):
if
find_spec
(
"flashinfer"
):
try
:
try
:
import
flashinfer.comm
as
flashinfer_comm
import
flashinfer.comm
as
flashinfer_comm
...
@@ -28,7 +31,6 @@ if find_spec("flashinfer"):
...
@@ -28,7 +31,6 @@ if find_spec("flashinfer"):
flashinfer_comm
=
None
flashinfer_comm
=
None
else
:
else
:
flashinfer_comm
=
None
flashinfer_comm
=
None
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -118,6 +120,230 @@ class AllGatherGEMMPattern(BasePattern):
...
@@ -118,6 +120,230 @@ class AllGatherGEMMPattern(BasePattern):
pm
.
fwd_only
,
pm_pass
)
pm
.
fwd_only
,
pm_pass
)
class
ScaledMMReduceScatterPattern
(
BasePattern
):
def
get_inputs
(
self
):
input
=
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
)
mm_weight
=
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
).
contiguous
().
transpose
(
0
,
1
)
scale_a
=
torch
.
empty
([
16
,
1
],
device
=
self
.
device
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
empty
([
1
,
16
],
device
=
self
.
device
,
dtype
=
torch
.
float32
)
return
[
input
,
mm_weight
,
scale_a
,
scale_b
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
pattern
(
input
:
torch
.
Tensor
,
mat2
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
scaled_mm
=
torch
.
ops
.
aten
.
_scaled_mm
.
default
(
input
,
mat2
=
mat2
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
bias
=
None
,
scale_result
=
None
,
out_dtype
=
self
.
dtype
)
reduce_scatter
=
torch
.
ops
.
vllm
.
reduce_scatter
.
default
(
scaled_mm
,
dim
=
0
,
world_size
=
self
.
tp_size
,
group_name
=
self
.
tp
.
unique_name
)
return
reduce_scatter
def
replacement
(
input
:
torch
.
Tensor
,
mat2
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
gemm_rs
=
torch
.
ops
.
symm_mem
.
fused_scaled_matmul_reduce_scatter
(
input
,
mat2
,
scale_a
,
scale_b
,
"avg"
,
scatter_dim
=
0
,
out_dtype
=
self
.
dtype
,
group_name
=
self
.
tp
.
device_group
.
group_name
,
)
return
gemm_rs
pm
.
register_replacement
(
pattern
,
replacement
,
self
.
get_inputs
(),
pm
.
fwd_only
,
pm_pass
)
class
AllGatherScaledMMPattern
(
BasePattern
):
def
get_inputs
(
self
):
x
=
torch
.
empty
([
8
,
16
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
)
weight
=
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
).
contiguous
().
transpose
(
0
,
1
)
s1
=
x
.
shape
[
0
]
*
self
.
tp_size
scale_a
=
torch
.
empty
([
s1
,
1
],
device
=
self
.
device
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
empty
([
1
,
16
],
device
=
self
.
device
,
dtype
=
torch
.
float32
)
return
[
x
,
weight
,
scale_a
,
scale_b
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
pattern
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
all_gather
=
torch
.
ops
.
vllm
.
all_gather
.
default
(
x
,
dim
=
0
,
world_size
=
self
.
tp_size
,
group_name
=
self
.
tp
.
unique_name
)
return
torch
.
ops
.
aten
.
_scaled_mm
.
default
(
all_gather
,
mat2
=
weight
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
bias
=
None
,
scale_result
=
None
,
out_dtype
=
self
.
dtype
)
def
replacement
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
ag_output
,
mm_outputs
=
torch
.
ops
.
symm_mem
.
fused_all_gather_scaled_matmul
(
# noqa
x
,
[
weight
],
scale_a
,
[
scale_b
],
gather_dim
=
0
,
biases
=
[
None
],
result_scales
=
[
None
],
out_dtypes
=
[
self
.
dtype
],
use_fast_accum
=
[
False
],
group_name
=
self
.
tp
.
device_group
.
group_name
,
)
return
mm_outputs
pm
.
register_replacement
(
pattern
,
replacement
,
self
.
get_inputs
(),
pm
.
fwd_only
,
pm_pass
)
class
CutlassScaledMMReduceScatterPattern
(
BasePattern
):
def
get_inputs
(
self
):
input
=
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
)
mm_weight
=
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
).
contiguous
().
transpose
(
0
,
1
)
scale_a
=
torch
.
empty
([
16
,
1
],
device
=
self
.
device
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
empty
([
1
,
16
],
device
=
self
.
device
,
dtype
=
torch
.
float32
)
cutlass_mm_output
=
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
return
[
input
,
mm_weight
,
scale_a
,
scale_b
,
cutlass_mm_output
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
cutlass_mm_output
:
torch
.
Tensor
)
->
torch
.
Tensor
:
cutlass_scaled_mm
=
torch
.
ops
.
higher_order
.
auto_functionalized
(
torch
.
ops
.
_C
.
cutlass_scaled_mm
.
default
,
out
=
cutlass_mm_output
,
a
=
input
,
b
=
weight
,
a_scales
=
scale_a
,
b_scales
=
scale_b
,
bias
=
None
)
reduce_scatter
=
torch
.
ops
.
vllm
.
reduce_scatter
.
default
(
cutlass_scaled_mm
[
1
],
dim
=
0
,
world_size
=
self
.
tp_size
,
group_name
=
self
.
tp
.
unique_name
)
return
reduce_scatter
def
replacement
(
input
:
torch
.
Tensor
,
mat2
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
cutlass_mm_output
:
torch
.
Tensor
)
->
torch
.
Tensor
:
gemm_rs
=
torch
.
ops
.
symm_mem
.
fused_scaled_matmul_reduce_scatter
(
input
,
mat2
,
scale_a
,
scale_b
,
"avg"
,
scatter_dim
=
0
,
out_dtype
=
self
.
dtype
,
group_name
=
self
.
tp
.
device_group
.
group_name
,
)
return
gemm_rs
pm
.
register_replacement
(
pattern
,
replacement
,
self
.
get_inputs
(),
pm
.
fwd_only
,
pm_pass
)
class
AllGatherCutlassScaledMMPattern
(
BasePattern
):
def
get_inputs
(
self
):
x
=
torch
.
empty
([
8
,
16
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
)
weight
=
torch
.
empty
([
16
,
16
],
device
=
self
.
device
,
dtype
=
FP8_DTYPE
).
contiguous
().
transpose
(
0
,
1
)
s1
=
x
.
shape
[
0
]
*
self
.
tp_size
scale_a
=
torch
.
empty
([
s1
,
1
],
device
=
self
.
device
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
empty
([
1
,
16
],
device
=
self
.
device
,
dtype
=
torch
.
float32
)
s2
=
weight
.
shape
[
1
]
output
=
torch
.
empty
([
s1
,
s2
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
return
[
x
,
weight
,
scale_a
,
scale_b
,
output
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
pattern
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
all_gather
=
torch
.
ops
.
vllm
.
all_gather
.
default
(
x
,
dim
=
0
,
world_size
=
self
.
tp_size
,
group_name
=
self
.
tp
.
unique_name
)
cutlass_scaled_mm
=
torch
.
ops
.
higher_order
.
auto_functionalized
(
torch
.
ops
.
_C
.
cutlass_scaled_mm
.
default
,
out
=
output
,
a
=
all_gather
,
b
=
weight
,
a_scales
=
scale_a
,
b_scales
=
scale_b
,
bias
=
None
)
return
cutlass_scaled_mm
[
1
]
def
replacement
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
output
:
torch
.
Tensor
)
->
torch
.
Tensor
:
ag_output
,
mm_outputs
=
torch
.
ops
.
symm_mem
.
fused_all_gather_scaled_matmul
(
# noqa
x
,
[
weight
],
scale_a
,
[
scale_b
],
gather_dim
=
0
,
biases
=
[
None
],
result_scales
=
[
None
],
out_dtypes
=
[
self
.
dtype
],
use_fast_accum
=
[
False
],
group_name
=
self
.
tp
.
device_group
.
group_name
,
)
return
mm_outputs
pm
.
register_replacement
(
pattern
,
replacement
,
self
.
get_inputs
(),
pm
.
fwd_only
,
pm_pass
)
class
AsyncTPPass
(
VllmInductorPass
):
class
AsyncTPPass
(
VllmInductorPass
):
def
__init__
(
self
,
config
:
VllmConfig
):
def
__init__
(
self
,
config
:
VllmConfig
):
...
@@ -133,6 +359,20 @@ class AsyncTPPass(VllmInductorPass):
...
@@ -133,6 +359,20 @@ class AsyncTPPass(VllmInductorPass):
AllGatherGEMMPattern
(
self
.
model_dtype
,
AllGatherGEMMPattern
(
self
.
model_dtype
,
self
.
device
).
register
(
self
.
patterns
)
self
.
device
).
register
(
self
.
patterns
)
# These fusions are enabled only for bfloat16 models because
# `scaled_mm` or `cutlass_scaled_mm` with per-token (row-wise) scaling
# only supports bfloat16 as the output dtype.
if
self
.
model_dtype
==
torch
.
bfloat16
:
ScaledMMReduceScatterPattern
(
self
.
model_dtype
,
self
.
device
).
register
(
self
.
patterns
)
AllGatherScaledMMPattern
(
self
.
model_dtype
,
self
.
device
).
register
(
self
.
patterns
)
CutlassScaledMMReduceScatterPattern
(
self
.
model_dtype
,
self
.
device
).
register
(
self
.
patterns
)
AllGatherCutlassScaledMMPattern
(
self
.
model_dtype
,
self
.
device
).
register
(
self
.
patterns
)
def
is_applicable_for_shape
(
self
,
shape
:
Optional
[
int
])
->
bool
:
def
is_applicable_for_shape
(
self
,
shape
:
Optional
[
int
])
->
bool
:
# only do replace for specific shapes
# only do replace for specific shapes
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
...
@@ -142,7 +382,7 @@ class AsyncTPPass(VllmInductorPass):
...
@@ -142,7 +382,7 @@ class AsyncTPPass(VllmInductorPass):
self
.
begin
()
self
.
begin
()
self
.
dump_graph
(
graph
,
"before_async_tp_pass"
)
self
.
dump_graph
(
graph
,
"before_async_tp_pass"
)
count
=
self
.
patterns
.
apply
(
graph
)
count
=
self
.
patterns
.
apply
(
graph
)
logger
.
debug
(
"Replaced %s patterns"
,
count
)
logger
.
debug
(
"Replaced %s patterns
with async TP pass.
"
,
count
)
self
.
dump_graph
(
graph
,
"after_async_tp_pass"
)
self
.
dump_graph
(
graph
,
"after_async_tp_pass"
)
self
.
end_and_log
()
self
.
end_and_log
()
...
...
vllm/compilation/sequence_parallelism.py
View file @
287f527f
...
@@ -477,6 +477,6 @@ class SequenceParallelismPass(VllmInductorPass):
...
@@ -477,6 +477,6 @@ class SequenceParallelismPass(VllmInductorPass):
self
.
begin
()
self
.
begin
()
self
.
dump_graph
(
graph
,
"before_sequence_parallelism_pass"
)
self
.
dump_graph
(
graph
,
"before_sequence_parallelism_pass"
)
count
=
self
.
patterns
.
apply
(
graph
)
count
=
self
.
patterns
.
apply
(
graph
)
logger
.
debug
(
"Replaced %s patterns"
,
count
)
logger
.
debug
(
"Replaced %s patterns
with sequence parallelism
"
,
count
)
self
.
dump_graph
(
graph
,
"after_sequence_parallelism_pass"
)
self
.
dump_graph
(
graph
,
"after_sequence_parallelism_pass"
)
self
.
end_and_log
()
self
.
end_and_log
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment