Unverified Commit a0754757 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Improve support and testing for direct recipe usage without autocast contexts (#2366)



* Refactor to avoid storing a global quantization config so direct recipe passing works as intended
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* fix use_split_accumulator for current scaling recipe
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* fix tests that pass direct recipe and were missing quantize meta set
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Revert "fix use_split_accumulator for current scaling recipe"

This reverts commit a74ab7df812ec0a069b1bdd208debb93ec25a900.
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* fix ci failures
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix amax_history post_init
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>

* Update transformer_engine/jax/quantize/quantizer.py
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix ci failures
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* fix ci issue
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* address comments
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* make recipe assertion classes in test_recipe_characteristics not inherit from unittest.TestCase
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent b88f727b
......@@ -40,6 +40,8 @@ from transformer_engine.jax.quantize import (
QuantizerFactory,
QuantizeLayout,
noop_quantizer_set,
QuantizeMetaSet,
QuantizeMeta,
)
from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation
......@@ -1457,7 +1459,12 @@ class TestDense:
value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2))
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
quantizer_set = QuantizerFactory.create_set(fp8_recipe=recipe)
quantizer_set = QuantizerFactory.create_set(
fp8_recipe=recipe,
quantize_meta_set=QuantizeMetaSet(
x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta()
),
)
n_iterations = 3 if recipe.delayed() else 1
with use_jax_gemm(enabled=with_jax_gemm):
......@@ -1516,7 +1523,12 @@ class TestFusedDense:
gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16)
quantizer_set = QuantizerFactory.create_set(fp8_recipe=recipe)
quantizer_set = QuantizerFactory.create_set(
fp8_recipe=recipe,
quantize_meta_set=QuantizeMetaSet(
x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta()
),
)
if norm_type == "layernorm":
beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
......@@ -1605,6 +1617,9 @@ class TestFusedDense:
quantizer_sets = QuantizerFactory.create_set(
n_quantizer_sets=2,
fp8_recipe=recipe,
quantize_meta_set=QuantizeMetaSet(
x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta()
),
)
if norm_type == "layernorm":
......
......@@ -23,7 +23,8 @@ from utils import EncoderLayer as RefEncoderLayer
from transformer_engine.common import recipe
from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType
from transformer_engine.jax.quantize import (
get_quantize_config,
get_global_quantize_recipe,
get_quantize_config_with_recipe,
ScalingMode,
is_fp8_available,
update_collections,
......@@ -358,7 +359,7 @@ class BaseRunner:
ref_params, test_params = self._sync_params(ref_params, test_params)
if get_quantize_config().is_fp8_enabled():
if get_quantize_config_with_recipe(get_global_quantize_recipe()).is_fp8_enabled():
for _ in range(4):
_, updated_state = jax.value_and_grad(self._loss_fn, argnums=(3,), has_aux=False)(
inputs,
......@@ -368,14 +369,24 @@ class BaseRunner:
test_layer,
)
if (
get_quantize_config().get_scaling_mode(TensorSource.X)
get_quantize_config_with_recipe(get_global_quantize_recipe()).get_scaling_mode(
TensorSource.X
)
== ScalingMode.DELAYED_TENSOR_SCALING
):
_, updated_quantize_meta = flax.core.pop(
updated_state[0], get_quantize_config().COLLECTION_NAME
updated_state[0],
get_quantize_config_with_recipe(
get_global_quantize_recipe()
).COLLECTION_NAME,
)
test_others = update_collections(
{get_quantize_config().COLLECTION_NAME: updated_quantize_meta}, test_others
{
get_quantize_config_with_recipe(
get_global_quantize_recipe()
).COLLECTION_NAME: updated_quantize_meta
},
test_others,
)
del updated_quantize_meta
del updated_state
......
This diff is collapsed.
......@@ -221,7 +221,7 @@ def manage_primitives(enable_names=None, disable_names=None, disable_all_first=F
"""
Helper function to manage primitive states by name without modifying environment variables.
Allows enabling specific primitives, disabling specific primitives, or disabling all primitives.
This helper is used in the get_quantize_config().initialize() methods.
This helper is used in the get_quantize_config_with_recipe().initialize() methods.
Args:
enable_names: List of strings, each representing the name of a primitive class to enable. Defaults to None.
......
......@@ -38,12 +38,13 @@ from ..quantize import (
ScalingMode,
Quantizer,
GroupedQuantizer,
get_quantize_config,
QuantizerSet,
QuantizeLayout,
noop_quantizer_set,
is_fp8_gemm_with_all_layouts_supported,
apply_padding_to_scale_inv,
get_quantize_config_with_recipe,
get_global_quantize_recipe,
)
from .misc import get_padded_spec, is_all_reduce_in_float32
from ..sharding import (
......@@ -1246,7 +1247,7 @@ def _te_gemm(
fuse_bias: bool = False,
fuse_gelu: bool = False,
grad: bool = False,
use_split_accumulator: bool = get_quantize_config().FP8_2X_ACC_FPROP,
use_split_accumulator: bool = None,
transpose_batch_sequence: bool = False,
collective_op: CollectiveOp = CollectiveOp.NONE,
) -> Tuple[jax.Array, ...]:
......@@ -1258,6 +1259,13 @@ def _te_gemm(
DeprecationWarning,
)
if use_split_accumulator is None:
# TODO(jberchtold): Rework GEMM API to provide the context here instead of relying on global state and also
# use context of the GEMM type so we can decide between fprop, dgrad, and wgrad
use_split_accumulator = get_quantize_config_with_recipe(
get_global_quantize_recipe()
).FP8_2X_ACC_FPROP
# Prepare non-quantized GEMM operands
lhs_data = lhs
rhs_data = rhs
......@@ -1720,10 +1728,15 @@ def _jax_gemm(
assert (
rhs.scaling_mode == lhs.scaling_mode
), f"rhs.scaling_mode={rhs.scaling_mode} != lhs.scaling_mode={lhs.scaling_mode}"
# TODO(jberchtold): Rework GEMM API to provide the context here instead of relying on global state and also
# use context of the GEMM type so we can decide between fprop, dgrad, and wgrad
use_split_accumulator = get_quantize_config_with_recipe(
get_global_quantize_recipe()
).FP8_2X_ACC_FPROP
precision = (
jax.lax.Precision.HIGHEST
if get_quantize_config().FP8_2X_ACC_FPROP
else jax.lax.Precision.DEFAULT
jax.lax.Precision.HIGHEST if use_split_accumulator else jax.lax.Precision.DEFAULT
)
return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision)
......
......@@ -820,7 +820,7 @@ def _quantize_dbias_impl(
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
)
scale = compute_scale_from_amax(amax, quantizer.q_dtype)
scale = compute_scale_from_amax(amax, quantizer.q_dtype, margin=0.0)
elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
scale = quantizer.scale
# Make sure to reset amax to zeros for DelayedScaling
......@@ -1227,7 +1227,7 @@ def grouped_quantize(
)
grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups)
for i in range(n_groups):
tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype)
tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype, margin=0.0)
scale = scale.at[i].set(tmp_scale[0])
is_tensor_scaling = quantizer.scaling_mode in (
......
......@@ -27,7 +27,6 @@ from .quantize import (
with_sharding_constraint_by_logical_axes,
is_fp8_gemm_with_all_layouts_supported,
TensorUsage,
get_quantize_config,
)
......@@ -95,7 +94,7 @@ def dense(
if transpose_batch_sequence:
warnings.warn("transpose_batch_sequence is not well tested, use with caution!")
if not get_quantize_config().is_fp8_enabled():
if quantizer_set == noop_quantizer_set:
input_dtype = x.dtype
kernel = kernel.astype(input_dtype)
......
......@@ -33,10 +33,11 @@ from ..cpp_extensions import (
)
from ..quantize import (
QuantizerFactory,
get_quantize_config,
get_global_quantize_recipe,
QuantizeMetaSet,
TensorSource,
get_quantize_config_with_recipe,
noop_quantizer_set,
)
PRNGKey = Any
......@@ -355,17 +356,17 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
Generate a set of FP8 meta for a GEMM.
"""
if fp8_recipe is None:
fp8_recipe = get_global_quantize_recipe()
quantize_config = get_quantize_config_with_recipe(fp8_recipe)
collection_name = (
variable_collection
if variable_collection is not None
else get_quantize_config().COLLECTION_NAME
else quantize_config.COLLECTION_NAME
)
if fp8_recipe is None:
quantize_config = get_quantize_config()
else:
quantize_config = get_quantize_config_with_recipe(fp8_recipe)
x_meta = quantize_config.get_quantize_flax_meta(
self, collection_name, postfix, TensorSource.X, "x"
)
......@@ -492,7 +493,11 @@ class DenseGeneral(TransformerEngineBase):
self.dtype,
)
if not get_quantize_config().is_fp8_enabled():
quantizer_set = self.generate_quantizer_set(
quantization_checkpoint_name=self.quantization_checkpoint_name
)
if quantizer_set == noop_quantizer_set:
kernel = kernel.astype(input_dtype)
if self.use_bias:
......@@ -505,9 +510,6 @@ class DenseGeneral(TransformerEngineBase):
else:
bias = None
quantizer_set = self.generate_quantizer_set(
quantization_checkpoint_name=self.quantization_checkpoint_name
)
contract_ind = tuple(range(0, len(axis)))
y = dense(
inputs,
......@@ -712,7 +714,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
)
fuse_layernorm = (
get_quantize_config().is_fp8_enabled()
quantizer_set != noop_quantizer_set
and not self.return_layernorm_output
and self.enable_layernorm
)
......@@ -763,7 +765,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
kernel_shape,
self.dtype,
)
if not get_quantize_config().is_fp8_enabled():
if quantizer_set == noop_quantizer_set:
kernel = kernel.astype(input_dtype)
contract_ind = tuple(range(0, len(axis)))
......@@ -1042,7 +1044,7 @@ class LayerNormMLP(TransformerEngineBase):
# TODO(Phuong): use fuse_layernorm for high-precision
# when NoOpQuantizer and Tensor are implemented
fuse_layernorm = (
get_quantize_config().is_fp8_enabled()
ffn1_quantizer_set != noop_quantizer_set
and not self.return_layernorm_output
and self.enable_layernorm
)
......@@ -1128,7 +1130,7 @@ class LayerNormMLP(TransformerEngineBase):
self.dtype,
)
if not get_quantize_config().is_fp8_enabled():
if ffn1_quantizer_set == noop_quantizer_set:
kernel_1 = kernel_1.astype(input_dtype)
hidden_size = inputs.shape[-1]
......@@ -1140,7 +1142,7 @@ class LayerNormMLP(TransformerEngineBase):
kernel_2_shape,
self.dtype,
)
if not get_quantize_config().is_fp8_enabled():
if ffn2_quantizer_set == noop_quantizer_set:
kernel_2 = kernel_2.astype(input_dtype)
contract_ind = tuple(range(0, len(axis)))
......
......@@ -23,7 +23,6 @@ from .quantize import (
noop_quantizer_set,
with_sharding_constraint_by_logical_axes,
TensorUsage,
get_quantize_config,
)
......@@ -73,7 +72,7 @@ def layernorm_dense(
- Quantization is applied to both the normalized input and kernel
"""
if not get_quantize_config().is_fp8_enabled():
if quantizer_set == noop_quantizer_set:
input_dtype = x.dtype
kernel = kernel.astype(input_dtype)
......
......@@ -28,7 +28,6 @@ from .quantize import (
QuantizerSet,
noop_quantizer_set,
TensorUsage,
get_quantize_config,
)
......@@ -114,7 +113,7 @@ def layernorm_mlp(
not zero_centered_gamma
), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'"
if not get_quantize_config().is_fp8_enabled():
if quantizer_sets == (noop_quantizer_set, noop_quantizer_set):
input_dtype = x.dtype
kernel_1 = kernel_1.astype(input_dtype)
kernel_2 = kernel_2.astype(input_dtype)
......
......@@ -46,7 +46,7 @@ from .scaling_modes import ScalingMode
from .device_utils import get_device_compute_capability
__all__ = [
"get_quantize_config",
"get_global_quantize_recipe",
"get_quantize_config_with_recipe",
"autocast",
"fp8_autocast",
......@@ -475,7 +475,12 @@ class DelayedScalingQuantizeConfig(BaseQuantizeConfig):
(self.AMAX_HISTORY_LEN,),
jnp.float32,
).value
return QuantizeMeta(scale=scale, amax_history=amax_history)
return QuantizeMeta(
margin=self.MARGIN,
amax_compute_algo=self.AMAX_COMPUTE_ALGO,
scale=scale,
amax_history=amax_history,
)
class CurrentScalingQuantizeConfig(BaseQuantizeConfig):
......@@ -669,14 +674,6 @@ class NVFP4ScalingQuantizeConfig(BaseQuantizeConfig):
)
_QUANTIZE_CONFIG = NoOpQuantizeConfig()
def get_quantize_config():
"""Global instance of BaseQuantizeConfig set by autocast context."""
return _QUANTIZE_CONFIG
def get_quantize_config_class(
fp8_recipe: Recipe,
) -> Type[BaseQuantizeConfig]:
......@@ -687,6 +684,8 @@ def get_quantize_config_class(
Returns:
The quantization config class corresponding to the given recipe.
"""
if fp8_recipe is None:
return NoOpQuantizeConfig
if isinstance(fp8_recipe, DelayedScaling):
return DelayedScalingQuantizeConfig
if isinstance(fp8_recipe, MXFP8BlockScaling):
......@@ -701,10 +700,23 @@ def get_quantize_config_class(
def get_quantize_config_with_recipe(fp8_recipe: Recipe):
"""Get the quantization configuration object based on the FP8 recipe."""
config = get_quantize_config_class(fp8_recipe)()
config.initialize_from_recipe(fp8_recipe)
if fp8_recipe is not None:
config.initialize_from_recipe(fp8_recipe)
return config
_GLOBAL_RECIPE: Optional[Recipe] = None
def get_global_quantize_recipe() -> Optional[Recipe]:
"""Get the global quantization recipe if set.
Returns:
The global quantization recipe or None if not set.
"""
return _GLOBAL_RECIPE
@contextmanager
def autocast(
enabled: bool = False,
......@@ -751,22 +763,21 @@ def autocast(
if recipe is None:
recipe = DelayedScaling()
global _QUANTIZE_CONFIG
global _GLOBAL_RECIPE
old_quantize_config = _QUANTIZE_CONFIG
old_global_recipe = _GLOBAL_RECIPE
_QUANTIZE_CONFIG = NoOpQuantizeConfig()
_GLOBAL_RECIPE = None
try:
with global_shard_guard(mesh_resource):
if enabled:
_QUANTIZE_CONFIG = get_quantize_config_class(recipe)()
is_supported, reason = _QUANTIZE_CONFIG.is_supported()
_GLOBAL_RECIPE = recipe
is_supported, reason = get_quantize_config_class(_GLOBAL_RECIPE)().is_supported()
assert is_supported, reason
_QUANTIZE_CONFIG.initialize_from_recipe(recipe)
yield
finally:
_QUANTIZE_CONFIG = old_quantize_config
_GLOBAL_RECIPE = old_global_recipe
@contextmanager
......
......@@ -28,7 +28,7 @@ from .tensor import (
NoScaleTensor,
)
from .helper import (
get_quantize_config,
get_global_quantize_recipe,
get_quantize_config_with_recipe,
AmaxComputeAlgo,
TensorSource,
......@@ -50,7 +50,7 @@ __all__ = [
def compute_scale_from_amax(
amax: jnp.ndarray, q_dtype: jnp.dtype, scale: Optional[jnp.ndarray] = None
amax: jnp.ndarray, q_dtype: jnp.dtype, margin: float, scale: Optional[jnp.ndarray] = None
) -> jnp.ndarray:
"""Compute scale from amax value.
......@@ -64,7 +64,7 @@ def compute_scale_from_amax(
fp8_max = jnp.astype(jnp.finfo(q_dtype).max, jnp.float32)
if scale is None:
scale = jnp.ones((1,))
sf = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN)
sf = (fp8_max / amax) / (2**margin)
sf = jnp.where(amax > 0.0, sf, scale)
sf = jnp.where(jnp.isfinite(amax), sf, scale)
assert sf.shape == (1,), f"Expected sf.shape == (1,), but got {sf.shape}"
......@@ -223,6 +223,7 @@ class CurrentScaleQuantizer(Quantizer):
Attributes:
scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING
q_layout: Quantization axis (default: ROWWISE_COLWISE)
data_layout: Data layout string (default: "NT")
"""
scaling_mode: ScalingMode = ScalingMode.CURRENT_TENSOR_SCALING
......@@ -254,8 +255,7 @@ class CurrentScaleQuantizer(Quantizer):
compute_dtype = jnp.float32
dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
amax = x.amax or jnp.max(jnp.abs(x.data)).reshape((1,))
fp8_max = jnp.astype(jnp.finfo(self.q_dtype).max, jnp.float32)
scale = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN)
scale = compute_scale_from_amax(amax, self.q_dtype, margin=0.0)
scaled_x = x.data.astype(compute_dtype) * scale
clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype)
......@@ -327,17 +327,23 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
Attributes:
scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING
q_layout: Quantization axis (default: ROWWISE_COLWISE)
data_layout: Data layout string (default: "NT")
margin: Margin value for scale computation
amax_compute_algo: Algorithm for computing amax
scale: Current scaling factor
amax_history: History of maximum absolute values
"""
scaling_mode: ScalingMode = ScalingMode.DELAYED_TENSOR_SCALING
q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
margin: float = 0.0
amax_compute_algo: AmaxComputeAlgo = AmaxComputeAlgo.MAX
scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32))
amax_history: jnp.ndarray = field(
default_factory=lambda: jnp.zeros((get_quantize_config().AMAX_HISTORY_LEN,), jnp.float32)
)
amax_history: jnp.ndarray = field(default_factory=lambda: jnp.zeros((1024,), jnp.float32))
def __post_init__(self):
assert self.margin is not None, "margin must be specified"
assert self.amax_compute_algo is not None, "amax_compute_algo must be specified"
assert self.amax_history is not None, "amax_history must be specified"
def tree_flatten(self):
"""Flatten the quantizer for JAX tree operations.
......@@ -352,6 +358,8 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
self.q_layout,
self.data_layout,
self.checkpoint_name,
self.margin,
self.amax_compute_algo,
)
return (children, aux_data)
......@@ -407,12 +415,14 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
Returns:
Updated AMAX history
"""
amax_history = amax_history.at[0].set(new_amax[0])
amax_history = amax_history.at[0].set(new_amax.reshape((1,))[0])
return amax_history
@staticmethod
@partial(jax.jit, static_argnums=(2,))
def _compute_scale(amax_history, scale, q_dtype):
@partial(jax.jit, static_argnums=(2, 3, 4))
def _compute_scale(
amax_history, scale, q_dtype, amax_compute_algo: AmaxComputeAlgo, margin: float
):
"""Compute new scale based on AMAX history.
Args:
......@@ -424,12 +434,12 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
Updated scale value
"""
# 2. Calculate the current scale
if get_quantize_config().AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
if amax_compute_algo is AmaxComputeAlgo.MAX:
amax = jnp.max(amax_history, axis=-1, keepdims=True)
else:
amax = amax_history[0:1]
return compute_scale_from_amax(amax, q_dtype, scale=scale)
return compute_scale_from_amax(amax, q_dtype, margin=margin, scale=scale)
@staticmethod
@jax.jit
......@@ -453,7 +463,9 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
new_amax: New maximum absolute value to add to history
"""
amax_history = self._update_amax_history(self.amax_history, new_amax)
self.scale = self._compute_scale(amax_history, self.scale, self.q_dtype)
self.scale = self._compute_scale(
amax_history, self.scale, self.q_dtype, self.amax_compute_algo, self.margin
)
self.amax_history = self._roll_and_reset_amax_history(amax_history)
......@@ -1124,6 +1136,7 @@ class QuantizerFactory:
bwd_dtype,
is_2x2x,
n_groups,
is_inference_mode=False,
checkpoint_name: Optional[str] = None,
**kwargs,
) -> QuantizerSet:
......@@ -1137,6 +1150,7 @@ class QuantizerFactory:
bwd_dtype: Data type for backward pass
is_2x2x: Whether to use 2x2x quantization
n_groups
is_inference_mode: Whether to create quantizers for inference mode. This option is not fully supported yet
checkpoint_name: Optional name for checkpointing quantizations
**kwargs: Additional arguments for quantizer initialization
......@@ -1149,7 +1163,7 @@ class QuantizerFactory:
q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE
if kernel_scaling_mode.is_1d_block_scaling():
q_layout_kernel = QuantizeLayout.COLWISE
if get_quantize_config().INFERENCE_MODE:
if is_inference_mode:
q_layout_dgrad = None
if "quantize_meta_set" in kwargs:
......@@ -1206,10 +1220,10 @@ class QuantizerFactory:
Args:
n_quantizer_sets: Number of quantizer sets to create
scaling_mode: Scaling mode to use, default is get_quantize_config().get_scaling_mode
fwd_dtype: Data type for forward pass, default is get_quantize_config().FWD_DTYPE
bwd_dtype: Data type for backward pass, default is get_quantize_config().BWD_DTYPE
is_2x2x: Whether to use 2x2x quantization, default is get_quantize_config().IF_QUANTIZE_2X
scaling_mode: Scaling mode to use, default is get the scaling mode from the specified or global recipe
fwd_dtype: Data type for forward pass, default is get the fwd dtype from the specified or global recipe
bwd_dtype: Data type for backward pass, default is get the bwd dtype from the specified or global recipe
is_2x2x: Whether to use 2x2x quantization, default is determined based on the specified or global recipe
n_groups:
checkpoint_name: Optional name for checkpointing quantizations
fp8_recipe: Recipe to use for quantization. Scaling mode can be specified directly via the scaling_mode parameter or indirectly via recipe. Recipe is preferred as it will support additional recipes in future where scaling mode differs between x, kernel, and grad in the quantizer set.
......@@ -1226,25 +1240,46 @@ class QuantizerFactory:
" scaling mode differs between x, kernel, and grad in the quantizer set."
)
# TODO(jberchtold): Currently this is a limitation because we only support automatically populating quantizer fields based on a given recipe when using Flax. In the generic quantizer logic, we cannot assume Flax is being used, so we require the user to provide the quantize_meta_set created by quantize_config.get_quantize_flax_meta() or the same data created by themselves if they are passing a recipe here directly.
assert (
fp8_recipe is None or "quantize_meta_set" in kwargs
), "When fp8_recipe is specified, quantize_meta_set must be provided in kwargs."
if fp8_recipe is None:
fp8_recipe = get_global_quantize_recipe()
if fp8_recipe is not None:
assert scaling_mode is None, (
"scaling_mode should not be specified when fp8_recipe is provided either directly"
" or through an autocast context."
)
assert fwd_dtype is None, (
"fwd_dtype should not be specified when fp8_recipe is provided either directly or"
" through an autocast context."
)
assert bwd_dtype is None, (
"bwd_dtype should not be specified when fp8_recipe is provided either directly or"
" through an autocast context."
)
quantize_config = get_quantize_config_with_recipe(fp8_recipe)
x_scaling_mode = quantize_config.get_scaling_mode(TensorSource.X)
kernel_scaling_mode = quantize_config.get_scaling_mode(TensorSource.KERNEL)
grad_scaling_mode = quantize_config.get_scaling_mode(TensorSource.DGRAD)
fwd_dtype = quantize_config.FWD_DTYPE
bwd_dtype = quantize_config.BWD_DTYPE
is_inference_mode = quantize_config.INFERENCE_MODE
else:
if scaling_mode is not None:
x_scaling_mode = scaling_mode
kernel_scaling_mode = scaling_mode
grad_scaling_mode = scaling_mode
else:
x_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.X)
kernel_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.KERNEL)
grad_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.DGRAD)
# TODO(jberchtold): make a way to explicitly pass a no scaling recipe here if we need other quantization config attributes in the future since NoOpQuantizeConfig already exists, we just can't use it here with direct recipe passing because we cannot differentiate between fp8_recipe=None meaning no recipe specified vs explicitly no quantization desired.
x_scaling_mode = ScalingMode.NO_SCALING
kernel_scaling_mode = ScalingMode.NO_SCALING
grad_scaling_mode = ScalingMode.NO_SCALING
is_inference_mode = False
fwd_dtype = fwd_dtype or get_quantize_config().FWD_DTYPE
bwd_dtype = bwd_dtype or get_quantize_config().BWD_DTYPE
if is_2x2x is None:
# TODO(Jeremy): check x, kernel, grad separately for 2x
if x_scaling_mode.is_1d_block_scaling():
......@@ -1253,7 +1288,6 @@ class QuantizerFactory:
is_2x2x = not is_fp8_gemm_with_all_layouts_supported()
else: # NO_SCALING ignores is_2x2x for now
is_2x2x = False
is_inference_mode = get_quantize_config().INFERENCE_MODE
assert not is_inference_mode, "Inference mode is not supported yet!"
q_set = []
......@@ -1267,6 +1301,7 @@ class QuantizerFactory:
bwd_dtype=bwd_dtype,
is_2x2x=is_2x2x,
n_groups=n_groups,
is_inference_mode=is_inference_mode,
checkpoint_name=checkpoint_name,
**kwargs,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment