Commit 8f66f64b authored by 王敏's avatar 王敏
Browse files

修复all2all报维度不匹配问题

parent 578d3e97
...@@ -180,7 +180,7 @@ class EPMoE(FusedMoE): ...@@ -180,7 +180,7 @@ class EPMoE(FusedMoE):
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
enable_eplb: bool = False, enable_eplb: bool = False,
num_redundant_experts: int = 0, num_redundant_experts: int = 0,
moe_permute_fusion: bool = False, moe_permute_fusion: bool = True,
moe_shared_expert_overlap: bool = False moe_shared_expert_overlap: bool = False
): ):
super().__init__(num_experts, top_k, hidden_size, super().__init__(num_experts, top_k, hidden_size,
......
...@@ -212,7 +212,6 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): ...@@ -212,7 +212,6 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
if self.use_all_gather: if self.use_all_gather:
# Gather is not supported for some devices such as TPUs. # Gather is not supported for some devices such as TPUs.
# Use all-gather instead. # Use all-gather instead.
num_global_tokens_per_expert = expert_parallel_all_gather(num_local_tokens_per_expert) \ num_global_tokens_per_expert = expert_parallel_all_gather(num_local_tokens_per_expert) \
.reshape(self.ep_size, self.tp_size, self.num_experts) \ .reshape(self.ep_size, self.tp_size, self.num_experts) \
.transpose(0, 1) .transpose(0, 1)
...@@ -327,7 +326,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): ...@@ -327,7 +326,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
###test############## ###test##############
#cuda_dtoh_stream.synchronize() #cuda_dtoh_stream.synchronize()
#cuda_dtoh_sync_event.synchronize() cuda_dtoh_sync_event.synchronize()
###test############## ###test##############
global_input_tokens = all_to_all( global_input_tokens = all_to_all(
...@@ -462,10 +461,10 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): ...@@ -462,10 +461,10 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self.num_global_tokens_per_local_expert, record_stream=on_side_stream self.num_global_tokens_per_local_expert, record_stream=on_side_stream
) )
#cuda_dtoh_sync_event.record() cuda_dtoh_sync_event.record()
if point == self.cuda_sync_point: # if point == self.cuda_sync_point:
# Synchronize with the dtoh stream at self.cuda_sync_point. # # Synchronize with the dtoh stream at self.cuda_sync_point.
cuda_dtoh_stream.synchronize() # cuda_dtoh_stream.synchronize()
return tokens_per_expert return tokens_per_expert
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment