Unverified Commit fbd88728 authored by Benjamin Chislett's avatar Benjamin Chislett Committed by GitHub
Browse files

[Bugfix] Fix DeepSeek MTP (#22934)


Signed-off-by: default avatarBenjamin Chislett <benjamin.chislett@centml.ai>
parent 070da660
...@@ -158,14 +158,13 @@ class DeepSeekMTP(nn.Module, SupportsPP): ...@@ -158,14 +158,13 @@ class DeepSeekMTP(nn.Module, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
previous_hidden_states: torch.Tensor, hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0, spec_step_idx: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, hidden_states = self.model(input_ids, positions, hidden_states,
previous_hidden_states, inputs_embeds, inputs_embeds, spec_step_idx)
spec_step_idx)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -213,13 +212,15 @@ class DeepSeekMTP(nn.Module, SupportsPP): ...@@ -213,13 +212,15 @@ class DeepSeekMTP(nn.Module, SupportsPP):
# for mlp.experts[0].gate_gate_up_proj, which breaks load. # for mlp.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts." in name) and name not in params_dict): if (("mlp.experts." in name) and name not in params_dict):
continue continue
name = name.replace(weight_name, param_name) name_mapped = name.replace(weight_name, param_name)
# QKV fusion is optional, fall back to normal # QKV fusion is optional, fall back to normal
# weight loading if it's not enabled # weight loading if it's not enabled
if ((param_name == "fused_qkv_a_proj") if ((param_name == "fused_qkv_a_proj")
and name not in params_dict): and name_mapped not in params_dict):
continue continue
else:
name = name_mapped
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
......
...@@ -180,14 +180,13 @@ class Glm4MoeMTP(nn.Module, SupportsPP): ...@@ -180,14 +180,13 @@ class Glm4MoeMTP(nn.Module, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
previous_hidden_states: torch.Tensor, hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0, spec_step_idx: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, hidden_states = self.model(input_ids, positions, hidden_states,
previous_hidden_states, inputs_embeds, inputs_embeds, spec_step_idx)
spec_step_idx)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
......
...@@ -164,15 +164,14 @@ class MiMoMTP(nn.Module): ...@@ -164,15 +164,14 @@ class MiMoMTP(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
previous_hidden_states: torch.Tensor, hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0, spec_step_idx: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
assert spec_step_idx == 0, "mimo_mtp only support predict one token now" assert spec_step_idx == 0, "mimo_mtp only support predict one token now"
hidden_states = self.model(input_ids, positions, hidden_states = self.model(input_ids, positions, hidden_states,
previous_hidden_states, inputs_embeds, inputs_embeds, spec_step_idx)
spec_step_idx)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment