Unverified Commit 858755c0 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] TE GEMM checkpointing policies (#2003)



* TE primitive checkpointing policies
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Remove batched gemm policy
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent 11ac24cf
......@@ -19,6 +19,10 @@ Variables are available in `transformer_engine.jax.sharding`.
* JOINED_AXES: The logical axis of non-defined dimension. It is usually not sharded.
Checkpointing
------------------------------------
When using checkpointing with Transformer Engine JAX, please be aware of the checkpointing policy being applied to your model. Any JAX checkpointing policy using `dot`, such as `jax.checkpoint_policies.dots_with_no_batch_dims`, may not work with GEMMs provided by Transformer Engine as they do not always use the `jax.lax.dot_general` primitive. Instead, you can use `transformer_engine.jax.checkpoint_policies.dots_and_te_gemms_with_no_batch_dims` or similar policies that are designed to work with Transformer Engine's GEMMs and `jax.lax.dot_general` GEMMs. You may also use any JAX policies that do not filter by primitive, such as `jax.checkpoint_policies.save_only_these_names` or `jax.checkpoint_policies.everything_saveable`.
Modules
------------------------------------
.. autoapiclass:: transformer_engine.jax.flax.TransformerLayerType
......
......@@ -855,7 +855,7 @@ def fused_attn_thd(
return output
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14))
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15))
def _fused_attn(
qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray],
......@@ -872,6 +872,7 @@ def _fused_attn(
context_parallel_strategy: CPStrategy,
context_parallel_causal_load_balanced: bool,
context_parallel_axis: str,
context_checkpoint_name: str = "context",
):
output, _ = _fused_attn_fwd_rule(
qkv,
......@@ -889,6 +890,7 @@ def _fused_attn(
context_parallel_strategy,
context_parallel_causal_load_balanced,
context_parallel_axis,
context_checkpoint_name=context_checkpoint_name,
)
return output
......@@ -909,6 +911,7 @@ def _fused_attn_fwd_rule(
context_parallel_strategy,
context_parallel_causal_load_balanced,
context_parallel_axis,
context_checkpoint_name,
):
output, softmax_aux, rng_state = tex.fused_attn_fwd(
qkv,
......@@ -927,9 +930,9 @@ def _fused_attn_fwd_rule(
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)
output = checkpoint_name(output, "context")
softmax_aux = checkpoint_name(softmax_aux, "context")
rng_state = checkpoint_name(rng_state, "context")
output = checkpoint_name(output, context_checkpoint_name)
softmax_aux = checkpoint_name(softmax_aux, context_checkpoint_name)
rng_state = checkpoint_name(rng_state, context_checkpoint_name)
return output, (
qkv,
bias,
......@@ -952,9 +955,11 @@ def _fused_attn_bwd_rule(
context_parallel_strategy,
context_parallel_causal_load_balanced,
context_parallel_axis,
context_checkpoint_name,
ctx,
dz,
):
del context_checkpoint_name
(
qkv,
bias,
......@@ -1012,6 +1017,7 @@ def fused_attn(
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
context_checkpoint_name: str = "context",
):
"""
Perform cuDNN fused attention.
......@@ -1044,6 +1050,7 @@ def fused_attn(
context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
context_checkpoint_name (str): The name of the context checkpoint for the custom VJP forward pass.
Returns:
(jnp.ndarray): The output tensor from the fused attention.
......@@ -1116,6 +1123,7 @@ def fused_attn(
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
context_checkpoint_name=context_checkpoint_name,
)
return output
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Checkpoint policies for Transformer Engine in JAX.
This module provides JAX checkpoint policies that are compatible with Transformer Engine's custom primitives.
"""
import jax
from .cpp_extensions.gemm import GemmPrimitive, GroupedGemmPrimitive
__all__ = [
"te_gemms_saveable",
"dots_and_te_gemms_with_no_batch_dims",
"checkpoint_dots_and_te_gemms",
]
def te_gemms_saveable(prim, *_, **__) -> bool:
"""Checkpoint policy for Transformer Engine GEMMs."""
is_te_gemm = prim in {GemmPrimitive.outer_primitive, GroupedGemmPrimitive.outer_primitive}
# Workaround to include JAX's scaled_matmul until JAX checkpoint policies for dots are
# updated to include it.
is_jax_scaled_matmul = prim.name == "scaled_matmul_wrapper"
return is_te_gemm or is_jax_scaled_matmul
dots_and_te_gemms_with_no_batch_dims = jax.checkpoint_policies.save_from_both_policies(
jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims,
te_gemms_saveable,
)
checkpoint_dots_and_te_gemms = jax.checkpoint_policies.save_from_both_policies(
jax.checkpoint_policies.checkpoint_dots,
te_gemms_saveable,
)
......@@ -940,6 +940,11 @@ class LayerNormMLP(TransformerEngineBase):
Indicate the logical axes of sharding constraint to the input of 2nd dot, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint.
ffn1_ckpt_name: str = "ffn1"
Checkpoint name for the output of the first fully-connected layer in the MLP block.
ffn2_ckpt_name: str = "ffn2"
Checkpoint name for the output of the second fully-connected layer in the MLP block.
Optimization parameters
-----------------------
......@@ -981,6 +986,8 @@ class LayerNormMLP(TransformerEngineBase):
layernorm_input_axes: Tuple[str, ...] = None
dot_1_input_axes: Tuple[str, ...] = None
dot_2_input_axes: Tuple[str, ...] = None
ffn1_ckpt_name: str = "ffn1"
ffn2_ckpt_name: str = "ffn2"
def __post_init__(self):
if self.transpose_batch_sequence:
......@@ -1150,9 +1157,6 @@ class LayerNormMLP(TransformerEngineBase):
bias_1 = None
bias_2 = None
ffn1_ckpt_name = "ffn1"
ffn2_ckpt_name = "ffn2"
if use_fused_layernorm_mlp:
out = layernorm_mlp(
y,
......@@ -1168,8 +1172,8 @@ class LayerNormMLP(TransformerEngineBase):
dot_2_input_axes=self.dot_2_input_axes,
kernel_1_axes=self.kernel_axes_1,
kernel_2_axes=self.kernel_axes_2,
ffn1_ckpt_name=ffn1_ckpt_name,
ffn2_ckpt_name=ffn2_ckpt_name,
ffn1_ckpt_name=self.ffn1_ckpt_name,
ffn2_ckpt_name=self.ffn2_ckpt_name,
activation_type=normalized_acts,
quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set),
)
......@@ -1251,7 +1255,7 @@ class LayerNormMLP(TransformerEngineBase):
if self.use_bias:
x += jnp.reshape(bias_1, bias_1_shape)
x = checkpoint_name(x, ffn1_ckpt_name)
x = checkpoint_name(x, self.ffn1_ckpt_name)
if is_act_implemented:
z = activation(x, normalized_acts)
else:
......@@ -1314,7 +1318,7 @@ class LayerNormMLP(TransformerEngineBase):
if self.use_bias:
out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
out = checkpoint_name(out, ffn2_ckpt_name)
out = checkpoint_name(out, self.ffn2_ckpt_name)
assert out.dtype == input_dtype
return out, ln_output # Output, layner_norm_output
......@@ -274,6 +274,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq: Optional[int] = 1
context_parallel_causal_load_balanced: bool = False
context_parallel_axis: str = ""
context_checkpoint_name: str = "context"
@nn.compact
def __call__(
......@@ -322,6 +323,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
context_checkpoint_name=self.context_checkpoint_name,
)
elif self.qkv_layout.is_kvpacked():
"""kvpacked format, treat
......@@ -348,6 +350,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
context_checkpoint_name=self.context_checkpoint_name,
)
elif self.qkv_layout.is_separate():
if self.transpose_batch_sequence:
......@@ -369,6 +372,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
context_checkpoint_name=self.context_checkpoint_name,
)
else:
raise ValueError(f"Unsupported {self.qkv_layout=}.")
......@@ -501,6 +505,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
context_checkpoint_name (str): The name of the context checkpoint in the forward pass of fused attention.
Optimization parameters
-----------------------
......@@ -524,6 +529,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
max_segments_per_seq: Optional[int] = 1
context_parallel_causal_load_balanced: bool = False
context_parallel_axis: str = ""
context_checkpoint_name: str = "context"
@nn.compact
def __call__(
......@@ -690,6 +696,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
context_checkpoint_name=self.context_checkpoint_name,
)(
query,
key,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment