Unverified Commit 67d63d02 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Support for checkpointing quantizations (#2356)



* Support for checkpointing quantizations
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Add jaxpr test for quant checkpoint name
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Revert "Support for checkpointing quantizations"

This reverts commit f7b784940369d0da2a77c57fa6ea744e883c5832.
Signed-off-by: default avatarJAX Toolbox <jax@nvidia.com>

* Checkpoint quantizations
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* revert other files
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* move checkpointing to VJPs
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* fix ci failure
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: default avatarJAX Toolbox <jax@nvidia.com>
Co-authored-by: default avatarJAX Toolbox <jax@nvidia.com>
parent 9440b76a
......@@ -263,23 +263,16 @@ class TestFP8Functions(unittest.TestCase):
class TestJaxprAndHlo:
"""Tests to verify Jaxpr and/or HLO of compiled modules apply expected recipe functionality and optimizations."""
@pytest_parametrize_wrapper(
"quantization_recipe",
[
quantization_recipe
for quantization_recipe in SUPPORTED_RECIPES
if isinstance(quantization_recipe, NVFP4BlockScaling)
],
)
def test_layernorm_mlp_reuses_amax_nvfp4(self, quantization_recipe):
"""Tests that layernorm_mlp reuses the amax computed in layernorm and the activation and does not recompute it during quantizaton."""
def _generate_jaxpr_for_layernorm_mlp_fwd_bwd(self, quantization_recipe, ln_mlp_kwargs=None):
"""Generates the jaxpr for a forward and backward pass of LayerNormMLP under the given quantization recipe."""
ln_mlp_kwargs = ln_mlp_kwargs or {}
with te.autocast(enabled=True, recipe=quantization_recipe, mesh_resource=te.MeshResource()):
model = te_flax.LayerNormMLP(
layernorm_type="rmsnorm",
return_layernorm_output=False,
intermediate_dropout_rate=0.0,
dtype=jnp.bfloat16,
**ln_mlp_kwargs,
)
var_collect = model.init(
......@@ -292,29 +285,83 @@ class TestJaxprAndHlo:
x = jax.random.normal(jax.random.PRNGKey(0), (128, 128), dtype=jnp.bfloat16)
rngs = {"sr_rng": jax.random.PRNGKey(1), "dropout": jax.random.PRNGKey(2)}
jaxpr = jax.make_jaxpr(jax.value_and_grad(loss_fn))(x, rngs=rngs)
rht_amax_eqns = [
eqn for eqn in jaxpr.jaxpr.eqns if eqn.primitive.name == "te_rht_amax_ffi_wrapper"
]
assert len(rht_amax_eqns) == 4, f"Expected 4 rht_amax_eqns, got {len(rht_amax_eqns)}"
def assert_param(index, tensor_name, expected_value: bool):
if expected_value:
assert rht_amax_eqns[index].params["produce_regular_amax"] == True, (
f"Expected produce_regular_amax for {tensor_name} to be True, indicating no"
" reuse of amax as this tensor does not have a previous operation to fuse"
" with"
)
else:
assert rht_amax_eqns[index].params["produce_regular_amax"] == False, (
f"Expected produce_regular_amax for {tensor_name} to be False, indicating"
" reuse of amax"
)
assert_param(0, "fwd ln+q", False)
assert_param(1, "fwd act+q", False)
# No previous op before incoming dgrad in the backward so amax is not reused
assert_param(2, "bwd dgrad", True)
assert_param(3, "bwd dact+q", False)
return jax.make_jaxpr(jax.value_and_grad(loss_fn))(x, rngs=rngs)
@pytest_parametrize_wrapper(
"quantization_recipe",
[
quantization_recipe
for quantization_recipe in SUPPORTED_RECIPES
if isinstance(quantization_recipe, NVFP4BlockScaling)
],
)
def test_layernorm_mlp_reuses_amax_nvfp4(self, quantization_recipe):
"""Tests that layernorm_mlp reuses the amax computed in layernorm and the activation and does not recompute it during quantizaton."""
jaxpr = self._generate_jaxpr_for_layernorm_mlp_fwd_bwd(quantization_recipe)
rht_amax_eqns = [
eqn for eqn in jaxpr.jaxpr.eqns if eqn.primitive.name == "te_rht_amax_ffi_wrapper"
]
assert len(rht_amax_eqns) == 4, f"Expected 4 rht_amax_eqns, got {len(rht_amax_eqns)}"
def assert_param(index, tensor_name, expected_value: bool):
if expected_value:
assert rht_amax_eqns[index].params["produce_regular_amax"] == True, (
f"Expected produce_regular_amax for {tensor_name} to be True, indicating no"
" reuse of amax as this tensor does not have a previous operation to fuse"
" with"
)
else:
assert rht_amax_eqns[index].params["produce_regular_amax"] == False, (
f"Expected produce_regular_amax for {tensor_name} to be False, indicating"
" reuse of amax"
)
assert_param(0, "fwd ln+q", False)
assert_param(1, "fwd act+q", False)
# No previous op before incoming dgrad in the backward so amax is not reused
assert_param(2, "bwd dgrad", True)
assert_param(3, "bwd dact+q", False)
@pytest_parametrize_wrapper("quantization_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper(
"quantization_checkpoint_name",
[None, "quantization", "some_arbitrary_user_checkpoint_name"],
)
def test_recipe_supports_quantization_checkpointing(
self, quantization_recipe, quantization_checkpoint_name
):
"""Tests that all supported quantization recipes correctly use checkpoint_name."""
kwargs = {
"quantization_checkpoint_name": quantization_checkpoint_name,
}
jaxpr = self._generate_jaxpr_for_layernorm_mlp_fwd_bwd(quantization_recipe, kwargs)
checkpoint_name_eqns = [
eqn
for eqn in jaxpr.jaxpr.eqns
if eqn.primitive.name == "name" and eqn.params["name"] == quantization_checkpoint_name
]
if quantization_checkpoint_name is None:
assert len(checkpoint_name_eqns) == 0, (
"Expected 0 checkpoint_name eqns when quantization_checkpoint_name is None, got"
f" {len(checkpoint_name_eqns)}"
)
return
# 12 checkpointed values:
# - Fwd pass:
# - Input RMSNorm+Q -> 3 possible output tensors that will be used in the backward
# - Kernel Q -> 3 possible output tensors that will be used in the backward
# - Input Activation+Q -> 3 possible output tensors that will be used in the backward
# - Kernel Q -> 3 possible output tensors that will be used in the backward
expected_checkpoint_eqn_count = 12
assert len(checkpoint_name_eqns) == expected_checkpoint_eqn_count, (
f"Expected {expected_checkpoint_eqn_count} checkpoint_name eqns when"
f" quantization_checkpoint_name is set, got {len(checkpoint_name_eqns)}"
)
......@@ -19,6 +19,7 @@ from . import cpp_extensions as tex
from .cpp_extensions.amax import AmaxScope
from .quantize import (
ScaledTensorFactory,
ScaledTensor,
ScalingMode,
QuantizeLayout,
QuantizerSet,
......@@ -227,8 +228,8 @@ def _dense_fwd_rule(
output += jnp.reshape(bias, bias_new_shape)
ctx = (
casted_x.get_tensor(usage=TensorUsage.LHS_TRANS),
casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS),
casted_x.get_tensor(usage=TensorUsage.LHS_TRANS).checkpoint(quantizer_set.x),
casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS).checkpoint(quantizer_set.kernel),
x.shape,
kernel.shape,
use_bias,
......@@ -529,8 +530,12 @@ def _grouped_dense_fwd_rule(
ctx = (
group_sizes,
ctx_x,
ctx_kernel,
ctx_x.checkpoint(quantizer_set.x) if isinstance(ctx_x, ScaledTensor) else ctx_x,
(
ctx_kernel.checkpoint(quantizer_set.kernel)
if isinstance(ctx_kernel, ScaledTensor)
else ctx_kernel
),
x.shape,
kernel.shape,
use_bias,
......
......@@ -6,7 +6,7 @@ Wrapper module for Transformer related layers with FP8 support.
"""
from functools import reduce
import operator
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType, Optional
import numpy as np
import jax.numpy as jnp
......@@ -345,7 +345,11 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
"""
def generate_quantizer_set(
self, postfix: str = "", variable_collection: str = None, fp8_recipe=None
self,
postfix: str = "",
variable_collection: str = None,
quantization_checkpoint_name: Optional[str] = None,
fp8_recipe=None,
):
"""
Generate a set of FP8 meta for a GEMM.
......@@ -375,7 +379,9 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta)
quantizer_set = QuantizerFactory.create_set(
fp8_recipe=fp8_recipe, quantize_meta_set=quantize_meta_set
fp8_recipe=fp8_recipe,
quantize_meta_set=quantize_meta_set,
checkpoint_name=quantization_checkpoint_name,
)
return quantizer_set
......@@ -424,6 +430,8 @@ class DenseGeneral(TransformerEngineBase):
The data type used to allocate the initial parameters.
transpose_batch_sequence: bool, default = False
Indicate whether to transpose the batch and sequence dimensions of the input tensor.
quantization_checkpoint_name: Optional[str], default = None
The name for checkpointing quantizations.
"""
features: Union[Iterable[int], int]
......@@ -439,6 +447,7 @@ class DenseGeneral(TransformerEngineBase):
dtype: DType = jnp.float32
input_axes: Tuple[str, ...] = ()
transpose_batch_sequence: bool = False
quantization_checkpoint_name: Optional[str] = None
def __post_init__(self):
if self.kernel_init is None:
......@@ -496,7 +505,9 @@ class DenseGeneral(TransformerEngineBase):
else:
bias = None
quantizer_set = self.generate_quantizer_set()
quantizer_set = self.generate_quantizer_set(
quantization_checkpoint_name=self.quantization_checkpoint_name
)
contract_ind = tuple(range(0, len(axis)))
y = dense(
inputs,
......@@ -628,6 +639,8 @@ class LayerNormDenseGeneral(TransformerEngineBase):
value or None. When None is set, then no scaling is applied.
transpose_batch_sequence: bool, default = False
Indicate whether to transpose the batch and sequence dimensions of the input tensor.
quantization_checkpoint_name: Optional[str], default = None
The name for checkpointing quantizations.
"""
features: Union[Iterable[int], int]
......@@ -654,6 +667,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
dot_input_axes: Tuple[str, ...] = None
depth_scaling: float = None
transpose_batch_sequence: bool = False
quantization_checkpoint_name: Optional[str] = None
def __post_init__(self):
if self.kernel_init is None:
......@@ -693,7 +707,9 @@ class LayerNormDenseGeneral(TransformerEngineBase):
input_dtype = inputs.dtype
ln_output = None
quantizer_set = self.generate_quantizer_set()
quantizer_set = self.generate_quantizer_set(
quantization_checkpoint_name=self.quantization_checkpoint_name
)
fuse_layernorm = (
get_quantize_config().is_fp8_enabled()
......@@ -941,6 +957,8 @@ class LayerNormMLP(TransformerEngineBase):
The data type used to allocate the initial parameters.
transpose_batch_sequence: bool, default = False
Indicate whether to transpose the batch and sequence dimensions of the input tensor.
quantization_checkpoint_name: Optional[str], default = None
The name for checkpointing quantizations.
"""
intermediate_dim: int = 2048
......@@ -976,6 +994,7 @@ class LayerNormMLP(TransformerEngineBase):
ffn1_ckpt_name: str = "ffn1"
ffn2_ckpt_name: str = "ffn2"
transpose_batch_sequence: bool = False
quantization_checkpoint_name: Optional[str] = None
def __post_init__(self):
if self.kernel_init is None:
......@@ -1010,8 +1029,12 @@ class LayerNormMLP(TransformerEngineBase):
"""
assert self.axis == -1, "Only support axis == -1 at this moment"
ffn1_quantizer_set = self.generate_quantizer_set("_0")
ffn2_quantizer_set = self.generate_quantizer_set("_1")
ffn1_quantizer_set = self.generate_quantizer_set(
"_0", quantization_checkpoint_name=self.quantization_checkpoint_name
)
ffn2_quantizer_set = self.generate_quantizer_set(
"_1", quantization_checkpoint_name=self.quantization_checkpoint_name
)
input_dtype = inputs.dtype
ln_output = None
......
......@@ -236,8 +236,8 @@ def _layernorm_dense_fwd_rule(
output += jnp.reshape(bias, bias_new_shape)
ctx = (
casted_ln_out.get_tensor(TensorUsage.LHS_TRANS),
casted_kernel.get_tensor(TensorUsage.RHS_TRANS),
casted_ln_out.get_tensor(TensorUsage.LHS_TRANS).checkpoint(quantizer_set.x),
casted_kernel.get_tensor(TensorUsage.RHS_TRANS).checkpoint(quantizer_set.kernel),
x.shape,
kernel.shape,
mu,
......
......@@ -390,11 +390,11 @@ def _layernorm_mlp_fwd_rule(
rsigma,
gamma,
beta,
casted_ln_out.get_tensor(TensorUsage.LHS_TRANS),
casted_kernel_1.get_tensor(TensorUsage.RHS_TRANS),
casted_ln_out.get_tensor(TensorUsage.LHS_TRANS).checkpoint(ffn1_quantizer_set.x),
casted_kernel_1.get_tensor(TensorUsage.RHS_TRANS).checkpoint(ffn1_quantizer_set.kernel),
dot_1_output,
casted_act_out.get_tensor(TensorUsage.LHS_TRANS),
casted_kernel_2.get_tensor(TensorUsage.RHS_TRANS),
casted_act_out.get_tensor(TensorUsage.LHS_TRANS).checkpoint(ffn2_quantizer_set.x),
casted_kernel_2.get_tensor(TensorUsage.RHS_TRANS).checkpoint(ffn2_quantizer_set.kernel),
x_contracting_dims,
k_contracting_dims,
kernel_1.shape,
......
......@@ -83,12 +83,15 @@ class Quantizer(ABC):
q_dtype: The data type for quantized values
scaling_mode: The scaling mode to use for quantization
q_layout: The quantization axis (row-wise, column-wise, or both)
data_layout: The data layout string (e.g., "NT")
checkpoint_name: Optional name for checkpointing quantization state
"""
q_dtype: jnp.dtype
scaling_mode: ScalingMode
q_layout: QuantizeLayout
data_layout: str
checkpoint_name: Optional[str] = None
def tree_flatten(self):
"""Flatten the quantizer for JAX tree operations.
......@@ -97,7 +100,13 @@ class Quantizer(ABC):
Tuple of (children, aux_data) for tree operations
"""
children = ()
aux_data = (self.q_dtype, self.scaling_mode, self.q_layout, self.data_layout)
aux_data = (
self.q_dtype,
self.scaling_mode,
self.q_layout,
self.data_layout,
self.checkpoint_name,
)
return (children, aux_data)
@classmethod
......@@ -337,7 +346,13 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
Tuple of (children, aux_data) for tree operations
"""
children = (self.scale, self.amax_history)
aux_data = (self.q_dtype, self.scaling_mode, self.q_layout, self.data_layout)
aux_data = (
self.q_dtype,
self.scaling_mode,
self.q_layout,
self.data_layout,
self.checkpoint_name,
)
return (children, aux_data)
def _quantize_func(
......@@ -588,7 +603,14 @@ class NVFP4Quantizer(Quantizer):
Tuple of (children, aux_data) for tree operations
"""
children = (self.stochastic_rounding_rng_state,)
aux_data = (self.q_dtype, self.scaling_mode, self.q_layout, self.data_layout, self.use_rht)
aux_data = (
self.q_dtype,
self.scaling_mode,
self.q_layout,
self.data_layout,
self.checkpoint_name,
self.use_rht,
)
return (children, aux_data)
@classmethod
......@@ -867,7 +889,14 @@ class GroupedQuantizer(Quantizer):
Tuple of (children, aux_data) for tree operations
"""
children = (self.quantizers,)
aux_data = (self.q_dtype, self.scaling_mode, self.q_layout, self.data_layout, self.n_groups)
aux_data = (
self.q_dtype,
self.scaling_mode,
self.q_layout,
self.data_layout,
self.checkpoint_name,
self.n_groups,
)
return (children, aux_data)
def __post_init__(self):
......@@ -1041,6 +1070,7 @@ class QuantizerFactory:
q_dtype: jnp.dtype = None,
q_layout: QuantizeLayout = None,
n_groups: int = None,
checkpoint_name: Optional[str] = None,
**kwargs,
) -> Quantizer:
"""Create one or more quantizers with specified parameters.
......@@ -1052,6 +1082,7 @@ class QuantizerFactory:
q_layout: Quantization axis
flatten_axis: The quantization axis for the tensor
n_groups: Number of quantizers if GroupedQuantizer
checkpoint_name: Optional name for checkpointing quantizations
**kwargs: Additional arguments for quantizer initialization
Returns:
......@@ -1075,7 +1106,11 @@ class QuantizerFactory:
for _ in range(n_quantizers):
quantizers.append(
quantizer_type(
q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout, **kwargs
q_dtype=q_dtype,
scaling_mode=scaling_mode,
q_layout=q_layout,
checkpoint_name=checkpoint_name,
**kwargs,
)
)
return quantizers[0] if len(quantizers) == 1 else tuple(quantizers)
......@@ -1089,6 +1124,7 @@ class QuantizerFactory:
bwd_dtype,
is_2x2x,
n_groups,
checkpoint_name: Optional[str] = None,
**kwargs,
) -> QuantizerSet:
"""Create a set of quantizers for forward and backward passes.
......@@ -1101,6 +1137,7 @@ class QuantizerFactory:
bwd_dtype: Data type for backward pass
is_2x2x: Whether to use 2x2x quantization
n_groups
checkpoint_name: Optional name for checkpointing quantizations
**kwargs: Additional arguments for quantizer initialization
Returns:
......@@ -1123,12 +1160,32 @@ class QuantizerFactory:
else:
args_x = args_kernel = args_grad = {}
q_x = QuantizerFactory.create(1, x_scaling_mode, fwd_dtype, q_layout_x, n_groups, **args_x)
q_x = QuantizerFactory.create(
1,
x_scaling_mode,
fwd_dtype,
q_layout_x,
n_groups,
checkpoint_name=checkpoint_name,
**args_x,
)
q_kernel = QuantizerFactory.create(
1, kernel_scaling_mode, fwd_dtype, q_layout_kernel, n_groups, **args_kernel
1,
kernel_scaling_mode,
fwd_dtype,
q_layout_kernel,
n_groups,
checkpoint_name=checkpoint_name,
**args_kernel,
)
q_dgrad = QuantizerFactory.create(
1, grad_scaling_mode, bwd_dtype, q_layout_dgrad, n_groups, **args_grad
1,
grad_scaling_mode,
bwd_dtype,
q_layout_dgrad,
n_groups,
checkpoint_name=checkpoint_name,
**args_grad,
)
return QuantizerSet(x=q_x, kernel=q_kernel, dgrad=q_dgrad)
......@@ -1140,6 +1197,7 @@ class QuantizerFactory:
bwd_dtype: jnp.dtype = None,
is_2x2x: bool = None,
n_groups: int = None,
checkpoint_name: Optional[str] = None,
# TODO(jberchtold): rename fp8_recipe to quantization_recipe
fp8_recipe: Optional[recipe.Recipe] = None,
**kwargs,
......@@ -1153,6 +1211,7 @@ class QuantizerFactory:
bwd_dtype: Data type for backward pass, default is get_quantize_config().BWD_DTYPE
is_2x2x: Whether to use 2x2x quantization, default is get_quantize_config().IF_QUANTIZE_2X
n_groups:
checkpoint_name: Optional name for checkpointing quantizations
fp8_recipe: Recipe to use for quantization. Scaling mode can be specified directly via the scaling_mode parameter or indirectly via recipe. Recipe is preferred as it will support additional recipes in future where scaling mode differs between x, kernel, and grad in the quantizer set.
**kwargs: Additional arguments for quantizer initialization
......@@ -1208,6 +1267,7 @@ class QuantizerFactory:
bwd_dtype=bwd_dtype,
is_2x2x=is_2x2x,
n_groups=n_groups,
checkpoint_name=checkpoint_name,
**kwargs,
)
)
......
......@@ -14,6 +14,7 @@ from abc import ABC, abstractmethod
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
from jax.ad_checkpoint import checkpoint_name as jax_checkpoint_name
from .scaling_modes import ScalingMode, TensorUsage
......@@ -89,6 +90,17 @@ class AbstractBaseTensor(ABC):
The tensor with applied sharding constraints
"""
@abstractmethod
def checkpoint(self, quantizer):
"""Checkpoints the tensor with the given quantizer's checkpoint name if available.
Args:
quantizer: The quantizer to use for checkpointing. If None, no checkpointing is applied.
Returns:
The checkpointed tensor
"""
@dataclass
class AbstractBaseTensor1x(AbstractBaseTensor):
......@@ -150,6 +162,18 @@ class NoScaleTensor(AbstractBaseTensor1x):
amax=self.amax,
)
def checkpoint(self, quantizer):
"""Checkpoints the tensor with the given quantizer's checkpoint name if available.
Args:
quantizer: The quantizer to use for checkpointing. If None, no checkpointing is applied.
Returns:
The checkpointed tensor
"""
assert quantizer is None, "NoScaleTensor does not support quantization."
return self
class ScaledTensor(ABC):
"""Abstract base class for scaled tensors."""
......@@ -317,6 +341,20 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor):
has_rht_applied=self.has_rht_applied,
)
def checkpoint(self, quantizer):
"""Checkpoints the tensor with the given quantizer's checkpoint name if available.
Args:
quantizer: The quantizer to use for checkpointing. If None, no checkpointing is applied.
Returns:
The checkpointed tensor
"""
if quantizer is None or quantizer.checkpoint_name is None:
return self
return jax_checkpoint_name(self, name=quantizer.checkpoint_name)
@register_pytree_node_class
@dataclass
......@@ -420,6 +458,20 @@ class GroupedScaledTensor1x(ScaledTensor1x):
def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
raise NotImplementedError
def checkpoint(self, quantizer):
"""Checkpoints the tensor with the given quantizer's checkpoint name if available.
Args:
quantizer: The quantizer to use for checkpointing. If None, no checkpointing is applied.
Returns:
The checkpointed tensor
"""
if quantizer is None or quantizer.checkpoint_name is None:
return self
return jax_checkpoint_name(self, name=quantizer.checkpoint_name)
@register_pytree_node_class
@dataclass
......@@ -496,6 +548,9 @@ class ScaledTensor2x(AbstractBaseTensor, ScaledTensor):
return ScaledTensor2x(rowwise_tensor, colwise_tensor)
def checkpoint(self, quantizer):
raise NotImplementedError
@dataclass
class ScaledTensorFactory:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment