Unverified Commit 05d68643 authored by ElizaWszola's avatar ElizaWszola Committed by GitHub
Browse files

[Kernel] Zero point support in fused MarlinMoE kernel + AWQ Fused MoE (#8973)


Co-authored-by: default avatarDipika <dipikasikka1@gmail.com>
Co-authored-by: default avatarDipika Sikka <ds3822@columbia.edu>
parent 0dcc8cbe
...@@ -557,14 +557,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -557,14 +557,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
x, x,
layer.w13_qweight, layer.w13_qweight,
layer.w2_qweight, layer.w2_qweight,
layer.w13_scales,
layer.w2_scales,
router_logits, router_logits,
layer.w13_g_idx,
layer.w2_g_idx,
layer.w13_g_idx_sort_indices,
layer.w2_g_idx_sort_indices,
topk_weights, topk_weights,
topk_ids, topk_ids,
w1_scale=layer.w13_scales, g_idx1=layer.w13_g_idx,
w2_scale=layer.w2_scales, g_idx2=layer.w2_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
num_bits=self.quant_config.quant_type.size_bits, num_bits=self.quant_config.quant_type.size_bits,
).to(orig_dtype) ).to(orig_dtype)
...@@ -208,6 +208,7 @@ def marlin_moe_permute_scales( ...@@ -208,6 +208,7 @@ def marlin_moe_permute_scales(
device=s.device, device=s.device,
dtype=s.dtype, dtype=s.dtype,
) )
for e in range(num_experts): for e in range(num_experts):
output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
return output return output
...@@ -258,6 +259,20 @@ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, ...@@ -258,6 +259,20 @@ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
return marlin_zp return marlin_zp
def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
size_n: int, num_bits: int):
num_experts = q_zp_packed.shape[0]
output = torch.empty(
(num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
device=q_zp_packed.device,
dtype=q_zp_packed.dtype,
)
for e in range(num_experts):
output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n,
num_bits)
return output
def apply_gptq_marlin_linear( def apply_gptq_marlin_linear(
input: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
......
...@@ -23,7 +23,9 @@ def get_model_architecture( ...@@ -23,7 +23,9 @@ def get_model_architecture(
architectures = getattr(model_config.hf_config, "architectures", []) architectures = getattr(model_config.hf_config, "architectures", [])
# Special handling for quantized Mixtral. # Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack. # FIXME(woosuk): This is a temporary hack.
mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin"] mixtral_supported = [
"fp8", "compressed-tensors", "gptq_marlin", "awq_marlin"
]
if (model_config.quantization is not None if (model_config.quantization is not None
and model_config.quantization not in mixtral_supported and model_config.quantization not in mixtral_supported
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment