Unverified Commit 7c73ceb5 authored by Jinzhen Lin's avatar Jinzhen Lin Committed by GitHub
Browse files

[Quantization] add marlin w4a8/w8a8 check (#31061)


Signed-off-by: default avatarJinzhen Lin <jinzhen.ljz@antgroup.com>
parent ae0770fa
......@@ -594,9 +594,15 @@ def apply_awq_marlin_linear(
a_scales = None
if input_dtype == torch.int8:
assert quant_type == scalar_types.uint4, (
"W8A8-INT8 is not supported by marlin kernel."
)
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
a_scales = a_scales * input_global_scale
elif input_dtype == torch.float8_e4m3fn:
assert quant_type == scalar_types.uint4, (
"INT8 weight + FP8 activation is not supported."
)
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
output = ops.gptq_marlin_gemm(
......@@ -649,9 +655,15 @@ def apply_rtn_marlin_linear(
a_scales = None
if input_dtype == torch.int8:
assert quant_type == scalar_types.uint4b8, (
"W8A8-INT8 is not supported by marlin kernel."
)
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
a_scales = a_scales * input_global_scale
elif input_dtype == torch.float8_e4m3fn:
assert quant_type == scalar_types.uint4b8, (
"INT8 weight + FP8 activation is not supported."
)
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
output = ops.gptq_marlin_gemm(
......
......@@ -154,6 +154,12 @@ def prepare_fp4_layer_for_marlin(
)
is_nvfp4 = hasattr(layer, "weight_scale_2")
if input_dtype is not None and input_dtype.itemsize == 1:
if is_nvfp4:
raise RuntimeError("NVFP4 weight + INT8/FP8 activation is not supported.")
elif input_dtype != torch.float8_e4m3fn:
raise RuntimeError("MXFP4 weight + INT8 activation is not supported.")
group_size = 16 if is_nvfp4 else 32
part_size_n = layer.output_size_per_partition
......@@ -231,6 +237,12 @@ def prepare_moe_fp4_layer_for_marlin(
)
is_nvfp4 = hasattr(layer, "w13_weight_scale_2")
if input_dtype is not None and input_dtype.itemsize == 1:
if is_nvfp4:
raise RuntimeError("NVFP4 weight + INT8/FP8 activation is not supported.")
elif input_dtype != torch.float8_e4m3fn:
raise RuntimeError("MXFP4 weight + INT8 activation is not supported.")
group_size = 16 if is_nvfp4 else 32
e = layer.num_experts
......
......@@ -99,6 +99,8 @@ def prepare_fp8_layer_for_marlin(
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
if input_dtype is not None and input_dtype.itemsize == 1:
raise RuntimeError("Marlin W8A8 is not supported.")
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition
......@@ -206,6 +208,8 @@ def prepare_moe_fp8_layer_for_marlin(
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
if input_dtype is not None and input_dtype.itemsize == 1:
raise RuntimeError("Marlin W8A8 is not supported.")
e = layer.num_experts
k = layer.hidden_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment