Commit 1e9ff2e7 authored by chenhw5's avatar chenhw5
Browse files

fix deepgemm accuracy bug.

parent a8d6ba1e
...@@ -244,7 +244,7 @@ def test_silu_mul_fp8_quant_deep_gemm(E: int, T: int, H: int, fp8_type: torch.dt ...@@ -244,7 +244,7 @@ def test_silu_mul_fp8_quant_deep_gemm(E: int, T: int, H: int, fp8_type: torch.dt
and current_platform.has_device_capability(100) and current_platform.has_device_capability(100)
and scale_fmt == DeepGemmQuantScaleFMT.UE8M0 and scale_fmt == DeepGemmQuantScaleFMT.UE8M0
): ):
from deep_gemm import transform_sf_into_required_layout from deepgemm import transform_sf_into_required_layout
_q, _s = ref_with_scale_fmt( _q, _s = ref_with_scale_fmt(
E, E,
......
...@@ -41,7 +41,7 @@ from vllm.model_executor.layers.activation import SiluAndMul ...@@ -41,7 +41,7 @@ from vllm.model_executor.layers.activation import SiluAndMul
from lightop import fuse_silu_mul_quant_ep from lightop import fuse_silu_mul_quant_ep
from lmslim.layers.gemm.int8_utils import per_token_quant_int8 from lmslim.layers.gemm.int8_utils import per_token_quant_int8
if has_deep_gemm(): if has_deep_gemm():
from deep_gemm import m_grouped_w8a8_gemm_nt_masked from deepgemm import m_grouped_w8a8_gemm_nt_masked
else: else:
from lightop import m_grouped_w8a8_gemm_nt_masked from lightop import m_grouped_w8a8_gemm_nt_masked
...@@ -642,26 +642,28 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -642,26 +642,28 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expected_m, expected_m,
) )
elif self.quant_config.use_int8_w8a8: elif self.quant_config.use_int8_w8a8:
# m_grouped_w8a8_gemm_nt_masked((a1q, a1q_scale),
# (w1, self.w1_scale),
# workspace1,
# expert_num_tokens,
# expected_m,
# )
# assert expert_num_tokens is not None
# a2q, a2q_scale = fuse_silu_mul_quant_ep(workspace1, expert_num_tokens)
# m_grouped_w8a8_gemm_nt_masked((a2q, a2q_scale),
# (w2, self.w2_scale),
# output,
# expert_num_tokens,
# expected_m)
moe_grouped_gemm(a1q, w1, a1q_scale, self.w1_scale, expert_num_tokens, workspace1) m_grouped_w8a8_gemm_nt_masked((a1q, a1q_scale),
act_out = self.act_fn(workspace1) (w1, self.w1_scale),
a2q, a2q_scale = per_token_quant_int8(act_out) workspace1,
moe_grouped_gemm(a2q, w2, a2q_scale, self.w2_scale, expert_num_tokens, output) expert_num_tokens,
expected_m,
)
assert expert_num_tokens is not None
a2q, a2q_scale = fuse_silu_mul_quant_ep(workspace1, expert_num_tokens)
m_grouped_w8a8_gemm_nt_masked((a2q, a2q_scale),
(w2, self.w2_scale),
output,
expert_num_tokens,
expected_m)
# moe_grouped_gemm(a1q, w1, a1q_scale, self.w1_scale, expert_num_tokens, workspace1)
# act_out = self.act_fn(workspace1)
# a2q, a2q_scale = per_token_quant_int8(act_out)
# moe_grouped_gemm(a2q, w2, a2q_scale, self.w2_scale, expert_num_tokens, output)
else: else:
raise ValueError(f"Unsupported dtype {self.quant_config.quant_dtype}") raise ValueError(f"Unsupported dtype {self.quant_config.quant_dtype}")
...@@ -39,7 +39,7 @@ from vllm.utils.import_utils import has_deep_gemm ...@@ -39,7 +39,7 @@ from vllm.utils.import_utils import has_deep_gemm
from lightop import fuse_silu_mul_quant from lightop import fuse_silu_mul_quant
if has_deep_gemm(): if has_deep_gemm():
from deep_gemm import m_grouped_i8_gemm_nt_contiguous from deepgemm import m_grouped_i8_gemm_nt_contiguous
else: else:
from lightop import m_grouped_w8a8_gemm_nt_contig_asm as m_grouped_i8_gemm_nt_contiguous from lightop import m_grouped_w8a8_gemm_nt_contig_asm as m_grouped_i8_gemm_nt_contiguous
......
...@@ -17,7 +17,7 @@ from vllm.model_executor.layers.fused_moe import ( ...@@ -17,7 +17,7 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoeWeightScaleSupported, FusedMoEConfig) FusedMoeWeightScaleSupported, FusedMoEConfig)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.layers.quantization.utils.w8a8_utils import( from vllm.model_executor.layers.quantization.utils.w8a8_utils import(
get_w8a8_int8_marlin_weights, w8a8_nt_kpack2_marlin_weight) get_w8a8_int8_marlin_weights, w8a8_nt_kpack2_marlin_weight, weight8bit_nt_kpack2_marlin1)
from vllm.model_executor.layers.fused_moe.config import (FusedMoEQuantConfig, int8_w8a8_moe_quant_config) from vllm.model_executor.layers.fused_moe.config import (FusedMoEQuantConfig, int8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoE,
...@@ -370,28 +370,29 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -370,28 +370,29 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
layer.w2_input_scale = None layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if not self.use_deepep: w1_marlin_list = []
w1_marlin_list = [] for ii in range(layer.w13_weight.shape[0]):
for ii in range(layer.w13_weight.shape[0]): if not self.use_deepep:
if not self.use_deepep: w1_marlin_in = get_w8a8_int8_marlin_weights(layer.w13_weight[ii])
w1_marlin_in = get_w8a8_int8_marlin_weights(layer.w13_weight[ii]) else:
else: #w1_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w13_weight[ii])
w1_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w13_weight[ii]) w1_marlin_in = weight8bit_nt_kpack2_marlin1(layer.w13_weight[ii])
w1_marlin_list.append(w1_marlin_in) w1_marlin_list.append(w1_marlin_in)
w1_marlin = torch.stack(w1_marlin_list, dim=0) w1_marlin = torch.stack(w1_marlin_list, dim=0)
del w1_marlin_list del w1_marlin_list
w2_marlin_list = [] w2_marlin_list = []
for ii in range(layer.w2_weight.shape[0]): for ii in range(layer.w2_weight.shape[0]):
if not self.use_deepep: if not self.use_deepep:
w2_marlin_in = get_w8a8_int8_marlin_weights(layer.w2_weight[ii]) w2_marlin_in = get_w8a8_int8_marlin_weights(layer.w2_weight[ii])
else: else:
w2_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w2_weight[ii]) #w2_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w2_weight[ii])
w2_marlin_list.append(w2_marlin_in) w2_marlin_in = weight8bit_nt_kpack2_marlin1(layer.w2_weight[ii])
w2_marlin = torch.stack(w2_marlin_list, dim=0) w2_marlin_list.append(w2_marlin_in)
w2_marlin = torch.stack(w2_marlin_list, dim=0)
layer.w13_weight = Parameter(w1_marlin, requires_grad=False)
layer.w2_weight = Parameter(w2_marlin, requires_grad=False) layer.w13_weight = Parameter(w1_marlin, requires_grad=False)
layer.w2_weight = Parameter(w2_marlin, requires_grad=False)
def apply( def apply(
self, self,
......
...@@ -43,6 +43,30 @@ def w8a8_nt_kpack2_marlin_weight(w8a8_w, # [size_n, size_k// 2 ] ...@@ -43,6 +43,30 @@ def w8a8_nt_kpack2_marlin_weight(w8a8_w, # [size_n, size_k// 2 ]
w8a8_w = w8a8_w.reshape((size_n // k_tile, size_k * k_tile)) w8a8_w = w8a8_w.reshape((size_n // k_tile, size_k * k_tile))
return w8a8_w return w8a8_w
def weight8bit_nt_kpack2_marlin1(weight, # [size_n, size_k// 2 ]
k_tile=16,
k_tile1=4,
n_tile=16,
n_tile1=16):
assert weight.element_size() == 1, "weight 必须是 8 bit 类型"
if weight.dim() == 2:
size_n, size_k = weight.shape
assert size_n % k_tile == 0 and size_k % n_tile == 0, "k_tile / n_tile 必须能整除对应维度"
q = weight.reshape((size_n // (n_tile*n_tile1), n_tile1, n_tile, size_k // (k_tile*k_tile1), k_tile1, k_tile))
# q = q.permute((0, 2, 1, 3)).contiguous()
q = q.permute((0, 3, 1, 4, 2, 5)).contiguous()
q = q.reshape((size_n // k_tile, size_k * k_tile))
elif weight.dim() == 3:
E, size_n, size_k = weight.shape
assert size_n % n_tile == 0 and size_k % k_tile == 0, "k_tile / n_tile 必须能整除对应维度"
q = weight.reshape((E, size_n // (n_tile*n_tile1), n_tile1, n_tile, size_k // (k_tile*k_tile1), k_tile1, k_tile))
q = q.permute((0, 1, 4, 2, 5, 3, 6)).contiguous()
q = q.reshape((E, size_n // k_tile, size_k * k_tile))
return q
def sparse_cutlass_supported() -> bool: def sparse_cutlass_supported() -> bool:
if not current_platform.is_cuda(): if not current_platform.is_cuda():
......
...@@ -413,8 +413,8 @@ def has_deep_ep() -> bool: ...@@ -413,8 +413,8 @@ def has_deep_ep() -> bool:
def has_deep_gemm() -> bool: def has_deep_gemm() -> bool:
"""Whether the optional `deep_gemm` package is available.""" """Whether the optional `deepgemm` package is available."""
return _has_module("deep_gemm") return _has_module("deepgemm")
def has_triton_kernels() -> bool: def has_triton_kernels() -> bool:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment