Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a25f2ade
Unverified
Commit
a25f2ade
authored
Oct 11, 2025
by
Angela Yi
Committed by
GitHub
Oct 11, 2025
Browse files
[compile] Add patched_fused_scaled_matmul_reduce_scatter (#26604)
Signed-off-by:
angelayi
<
yiangela7@gmail.com
>
parent
d0bed837
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
119 additions
and
6 deletions
+119
-6
tests/compile/test_async_tp.py
tests/compile/test_async_tp.py
+22
-4
vllm/compilation/collective_fusion.py
vllm/compilation/collective_fusion.py
+2
-2
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+95
-0
No files found.
tests/compile/test_async_tp.py
View file @
a25f2ade
...
...
@@ -142,7 +142,7 @@ class TestScaledMMRSModel(_BaseScaledMMModel):
return
[
torch
.
ops
.
vllm
.
reduce_scatter
.
default
]
def
ops_in_model_after
(
self
):
return
[
torch
.
ops
.
symm_mem
.
fused_scaled_matmul_reduce_scatter
.
default
]
return
[
torch
.
ops
.
vllm
.
patched_
fused_scaled_matmul_reduce_scatter
.
default
]
class
TestAGScaledMMModel
(
_BaseScaledMMModel
):
...
...
@@ -195,7 +195,7 @@ class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
return
[
torch
.
ops
.
vllm
.
reduce_scatter
.
default
]
def
ops_in_model_after
(
self
):
return
[
torch
.
ops
.
symm_mem
.
fused_scaled_matmul_reduce_scatter
.
default
]
return
[
torch
.
ops
.
vllm
.
patched_
fused_scaled_matmul_reduce_scatter
.
default
]
class
TestAGCutlassScaledMMModel
(
_BaseScaledMMModel
):
...
...
@@ -243,9 +243,15 @@ class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"dynamic"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
envs
.
VLLM_TARGET_DEVICE
not
in
[
"cuda"
],
reason
=
"Only test on CUDA"
)
def
test_async_tp_pass_replace
(
test_model
:
str
,
batch_size
:
int
,
seq_len
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
test_model
:
str
,
batch_size
:
int
,
seq_len
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
dynamic
:
bool
,
):
if
(
test_model
...
...
@@ -269,7 +275,15 @@ def test_async_tp_pass_replace(
# torch.distributed and cuda
torch
.
multiprocessing
.
spawn
(
fn
,
args
=
(
num_processes
,
test_model
,
batch_size
,
seq_len
,
hidden_size
,
dtype
),
args
=
(
num_processes
,
test_model
,
batch_size
,
seq_len
,
hidden_size
,
dtype
,
dynamic
,
),
nprocs
=
nprocs
,
)
...
...
@@ -284,6 +298,7 @@ def async_tp_pass_on_test_model(
seq_len
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
dynamic
:
bool
,
):
current_platform
.
seed_everything
(
0
)
...
...
@@ -331,6 +346,9 @@ def async_tp_pass_on_test_model(
(
batch_size
*
seq_len
,
hidden_size
),
dtype
=
dtype
,
requires_grad
=
False
)
if
dynamic
:
torch
.
_dynamo
.
mark_dynamic
(
hidden_states
,
0
)
compiled_model
=
torch
.
compile
(
model
,
backend
=
backend
)
compiled_model
(
hidden_states
)
...
...
vllm/compilation/collective_fusion.py
View file @
a25f2ade
...
...
@@ -172,7 +172,7 @@ class ScaledMMReduceScatterPattern(BasePattern):
# Calculate output shape: input @ mat2 with scatter_dim reduced
output_shape
=
[
*
input
.
shape
[:
-
1
],
mat2
.
shape
[
1
]]
scatter_dim
=
0
gemm_rs
=
torch
.
ops
.
symm_mem
.
fused_scaled_matmul_reduce_scatter
(
gemm_rs
=
torch
.
ops
.
vllm
.
patched_
fused_scaled_matmul_reduce_scatter
(
input
,
mat2
,
scale_a
,
...
...
@@ -307,7 +307,7 @@ class CutlassScaledMMReduceScatterPattern(BasePattern):
# Calculate output shape: input @ mat2 with scatter_dim reduced
output_shape
=
[
*
input
.
shape
[:
-
1
],
mat2
.
shape
[
1
]]
scatter_dim
=
0
gemm_rs
=
torch
.
ops
.
symm_mem
.
fused_scaled_matmul_reduce_scatter
(
gemm_rs
=
torch
.
ops
.
vllm
.
patched_
fused_scaled_matmul_reduce_scatter
(
input
,
mat2
,
scale_a
,
...
...
vllm/distributed/parallel_state.py
View file @
a25f2ade
...
...
@@ -37,6 +37,8 @@ from unittest.mock import patch
import
torch
import
torch.distributed
import
torch.distributed._functional_collectives
as
funcol
import
torch.distributed._symmetric_memory
from
torch.distributed
import
Backend
,
ProcessGroup
from
typing_extensions
import
deprecated
...
...
@@ -159,6 +161,90 @@ def all_gather_fake(
return
torch
.
empty
(
new_shape
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
def
patched_fused_scaled_matmul_reduce_scatter_fake
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
A_scale
:
torch
.
Tensor
,
B_scale
:
torch
.
Tensor
,
reduce_op
:
str
,
orig_scatter_dim
:
int
,
scatter_dim_after_maybe_reshape
:
int
,
group_name
:
str
,
output_shape
:
list
[
int
],
bias
:
torch
.
Tensor
|
None
=
None
,
result_scale
:
torch
.
Tensor
|
None
=
None
,
out_dtype
:
torch
.
dtype
|
None
=
None
,
use_fast_accum
:
bool
=
False
,
)
->
torch
.
Tensor
:
# Copied from
# https://github.com/pytorch/pytorch/blob/50c338c2da905062449e4d9ac807832d1b5cd90e/torch/distributed/_symmetric_memory/__init__.py#L1189
if
A_scale
.
numel
()
>
1
:
if
A_scale
.
shape
[:
-
1
]
!=
A
.
shape
[:
-
1
]:
raise
ValueError
(
"For row-wise scaling, the leading dims of A_scale "
"must match the leading dims of A "
f
"(A shape:
{
A
.
shape
}
, A_scale shape:
{
A_scale
.
shape
}
)"
)
A_scale
=
A_scale
.
flatten
(
0
,
-
2
).
contiguous
()
elif
A_scale
.
numel
()
!=
1
:
raise
ValueError
(
"Invalid A_scale shape "
f
"(A shape:
{
A
.
shape
}
, A_scale shape:
{
A_scale
.
shape
}
)"
)
C
=
torch
.
_scaled_mm
(
A
.
flatten
(
0
,
-
2
).
contiguous
(),
B
,
A_scale
,
B_scale
,
bias
,
result_scale
,
out_dtype
,
use_fast_accum
,
)
C
=
C
.
view
(
*
output_shape
[:
-
1
],
B
.
shape
[
1
])
res
=
funcol
.
reduce_scatter_tensor
(
C
,
reduce_op
,
orig_scatter_dim
,
# need original scatter dim for 3D+ output tensor here
group_name
,
)
res
=
funcol
.
wait_tensor
(
res
)
return
res
def
patched_fused_scaled_matmul_reduce_scatter
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
A_scale
:
torch
.
Tensor
,
B_scale
:
torch
.
Tensor
,
reduce_op
:
str
,
orig_scatter_dim
:
int
,
scatter_dim_after_maybe_reshape
:
int
,
group_name
:
str
,
output_shape
:
list
[
int
],
bias
:
torch
.
Tensor
|
None
=
None
,
result_scale
:
torch
.
Tensor
|
None
=
None
,
out_dtype
:
torch
.
dtype
|
None
=
None
,
use_fast_accum
:
bool
=
False
,
)
->
torch
.
Tensor
:
return
torch
.
ops
.
symm_mem
.
fused_scaled_matmul_reduce_scatter
(
A
,
B
,
A_scale
,
B_scale
,
reduce_op
,
orig_scatter_dim
,
scatter_dim_after_maybe_reshape
,
group_name
,
output_shape
,
bias
,
result_scale
,
out_dtype
,
use_fast_accum
,
)
if
supports_custom_op
():
direct_register_custom_op
(
op_name
=
"all_reduce"
,
...
...
@@ -178,6 +264,15 @@ if supports_custom_op():
fake_impl
=
all_gather_fake
,
)
# TODO: Remove this once the pytorch fix
# (https://github.com/pytorch/pytorch/pull/165086) gets released,
# in either 2.9.1 or 2.10
direct_register_custom_op
(
op_name
=
"patched_fused_scaled_matmul_reduce_scatter"
,
op_func
=
patched_fused_scaled_matmul_reduce_scatter
,
fake_impl
=
patched_fused_scaled_matmul_reduce_scatter_fake
,
)
class
GroupCoordinator
:
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment