Unverified Commit 44eea10f authored by xuebwang-amd's avatar xuebwang-amd Committed by GitHub
Browse files

[ROCm][Quantization] make quark ocp mx dtype parser robust for weight-only quantization (#36232)


Signed-off-by: default avatarxuebwang-amd <xuebwang@amd.com>
parent 8b6c6b95
...@@ -92,7 +92,8 @@ class QuarkMoEMethod(FusedMoEMethodBase): ...@@ -92,7 +92,8 @@ class QuarkMoEMethod(FusedMoEMethodBase):
rocm_aiter_ops.is_fused_moe_enabled() rocm_aiter_ops.is_fused_moe_enabled()
) )
if ( if (
input_config.get("dtype") == "fp8_e4m3" input_config is not None
and input_config.get("dtype") == "fp8_e4m3"
and not input_config.get("is_dynamic") and not input_config.get("is_dynamic")
and not emulate and not emulate
): ):
......
...@@ -176,7 +176,7 @@ class QuarkOCP_MX(QuarkScheme): ...@@ -176,7 +176,7 @@ class QuarkOCP_MX(QuarkScheme):
def __init__( def __init__(
self, self,
weight_quant_spec: dict[str, Any], weight_quant_spec: dict[str, Any],
input_quant_spec: dict[str, Any], input_quant_spec: dict[str, Any] | None,
dynamic_mxfp4_quant: bool = False, dynamic_mxfp4_quant: bool = False,
): ):
self.out_dtype = torch.get_default_dtype() self.out_dtype = torch.get_default_dtype()
...@@ -185,7 +185,13 @@ class QuarkOCP_MX(QuarkScheme): ...@@ -185,7 +185,13 @@ class QuarkOCP_MX(QuarkScheme):
self.input_quant_spec = input_quant_spec self.input_quant_spec = input_quant_spec
self.dynamic_mxfp4_quant = dynamic_mxfp4_quant self.dynamic_mxfp4_quant = dynamic_mxfp4_quant
self.weight_dtype = weight_quant_spec["dtype"].replace("fp", "mxfp") self.weight_dtype = weight_quant_spec["dtype"].replace("fp", "mxfp")
self.input_dtype = input_quant_spec["dtype"].replace("fp", "mxfp") self.input_dtype: str | None = None
if input_quant_spec is not None:
input_quant = input_quant_spec["dtype"]
if input_quant == "fp8_e4m3":
self.input_dtype = "fp8"
else:
self.input_dtype = input_quant.replace("fp", "mxfp")
self.ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype( self.ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype(
self.input_dtype, self.weight_dtype self.input_dtype, self.weight_dtype
...@@ -200,14 +206,21 @@ class QuarkOCP_MX(QuarkScheme): ...@@ -200,14 +206,21 @@ class QuarkOCP_MX(QuarkScheme):
dequant_mxfp6, quant_dtype=self.weight_dtype.replace("mx", "") dequant_mxfp6, quant_dtype=self.weight_dtype.replace("mx", "")
) )
if self.input_dtype == "mxfp4": if self.input_dtype is None:
self.quant_dequant_func: Callable[[torch.Tensor], torch.Tensor] = (
lambda x: x
) # no input Q/DQ for weight-only
elif self.input_dtype == "mxfp4":
self.quant_dequant_func = quant_dequant_mxfp4 self.quant_dequant_func = quant_dequant_mxfp4
else: else:
self.quant_dequant_func = partial( self.quant_dequant_func = partial(
quant_dequant_mxfp6, quant_dtype=self.input_dtype.replace("mx", "") quant_dequant_mxfp6, quant_dtype=self.input_dtype.replace("mx", "")
) )
self.static_input_scales = not input_quant_spec.get("is_dynamic") if input_quant_spec is None:
self.static_input_scales = False
else:
self.static_input_scales = not input_quant_spec.get("is_dynamic")
if self.static_input_scales: if self.static_input_scales:
raise NotImplementedError( raise NotImplementedError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment