Unverified Commit 7c73ceb5 authored by Jinzhen Lin's avatar Jinzhen Lin Committed by GitHub
Browse files

[Quantization] add marlin w4a8/w8a8 check (#31061)


Signed-off-by: default avatarJinzhen Lin <jinzhen.ljz@antgroup.com>
parent ae0770fa
...@@ -594,9 +594,15 @@ def apply_awq_marlin_linear( ...@@ -594,9 +594,15 @@ def apply_awq_marlin_linear(
a_scales = None a_scales = None
if input_dtype == torch.int8: if input_dtype == torch.int8:
assert quant_type == scalar_types.uint4, (
"W8A8-INT8 is not supported by marlin kernel."
)
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype) reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
a_scales = a_scales * input_global_scale a_scales = a_scales * input_global_scale
elif input_dtype == torch.float8_e4m3fn: elif input_dtype == torch.float8_e4m3fn:
assert quant_type == scalar_types.uint4, (
"INT8 weight + FP8 activation is not supported."
)
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype) reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
output = ops.gptq_marlin_gemm( output = ops.gptq_marlin_gemm(
...@@ -649,9 +655,15 @@ def apply_rtn_marlin_linear( ...@@ -649,9 +655,15 @@ def apply_rtn_marlin_linear(
a_scales = None a_scales = None
if input_dtype == torch.int8: if input_dtype == torch.int8:
assert quant_type == scalar_types.uint4b8, (
"W8A8-INT8 is not supported by marlin kernel."
)
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype) reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
a_scales = a_scales * input_global_scale a_scales = a_scales * input_global_scale
elif input_dtype == torch.float8_e4m3fn: elif input_dtype == torch.float8_e4m3fn:
assert quant_type == scalar_types.uint4b8, (
"INT8 weight + FP8 activation is not supported."
)
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype) reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
output = ops.gptq_marlin_gemm( output = ops.gptq_marlin_gemm(
......
...@@ -154,6 +154,12 @@ def prepare_fp4_layer_for_marlin( ...@@ -154,6 +154,12 @@ def prepare_fp4_layer_for_marlin(
) )
is_nvfp4 = hasattr(layer, "weight_scale_2") is_nvfp4 = hasattr(layer, "weight_scale_2")
if input_dtype is not None and input_dtype.itemsize == 1:
if is_nvfp4:
raise RuntimeError("NVFP4 weight + INT8/FP8 activation is not supported.")
elif input_dtype != torch.float8_e4m3fn:
raise RuntimeError("MXFP4 weight + INT8 activation is not supported.")
group_size = 16 if is_nvfp4 else 32 group_size = 16 if is_nvfp4 else 32
part_size_n = layer.output_size_per_partition part_size_n = layer.output_size_per_partition
...@@ -231,6 +237,12 @@ def prepare_moe_fp4_layer_for_marlin( ...@@ -231,6 +237,12 @@ def prepare_moe_fp4_layer_for_marlin(
) )
is_nvfp4 = hasattr(layer, "w13_weight_scale_2") is_nvfp4 = hasattr(layer, "w13_weight_scale_2")
if input_dtype is not None and input_dtype.itemsize == 1:
if is_nvfp4:
raise RuntimeError("NVFP4 weight + INT8/FP8 activation is not supported.")
elif input_dtype != torch.float8_e4m3fn:
raise RuntimeError("MXFP4 weight + INT8 activation is not supported.")
group_size = 16 if is_nvfp4 else 32 group_size = 16 if is_nvfp4 else 32
e = layer.num_experts e = layer.num_experts
......
...@@ -99,6 +99,8 @@ def prepare_fp8_layer_for_marlin( ...@@ -99,6 +99,8 @@ def prepare_fp8_layer_for_marlin(
"be used leveraging the Marlin kernel. This may degrade " "be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads." "performance for compute-heavy workloads."
) )
if input_dtype is not None and input_dtype.itemsize == 1:
raise RuntimeError("Marlin W8A8 is not supported.")
part_size_n = layer.output_size_per_partition part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition part_size_k = layer.input_size_per_partition
...@@ -206,6 +208,8 @@ def prepare_moe_fp8_layer_for_marlin( ...@@ -206,6 +208,8 @@ def prepare_moe_fp8_layer_for_marlin(
"be used leveraging the Marlin kernel. This may degrade " "be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads." "performance for compute-heavy workloads."
) )
if input_dtype is not None and input_dtype.itemsize == 1:
raise RuntimeError("Marlin W8A8 is not supported.")
e = layer.num_experts e = layer.num_experts
k = layer.hidden_size k = layer.hidden_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment