Unverified Commit e5ddeb04 authored by Charles Chen's avatar Charles Chen Committed by GitHub
Browse files

Quick fix for DeepGemm requant to also cover MTP. (#7378)

parent bdbb8d00
......@@ -1988,11 +1988,9 @@ class DeepseekV2ForCausalLM(nn.Module):
and hasattr(self.quant_config, "weight_block_size")
and self.quant_config.weight_block_size is not None
):
self._weight_requant_ue8m0()
self._weight_requant_ue8m0(is_nextn)
def _weight_requant_ue8m0(self):
if self.config.architectures[0] == "DeepseekV3ForCausalLMNextN":
return
def _weight_requant_ue8m0(self, is_nextn=False):
weight_block_size = self.quant_config.weight_block_size
moe_layers = list(
......@@ -2003,7 +2001,11 @@ class DeepseekV2ForCausalLM(nn.Module):
)
)
for layer_id in range(self.config.num_hidden_layers):
num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers
for layer_id in range(num_hidden_layers):
if is_nextn:
layer = self.model.decoder
else:
layer = self.model.layers[layer_id]
for module in [
......@@ -2016,7 +2018,7 @@ class DeepseekV2ForCausalLM(nn.Module):
module.weight, module.weight_scale_inv, weight_block_size
)
if layer_id in moe_layers:
if layer_id in moe_layers or is_nextn:
shared_experts = getattr(layer.mlp, "shared_experts", None)
if shared_experts is not None:
for module in [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment