Commit f505d366 authored by 王敏's avatar 王敏
Browse files

[fix]解决mori ep dp>1时cudagraph卡住问题

parent f0f159a4
...@@ -229,10 +229,10 @@ class EPMoE(FusedMoE): ...@@ -229,10 +229,10 @@ class EPMoE(FusedMoE):
] ]
self.use_shared_expert = False self.use_shared_expert = False
self.token_dispatcher = MoEAlltoAllTokenDispatcher( # self.token_dispatcher = MoEAlltoAllTokenDispatcher(
self.local_num_experts, self.local_expert_indices, # self.local_num_experts, self.local_expert_indices,
config=self.ep_moe_config, layer_name=f"{self.layer_name}.token_dispatcher", # config=self.ep_moe_config, layer_name=f"{self.layer_name}.token_dispatcher",
) # )
self.shared_expert_overlap = moe_shared_expert_overlap self.shared_expert_overlap = moe_shared_expert_overlap
self.shared_experts = None self.shared_experts = None
...@@ -245,7 +245,6 @@ class EPMoE(FusedMoE): ...@@ -245,7 +245,6 @@ class EPMoE(FusedMoE):
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
self.max_num_inp_token_per_rank = vllm_config.scheduler_config.max_num_seqs self.max_num_inp_token_per_rank = vllm_config.scheduler_config.max_num_seqs
self.mori_op = self.get_mori_op() self.mori_op = self.get_mori_op()
self.first = True
def get_mori_op(self): def get_mori_op(self):
...@@ -293,8 +292,8 @@ class EPMoE(FusedMoE): ...@@ -293,8 +292,8 @@ class EPMoE(FusedMoE):
if self.shared_experts is None: if self.shared_experts is None:
self.shared_experts = shared_experts self.shared_experts = shared_experts
if self.shared_expert_overlap: # if self.shared_expert_overlap:
self.token_dispatcher.set_shared_experts(self.shared_experts) # self.token_dispatcher.set_shared_experts(self.shared_experts)
def create_quant_method(self, moe, quant_config, prefix): def create_quant_method(self, moe, quant_config, prefix):
# Note: get_quant_method will look at the layer's local_num_experts # Note: get_quant_method will look at the layer's local_num_experts
...@@ -308,7 +307,7 @@ class EPMoE(FusedMoE): ...@@ -308,7 +307,7 @@ class EPMoE(FusedMoE):
return quant_method return quant_method
def sync(self): def sync(self):
#torch.cuda.synchronize() torch.cuda.synchronize()
dist.barrier() dist.barrier()
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
...@@ -383,30 +382,9 @@ class EPMoE(FusedMoE): ...@@ -383,30 +382,9 @@ class EPMoE(FusedMoE):
topk_weights, topk_weights,
scales, scales,
topk_ids, topk_ids,
layer_idx=int(self.layer_name.split('.')[2])
) )
#self.sync() #self.sync()
# expect_m = topk_ids.shape[0] * self.ep_size
# dispatch_output_clip = dispatch_output[:expect_m]
# dispatch_weights_clip = dispatch_weights[:expect_m]
# dispatch_indices_clip = dispatch_indices[:expect_m]
# dispatch_scales_clip = dispatch_scales[:expect_m]
# expert_output = self.quant_method.apply_ep(
# layer=self,
# x=dispatch_output_clip,
# topk_weights=dispatch_weights_clip,
# topk_ids=dispatch_indices_clip,
# global_num_experts=self.global_num_experts,
# expert_map=self.expert_map,
# activation=self.activation,
# apply_router_weight_on_input=self.apply_router_weight_on_input,
# use_nn_moe=self.use_nn_moe,
# num_local_tokens=dispatch_recv_num_token,
# config_select_bs=hidden_states.shape[0],
# scales=dispatch_scales_clip if self.use_int8_dispatch else None
# #routed_scaling_factor=self.routed_scaling_factor,
# )
expert_output = self.quant_method.apply_ep( expert_output = self.quant_method.apply_ep(
layer=self, layer=self,
...@@ -419,7 +397,7 @@ class EPMoE(FusedMoE): ...@@ -419,7 +397,7 @@ class EPMoE(FusedMoE):
apply_router_weight_on_input=self.apply_router_weight_on_input, apply_router_weight_on_input=self.apply_router_weight_on_input,
use_nn_moe=self.use_nn_moe, use_nn_moe=self.use_nn_moe,
num_local_tokens=dispatch_recv_num_token, num_local_tokens=dispatch_recv_num_token,
config_select_bs=hidden_states.shape[0], config_select_bs=hidden_states.shape[0]*self.ep_size/self.dp_size,
scales=dispatch_scales if self.use_int8_dispatch else None scales=dispatch_scales if self.use_int8_dispatch else None
#routed_scaling_factor=self.routed_scaling_factor, #routed_scaling_factor=self.routed_scaling_factor,
) )
...@@ -431,8 +409,6 @@ class EPMoE(FusedMoE): ...@@ -431,8 +409,6 @@ class EPMoE(FusedMoE):
#self.sync() #self.sync()
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None: if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None:
# if shared_expert_overlap is True, the expert calculation happens in
# the token_dispatcher to overlap communications and computations
# shared_output = ( # shared_output = (
# self.maybe_all_reduce_tensor_model_parallel( # self.maybe_all_reduce_tensor_model_parallel(
# shared_output)) # shared_output))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment