Commit 29523973 authored by yangql's avatar yangql
Browse files

修复ep的auto模式的崩溃bug

parent 8bfc2d65
...@@ -50,7 +50,7 @@ class DeepEPAutoPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -50,7 +50,7 @@ class DeepEPAutoPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
if self._current_phase == "prefill": if self._current_phase == "prefill":
#rint("************prefill***********") #rint("************prefill***********")
return self.ll_prepare_finalize return self.ll_prepare_finalize
else: else:
# print("attn_metadata.num_decode_tokens",attn_metadata.num_decode_tokens) # print("attn_metadata.num_decode_tokens",attn_metadata.num_decode_tokens)
return self.ht__prepare_finalize return self.ht__prepare_finalize
#return self.ht_prepare_finalize #return self.ht_prepare_finalize
......
...@@ -911,7 +911,7 @@ class FusedMoE(torch.nn.Module): ...@@ -911,7 +911,7 @@ class FusedMoE(torch.nn.Module):
self.batched_hidden_states: Optional[torch.Tensor] = None self.batched_hidden_states: Optional[torch.Tensor] = None
self.batched_router_logits: Optional[torch.Tensor] = None self.batched_router_logits: Optional[torch.Tensor] = None
if (self.moe_parallel_config.use_pplx_kernels if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels): or self.moe_parallel_config.use_deepep_ll_kernels or self.moe_parallel_config.use_deepep_auto_kernels):
self.batched_hidden_states = torch.zeros( self.batched_hidden_states = torch.zeros(
(moe.max_num_tokens, self.hidden_size), (moe.max_num_tokens, self.hidden_size),
dtype=moe.in_dtype, dtype=moe.in_dtype,
......
...@@ -957,7 +957,8 @@ class DeepseekV2Model(nn.Module): ...@@ -957,7 +957,8 @@ class DeepseekV2Model(nn.Module):
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \ self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \ (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_auto")
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment