Commit 578d3e97 authored by 王敏's avatar 王敏
Browse files

修复all2all报维度不匹配问题

parent cde83ab0
...@@ -327,8 +327,7 @@ def all_to_all(group, input, output_split_sizes, input_split_sizes): ...@@ -327,8 +327,7 @@ def all_to_all(group, input, output_split_sizes, input_split_sizes):
output = input.new_empty( output = input.new_empty(
size=[sum(output_split_sizes)] + list(input.size()[1:]), size=[sum(output_split_sizes)] + list(input.size()[1:]),
dtype=input.dtype, dtype=input.dtype,
#device=torch.cuda.current_device(), device=torch.cuda.current_device()
device=input.device,
) )
torch.distributed.all_to_all_single( torch.distributed.all_to_all_single(
...@@ -336,8 +335,7 @@ def all_to_all(group, input, output_split_sizes, input_split_sizes): ...@@ -336,8 +335,7 @@ def all_to_all(group, input, output_split_sizes, input_split_sizes):
input, input,
output_split_sizes=output_split_sizes, output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes, input_split_sizes=input_split_sizes,
group=group, group=group
async_op=True
) )
return output return output
...@@ -180,7 +180,7 @@ class EPMoE(FusedMoE): ...@@ -180,7 +180,7 @@ class EPMoE(FusedMoE):
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
enable_eplb: bool = False, enable_eplb: bool = False,
num_redundant_experts: int = 0, num_redundant_experts: int = 0,
moe_permute_fusion: bool = True, moe_permute_fusion: bool = False,
moe_shared_expert_overlap: bool = False moe_shared_expert_overlap: bool = False
): ):
super().__init__(num_experts, top_k, hidden_size, super().__init__(num_experts, top_k, hidden_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment