Unverified Commit 183d9f96 authored by HAI's avatar HAI Committed by GitHub
Browse files

DeepSeek: enable none block-quant FP8 quantizations (#6638)

parent 63195028
......@@ -57,6 +57,7 @@ from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.layers.quantization.fp8_kernel import (
is_fp8_fnuz,
per_tensor_quant_mla_fp8,
per_token_group_quant_mla_deep_gemm_masked_fp8,
)
......@@ -101,6 +102,7 @@ from sglang.srt.utils import (
_is_hip = is_hip()
_is_cuda = is_cuda()
_is_fp8_fnuz = is_fp8_fnuz()
if _is_cuda:
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
......@@ -684,7 +686,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self.w_kc = None
self.w_vc = None
self.w_scale = None
self.w_scale = 1.0
self.w_scale_k = None
self.w_scale_v = None
......@@ -948,8 +950,8 @@ class DeepseekV2AttentionMLA(nn.Module):
expected_m,
)
q_nope_out = q_nope_out[:, :expected_m, :]
elif self.w_kc.dtype == torch.float8_e4m3fnuz:
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
elif _is_hip:
# TODO(haishaw): add bmm_fp8 to ROCm
q_nope_out = torch.bmm(
q_nope.to(torch.bfloat16).transpose(0, 1),
self.w_kc.to(torch.bfloat16) * self.w_scale,
......@@ -1000,8 +1002,8 @@ class DeepseekV2AttentionMLA(nn.Module):
expected_m,
)
attn_bmm_output = attn_bmm_output[:, :expected_m, :]
elif self.w_vc.dtype == torch.float8_e4m3fnuz:
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
elif _is_hip:
# TODO(haishaw): add bmm_fp8 to ROCm
attn_bmm_output = torch.bmm(
attn_output.to(torch.bfloat16).transpose(0, 1),
self.w_vc.to(torch.bfloat16) * self.w_scale,
......@@ -1052,8 +1054,8 @@ class DeepseekV2AttentionMLA(nn.Module):
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
if self.w_kc.dtype == torch.float8_e4m3fnuz:
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
if _is_hip:
# TODO(haishaw): add bmm_fp8 to ROCm
q_nope_out = torch.bmm(
q_nope.to(torch.bfloat16).transpose(0, 1),
self.w_kc.to(torch.bfloat16) * self.w_scale,
......@@ -1186,8 +1188,8 @@ class DeepseekV2AttentionMLA(nn.Module):
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
if self.w_vc.dtype == torch.float8_e4m3fnuz:
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
if _is_hip:
# TODO(haishaw): add bmm_fp8 to ROCm
attn_bmm_output = torch.bmm(
attn_output.to(torch.bfloat16).transpose(0, 1),
self.w_vc.to(torch.bfloat16) * self.w_scale,
......@@ -1749,46 +1751,56 @@ class DeepseekV2ForCausalLM(nn.Module):
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
):
if hasattr(self.quant_config, "weight_block_size"):
if (
hasattr(self.quant_config, "weight_block_size")
and self.quant_config.weight_block_size is not None
):
weight_block_size = self.quant_config.weight_block_size
if weight_block_size is not None:
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
if _is_hip:
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=w,
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
input_scale=None,
)
else:
weight = w
weight_scale = self_attn.kv_b_proj.weight_scale_inv
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
if _is_fp8_fnuz:
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=w,
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
input_scale=None,
)
else:
weight = w
weight_scale = self_attn.kv_b_proj.weight_scale_inv
if (
_is_cuda
and weight_block_size[0] == 128
and weight_block_size[1] == 128
and model_dtype == torch.bfloat16
if (
_is_cuda
and weight_block_size[0] == 128
and weight_block_size[1] == 128
and model_dtype == torch.bfloat16
):
if _ENABLE_JIT_DEEPGEMM and get_bool_env_var(
"SGL_USE_DEEPGEMM_BMM", "false"
):
if _ENABLE_JIT_DEEPGEMM and get_bool_env_var(
"SGL_USE_DEEPGEMM_BMM", "false"
):
block_scale = weight_scale
use_deep_gemm_bmm = True
else:
w = block_quant_dequant(
weight,
weight_scale,
weight_block_size,
model_dtype,
)
block_scale = weight_scale
use_deep_gemm_bmm = True
else:
w, scale = block_quant_to_tensor_quant(
weight, weight_scale, weight_block_size
w = block_quant_dequant(
weight,
weight_scale,
weight_block_size,
model_dtype,
)
self_attn.w_scale = scale
else:
w, scale = block_quant_to_tensor_quant(
weight, weight_scale, weight_block_size
)
self_attn.w_scale = scale
else:
weight = w
weight_scale = self_attn.kv_b_proj.weight_scale
if _is_fp8_fnuz:
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=w,
weight_scale=self_attn.kv_b_proj.weight_scale,
input_scale=None,
)
else:
weight = w
weight_scale = self_attn.kv_b_proj.weight_scale
w, scale = channel_quant_to_tensor_quant(weight, weight_scale)
self_attn.w_scale = scale
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment