Commit 6efaf21a authored by chenhw5's avatar chenhw5 Committed by zhangzbb
Browse files

[BUGFIX]修复deepgemm算子导致的GLM5 W8A8精度问题。

parent fbfe20c6
......@@ -244,7 +244,7 @@ def test_silu_mul_fp8_quant_deep_gemm(E: int, T: int, H: int, fp8_type: torch.dt
and current_platform.has_device_capability(100)
and scale_fmt == DeepGemmQuantScaleFMT.UE8M0
):
from deep_gemm import transform_sf_into_required_layout
from deepgemm import transform_sf_into_required_layout
_q, _s = ref_with_scale_fmt(
E,
......
......@@ -36,7 +36,7 @@ from vllm.utils.import_utils import has_deep_gemm
from lightop import fuse_silu_mul_quant_ep
if has_deep_gemm():
from deep_gemm import m_grouped_w8a8_gemm_nt_masked
from deepgemm import m_grouped_w8a8_gemm_nt_masked
else:
from lightop import m_grouped_w8a8_gemm_nt_masked
......@@ -481,5 +481,11 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
output,
expert_num_tokens,
expected_m)
# moe_grouped_gemm(a1q, w1, a1q_scale, self.w1_scale, expert_num_tokens, workspace1)
# act_out = self.act_fn(workspace1)
# a2q, a2q_scale = per_token_quant_int8(act_out)
# moe_grouped_gemm(a2q, w2, a2q_scale, self.w2_scale, expert_num_tokens, output)
else:
raise ValueError(f"Unsupported dtype {self.quant_config.quant_dtype}")
......@@ -39,7 +39,7 @@ from vllm.utils.import_utils import has_deep_gemm
from lightop import fuse_silu_mul_quant
if has_deep_gemm():
from deep_gemm import m_grouped_i8_gemm_nt_contiguous
from deepgemm import m_grouped_i8_gemm_nt_contiguous
else:
from lightop import m_grouped_w8a8_gemm_nt_contig_asm as m_grouped_i8_gemm_nt_contiguous
......
......@@ -17,7 +17,7 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoeWeightScaleSupported, FusedMoEConfig)
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.layers.quantization.utils.w8a8_utils import(
get_w8a8_int8_marlin_weights, w8a8_nt_kpack2_marlin_weight)
get_w8a8_int8_marlin_weights, w8a8_nt_kpack2_marlin_weight, weight8bit_nt_kpack2_marlin1)
from vllm.model_executor.layers.fused_moe.config import (FusedMoEQuantConfig, int8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
......@@ -375,7 +375,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
if not self.use_deepep:
w1_marlin_in = get_w8a8_int8_marlin_weights(layer.w13_weight[ii])
else:
w1_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w13_weight[ii])
w1_marlin_in = weight8bit_nt_kpack2_marlin1(layer.w13_weight[ii])
w1_marlin_list.append(w1_marlin_in)
w1_marlin = torch.stack(w1_marlin_list, dim=0)
......@@ -385,7 +385,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
if not self.use_deepep:
w2_marlin_in = get_w8a8_int8_marlin_weights(layer.w2_weight[ii])
else:
w2_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w2_weight[ii])
w2_marlin_in = weight8bit_nt_kpack2_marlin1(layer.w2_weight[ii])
w2_marlin_list.append(w2_marlin_in)
w2_marlin = torch.stack(w2_marlin_list, dim=0)
......
......@@ -43,6 +43,30 @@ def w8a8_nt_kpack2_marlin_weight(w8a8_w, # [size_n, size_k// 2 ]
w8a8_w = w8a8_w.reshape((size_n // k_tile, size_k * k_tile))
return w8a8_w
def weight8bit_nt_kpack2_marlin1(weight, # [size_n, size_k// 2 ]
k_tile=16,
k_tile1=4,
n_tile=16,
n_tile1=16):
assert weight.element_size() == 1, "weight 必须是 8 bit 类型"
if weight.dim() == 2:
size_n, size_k = weight.shape
assert size_n % k_tile == 0 and size_k % n_tile == 0, "k_tile / n_tile 必须能整除对应维度"
q = weight.reshape((size_n // (n_tile*n_tile1), n_tile1, n_tile, size_k // (k_tile*k_tile1), k_tile1, k_tile))
# q = q.permute((0, 2, 1, 3)).contiguous()
q = q.permute((0, 3, 1, 4, 2, 5)).contiguous()
q = q.reshape((size_n // k_tile, size_k * k_tile))
elif weight.dim() == 3:
E, size_n, size_k = weight.shape
assert size_n % n_tile == 0 and size_k % k_tile == 0, "k_tile / n_tile 必须能整除对应维度"
q = weight.reshape((E, size_n // (n_tile*n_tile1), n_tile1, n_tile, size_k // (k_tile*k_tile1), k_tile1, k_tile))
q = q.permute((0, 1, 4, 2, 5, 3, 6)).contiguous()
q = q.reshape((E, size_n // k_tile, size_k * k_tile))
return q
def sparse_cutlass_supported() -> bool:
if not current_platform.is_cuda():
......
......@@ -413,8 +413,8 @@ def has_deep_ep() -> bool:
def has_deep_gemm() -> bool:
"""Whether the optional `deep_gemm` package is available."""
return _has_module("deep_gemm")
"""Whether the optional `deepgemm` package is available."""
return _has_module("deepgemm")
def has_triton_kernels() -> bool:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment