Unverified Commit 10042057 authored by Mengqing Cao's avatar Mengqing Cao Committed by GitHub
Browse files

[MTP] Refactor mtp predictor to avoid d2h operation (#27643)


Signed-off-by: default avatarMengqingCao <cmq0113@163.com>
parent ba33e883
...@@ -97,7 +97,7 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module): ...@@ -97,7 +97,7 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
assert inputs_embeds is not None assert inputs_embeds is not None
# masking inputs at position 0, as not needed by MTP # masking inputs at position 0, as not needed by MTP
inputs_embeds[positions == 0] = 0 inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds)
inputs_embeds = self.enorm(inputs_embeds) inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states) previous_hidden_states = self.hnorm(previous_hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment