Unverified Commit 12eb02e9 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Change bf16 to fp8 for some gemms in attention for DeepSeek ckpt v2 (#11805)

parent 002d0373
......@@ -5,7 +5,7 @@ import torch
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
from sglang.srt.utils import is_sm100_supported, offloader
from sglang.srt.utils import ceil_div, is_sm100_supported, offloader
try:
from vllm import _custom_ops as ops
......@@ -441,24 +441,54 @@ def _requant_weight_ue8m0(
torch.bfloat16,
)
out_w, out_s = quant_weight_ue8m0(
weight_dequant=weight_dequant,
weight_block_size=weight_block_size,
)
out_s = _transform_scale_ue8m0(out_s, mn=out_w.shape[-2])
return out_w, out_s
def quant_weight_ue8m0(
weight_dequant: torch.Tensor,
weight_block_size: List[int],
):
assert weight_block_size == [128, 128]
assert (
weight_dequant.dtype == torch.bfloat16
), f"{weight_dequant.dtype=} {weight_dequant.shape=}"
*batch_dims, n, k = weight_dequant.shape
weight_dequant_flat = weight_dequant.view((-1, k))
out_w_flat, out_s_flat = per_block_cast_to_fp8(weight_dequant_flat)
out_w = out_w_flat.view(weight.shape)
out_s = out_s_flat.view(weight_scale_inv.shape)
out_w = out_w_flat.view((*batch_dims, n, k))
out_s = out_s_flat.view(
(
*batch_dims,
ceil_div(n, weight_block_size[0]),
ceil_div(k, weight_block_size[1]),
)
)
return out_w, out_s
# NOTE copy and modified from DeepGEMM
def _transform_scale(sf, mn: int):
def transform_scale_ue8m0_inplace(param, mn):
param.data = _transform_scale_ue8m0(param.data, mn=mn)
# NOTE copy and modified from DeepGEMM
def _transform_scale_ue8m0(sf, mn):
import deep_gemm.utils.layout
sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128)
sf = deep_gemm.utils.layout.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
return sf
out_s = _transform_scale(out_s, mn=out_w.shape[-2])
return out_w, out_s
# COPIED FROM DeepGEMM
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
......
......@@ -94,7 +94,9 @@ from sglang.srt.layers.quantization.fp8_utils import (
channel_quant_to_tensor_quant,
input_to_float8,
normalize_e4m3fn_to_e4m3fnuz,
quant_weight_ue8m0,
requant_weight_ue8m0_inplace,
transform_scale_ue8m0_inplace,
)
from sglang.srt.layers.quantization.int8_utils import (
block_dequant as int8_block_dequant,
......@@ -1098,7 +1100,7 @@ class DeepseekV2AttentionMLA(nn.Module):
q_lora_rank,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
quant_config=self._get_q_b_proj_quant_config(quant_config),
prefix=add_prefix("q_b_proj", prefix),
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
......@@ -2393,6 +2395,17 @@ class DeepseekV2AttentionMLA(nn.Module):
output, _ = self.o_proj(attn_output)
return output
@staticmethod
def _get_q_b_proj_quant_config(quant_config):
if get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN"):
# refer to real DeepSeek V3 quant config
return Fp8Config(
is_checkpoint_fp8_serialized=True,
weight_block_size=[128, 128],
)
else:
return quant_config
class DeepseekV2DecoderLayer(nn.Module):
......@@ -3130,6 +3143,13 @@ class DeepseekV2ForCausalLM(nn.Module):
):
self._weight_requant_ue8m0(is_nextn)
# TODO can move weight_requant_ue8m0 and transform_scale_ue8m0 into Fp8LinearMethod.process_weights_after_loading
if (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
and get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN")
):
self._transform_scale_ue8m0(is_nextn)
if is_nextn and enable_nextn_moe_bf16_cast_to_fp8(self.quant_config):
self._transform_scale_nextn_moe_ue8m0()
......@@ -3198,6 +3218,25 @@ class DeepseekV2ForCausalLM(nn.Module):
module.weight, module.weight_scale_inv, weight_block_size
)
# TODO can move weight_requant_ue8m0 and transform_scale_ue8m0 into Fp8LinearMethod.process_weights_after_loading
def _transform_scale_ue8m0(self, is_nextn=False):
num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers
for layer_id in range(num_hidden_layers):
if is_nextn:
layer = self.model.decoder
else:
layer = self.model.layers[layer_id]
module_list = []
if self.config.q_lora_rank is not None:
module_list.append(layer.self_attn.q_b_proj)
for module in module_list:
transform_scale_ue8m0_inplace(
module.weight_scale_inv, mn=module.weight.shape[-2]
)
# TODO avoid code dup (currently combine from weight_requant_ue8m0 and transform_scale_ue8m0)
def _transform_scale_nextn_moe_ue8m0(self):
layer = self.model.decoder
......@@ -3235,6 +3274,8 @@ class DeepseekV2ForCausalLM(nn.Module):
else:
raise ValueError("num_nextn_predict_layers is not in the config")
if get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN"):
weights = self._quant_attn_to_fp8_ue8m0(weights, is_nextn=is_nextn)
if is_nextn and enable_nextn_moe_bf16_cast_to_fp8(self.quant_config):
weights = self._quant_nextn_moe_to_fp8_ue8m0(
weights, nextn_layer_id=nextn_layer_id
......@@ -3469,6 +3510,30 @@ class DeepseekV2ForCausalLM(nn.Module):
self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
def _quant_attn_to_fp8_ue8m0(self, weights, is_nextn):
weights_dict = dict(weights)
# temporarily only support DeepSeek V3/R1
weight_block_size = [128, 128]
for layer_id in trange(
self.config.num_hidden_layers + int(is_nextn),
desc="quant attn to fp8 ue8m0",
):
for stem in [
# may put tensors like `o_proj` here for DeepSeek FP4 ckpt v1
"q_b_proj",
]:
partial_name = f"model.layers.{layer_id}.self_attn.{stem}"
original_weight = weights_dict[f"{partial_name}.weight"]
out_w, out_s = quant_weight_ue8m0(
original_weight, weight_block_size=weight_block_size
)
weights_dict[f"{partial_name}.weight"] = out_w
weights_dict[f"{partial_name}.weight_scale_inv"] = out_s
return list(weights_dict.items())
# TODO avoid code dup
def _quant_nextn_moe_to_fp8_ue8m0(self, weights, nextn_layer_id: int):
weights_dict = dict(weights)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment