Unverified Commit c898ab1b authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

[JAX] Add checkpoint_name for the recompute granularity control (#542)



Add checkpoint_name
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent 92c1e500
......@@ -16,6 +16,7 @@ from flax.linen import partitioning as nn_partitioning
from jax import lax
from jax import nn as jax_nn
from jax import random as jax_random
from jax.ad_checkpoint import checkpoint_name
from ..dot import type_safe_dot_general
from ..fp8 import FP8Helper, FP8MetaPackage
......@@ -923,6 +924,8 @@ class LayerNormMLP(TransformerEngineBase):
bias_shape = (1,) * (x.ndim - bias.ndim) + bias.shape
x += jnp.reshape(bias, bias_shape)
x = checkpoint_name(x, 'ffn1')
activations = []
if is_geglu(self.activations):
z = geglu(x)
......@@ -957,4 +960,6 @@ class LayerNormMLP(TransformerEngineBase):
bias = bias.astype(self.dtype)
out += jnp.reshape(bias, (1,) * (out.ndim - 1) + (-1,))
out = checkpoint_name(out, 'ffn2')
return out, ln_output # Output, layner_norm_output
......@@ -20,6 +20,7 @@ from jax import dtypes
from jax import nn as jax_nn
from jax import random as jax_random
from jax import lax, vmap
from jax.ad_checkpoint import checkpoint_name
from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax
......@@ -211,6 +212,8 @@ def core_attention(query: Array,
else:
attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key)
attn_weights = checkpoint_name(attn_weights, 'logits')
attn_weights = _with_sharding_constraint(attn_weights,
(BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES))
......@@ -499,6 +502,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
bias_axes=(W_JOINED_AXES, W_TP_AXES),
name='qkv',
dtype=self.dtype)(inputs_q)
qkv_proj = checkpoint_name(qkv_proj, 'combined_qkv_proj')
if not use_fused_attn:
query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2)
else:
......@@ -530,6 +534,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
bias_axes=(W_JOINED_AXES, W_TP_AXES),
name='kv',
dtype=self.dtype)(inputs_kv)
kv_proj = checkpoint_name(kv_proj, 'combined_kv_proj')
if not use_fused_attn:
key, value = jnp.split(kv_proj, [1], axis=-2)
else:
......@@ -574,6 +579,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
residual = ln_out
if not use_fused_attn:
query = checkpoint_name(query, 'query_proj')
key = checkpoint_name(key, 'key_proj')
value = checkpoint_name(value, 'value_proj')
query = query.reshape((query.shape[0], query.shape[1], self.num_heads, self.head_dim))
key = key.reshape((key.shape[0], key.shape[1], self.num_heads, self.head_dim))
value = value.reshape((value.shape[0], value.shape[1], self.num_heads, self.head_dim))
......@@ -706,6 +714,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
dtype=self.dtype,
float32_logits=self.float32_logits)
x = checkpoint_name(x, 'context')
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
attn_context_sharding_constraint = \
......@@ -724,6 +734,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
bias_axes=(W_NO_SHARD_AXES,),
dtype=self.dtype,
name='out')(x)
out = checkpoint_name(out, 'out_proj')
return out, residual
......
......@@ -5,6 +5,7 @@
from enum import Enum
from functools import partial
from jax.ad_checkpoint import checkpoint_name
import jax
import jax.numpy as jnp
......@@ -91,6 +92,9 @@ def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.nda
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
output = checkpoint_name(output, 'context')
softmax_aux = checkpoint_name(softmax_aux, 'context')
rng_state = checkpoint_name(rng_state, 'context')
return output, (qkv, bias, softmax_aux, rng_state, output, squeezed_mask)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment