Unverified Commit 12eb02e9 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Change bf16 to fp8 for some gemms in attention for DeepSeek ckpt v2 (#11805)

parent 002d0373
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8 from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
from sglang.srt.utils import is_sm100_supported, offloader from sglang.srt.utils import ceil_div, is_sm100_supported, offloader
try: try:
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
...@@ -441,23 +441,53 @@ def _requant_weight_ue8m0( ...@@ -441,23 +441,53 @@ def _requant_weight_ue8m0(
torch.bfloat16, torch.bfloat16,
) )
out_w, out_s = quant_weight_ue8m0(
weight_dequant=weight_dequant,
weight_block_size=weight_block_size,
)
out_s = _transform_scale_ue8m0(out_s, mn=out_w.shape[-2])
return out_w, out_s
def quant_weight_ue8m0(
weight_dequant: torch.Tensor,
weight_block_size: List[int],
):
assert weight_block_size == [128, 128]
assert (
weight_dequant.dtype == torch.bfloat16
), f"{weight_dequant.dtype=} {weight_dequant.shape=}"
*batch_dims, n, k = weight_dequant.shape
weight_dequant_flat = weight_dequant.view((-1, k)) weight_dequant_flat = weight_dequant.view((-1, k))
out_w_flat, out_s_flat = per_block_cast_to_fp8(weight_dequant_flat) out_w_flat, out_s_flat = per_block_cast_to_fp8(weight_dequant_flat)
out_w = out_w_flat.view(weight.shape) out_w = out_w_flat.view((*batch_dims, n, k))
out_s = out_s_flat.view(weight_scale_inv.shape) out_s = out_s_flat.view(
(
*batch_dims,
ceil_div(n, weight_block_size[0]),
ceil_div(k, weight_block_size[1]),
)
)
return out_w, out_s
# NOTE copy and modified from DeepGEMM def transform_scale_ue8m0_inplace(param, mn):
def _transform_scale(sf, mn: int): param.data = _transform_scale_ue8m0(param.data, mn=mn)
import deep_gemm.utils.layout
sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128)
sf = deep_gemm.utils.layout.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
return sf
out_s = _transform_scale(out_s, mn=out_w.shape[-2]) # NOTE copy and modified from DeepGEMM
def _transform_scale_ue8m0(sf, mn):
import deep_gemm.utils.layout
return out_w, out_s sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128)
sf = deep_gemm.utils.layout.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
return sf
# COPIED FROM DeepGEMM # COPIED FROM DeepGEMM
......
...@@ -94,7 +94,9 @@ from sglang.srt.layers.quantization.fp8_utils import ( ...@@ -94,7 +94,9 @@ from sglang.srt.layers.quantization.fp8_utils import (
channel_quant_to_tensor_quant, channel_quant_to_tensor_quant,
input_to_float8, input_to_float8,
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz,
quant_weight_ue8m0,
requant_weight_ue8m0_inplace, requant_weight_ue8m0_inplace,
transform_scale_ue8m0_inplace,
) )
from sglang.srt.layers.quantization.int8_utils import ( from sglang.srt.layers.quantization.int8_utils import (
block_dequant as int8_block_dequant, block_dequant as int8_block_dequant,
...@@ -1098,7 +1100,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1098,7 +1100,7 @@ class DeepseekV2AttentionMLA(nn.Module):
q_lora_rank, q_lora_rank,
self.num_heads * self.qk_head_dim, self.num_heads * self.qk_head_dim,
bias=False, bias=False,
quant_config=quant_config, quant_config=self._get_q_b_proj_quant_config(quant_config),
prefix=add_prefix("q_b_proj", prefix), prefix=add_prefix("q_b_proj", prefix),
tp_rank=attn_tp_rank, tp_rank=attn_tp_rank,
tp_size=attn_tp_size, tp_size=attn_tp_size,
...@@ -2393,6 +2395,17 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -2393,6 +2395,17 @@ class DeepseekV2AttentionMLA(nn.Module):
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@staticmethod
def _get_q_b_proj_quant_config(quant_config):
if get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN"):
# refer to real DeepSeek V3 quant config
return Fp8Config(
is_checkpoint_fp8_serialized=True,
weight_block_size=[128, 128],
)
else:
return quant_config
class DeepseekV2DecoderLayer(nn.Module): class DeepseekV2DecoderLayer(nn.Module):
...@@ -3130,6 +3143,13 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -3130,6 +3143,13 @@ class DeepseekV2ForCausalLM(nn.Module):
): ):
self._weight_requant_ue8m0(is_nextn) self._weight_requant_ue8m0(is_nextn)
# TODO can move weight_requant_ue8m0 and transform_scale_ue8m0 into Fp8LinearMethod.process_weights_after_loading
if (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
and get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN")
):
self._transform_scale_ue8m0(is_nextn)
if is_nextn and enable_nextn_moe_bf16_cast_to_fp8(self.quant_config): if is_nextn and enable_nextn_moe_bf16_cast_to_fp8(self.quant_config):
self._transform_scale_nextn_moe_ue8m0() self._transform_scale_nextn_moe_ue8m0()
...@@ -3198,6 +3218,25 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -3198,6 +3218,25 @@ class DeepseekV2ForCausalLM(nn.Module):
module.weight, module.weight_scale_inv, weight_block_size module.weight, module.weight_scale_inv, weight_block_size
) )
# TODO can move weight_requant_ue8m0 and transform_scale_ue8m0 into Fp8LinearMethod.process_weights_after_loading
def _transform_scale_ue8m0(self, is_nextn=False):
num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers
for layer_id in range(num_hidden_layers):
if is_nextn:
layer = self.model.decoder
else:
layer = self.model.layers[layer_id]
module_list = []
if self.config.q_lora_rank is not None:
module_list.append(layer.self_attn.q_b_proj)
for module in module_list:
transform_scale_ue8m0_inplace(
module.weight_scale_inv, mn=module.weight.shape[-2]
)
# TODO avoid code dup (currently combine from weight_requant_ue8m0 and transform_scale_ue8m0) # TODO avoid code dup (currently combine from weight_requant_ue8m0 and transform_scale_ue8m0)
def _transform_scale_nextn_moe_ue8m0(self): def _transform_scale_nextn_moe_ue8m0(self):
layer = self.model.decoder layer = self.model.decoder
...@@ -3235,6 +3274,8 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -3235,6 +3274,8 @@ class DeepseekV2ForCausalLM(nn.Module):
else: else:
raise ValueError("num_nextn_predict_layers is not in the config") raise ValueError("num_nextn_predict_layers is not in the config")
if get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN"):
weights = self._quant_attn_to_fp8_ue8m0(weights, is_nextn=is_nextn)
if is_nextn and enable_nextn_moe_bf16_cast_to_fp8(self.quant_config): if is_nextn and enable_nextn_moe_bf16_cast_to_fp8(self.quant_config):
weights = self._quant_nextn_moe_to_fp8_ue8m0( weights = self._quant_nextn_moe_to_fp8_ue8m0(
weights, nextn_layer_id=nextn_layer_id weights, nextn_layer_id=nextn_layer_id
...@@ -3469,6 +3510,30 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -3469,6 +3510,30 @@ class DeepseekV2ForCausalLM(nn.Module):
self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names) self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
def _quant_attn_to_fp8_ue8m0(self, weights, is_nextn):
weights_dict = dict(weights)
# temporarily only support DeepSeek V3/R1
weight_block_size = [128, 128]
for layer_id in trange(
self.config.num_hidden_layers + int(is_nextn),
desc="quant attn to fp8 ue8m0",
):
for stem in [
# may put tensors like `o_proj` here for DeepSeek FP4 ckpt v1
"q_b_proj",
]:
partial_name = f"model.layers.{layer_id}.self_attn.{stem}"
original_weight = weights_dict[f"{partial_name}.weight"]
out_w, out_s = quant_weight_ue8m0(
original_weight, weight_block_size=weight_block_size
)
weights_dict[f"{partial_name}.weight"] = out_w
weights_dict[f"{partial_name}.weight_scale_inv"] = out_s
return list(weights_dict.items())
# TODO avoid code dup # TODO avoid code dup
def _quant_nextn_moe_to_fp8_ue8m0(self, weights, nextn_layer_id: int): def _quant_nextn_moe_to_fp8_ue8m0(self, weights, nextn_layer_id: int):
weights_dict = dict(weights) weights_dict = dict(weights)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment