Unverified Commit 442513c5 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Add tutorial for integrating TE/JAX quantization into an existing framework (#2423)



* Tutorial for integration te/jax quantization into an existing framework
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* add todos
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* support nvfp4 sr rng key, move wrapper module into TE itself, fix bfloat16 cast
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* update docstrings
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix QKV proj and out proj in Flax example transformer layer
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Use fused attention in quickstart_jax example
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remat policy
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* add tutorial to docs
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* update title
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* remove unused dtype from TE DPA module
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix notebook title
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Add explanation of flax module wrapper
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 5c2f2ff5
...@@ -53,7 +53,7 @@ ...@@ -53,7 +53,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 26, "execution_count": null,
"id": "d5284a38", "id": "d5284a38",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
...@@ -67,7 +67,7 @@ ...@@ -67,7 +67,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 27, "execution_count": 2,
"id": "a4d1cfdc", "id": "a4d1cfdc",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
...@@ -142,6 +142,9 @@ ...@@ -142,6 +142,9 @@
" \n", " \n",
" # Reshape output from [batch, seq_len, num_heads, kv_channels] to [batch, seq_len, hidden_size]\n", " # Reshape output from [batch, seq_len, num_heads, kv_channels] to [batch, seq_len, hidden_size]\n",
" x = x.reshape(x.shape[0], x.shape[1], self.hidden_size)\n", " x = x.reshape(x.shape[0], x.shape[1], self.hidden_size)\n",
"\n",
" # Output projection\n",
" x = nn.Dense(features=self.hidden_size, use_bias=True)(x)\n",
" \n", " \n",
" x = res + x\n", " x = res + x\n",
" \n", " \n",
...@@ -171,7 +174,7 @@ ...@@ -171,7 +174,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 28, "execution_count": 3,
"id": "8b44649d", "id": "8b44649d",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
...@@ -192,7 +195,7 @@ ...@@ -192,7 +195,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 29, "execution_count": 4,
"id": "e44ed26d", "id": "e44ed26d",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -201,7 +204,7 @@ ...@@ -201,7 +204,7 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Pure Flax FlaxTransformerLayer initialized successfully!\n", "Pure Flax FlaxTransformerLayer initialized successfully!\n",
"Parameter shapes: {'params': {'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'FlaxMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}}}\n" "Parameter shapes: {'params': {'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'Dense_1': {'bias': (4096,), 'kernel': (4096, 4096)}, 'FlaxMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}}}\n"
] ]
} }
], ],
...@@ -222,7 +225,7 @@ ...@@ -222,7 +225,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 30, "execution_count": 5,
"id": "de91af7a", "id": "de91af7a",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -248,7 +251,7 @@ ...@@ -248,7 +251,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 31, "execution_count": 6,
"id": "037bc8d9", "id": "037bc8d9",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -256,7 +259,7 @@ ...@@ -256,7 +259,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Mean time: 18.546080589294434 ms\n" "Mean time: 19.258604049682617 ms\n"
] ]
} }
], ],
...@@ -270,8 +273,8 @@ ...@@ -270,8 +273,8 @@
" variables=params,\n", " variables=params,\n",
" input=x,\n", " input=x,\n",
" output_grad=dy,\n", " output_grad=dy,\n",
" dropout_key=dropout_key,\n",
" forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n", " forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
" rngs={\"dropout\": dropout_key},\n",
")" ")"
] ]
}, },
...@@ -304,7 +307,7 @@ ...@@ -304,7 +307,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 32, "execution_count": 7,
"id": "bed20d6b", "id": "bed20d6b",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
...@@ -323,14 +326,11 @@ ...@@ -323,14 +326,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 33, "execution_count": 8,
"id": "56105579", "id": "56105579",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from transformer_engine.jax.flax.transformer import DotProductAttention as TEDotProductAttention\n",
"\n",
"\n",
"class TEUnfusedMLP(nn.Module):\n", "class TEUnfusedMLP(nn.Module):\n",
" hidden_size : int\n", " hidden_size : int\n",
" ffn_hidden_size: int\n", " ffn_hidden_size: int\n",
...@@ -360,11 +360,7 @@ ...@@ -360,11 +360,7 @@
" x: jnp.ndarray,\n", " x: jnp.ndarray,\n",
" attention_mask: Optional[jnp.ndarray] = None,\n", " attention_mask: Optional[jnp.ndarray] = None,\n",
" deterministic: bool = False\n", " deterministic: bool = False\n",
" ) -> jnp.ndarray:\n", " ) -> jnp.ndarray: \n",
" # Create causal mask if not provided\n",
" if attention_mask is None:\n",
" attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_)\n",
" \n",
" res = x\n", " res = x\n",
" x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x)\n", " x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x)\n",
"\n", "\n",
...@@ -376,14 +372,19 @@ ...@@ -376,14 +372,19 @@
" # Attention - either TE or Flax implementation\n", " # Attention - either TE or Flax implementation\n",
" if self.use_te_attention:\n", " if self.use_te_attention:\n",
" # Use TE's DotProductAttention\n", " # Use TE's DotProductAttention\n",
" attention = TEDotProductAttention(\n", " attention = te_flax.DotProductAttention(\n",
" head_dim=self.kv_channels,\n", " head_dim=self.kv_channels,\n",
" num_attention_heads=self.num_attention_heads,\n", " num_attention_heads=self.num_attention_heads,\n",
" num_gqa_groups=self.num_attention_heads, # No GQA\n", " num_gqa_groups=self.num_attention_heads, # No GQA\n",
" attention_dropout=self.attention_dropout,\n", " attention_dropout=self.attention_dropout,\n",
" attn_mask_type='causal',\n", " attn_mask_type='causal',\n",
" )\n", " )\n",
" x = attention(q, k, v, sequence_descriptor=None, deterministic=deterministic)\n", " x = attention(\n",
" q, k, v,\n",
" # Causal mask does not need an explicit instatiated mask as specialized kernels exist to handle it\n",
" sequence_descriptor=None, \n",
" deterministic=deterministic\n",
" )\n",
" # Reshape from [batch, seq_len, num_heads, head_dim] to [batch, seq_len, hidden_size]\n", " # Reshape from [batch, seq_len, num_heads, head_dim] to [batch, seq_len, hidden_size]\n",
" x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))\n", " x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))\n",
" x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x)\n", " x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x)\n",
...@@ -393,6 +394,10 @@ ...@@ -393,6 +394,10 @@
" q_reshaped = q.reshape(q.shape[0], q.shape[1], self.hidden_size)\n", " q_reshaped = q.reshape(q.shape[0], q.shape[1], self.hidden_size)\n",
" k_reshaped = k.reshape(k.shape[0], k.shape[1], self.hidden_size)\n", " k_reshaped = k.reshape(k.shape[0], k.shape[1], self.hidden_size)\n",
" v_reshaped = v.reshape(v.shape[0], v.shape[1], self.hidden_size)\n", " v_reshaped = v.reshape(v.shape[0], v.shape[1], self.hidden_size)\n",
"\n",
" # Create causal mask if not provided\n",
" if attention_mask is None:\n",
" attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_)\n",
" \n", " \n",
" attention = nn.MultiHeadDotProductAttention(\n", " attention = nn.MultiHeadDotProductAttention(\n",
" num_heads=self.num_attention_heads,\n", " num_heads=self.num_attention_heads,\n",
...@@ -428,7 +433,7 @@ ...@@ -428,7 +433,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 34, "execution_count": 9,
"id": "4b67511f", "id": "4b67511f",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -436,7 +441,7 @@ ...@@ -436,7 +441,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Mean time: 16.375374794006348 ms\n" "Mean time: 16.003193855285645 ms\n"
] ]
} }
], ],
...@@ -455,8 +460,8 @@ ...@@ -455,8 +460,8 @@
" variables=te_params, # Ensure the correct `params` is passed\n", " variables=te_params, # Ensure the correct `params` is passed\n",
" input=x,\n", " input=x,\n",
" output_grad=dy,\n", " output_grad=dy,\n",
" dropout_key=dropout_key,\n",
" forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n", " forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
" rngs={\"dropout\": dropout_key},\n",
")\n" ")\n"
] ]
}, },
...@@ -470,7 +475,7 @@ ...@@ -470,7 +475,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 35, "execution_count": 10,
"id": "5146cd99", "id": "5146cd99",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -478,23 +483,9 @@ ...@@ -478,23 +483,9 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"/code/github/TransformerEngine/transformer_engine/jax/flax/transformer.py:634: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n", "/mnt/jberchtold/ptyche-lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n",
" warnings.warn(\n", " warnings.warn(\n",
"/code/github/TransformerEngine/transformer_engine/jax/flax/transformer.py:742: UserWarning: Fused attention is not enabled because there is no available kernel.\n", "/mnt/jberchtold/ptyche-lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n",
"Fall back to the unfused attention.\n",
"Please try to update the cuDNN and TE to the latest version.\n",
"self.dtype=<class 'jax.numpy.float32'>\n",
"qkv_layout=<QKVLayout.BSHD_BSHD_BSHD: <NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: 9>>\n",
"attn_bias_type=<AttnBiasType.NO_BIAS: <NVTE_Bias_Type.NVTE_NO_BIAS: 0>>\n",
"attn_mask_type=<AttnMaskType.CAUSAL_MASK: <NVTE_Mask_Type.NVTE_CAUSAL_MASK: 2>>\n",
"self.attention_dropout=0.1\n",
"self.num_attention_heads=32\n",
"self.num_gqa_groups=32\n",
"seqlen_q=2048\n",
"seqlen_kv=2048\n",
"head_dim_qk=128\n",
"head_dim_v=128\n",
"\n",
" warnings.warn(\n" " warnings.warn(\n"
] ]
}, },
...@@ -502,7 +493,7 @@ ...@@ -502,7 +493,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Mean time: 12.403340339660645 ms\n" "Mean time: 8.897695541381836 ms\n"
] ]
} }
], ],
...@@ -520,8 +511,8 @@ ...@@ -520,8 +511,8 @@
" variables=te_params, # Ensure the correct `params` is passed\n", " variables=te_params, # Ensure the correct `params` is passed\n",
" input=x,\n", " input=x,\n",
" output_grad=dy,\n", " output_grad=dy,\n",
" dropout_key=dropout_key,\n",
" forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n", " forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
" rngs={\"dropout\": dropout_key},\n",
")" ")"
] ]
}, },
...@@ -553,7 +544,7 @@ ...@@ -553,7 +544,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 36, "execution_count": 11,
"id": "c2eee376", "id": "c2eee376",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
...@@ -565,15 +556,27 @@ ...@@ -565,15 +556,27 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 37, "execution_count": 12,
"id": "de96827c", "id": "de96827c",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/mnt/jberchtold/ptyche-lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n",
" warnings.warn(\n",
"/mnt/jberchtold/ptyche-lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n",
" warnings.warn(\n",
"/mnt/jberchtold/ptyche-lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n",
" warnings.warn(\n"
]
},
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Mean time: 9.396424293518066 ms\n" "Mean time: 5.651178359985352 ms\n"
] ]
} }
], ],
...@@ -589,9 +592,9 @@ ...@@ -589,9 +592,9 @@
" variables=te_unfused_params, # Ensure the correct `params` is passed\n", " variables=te_unfused_params, # Ensure the correct `params` is passed\n",
" input=x,\n", " input=x,\n",
" output_grad=dy,\n", " output_grad=dy,\n",
" dropout_key=dropout_key,\n",
" forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n", " forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
" autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe}\n", " autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe},\n",
" rngs={\"dropout\": dropout_key},\n",
")" ")"
] ]
}, },
...@@ -626,7 +629,7 @@ ...@@ -626,7 +629,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 38, "execution_count": 13,
"id": "11203785", "id": "11203785",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
...@@ -659,7 +662,7 @@ ...@@ -659,7 +662,7 @@
" q, k, v = jnp.split(qkv, 3, axis=3)\n", " q, k, v = jnp.split(qkv, 3, axis=3)\n",
"\n", "\n",
" # Attention using TE's DotProductAttention\n", " # Attention using TE's DotProductAttention\n",
" attention = TEDotProductAttention(\n", " attention = te_flax.DotProductAttention(\n",
" head_dim=self.kv_channels,\n", " head_dim=self.kv_channels,\n",
" num_attention_heads=self.num_attention_heads,\n", " num_attention_heads=self.num_attention_heads,\n",
" num_gqa_groups=self.num_attention_heads, \n", " num_gqa_groups=self.num_attention_heads, \n",
...@@ -697,15 +700,27 @@ ...@@ -697,15 +700,27 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 39, "execution_count": 14,
"id": "6b0c705e", "id": "6b0c705e",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/mnt/jberchtold/ptyche-lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n",
" warnings.warn(\n",
"/mnt/jberchtold/ptyche-lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n",
" warnings.warn(\n",
"/mnt/jberchtold/ptyche-lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n",
" warnings.warn(\n"
]
},
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Mean time: 9.145426750183105 ms\n" "Mean time: 5.493879318237305 ms\n"
] ]
} }
], ],
...@@ -726,9 +741,9 @@ ...@@ -726,9 +741,9 @@
" variables=te_fused_params,\n", " variables=te_fused_params,\n",
" input=x,\n", " input=x,\n",
" output_grad=dy,\n", " output_grad=dy,\n",
" dropout_key=dropout_key,\n",
" forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n", " forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
" autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe}\n", " autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe},\n",
" rngs={\"dropout\": dropout_key},\n",
")" ")"
] ]
}, },
...@@ -742,35 +757,11 @@ ...@@ -742,35 +757,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 40, "execution_count": 15,
"id": "b2aaa8ef", "id": "b2aaa8ef",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stderr",
"output_type": "stream",
"text": [
"/code/github/TransformerEngine/transformer_engine/jax/flax/transformer.py:742: UserWarning: Fused attention is not enabled because there is no available kernel.\n",
"Fall back to the unfused attention.\n",
"Please try to update the cuDNN and TE to the latest version.\n",
"self.dtype=<class 'jax.numpy.float32'>\n",
"qkv_layout=<QKVLayout.BS3HD: <NVTE_QKV_Layout.NVTE_BS3HD: 5>>\n",
"attn_bias_type=<AttnBiasType.NO_BIAS: <NVTE_Bias_Type.NVTE_NO_BIAS: 0>>\n",
"attn_mask_type=<AttnMaskType.CAUSAL_MASK: <NVTE_Mask_Type.NVTE_CAUSAL_MASK: 2>>\n",
"self.attention_dropout=0.1\n",
"self.num_attention_heads=32\n",
"self.num_gqa_groups=32\n",
"seqlen_q=2048\n",
"seqlen_kv=2048\n",
"head_dim_qk=128\n",
"head_dim_v=128\n",
"\n",
" warnings.warn(\n"
]
}
],
"source": [ "source": [
"\n",
"te_transformer = te_flax.TransformerLayer(\n", "te_transformer = te_flax.TransformerLayer(\n",
" hidden_size=hidden_size,\n", " hidden_size=hidden_size,\n",
" mlp_hidden_size=ffn_hidden_size, \n", " mlp_hidden_size=ffn_hidden_size, \n",
...@@ -782,7 +773,7 @@ ...@@ -782,7 +773,7 @@
" intermediate_dropout=0.0,\n", " intermediate_dropout=0.0,\n",
" enable_relative_embedding=False,\n", " enable_relative_embedding=False,\n",
" self_attn_bias_type='no_bias',\n", " self_attn_bias_type='no_bias',\n",
" hidden_dropout=0.0\n", " hidden_dropout=0.0,\n",
")\n", ")\n",
"\n", "\n",
"with te.autocast(enabled=True, recipe=fp8_recipe):\n", "with te.autocast(enabled=True, recipe=fp8_recipe):\n",
...@@ -792,7 +783,7 @@ ...@@ -792,7 +783,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 41, "execution_count": 20,
"id": "b9cdbf22", "id": "b9cdbf22",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -800,7 +791,7 @@ ...@@ -800,7 +791,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Mean time: 9.020795822143555 ms\n" "Mean time: 5.334172248840332 ms\n"
] ]
} }
], ],
...@@ -811,9 +802,9 @@ ...@@ -811,9 +802,9 @@
" variables=te_transformer_params,\n", " variables=te_transformer_params,\n",
" input=x,\n", " input=x,\n",
" output_grad=dy,\n", " output_grad=dy,\n",
" dropout_key=dropout_key,\n",
" forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n", " forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
" autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe }\n", " autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe },\n",
" rngs={\"dropout\": dropout_key},\n",
")" ")"
] ]
} }
......
...@@ -19,12 +19,12 @@ def speedometer( ...@@ -19,12 +19,12 @@ def speedometer(
variables: Any, variables: Any,
input: jnp.ndarray, input: jnp.ndarray,
output_grad: jnp.ndarray, output_grad: jnp.ndarray,
dropout_key: jax.random.PRNGKey,
model_init_fn: Callable = None, model_init_fn: Callable = None,
forward_kwargs: dict = {}, forward_kwargs: dict = {},
autocast_kwargs: Optional[dict] = None, autocast_kwargs: Optional[dict] = None,
timing_iters: int = 50, timing_iters: int = 50,
warmup_iters: int = 50, warmup_iters: int = 50,
rngs: Dict[str, jax.random.PRNGKey] = None,
) -> None: ) -> None:
"""Measure average runtime for a JAX module """Measure average runtime for a JAX module
Perform forward and backward passes . Perform forward and backward passes .
...@@ -33,19 +33,21 @@ def speedometer( ...@@ -33,19 +33,21 @@ def speedometer(
autocast_kwargs = {"enabled": False} autocast_kwargs = {"enabled": False}
model_init_fn = None model_init_fn = None
if rngs is None:
rngs = {}
train_step_fn = create_train_step_fn(model_apply_fn, autocast_kwargs, forward_kwargs) train_step_fn = create_train_step_fn(model_apply_fn, autocast_kwargs, forward_kwargs)
# Warm up runs # Warm up runs
key = dropout_key
for _ in range(warmup_iters): for _ in range(warmup_iters):
key, step_key = jax.random.split(key) rngs, step_rngs = _split_step_rngs(rngs)
loss, (param_grads, other_grads) = train_step_fn(variables, input, output_grad, step_key) loss, (param_grads, other_grads) = train_step_fn(variables, input, output_grad, step_rngs)
# Timing runs # Timing runs
start = time.time() start = time.time()
for _ in range(timing_iters): for _ in range(timing_iters):
key, step_key = jax.random.split(key) rngs, step_rngs = _split_step_rngs(rngs)
loss, (param_grads, other_grads) = train_step_fn(variables, input, output_grad, step_key) loss, (param_grads, other_grads) = train_step_fn(variables, input, output_grad, step_rngs)
end = time.time() end = time.time()
print(f"Mean time: {(end - start) * 1000 / timing_iters} ms") print(f"Mean time: {(end - start) * 1000 / timing_iters} ms")
...@@ -63,8 +65,12 @@ def create_train_step_fn( ...@@ -63,8 +65,12 @@ def create_train_step_fn(
if forward_kwargs is None: if forward_kwargs is None:
forward_kwargs = {} forward_kwargs = {}
def loss_fn(variables: Any, inp: jnp.ndarray, grad_target: jnp.ndarray, dropout_key): def loss_fn(
rngs = {"dropout": dropout_key} variables: Any,
inp: jnp.ndarray,
grad_target: jnp.ndarray,
rngs: Dict[str, jax.random.PRNGKey],
):
with te.autocast(**autocast_kwargs): with te.autocast(**autocast_kwargs):
# Forward Pass: Apply the model using current parameters and variables # Forward Pass: Apply the model using current parameters and variables
call_kwargs = {**forward_kwargs, "rngs": rngs} call_kwargs = {**forward_kwargs, "rngs": rngs}
...@@ -84,3 +90,16 @@ def create_train_step_fn( ...@@ -84,3 +90,16 @@ def create_train_step_fn(
# JIT-compile the fwd_bwd_fn # JIT-compile the fwd_bwd_fn
return jax.jit(fwd_bwd_fn) return jax.jit(fwd_bwd_fn)
def _split_step_rngs(
rngs: Dict[str, jax.random.PRNGKey],
) -> Tuple[Dict[str, jax.random.PRNGKey], Dict[str, jax.random.PRNGKey]]:
"""Splits each RNG in the rngs dictionary for a new step."""
step_rngs = {}
new_rngs = {}
for name, key in rngs.items():
new_key, step_key = jax.random.split(key)
new_rngs[name] = new_key
step_rngs[name] = step_key
return new_rngs, step_rngs
{
"cells": [
{
"cell_type": "markdown",
"id": "962d87bb",
"metadata": {},
"source": [
"\n",
"\n",
"# JAX: Integrating TE into an existing framework\n",
"\n",
"This tutorial will cover how to integrate TransformerEngine into an existing JAX model framework, such as [MaxText's TE integration](https://github.com/AI-Hypercomputer/maxtext/blob/ed517cf80d9aa81f76e236c5516dacebfe39e96d/src/MaxText/layers/quantizations.py#L753) or your own model framework. \n"
]
},
{
"cell_type": "markdown",
"id": "b36876bb",
"metadata": {},
"source": [
"Let's start with a standard JAX+Flax Transformer layer"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "d5284a38",
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"from flax import linen as nn\n",
"import quickstart_jax_utils as utils\n",
"from typing import Optional"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "a4d1cfdc",
"metadata": {},
"outputs": [],
"source": [
"class FlaxMLP(nn.Module):\n",
" \"\"\"Feed-forward network in Transformer layer\n",
" Built with plain Flax modules.\n",
" \"\"\"\n",
" hidden_size: int\n",
" ffn_hidden_size: int\n",
" dot_general_cls: callable = lambda: None\n",
"\n",
" @nn.compact\n",
" def __call__(self, x: jnp.ndarray) -> jnp.ndarray:\n",
" x = nn.Dense(features=self.ffn_hidden_size, use_bias=True, dot_general=self.dot_general_cls())(x)\n",
" x = nn.gelu(x, approximate=True) # equivalent to tanh approximation\n",
" x = nn.Dense(features=self.hidden_size, use_bias=True, dot_general=self.dot_general_cls())(x)\n",
" return x\n",
"\n",
"class FlaxTransformerLayer(nn.Module):\n",
" \"\"\"Basic Transformer layer using plain Flax modules\"\"\"\n",
" hidden_size: int\n",
" ffn_hidden_size: int\n",
" num_attention_heads: int\n",
" layernorm_eps: float = 1e-5\n",
" attention_dropout: float = 0.1\n",
" dot_general_cls: callable = lambda: None\n",
" \n",
" def setup(self):\n",
" self.kv_channels = self.hidden_size // self.num_attention_heads\n",
"\n",
" @nn.compact\n",
" def __call__(\n",
" self, \n",
" x: jnp.ndarray, \n",
" attention_mask: Optional[jnp.ndarray] = None,\n",
" deterministic: bool = False\n",
" ) -> jnp.ndarray:\n",
" # Create causal mask if not provided\n",
" if attention_mask is None:\n",
" attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_)\n",
" \n",
" res = x\n",
" x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)\n",
" \n",
" # Fused QKV projection\n",
" qkv = nn.Dense(features=3 * self.hidden_size, use_bias=True, dot_general=self.dot_general_cls())(x)\n",
" qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)\n",
" q, k, v = jnp.split(qkv, 3, axis=3)\n",
" \n",
" # q, k, v now have shape [batch, seq_len, num_heads, kv_channels]\n",
" # which is the correct format for dot_product_attention\n",
" \n",
" # Apply dot product attention\n",
" # Note: dot_product_attention expects mask to be broadcastable to \n",
" # [batch, num_heads, q_length, kv_length], but attention_mask from \n",
" # nn.make_causal_mask has shape [batch, 1, seq_len, seq_len]\n",
" \n",
" # Generate dropout RNG key when needed (not deterministic and dropout_rate > 0)\n",
" dropout_rng = None\n",
" if not deterministic and self.attention_dropout > 0:\n",
" dropout_rng = self.make_rng('dropout')\n",
" \n",
" # See quickstart_jax.ipynb for details on using TE's faster fused attention\n",
" x = nn.dot_product_attention(\n",
" query=q,\n",
" key=k,\n",
" value=v,\n",
" mask=attention_mask,\n",
" dropout_rng=dropout_rng,\n",
" dropout_rate=self.attention_dropout,\n",
" deterministic=deterministic,\n",
" broadcast_dropout=True,\n",
" )\n",
" \n",
" # Reshape output from [batch, seq_len, num_heads, kv_channels] to [batch, seq_len, hidden_size]\n",
" x = x.reshape(x.shape[0], x.shape[1], self.hidden_size)\n",
"\n",
" # Output projection\n",
" x = nn.Dense(features=self.hidden_size, use_bias=True, dot_general=self.dot_general_cls())(x)\n",
" \n",
" x = res + x\n",
" \n",
" # Second residual connection\n",
" res = x\n",
" x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)\n",
" \n",
" # MLP\n",
" mlp = FlaxMLP(\n",
" hidden_size=self.hidden_size,\n",
" ffn_hidden_size=self.ffn_hidden_size,\n",
" dot_general_cls=self.dot_general_cls,\n",
" )\n",
" x = mlp(x)\n",
" \n",
" return x + res\n"
]
},
{
"cell_type": "markdown",
"id": "db16bf70",
"metadata": {},
"source": [
"We've exposed `dot_general_cls` here so we can test out different GEMM implementations later. By default, Flax's `nn.Dense` will use JAX's GEMM `jax.lax.dot_general` when `dot_general` is `None`."
]
},
{
"cell_type": "markdown",
"id": "fbc3510b",
"metadata": {},
"source": [
"## Testing Performance\n",
"\n",
"Now let's test the performance of our FlaxTransformerLayer:\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8b44649d",
"metadata": {},
"outputs": [],
"source": [
"# Layer configuration\n",
"hidden_size = 4096\n",
"sequence_length = 2048\n",
"batch_size = 4\n",
"ffn_hidden_size = 16384\n",
"num_attention_heads = 32\n",
"dtype = jnp.bfloat16\n",
"\n",
"# Synthetic data\n",
"key, dropout_key = jax.random.split(jax.random.PRNGKey(42))\n",
"x = jax.random.normal(key, (batch_size, sequence_length, hidden_size)).astype(dtype)\n",
"dy = jax.random.normal(key, (batch_size, sequence_length, hidden_size)).astype(dtype)\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e44ed26d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pure Flax FlaxTransformerLayer initialized successfully!\n",
"Parameter shapes: {'params': {'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'Dense_1': {'bias': (4096,), 'kernel': (4096, 4096)}, 'FlaxMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}}}\n"
]
}
],
"source": [
"# Initialize the FlaxTransformerLayer\n",
"flax_transformer = FlaxTransformerLayer(\n",
" hidden_size=hidden_size,\n",
" ffn_hidden_size=ffn_hidden_size,\n",
" num_attention_heads=num_attention_heads,\n",
")\n",
"\n",
"# Initialize parameters\n",
"params = flax_transformer.init(key, x, attention_mask=None, deterministic=False)\n",
"\n",
"print(\"Pure Flax FlaxTransformerLayer initialized successfully!\")\n",
"print(f\"Parameter shapes: {jax.tree_util.tree_map(lambda x: x.shape, params)}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "de91af7a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input shape: (4, 2048, 4096)\n",
"Output shape: (4, 2048, 4096)\n",
"Output dtype: float32\n",
"Forward pass completed successfully!\n"
]
}
],
"source": [
"# Example usage of forward pass\n",
"y = flax_transformer.apply(params, x, attention_mask=None, deterministic=True)\n",
"print(f\"Input shape: {x.shape}\")\n",
"print(f\"Output shape: {y.shape}\")\n",
"print(f\"Output dtype: {y.dtype}\")\n",
"print(\"Forward pass completed successfully!\")\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "037bc8d9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 18.83516788482666 ms\n"
]
}
],
"source": [
"import importlib\n",
"import quickstart_jax_utils\n",
"importlib.reload(quickstart_jax_utils)\n",
"\n",
"utils.speedometer(\n",
" model_apply_fn=flax_transformer.apply,\n",
" variables=params,\n",
" input=x,\n",
" output_grad=dy,\n",
" forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
" rngs={\"dropout\": dropout_key},\n",
")"
]
},
{
"cell_type": "markdown",
"id": "5e9310c9",
"metadata": {},
"source": [
"# Transformer Engine"
]
},
{
"cell_type": "markdown",
"id": "1f8e213e",
"metadata": {},
"source": [
"TransformerEngine/JAX is currently using Flax Linen. However, it is easily compatible with Flax NNX or Haiku.\n",
"* [Use Flax NNX and Linen together](https://flax.readthedocs.io/en/latest/guides/bridge_guide.html)\n",
"* [Haiku and Flax interop](https://dm-haiku.readthedocs.io/en/latest/notebooks/flax.html)\n",
"\n",
"Additionally, with the tutorial below, no model parameters need to be managed by TransformerEngine. You can keep all your existing model parameters, initialization, and sharding the same. The only change required is to call TE's dot_general_cls instead of the default Dense dot_general implementation. TE's dot_general_cls is a small module that performs a quantized dense VJP and stores some small recipe-specific state."
]
},
{
"cell_type": "markdown",
"id": "4477d4e9",
"metadata": {},
"source": [
"Now we'll select a recipe. `DelayedScaling` and `CurrentScaling` use per-tensor scaling and are supported on Hopper and Blackwell. `MXFP8BlockScaling` and `NVFP4BlockScaling` use block scaling or a combination of both per-tensor and block scaling and are supported on Blackwell.\n",
"\n",
"If you would like to customize the recipe further, various options can be changed by passing args to the recipe's constructor."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "5ddf41e7",
"metadata": {},
"outputs": [],
"source": [
"from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, MXFP8BlockScaling, NVFP4BlockScaling\n",
"from transformer_engine.jax import flax as te_flax \n",
"\n",
"# Choose a quantization recipe. This can be modified to any of the recipes imported above.\n",
"quantization_recipe = DelayedScaling()\n",
"\n",
"te_dot_general_cls = te_flax.make_dot_general_cls(quantization_recipe)\n",
"\n",
"rngs = {'dropout': dropout_key}\n",
"if isinstance(quantization_recipe, NVFP4BlockScaling):\n",
" # The NVFP4 recipe requires a Flax RNG for stochastic rounding\n",
" rngs['sr_rng'] = jax.random.PRNGKey(0)\n"
]
},
{
"cell_type": "markdown",
"id": "c8769655",
"metadata": {},
"source": [
"Now using this quantized dense in our model is as simple as passing in `dot_general_fn=te_dot_general`. Let's try it out!\n",
"\n",
"<div class=\"alert alert-warning\">\n",
"\n",
"<b>Important: Remat Policy</b>\n",
"\n",
"TE's quantization uses specialized TE quantized GEMM primitives. If you are using any built-in JAX checkpoint policies that look for JAX GEMMs (dots), such as `jax.checkpoint_policies.checkpoint_dots`, please replace the policy with `transformer_engine.jax.checkpoint_policies.checkpoint_dots_and_te_gemms` or similar policies to ensure TE's quantized GEMM primitives are checkpointed correctly.\n",
"\n",
"If this is not performed, TE GEMMs will be rematerialized introducing an incorrect performance comparison.\n",
"\n",
"</div>"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "8407d2ea",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pure Flax FlaxTransformerLayer initialized successfully!\n",
"Parameter shapes: {'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'Dense_1': {'bias': (4096,), 'kernel': (4096, 4096)}, 'FlaxMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}}\n",
"Additional state: {'_overwrite_with_gradient': {'FlaxMLP_0': {'TEWrapper_dot_general_0': {'grad_amax_history': (1024,), 'grad_scale': (1,), 'kernel_amax_history': (1024,), 'kernel_scale': (1,), 'x_amax_history': (1024,), 'x_scale': (1,)}, 'TEWrapper_dot_general_1': {'grad_amax_history': (1024,), 'grad_scale': (1,), 'kernel_amax_history': (1024,), 'kernel_scale': (1,), 'x_amax_history': (1024,), 'x_scale': (1,)}}, 'TEWrapper_dot_general_0': {'grad_amax_history': (1024,), 'grad_scale': (1,), 'kernel_amax_history': (1024,), 'kernel_scale': (1,), 'x_amax_history': (1024,), 'x_scale': (1,)}, 'TEWrapper_dot_general_1': {'grad_amax_history': (1024,), 'grad_scale': (1,), 'kernel_amax_history': (1024,), 'kernel_scale': (1,), 'x_amax_history': (1024,), 'x_scale': (1,)}}}\n"
]
}
],
"source": [
"# Initialize the FlaxTransformerLayer\n",
"flax_transformer = FlaxTransformerLayer(\n",
" hidden_size=hidden_size,\n",
" ffn_hidden_size=ffn_hidden_size,\n",
" num_attention_heads=num_attention_heads,\n",
" dot_general_cls=te_dot_general_cls,\n",
")\n",
"\n",
"# Initialize parameters\n",
"var_collect = flax_transformer.init(key, x, attention_mask=None, deterministic=False)\n",
"\n",
"print(\"Pure Flax FlaxTransformerLayer initialized successfully!\")\n",
"print(f\"Parameter shapes: {jax.tree_util.tree_map(lambda x: x.shape, var_collect['params'])}\")\n",
"print(f\"Additional state: {jax.tree_util.tree_map(lambda x: x.shape, {k: v for k, v in var_collect.items() if k != 'params'})}\")"
]
},
{
"cell_type": "markdown",
"id": "abe27237",
"metadata": {},
"source": [
"If using a recipe that stores additional state, such as `DelayedScaling`, you'll see this additional state stored as Flax variables. It is important to maintain and pass the whole state of Flax variables `var_collect` across training steps, not just the model params, for proper usage of stateful recipes like `DelayedScaling`.\n",
"\n",
"For example, above inside `Additional state: ` you'll see the `amax_history` of each quantization which is used to compute the per-tensor scale in the `DelayedScaling` recipe."
]
},
{
"cell_type": "markdown",
"id": "5ab72935",
"metadata": {},
"source": [
"The reason we need `te_dot_general_cls` as a Flax module instead of a module-less function like `jax.lax.dot_general` is for some quantization recipes to track internal state separate from model parameters.\n",
"\n",
"Flax modules can manage 3 things:\n",
"1. Model parameters/weights, e.g. your Dense \"kernel\", \"bias\", etc.\n",
"2. RNGs for dropout, stochastic rounding, etc.\n",
"3. Flax variables. These are additional state variables that are used across training steps but are distinct from model params in that you don't take gradients or optimize them. Currently, we only use this for DelayedScaling's amax_history state\n",
"\n",
"With the simplest quantization integration shown in this tutorial, we want users to keep their existing model param setup so they don't need to worry about preserving the sharding, init distribution, etc.. So we don't need point 1 since we don't do model param creation in this codepath with dot_general_cls, but we still do need `te_dot_general_cls()` to produce a Flax module since we potentially need to do points 2 or 3 which need to be in a Flax module."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "3b6b344b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input shape: (4, 2048, 4096)\n",
"Output shape: (4, 2048, 4096)\n",
"Output dtype: float32\n",
"Forward pass completed successfully!\n"
]
}
],
"source": [
"# Example usage of forward pass\n",
"y = flax_transformer.apply(var_collect, x, attention_mask=None, deterministic=True, rngs=rngs)\n",
"print(f\"Input shape: {x.shape}\")\n",
"print(f\"Output shape: {y.shape}\")\n",
"print(f\"Output dtype: {y.dtype}\")\n",
"print(\"Forward pass completed successfully!\")\n"
]
},
{
"cell_type": "markdown",
"id": "d178f247",
"metadata": {},
"source": [
"Now let's measure the performance!"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "5cc6c2a7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 10.553865432739258 ms\n"
]
}
],
"source": [
"import importlib\n",
"import quickstart_jax_utils\n",
"importlib.reload(quickstart_jax_utils)\n",
"\n",
"utils.speedometer(\n",
" model_apply_fn=flax_transformer.apply,\n",
" variables=var_collect,\n",
" input=x,\n",
" output_grad=dy,\n",
" forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
" rngs=rngs,\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
...@@ -48,6 +48,7 @@ Transformer Engine documentation ...@@ -48,6 +48,7 @@ Transformer Engine documentation
examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb
examples/te_gemma/tutorial_generation_gemma_with_te.ipynb examples/te_gemma/tutorial_generation_gemma_with_te.ipynb
examples/onnx/onnx_export.ipynb examples/onnx/onnx_export.ipynb
examples/te_jax_integration.ipynb
.. toctree:: .. toctree::
:hidden: :hidden:
......
...@@ -26,7 +26,7 @@ from transformer_engine.jax.attention import ( ...@@ -26,7 +26,7 @@ from transformer_engine.jax.attention import (
CPStrategy, CPStrategy,
SequenceDescriptor, SequenceDescriptor,
) )
from ..sharding import with_sharding_constraint_by_logical_axes, HEAD_AXES from ..sharding import with_sharding_constraint_by_logical_axes, HEAD_AXES, is_mesh_available
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
from .misc import ( from .misc import (
...@@ -3288,7 +3288,7 @@ register_primitive(FusedRingAttnStripedBwdPrimitive) ...@@ -3288,7 +3288,7 @@ register_primitive(FusedRingAttnStripedBwdPrimitive)
def _maybe_context_parallel_axis(cp_axis: str): def _maybe_context_parallel_axis(cp_axis: str):
if not cp_axis: if not cp_axis and is_mesh_available():
gmr = global_mesh_resource() gmr = global_mesh_resource()
if gmr is not None: if gmr is not None:
cp_axis = gmr.cp_resource cp_axis = gmr.cp_resource
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""Transformer Engine bindings for JAX""" """Transformer Engine bindings for JAX"""
from .module import DenseGeneral, LayerNorm from .module import DenseGeneral, LayerNorm
from .module import LayerNormDenseGeneral, LayerNormMLP from .module import LayerNormDenseGeneral, LayerNormMLP
from .module import wrap_function_in_te_state_module, make_dot_general_cls
from .transformer import extend_logical_axis_rules from .transformer import extend_logical_axis_rules
from .transformer import DotProductAttention, MultiHeadAttention, RelativePositionBiases from .transformer import DotProductAttention, MultiHeadAttention, RelativePositionBiases
from .transformer import TransformerLayer, TransformerLayerType from .transformer import TransformerLayer, TransformerLayerType
...@@ -13,6 +14,8 @@ __all__ = [ ...@@ -13,6 +14,8 @@ __all__ = [
"LayerNorm", "LayerNorm",
"LayerNormDenseGeneral", "LayerNormDenseGeneral",
"LayerNormMLP", "LayerNormMLP",
"wrap_function_in_te_state_module",
"make_dot_general_cls",
"extend_logical_axis_rules", "extend_logical_axis_rules",
"DotProductAttention", "DotProductAttention",
"MultiHeadAttention", "MultiHeadAttention",
......
...@@ -1354,3 +1354,87 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1354,3 +1354,87 @@ class LayerNormMLP(TransformerEngineBase):
assert out.dtype == input_dtype assert out.dtype == input_dtype
return out, ln_output # Output, layer_norm_output return out, ln_output # Output, layer_norm_output
def wrap_function_in_te_state_module(f, quantization_recipe, name: Optional[str] = None):
"""Wraps the given function `f` to support TransformerEngine quantization.
This method does a couple things:
1. Wraps the given function in a Flax linen module. This module does not store any Flax parameters
but can store Flax variables for quantizers if required by the recipe.
2. When the wrapper is called, it provides an additional argument to the given function `f`, 'generate_quantizer_set' as the first argument. 'generate_quantizer_set' is a function that can be called to generate a TransformerEngine/JAX quantizer set object used in TransformerEngine/JAX APIs. 'generate_quantizer_set' will generate quantizers based on the recipe of this TransformerEngineQuantizer object.
Args:
f: The function to wrap. The first argument must be 'generate_quantizer_set'.
name: The name of this wrapped operation. If unspecified, will use `f.__name__`.
Returns:
A Flax linen module that wraps the given function.
"""
import transformer_engine.jax as te
class TEWrapper(te.flax.module.TransformerEngineBase):
"""Wrapper Flax module for TransformerEngine quantization support."""
def generate_quantizer_set(self, postfix: str = ""):
OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient"
return super().generate_quantizer_set(
postfix=postfix,
variable_collection=OVERWRITE_WITH_GRADIENT,
fp8_recipe=quantization_recipe,
)
@nn.compact
def __call__(self, *args, **kwargs):
return f(self.generate_quantizer_set, *args, **kwargs)
TEWrapper.__name__ = f"TEWrapper_{name if name else f.__name__}"
return TEWrapper
def make_dot_general_cls(quantization_recipe):
"""Creates a Flax module class that performs a dot_general operation with the arguments x and kernel using the given quantization recipe.
This is intended for usage when you already have model parameters initialized and sharded for the kernel weights and you want to replace the GEMM implementation with TE's quantized GEMM using a given recipe.
For example,
```
te_dot_general_cls = make_dot_general_cls(DelayedScaling())
dense = nn.Dense(..., dot_general=te_dot_general_cls())
```
If you would like a drop-in replacement for nn.Dense that manages the model weights itself, please use TE's DenseGeneral module.
Args:
quantization_recipe: The quantization recipe to use for the dot_general operation.
Returns:
A Flax module class that performs a dot_general operation with the given quantization recipe.
"""
import transformer_engine.jax as te
from transformer_engine.common.recipe import NVFP4BlockScaling
def te_dot_general(generate_quantizer_set, x, kernel, dims, **kwargs):
"""Performs a dot_general operation using TransformerEngine with quantization."""
del kwargs # Unused
contracting_dims, batch_dims = dims
assert batch_dims == ((), ()), "Batch dimensions must be empty for TransformerEngine dot."
quantizer_set = generate_quantizer_set()
if isinstance(quantization_recipe, NVFP4BlockScaling):
# NVFP4 RHT requires inputs to be in bfloat16
x = x.astype(jnp.bfloat16)
kernel = kernel.astype(jnp.bfloat16)
return te.dense.dense(
x,
kernel,
contracting_dims=contracting_dims,
quantizer_set=quantizer_set,
)
return wrap_function_in_te_state_module(te_dot_general, quantization_recipe, "dot_general")
...@@ -121,7 +121,6 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- ...@@ -121,7 +121,6 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
attention_dropout: float = 0.0 attention_dropout: float = 0.0
attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
attn_bias_type: Optional[AttnBiasType] = None attn_bias_type: Optional[AttnBiasType] = None
dtype: DType = jnp.float32
float32_logits: bool = False float32_logits: bool = False
scale_factor: Optional[float] = None scale_factor: Optional[float] = None
transpose_batch_sequence: bool = False transpose_batch_sequence: bool = False
...@@ -294,7 +293,6 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -294,7 +293,6 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
attention_dropout: float = 0.0 attention_dropout: float = 0.0
attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
attn_bias_type: Optional[AttnBiasType] = None attn_bias_type: Optional[AttnBiasType] = None
dtype: DType = jnp.float32
qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD
scale_factor: Optional[float] = None scale_factor: Optional[float] = None
transpose_batch_sequence: bool = False transpose_batch_sequence: bool = False
...@@ -600,11 +598,6 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -600,11 +598,6 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention ``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention
(``'zero sink'`` and ``'learnable sink'``). (``'zero sink'`` and ``'learnable sink'``).
Optimization parameters
-----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters.
""" """
head_dim: int head_dim: int
...@@ -613,7 +606,6 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -613,7 +606,6 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
attention_dropout: float = 0.0 attention_dropout: float = 0.0
attn_mask_type: AttnMaskType = "causal" attn_mask_type: AttnMaskType = "causal"
attn_bias_type: AttnBiasType = None attn_bias_type: AttnBiasType = None
dtype: DType = jnp.float32
dropout_rng_name: str = "dropout" dropout_rng_name: str = "dropout"
float32_logits: bool = False float32_logits: bool = False
qkv_layout: str = "bshd_bshd_bshd" qkv_layout: str = "bshd_bshd_bshd"
...@@ -638,6 +630,24 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -638,6 +630,24 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
self.transpose_batch_sequence = False self.transpose_batch_sequence = False
super().__post_init__() super().__post_init__()
def _assert_dtypes(self, query: Array, key: Array, value: Array, qkv_layout: QKVLayout):
"""Asserts that the dtypes of query, key, and value dtypes are consistent."""
if qkv_layout.is_qkvpacked():
pass # No need to check dtypes for key and value since it is packed
elif qkv_layout.is_kvpacked():
assert (
key.dtype == query.dtype
), f"Expected kv dtype={key.dtype} to match query dtype={query.dtype}."
elif qkv_layout.is_separate():
assert (
key.dtype == query.dtype
), f"Expected key dtype={key.dtype} to match query dtype={query.dtype}."
assert (
value.dtype == query.dtype
), f"Expected value dtype={value.dtype} to match query dtype={query.dtype}."
else:
raise ValueError(f"Unsupported {qkv_layout=}.")
@nn.compact @nn.compact
def __call__( def __call__(
self, self,
...@@ -700,6 +710,9 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -700,6 +710,9 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
assert bias is None assert bias is None
else: else:
assert bias is not None assert bias is not None
bias = bias.astype(input_dtype)
self._assert_dtypes(query, key, value, qkv_layout)
# Use fused attn (if kernel check below passes) by default # Use fused attn (if kernel check below passes) by default
enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "1")) enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "1"))
...@@ -720,8 +733,9 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -720,8 +733,9 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
has_fused_attn_kernel = is_fused_attn_kernel_available( has_fused_attn_kernel = is_fused_attn_kernel_available(
# This needs to be fixed: TE-Jax has historically correlated training mode with deterministic mode. # This needs to be fixed: TE-Jax has historically correlated training mode with deterministic mode.
not deterministic, not deterministic,
self.dtype, input_dtype,
self.dtype, # self._assert_dtypes enforces Q, K, V, bias to have the same dtype so using input_dtype as kv dtype is sufficient
input_dtype,
qkv_layout, qkv_layout,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
...@@ -743,7 +757,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -743,7 +757,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
"Fused attention is not enabled because there is no available kernel.\n" "Fused attention is not enabled because there is no available kernel.\n"
"Fall back to the unfused attention.\n" "Fall back to the unfused attention.\n"
"Please try to update the cuDNN and TE to the latest version.\n" "Please try to update the cuDNN and TE to the latest version.\n"
f"{self.dtype=}\n{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n" f"{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n"
f"{self.attention_dropout=}\n{self.num_attention_heads=}\n" f"{self.attention_dropout=}\n{self.num_attention_heads=}\n"
f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{head_dim_qk=}\n{head_dim_v=}\n" f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{head_dim_qk=}\n{head_dim_v=}\n"
) )
...@@ -797,7 +811,6 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -797,7 +811,6 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
attention_dropout=self.attention_dropout, attention_dropout=self.attention_dropout,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
dtype=self.dtype,
float32_logits=self.float32_logits, float32_logits=self.float32_logits,
scale_factor=scale_factor, scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
...@@ -817,7 +830,6 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -817,7 +830,6 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
attention_dropout=self.attention_dropout, attention_dropout=self.attention_dropout,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
dtype=self.dtype,
scale_factor=scale_factor, scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
...@@ -1572,7 +1584,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1572,7 +1584,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_mask_type=self.attn_mask_type, attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type, attn_bias_type=self.attn_bias_type,
attention_dropout=self.attention_dropout, attention_dropout=self.attention_dropout,
dtype=self.dtype,
dropout_rng_name=self.dropout_rng_name, dropout_rng_name=self.dropout_rng_name,
float32_logits=self.float32_logits, float32_logits=self.float32_logits,
qkv_layout=qkv_layout.name, qkv_layout=qkv_layout.name,
......
...@@ -59,6 +59,14 @@ def _validate_mesh_resource_configuration(mesh_resource): ...@@ -59,6 +59,14 @@ def _validate_mesh_resource_configuration(mesh_resource):
) )
def is_mesh_available() -> bool:
"""
Check if a physical mesh is available.
"""
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
return mesh is not None and not mesh.empty
def get_sharding_map_logic_axis_to_mesh_axis(): def get_sharding_map_logic_axis_to_mesh_axis():
""" """
Generate a dict to map logical axes to mesh axes. Generate a dict to map logical axes to mesh axes.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment