"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "7298bdd8177c16eadb74f6166327f5984fd8c69d"
Unverified Commit 12eb02e9 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Change bf16 to fp8 for some gemms in attention for DeepSeek ckpt v2 (#11805)

parent 002d0373
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8 from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
from sglang.srt.utils import is_sm100_supported, offloader from sglang.srt.utils import ceil_div, is_sm100_supported, offloader
try: try:
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
...@@ -441,23 +441,53 @@ def _requant_weight_ue8m0( ...@@ -441,23 +441,53 @@ def _requant_weight_ue8m0(
torch.bfloat16, torch.bfloat16,
) )
out_w, out_s = quant_weight_ue8m0(
weight_dequant=weight_dequant,
weight_block_size=weight_block_size,
)
out_s = _transform_scale_ue8m0(out_s, mn=out_w.shape[-2])
return out_w, out_s
def quant_weight_ue8m0(
weight_dequant: torch.Tensor,
weight_block_size: List[int],
):
assert weight_block_size == [128, 128]
assert (
weight_dequant.dtype == torch.bfloat16
), f"{weight_dequant.dtype=} {weight_dequant.shape=}"
*batch_dims, n, k = weight_dequant.shape
weight_dequant_flat = weight_dequant.view((-1, k)) weight_dequant_flat = weight_dequant.view((-1, k))
out_w_flat, out_s_flat = per_block_cast_to_fp8(weight_dequant_flat) out_w_flat, out_s_flat = per_block_cast_to_fp8(weight_dequant_flat)
out_w = out_w_flat.view(weight.shape) out_w = out_w_flat.view((*batch_dims, n, k))
out_s = out_s_flat.view(weight_scale_inv.shape) out_s = out_s_flat.view(
(
*batch_dims,
ceil_div(n, weight_block_size[0]),
ceil_div(k, weight_block_size[1]),
)
)
return out_w, out_s
# NOTE copy and modified from DeepGEMM def transform_scale_ue8m0_inplace(param, mn):
def _transform_scale(sf, mn: int): param.data = _transform_scale_ue8m0(param.data, mn=mn)
import deep_gemm.utils.layout
sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128)
sf = deep_gemm.utils.layout.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
return sf
out_s = _transform_scale(out_s, mn=out_w.shape[-2]) # NOTE copy and modified from DeepGEMM
def _transform_scale_ue8m0(sf, mn):
import deep_gemm.utils.layout
return out_w, out_s sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128)
sf = deep_gemm.utils.layout.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
return sf
# COPIED FROM DeepGEMM # COPIED FROM DeepGEMM
......
...@@ -94,7 +94,9 @@ from sglang.srt.layers.quantization.fp8_utils import ( ...@@ -94,7 +94,9 @@ from sglang.srt.layers.quantization.fp8_utils import (
channel_quant_to_tensor_quant, channel_quant_to_tensor_quant,
input_to_float8, input_to_float8,
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz,
quant_weight_ue8m0,
requant_weight_ue8m0_inplace, requant_weight_ue8m0_inplace,
transform_scale_ue8m0_inplace,
) )
from sglang.srt.layers.quantization.int8_utils import ( from sglang.srt.layers.quantization.int8_utils import (
block_dequant as int8_block_dequant, block_dequant as int8_block_dequant,
...@@ -1098,7 +1100,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1098,7 +1100,7 @@ class DeepseekV2AttentionMLA(nn.Module):
q_lora_rank, q_lora_rank,
self.num_heads * self.qk_head_dim, self.num_heads * self.qk_head_dim,
bias=False, bias=False,
quant_config=quant_config, quant_config=self._get_q_b_proj_quant_config(quant_config),
prefix=add_prefix("q_b_proj", prefix), prefix=add_prefix("q_b_proj", prefix),
tp_rank=attn_tp_rank, tp_rank=attn_tp_rank,
tp_size=attn_tp_size, tp_size=attn_tp_size,
...@@ -2393,6 +2395,17 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -2393,6 +2395,17 @@ class DeepseekV2AttentionMLA(nn.Module):
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@staticmethod
def _get_q_b_proj_quant_config(quant_config):
if get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN"):
# refer to real DeepSeek V3 quant config
return Fp8Config(
is_checkpoint_fp8_serialized=True,
weight_block_size=[128, 128],
)
else:
return quant_config
class DeepseekV2DecoderLayer(nn.Module): class DeepseekV2DecoderLayer(nn.Module):
...@@ -3130,6 +3143,13 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -3130,6 +3143,13 @@ class DeepseekV2ForCausalLM(nn.Module):
): ):
self._weight_requant_ue8m0(is_nextn) self._weight_requant_ue8m0(is_nextn)
# TODO can move weight_requant_ue8m0 and transform_scale_ue8m0 into Fp8LinearMethod.process_weights_after_loading
if (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
and get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN")
):
self._transform_scale_ue8m0(is_nextn)
if is_nextn and enable_nextn_moe_bf16_cast_to_fp8(self.quant_config): if is_nextn and enable_nextn_moe_bf16_cast_to_fp8(self.quant_config):
self._transform_scale_nextn_moe_ue8m0() self._transform_scale_nextn_moe_ue8m0()
...@@ -3198,6 +3218,25 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -3198,6 +3218,25 @@ class DeepseekV2ForCausalLM(nn.Module):
module.weight, module.weight_scale_inv, weight_block_size module.weight, module.weight_scale_inv, weight_block_size
) )
# TODO can move weight_requant_ue8m0 and transform_scale_ue8m0 into Fp8LinearMethod.process_weights_after_loading
def _transform_scale_ue8m0(self, is_nextn=False):
num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers
for layer_id in range(num_hidden_layers):
if is_nextn:
layer = self.model.decoder
else:
layer = self.model.layers[layer_id]
module_list = []
if self.config.q_lora_rank is not None:
module_list.append(layer.self_attn.q_b_proj)
for module in module_list:
transform_scale_ue8m0_inplace(
module.weight_scale_inv, mn=module.weight.shape[-2]
)
# TODO avoid code dup (currently combine from weight_requant_ue8m0 and transform_scale_ue8m0) # TODO avoid code dup (currently combine from weight_requant_ue8m0 and transform_scale_ue8m0)
def _transform_scale_nextn_moe_ue8m0(self): def _transform_scale_nextn_moe_ue8m0(self):
layer = self.model.decoder layer = self.model.decoder
...@@ -3235,6 +3274,8 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -3235,6 +3274,8 @@ class DeepseekV2ForCausalLM(nn.Module):
else: else:
raise ValueError("num_nextn_predict_layers is not in the config") raise ValueError("num_nextn_predict_layers is not in the config")
if get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN"):
weights = self._quant_attn_to_fp8_ue8m0(weights, is_nextn=is_nextn)
if is_nextn and enable_nextn_moe_bf16_cast_to_fp8(self.quant_config): if is_nextn and enable_nextn_moe_bf16_cast_to_fp8(self.quant_config):
weights = self._quant_nextn_moe_to_fp8_ue8m0( weights = self._quant_nextn_moe_to_fp8_ue8m0(
weights, nextn_layer_id=nextn_layer_id weights, nextn_layer_id=nextn_layer_id
...@@ -3469,6 +3510,30 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -3469,6 +3510,30 @@ class DeepseekV2ForCausalLM(nn.Module):
self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names) self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
def _quant_attn_to_fp8_ue8m0(self, weights, is_nextn):
weights_dict = dict(weights)
# temporarily only support DeepSeek V3/R1
weight_block_size = [128, 128]
for layer_id in trange(
self.config.num_hidden_layers + int(is_nextn),
desc="quant attn to fp8 ue8m0",
):
for stem in [
# may put tensors like `o_proj` here for DeepSeek FP4 ckpt v1
"q_b_proj",
]:
partial_name = f"model.layers.{layer_id}.self_attn.{stem}"
original_weight = weights_dict[f"{partial_name}.weight"]
out_w, out_s = quant_weight_ue8m0(
original_weight, weight_block_size=weight_block_size
)
weights_dict[f"{partial_name}.weight"] = out_w
weights_dict[f"{partial_name}.weight_scale_inv"] = out_s
return list(weights_dict.items())
# TODO avoid code dup # TODO avoid code dup
def _quant_nextn_moe_to_fp8_ue8m0(self, weights, nextn_layer_id: int): def _quant_nextn_moe_to_fp8_ue8m0(self, weights, nextn_layer_id: int):
weights_dict = dict(weights) weights_dict = dict(weights)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment