"src/vscode:/vscode.git/clone" did not exist on "48664d62b8e9f70d03b1be4059c1464a3b167f85"
Unverified Commit 5b1afa78 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Re-quantize DeepSeek model weights to support DeepGEMM new input format (#7156)

parent c49c1d92
# COPIED FROM DeepGEMM
def align(x: int, y: int) -> int:
return ceil_div(x, y) * y
# COPIED FROM DeepGEMM
def ceil_div(x: int, y: int) -> int:
return (x + y - 1) // y
......@@ -4,6 +4,7 @@ from typing import Callable, List, Optional, Tuple
import torch
from sglang.math_utils import align
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
from sglang.srt.layers.utils import is_sm100_supported
......@@ -390,6 +391,66 @@ def block_quant_dequant(
return (x_q_block.to(torch.float32) * x_scale_repeat).to(dtype)
def requant_weight_ue8m0_inplace(weight, weight_scale_inv, weight_block_size):
assert isinstance(weight, torch.nn.Parameter)
assert isinstance(weight_scale_inv, torch.nn.Parameter)
weight.data, weight_scale_inv.data = _requant_weight_ue8m0(
weight, weight_scale_inv, weight_block_size
)
def _requant_weight_ue8m0(
weight: torch.Tensor,
weight_scale_inv: torch.Tensor,
weight_block_size: List[int],
):
assert weight_block_size == [128, 128]
*_, n, k = weight.shape
weight_dequant = block_quant_dequant(
weight,
weight_scale_inv,
weight_block_size,
torch.bfloat16,
)
weight_dequant_flat = weight_dequant.view((-1, k))
out_w_flat, out_s_flat = per_block_cast_to_fp8(weight_dequant_flat)
out_w = out_w_flat.view(weight.shape)
out_s = out_s_flat.view(weight_scale_inv.shape)
# NOTE copy and modified from DeepGEMM
def _transform_scale(sf, mn: int):
import deep_gemm.utils.layout
sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128)
sf = deep_gemm.utils.layout.get_col_major_tma_aligned_packed_tensor(sf)
return sf
out_s = _transform_scale(out_s, mn=out_w.shape[-2])
return out_w, out_s
# COPIED FROM DeepGEMM
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(
(align(m, 128), align(n, 128)), dtype=x.dtype, device=x.device
)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
sf = ceil_to_ue8m0(x_amax / 448.0)
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
x_view.size(0), x_view.size(2)
)
# COPIED FROM DeepGEMM
def ceil_to_ue8m0(x: torch.Tensor):
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
......
......@@ -66,6 +66,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
block_quant_to_tensor_quant,
channel_quant_to_tensor_quant,
normalize_e4m3fn_to_e4m3fnuz,
requant_weight_ue8m0_inplace,
)
from sglang.srt.layers.quantization.int8_utils import (
block_dequant as int8_block_dequant,
......@@ -1935,6 +1936,61 @@ class DeepseekV2ForCausalLM(nn.Module):
self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous())
self_attn.use_deep_gemm_bmm = True
if False: # TODO (pr-chain)
self._weight_requant_ue8m0()
def _weight_requant_ue8m0(self):
weight_block_size = self.quant_config.weight_block_size
moe_layers = list(
range(
self.config.first_k_dense_replace,
self.config.num_hidden_layers,
self.config.moe_layer_freq,
)
)
for layer_id in range(self.config.num_hidden_layers):
layer = self.model.layers[layer_id]
for module in [
layer.self_attn.fused_qkv_a_proj_with_mqa,
layer.self_attn.q_b_proj,
layer.self_attn.kv_b_proj,
layer.self_attn.o_proj,
]:
requant_weight_ue8m0_inplace(
module.weight, module.weight_scale_inv, weight_block_size
)
if layer_id in moe_layers:
shared_experts = layer.mlp.shared_experts
for module in [
shared_experts.gate_up_proj,
shared_experts.down_proj,
]:
requant_weight_ue8m0_inplace(
module.weight, module.weight_scale_inv, weight_block_size
)
experts = layer.mlp.experts
if isinstance(experts, DeepEPMoE):
for w in [
experts.w13_weight_fp8,
experts.w2_weight_fp8,
]:
requant_weight_ue8m0_inplace(w[0], w[1], weight_block_size)
else:
mlp = layer.mlp
assert isinstance(mlp, DeepseekV2MLP)
for module in [
mlp.gate_up_proj,
mlp.down_proj,
]:
requant_weight_ue8m0_inplace(
module.weight, module.weight_scale_inv, weight_block_size
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
if is_nextn:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment