n_quantizer_sets: Number of quantizer sets to create
scaling_mode: Scaling mode to use, default is QuantizeConfig.SCALING_MODE
fwd_dtype: Data type for forward pass, default is QuantizeConfig.FWD_DTYPE
bwd_dtype: Data type for backward pass, default is QuantizeConfig.BWD_DTYPE
is_2x2x: Whether to use 2x2x quantization, default is QuantizeConfig.IF_QUANTIZE_2X
scaling_mode: Scaling mode to use, default is get_quantize_config().get_scaling_mode
fwd_dtype: Data type for forward pass, default is get_quantize_config().FWD_DTYPE
bwd_dtype: Data type for backward pass, default is get_quantize_config().BWD_DTYPE
is_2x2x: Whether to use 2x2x quantization, default is get_quantize_config().IF_QUANTIZE_2X
n_groups:
fp8_recipe: Recipe to use for quantization. Scaling mode can be specified directly via the scaling_mode parameter or indirectly via recipe. Recipe is preferred as it will support additional recipes in future where scaling mode differs between x, kernel, and grad in the quantizer set.
**kwargs: Additional arguments for quantizer initialization
...
...
@@ -912,27 +946,44 @@ class QuantizerFactory:
)
iffp8_recipeisnotNone:
# TODO(jberchtold): once recipe and scaling mode are decoupled update this logic
@@ -396,7 +480,7 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
The quantize layout for the tensor usage
"""
# If we need to support 1x1x for inference in the future
# if QuantizeConfig.INFERENCE_MODE:
# if get_quantize_config().INFERENCE_MODE:
# assert usage not in (TensorUsage.LHS_TRANS, TensorUsage.RHS_TRANS), (f"Invalid usage {usage} as we are in MXFP8_1D_SCALING 1x1x (FWD only) mode so no transposed usage is needed!")