Unverified Commit b7361cc4 authored by Guoyuan Lin's avatar Guoyuan Lin Committed by GitHub
Browse files

[Fix] fix the issue encountered when inference LongCat-Flash/MTP EP MoE on b200 (#9916)

parent a96c5b5c
...@@ -651,9 +651,6 @@ class LongcatFlashForCausalLM(nn.Module): ...@@ -651,9 +651,6 @@ class LongcatFlashForCausalLM(nn.Module):
).T ).T
else: else:
w = self_attn.kv_b_proj.weight w = self_attn.kv_b_proj.weight
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
# This may affect the accuracy of fp8 model.
# Fix deepseek v3 blockwise bmm by using deep_gemm
use_deep_gemm_bmm = False use_deep_gemm_bmm = False
if w.dtype in ( if w.dtype in (
...@@ -790,6 +787,9 @@ class LongcatFlashForCausalLM(nn.Module): ...@@ -790,6 +787,9 @@ class LongcatFlashForCausalLM(nn.Module):
self.config.hidden_size / self.config.kv_lora_rank self.config.hidden_size / self.config.kv_lora_rank
) ** 0.5 ) ** 0.5
# TODO(linguoyuan) EPMoE not support DEEPGEMM_BLACKWELL, DeepEP needs to be supported in the future
deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 = False
if ( if (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
...@@ -804,24 +804,35 @@ class LongcatFlashForCausalLM(nn.Module): ...@@ -804,24 +804,35 @@ class LongcatFlashForCausalLM(nn.Module):
for layer_id in range(self.config.num_hidden_layers): for layer_id in range(self.config.num_hidden_layers):
layer = self.model.layers[layer_id] layer = self.model.layers[layer_id]
for i in range(2): for i in range(2):
for module in [ self_attn = layer.self_attn[i]
layer.self_attn[i].fused_qkv_a_proj_with_mqa, module_list = [
layer.self_attn[i].q_b_proj, self_attn.kv_b_proj,
layer.self_attn[i].kv_b_proj, self_attn.o_proj,
layer.self_attn[i].o_proj, ]
]:
requant_weight_ue8m0_inplace( if self.config.q_lora_rank is not None:
module.weight, module.weight_scale_inv, weight_block_size module_list.append(self_attn.fused_qkv_a_proj_with_mqa)
) module_list.append(self_attn.q_b_proj)
else:
module_list.append(self_attn.kv_a_proj_with_mqa)
module_list.append(self_attn.q_proj)
for module in module_list:
if hasattr(module, "weight_scale_inv"):
requant_weight_ue8m0_inplace(
module.weight, module.weight_scale_inv, weight_block_size
)
mlp = layer.mlps[i] mlp = layer.mlps[i]
assert isinstance(mlp, LongcatFlashMLP) assert isinstance(mlp, LongcatFlashMLP)
for module in [ for module in [
mlp.gate_up_proj, mlp.gate_up_proj,
mlp.down_proj, mlp.down_proj,
]: ]:
requant_weight_ue8m0_inplace( if hasattr(module, "weight_scale_inv"):
module.weight, module.weight_scale_inv, weight_block_size requant_weight_ue8m0_inplace(
) module.weight, module.weight_scale_inv, weight_block_size
)
for layer_id in range(self.config.num_hidden_layers): for layer_id in range(self.config.num_hidden_layers):
experts = layer.mlp.experts experts = layer.mlp.experts
......
...@@ -344,9 +344,6 @@ class LongcatFlashForCausalLMNextN(LongcatFlashForCausalLM): ...@@ -344,9 +344,6 @@ class LongcatFlashForCausalLMNextN(LongcatFlashForCausalLM):
).T ).T
else: else:
w = self_attn.kv_b_proj.weight w = self_attn.kv_b_proj.weight
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
# This may affect the accuracy of fp8 model.
# Fix deepseek v3 blockwise bmm by using deep_gemm
use_deep_gemm_bmm = False use_deep_gemm_bmm = False
if w.dtype in ( if w.dtype in (
torch.float8_e4m3fn, torch.float8_e4m3fn,
...@@ -480,24 +477,35 @@ class LongcatFlashForCausalLMNextN(LongcatFlashForCausalLM): ...@@ -480,24 +477,35 @@ class LongcatFlashForCausalLMNextN(LongcatFlashForCausalLM):
def _weight_requant_ue8m0(self): def _weight_requant_ue8m0(self):
weight_block_size = self.quant_config.weight_block_size weight_block_size = self.quant_config.weight_block_size
layer = self.model.decoder layer = self.model.decoder
for module in [ self_attn = layer.self_attn
layer.self_attn.fused_qkv_a_proj_with_mqa, module_list = [
layer.self_attn.q_b_proj, self_attn.kv_b_proj,
layer.self_attn.kv_b_proj, self_attn.o_proj,
layer.self_attn.o_proj, ]
]:
requant_weight_ue8m0_inplace( if self.config.q_lora_rank is not None:
module.weight, module.weight_scale_inv, weight_block_size module_list.append(self_attn.fused_qkv_a_proj_with_mqa)
) module_list.append(self_attn.q_b_proj)
else:
module_list.append(self_attn.kv_a_proj_with_mqa)
module_list.append(self_attn.q_proj)
for module in module_list:
if hasattr(module, "weight_scale_inv"):
requant_weight_ue8m0_inplace(
module.weight, module.weight_scale_inv, weight_block_size
)
mlp = layer.mlps mlp = layer.mlps
assert isinstance(mlp, LongcatFlashMLP) assert isinstance(mlp, LongcatFlashMLP)
for module in [ for module in [
mlp.gate_up_proj, mlp.gate_up_proj,
mlp.down_proj, mlp.down_proj,
]: ]:
requant_weight_ue8m0_inplace( if hasattr(module, "weight_scale_inv"):
module.weight, module.weight_scale_inv, weight_block_size requant_weight_ue8m0_inplace(
) module.weight, module.weight_scale_inv, weight_block_size
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment