Unverified Commit b7361cc4 authored by Guoyuan Lin's avatar Guoyuan Lin Committed by GitHub
Browse files

[Fix] fix the issue encountered when inference LongCat-Flash/MTP EP MoE on b200 (#9916)

parent a96c5b5c
......@@ -651,9 +651,6 @@ class LongcatFlashForCausalLM(nn.Module):
).T
else:
w = self_attn.kv_b_proj.weight
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
# This may affect the accuracy of fp8 model.
# Fix deepseek v3 blockwise bmm by using deep_gemm
use_deep_gemm_bmm = False
if w.dtype in (
......@@ -790,6 +787,9 @@ class LongcatFlashForCausalLM(nn.Module):
self.config.hidden_size / self.config.kv_lora_rank
) ** 0.5
# TODO(linguoyuan) EPMoE not support DEEPGEMM_BLACKWELL, DeepEP needs to be supported in the future
deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 = False
if (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
......@@ -804,24 +804,35 @@ class LongcatFlashForCausalLM(nn.Module):
for layer_id in range(self.config.num_hidden_layers):
layer = self.model.layers[layer_id]
for i in range(2):
for module in [
layer.self_attn[i].fused_qkv_a_proj_with_mqa,
layer.self_attn[i].q_b_proj,
layer.self_attn[i].kv_b_proj,
layer.self_attn[i].o_proj,
]:
requant_weight_ue8m0_inplace(
module.weight, module.weight_scale_inv, weight_block_size
)
self_attn = layer.self_attn[i]
module_list = [
self_attn.kv_b_proj,
self_attn.o_proj,
]
if self.config.q_lora_rank is not None:
module_list.append(self_attn.fused_qkv_a_proj_with_mqa)
module_list.append(self_attn.q_b_proj)
else:
module_list.append(self_attn.kv_a_proj_with_mqa)
module_list.append(self_attn.q_proj)
for module in module_list:
if hasattr(module, "weight_scale_inv"):
requant_weight_ue8m0_inplace(
module.weight, module.weight_scale_inv, weight_block_size
)
mlp = layer.mlps[i]
assert isinstance(mlp, LongcatFlashMLP)
for module in [
mlp.gate_up_proj,
mlp.down_proj,
]:
requant_weight_ue8m0_inplace(
module.weight, module.weight_scale_inv, weight_block_size
)
if hasattr(module, "weight_scale_inv"):
requant_weight_ue8m0_inplace(
module.weight, module.weight_scale_inv, weight_block_size
)
for layer_id in range(self.config.num_hidden_layers):
experts = layer.mlp.experts
......
......@@ -344,9 +344,6 @@ class LongcatFlashForCausalLMNextN(LongcatFlashForCausalLM):
).T
else:
w = self_attn.kv_b_proj.weight
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
# This may affect the accuracy of fp8 model.
# Fix deepseek v3 blockwise bmm by using deep_gemm
use_deep_gemm_bmm = False
if w.dtype in (
torch.float8_e4m3fn,
......@@ -480,24 +477,35 @@ class LongcatFlashForCausalLMNextN(LongcatFlashForCausalLM):
def _weight_requant_ue8m0(self):
weight_block_size = self.quant_config.weight_block_size
layer = self.model.decoder
for module in [
layer.self_attn.fused_qkv_a_proj_with_mqa,
layer.self_attn.q_b_proj,
layer.self_attn.kv_b_proj,
layer.self_attn.o_proj,
]:
requant_weight_ue8m0_inplace(
module.weight, module.weight_scale_inv, weight_block_size
)
self_attn = layer.self_attn
module_list = [
self_attn.kv_b_proj,
self_attn.o_proj,
]
if self.config.q_lora_rank is not None:
module_list.append(self_attn.fused_qkv_a_proj_with_mqa)
module_list.append(self_attn.q_b_proj)
else:
module_list.append(self_attn.kv_a_proj_with_mqa)
module_list.append(self_attn.q_proj)
for module in module_list:
if hasattr(module, "weight_scale_inv"):
requant_weight_ue8m0_inplace(
module.weight, module.weight_scale_inv, weight_block_size
)
mlp = layer.mlps
assert isinstance(mlp, LongcatFlashMLP)
for module in [
mlp.gate_up_proj,
mlp.down_proj,
]:
requant_weight_ue8m0_inplace(
module.weight, module.weight_scale_inv, weight_block_size
)
if hasattr(module, "weight_scale_inv"):
requant_weight_ue8m0_inplace(
module.weight, module.weight_scale_inv, weight_block_size
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment