Unverified Commit 6a4e871e authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Support custom recipe and custom collection name when creating quantizer sets (#2059)



* Support setting collection name for quantizer set Flax variables in TransformerEngineBase flax module
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Support creating quantizer set from a recipe directly
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix debug error format string in gemm.py
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent f947e703
......@@ -1090,7 +1090,7 @@ def _jax_gemm(
if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums)
raise NotImplementedError("Unsupported ScalingMode: {lhs.scaling_mode}")
raise NotImplementedError(f"Unsupported ScalingMode: {lhs.scaling_mode}")
lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims)
......
......@@ -337,21 +337,28 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
Base class of transformer engine
"""
def generate_quantizer_set(self, postfix: str = ""):
def generate_quantizer_set(
self, postfix: str = "", variable_collection: str = None, fp8_recipe=None
):
"""
Generate a set of FP8 meta for a GEMM.
"""
def generate_quantize_meta(quantizer_name: str):
collection_name = (
variable_collection
if variable_collection is not None
else QuantizeConfig.COLLECTION_NAME
)
scale = self.variable(
QuantizeConfig.COLLECTION_NAME,
collection_name,
f"{quantizer_name}{postfix}_scale",
jnp.ones,
(1,),
jnp.float32,
).value
amax_history = self.variable(
QuantizeConfig.COLLECTION_NAME,
collection_name,
f"{quantizer_name}{postfix}_amax_history",
jnp.zeros,
(QuantizeConfig.AMAX_HISTORY_LEN,),
......@@ -368,7 +375,7 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
else:
kwargs = {}
quantizer_set = QuantizerFactory.create_set(**kwargs)
quantizer_set = QuantizerFactory.create_set(fp8_recipe=fp8_recipe, **kwargs)
return quantizer_set
......
......@@ -16,12 +16,14 @@ import jax
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
from transformer_engine_jax import QuantizeLayout
from transformer_engine.common import recipe
from .scaling_modes import ScalingMode
from .tensor import ScaledTensor, ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory
from .helper import (
QuantizeConfig,
AmaxComputeAlgo,
_get_scaling_mode,
)
from .device_utils import is_fp8_gemm_with_all_layouts_supported
......@@ -878,11 +880,12 @@ class QuantizerFactory:
@staticmethod
def create_set(
n_quantizer_sets: int = 1,
scaling_mode: ScalingMode = None,
scaling_mode: Optional[ScalingMode] = None,
fwd_dtype: jnp.dtype = None,
bwd_dtype: jnp.dtype = None,
is_2x2x: bool = None,
n_groups: int = None,
fp8_recipe: Optional[recipe.Recipe] = None,
**kwargs,
) -> tuple[Union[tuple[Quantizer], None]]:
"""Create one or more sets of quantizers.
......@@ -894,11 +897,24 @@ class QuantizerFactory:
bwd_dtype: Data type for backward pass, default is QuantizeConfig.BWD_DTYPE
is_2x2x: Whether to use 2x2x quantization, default is QuantizeConfig.IF_QUANTIZE_2X
n_groups:
fp8_recipe: Recipe to use for quantization. Scaling mode can be specified directly via the scaling_mode parameter or indirectly via recipe. Recipe is preferred as it will support additional recipes in future where scaling mode differs between x, kernel, and grad in the quantizer set.
**kwargs: Additional arguments for quantizer initialization
Returns:
A single quantizer set or tuple of quantizer sets
"""
assert scaling_mode is None or fp8_recipe is None, (
"Cannot specify both scaling_mode and fp8_recipe when creating a quantizer set. Scaling"
" mode can be specified directly via the scaling_mode parameter or indirectly via"
" recipe. Recipe is preferred as it will support additional recipes in future where"
" scaling mode differs between x, kernel, and grad in the quantizer set."
)
if fp8_recipe is not None:
# TODO(jberchtold): once recipe and scaling mode are decoupled update this logic
scaling_mode = _get_scaling_mode(fp8_recipe)
else:
scaling_mode = scaling_mode or QuantizeConfig.SCALING_MODE
fwd_dtype = fwd_dtype or QuantizeConfig.FWD_DTYPE
bwd_dtype = bwd_dtype or QuantizeConfig.BWD_DTYPE
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment