Unverified Commit a0754757 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Improve support and testing for direct recipe usage without autocast contexts (#2366)



* Refactor to avoid storing a global quantization config so direct recipe passing works as intended
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* fix use_split_accumulator for current scaling recipe
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* fix tests that pass direct recipe and were missing quantize meta set
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Revert "fix use_split_accumulator for current scaling recipe"

This reverts commit a74ab7df812ec0a069b1bdd208debb93ec25a900.
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* fix ci failures
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix amax_history post_init
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>

* Update transformer_engine/jax/quantize/quantizer.py
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix ci failures
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* fix ci issue
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* address comments
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* make recipe assertion classes in test_recipe_characteristics not inherit from unittest.TestCase
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent b88f727b
...@@ -40,6 +40,8 @@ from transformer_engine.jax.quantize import ( ...@@ -40,6 +40,8 @@ from transformer_engine.jax.quantize import (
QuantizerFactory, QuantizerFactory,
QuantizeLayout, QuantizeLayout,
noop_quantizer_set, noop_quantizer_set,
QuantizeMetaSet,
QuantizeMeta,
) )
from transformer_engine.jax.quantize import helper from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation from transformer_engine.jax.activation import activation
...@@ -1457,7 +1459,12 @@ class TestDense: ...@@ -1457,7 +1459,12 @@ class TestDense:
value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2)) value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2))
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
quantizer_set = QuantizerFactory.create_set(fp8_recipe=recipe) quantizer_set = QuantizerFactory.create_set(
fp8_recipe=recipe,
quantize_meta_set=QuantizeMetaSet(
x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta()
),
)
n_iterations = 3 if recipe.delayed() else 1 n_iterations = 3 if recipe.delayed() else 1
with use_jax_gemm(enabled=with_jax_gemm): with use_jax_gemm(enabled=with_jax_gemm):
...@@ -1516,7 +1523,12 @@ class TestFusedDense: ...@@ -1516,7 +1523,12 @@ class TestFusedDense:
gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16) gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16)
quantizer_set = QuantizerFactory.create_set(fp8_recipe=recipe) quantizer_set = QuantizerFactory.create_set(
fp8_recipe=recipe,
quantize_meta_set=QuantizeMetaSet(
x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta()
),
)
if norm_type == "layernorm": if norm_type == "layernorm":
beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16) beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
...@@ -1605,6 +1617,9 @@ class TestFusedDense: ...@@ -1605,6 +1617,9 @@ class TestFusedDense:
quantizer_sets = QuantizerFactory.create_set( quantizer_sets = QuantizerFactory.create_set(
n_quantizer_sets=2, n_quantizer_sets=2,
fp8_recipe=recipe, fp8_recipe=recipe,
quantize_meta_set=QuantizeMetaSet(
x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta()
),
) )
if norm_type == "layernorm": if norm_type == "layernorm":
......
...@@ -23,7 +23,8 @@ from utils import EncoderLayer as RefEncoderLayer ...@@ -23,7 +23,8 @@ from utils import EncoderLayer as RefEncoderLayer
from transformer_engine.common import recipe from transformer_engine.common import recipe
from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType
from transformer_engine.jax.quantize import ( from transformer_engine.jax.quantize import (
get_quantize_config, get_global_quantize_recipe,
get_quantize_config_with_recipe,
ScalingMode, ScalingMode,
is_fp8_available, is_fp8_available,
update_collections, update_collections,
...@@ -358,7 +359,7 @@ class BaseRunner: ...@@ -358,7 +359,7 @@ class BaseRunner:
ref_params, test_params = self._sync_params(ref_params, test_params) ref_params, test_params = self._sync_params(ref_params, test_params)
if get_quantize_config().is_fp8_enabled(): if get_quantize_config_with_recipe(get_global_quantize_recipe()).is_fp8_enabled():
for _ in range(4): for _ in range(4):
_, updated_state = jax.value_and_grad(self._loss_fn, argnums=(3,), has_aux=False)( _, updated_state = jax.value_and_grad(self._loss_fn, argnums=(3,), has_aux=False)(
inputs, inputs,
...@@ -368,14 +369,24 @@ class BaseRunner: ...@@ -368,14 +369,24 @@ class BaseRunner:
test_layer, test_layer,
) )
if ( if (
get_quantize_config().get_scaling_mode(TensorSource.X) get_quantize_config_with_recipe(get_global_quantize_recipe()).get_scaling_mode(
TensorSource.X
)
== ScalingMode.DELAYED_TENSOR_SCALING == ScalingMode.DELAYED_TENSOR_SCALING
): ):
_, updated_quantize_meta = flax.core.pop( _, updated_quantize_meta = flax.core.pop(
updated_state[0], get_quantize_config().COLLECTION_NAME updated_state[0],
get_quantize_config_with_recipe(
get_global_quantize_recipe()
).COLLECTION_NAME,
) )
test_others = update_collections( test_others = update_collections(
{get_quantize_config().COLLECTION_NAME: updated_quantize_meta}, test_others {
get_quantize_config_with_recipe(
get_global_quantize_recipe()
).COLLECTION_NAME: updated_quantize_meta
},
test_others,
) )
del updated_quantize_meta del updated_quantize_meta
del updated_state del updated_state
......
This diff is collapsed.
...@@ -221,7 +221,7 @@ def manage_primitives(enable_names=None, disable_names=None, disable_all_first=F ...@@ -221,7 +221,7 @@ def manage_primitives(enable_names=None, disable_names=None, disable_all_first=F
""" """
Helper function to manage primitive states by name without modifying environment variables. Helper function to manage primitive states by name without modifying environment variables.
Allows enabling specific primitives, disabling specific primitives, or disabling all primitives. Allows enabling specific primitives, disabling specific primitives, or disabling all primitives.
This helper is used in the get_quantize_config().initialize() methods. This helper is used in the get_quantize_config_with_recipe().initialize() methods.
Args: Args:
enable_names: List of strings, each representing the name of a primitive class to enable. Defaults to None. enable_names: List of strings, each representing the name of a primitive class to enable. Defaults to None.
......
...@@ -38,12 +38,13 @@ from ..quantize import ( ...@@ -38,12 +38,13 @@ from ..quantize import (
ScalingMode, ScalingMode,
Quantizer, Quantizer,
GroupedQuantizer, GroupedQuantizer,
get_quantize_config,
QuantizerSet, QuantizerSet,
QuantizeLayout, QuantizeLayout,
noop_quantizer_set, noop_quantizer_set,
is_fp8_gemm_with_all_layouts_supported, is_fp8_gemm_with_all_layouts_supported,
apply_padding_to_scale_inv, apply_padding_to_scale_inv,
get_quantize_config_with_recipe,
get_global_quantize_recipe,
) )
from .misc import get_padded_spec, is_all_reduce_in_float32 from .misc import get_padded_spec, is_all_reduce_in_float32
from ..sharding import ( from ..sharding import (
...@@ -1246,7 +1247,7 @@ def _te_gemm( ...@@ -1246,7 +1247,7 @@ def _te_gemm(
fuse_bias: bool = False, fuse_bias: bool = False,
fuse_gelu: bool = False, fuse_gelu: bool = False,
grad: bool = False, grad: bool = False,
use_split_accumulator: bool = get_quantize_config().FP8_2X_ACC_FPROP, use_split_accumulator: bool = None,
transpose_batch_sequence: bool = False, transpose_batch_sequence: bool = False,
collective_op: CollectiveOp = CollectiveOp.NONE, collective_op: CollectiveOp = CollectiveOp.NONE,
) -> Tuple[jax.Array, ...]: ) -> Tuple[jax.Array, ...]:
...@@ -1258,6 +1259,13 @@ def _te_gemm( ...@@ -1258,6 +1259,13 @@ def _te_gemm(
DeprecationWarning, DeprecationWarning,
) )
if use_split_accumulator is None:
# TODO(jberchtold): Rework GEMM API to provide the context here instead of relying on global state and also
# use context of the GEMM type so we can decide between fprop, dgrad, and wgrad
use_split_accumulator = get_quantize_config_with_recipe(
get_global_quantize_recipe()
).FP8_2X_ACC_FPROP
# Prepare non-quantized GEMM operands # Prepare non-quantized GEMM operands
lhs_data = lhs lhs_data = lhs
rhs_data = rhs rhs_data = rhs
...@@ -1720,10 +1728,15 @@ def _jax_gemm( ...@@ -1720,10 +1728,15 @@ def _jax_gemm(
assert ( assert (
rhs.scaling_mode == lhs.scaling_mode rhs.scaling_mode == lhs.scaling_mode
), f"rhs.scaling_mode={rhs.scaling_mode} != lhs.scaling_mode={lhs.scaling_mode}" ), f"rhs.scaling_mode={rhs.scaling_mode} != lhs.scaling_mode={lhs.scaling_mode}"
# TODO(jberchtold): Rework GEMM API to provide the context here instead of relying on global state and also
# use context of the GEMM type so we can decide between fprop, dgrad, and wgrad
use_split_accumulator = get_quantize_config_with_recipe(
get_global_quantize_recipe()
).FP8_2X_ACC_FPROP
precision = ( precision = (
jax.lax.Precision.HIGHEST jax.lax.Precision.HIGHEST if use_split_accumulator else jax.lax.Precision.DEFAULT
if get_quantize_config().FP8_2X_ACC_FPROP
else jax.lax.Precision.DEFAULT
) )
return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision) return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision)
......
...@@ -820,7 +820,7 @@ def _quantize_dbias_impl( ...@@ -820,7 +820,7 @@ def _quantize_dbias_impl(
amax_scope=amax_scope, amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
) )
scale = compute_scale_from_amax(amax, quantizer.q_dtype) scale = compute_scale_from_amax(amax, quantizer.q_dtype, margin=0.0)
elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
scale = quantizer.scale scale = quantizer.scale
# Make sure to reset amax to zeros for DelayedScaling # Make sure to reset amax to zeros for DelayedScaling
...@@ -1227,7 +1227,7 @@ def grouped_quantize( ...@@ -1227,7 +1227,7 @@ def grouped_quantize(
) )
grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups) grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups)
for i in range(n_groups): for i in range(n_groups):
tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype) tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype, margin=0.0)
scale = scale.at[i].set(tmp_scale[0]) scale = scale.at[i].set(tmp_scale[0])
is_tensor_scaling = quantizer.scaling_mode in ( is_tensor_scaling = quantizer.scaling_mode in (
......
...@@ -27,7 +27,6 @@ from .quantize import ( ...@@ -27,7 +27,6 @@ from .quantize import (
with_sharding_constraint_by_logical_axes, with_sharding_constraint_by_logical_axes,
is_fp8_gemm_with_all_layouts_supported, is_fp8_gemm_with_all_layouts_supported,
TensorUsage, TensorUsage,
get_quantize_config,
) )
...@@ -95,7 +94,7 @@ def dense( ...@@ -95,7 +94,7 @@ def dense(
if transpose_batch_sequence: if transpose_batch_sequence:
warnings.warn("transpose_batch_sequence is not well tested, use with caution!") warnings.warn("transpose_batch_sequence is not well tested, use with caution!")
if not get_quantize_config().is_fp8_enabled(): if quantizer_set == noop_quantizer_set:
input_dtype = x.dtype input_dtype = x.dtype
kernel = kernel.astype(input_dtype) kernel = kernel.astype(input_dtype)
......
...@@ -33,10 +33,11 @@ from ..cpp_extensions import ( ...@@ -33,10 +33,11 @@ from ..cpp_extensions import (
) )
from ..quantize import ( from ..quantize import (
QuantizerFactory, QuantizerFactory,
get_quantize_config, get_global_quantize_recipe,
QuantizeMetaSet, QuantizeMetaSet,
TensorSource, TensorSource,
get_quantize_config_with_recipe, get_quantize_config_with_recipe,
noop_quantizer_set,
) )
PRNGKey = Any PRNGKey = Any
...@@ -355,17 +356,17 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method ...@@ -355,17 +356,17 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
Generate a set of FP8 meta for a GEMM. Generate a set of FP8 meta for a GEMM.
""" """
if fp8_recipe is None:
fp8_recipe = get_global_quantize_recipe()
quantize_config = get_quantize_config_with_recipe(fp8_recipe)
collection_name = ( collection_name = (
variable_collection variable_collection
if variable_collection is not None if variable_collection is not None
else get_quantize_config().COLLECTION_NAME else quantize_config.COLLECTION_NAME
) )
if fp8_recipe is None:
quantize_config = get_quantize_config()
else:
quantize_config = get_quantize_config_with_recipe(fp8_recipe)
x_meta = quantize_config.get_quantize_flax_meta( x_meta = quantize_config.get_quantize_flax_meta(
self, collection_name, postfix, TensorSource.X, "x" self, collection_name, postfix, TensorSource.X, "x"
) )
...@@ -492,7 +493,11 @@ class DenseGeneral(TransformerEngineBase): ...@@ -492,7 +493,11 @@ class DenseGeneral(TransformerEngineBase):
self.dtype, self.dtype,
) )
if not get_quantize_config().is_fp8_enabled(): quantizer_set = self.generate_quantizer_set(
quantization_checkpoint_name=self.quantization_checkpoint_name
)
if quantizer_set == noop_quantizer_set:
kernel = kernel.astype(input_dtype) kernel = kernel.astype(input_dtype)
if self.use_bias: if self.use_bias:
...@@ -505,9 +510,6 @@ class DenseGeneral(TransformerEngineBase): ...@@ -505,9 +510,6 @@ class DenseGeneral(TransformerEngineBase):
else: else:
bias = None bias = None
quantizer_set = self.generate_quantizer_set(
quantization_checkpoint_name=self.quantization_checkpoint_name
)
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
y = dense( y = dense(
inputs, inputs,
...@@ -712,7 +714,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -712,7 +714,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
) )
fuse_layernorm = ( fuse_layernorm = (
get_quantize_config().is_fp8_enabled() quantizer_set != noop_quantizer_set
and not self.return_layernorm_output and not self.return_layernorm_output
and self.enable_layernorm and self.enable_layernorm
) )
...@@ -763,7 +765,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -763,7 +765,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
kernel_shape, kernel_shape,
self.dtype, self.dtype,
) )
if not get_quantize_config().is_fp8_enabled(): if quantizer_set == noop_quantizer_set:
kernel = kernel.astype(input_dtype) kernel = kernel.astype(input_dtype)
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
...@@ -1042,7 +1044,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1042,7 +1044,7 @@ class LayerNormMLP(TransformerEngineBase):
# TODO(Phuong): use fuse_layernorm for high-precision # TODO(Phuong): use fuse_layernorm for high-precision
# when NoOpQuantizer and Tensor are implemented # when NoOpQuantizer and Tensor are implemented
fuse_layernorm = ( fuse_layernorm = (
get_quantize_config().is_fp8_enabled() ffn1_quantizer_set != noop_quantizer_set
and not self.return_layernorm_output and not self.return_layernorm_output
and self.enable_layernorm and self.enable_layernorm
) )
...@@ -1128,7 +1130,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1128,7 +1130,7 @@ class LayerNormMLP(TransformerEngineBase):
self.dtype, self.dtype,
) )
if not get_quantize_config().is_fp8_enabled(): if ffn1_quantizer_set == noop_quantizer_set:
kernel_1 = kernel_1.astype(input_dtype) kernel_1 = kernel_1.astype(input_dtype)
hidden_size = inputs.shape[-1] hidden_size = inputs.shape[-1]
...@@ -1140,7 +1142,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1140,7 +1142,7 @@ class LayerNormMLP(TransformerEngineBase):
kernel_2_shape, kernel_2_shape,
self.dtype, self.dtype,
) )
if not get_quantize_config().is_fp8_enabled(): if ffn2_quantizer_set == noop_quantizer_set:
kernel_2 = kernel_2.astype(input_dtype) kernel_2 = kernel_2.astype(input_dtype)
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
......
...@@ -23,7 +23,6 @@ from .quantize import ( ...@@ -23,7 +23,6 @@ from .quantize import (
noop_quantizer_set, noop_quantizer_set,
with_sharding_constraint_by_logical_axes, with_sharding_constraint_by_logical_axes,
TensorUsage, TensorUsage,
get_quantize_config,
) )
...@@ -73,7 +72,7 @@ def layernorm_dense( ...@@ -73,7 +72,7 @@ def layernorm_dense(
- Quantization is applied to both the normalized input and kernel - Quantization is applied to both the normalized input and kernel
""" """
if not get_quantize_config().is_fp8_enabled(): if quantizer_set == noop_quantizer_set:
input_dtype = x.dtype input_dtype = x.dtype
kernel = kernel.astype(input_dtype) kernel = kernel.astype(input_dtype)
......
...@@ -28,7 +28,6 @@ from .quantize import ( ...@@ -28,7 +28,6 @@ from .quantize import (
QuantizerSet, QuantizerSet,
noop_quantizer_set, noop_quantizer_set,
TensorUsage, TensorUsage,
get_quantize_config,
) )
...@@ -114,7 +113,7 @@ def layernorm_mlp( ...@@ -114,7 +113,7 @@ def layernorm_mlp(
not zero_centered_gamma not zero_centered_gamma
), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'" ), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'"
if not get_quantize_config().is_fp8_enabled(): if quantizer_sets == (noop_quantizer_set, noop_quantizer_set):
input_dtype = x.dtype input_dtype = x.dtype
kernel_1 = kernel_1.astype(input_dtype) kernel_1 = kernel_1.astype(input_dtype)
kernel_2 = kernel_2.astype(input_dtype) kernel_2 = kernel_2.astype(input_dtype)
......
...@@ -46,7 +46,7 @@ from .scaling_modes import ScalingMode ...@@ -46,7 +46,7 @@ from .scaling_modes import ScalingMode
from .device_utils import get_device_compute_capability from .device_utils import get_device_compute_capability
__all__ = [ __all__ = [
"get_quantize_config", "get_global_quantize_recipe",
"get_quantize_config_with_recipe", "get_quantize_config_with_recipe",
"autocast", "autocast",
"fp8_autocast", "fp8_autocast",
...@@ -475,7 +475,12 @@ class DelayedScalingQuantizeConfig(BaseQuantizeConfig): ...@@ -475,7 +475,12 @@ class DelayedScalingQuantizeConfig(BaseQuantizeConfig):
(self.AMAX_HISTORY_LEN,), (self.AMAX_HISTORY_LEN,),
jnp.float32, jnp.float32,
).value ).value
return QuantizeMeta(scale=scale, amax_history=amax_history) return QuantizeMeta(
margin=self.MARGIN,
amax_compute_algo=self.AMAX_COMPUTE_ALGO,
scale=scale,
amax_history=amax_history,
)
class CurrentScalingQuantizeConfig(BaseQuantizeConfig): class CurrentScalingQuantizeConfig(BaseQuantizeConfig):
...@@ -669,14 +674,6 @@ class NVFP4ScalingQuantizeConfig(BaseQuantizeConfig): ...@@ -669,14 +674,6 @@ class NVFP4ScalingQuantizeConfig(BaseQuantizeConfig):
) )
_QUANTIZE_CONFIG = NoOpQuantizeConfig()
def get_quantize_config():
"""Global instance of BaseQuantizeConfig set by autocast context."""
return _QUANTIZE_CONFIG
def get_quantize_config_class( def get_quantize_config_class(
fp8_recipe: Recipe, fp8_recipe: Recipe,
) -> Type[BaseQuantizeConfig]: ) -> Type[BaseQuantizeConfig]:
...@@ -687,6 +684,8 @@ def get_quantize_config_class( ...@@ -687,6 +684,8 @@ def get_quantize_config_class(
Returns: Returns:
The quantization config class corresponding to the given recipe. The quantization config class corresponding to the given recipe.
""" """
if fp8_recipe is None:
return NoOpQuantizeConfig
if isinstance(fp8_recipe, DelayedScaling): if isinstance(fp8_recipe, DelayedScaling):
return DelayedScalingQuantizeConfig return DelayedScalingQuantizeConfig
if isinstance(fp8_recipe, MXFP8BlockScaling): if isinstance(fp8_recipe, MXFP8BlockScaling):
...@@ -701,10 +700,23 @@ def get_quantize_config_class( ...@@ -701,10 +700,23 @@ def get_quantize_config_class(
def get_quantize_config_with_recipe(fp8_recipe: Recipe): def get_quantize_config_with_recipe(fp8_recipe: Recipe):
"""Get the quantization configuration object based on the FP8 recipe.""" """Get the quantization configuration object based on the FP8 recipe."""
config = get_quantize_config_class(fp8_recipe)() config = get_quantize_config_class(fp8_recipe)()
config.initialize_from_recipe(fp8_recipe) if fp8_recipe is not None:
config.initialize_from_recipe(fp8_recipe)
return config return config
_GLOBAL_RECIPE: Optional[Recipe] = None
def get_global_quantize_recipe() -> Optional[Recipe]:
"""Get the global quantization recipe if set.
Returns:
The global quantization recipe or None if not set.
"""
return _GLOBAL_RECIPE
@contextmanager @contextmanager
def autocast( def autocast(
enabled: bool = False, enabled: bool = False,
...@@ -751,22 +763,21 @@ def autocast( ...@@ -751,22 +763,21 @@ def autocast(
if recipe is None: if recipe is None:
recipe = DelayedScaling() recipe = DelayedScaling()
global _QUANTIZE_CONFIG global _GLOBAL_RECIPE
old_quantize_config = _QUANTIZE_CONFIG old_global_recipe = _GLOBAL_RECIPE
_QUANTIZE_CONFIG = NoOpQuantizeConfig() _GLOBAL_RECIPE = None
try: try:
with global_shard_guard(mesh_resource): with global_shard_guard(mesh_resource):
if enabled: if enabled:
_QUANTIZE_CONFIG = get_quantize_config_class(recipe)() _GLOBAL_RECIPE = recipe
is_supported, reason = _QUANTIZE_CONFIG.is_supported() is_supported, reason = get_quantize_config_class(_GLOBAL_RECIPE)().is_supported()
assert is_supported, reason assert is_supported, reason
_QUANTIZE_CONFIG.initialize_from_recipe(recipe)
yield yield
finally: finally:
_QUANTIZE_CONFIG = old_quantize_config _GLOBAL_RECIPE = old_global_recipe
@contextmanager @contextmanager
......
...@@ -28,7 +28,7 @@ from .tensor import ( ...@@ -28,7 +28,7 @@ from .tensor import (
NoScaleTensor, NoScaleTensor,
) )
from .helper import ( from .helper import (
get_quantize_config, get_global_quantize_recipe,
get_quantize_config_with_recipe, get_quantize_config_with_recipe,
AmaxComputeAlgo, AmaxComputeAlgo,
TensorSource, TensorSource,
...@@ -50,7 +50,7 @@ __all__ = [ ...@@ -50,7 +50,7 @@ __all__ = [
def compute_scale_from_amax( def compute_scale_from_amax(
amax: jnp.ndarray, q_dtype: jnp.dtype, scale: Optional[jnp.ndarray] = None amax: jnp.ndarray, q_dtype: jnp.dtype, margin: float, scale: Optional[jnp.ndarray] = None
) -> jnp.ndarray: ) -> jnp.ndarray:
"""Compute scale from amax value. """Compute scale from amax value.
...@@ -64,7 +64,7 @@ def compute_scale_from_amax( ...@@ -64,7 +64,7 @@ def compute_scale_from_amax(
fp8_max = jnp.astype(jnp.finfo(q_dtype).max, jnp.float32) fp8_max = jnp.astype(jnp.finfo(q_dtype).max, jnp.float32)
if scale is None: if scale is None:
scale = jnp.ones((1,)) scale = jnp.ones((1,))
sf = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN) sf = (fp8_max / amax) / (2**margin)
sf = jnp.where(amax > 0.0, sf, scale) sf = jnp.where(amax > 0.0, sf, scale)
sf = jnp.where(jnp.isfinite(amax), sf, scale) sf = jnp.where(jnp.isfinite(amax), sf, scale)
assert sf.shape == (1,), f"Expected sf.shape == (1,), but got {sf.shape}" assert sf.shape == (1,), f"Expected sf.shape == (1,), but got {sf.shape}"
...@@ -223,6 +223,7 @@ class CurrentScaleQuantizer(Quantizer): ...@@ -223,6 +223,7 @@ class CurrentScaleQuantizer(Quantizer):
Attributes: Attributes:
scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING
q_layout: Quantization axis (default: ROWWISE_COLWISE) q_layout: Quantization axis (default: ROWWISE_COLWISE)
data_layout: Data layout string (default: "NT")
""" """
scaling_mode: ScalingMode = ScalingMode.CURRENT_TENSOR_SCALING scaling_mode: ScalingMode = ScalingMode.CURRENT_TENSOR_SCALING
...@@ -254,8 +255,7 @@ class CurrentScaleQuantizer(Quantizer): ...@@ -254,8 +255,7 @@ class CurrentScaleQuantizer(Quantizer):
compute_dtype = jnp.float32 compute_dtype = jnp.float32
dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype) dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
amax = x.amax or jnp.max(jnp.abs(x.data)).reshape((1,)) amax = x.amax or jnp.max(jnp.abs(x.data)).reshape((1,))
fp8_max = jnp.astype(jnp.finfo(self.q_dtype).max, jnp.float32) scale = compute_scale_from_amax(amax, self.q_dtype, margin=0.0)
scale = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN)
scaled_x = x.data.astype(compute_dtype) * scale scaled_x = x.data.astype(compute_dtype) * scale
clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype) clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype)
...@@ -327,17 +327,23 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer): ...@@ -327,17 +327,23 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
Attributes: Attributes:
scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING
q_layout: Quantization axis (default: ROWWISE_COLWISE) q_layout: Quantization axis (default: ROWWISE_COLWISE)
data_layout: Data layout string (default: "NT")
margin: Margin value for scale computation
amax_compute_algo: Algorithm for computing amax
scale: Current scaling factor scale: Current scaling factor
amax_history: History of maximum absolute values amax_history: History of maximum absolute values
""" """
scaling_mode: ScalingMode = ScalingMode.DELAYED_TENSOR_SCALING margin: float = 0.0
q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE amax_compute_algo: AmaxComputeAlgo = AmaxComputeAlgo.MAX
scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32)) scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32))
amax_history: jnp.ndarray = field( amax_history: jnp.ndarray = field(default_factory=lambda: jnp.zeros((1024,), jnp.float32))
default_factory=lambda: jnp.zeros((get_quantize_config().AMAX_HISTORY_LEN,), jnp.float32)
) def __post_init__(self):
assert self.margin is not None, "margin must be specified"
assert self.amax_compute_algo is not None, "amax_compute_algo must be specified"
assert self.amax_history is not None, "amax_history must be specified"
def tree_flatten(self): def tree_flatten(self):
"""Flatten the quantizer for JAX tree operations. """Flatten the quantizer for JAX tree operations.
...@@ -352,6 +358,8 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer): ...@@ -352,6 +358,8 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
self.q_layout, self.q_layout,
self.data_layout, self.data_layout,
self.checkpoint_name, self.checkpoint_name,
self.margin,
self.amax_compute_algo,
) )
return (children, aux_data) return (children, aux_data)
...@@ -407,12 +415,14 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer): ...@@ -407,12 +415,14 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
Returns: Returns:
Updated AMAX history Updated AMAX history
""" """
amax_history = amax_history.at[0].set(new_amax[0]) amax_history = amax_history.at[0].set(new_amax.reshape((1,))[0])
return amax_history return amax_history
@staticmethod @staticmethod
@partial(jax.jit, static_argnums=(2,)) @partial(jax.jit, static_argnums=(2, 3, 4))
def _compute_scale(amax_history, scale, q_dtype): def _compute_scale(
amax_history, scale, q_dtype, amax_compute_algo: AmaxComputeAlgo, margin: float
):
"""Compute new scale based on AMAX history. """Compute new scale based on AMAX history.
Args: Args:
...@@ -424,12 +434,12 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer): ...@@ -424,12 +434,12 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
Updated scale value Updated scale value
""" """
# 2. Calculate the current scale # 2. Calculate the current scale
if get_quantize_config().AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX: if amax_compute_algo is AmaxComputeAlgo.MAX:
amax = jnp.max(amax_history, axis=-1, keepdims=True) amax = jnp.max(amax_history, axis=-1, keepdims=True)
else: else:
amax = amax_history[0:1] amax = amax_history[0:1]
return compute_scale_from_amax(amax, q_dtype, scale=scale) return compute_scale_from_amax(amax, q_dtype, margin=margin, scale=scale)
@staticmethod @staticmethod
@jax.jit @jax.jit
...@@ -453,7 +463,9 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer): ...@@ -453,7 +463,9 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
new_amax: New maximum absolute value to add to history new_amax: New maximum absolute value to add to history
""" """
amax_history = self._update_amax_history(self.amax_history, new_amax) amax_history = self._update_amax_history(self.amax_history, new_amax)
self.scale = self._compute_scale(amax_history, self.scale, self.q_dtype) self.scale = self._compute_scale(
amax_history, self.scale, self.q_dtype, self.amax_compute_algo, self.margin
)
self.amax_history = self._roll_and_reset_amax_history(amax_history) self.amax_history = self._roll_and_reset_amax_history(amax_history)
...@@ -1124,6 +1136,7 @@ class QuantizerFactory: ...@@ -1124,6 +1136,7 @@ class QuantizerFactory:
bwd_dtype, bwd_dtype,
is_2x2x, is_2x2x,
n_groups, n_groups,
is_inference_mode=False,
checkpoint_name: Optional[str] = None, checkpoint_name: Optional[str] = None,
**kwargs, **kwargs,
) -> QuantizerSet: ) -> QuantizerSet:
...@@ -1137,6 +1150,7 @@ class QuantizerFactory: ...@@ -1137,6 +1150,7 @@ class QuantizerFactory:
bwd_dtype: Data type for backward pass bwd_dtype: Data type for backward pass
is_2x2x: Whether to use 2x2x quantization is_2x2x: Whether to use 2x2x quantization
n_groups n_groups
is_inference_mode: Whether to create quantizers for inference mode. This option is not fully supported yet
checkpoint_name: Optional name for checkpointing quantizations checkpoint_name: Optional name for checkpointing quantizations
**kwargs: Additional arguments for quantizer initialization **kwargs: Additional arguments for quantizer initialization
...@@ -1149,7 +1163,7 @@ class QuantizerFactory: ...@@ -1149,7 +1163,7 @@ class QuantizerFactory:
q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE
if kernel_scaling_mode.is_1d_block_scaling(): if kernel_scaling_mode.is_1d_block_scaling():
q_layout_kernel = QuantizeLayout.COLWISE q_layout_kernel = QuantizeLayout.COLWISE
if get_quantize_config().INFERENCE_MODE: if is_inference_mode:
q_layout_dgrad = None q_layout_dgrad = None
if "quantize_meta_set" in kwargs: if "quantize_meta_set" in kwargs:
...@@ -1206,10 +1220,10 @@ class QuantizerFactory: ...@@ -1206,10 +1220,10 @@ class QuantizerFactory:
Args: Args:
n_quantizer_sets: Number of quantizer sets to create n_quantizer_sets: Number of quantizer sets to create
scaling_mode: Scaling mode to use, default is get_quantize_config().get_scaling_mode scaling_mode: Scaling mode to use, default is get the scaling mode from the specified or global recipe
fwd_dtype: Data type for forward pass, default is get_quantize_config().FWD_DTYPE fwd_dtype: Data type for forward pass, default is get the fwd dtype from the specified or global recipe
bwd_dtype: Data type for backward pass, default is get_quantize_config().BWD_DTYPE bwd_dtype: Data type for backward pass, default is get the bwd dtype from the specified or global recipe
is_2x2x: Whether to use 2x2x quantization, default is get_quantize_config().IF_QUANTIZE_2X is_2x2x: Whether to use 2x2x quantization, default is determined based on the specified or global recipe
n_groups: n_groups:
checkpoint_name: Optional name for checkpointing quantizations checkpoint_name: Optional name for checkpointing quantizations
fp8_recipe: Recipe to use for quantization. Scaling mode can be specified directly via the scaling_mode parameter or indirectly via recipe. Recipe is preferred as it will support additional recipes in future where scaling mode differs between x, kernel, and grad in the quantizer set. fp8_recipe: Recipe to use for quantization. Scaling mode can be specified directly via the scaling_mode parameter or indirectly via recipe. Recipe is preferred as it will support additional recipes in future where scaling mode differs between x, kernel, and grad in the quantizer set.
...@@ -1226,25 +1240,46 @@ class QuantizerFactory: ...@@ -1226,25 +1240,46 @@ class QuantizerFactory:
" scaling mode differs between x, kernel, and grad in the quantizer set." " scaling mode differs between x, kernel, and grad in the quantizer set."
) )
# TODO(jberchtold): Currently this is a limitation because we only support automatically populating quantizer fields based on a given recipe when using Flax. In the generic quantizer logic, we cannot assume Flax is being used, so we require the user to provide the quantize_meta_set created by quantize_config.get_quantize_flax_meta() or the same data created by themselves if they are passing a recipe here directly.
assert (
fp8_recipe is None or "quantize_meta_set" in kwargs
), "When fp8_recipe is specified, quantize_meta_set must be provided in kwargs."
if fp8_recipe is None:
fp8_recipe = get_global_quantize_recipe()
if fp8_recipe is not None: if fp8_recipe is not None:
assert scaling_mode is None, (
"scaling_mode should not be specified when fp8_recipe is provided either directly"
" or through an autocast context."
)
assert fwd_dtype is None, (
"fwd_dtype should not be specified when fp8_recipe is provided either directly or"
" through an autocast context."
)
assert bwd_dtype is None, (
"bwd_dtype should not be specified when fp8_recipe is provided either directly or"
" through an autocast context."
)
quantize_config = get_quantize_config_with_recipe(fp8_recipe) quantize_config = get_quantize_config_with_recipe(fp8_recipe)
x_scaling_mode = quantize_config.get_scaling_mode(TensorSource.X) x_scaling_mode = quantize_config.get_scaling_mode(TensorSource.X)
kernel_scaling_mode = quantize_config.get_scaling_mode(TensorSource.KERNEL) kernel_scaling_mode = quantize_config.get_scaling_mode(TensorSource.KERNEL)
grad_scaling_mode = quantize_config.get_scaling_mode(TensorSource.DGRAD) grad_scaling_mode = quantize_config.get_scaling_mode(TensorSource.DGRAD)
fwd_dtype = quantize_config.FWD_DTYPE fwd_dtype = quantize_config.FWD_DTYPE
bwd_dtype = quantize_config.BWD_DTYPE bwd_dtype = quantize_config.BWD_DTYPE
is_inference_mode = quantize_config.INFERENCE_MODE
else: else:
if scaling_mode is not None: if scaling_mode is not None:
x_scaling_mode = scaling_mode x_scaling_mode = scaling_mode
kernel_scaling_mode = scaling_mode kernel_scaling_mode = scaling_mode
grad_scaling_mode = scaling_mode grad_scaling_mode = scaling_mode
else: else:
x_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.X) # TODO(jberchtold): make a way to explicitly pass a no scaling recipe here if we need other quantization config attributes in the future since NoOpQuantizeConfig already exists, we just can't use it here with direct recipe passing because we cannot differentiate between fp8_recipe=None meaning no recipe specified vs explicitly no quantization desired.
kernel_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.KERNEL) x_scaling_mode = ScalingMode.NO_SCALING
grad_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.DGRAD) kernel_scaling_mode = ScalingMode.NO_SCALING
grad_scaling_mode = ScalingMode.NO_SCALING
is_inference_mode = False
fwd_dtype = fwd_dtype or get_quantize_config().FWD_DTYPE
bwd_dtype = bwd_dtype or get_quantize_config().BWD_DTYPE
if is_2x2x is None: if is_2x2x is None:
# TODO(Jeremy): check x, kernel, grad separately for 2x # TODO(Jeremy): check x, kernel, grad separately for 2x
if x_scaling_mode.is_1d_block_scaling(): if x_scaling_mode.is_1d_block_scaling():
...@@ -1253,7 +1288,6 @@ class QuantizerFactory: ...@@ -1253,7 +1288,6 @@ class QuantizerFactory:
is_2x2x = not is_fp8_gemm_with_all_layouts_supported() is_2x2x = not is_fp8_gemm_with_all_layouts_supported()
else: # NO_SCALING ignores is_2x2x for now else: # NO_SCALING ignores is_2x2x for now
is_2x2x = False is_2x2x = False
is_inference_mode = get_quantize_config().INFERENCE_MODE
assert not is_inference_mode, "Inference mode is not supported yet!" assert not is_inference_mode, "Inference mode is not supported yet!"
q_set = [] q_set = []
...@@ -1267,6 +1301,7 @@ class QuantizerFactory: ...@@ -1267,6 +1301,7 @@ class QuantizerFactory:
bwd_dtype=bwd_dtype, bwd_dtype=bwd_dtype,
is_2x2x=is_2x2x, is_2x2x=is_2x2x,
n_groups=n_groups, n_groups=n_groups,
is_inference_mode=is_inference_mode,
checkpoint_name=checkpoint_name, checkpoint_name=checkpoint_name,
**kwargs, **kwargs,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment