"components/vscode:/vscode.git/clone" did not exist on "20f1c5a3c0bec01a27efb133bcb7589c640f5a5a"
Commit 578d3e97 authored by 王敏's avatar 王敏
Browse files

修复all2all报维度不匹配问题

parent cde83ab0
......@@ -327,8 +327,7 @@ def all_to_all(group, input, output_split_sizes, input_split_sizes):
output = input.new_empty(
size=[sum(output_split_sizes)] + list(input.size()[1:]),
dtype=input.dtype,
#device=torch.cuda.current_device(),
device=input.device,
device=torch.cuda.current_device()
)
torch.distributed.all_to_all_single(
......@@ -336,8 +335,7 @@ def all_to_all(group, input, output_split_sizes, input_split_sizes):
input,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
async_op=True
group=group
)
return output
......@@ -180,7 +180,7 @@ class EPMoE(FusedMoE):
routed_scaling_factor: Optional[float] = None,
enable_eplb: bool = False,
num_redundant_experts: int = 0,
moe_permute_fusion: bool = True,
moe_permute_fusion: bool = False,
moe_shared_expert_overlap: bool = False
):
super().__init__(num_experts, top_k, hidden_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment