Unverified Commit 858755c0 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] TE GEMM checkpointing policies (#2003)



* TE primitive checkpointing policies
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Remove batched gemm policy
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent 11ac24cf
...@@ -19,6 +19,10 @@ Variables are available in `transformer_engine.jax.sharding`. ...@@ -19,6 +19,10 @@ Variables are available in `transformer_engine.jax.sharding`.
* JOINED_AXES: The logical axis of non-defined dimension. It is usually not sharded. * JOINED_AXES: The logical axis of non-defined dimension. It is usually not sharded.
Checkpointing
------------------------------------
When using checkpointing with Transformer Engine JAX, please be aware of the checkpointing policy being applied to your model. Any JAX checkpointing policy using `dot`, such as `jax.checkpoint_policies.dots_with_no_batch_dims`, may not work with GEMMs provided by Transformer Engine as they do not always use the `jax.lax.dot_general` primitive. Instead, you can use `transformer_engine.jax.checkpoint_policies.dots_and_te_gemms_with_no_batch_dims` or similar policies that are designed to work with Transformer Engine's GEMMs and `jax.lax.dot_general` GEMMs. You may also use any JAX policies that do not filter by primitive, such as `jax.checkpoint_policies.save_only_these_names` or `jax.checkpoint_policies.everything_saveable`.
Modules Modules
------------------------------------ ------------------------------------
.. autoapiclass:: transformer_engine.jax.flax.TransformerLayerType .. autoapiclass:: transformer_engine.jax.flax.TransformerLayerType
......
...@@ -855,7 +855,7 @@ def fused_attn_thd( ...@@ -855,7 +855,7 @@ def fused_attn_thd(
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14)) @partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15))
def _fused_attn( def _fused_attn(
qkv: Tuple[jnp.ndarray, ...], qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray], bias: Optional[jnp.ndarray],
...@@ -872,6 +872,7 @@ def _fused_attn( ...@@ -872,6 +872,7 @@ def _fused_attn(
context_parallel_strategy: CPStrategy, context_parallel_strategy: CPStrategy,
context_parallel_causal_load_balanced: bool, context_parallel_causal_load_balanced: bool,
context_parallel_axis: str, context_parallel_axis: str,
context_checkpoint_name: str = "context",
): ):
output, _ = _fused_attn_fwd_rule( output, _ = _fused_attn_fwd_rule(
qkv, qkv,
...@@ -889,6 +890,7 @@ def _fused_attn( ...@@ -889,6 +890,7 @@ def _fused_attn(
context_parallel_strategy, context_parallel_strategy,
context_parallel_causal_load_balanced, context_parallel_causal_load_balanced,
context_parallel_axis, context_parallel_axis,
context_checkpoint_name=context_checkpoint_name,
) )
return output return output
...@@ -909,6 +911,7 @@ def _fused_attn_fwd_rule( ...@@ -909,6 +911,7 @@ def _fused_attn_fwd_rule(
context_parallel_strategy, context_parallel_strategy,
context_parallel_causal_load_balanced, context_parallel_causal_load_balanced,
context_parallel_axis, context_parallel_axis,
context_checkpoint_name,
): ):
output, softmax_aux, rng_state = tex.fused_attn_fwd( output, softmax_aux, rng_state = tex.fused_attn_fwd(
qkv, qkv,
...@@ -927,9 +930,9 @@ def _fused_attn_fwd_rule( ...@@ -927,9 +930,9 @@ def _fused_attn_fwd_rule(
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis, context_parallel_axis=context_parallel_axis,
) )
output = checkpoint_name(output, "context") output = checkpoint_name(output, context_checkpoint_name)
softmax_aux = checkpoint_name(softmax_aux, "context") softmax_aux = checkpoint_name(softmax_aux, context_checkpoint_name)
rng_state = checkpoint_name(rng_state, "context") rng_state = checkpoint_name(rng_state, context_checkpoint_name)
return output, ( return output, (
qkv, qkv,
bias, bias,
...@@ -952,9 +955,11 @@ def _fused_attn_bwd_rule( ...@@ -952,9 +955,11 @@ def _fused_attn_bwd_rule(
context_parallel_strategy, context_parallel_strategy,
context_parallel_causal_load_balanced, context_parallel_causal_load_balanced,
context_parallel_axis, context_parallel_axis,
context_checkpoint_name,
ctx, ctx,
dz, dz,
): ):
del context_checkpoint_name
( (
qkv, qkv,
bias, bias,
...@@ -1012,6 +1017,7 @@ def fused_attn( ...@@ -1012,6 +1017,7 @@ def fused_attn(
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT, context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False, context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "", context_parallel_axis: str = "",
context_checkpoint_name: str = "context",
): ):
""" """
Perform cuDNN fused attention. Perform cuDNN fused attention.
...@@ -1044,6 +1050,7 @@ def fused_attn( ...@@ -1044,6 +1050,7 @@ def fused_attn(
context_parallel_causal_load_balanced (bool): context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism. Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis. context_parallel_axis (str): The name of the context parallel axis.
context_checkpoint_name (str): The name of the context checkpoint for the custom VJP forward pass.
Returns: Returns:
(jnp.ndarray): The output tensor from the fused attention. (jnp.ndarray): The output tensor from the fused attention.
...@@ -1116,6 +1123,7 @@ def fused_attn( ...@@ -1116,6 +1123,7 @@ def fused_attn(
context_parallel_strategy=context_parallel_strategy, context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis, context_parallel_axis=context_parallel_axis,
context_checkpoint_name=context_checkpoint_name,
) )
return output return output
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Checkpoint policies for Transformer Engine in JAX.
This module provides JAX checkpoint policies that are compatible with Transformer Engine's custom primitives.
"""
import jax
from .cpp_extensions.gemm import GemmPrimitive, GroupedGemmPrimitive
__all__ = [
"te_gemms_saveable",
"dots_and_te_gemms_with_no_batch_dims",
"checkpoint_dots_and_te_gemms",
]
def te_gemms_saveable(prim, *_, **__) -> bool:
"""Checkpoint policy for Transformer Engine GEMMs."""
is_te_gemm = prim in {GemmPrimitive.outer_primitive, GroupedGemmPrimitive.outer_primitive}
# Workaround to include JAX's scaled_matmul until JAX checkpoint policies for dots are
# updated to include it.
is_jax_scaled_matmul = prim.name == "scaled_matmul_wrapper"
return is_te_gemm or is_jax_scaled_matmul
dots_and_te_gemms_with_no_batch_dims = jax.checkpoint_policies.save_from_both_policies(
jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims,
te_gemms_saveable,
)
checkpoint_dots_and_te_gemms = jax.checkpoint_policies.save_from_both_policies(
jax.checkpoint_policies.checkpoint_dots,
te_gemms_saveable,
)
...@@ -940,6 +940,11 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -940,6 +940,11 @@ class LayerNormMLP(TransformerEngineBase):
Indicate the logical axes of sharding constraint to the input of 2nd dot, like Indicate the logical axes of sharding constraint to the input of 2nd dot, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint. sharding constraint.
ffn1_ckpt_name: str = "ffn1"
Checkpoint name for the output of the first fully-connected layer in the MLP block.
ffn2_ckpt_name: str = "ffn2"
Checkpoint name for the output of the second fully-connected layer in the MLP block.
Optimization parameters Optimization parameters
----------------------- -----------------------
...@@ -981,6 +986,8 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -981,6 +986,8 @@ class LayerNormMLP(TransformerEngineBase):
layernorm_input_axes: Tuple[str, ...] = None layernorm_input_axes: Tuple[str, ...] = None
dot_1_input_axes: Tuple[str, ...] = None dot_1_input_axes: Tuple[str, ...] = None
dot_2_input_axes: Tuple[str, ...] = None dot_2_input_axes: Tuple[str, ...] = None
ffn1_ckpt_name: str = "ffn1"
ffn2_ckpt_name: str = "ffn2"
def __post_init__(self): def __post_init__(self):
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
...@@ -1150,9 +1157,6 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1150,9 +1157,6 @@ class LayerNormMLP(TransformerEngineBase):
bias_1 = None bias_1 = None
bias_2 = None bias_2 = None
ffn1_ckpt_name = "ffn1"
ffn2_ckpt_name = "ffn2"
if use_fused_layernorm_mlp: if use_fused_layernorm_mlp:
out = layernorm_mlp( out = layernorm_mlp(
y, y,
...@@ -1168,8 +1172,8 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1168,8 +1172,8 @@ class LayerNormMLP(TransformerEngineBase):
dot_2_input_axes=self.dot_2_input_axes, dot_2_input_axes=self.dot_2_input_axes,
kernel_1_axes=self.kernel_axes_1, kernel_1_axes=self.kernel_axes_1,
kernel_2_axes=self.kernel_axes_2, kernel_2_axes=self.kernel_axes_2,
ffn1_ckpt_name=ffn1_ckpt_name, ffn1_ckpt_name=self.ffn1_ckpt_name,
ffn2_ckpt_name=ffn2_ckpt_name, ffn2_ckpt_name=self.ffn2_ckpt_name,
activation_type=normalized_acts, activation_type=normalized_acts,
quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set), quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set),
) )
...@@ -1251,7 +1255,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1251,7 +1255,7 @@ class LayerNormMLP(TransformerEngineBase):
if self.use_bias: if self.use_bias:
x += jnp.reshape(bias_1, bias_1_shape) x += jnp.reshape(bias_1, bias_1_shape)
x = checkpoint_name(x, ffn1_ckpt_name) x = checkpoint_name(x, self.ffn1_ckpt_name)
if is_act_implemented: if is_act_implemented:
z = activation(x, normalized_acts) z = activation(x, normalized_acts)
else: else:
...@@ -1314,7 +1318,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1314,7 +1318,7 @@ class LayerNormMLP(TransformerEngineBase):
if self.use_bias: if self.use_bias:
out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,)) out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
out = checkpoint_name(out, ffn2_ckpt_name) out = checkpoint_name(out, self.ffn2_ckpt_name)
assert out.dtype == input_dtype assert out.dtype == input_dtype
return out, ln_output # Output, layner_norm_output return out, ln_output # Output, layner_norm_output
...@@ -274,6 +274,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -274,6 +274,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq: Optional[int] = 1 max_segments_per_seq: Optional[int] = 1
context_parallel_causal_load_balanced: bool = False context_parallel_causal_load_balanced: bool = False
context_parallel_axis: str = "" context_parallel_axis: str = ""
context_checkpoint_name: str = "context"
@nn.compact @nn.compact
def __call__( def __call__(
...@@ -322,6 +323,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -322,6 +323,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq=self.max_segments_per_seq, max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis, context_parallel_axis=self.context_parallel_axis,
context_checkpoint_name=self.context_checkpoint_name,
) )
elif self.qkv_layout.is_kvpacked(): elif self.qkv_layout.is_kvpacked():
"""kvpacked format, treat """kvpacked format, treat
...@@ -348,6 +350,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -348,6 +350,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq=self.max_segments_per_seq, max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis, context_parallel_axis=self.context_parallel_axis,
context_checkpoint_name=self.context_checkpoint_name,
) )
elif self.qkv_layout.is_separate(): elif self.qkv_layout.is_separate():
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
...@@ -369,6 +372,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -369,6 +372,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq=self.max_segments_per_seq, max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis, context_parallel_axis=self.context_parallel_axis,
context_checkpoint_name=self.context_checkpoint_name,
) )
else: else:
raise ValueError(f"Unsupported {self.qkv_layout=}.") raise ValueError(f"Unsupported {self.qkv_layout=}.")
...@@ -501,6 +505,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -501,6 +505,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
context_parallel_causal_load_balanced (bool): context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism. Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis. context_parallel_axis (str): The name of the context parallel axis.
context_checkpoint_name (str): The name of the context checkpoint in the forward pass of fused attention.
Optimization parameters Optimization parameters
----------------------- -----------------------
...@@ -524,6 +529,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -524,6 +529,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
max_segments_per_seq: Optional[int] = 1 max_segments_per_seq: Optional[int] = 1
context_parallel_causal_load_balanced: bool = False context_parallel_causal_load_balanced: bool = False
context_parallel_axis: str = "" context_parallel_axis: str = ""
context_checkpoint_name: str = "context"
@nn.compact @nn.compact
def __call__( def __call__(
...@@ -690,6 +696,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -690,6 +696,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
max_segments_per_seq=self.max_segments_per_seq, max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis, context_parallel_axis=self.context_parallel_axis,
context_checkpoint_name=self.context_checkpoint_name,
)( )(
query, query,
key, key,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment