"tests/vscode:/vscode.git/clone" did not exist on "7961486a9b749b1b60d8b6fd5fb7d61596a9b041"
Unverified Commit 07bdabef authored by Wang Xingran's avatar Wang Xingran Committed by GitHub
Browse files

[Bugfix] Use 'sum' reduction instead of 'avg' in Async TP reduce-scatter (#33088)


Signed-off-by: default avatarXingran Wang <wangxingran123456@outlook.com>
Signed-off-by: default avatarHongjian Zhang <hirokenovo@gmail.com>
Co-authored-by: default avatarHongjian Zhang <hirokenovo@gmail.com>
parent a572baff
...@@ -53,7 +53,7 @@ class GEMMReduceScatterPattern(BasePattern): ...@@ -53,7 +53,7 @@ class GEMMReduceScatterPattern(BasePattern):
gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter( gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
mul, mul,
mm_weight, mm_weight,
"avg", "sum",
scatter_dim=0, scatter_dim=0,
group_name=self.tp.device_group.group_name, group_name=self.tp.device_group.group_name,
) )
...@@ -150,7 +150,7 @@ class ScaledMMReduceScatterPattern(BasePattern): ...@@ -150,7 +150,7 @@ class ScaledMMReduceScatterPattern(BasePattern):
mat2, mat2,
scale_a, scale_a,
scale_b, scale_b,
"avg", "sum",
scatter_dim, # orig_scatter_dim scatter_dim, # orig_scatter_dim
scatter_dim, # scatter_dim_after_maybe_reshape scatter_dim, # scatter_dim_after_maybe_reshape
self.tp.device_group.group_name, self.tp.device_group.group_name,
...@@ -285,7 +285,7 @@ class CutlassScaledMMReduceScatterPattern(BasePattern): ...@@ -285,7 +285,7 @@ class CutlassScaledMMReduceScatterPattern(BasePattern):
mat2, mat2,
scale_a, scale_a,
scale_b, scale_b,
"avg", "sum",
scatter_dim, # orig_scatter_dim scatter_dim, # orig_scatter_dim
scatter_dim, # scatter_dim_after_maybe_reshape scatter_dim, # scatter_dim_after_maybe_reshape
self.tp.device_group.group_name, self.tp.device_group.group_name,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment