Commit 970620a5 authored by wenjh's avatar wenjh
Browse files

merge nv_release_v2.10 to release_v2.10


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents c1a1c04e 769ed778
...@@ -100,7 +100,7 @@ ...@@ -100,7 +100,7 @@
"\n", "\n",
"</div>\n", "</div>\n",
"\n", "\n",
"A variety of parallelism strategies can be used to enable multi-GPU training of Transformer models, often based on different approaches to distribute their $\\text{sequence_length} \\times \\text{batch_size} \\times \\text{hidden_size}$ activation tensors. The most common approach is data parallelism, which distributes along the $\\text{batch_size}$ dimension. By storing duplicate copies of the model on each GPU, the forward and backward passes of the training step can be done independently, followed by a gradient synchronization. A more advanced strategy is tensor parallelism, a type of model parallelism that distributes along the $\\text{hidden_size}$ dimension. This allows us to scale past the limits of data parallelism (typically $\\text{hidden_size} > \\text{batch_size}$) and to reduce the per-GPU memory usage (since model parameters are also distributed), but it also incurs the overhead of communicating activation tensors between GPUs at every step. For a more detailed explanation, please see the [Megatron-LM paper](https://arxiv.org/pdf/1909.08053.pdf). Finally, sequence parallelism distributes along the $\\text{sequence_length}$ dimension. This can be used when tensor parallelism is enabled in order to parallelize operations that run outside the tensor-parallel region (e.g. layer norm). For more details, please see [this paper](https://arxiv.org/pdf/2205.05198.pdf).\n", "A variety of parallelism strategies can be used to enable multi-GPU training of Transformer models, often based on different approaches to distribute their $\\text{sequence_length} \\cdot \\text{batch_size} \\cdot \\text{hidden_size}$ activation tensors. The most common approach is data parallelism, which distributes along the $\\text{batch_size}$ dimension. By storing duplicate copies of the model on each GPU, the forward and backward passes of the training step can be done independently, followed by a gradient synchronization. A more advanced strategy is tensor parallelism, a type of model parallelism that distributes along the $\\text{hidden_size}$ dimension. This allows us to scale past the limits of data parallelism (typically $\\text{hidden_size} > \\text{batch_size}$) and to reduce the per-GPU memory usage (since model parameters are also distributed), but it also incurs the overhead of communicating activation tensors between GPUs at every step. For a more detailed explanation, please see the [Megatron-LM paper](https://arxiv.org/pdf/1909.08053.pdf). Finally, sequence parallelism distributes along the $\\text{sequence_length}$ dimension. This can be used when tensor parallelism is enabled in order to parallelize operations that run outside the tensor-parallel region (e.g. layer norm). For more details, please see [this paper](https://arxiv.org/pdf/2205.05198.pdf).\n",
"\n", "\n",
"To show this in action, let's first initialize NCCL with a trivial process group:" "To show this in action, let's first initialize NCCL with a trivial process group:"
] ]
...@@ -131,7 +131,7 @@ ...@@ -131,7 +131,7 @@
"id": "1f2b80d0", "id": "1f2b80d0",
"metadata": {}, "metadata": {},
"source": [ "source": [
"We only initialize with one GPU to keep this example simple. Please consult the documentation [torch.distributed](https://pytorch.org/docs/stable/distributed.html) for guidance on running with multiple GPUs. Note that we require that each distributed process corresponds to exactly one GPU, so we treat them interchangeably. In practice, there are multiple factors that can affect the optimal parallel layout: the system hardware, the network topology, usage of other parallelism schemes like pipeline parallelism. A rough rule-of-thumb is to interpret the GPUs as a 2D grid with dimensions of $\\text{num_nodes} \\times \\text{gpus_per_node}$. The rows are tensor-parallel groups and the columns are data-parallel groups.\n", "We only initialize with one GPU to keep this example simple. Please consult the documentation [torch.distributed](https://pytorch.org/docs/stable/distributed.html) for guidance on running with multiple GPUs. Note that we require that each distributed process corresponds to exactly one GPU, so we treat them interchangeably. In practice, there are multiple factors that can affect the optimal parallel layout: the system hardware, the network topology, usage of other parallelism schemes like pipeline parallelism. A rough rule-of-thumb is to interpret the GPUs as a 2D grid with dimensions of $\\text{num_nodes} \\cdot \\text{gpus_per_node}$. The rows are tensor-parallel groups and the columns are data-parallel groups.\n",
"\n", "\n",
"Enabling data parallelism with Transformer Engine is similar to enabling data parallelism with standard PyTorch models: simply wrap the modules with [torch.nn.parallel.DistributedDataParallel](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html). Transformer Engine modules also have native support for tensor and sequence parallelism. If the user provides a process group for tensor parallelism, the modules will distribute the data and perform communication internally. If sequence parallelism is enabled, it will be applied for operations that are not amenable to tensor parallelism and it will use the tensor-parallel process group.\n", "Enabling data parallelism with Transformer Engine is similar to enabling data parallelism with standard PyTorch models: simply wrap the modules with [torch.nn.parallel.DistributedDataParallel](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html). Transformer Engine modules also have native support for tensor and sequence parallelism. If the user provides a process group for tensor parallelism, the modules will distribute the data and perform communication internally. If sequence parallelism is enabled, it will be applied for operations that are not amenable to tensor parallelism and it will use the tensor-parallel process group.\n",
"\n", "\n",
......
...@@ -174,7 +174,7 @@ ...@@ -174,7 +174,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": null,
"id": "50852cb5", "id": "50852cb5",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -266,7 +266,7 @@ ...@@ -266,7 +266,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 24, "execution_count": null,
"id": "906b8cf1", "id": "906b8cf1",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -299,7 +299,7 @@ ...@@ -299,7 +299,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 23, "execution_count": null,
"id": "d3637094", "id": "d3637094",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -509,10 +509,10 @@ ...@@ -509,10 +509,10 @@
"\n", "\n",
"* PyTorch: When both options are provided by the user, `cu_seqlens` is preferred as there is no extra conversion needed.\n", "* PyTorch: When both options are provided by the user, `cu_seqlens` is preferred as there is no extra conversion needed.\n",
" - `cu_seqlens`: Users can provide cumulative sequence length tensors `cu_seqlens_q` and `cu_seqlens_kv` for `q` and `k`/`v` to the flash-attention or cuDNN attention backend. An example of `cu_seqlens` is `[0, 2, 6, 7]` for a batch of 3 `[aa000, bbbb0, c0000]`.\n", " - `cu_seqlens`: Users can provide cumulative sequence length tensors `cu_seqlens_q` and `cu_seqlens_kv` for `q` and `k`/`v` to the flash-attention or cuDNN attention backend. An example of `cu_seqlens` is `[0, 2, 6, 7]` for a batch of 3 `[aa000, bbbb0, c0000]`.\n",
" - `attention_mask`: Users can also provide `attention_mask` as an alternative, which will then be converted to `cu_seqlens`. For self-attention, `attention_mask` should be one single tensor in shape `[batch_size, 1, 1, seqlen_q]`, and for cross-attention, `attention_mask` should be a list of two tensors in shapes `[batch_size, 1, 1, seqlen_q]` and `[batch_size, 1, 1, seqlen_kv]`, respectively.\n", " - `attention_mask`: Users can also provide `attention_mask` as an alternative, which will then be converted to `cu_seqlens`. For self-attention, `attention_mask` should be one single tensor of shape `[batch_size, 1, 1, seqlen_q]`, and for cross-attention, `attention_mask` should be a list of two tensors of shapes `[batch_size, 1, 1, seqlen_q]` and `[batch_size, 1, 1, seqlen_kv]`, respectively.\n",
"\n", "\n",
"\n", "\n",
"* JAX: Users should provide the `attention_mask` tensor in shape `[batch_size, 1, seqlen_q, seqlen_kv]`.\n", "* JAX: Users should provide the `attention_mask` tensor of shape `[batch_size, 1, seqlen_q, seqlen_kv]`.\n",
"\n", "\n",
"**qkv_format=thd:** Transformer Engine extracts the max sequence length information from `q`, `k`, `v` if `max_seqlen_q` and `max_seqlen_kv` are not provided. This requires GPU-CPU copy and synchronization operations. For performance reasons, please set `max_seqlen_q` and `max_seqlen_kv` to their appropriate values for `thd` QKV format.\n", "**qkv_format=thd:** Transformer Engine extracts the max sequence length information from `q`, `k`, `v` if `max_seqlen_q` and `max_seqlen_kv` are not provided. This requires GPU-CPU copy and synchronization operations. For performance reasons, please set `max_seqlen_q` and `max_seqlen_kv` to their appropriate values for `thd` QKV format.\n",
"\n", "\n",
...@@ -521,7 +521,7 @@ ...@@ -521,7 +521,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 33, "execution_count": null,
"id": "a1f25a9b", "id": "a1f25a9b",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
......
{
"cells": [
{
"cell_type": "markdown",
"id": "962d87bb",
"metadata": {},
"source": [
"\n",
"\n",
"# Getting Started\n",
"\n",
"## Overview\n",
"\n",
"Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, providing better performance with lower memory utilization in both training and inference. It provides support for 8-bit floating point (FP8) precision on Hopper, Ada, as well as 8-bit and 4-bit floating point (NVFP4) precision on Blackwell GPUs, implements a collection of highly optimized building blocks for popular Transformer architectures, and exposes an automatic-mixed-precision-like API that can be used seamlessly with your JAX code. It also includes a framework-agnostic C++ API that can be integrated with other deep learning libraries to enable FP8 support for Transformers.\n",
"\n",
"This guide shows how to start using Transformer Engine with JAX. Similar tutorial for pyTorch is available [here](quickstart.ipynb).\n",
"We recommend you to try understanding the basics of JAX first, using these resources:\n",
"\n",
"- Thinking in JAX: https://docs.jax.dev/en/latest/notebooks/thinking_in_jax.html\n",
"- JAX 101: https://docs.jax.dev/en/latest/jax-101.html\n",
"- Key concepts in JAX: https://docs.jax.dev/en/latest/key-concepts.html#jax-arrays-jax-array\n",
"- Flax 101: https://flax-linen.readthedocs.io/en/latest/guides/flax_fundamentals/index.html\n",
"\n",
"## Let's build a Transformer decoder layer!\n",
"<small>_This is based upon the GPT decoder layer with causal masking, which prevents each position from attending to future positions._</small>\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Summary</b>\n",
" \n",
"We build a basic Transformer layer using regular Flax modules. This will be our baseline for later comparisons with Transformer Engine.\n",
"\n",
"</div>\n",
"\n",
"Let's start with creating the transformer layer using plain [FLAX Linen](https://flax.readthedocs.io/en/stable/) . Figure 1 shows the overall structure.\n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"transformer_layer.png\" width=\"20%\">\n",
"<figcaption> Figure 1: Structure of a GPT decoder layer.</figcaption>\n",
"</figure>\n",
"\n",
"We construct the components as follows:\n",
"\n",
"- `LayerNorm`: `nn.LayerNorm` (Flax)\n",
"- `QKV Projection`: `nn.Dense` (conceptually there are three seperate `Dense` layers for Q, K, and V separately, but we fuse them together into a single `Dense` layer that is three times larger)\n",
"- `DotProductAttention`: `nn.MuliheadDotProductAttention` (Flax)\n",
"- `Projection`: `nn.Dense` (Flax)\n",
"- `Dropout`: `nn.Dropout` (Flax)\n",
"- `MLP`: `FlaxMLP` implemented using `nn.Dense` and `nn.gelu`\n",
"\n",
"Over the course of this tutorial we will use a few modules and helper functions defined in [quickstart_jax_utils.py](quickstart_jax_utils.py). Putting it all together: \n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "d5284a38",
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"from flax import linen as nn\n",
"import quickstart_jax_utils as utils\n",
"from typing import Optional"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a4d1cfdc",
"metadata": {},
"outputs": [],
"source": [
"class FlaxMLP(nn.Module):\n",
" \"\"\"Feed-forward network in Transformer layer\n",
" Built with plain Flax modules.\n",
" \"\"\"\n",
" hidden_size: int\n",
" ffn_hidden_size: int\n",
"\n",
" @nn.compact\n",
" def __call__(self, x: jnp.ndarray) -> jnp.ndarray:\n",
" x = nn.Dense(features=self.ffn_hidden_size, use_bias=True)(x)\n",
" x = nn.gelu(x, approximate=True) # equivalent to tanh approximation\n",
" x = nn.Dense(features=self.hidden_size, use_bias=True)(x)\n",
" return x\n",
"\n",
"class FlaxTransformerLayer(nn.Module):\n",
" \"\"\"Basic Transformer layer using plain Flax modules\"\"\"\n",
" hidden_size: int\n",
" ffn_hidden_size: int\n",
" num_attention_heads: int\n",
" layernorm_eps: float = 1e-5\n",
" attention_dropout: float = 0.1\n",
" \n",
" def setup(self):\n",
" self.kv_channels = self.hidden_size // self.num_attention_heads\n",
"\n",
" @nn.compact\n",
" def __call__(\n",
" self, \n",
" x: jnp.ndarray, \n",
" attention_mask: Optional[jnp.ndarray] = None,\n",
" deterministic: bool = False\n",
" ) -> jnp.ndarray:\n",
" # Create causal mask if not provided\n",
" if attention_mask is None:\n",
" attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_)\n",
" \n",
" res = x\n",
" x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)\n",
" \n",
" # Fused QKV projection\n",
" qkv = nn.Dense(features=3 * self.hidden_size, use_bias=True)(x)\n",
" qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)\n",
" q, k, v = jnp.split(qkv, 3, axis=3)\n",
" \n",
" # Reshape to [batch, seq_len, num_heads * head_dim] for Flax MultiHeadDotProductAttention\n",
" q_reshaped = q.reshape(q.shape[0], q.shape[1], self.hidden_size)\n",
" k_reshaped = k.reshape(k.shape[0], k.shape[1], self.hidden_size)\n",
" v_reshaped = v.reshape(v.shape[0], v.shape[1], self.hidden_size)\n",
" \n",
" # Attention using Flax's MultiHeadDotProductAttention\n",
" attention = nn.MultiHeadDotProductAttention(\n",
" num_heads=self.num_attention_heads,\n",
" qkv_features=self.kv_channels,\n",
" dropout_rate=self.attention_dropout,\n",
" )\n",
" x = attention(q_reshaped, k_reshaped, v_reshaped, mask=attention_mask, deterministic=deterministic)\n",
"\n",
" x = res + x\n",
" \n",
" # Second residual connection\n",
" res = x\n",
" x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)\n",
" \n",
" # MLP\n",
" mlp = FlaxMLP(\n",
" hidden_size=self.hidden_size,\n",
" ffn_hidden_size=self.ffn_hidden_size,\n",
" )\n",
" x = mlp(x)\n",
" \n",
" return x + res\n"
]
},
{
"cell_type": "markdown",
"id": "fbc3510b",
"metadata": {},
"source": [
"## Testing Performance\n",
"\n",
"Now let's test the performance of our FlaxTransformerLayer:\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8b44649d",
"metadata": {},
"outputs": [],
"source": [
"# Layer configuration\n",
"hidden_size = 4096\n",
"sequence_length = 2048\n",
"batch_size = 4\n",
"ffn_hidden_size = 16384\n",
"num_attention_heads = 32\n",
"dtype = jnp.bfloat16\n",
"\n",
"# Synthetic data\n",
"key, dropout_key = jax.random.split(jax.random.PRNGKey(42))\n",
"x = jax.random.normal(key, (batch_size, sequence_length, hidden_size)).astype(dtype)\n",
"dy = jax.random.normal(key, (batch_size, sequence_length, hidden_size)).astype(dtype)\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e44ed26d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pure Flax FlaxTransformerLayer initialized successfully!\n",
"Parameter shapes: {'params': {'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'FlaxMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}, 'MultiHeadDotProductAttention_0': {'key': {'bias': (32, 4), 'kernel': (4096, 32, 4)}, 'out': {'bias': (4096,), 'kernel': (32, 4, 4096)}, 'query': {'bias': (32, 4), 'kernel': (4096, 32, 4)}, 'value': {'bias': (32, 4), 'kernel': (4096, 32, 4)}}}}\n"
]
}
],
"source": [
"# Initialize the FlaxTransformerLayer\n",
"flax_transformer = FlaxTransformerLayer(\n",
" hidden_size=hidden_size,\n",
" ffn_hidden_size=ffn_hidden_size,\n",
" num_attention_heads=num_attention_heads,\n",
")\n",
"\n",
"# Initialize parameters\n",
"params = flax_transformer.init(key, x, attention_mask=None, deterministic=False)\n",
"\n",
"print(\"Pure Flax FlaxTransformerLayer initialized successfully!\")\n",
"print(f\"Parameter shapes: {jax.tree_util.tree_map(lambda x: x.shape, params)}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "de91af7a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input shape: (4, 2048, 4096)\n",
"Output shape: (4, 2048, 4096)\n",
"Output dtype: float32\n",
"Forward pass completed successfully!\n"
]
}
],
"source": [
"# Example usage of forward pass\n",
"y = flax_transformer.apply(params, x, attention_mask=None, deterministic=True)\n",
"print(f\"Input shape: {x.shape}\")\n",
"print(f\"Output shape: {y.shape}\")\n",
"print(f\"Output dtype: {y.dtype}\")\n",
"print(\"Forward pass completed successfully!\")\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "037bc8d9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 17.708301544189453 ms\n"
]
}
],
"source": [
"import importlib\n",
"import quickstart_jax_utils\n",
"importlib.reload(quickstart_jax_utils)\n",
"\n",
"utils.speedometer(\n",
" model_apply_fn=flax_transformer.apply,\n",
" variables=params,\n",
" input=x,\n",
" output_grad=dy,\n",
" dropout_key=dropout_key,\n",
" forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
")"
]
},
{
"cell_type": "markdown",
"id": "ccb16f31",
"metadata": {},
"source": [
"## Meet Transformer Engine\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Summary</b>\n",
" \n",
"Now that we have a basic Transformer layer in Flax, let's use Transformer Engine to speed up the training. The following examples show how to use TE modules.\n",
"\n",
"</div>\n",
"\n",
"As a reminder, the FlaxTransformerLayer above used:\n",
"\n",
"- `nn.LayerNorm`: Flax LayerNorm\n",
"- `nn.Dense`: Flax Dense layer for QKV projection \n",
"- `nn.MultiheadDotProductAttention`: Flax MultiheadDotProductAttention\n",
"- `nn.Dense`: Flax Dense layer for projection\n",
"- `nn.Dropout`: Flax Dropout\n",
"- `FlaxMLP`: Custom MLP implemented from `nn.Dense`\n",
"\n",
"Below we show how to use Transformer Engine Flax modules for better performance:\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "bed20d6b",
"metadata": {},
"outputs": [],
"source": [
"import transformer_engine.jax as te\n",
"import transformer_engine.jax.flax as te_flax"
]
},
{
"cell_type": "markdown",
"id": "f28cb444",
"metadata": {},
"source": [
"TE provides a set of Flax Linen modules that can be used to build Transformer layers. The simplest of the provided modules are the `DenseGeneral ` and `LayerNorm` layers, which we can use instead of `flax.linen.Dense` and ` flax.linen.LayerNorm`. Let's modify our `FlaxTransformerLayer`:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "56105579",
"metadata": {},
"outputs": [],
"source": [
"from transformer_engine.jax.flax.transformer import DotProductAttention as TEDotProductAttention\n",
"\n",
"\n",
"class TEUnfusedMLP(nn.Module):\n",
" hidden_size : int\n",
" ffn_hidden_size: int\n",
"\n",
" @nn.compact\n",
" def __call__(self, x: jnp.ndarray, deterministic: bool) -> jnp.ndarray:\n",
" x = te_flax.DenseGeneral(features=self.ffn_hidden_size, use_bias=True) (x)\n",
" x = x.reshape(*x.shape[:-1], 1, x.shape[-1])\n",
" x = te.activation.activation(x, activation_type=('gelu',))\n",
" x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True) (x)\n",
" return x\n",
"\n",
"class TEUnfusedTransformerLayer(nn.Module):\n",
" hidden_size: int\n",
" ffn_hidden_size: int \n",
" num_attention_heads: int \n",
" layernorm_eps: float = 1e-5\n",
" attention_dropout: float = 0.1 \n",
" use_te_attention: bool = True # True for TE attention, False for Flax attention\n",
"\n",
" def setup(self):\n",
" self.kv_channels = self.hidden_size // self.num_attention_heads\n",
"\n",
" @nn.compact\n",
" def __call__(\n",
" self, \n",
" x: jnp.ndarray,\n",
" attention_mask: Optional[jnp.ndarray] = None,\n",
" deterministic: bool = False\n",
" ) -> jnp.ndarray:\n",
" # Create causal mask if not provided\n",
" if attention_mask is None:\n",
" attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_)\n",
" \n",
" res = x\n",
" x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x)\n",
"\n",
" # Fused QKV projection\n",
" qkv = te_flax.DenseGeneral(features=3 * self.hidden_size, use_bias=True)(x)\n",
" qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)\n",
" q, k, v = jnp.split(qkv, 3, axis=3)\n",
"\n",
" # Attention - either TE or Flax implementation\n",
" if self.use_te_attention:\n",
" # Use TE's DotProductAttention\n",
" attention = TEDotProductAttention(\n",
" head_dim=self.kv_channels,\n",
" num_attention_heads=self.num_attention_heads,\n",
" num_gqa_groups=self.num_attention_heads, # No GQA\n",
" attention_dropout=self.attention_dropout,\n",
" attn_mask_type='causal',\n",
" )\n",
" x = attention(q, k, v, sequence_descriptor=None, deterministic=deterministic)\n",
" # Reshape from [batch, seq_len, num_heads, head_dim] to [batch, seq_len, hidden_size]\n",
" x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))\n",
" x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x)\n",
" x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)\n",
" else:\n",
" # Use Flax's MultiHeadDotProductAttention\n",
" q_reshaped = q.reshape(q.shape[0], q.shape[1], self.hidden_size)\n",
" k_reshaped = k.reshape(k.shape[0], k.shape[1], self.hidden_size)\n",
" v_reshaped = v.reshape(v.shape[0], v.shape[1], self.hidden_size)\n",
" \n",
" attention = nn.MultiHeadDotProductAttention(\n",
" num_heads=self.num_attention_heads,\n",
" qkv_features=self.kv_channels,\n",
" dropout_rate=self.attention_dropout,\n",
" )\n",
" x = attention(q_reshaped, k_reshaped, v_reshaped, mask=attention_mask, deterministic=deterministic)\n",
"\n",
" x = res + x\n",
"\n",
" # Second residual connection\n",
" res = x\n",
" x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x)\n",
"\n",
" # MLP\n",
" mlp = TEUnfusedMLP(\n",
" hidden_size=self.hidden_size,\n",
" ffn_hidden_size=self.ffn_hidden_size\n",
" )\n",
"\n",
" x = mlp(x, deterministic=deterministic)\n",
"\n",
" return x + res"
]
},
{
"cell_type": "markdown",
"id": "a76911ac",
"metadata": {},
"source": [
"Testing performance of the model, using `DenseGeneral`, `LayerNorm` and activation from TE, while keeping Flax's `MultiHeadDotProductAttention` the same as the first simple Transformer in JAX implementation. To read more about this implementation from Flax, you can refer to this documentation: https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/attention.html"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "4b67511f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 16.505107879638672 ms\n"
]
}
],
"source": [
"te_unfused_transformer_with_flax_MHA = TEUnfusedTransformerLayer(\n",
" hidden_size, \n",
" ffn_hidden_size, \n",
" num_attention_heads,\n",
" use_te_attention=False\n",
")\n",
"\n",
"te_params = te_unfused_transformer_with_flax_MHA.init(key, x, attention_mask=None, deterministic=False)\n",
"\n",
"utils.speedometer(\n",
" model_apply_fn=te_unfused_transformer_with_flax_MHA.apply,\n",
" variables=te_params, # Ensure the correct `params` is passed\n",
" input=x,\n",
" output_grad=dy,\n",
" dropout_key=dropout_key,\n",
" forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
")\n"
]
},
{
"cell_type": "markdown",
"id": "0b230058",
"metadata": {},
"source": [
"Now, we move on to also replace the attention sub-layer with TE's `DotProductAttention` implementation"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "5146cd99",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 12.80329704284668 ms\n"
]
}
],
"source": [
"te_unfused_transformer = TEUnfusedTransformerLayer(\n",
" hidden_size, \n",
" ffn_hidden_size, \n",
" num_attention_heads,\n",
")\n",
"\n",
"te_params = te_unfused_transformer.init(key, x, attention_mask=None, deterministic=False)\n",
"\n",
"utils.speedometer(\n",
" model_apply_fn=te_unfused_transformer.apply,\n",
" variables=te_params, # Ensure the correct `params` is passed\n",
" input=x,\n",
" output_grad=dy,\n",
" dropout_key=dropout_key,\n",
" forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
")"
]
},
{
"cell_type": "markdown",
"id": "c9a101d3",
"metadata": {},
"source": [
"## Enabling Quantization (FP8 or FP4)\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Summary</b>\n",
" \n",
"We configure a TE module to perform compute in FP8.\n",
"\n",
"</div>\n",
"\n",
"Enabling FP8 support is very simple in Transformer Engine. We just need to wrap the modules within an [autocast](../api/jax.rst#transformer_engine.jax.fp8_autocast) context manager. See the [FP8 tutorial](fp8_primer.ipynb) for a detailed explanation of FP8 recipes and the supported options.\n",
"\n",
"<div class=\"alert alert-warning\">\n",
"\n",
"<b>Important: FP8 Metadata Initialization</b>\n",
"\n",
"When using FP8, the model **must be initialized within the `autocast` context**. This creates a special collection called `fp8_metas` that contains scaling factors and other metadata required for FP8 computation. If you initialize a model outside of `autocast` and then try to use it with FP8, you will get a `ScopeCollectionNotFound` error because the `fp8_metas` collection was never created.\n",
"\n",
"</div>"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "c2eee376",
"metadata": {},
"outputs": [],
"source": [
"from transformer_engine.common.recipe import Format, DelayedScaling\n",
"fp8_format = Format.HYBRID\n",
"fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo=\"max\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "de96827c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 9.615030288696289 ms\n"
]
}
],
"source": [
"with te.autocast(enabled=True, recipe=fp8_recipe):\n",
" te_unfused_params = te_unfused_transformer.init(key, x, attention_mask=None, deterministic=False)\n",
"\n",
" # Example usage of forward \n",
" y = te_unfused_transformer.apply(te_unfused_params, x, attention_mask=None, deterministic=True)\n",
"\n",
"utils.speedometer(\n",
" model_apply_fn=te_unfused_transformer.apply,\n",
" variables=te_unfused_params, # Ensure the correct `params` is passed\n",
" input=x,\n",
" output_grad=dy,\n",
" dropout_key=dropout_key,\n",
" forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
" autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe}\n",
")"
]
},
{
"cell_type": "markdown",
"id": "3801b201",
"metadata": {},
"source": [
"\n",
"## Fused TE Modules\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Summary</b>\n",
" \n",
"We optimize the example Transformer layer with TE modules for fused operations.\n",
"\n",
"</div>\n",
"\n",
"The `DenseGeneral` layer is enough to build any Transformer model and it enables usage of the Transformer Engine even for very custom Transformers. However, having more knowledge about the model allows for additional optimizations such as kernel fusions in mixed-precision recipes, increasing the achievable speedup.\n",
"\n",
"Transformer Engine therefore provides coarser modules that span multiple layers:\n",
"\n",
"* `LayerNormDenseGeneral`\n",
"* `LayerNormMLP`\n",
"* `TransformerLayer`\n",
"\n",
"To see a complete list of all the functions TE Flax support, you can view it here: https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/jax.html#modules\n",
"\n",
"Building a third iteration of our Transformer layer with `LayerNormDenseGeneral` and `LayerNormMLP`:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "11203785",
"metadata": {},
"outputs": [],
"source": [
"class TEFusedTransformerLayer(nn.Module):\n",
" hidden_size: int\n",
" ffn_hidden_size: int \n",
" num_attention_heads: int \n",
" layernorm_eps: float = 1e-5\n",
" attention_dropout: float = 0.1\n",
"\n",
" def setup(self):\n",
" self.kv_channels = self.hidden_size // self.num_attention_heads\n",
"\n",
" @nn.compact\n",
" def __call__(\n",
" self, \n",
" x: jnp.ndarray,\n",
" attention_mask: Optional[jnp.ndarray] = None,\n",
" deterministic: bool = False\n",
" ) -> jnp.ndarray:\n",
" res = x\n",
"\n",
" # Fused QKV projection\n",
" qkv,_ = te_flax.LayerNormDenseGeneral(features=3 * self.hidden_size, \n",
" epsilon=self.layernorm_eps, \n",
" use_bias=True, \n",
" return_layernorm_output=False)(x)\n",
" qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)\n",
" q, k, v = jnp.split(qkv, 3, axis=3)\n",
"\n",
" # Attention using TE's DotProductAttention\n",
" attention = TEDotProductAttention(\n",
" head_dim=self.kv_channels,\n",
" num_attention_heads=self.num_attention_heads,\n",
" num_gqa_groups=self.num_attention_heads, \n",
" attention_dropout=self.attention_dropout,\n",
" attn_mask_type='causal',\n",
" )\n",
" x = attention(q, k, v, sequence_descriptor=None, deterministic=deterministic)\n",
" # Reshape from [batch, seq_len, num_heads, head_dim] to [batch, seq_len, hidden_size]\n",
" x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))\n",
" x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x)\n",
" x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)\n",
"\n",
" x = res + x\n",
"\n",
" # Second residual connection\n",
" res = x\n",
" x,_ = te_flax.LayerNormMLP(intermediate_dim=self.ffn_hidden_size, \n",
" epsilon=self.layernorm_eps,\n",
" use_bias=True,\n",
" activations=('gelu',),\n",
" intermediate_dropout_rate=0.0,\n",
" return_layernorm_output=False\n",
" )(x, deterministic=deterministic)\n",
"\n",
" return x + res"
]
},
{
"cell_type": "markdown",
"id": "334cff59",
"metadata": {},
"source": [
"Similar to the unnfused model, we also compare the performance of fused model when using Flax's MultiheadDotProductAttention implementation and TE's."
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "6b0c705e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 9.331779479980469 ms\n"
]
}
],
"source": [
"te_fused_transformer = TEFusedTransformerLayer(\n",
" hidden_size, \n",
" ffn_hidden_size, \n",
" num_attention_heads\n",
")\n",
"\n",
"with te.autocast(enabled=True, recipe=fp8_recipe):\n",
" te_fused_params = te_fused_transformer.init(key, x, attention_mask=None, deterministic=False)\n",
" # Example usage of forward \n",
" y = te_fused_transformer.apply(te_fused_params, x, attention_mask=None, deterministic=True)\n",
"\n",
"utils.speedometer(\n",
" model_apply_fn=te_fused_transformer.apply,\n",
" variables=te_fused_params,\n",
" input=x,\n",
" output_grad=dy,\n",
" dropout_key=dropout_key,\n",
" forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
" autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe}\n",
")"
]
},
{
"cell_type": "markdown",
"id": "a45c12c8",
"metadata": {},
"source": [
"Finally, the `TransformerLayer` module is convenient for creating standard Transformer architectures."
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "b2aaa8ef",
"metadata": {},
"outputs": [],
"source": [
"\n",
"te_transformer = te_flax.TransformerLayer(\n",
" hidden_size=hidden_size,\n",
" mlp_hidden_size=ffn_hidden_size, \n",
" num_attention_heads=num_attention_heads,\n",
" mlp_activations=(\"gelu\",),\n",
" self_attn_mask_type='causal',\n",
" layernorm_epsilon=1e-5,\n",
" use_bias=True,\n",
" intermediate_dropout=0.0,\n",
" enable_relative_embedding=False,\n",
" self_attn_bias_type='no_bias',\n",
" hidden_dropout=0.0\n",
")\n",
"\n",
"with te.autocast(enabled=True, recipe=fp8_recipe):\n",
" te_transformer_params = te_transformer.init(key, x, deterministic=False)\n",
" y = te_transformer.apply(te_transformer_params, x, attention_mask=None, deterministic=True)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "b9cdbf22",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 9.23741340637207 ms\n"
]
}
],
"source": [
"utils.speedometer(\n",
" model_apply_fn=te_transformer.apply,\n",
" model_init_fn=te_transformer.init,\n",
" variables=te_transformer_params,\n",
" input=x,\n",
" output_grad=dy,\n",
" dropout_key=dropout_key,\n",
" forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
" autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe }\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import jax
import jax.numpy as jnp
import time
import math
from typing import Callable, Any, Dict, Optional, Tuple
from flax import linen as nn
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.flax.transformer import DotProductAttention as TEDotProductAttention
def speedometer(
model_apply_fn: Callable,
variables: Any,
input: jnp.ndarray,
output_grad: jnp.ndarray,
dropout_key: jax.random.PRNGKey,
model_init_fn: Callable = None,
forward_kwargs: dict = {},
autocast_kwargs: Optional[dict] = None,
timing_iters: int = 50,
warmup_iters: int = 50,
) -> None:
"""Measure average runtime for a JAX module
Perform forward and backward passes .
"""
if autocast_kwargs is None:
autocast_kwargs = {"enabled": False}
model_init_fn = None
train_step_fn = create_train_step_fn(model_apply_fn, autocast_kwargs, forward_kwargs)
# Warm up runs
key = dropout_key
for _ in range(warmup_iters):
key, step_key = jax.random.split(key)
loss, (param_grads, other_grads) = train_step_fn(variables, input, output_grad, step_key)
# Timing runs
start = time.time()
for _ in range(timing_iters):
key, step_key = jax.random.split(key)
loss, (param_grads, other_grads) = train_step_fn(variables, input, output_grad, step_key)
end = time.time()
print(f"Mean time: {(end - start) * 1000 / timing_iters} ms")
def create_train_step_fn(
model_apply_fn: Callable,
autocast_kwargs: Dict[str, Any],
forward_kwargs: Dict[str, Any] = None,
) -> Callable:
"""
Creates a JIT-compiled function that performs one forward/backward pass.
"""
if forward_kwargs is None:
forward_kwargs = {}
def loss_fn(variables: Any, inp: jnp.ndarray, grad_target: jnp.ndarray, dropout_key):
rngs = {"dropout": dropout_key}
with te.autocast(**autocast_kwargs):
# Forward Pass: Apply the model using current parameters and variables
call_kwargs = {**forward_kwargs, "rngs": rngs}
out = model_apply_fn(variables, inp, **call_kwargs)
# grad_target = derivative of L (loss fn) over y (output) = signma(L)/sigma(y)
# where grad_w(L) = gradient of loss over params = sigma(L)/sigma(y) * sigma(y)/sigma(w) --> chain rule
# sigma(y)/sigma(w) = J_model(w)
return jnp.vdot(out, grad_target)
def fwd_bwd_fn(*args, **kwargs):
return jax.value_and_grad(loss_fn, argnums=(0, 1))(*args, **kwargs)
# Use jax.value_and_grad to get the loss value and gradients simultaneously. (forward + backward pass)
# ∇_params[output^T · grad_target] = grad_target^T · J_output(params) = VJP
# fwd_bwd_fn = jax.value_and_grad(loss_fn, argnums=(0, 1))
# JIT-compile the fwd_bwd_fn
return jax.jit(fwd_bwd_fn)
...@@ -38,7 +38,7 @@ ...@@ -38,7 +38,7 @@
"\n", "\n",
"For those seeking a deeper understanding of text generation mechanisms in Transformers, it is recommended to check out the [HuggingFace generation tutorial](https://huggingface.co/docs/transformers/llm_tutorial).\n", "For those seeking a deeper understanding of text generation mechanisms in Transformers, it is recommended to check out the [HuggingFace generation tutorial](https://huggingface.co/docs/transformers/llm_tutorial).\n",
"\n", "\n",
"In a previous tutorial on [Llama](../te_llama/tutorial_accelerate_hf_llama_finetuning_with_te.ipynb), it was demonstrated how finetuning of an open-source Llama model can be accelerated using Transformer Engine's `TransformerLayer`. Building on that foundation, this tutorial showcases how to accelerate the token generation from the open-source Hugging Face Gemma 7B model.\n", "In a previous tutorial on [Llama](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb), it was demonstrated how finetuning of an open-source Llama model can be accelerated using Transformer Engine's `TransformerLayer`. Building on that foundation, this tutorial showcases how to accelerate the token generation from the open-source Hugging Face Gemma 7B model.\n",
"\n", "\n",
"This tutorial introduces several features of the Transformer Engine library that contribute towards this goal. A brief explanation is as follows:\n", "This tutorial introduces several features of the Transformer Engine library that contribute towards this goal. A brief explanation is as follows:\n",
"\n", "\n",
......
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
Getting Started
===============
Choose your framework to get started with Transformer Engine:
.. toctree::
:maxdepth: 1
PyTorch <examples/quickstart.ipynb>
JAX <examples/quickstart_jax.ipynb>
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
See LICENSE for license information. See LICENSE for license information.
Transformer Engine documentation Transformer Engine documentation
============================================== =================================
.. ifconfig:: "dev" in release .. ifconfig:: "dev" in release
...@@ -29,7 +29,7 @@ Transformer Engine documentation ...@@ -29,7 +29,7 @@ Transformer Engine documentation
:caption: Getting Started :caption: Getting Started
installation installation
examples/quickstart.ipynb getting_started
faq faq
.. toctree:: .. toctree::
......
...@@ -28,7 +28,7 @@ on `NVIDIA GPU Cloud <https://ngc.nvidia.com>`_. ...@@ -28,7 +28,7 @@ on `NVIDIA GPU Cloud <https://ngc.nvidia.com>`_.
pip - from PyPI pip - from PyPI
----------------------- ---------------
Transformer Engine can be directly installed from `our PyPI <https://pypi.org/project/transformer-engine/>`_, e.g. Transformer Engine can be directly installed from `our PyPI <https://pypi.org/project/transformer-engine/>`_, e.g.
...@@ -47,7 +47,7 @@ The core package from Transformer Engine (without any framework extensions) can ...@@ -47,7 +47,7 @@ The core package from Transformer Engine (without any framework extensions) can
By default, this will install the core library compiled for CUDA 12. The cuda major version can be specified by modified the extra dependency to `core_cu12` or `core_cu13`. By default, this will install the core library compiled for CUDA 12. The cuda major version can be specified by modified the extra dependency to `core_cu12` or `core_cu13`.
pip - from GitHub pip - from GitHub
----------------------- -----------------
Additional Prerequisites Additional Prerequisites
^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^
......
...@@ -30,13 +30,13 @@ do ...@@ -30,13 +30,13 @@ do
# Build Flash Attention # Build Flash Attention
if [ "${fa_version}" \< "3.0.0" ] if [ "${fa_version}" \< "3.0.0" ]
then then
pip3 install flash-attn==${fa_version} pip3 install flash-attn==${fa_version} --no-build-isolation
else else
git clone https://github.com/Dao-AILab/flash-attention.git git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/ && git checkout 27f501d && cd hopper/ && python setup.py install cd flash-attention/hopper && python setup.py install
python_path=`python -c "import site; print(site.getsitepackages()[0])"` python_path=`python -c "import site; print(site.getsitepackages()[0])"`
mkdir -p $python_path/flash_attn_3 mkdir -p $python_path/flash_attn_3
wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/27f501dbe011f4371bff938fe7e09311ab3002fa/hopper/flash_attn_interface.py cp flash_attn_interface.py $python_path/flash_attn_3/
cd ../../ cd ../../
fi fi
......
...@@ -18,6 +18,7 @@ from transformer_engine.jax.attention import ( ...@@ -18,6 +18,7 @@ from transformer_engine.jax.attention import (
is_fused_attn_kernel_available, is_fused_attn_kernel_available,
AttnBiasType, AttnBiasType,
AttnMaskType, AttnMaskType,
AttnSoftmaxType,
QKVLayout, QKVLayout,
QKVFormat, QKVFormat,
reorder_causal_load_balancing, reorder_causal_load_balancing,
...@@ -66,6 +67,7 @@ class TestDistributedSelfAttn: ...@@ -66,6 +67,7 @@ class TestDistributedSelfAttn:
bias_shape, bias_shape,
attn_mask_type, attn_mask_type,
dtype, dtype,
softmax_type,
use_shardy, use_shardy,
): ):
jax.config.update("jax_use_shardy_partitioner", use_shardy) jax.config.update("jax_use_shardy_partitioner", use_shardy)
...@@ -80,6 +82,7 @@ class TestDistributedSelfAttn: ...@@ -80,6 +82,7 @@ class TestDistributedSelfAttn:
QKVLayout.BS3HD, QKVLayout.BS3HD,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
dropout_prob, dropout_prob,
num_head, num_head,
num_head, num_head,
...@@ -109,6 +112,7 @@ class TestDistributedSelfAttn: ...@@ -109,6 +112,7 @@ class TestDistributedSelfAttn:
hidden, hidden,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
dropout_prob, dropout_prob,
dtype, dtype,
is_training, is_training,
...@@ -142,6 +146,14 @@ class TestDistributedSelfAttn: ...@@ -142,6 +146,14 @@ class TestDistributedSelfAttn:
], ],
) )
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize(
"softmax_type",
[
pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"),
pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
],
)
def test_self_attn( def test_self_attn(
self, self,
device_count, device_count,
...@@ -153,6 +165,7 @@ class TestDistributedSelfAttn: ...@@ -153,6 +165,7 @@ class TestDistributedSelfAttn:
bias_shape, bias_shape,
attn_mask_type, attn_mask_type,
dtype, dtype,
softmax_type,
): ):
self.impl_test_self_attn( self.impl_test_self_attn(
device_count, device_count,
...@@ -164,6 +177,7 @@ class TestDistributedSelfAttn: ...@@ -164,6 +177,7 @@ class TestDistributedSelfAttn:
bias_shape, bias_shape,
attn_mask_type, attn_mask_type,
dtype, dtype,
softmax_type,
use_shardy=False, use_shardy=False,
) )
...@@ -175,8 +189,23 @@ class TestDistributedSelfAttn: ...@@ -175,8 +189,23 @@ class TestDistributedSelfAttn:
pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"), pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"),
], ],
) )
@pytest.mark.parametrize(
"softmax_type",
[
pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"),
pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
],
)
def test_self_attn_shardy( def test_self_attn_shardy(
self, device_count, mesh_shape, mesh_axes, mesh_resource, attn_bias_type, bias_shape self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
attn_bias_type,
bias_shape,
softmax_type,
): ):
data_shape = (32, 512, 12, 64) data_shape = (32, 512, 12, 64)
self.impl_test_self_attn( self.impl_test_self_attn(
...@@ -189,6 +218,7 @@ class TestDistributedSelfAttn: ...@@ -189,6 +218,7 @@ class TestDistributedSelfAttn:
bias_shape, bias_shape,
AttnMaskType.PADDING_MASK, AttnMaskType.PADDING_MASK,
jnp.bfloat16, jnp.bfloat16,
softmax_type,
use_shardy=True, use_shardy=True,
) )
...@@ -213,8 +243,24 @@ class TestDistributedCrossAttn: ...@@ -213,8 +243,24 @@ class TestDistributedCrossAttn:
"attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK] "attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
) )
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize(
"softmax_type",
[
pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"),
pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
],
)
def test_cross_attn( def test_cross_attn(
self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, attn_mask_type, dtype self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
attn_mask_type,
dtype,
softmax_type,
): ):
attn_bias_type = AttnBiasType.NO_BIAS attn_bias_type = AttnBiasType.NO_BIAS
bias_shape = None bias_shape = None
...@@ -230,6 +276,7 @@ class TestDistributedCrossAttn: ...@@ -230,6 +276,7 @@ class TestDistributedCrossAttn:
QKVLayout.BSHD_BS2HD, QKVLayout.BSHD_BS2HD,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
dropout_prob, dropout_prob,
num_head, num_head,
num_head, num_head,
...@@ -252,6 +299,7 @@ class TestDistributedCrossAttn: ...@@ -252,6 +299,7 @@ class TestDistributedCrossAttn:
hidden, hidden,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
dropout_prob, dropout_prob,
dtype, dtype,
is_training, is_training,
...@@ -322,6 +370,8 @@ class TestDistributedContextParallelSelfAttn: ...@@ -322,6 +370,8 @@ class TestDistributedContextParallelSelfAttn:
bias_shape = None bias_shape = None
dropout_prob = 0.0 dropout_prob = 0.0
is_training = True is_training = True
# Context parallel does not support softmax_offset
softmax_type = AttnSoftmaxType.VANILLA_SOFTMAX
dp_size, cp_size, tp_size = mesh_shape dp_size, cp_size, tp_size = mesh_shape
batch, seqlen, num_head, hidden = data_shape batch, seqlen, num_head, hidden = data_shape
...@@ -343,6 +393,7 @@ class TestDistributedContextParallelSelfAttn: ...@@ -343,6 +393,7 @@ class TestDistributedContextParallelSelfAttn:
hidden, hidden,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
dropout_prob, dropout_prob,
dtype, dtype,
is_training, is_training,
...@@ -366,6 +417,7 @@ class TestDistributedContextParallelSelfAttn: ...@@ -366,6 +417,7 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout, qkv_layout,
attn_bias_type, attn_bias_type,
mask_type, mask_type,
softmax_type,
dropout_prob, dropout_prob,
num_head, num_head,
num_kv_heads, num_kv_heads,
......
...@@ -16,7 +16,7 @@ from distributed_test_base import generate_configs, generate_collectives_count ...@@ -16,7 +16,7 @@ from distributed_test_base import generate_configs, generate_collectives_count
from distributed_test_base import compare_ops from distributed_test_base import compare_ops
from utils import make_causal_mask, make_self_mask from utils import make_causal_mask, make_self_mask
from transformer_engine.jax import autocast from transformer_engine.jax import autocast
from transformer_engine.jax.softmax import SoftmaxType, softmax from transformer_engine.jax.softmax import SoftmaxFusionType, softmax
DTYPES = [jnp.float16, jnp.bfloat16] DTYPES = [jnp.float16, jnp.bfloat16]
...@@ -29,12 +29,12 @@ class TestDistributedSoftmax: ...@@ -29,12 +29,12 @@ class TestDistributedSoftmax:
return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0) return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)
def generate_inputs( def generate_inputs(
self, shape, mesh_resource, softmax_type, dtype, bad_sharding, broadcast_batch_mask self, shape, mesh_resource, softmax_fusion_type, dtype, bad_sharding, broadcast_batch_mask
): ):
batch, _, sqelen, _ = shape batch, _, sqelen, _ = shape
x = random.normal(random.PRNGKey(1124), shape, dtype=dtype) x = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
if softmax_type == SoftmaxType.SCALED_UPPER_TRIANG_MASKED: if softmax_fusion_type == SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED:
mask = make_causal_mask(batch, sqelen) mask = make_causal_mask(batch, sqelen)
else: else:
mask = make_self_mask(1 if broadcast_batch_mask else batch, sqelen) mask = make_self_mask(1 if broadcast_batch_mask else batch, sqelen)
...@@ -56,8 +56,10 @@ class TestDistributedSoftmax: ...@@ -56,8 +56,10 @@ class TestDistributedSoftmax:
return (x, mask), (x_pspec, mask_pspec) return (x, mask), (x_pspec, mask_pspec)
@staticmethod @staticmethod
def target_func(x, mask, scale_factor=1.0, softmax_type=SoftmaxType.SCALED): def target_func(x, mask, scale_factor=1.0, softmax_fusion_type=SoftmaxFusionType.SCALED):
return jnp.mean(softmax(x, mask, scale_factor=scale_factor, softmax_type=softmax_type)) return jnp.mean(
softmax(x, mask, scale_factor=scale_factor, softmax_fusion_type=softmax_fusion_type)
)
@staticmethod @staticmethod
def ref_func(x, mask, scale_factor=1.0, dtype=jnp.float16): def ref_func(x, mask, scale_factor=1.0, dtype=jnp.float16):
...@@ -80,24 +82,29 @@ class TestDistributedSoftmax: ...@@ -80,24 +82,29 @@ class TestDistributedSoftmax:
mesh_axes, mesh_axes,
mesh_resource, mesh_resource,
data_shape, data_shape,
softmax_type, softmax_fusion_type,
scale_factor, scale_factor,
dtype, dtype,
bad_sharding, bad_sharding,
broadcast_batch_mask, broadcast_batch_mask,
use_shardy, use_shardy,
): ):
if broadcast_batch_mask and softmax_type != SoftmaxType.SCALED_MASKED: if broadcast_batch_mask and softmax_fusion_type != SoftmaxFusionType.SCALED_MASKED:
pytest.skip("Softmax type has no mask.") pytest.skip("Softmax type has no mask.")
jax.config.update("jax_use_shardy_partitioner", use_shardy) jax.config.update("jax_use_shardy_partitioner", use_shardy)
target_func = partial( target_func = partial(
self.target_func, scale_factor=scale_factor, softmax_type=softmax_type self.target_func, scale_factor=scale_factor, softmax_fusion_type=softmax_fusion_type
) )
ref_func = partial(self.ref_func, scale_factor=scale_factor, dtype=dtype) ref_func = partial(self.ref_func, scale_factor=scale_factor, dtype=dtype)
(x, mask), (x_pspec, mask_pspec) = self.generate_inputs( (x, mask), (x_pspec, mask_pspec) = self.generate_inputs(
data_shape, mesh_resource, softmax_type, dtype, bad_sharding, broadcast_batch_mask data_shape,
mesh_resource,
softmax_fusion_type,
dtype,
bad_sharding,
broadcast_batch_mask,
) )
collective_count_ref = self.generate_collectives_count_ref() collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
...@@ -139,8 +146,12 @@ class TestDistributedSoftmax: ...@@ -139,8 +146,12 @@ class TestDistributedSoftmax:
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [8, 8, 1024, 1024]]) @pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [8, 8, 1024, 1024]])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"softmax_type", "softmax_fusion_type",
[SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED], [
SoftmaxFusionType.SCALED,
SoftmaxFusionType.SCALED_MASKED,
SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED,
],
) )
@pytest.mark.parametrize("scale_factor", [1.0, 3.0]) @pytest.mark.parametrize("scale_factor", [1.0, 3.0])
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
...@@ -153,7 +164,7 @@ class TestDistributedSoftmax: ...@@ -153,7 +164,7 @@ class TestDistributedSoftmax:
mesh_axes, mesh_axes,
mesh_resource, mesh_resource,
data_shape, data_shape,
softmax_type, softmax_fusion_type,
scale_factor, scale_factor,
dtype, dtype,
bad_sharding, bad_sharding,
...@@ -165,7 +176,7 @@ class TestDistributedSoftmax: ...@@ -165,7 +176,7 @@ class TestDistributedSoftmax:
mesh_axes, mesh_axes,
mesh_resource, mesh_resource,
data_shape, data_shape,
softmax_type, softmax_fusion_type,
scale_factor, scale_factor,
dtype, dtype,
bad_sharding, bad_sharding,
...@@ -174,7 +185,9 @@ class TestDistributedSoftmax: ...@@ -174,7 +185,9 @@ class TestDistributedSoftmax:
) )
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("softmax_type", [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED]) @pytest.mark.parametrize(
"softmax_fusion_type", [SoftmaxFusionType.SCALED, SoftmaxFusionType.SCALED_MASKED]
)
@pytest.mark.parametrize("bad_sharding", [False, True]) @pytest.mark.parametrize("bad_sharding", [False, True])
@pytest.mark.parametrize("broadcast_batch_mask", [False, True]) @pytest.mark.parametrize("broadcast_batch_mask", [False, True])
def test_softmax_gspmd( def test_softmax_gspmd(
...@@ -183,7 +196,7 @@ class TestDistributedSoftmax: ...@@ -183,7 +196,7 @@ class TestDistributedSoftmax:
mesh_shape, mesh_shape,
mesh_axes, mesh_axes,
mesh_resource, mesh_resource,
softmax_type, softmax_fusion_type,
bad_sharding, bad_sharding,
broadcast_batch_mask, broadcast_batch_mask,
): ):
...@@ -193,7 +206,7 @@ class TestDistributedSoftmax: ...@@ -193,7 +206,7 @@ class TestDistributedSoftmax:
mesh_axes, mesh_axes,
mesh_resource, mesh_resource,
data_shape=[32, 12, 128, 128], data_shape=[32, 12, 128, 128],
softmax_type=softmax_type, softmax_fusion_type=softmax_fusion_type,
scale_factor=1.0, scale_factor=1.0,
dtype=DTYPES[0], dtype=DTYPES[0],
bad_sharding=bad_sharding, bad_sharding=bad_sharding,
......
...@@ -27,6 +27,7 @@ from transformer_engine.jax.sharding import MeshResource ...@@ -27,6 +27,7 @@ from transformer_engine.jax.sharding import MeshResource
from transformer_engine.jax.attention import ( from transformer_engine.jax.attention import (
AttnBiasType, AttnBiasType,
AttnMaskType, AttnMaskType,
AttnSoftmaxType,
QKVLayout, QKVLayout,
QKVFormat, QKVFormat,
reorder_causal_load_balancing, reorder_causal_load_balancing,
...@@ -59,14 +60,16 @@ def init(): ...@@ -59,14 +60,16 @@ def init():
yield yield
@partial(jax.jit, static_argnums=(5, 6, 7, 9)) @partial(jax.jit, static_argnums=(6, 7, 8, 9, 11))
def general_dot_product_attention( def general_dot_product_attention(
query: ArrayLike, query: ArrayLike,
key: ArrayLike, key: ArrayLike,
value: ArrayLike, value: ArrayLike,
softmax_offset: Optional[ArrayLike],
bias: ArrayLike, bias: ArrayLike,
mask: ArrayLike, mask: ArrayLike,
deterministic: bool, deterministic: bool,
softmax_type: AttnSoftmaxType,
scale_factor: float, scale_factor: float,
dropout_rate: float, dropout_rate: float,
dropout_rng: ArrayLike, dropout_rng: ArrayLike,
...@@ -99,7 +102,25 @@ def general_dot_product_attention( ...@@ -99,7 +102,25 @@ def general_dot_product_attention(
mask = jnp.expand_dims(mask, axis=-3) mask = jnp.expand_dims(mask, axis=-3)
logits = jnp.where(mask, jnp.finfo(dtype).min, logits) logits = jnp.where(mask, jnp.finfo(dtype).min, logits)
softmax_out = jax.nn.softmax(logits).astype(dtype) match softmax_type:
case AttnSoftmaxType.VANILLA_SOFTMAX:
softmax_out = jax.nn.softmax(logits).astype(dtype)
case AttnSoftmaxType.OFF_BY_ONE_SOFTMAX:
# Softmax with +1 in denominator: exp(x_i) / (sum(exp(x_j)) + 1)
# Append a zero logit, apply standard softmax, then remove last column
zero_logit = jnp.zeros(logits.shape[:-1] + (1,), dtype=logits.dtype)
logits_with_extra = jnp.concatenate([logits, zero_logit], axis=-1)
softmax_with_extra = jax.nn.softmax(logits_with_extra, axis=-1)
softmax_out = softmax_with_extra[..., :-1].astype(dtype)
case AttnSoftmaxType.LEARNABLE_SOFTMAX:
# Append learnable offset logit, apply standard softmax, then remove last column
learnable_logit = softmax_offset.reshape(1, h_kv, num_groups, 1, 1)
learnable_logit = jnp.broadcast_to(learnable_logit, logits.shape[:-1] + (1,))
logits_with_extra = jnp.concatenate([logits, learnable_logit], axis=-1)
softmax_with_extra = jax.nn.softmax(logits_with_extra, axis=-1)
softmax_out = softmax_with_extra[..., :-1].astype(dtype)
case _:
raise NotImplementedError(f"Unknown {softmax_type=}")
if not deterministic and dropout_rate > 0.0: if not deterministic and dropout_rate > 0.0:
keep_prob = 1.0 - dropout_rate keep_prob = 1.0 - dropout_rate
...@@ -238,7 +259,7 @@ def _split_valid_and_invalid(primitive, reference, pad): ...@@ -238,7 +259,7 @@ def _split_valid_and_invalid(primitive, reference, pad):
return primitive_valid, primitive_invalid, reference_valid, reference_invalid return primitive_valid, primitive_invalid, reference_valid, reference_invalid
def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs): def jax_dpa(query, key, value, bias, softmax_offset, mask, dropout_rng, **kwargs):
""" """
JAX native dot product attention implementation JAX native dot product attention implementation
""" """
...@@ -246,11 +267,13 @@ def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs): ...@@ -246,11 +267,13 @@ def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs):
query, query,
key, key,
value, value,
softmax_offset,
bias, bias,
mask, mask,
deterministic=not kwargs["is_training"], deterministic=not kwargs["is_training"],
scale_factor=kwargs["scaling_factor"], scale_factor=kwargs["scaling_factor"],
dropout_rate=kwargs["dropout_probability"], dropout_rate=kwargs["dropout_probability"],
softmax_type=kwargs["softmax_type"],
dropout_rng=dropout_rng, dropout_rng=dropout_rng,
dtype=jnp.float32, dtype=jnp.float32,
) )
...@@ -262,6 +285,7 @@ def customcall_fused_dpa( ...@@ -262,6 +285,7 @@ def customcall_fused_dpa(
key, key,
value, value,
bias, bias,
softmax_offset,
sequence_descriptor, sequence_descriptor,
dropout_rng, dropout_rng,
**kwargs, **kwargs,
...@@ -283,9 +307,9 @@ def customcall_fused_dpa( ...@@ -283,9 +307,9 @@ def customcall_fused_dpa(
qkv_args = (query, key, value) qkv_args = (query, key, value)
case _: case _:
raise ValueError(f"Unsupported {qkv_layout=}") raise ValueError(f"Unsupported {qkv_layout=}")
return fused_attn(qkv_args, bias, sequence_descriptor, dropout_rng, **kwargs).astype( return fused_attn(
query.dtype qkv_args, bias, sequence_descriptor, dropout_rng, softmax_offset=softmax_offset, **kwargs
) ).astype(query.dtype)
class BiasShape(Enum): class BiasShape(Enum):
...@@ -320,6 +344,7 @@ class FusedAttnRunner: ...@@ -320,6 +344,7 @@ class FusedAttnRunner:
head_dim_v: int head_dim_v: int
attn_bias_type: AttnBiasType attn_bias_type: AttnBiasType
attn_mask_type: AttnMaskType attn_mask_type: AttnMaskType
softmax_type: AttnSoftmaxType
dropout_prob: float dropout_prob: float
dtype: DTypeLike dtype: DTypeLike
is_training: bool is_training: bool
...@@ -402,6 +427,7 @@ class FusedAttnRunner: ...@@ -402,6 +427,7 @@ class FusedAttnRunner:
self.qkv_layout, self.qkv_layout,
self.attn_bias_type, self.attn_bias_type,
self.attn_mask_type, self.attn_mask_type,
self.softmax_type,
self.dropout_prob, self.dropout_prob,
self.num_heads_q, self.num_heads_q,
self.num_heads_kv, self.num_heads_kv,
...@@ -439,7 +465,7 @@ class FusedAttnRunner: ...@@ -439,7 +465,7 @@ class FusedAttnRunner:
self.tp_size = self.mesh.shape.get(self.mesh_resource.tpsp_resource, 1) self.tp_size = self.mesh.shape.get(self.mesh_resource.tpsp_resource, 1)
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5) q_key, k_key, v_key, bias_key, dropout_key, softmax_key = jax.random.split(key, 6)
q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim_qk) q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim_qk)
k_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_qk) k_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_qk)
...@@ -490,6 +516,13 @@ class FusedAttnRunner: ...@@ -490,6 +516,13 @@ class FusedAttnRunner:
else: else:
pad_ratio = 0.0 pad_ratio = 0.0
if self.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX:
self.softmax_offset = jax.random.uniform(
softmax_key, (1, self.num_heads_q, 1, 1), jnp.float32, -1.0
)
else:
self.softmax_offset = None
def gen_valid(bs, max_seqlen, pad_ratio): def gen_valid(bs, max_seqlen, pad_ratio):
pad_len = int(max_seqlen * pad_ratio) pad_len = int(max_seqlen * pad_ratio)
valid_len = max_seqlen - pad_len valid_len = max_seqlen - pad_len
...@@ -713,6 +746,16 @@ class FusedAttnRunner: ...@@ -713,6 +746,16 @@ class FusedAttnRunner:
self.bias_pspec = PartitionSpec() self.bias_pspec = PartitionSpec()
self.bias_sharding = NamedSharding(self.mesh, self.bias_pspec) self.bias_sharding = NamedSharding(self.mesh, self.bias_pspec)
# Softmax offset sharding (1, num_heads, 1, 1)
# Use the same logic as HEAD_AXES: tpsp_resource if enabled, else tp_resource
head_resource = (
self.mesh_resource.tpsp_resource
if self.mesh_resource.tpsp_resource is not None
else self.mesh_resource.tp_resource
)
self.softmax_offset_pspec = PartitionSpec(None, head_resource, None, None)
self.softmax_offset_sharding = NamedSharding(self.mesh, self.softmax_offset_pspec)
self.dropout_rng_pspec = PartitionSpec( self.dropout_rng_pspec = PartitionSpec(
None, None,
) )
...@@ -732,7 +775,7 @@ class FusedAttnRunner: ...@@ -732,7 +775,7 @@ class FusedAttnRunner:
""" """
self._setup_inputs() self._setup_inputs()
args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng] args = [self.q, self.k, self.v, self.bias, self.softmax_offset, self.mask, self.dropout_rng]
customcall_args = [ customcall_args = [
# Put test data onto each GPU for distributed. # Put test data onto each GPU for distributed.
...@@ -742,12 +785,14 @@ class FusedAttnRunner: ...@@ -742,12 +785,14 @@ class FusedAttnRunner:
jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding), jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding), jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
jax.device_put(self.bias, self.bias_sharding), jax.device_put(self.bias, self.bias_sharding),
jax.device_put(self.softmax_offset, self.softmax_offset_sharding),
jax.device_put(self.sequence_desciptor, self.seq_desc_sharding), jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
jax.device_put(self.dropout_rng, self.dropout_rng_sharding), jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
] ]
kwargs = { kwargs = {
"attn_bias_type": self.attn_bias_type, "attn_bias_type": self.attn_bias_type,
"attn_mask_type": self.attn_mask_type, "attn_mask_type": self.attn_mask_type,
"softmax_type": self.softmax_type,
"scaling_factor": self.scaling_factor, "scaling_factor": self.scaling_factor,
"dropout_probability": self.dropout_prob, "dropout_probability": self.dropout_prob,
"is_training": self.is_training, "is_training": self.is_training,
...@@ -766,6 +811,7 @@ class FusedAttnRunner: ...@@ -766,6 +811,7 @@ class FusedAttnRunner:
self.qkvo_sharding, self.qkvo_sharding,
self.qkvo_sharding, self.qkvo_sharding,
self.bias_sharding, self.bias_sharding,
self.softmax_offset_sharding,
self.seq_desc_sharding, self.seq_desc_sharding,
self.dropout_rng_sharding, self.dropout_rng_sharding,
], ],
...@@ -826,7 +872,7 @@ class FusedAttnRunner: ...@@ -826,7 +872,7 @@ class FusedAttnRunner:
jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier
).astype(self.dtype) ).astype(self.dtype)
args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng] args = [self.q, self.k, self.v, self.bias, self.softmax_offset, self.mask, self.dropout_rng]
customcall_args = [ customcall_args = [
# TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and # TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and
# THD params once we support those features on CP. # THD params once we support those features on CP.
...@@ -834,12 +880,14 @@ class FusedAttnRunner: ...@@ -834,12 +880,14 @@ class FusedAttnRunner:
jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding), jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding), jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
jax.device_put(self.bias, self.bias_sharding), jax.device_put(self.bias, self.bias_sharding),
jax.device_put(self.softmax_offset, self.softmax_offset_sharding),
jax.device_put(self.sequence_desciptor, self.seq_desc_sharding), jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
jax.device_put(self.dropout_rng, self.dropout_rng_sharding), jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
] ]
kwargs = { kwargs = {
"attn_bias_type": self.attn_bias_type, "attn_bias_type": self.attn_bias_type,
"attn_mask_type": self.attn_mask_type, "attn_mask_type": self.attn_mask_type,
"softmax_type": self.softmax_type,
"scaling_factor": self.scaling_factor, "scaling_factor": self.scaling_factor,
"dropout_probability": self.dropout_prob, "dropout_probability": self.dropout_prob,
"is_training": self.is_training, "is_training": self.is_training,
...@@ -866,8 +914,16 @@ class FusedAttnRunner: ...@@ -866,8 +914,16 @@ class FusedAttnRunner:
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
jitted_primitive = jit( jitted_primitive = jit(
value_and_grad( value_and_grad(
lambda q, k, v, bias, *args: grad_func( lambda q, k, v, bias, softmax_offset, *args: grad_func(
customcall_fused_dpa, q, k, v, bias, *args, cp_reverse_out=True, **kwargs customcall_fused_dpa,
q,
k,
v,
bias,
softmax_offset,
*args,
cp_reverse_out=True,
**kwargs,
), ),
arg_nums, arg_nums,
), ),
...@@ -876,6 +932,7 @@ class FusedAttnRunner: ...@@ -876,6 +932,7 @@ class FusedAttnRunner:
self.qkvo_sharding, self.qkvo_sharding,
self.qkvo_sharding, self.qkvo_sharding,
self.bias_sharding, self.bias_sharding,
self.softmax_offset_sharding,
self.seq_desc_sharding, self.seq_desc_sharding,
self.dropout_rng_sharding, self.dropout_rng_sharding,
), ),
...@@ -883,7 +940,9 @@ class FusedAttnRunner: ...@@ -883,7 +940,9 @@ class FusedAttnRunner:
) )
jitted_reference = jit( jitted_reference = jit(
value_and_grad( value_and_grad(
lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs), lambda q, k, v, bias, softmax_offset, *args: grad_func(
jax_dpa, q, k, v, bias, softmax_offset, *args, **kwargs
),
arg_nums, arg_nums,
) )
) )
...@@ -976,6 +1035,14 @@ class FusedAttnRunner: ...@@ -976,6 +1035,14 @@ class FusedAttnRunner:
), ),
], ],
) )
@pytest.mark.parametrize(
"softmax_type",
[
pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"),
pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"qkv_layout", "qkv_layout",
[ [
...@@ -1084,6 +1151,7 @@ class TestFusedAttn: ...@@ -1084,6 +1151,7 @@ class TestFusedAttn:
d_v, d_v,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
dropout_prob, dropout_prob,
dtype, dtype,
is_training, is_training,
...@@ -1110,6 +1178,7 @@ class TestFusedAttn: ...@@ -1110,6 +1178,7 @@ class TestFusedAttn:
d_v, d_v,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
dropout_prob, dropout_prob,
dtype, dtype,
is_training, is_training,
...@@ -1138,6 +1207,7 @@ class TestFusedAttn: ...@@ -1138,6 +1207,7 @@ class TestFusedAttn:
d_v, d_v,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
dropout_prob, dropout_prob,
dtype, dtype,
qkv_layout, qkv_layout,
...@@ -1161,6 +1231,7 @@ class TestFusedAttn: ...@@ -1161,6 +1231,7 @@ class TestFusedAttn:
d_v, d_v,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
dropout_prob, dropout_prob,
dtype, dtype,
True, True,
......
...@@ -83,6 +83,7 @@ _KEY_OF_FLOAT32_ATTENTION_LOGITS = "float32_attention_logits" ...@@ -83,6 +83,7 @@ _KEY_OF_FLOAT32_ATTENTION_LOGITS = "float32_attention_logits"
_KEY_OF_USE_BIAS = "use_bias" _KEY_OF_USE_BIAS = "use_bias"
_KEY_OF_RELATIVE_EMBEDDING = "enable_relative_embedding" _KEY_OF_RELATIVE_EMBEDDING = "enable_relative_embedding"
_KEY_OF_WINDOW_SIZE = "window_size" _KEY_OF_WINDOW_SIZE = "window_size"
_KEY_OF_SOFTMAX_TYPE = "softmax_type"
BASE_ATTRS = { BASE_ATTRS = {
_KEY_OF_TRANSPOSE_BS: True, _KEY_OF_TRANSPOSE_BS: True,
...@@ -276,6 +277,14 @@ ATTRS = [ ...@@ -276,6 +277,14 @@ ATTRS = [
_KEY_OF_RELATIVE_EMBEDDING: True, _KEY_OF_RELATIVE_EMBEDDING: True,
_KEY_OF_SELF_ATTN_BIAS_TYPE: "post_scale_bias", _KEY_OF_SELF_ATTN_BIAS_TYPE: "post_scale_bias",
}, },
# attrs31
{
_KEY_OF_SOFTMAX_TYPE: "off_by_one",
},
# attrs31
{
_KEY_OF_SOFTMAX_TYPE: "learnable",
},
] ]
ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS] ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
...@@ -418,6 +427,9 @@ class EncoderRunner(BaseRunner): ...@@ -418,6 +427,9 @@ class EncoderRunner(BaseRunner):
"attention/qkv/ln_bias": "pre_attention_layer_norm/ln_bias", "attention/qkv/ln_bias": "pre_attention_layer_norm/ln_bias",
"attention/query/scale": "pre_attention_layer_norm/scale", "attention/query/scale": "pre_attention_layer_norm/scale",
"attention/query/ln_bias": "pre_attention_layer_norm/ln_bias", "attention/query/ln_bias": "pre_attention_layer_norm/ln_bias",
"attention/DotProductAttention_0/_UnfusedDotProductAttention_0/softmax_offset": (
"attention/DotProductAttention_0/softmax_offset"
),
"mlp/wi_kernel": "mlp/wi/kernel", "mlp/wi_kernel": "mlp/wi/kernel",
"mlp/wi_bias": "mlp/wi/bias", "mlp/wi_bias": "mlp/wi/bias",
"mlp/wo_kernel": "mlp/wo/kernel", "mlp/wo_kernel": "mlp/wo/kernel",
...@@ -463,10 +475,16 @@ class DecoderRunner(BaseRunner): ...@@ -463,10 +475,16 @@ class DecoderRunner(BaseRunner):
"encoder_decoder_attention/qkv/ln_bias": "pre_cross_attention_layer_norm/ln_bias", "encoder_decoder_attention/qkv/ln_bias": "pre_cross_attention_layer_norm/ln_bias",
"encoder_decoder_attention/query/scale": "pre_cross_attention_layer_norm/scale", "encoder_decoder_attention/query/scale": "pre_cross_attention_layer_norm/scale",
"encoder_decoder_attention/query/ln_bias": "pre_cross_attention_layer_norm/ln_bias", "encoder_decoder_attention/query/ln_bias": "pre_cross_attention_layer_norm/ln_bias",
"encoder_decoder_attention/DotProductAttention_0/_UnfusedDotProductAttention_0/softmax_offset": (
"encoder_decoder_attention/DotProductAttention_0/softmax_offset"
),
"self_attention/qkv/scale": "pre_self_attention_layer_norm/scale", "self_attention/qkv/scale": "pre_self_attention_layer_norm/scale",
"self_attention/qkv/ln_bias": "pre_self_attention_layer_norm/ln_bias", "self_attention/qkv/ln_bias": "pre_self_attention_layer_norm/ln_bias",
"self_attention/query/scale": "pre_self_attention_layer_norm/scale", "self_attention/query/scale": "pre_self_attention_layer_norm/scale",
"self_attention/query/ln_bias": "pre_self_attention_layer_norm/ln_bias", "self_attention/query/ln_bias": "pre_self_attention_layer_norm/ln_bias",
"self_attention/DotProductAttention_0/_UnfusedDotProductAttention_0/softmax_offset": (
"self_attention/DotProductAttention_0/softmax_offset"
),
"mlp/wi_kernel": "mlp/wi/kernel", "mlp/wi_kernel": "mlp/wi/kernel",
"mlp/wi_bias": "mlp/wi/bias", "mlp/wi_bias": "mlp/wi/bias",
"mlp/wo_kernel": "mlp/wo/kernel", "mlp/wo_kernel": "mlp/wo/kernel",
......
...@@ -17,7 +17,8 @@ from jax.typing import DTypeLike ...@@ -17,7 +17,8 @@ from jax.typing import DTypeLike
from utils import assert_allclose from utils import assert_allclose
from transformer_engine.jax.cpp_extensions import is_softmax_kernel_available from transformer_engine.jax.cpp_extensions import is_softmax_kernel_available
from transformer_engine.jax.softmax import SoftmaxType, softmax from transformer_engine.jax.cpp_extensions.attention import AttnSoftmaxType
from transformer_engine.jax.softmax import SoftmaxFusionType, softmax
from transformer_engine.jax.flax.module import Softmax from transformer_engine.jax.flax.module import Softmax
...@@ -50,8 +51,9 @@ class SoftmaxRunner: ...@@ -50,8 +51,9 @@ class SoftmaxRunner:
max_seqlen_kv: int max_seqlen_kv: int
num_heads: int num_heads: int
scale_factor: float scale_factor: float
softmax_type: SoftmaxType softmax_fusion_type: SoftmaxFusionType
dtype: DTypeLike dtype: DTypeLike
softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
@staticmethod @staticmethod
def reference_softmax(logits, mask, scale_factor, **_): def reference_softmax(logits, mask, scale_factor, **_):
...@@ -68,6 +70,7 @@ class SoftmaxRunner: ...@@ -68,6 +70,7 @@ class SoftmaxRunner:
def _is_support(self): def _is_support(self):
return is_softmax_kernel_available( return is_softmax_kernel_available(
self.softmax_fusion_type,
self.softmax_type, self.softmax_type,
self.batch_size, self.batch_size,
self.num_heads, self.num_heads,
...@@ -85,22 +88,22 @@ class SoftmaxRunner: ...@@ -85,22 +88,22 @@ class SoftmaxRunner:
self.logits = jax.random.uniform(logits_key, logits_shape, self.dtype, -1.0) self.logits = jax.random.uniform(logits_key, logits_shape, self.dtype, -1.0)
match self.softmax_type: match self.softmax_fusion_type:
case SoftmaxType.SCALED: case SoftmaxFusionType.SCALED:
self.mask = None self.mask = None
case SoftmaxType.SCALED_MASKED: case SoftmaxFusionType.SCALED_MASKED:
self.mask = jax.random.bernoulli(mask_key, shape=mask_shape).astype(jnp.uint8) self.mask = jax.random.bernoulli(mask_key, shape=mask_shape).astype(jnp.uint8)
case SoftmaxType.SCALED_UPPER_TRIANG_MASKED: case SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED:
self.mask = (1.0 - jnp.tril(jnp.ones_like(self.logits))).astype(jnp.uint8) self.mask = (1.0 - jnp.tril(jnp.ones_like(self.logits))).astype(jnp.uint8)
case _: case _:
raise ValueError(f"Unknown {self.softmax_type=}") raise ValueError(f"Unknown {self.softmax_fusion_type=}")
def test_forward(self): def test_forward(self):
""" """
Test transformer_engine.jax.softmax.softmax fwd rule Test transformer_engine.jax.softmax.softmax fwd rule
""" """
self._setup_inputs() self._setup_inputs()
primitive_out = softmax(self.logits, self.mask, self.scale_factor, self.softmax_type) primitive_out = softmax(self.logits, self.mask, self.scale_factor, self.softmax_fusion_type)
reference_out = __class__.reference_softmax(self.logits, self.mask, self.scale_factor) reference_out = __class__.reference_softmax(self.logits, self.mask, self.scale_factor)
assert_allclose(primitive_out, reference_out, dtype=self.dtype) assert_allclose(primitive_out, reference_out, dtype=self.dtype)
...@@ -117,7 +120,7 @@ class SoftmaxRunner: ...@@ -117,7 +120,7 @@ class SoftmaxRunner:
args = [self.logits, self.mask] args = [self.logits, self.mask]
kwargs = { kwargs = {
"scale_factor": self.scale_factor, "scale_factor": self.scale_factor,
"softmax_type": self.softmax_type, "softmax_fusion_type": self.softmax_fusion_type,
} }
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
...@@ -175,7 +178,7 @@ class SoftmaxModuleRunner: ...@@ -175,7 +178,7 @@ class SoftmaxModuleRunner:
rng = jax.random.PRNGKey(0) rng = jax.random.PRNGKey(0)
softmax_module = Softmax( softmax_module = Softmax(
scale_factor=runner.scale_factor, scale_factor=runner.scale_factor,
softmax_type=runner.softmax_type, softmax_fusion_type=runner.softmax_fusion_type,
) )
softmax_vars = softmax_module.init(rng, runner.logits, runner.mask) softmax_vars = softmax_module.init(rng, runner.logits, runner.mask)
module_out = softmax_module.apply(softmax_vars, runner.logits, runner.mask) module_out = softmax_module.apply(softmax_vars, runner.logits, runner.mask)
...@@ -194,11 +197,11 @@ class SoftmaxModuleRunner: ...@@ -194,11 +197,11 @@ class SoftmaxModuleRunner:
) )
@pytest.mark.parametrize("scale_factor", [0.125]) @pytest.mark.parametrize("scale_factor", [0.125])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"softmax_type", "softmax_fusion_type",
[ [
pytest.param(SoftmaxType.SCALED, id="SCALED"), pytest.param(SoftmaxFusionType.SCALED, id="SCALED"),
pytest.param(SoftmaxType.SCALED_MASKED, id="SCALED_MASKED"), pytest.param(SoftmaxFusionType.SCALED_MASKED, id="SCALED_MASKED"),
pytest.param(SoftmaxType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"), pytest.param(SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"),
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -214,19 +217,19 @@ class TestSoftmaxPrimitives: ...@@ -214,19 +217,19 @@ class TestSoftmaxPrimitives:
""" """
@staticmethod @staticmethod
def test_forward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype): def test_forward(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype):
""" """
Test forward with parameterized configs Test forward with parameterized configs
""" """
runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype) runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype)
runner.test_forward() runner.test_forward()
@staticmethod @staticmethod
def test_backward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype): def test_backward(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype):
""" """
Test forward with parameterized configs Test forward with parameterized configs
""" """
runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype) runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype)
runner.test_backward() runner.test_backward()
...@@ -243,11 +246,11 @@ class TestSoftmaxPrimitives: ...@@ -243,11 +246,11 @@ class TestSoftmaxPrimitives:
) )
@pytest.mark.parametrize("scale_factor", [0.125]) @pytest.mark.parametrize("scale_factor", [0.125])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"softmax_type", "softmax_fusion_type",
[ [
pytest.param(SoftmaxType.SCALED, id="SCALED"), pytest.param(SoftmaxFusionType.SCALED, id="SCALED"),
pytest.param(SoftmaxType.SCALED_MASKED, id="SCALED_MASKED"), pytest.param(SoftmaxFusionType.SCALED_MASKED, id="SCALED_MASKED"),
pytest.param(SoftmaxType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"), pytest.param(SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"),
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -263,11 +266,11 @@ class TestSoftmaxModule: ...@@ -263,11 +266,11 @@ class TestSoftmaxModule:
""" """
@staticmethod @staticmethod
def test_forward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype): def test_forward(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype):
""" """
Test forward with parameterized configs Test forward with parameterized configs
""" """
module_runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype) module_runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype)
bias = None bias = None
runner = SoftmaxModuleRunner(module_runner, bias) runner = SoftmaxModuleRunner(module_runner, bias)
runner.test_forward() runner.test_forward()
...@@ -21,6 +21,7 @@ from jax import random as jax_random ...@@ -21,6 +21,7 @@ from jax import random as jax_random
import pytest import pytest
from transformer_engine.jax.attention import ( from transformer_engine.jax.attention import (
AttnSoftmaxType,
canonicalize_attn_mask_type, canonicalize_attn_mask_type,
make_swa_mask, make_swa_mask,
) )
...@@ -162,6 +163,7 @@ class DotProductAttention(nn.Module): ...@@ -162,6 +163,7 @@ class DotProductAttention(nn.Module):
dropout_rate: float = 0.0 dropout_rate: float = 0.0
dtype: DType = jnp.float32 dtype: DType = jnp.float32
float32_logits: bool = False float32_logits: bool = False
softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
"""Computes dot-product attention given query, key, and value. """Computes dot-product attention given query, key, and value.
This is the core function for applying attention based on This is the core function for applying attention based on
...@@ -211,6 +213,24 @@ class DotProductAttention(nn.Module): ...@@ -211,6 +213,24 @@ class DotProductAttention(nn.Module):
assert key.shape[-2] == value.shape[-2], "k, v num_heads must match." assert key.shape[-2] == value.shape[-2], "k, v num_heads must match."
assert query.shape[-1] == key.shape[-1], "q, k head_dim must match." assert query.shape[-1] == key.shape[-1], "q, k head_dim must match."
# Infer number of attention heads from query shape
# query shape: [..., h, d] where h is num_attention_heads
num_attention_heads = query.shape[-2]
# Initialize softmax_offset for off-by-one or learnable softmax
softmax_offset = None
if self.softmax_type == AttnSoftmaxType.OFF_BY_ONE_SOFTMAX:
# For off-by-one softmax, use zeros with shape (1, h, 1, 1)
softmax_offset = jnp.zeros((1, num_attention_heads, 1, 1), dtype=input_dtype)
elif self.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX:
# For learnable softmax, create a learnable parameter with shape (1, h, 1, 1)
softmax_offset = self.param(
"softmax_offset",
nn.initializers.zeros,
(1, num_attention_heads, 1, 1),
jnp.float32,
)
if self.scale_attn_logits: if self.scale_attn_logits:
head_dim = query.shape[-1] head_dim = query.shape[-1]
depth_scaling = jnp.sqrt(head_dim).astype(input_dtype) depth_scaling = jnp.sqrt(head_dim).astype(input_dtype)
...@@ -241,9 +261,23 @@ class DotProductAttention(nn.Module): ...@@ -241,9 +261,23 @@ class DotProductAttention(nn.Module):
if bias is not None: if bias is not None:
attn_weights = attn_weights + bias.astype(attn_weights.dtype) attn_weights = attn_weights + bias.astype(attn_weights.dtype)
# Add attention sink to the last column if not vanilla softmax
if self.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
# Add extra column with softmax_offset
# softmax_offset shape: (1, h, 1, 1), attn_weights shape: [b, h, q, k]
extra_col = jnp.broadcast_to(
softmax_offset,
(attn_weights.shape[0], attn_weights.shape[1], attn_weights.shape[2], 1),
)
attn_weights = jnp.concatenate([attn_weights, extra_col], axis=-1)
# Normalize the attention weights across `kv_length` dimension. # Normalize the attention weights across `kv_length` dimension.
attn_weights = jax_nn.softmax(attn_weights).astype(input_dtype) attn_weights = jax_nn.softmax(attn_weights).astype(input_dtype)
# Remove the extra column after softmax if not vanilla softmax
if self.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
attn_weights = attn_weights[..., :-1]
# Apply attention dropout. # Apply attention dropout.
if not deterministic and self.dropout_rate > 0.0: if not deterministic and self.dropout_rate > 0.0:
keep_prob = 1.0 - self.dropout_rate keep_prob = 1.0 - self.dropout_rate
...@@ -535,6 +569,7 @@ class MultiHeadAttention(nn.Module): ...@@ -535,6 +569,7 @@ class MultiHeadAttention(nn.Module):
rotary_pos_emb_group_method: str = "consecutive" rotary_pos_emb_group_method: str = "consecutive"
fuse_qkv: bool = True fuse_qkv: bool = True
use_bias: bool = False use_bias: bool = False
softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
...@@ -801,6 +836,7 @@ class MultiHeadAttention(nn.Module): ...@@ -801,6 +836,7 @@ class MultiHeadAttention(nn.Module):
dropout_rate=self.dropout_rate, dropout_rate=self.dropout_rate,
dtype=self.dtype, dtype=self.dtype,
float32_logits=self.float32_logits, float32_logits=self.float32_logits,
softmax_type=self.softmax_type,
)(query, key, value, bias=attention_bias, deterministic=deterministic) )(query, key, value, bias=attention_bias, deterministic=deterministic)
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3])) x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
...@@ -1058,6 +1094,7 @@ class EncoderLayer(nn.Module): ...@@ -1058,6 +1094,7 @@ class EncoderLayer(nn.Module):
self_attn_bias_type: Any = None self_attn_bias_type: Any = None
self_attn_mask_type: str = "no_mask" self_attn_mask_type: str = "no_mask"
window_size: Tuple[int, int] = (-1, -1) window_size: Tuple[int, int] = (-1, -1)
softmax_type: str = "vanilla"
def __post_init__(self): def __post_init__(self):
if self.num_gqa_groups is None: if self.num_gqa_groups is None:
...@@ -1111,6 +1148,9 @@ class EncoderLayer(nn.Module): ...@@ -1111,6 +1148,9 @@ class EncoderLayer(nn.Module):
else: else:
x = inputs x = inputs
# Convert softmax_type string to AttnSoftmaxType enum
attn_softmax_type = AttnSoftmaxType.from_str(self.softmax_type)
# [batch, length, emb_dim] -> [batch, length, emb_dim] # [batch, length, emb_dim] -> [batch, length, emb_dim]
x = MultiHeadAttention( x = MultiHeadAttention(
num_heads=self.num_attention_heads, num_heads=self.num_attention_heads,
...@@ -1126,6 +1166,7 @@ class EncoderLayer(nn.Module): ...@@ -1126,6 +1166,7 @@ class EncoderLayer(nn.Module):
enable_rotary_pos_emb=self.enable_rotary_pos_emb, enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
use_bias=self.use_bias, use_bias=self.use_bias,
softmax_type=attn_softmax_type,
name="attention", name="attention",
)(x, x, encoder_mask, encoder_bias, deterministic=deterministic) )(x, x, encoder_mask, encoder_bias, deterministic=deterministic)
x = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)( x = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
...@@ -1222,6 +1263,7 @@ class DecoderLayer(nn.Module): ...@@ -1222,6 +1263,7 @@ class DecoderLayer(nn.Module):
self_attn_bias_type: Any = None self_attn_bias_type: Any = None
self_attn_mask_type: str = "no_mask" self_attn_mask_type: str = "no_mask"
window_size: Tuple[int, int] = (-1, -1) window_size: Tuple[int, int] = (-1, -1)
softmax_type: str = "vanilla"
def __post_init__(self): def __post_init__(self):
if self.num_gqa_groups is None: if self.num_gqa_groups is None:
...@@ -1290,6 +1332,9 @@ class DecoderLayer(nn.Module): ...@@ -1290,6 +1332,9 @@ class DecoderLayer(nn.Module):
else: else:
x = inputs x = inputs
# Convert softmax_type string to AttnSoftmaxType enum
attn_softmax_type = AttnSoftmaxType.from_str(self.softmax_type)
# Self-attention block # Self-attention block
x = MultiHeadAttention( x = MultiHeadAttention(
num_heads=self.num_attention_heads, num_heads=self.num_attention_heads,
...@@ -1305,6 +1350,7 @@ class DecoderLayer(nn.Module): ...@@ -1305,6 +1350,7 @@ class DecoderLayer(nn.Module):
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
fuse_qkv=self.fuse_qkv_params, fuse_qkv=self.fuse_qkv_params,
use_bias=self.use_bias, use_bias=self.use_bias,
softmax_type=attn_softmax_type,
name="self_attention", name="self_attention",
)(x, x, decoder_mask, decoder_bias, deterministic=deterministic, decode=decode) )(x, x, decoder_mask, decoder_bias, deterministic=deterministic, decode=decode)
x = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)( x = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
...@@ -1343,6 +1389,7 @@ class DecoderLayer(nn.Module): ...@@ -1343,6 +1389,7 @@ class DecoderLayer(nn.Module):
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
fuse_qkv=self.fuse_qkv_params, fuse_qkv=self.fuse_qkv_params,
use_bias=self.use_bias, use_bias=self.use_bias,
softmax_type=attn_softmax_type,
name="encoder_decoder_attention", name="encoder_decoder_attention",
)(y, encoded, encoder_decoder_mask, deterministic=deterministic) )(y, encoded, encoder_decoder_mask, deterministic=deterministic)
y = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)( y = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
......
...@@ -89,40 +89,47 @@ def generate_input_shapes( ...@@ -89,40 +89,47 @@ def generate_input_shapes(
cu_seqlens_q_padded = None cu_seqlens_q_padded = None
cu_seqlens_kv_padded = None cu_seqlens_kv_padded = None
elif qkv_format == "thd": elif qkv_format == "thd":
seqlens_q = torch.randint(0, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32)
seqlens_q_padded = (seqlens_q + 2 * world_size - 1) // (world_size * 2) * (world_size * 2)
cu_seqlens_q_padded = torch.cat(
[
torch.zeros([1], dtype=torch.int32),
seqlens_q_padded.cumsum(0, dtype=torch.int32),
]
).cuda()
cu_seqlens_q = torch.clone(cu_seqlens_q_padded)
# Since FlashAttention doesn't support pad b/w sequences, and FusedAttention does,
# cu_seqlens_q is updated to reflect non-padded lengths for FusedAttention only.
if kernel_backend == "FusedAttention":
cu_seqlens_q[1:] = seqlens_q.cumsum(0, dtype=torch.int32).cuda()
# NOTE: In case of Cross-Attention, `cu_seqlens_kv` and `cu_seqlens_kv_padded`
# will not be the same as `cu_seqlens_q` and `cu_seqlens_q_padded` respectively.
cu_seqlens_kv = cu_seqlens_q
cu_seqlens_kv_padded = cu_seqlens_q_padded
total_tokens = cu_seqlens_q_padded[-1]
q_input_shape = ( q_input_shape = (
config.batch_size * config.max_seqlen_q, total_tokens,
config.num_heads, config.num_heads,
config.head_dim_qk, config.head_dim_qk,
) )
k_input_shape = ( k_input_shape = (
config.batch_size * config.max_seqlen_q, total_tokens,
config.num_gqa_groups, config.num_gqa_groups,
config.head_dim_qk, config.head_dim_qk,
) )
v_input_shape = ( v_input_shape = (
config.batch_size * config.max_seqlen_q, total_tokens,
config.num_gqa_groups, config.num_gqa_groups,
config.head_dim_v, config.head_dim_v,
) )
attn_output_shape = ( attn_output_shape = (
config.batch_size * config.max_seqlen_q, total_tokens,
config.num_heads * config.head_dim_v, config.num_heads * config.head_dim_v,
) )
seqlens_q = torch.randint(0, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32)
seqlens_q_padded = (seqlens_q + 2 * world_size - 1) // (world_size * 2) * (world_size * 2)
cu_seqlens_q_padded = torch.cat(
[
torch.zeros([1], dtype=torch.int32),
seqlens_q_padded.cumsum(0, dtype=torch.int32),
torch.tensor([q_input_shape[0]], dtype=torch.int32),
]
).cuda()
cu_seqlens_q = torch.clone(cu_seqlens_q_padded)
if kernel_backend == "FusedAttention":
cu_seqlens_q[1:-1] = seqlens_q.cumsum(0, dtype=torch.int32).cuda()
cu_seqlens_q[-1] = cu_seqlens_q[-2]
cu_seqlens_kv = cu_seqlens_q
cu_seqlens_kv_padded = cu_seqlens_q_padded
else: else:
assert False, f"{qkv_format=} is not supported!" assert False, f"{qkv_format=} is not supported!"
......
...@@ -117,7 +117,14 @@ model_configs_base = { ...@@ -117,7 +117,14 @@ model_configs_base = {
@pytest.mark.parametrize("swa", [False]) @pytest.mark.parametrize("swa", [False])
@pytest.mark.parametrize("pad_between_seqs", [False]) @pytest.mark.parametrize("pad_between_seqs", [False])
def test_dot_product_attention( def test_dot_product_attention(
dtype, model_configs, model, ckpt_attn, workspace_opt, qkv_layout, swa, pad_between_seqs dtype,
model_configs,
model,
ckpt_attn,
workspace_opt,
qkv_layout,
swa,
pad_between_seqs,
): ):
"""Test DotProductAttention module""" """Test DotProductAttention module"""
...@@ -308,6 +315,31 @@ def test_dpa_max_logit(dtype, model_configs, model, qkv_layout): ...@@ -308,6 +315,31 @@ def test_dpa_max_logit(dtype, model_configs, model, qkv_layout):
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False) test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)
model_configs_num_splits = {
# test: ModelConfig(b, sq, hq, dqk)
"num_splits_1_0": ModelConfig(2, 2048, 24, 128, num_splits=2),
"num_splits_1_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096, num_splits=4),
}
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_num_splits])
@pytest.mark.parametrize("model", model_configs_num_splits.keys())
def test_dpa_num_splits(dtype, model_configs, model):
"""Test DotProductAttention with FlashAttention-3 num_splits enabled"""
test_dot_product_attention(
dtype,
model_configs,
model,
False,
True,
None,
False,
False,
)
model_configs_softmax = { model_configs_softmax = {
# test: ModelConfig(b, sq, hq, dqk) # test: ModelConfig(b, sq, hq, dqk)
"softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8), "softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8),
...@@ -1153,6 +1185,8 @@ def _run_dot_product_attention( ...@@ -1153,6 +1185,8 @@ def _run_dot_product_attention(
core_attention_bias=bias, core_attention_bias=bias,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
fast_zero_fill=True, fast_zero_fill=True,
# Only pass num_splits when exercising the FlashAttention path
num_splits=config.num_splits if backend == "FlashAttention" else 1,
) )
max_logit = None max_logit = None
if config.return_max_logit: if config.return_max_logit:
...@@ -1787,9 +1821,10 @@ def test_mha_fp8_vs_f16( ...@@ -1787,9 +1821,10 @@ def test_mha_fp8_vs_f16(
fp8_meta=fp8_meta, fp8_meta=fp8_meta,
is_training=is_training, is_training=is_training,
) )
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends flash_attn_supported, fused_attn_supported_fp8, unfused_attn_supported = available_backends
if flash_attn_supported + fused_attn_supported < 1: if flash_attn_supported + fused_attn_supported_fp8 < 1:
pytest.skip("No FP8 attention backend available.") pytest.skip("No FP8 attention backend available.")
fused_attn_supported_f16 = False
if not fp8_dpa_bwd: if not fp8_dpa_bwd:
available_backends, _, fused_attn_backends = get_available_attention_backends( available_backends, _, fused_attn_backends = get_available_attention_backends(
config, config,
...@@ -1797,8 +1832,8 @@ def test_mha_fp8_vs_f16( ...@@ -1797,8 +1832,8 @@ def test_mha_fp8_vs_f16(
qkv_layout=qkv_format.replace("hd", "h3d"), qkv_layout=qkv_format.replace("hd", "h3d"),
is_training=is_training, is_training=is_training,
) )
_, fused_attn_supported, _ = available_backends _, fused_attn_supported_f16, _ = available_backends
if not fused_attn_supported: if not fused_attn_supported_f16:
pytest.skip("No attention backend available.") pytest.skip("No attention backend available.")
if flash_attn_supported: if flash_attn_supported:
...@@ -1810,23 +1845,28 @@ def test_mha_fp8_vs_f16( ...@@ -1810,23 +1845,28 @@ def test_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
) )
os.environ["NVTE_FLASH_ATTN"] = "0" if fused_attn_supported_fp8:
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FLASH_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True os.environ["NVTE_FUSED_ATTN"] = "1"
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") _attention_backends["backend_selection_requires_update"] = True
fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16( logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
) dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
)
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False") if fused_attn_supported_f16:
fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16( os.environ["NVTE_FLASH_ATTN"] = "0"
dtype, config, False, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe os.environ["NVTE_FUSED_ATTN"] = "1"
) _attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
dtype, config, False, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
)
atol = 5e-1 atol = 5e-1
rtol = 5e-1 rtol = 5e-1
rmse_tol = 0.15 rmse_tol = 0.15
if flash_attn_supported: if flash_attn_supported and fused_attn_supported_f16:
logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:")) logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:"))
logging.debug("========== {:^25s} ==========".format("forward output")) logging.debug("========== {:^25s} ==========".format("forward output"))
compare_and_assert( compare_and_assert(
...@@ -1839,32 +1879,33 @@ def test_mha_fp8_vs_f16( ...@@ -1839,32 +1879,33 @@ def test_mha_fp8_vs_f16(
rmse_tol, rmse_tol,
True, True,
) )
logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:")) if fused_attn_supported_fp8 and fused_attn_supported_f16:
logging.debug("========== {:^25s} ==========".format("forward output")) logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:"))
compare_and_assert( logging.debug("========== {:^25s} ==========".format("forward output"))
fused_attn_fwd_fp8, compare_and_assert(
fused_attn_fwd_f16, fused_attn_fwd_fp8,
"fused_attn_fwd_fp8", fused_attn_fwd_f16,
"fused_attn_fwd_f16", "fused_attn_fwd_fp8",
atol, "fused_attn_fwd_f16",
rtol, atol,
rmse_tol, rtol,
True, rmse_tol,
) True,
)
if is_training: if is_training:
for i in range(len(param_names[:1])): for i in range(len(param_names[:1])):
logging.debug("========== {:^25s} ==========".format(param_names[i])) logging.debug("========== {:^25s} ==========".format(param_names[i]))
compare_and_assert( compare_and_assert(
fused_attn_bwd_fp8[i], fused_attn_bwd_fp8[i],
fused_attn_bwd_f16[i], fused_attn_bwd_f16[i],
f"fused_attn_bwd_fp8[{i}]", f"fused_attn_bwd_fp8[{i}]",
f"fused_attn_bwd_f16[{i}]", f"fused_attn_bwd_f16[{i}]",
atol, atol,
rtol, rtol,
rmse_tol, rmse_tol,
True, True,
) )
def _run_mha_fp8_vs_f16( def _run_mha_fp8_vs_f16(
...@@ -2490,7 +2531,6 @@ class _custom_mha_fp8(torch.autograd.Function): ...@@ -2490,7 +2531,6 @@ class _custom_mha_fp8(torch.autograd.Function):
max_s: int, max_s: int,
fast_zero_fill: bool, fast_zero_fill: bool,
fp8_meta: Dict[str, Any], fp8_meta: Dict[str, Any],
workspace: torch.Tensor,
is_training: bool, is_training: bool,
mask_type: str, mask_type: str,
quantizers: list[Quantizer], quantizers: list[Quantizer],
...@@ -2519,7 +2559,6 @@ class _custom_mha_fp8(torch.autograd.Function): ...@@ -2519,7 +2559,6 @@ class _custom_mha_fp8(torch.autograd.Function):
qkv, *_ = ext.general_gemm( qkv, *_ = ext.general_gemm(
qkv_weight_fp8, qkv_weight_fp8,
inp_fp8, inp_fp8,
workspace,
bias=qkv_bias, bias=qkv_bias,
out_dtype=qkv_weight_fp8.dtype, out_dtype=qkv_weight_fp8.dtype,
quantization_params=qkv_quantizer, quantization_params=qkv_quantizer,
...@@ -2561,9 +2600,7 @@ class _custom_mha_fp8(torch.autograd.Function): ...@@ -2561,9 +2600,7 @@ class _custom_mha_fp8(torch.autograd.Function):
s_quantizer=s_quantizer, s_quantizer=s_quantizer,
) )
tensors_to_save, tensor_objects = prepare_for_saving( tensors_to_save, tensor_objects = prepare_for_saving(q, k, v, inp_fp8, qkv_weight_fp8, out)
q, k, v, inp_fp8, qkv_weight_fp8, workspace, out
)
ctx.save_for_backward(*tensors_to_save) ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects ctx.tensor_objects = tensor_objects
...@@ -2593,7 +2630,7 @@ class _custom_mha_fp8(torch.autograd.Function): ...@@ -2593,7 +2630,7 @@ class _custom_mha_fp8(torch.autograd.Function):
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
with torch.cuda.nvtx.range("_DPA"): with torch.cuda.nvtx.range("_DPA"):
saved_tensors = ctx.saved_tensors saved_tensors = ctx.saved_tensors
(q, k, v, inp_fp8, qkv_weight_fp8, workspace, out) = restore_from_saved( (q, k, v, inp_fp8, qkv_weight_fp8, out) = restore_from_saved(
ctx.tensor_objects, saved_tensors ctx.tensor_objects, saved_tensors
) )
...@@ -2649,7 +2686,6 @@ class _custom_mha_fp8(torch.autograd.Function): ...@@ -2649,7 +2686,6 @@ class _custom_mha_fp8(torch.autograd.Function):
qkv_dgrad, *_ = ext.general_gemm( qkv_dgrad, *_ = ext.general_gemm(
qkv_weight_fp8, qkv_weight_fp8,
dqkv_c, dqkv_c,
workspace,
ctx.dtype, ctx.dtype,
use_split_accumulator=_2X_ACC_DGRAD, use_split_accumulator=_2X_ACC_DGRAD,
layout="NN", layout="NN",
...@@ -2659,7 +2695,6 @@ class _custom_mha_fp8(torch.autograd.Function): ...@@ -2659,7 +2695,6 @@ class _custom_mha_fp8(torch.autograd.Function):
qkv_wgrad, *_ = ext.general_gemm( qkv_wgrad, *_ = ext.general_gemm(
inp_fp8, inp_fp8,
dqkv, dqkv,
workspace,
ctx.dtype, ctx.dtype,
use_split_accumulator=_2X_ACC_WGRAD, use_split_accumulator=_2X_ACC_WGRAD,
layout="NT", layout="NT",
...@@ -2710,9 +2745,6 @@ class Custom_MHA_FP8(TransformerEngineBaseModule): ...@@ -2710,9 +2745,6 @@ class Custom_MHA_FP8(TransformerEngineBaseModule):
with torch.no_grad(): with torch.no_grad():
self.qkv_bias.zero_() self.qkv_bias.zero_()
self.qkv_weight.fill_(1.0) self.qkv_weight.fill_(1.0)
self.workspace = torch.empty(
_CUBLASLT_WORKSPACE_SIZE_BYTES, dtype=torch.int8, device="cuda"
)
def forward( def forward(
self, self,
...@@ -2731,7 +2763,6 @@ class Custom_MHA_FP8(TransformerEngineBaseModule): ...@@ -2731,7 +2763,6 @@ class Custom_MHA_FP8(TransformerEngineBaseModule):
max_s, max_s,
self.fast_zero_fill, self.fast_zero_fill,
self.fp8_meta, self.fp8_meta,
self.workspace,
self.training, self.training,
self.mask_type, self.mask_type,
self.quantizers, self.quantizers,
......
...@@ -7,7 +7,7 @@ import subprocess ...@@ -7,7 +7,7 @@ import subprocess
import sys import sys
import pathlib import pathlib
import logging import logging
import copy
import pytest import pytest
import torch import torch
from transformer_engine.pytorch import ( from transformer_engine.pytorch import (
...@@ -74,7 +74,7 @@ dtypes = ["bf16", "fp16"] ...@@ -74,7 +74,7 @@ dtypes = ["bf16", "fp16"]
qkv_formats = ["bshd", "sbhd", "thd"] qkv_formats = ["bshd", "sbhd", "thd"]
cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"]
if test_essential: if test_essential:
configs = ["cp_1_0", "cp_2_1", "cp_3_2", "cp_3_3"] configs = ["cp_1_0", "cp_1_2", "cp_2_1", "cp_3_2", "cp_3_3"]
model_configs_flash_attn = {k: model_configs_flash_attn[k] for k in configs} model_configs_flash_attn = {k: model_configs_flash_attn[k] for k in configs}
dtypes = ["bf16"] dtypes = ["bf16"]
qkv_formats = ["sbhd", "thd"] qkv_formats = ["sbhd", "thd"]
...@@ -97,12 +97,16 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): ...@@ -97,12 +97,16 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("CP implementation with KV P2P does not support sliding window yet!") pytest.skip("CP implementation with KV P2P does not support sliding window yet!")
if cp_comm_type == "all_gather" and qkv_format == "thd":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with KV all-gather does not support bias yet!") pytest.skip("CP implementation with KV all-gather does not support bias yet!")
if "a2a" in cp_comm_type and qkv_format == "thd": if qkv_format == "thd":
pytest.skip("CP implementation with QKVO A2A does not support THD format yet!") if cp_comm_type == "all_gather":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if cp_comm_type == "a2a+p2p":
pytest.skip(
"CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format"
" yet!"
)
if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias": if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with QKVO A2A does not support bias yet!") pytest.skip("CP implementation with QKVO A2A does not support bias yet!")
if "a2a" in cp_comm_type and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0): if "a2a" in cp_comm_type and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0):
...@@ -184,7 +188,7 @@ dtypes = ["bf16", "fp16", "fp8"] ...@@ -184,7 +188,7 @@ dtypes = ["bf16", "fp16", "fp8"]
qkv_formats = ["bshd", "sbhd", "thd"] qkv_formats = ["bshd", "sbhd", "thd"]
cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"]
if test_essential: if test_essential:
configs = ["cp_1_0", "cp_1_1", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"] configs = ["cp_1_0", "cp_1_1", "cp_1_4", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"]
model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs}
dtypes = ["bf16", "fp8"] dtypes = ["bf16", "fp8"]
qkv_formats = ["sbhd", "thd"] qkv_formats = ["sbhd", "thd"]
...@@ -225,10 +229,14 @@ def test_cp_with_fused_attention( ...@@ -225,10 +229,14 @@ def test_cp_with_fused_attention(
if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias": if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias":
pytest.skip("THD format does not support post_scale_bias yet!") pytest.skip("THD format does not support post_scale_bias yet!")
if qkv_format == "thd" and cp_comm_type == "all_gather": if qkv_format == "thd":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!") if cp_comm_type == "all_gather":
if qkv_format == "thd" and "a2a" in cp_comm_type: pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
pytest.skip("CP implementation with QKVO A2A does not support THD format yet!") if cp_comm_type == "a2a+p2p":
pytest.skip(
"CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format"
" yet!"
)
if dtype == "fp8" and cp_comm_type == "all_gather": if dtype == "fp8" and cp_comm_type == "all_gather":
pytest.skip( pytest.skip(
"CP implementation with KV all-gather does not support FP8 + context parallelism yet!" "CP implementation with KV all-gather does not support FP8 + context parallelism yet!"
...@@ -282,6 +290,14 @@ def test_cp_with_fused_attention( ...@@ -282,6 +290,14 @@ def test_cp_with_fused_attention(
) )
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
if qkv_format == "thd":
config = copy.deepcopy(config)
if "causal" in config.attn_mask_type:
config.attn_mask_type = "padding_causal"
else:
config.attn_mask_type = "padding"
fp8_meta = {} fp8_meta = {}
fp8_meta["recipe"] = None fp8_meta["recipe"] = None
fp8_meta["local_recipes"] = [] fp8_meta["local_recipes"] = []
......
...@@ -82,7 +82,6 @@ def _fp8_gemm_kernel(tensor1, scale1, dtype1, tensor2, scale2, dtype2, use_split ...@@ -82,7 +82,6 @@ def _fp8_gemm_kernel(tensor1, scale1, dtype1, tensor2, scale2, dtype2, use_split
out, *_ = tepytorch.cpp_extensions.general_gemm( out, *_ = tepytorch.cpp_extensions.general_gemm(
fp8_tensor1, fp8_tensor1,
fp8_tensor2, fp8_tensor2,
tepytorch.module.base.get_workspace(),
torch.float32, torch.float32,
use_split_accumulator=use_split_accumulator, use_split_accumulator=use_split_accumulator,
) )
...@@ -199,7 +198,6 @@ def _emulate_linear( ...@@ -199,7 +198,6 @@ def _emulate_linear(
wgrad, *_ = tepytorch.cpp_extensions.general_gemm( wgrad, *_ = tepytorch.cpp_extensions.general_gemm(
wgrad_input, wgrad_input,
wgrad_gradient, wgrad_gradient,
tepytorch.module.base.get_workspace(),
torch.float32, torch.float32,
layout="NT", layout="NT",
grad=True, grad=True,
......
...@@ -25,10 +25,8 @@ from transformer_engine.pytorch import ( ...@@ -25,10 +25,8 @@ from transformer_engine.pytorch import (
MXFP8Quantizer, MXFP8Quantizer,
) )
import transformer_engine.pytorch.cpp_extensions as tex import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.module.base import ( from transformer_engine.pytorch.cpp_extensions.gemm import get_cublas_workspace_size_bytes
fill_userbuffers_buffer_for_all_gather, from transformer_engine.pytorch.module.base import fill_userbuffers_buffer_for_all_gather
get_cublas_workspace_size_bytes,
)
warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=FutureWarning)
...@@ -420,10 +418,6 @@ def _main(opts): ...@@ -420,10 +418,6 @@ def _main(opts):
std=opts.std, std=opts.std,
) )
# Allocate cuBLAS workspace
workspace_size = 1 * get_cublas_workspace_size_bytes()
workspace = torch.empty(workspace_size, dtype=torch.uint8, device="cuda")
# Gather global tensors and calculate reference result (need these first for Fp8 scales) # Gather global tensors and calculate reference result (need these first for Fp8 scales)
if opts.bulk_overlap: if opts.bulk_overlap:
ker_g = torch.transpose(kernel_t, 0, 1) ker_g = torch.transpose(kernel_t, 0, 1)
...@@ -620,7 +614,6 @@ def _main(opts): ...@@ -620,7 +614,6 @@ def _main(opts):
return tex.general_gemm( return tex.general_gemm(
kernel_t_fp8, kernel_t_fp8,
gemm_inp, gemm_inp,
workspace,
out_dtype=torch.float8_e4m3fn if opts.fp8_output else torch.bfloat16, out_dtype=torch.float8_e4m3fn if opts.fp8_output else torch.bfloat16,
quantization_params=out_quantizer, quantization_params=out_quantizer,
use_split_accumulator=te.module.base._2X_ACC_FPROP, use_split_accumulator=te.module.base._2X_ACC_FPROP,
...@@ -638,7 +631,6 @@ def _main(opts): ...@@ -638,7 +631,6 @@ def _main(opts):
return tex.general_gemm( return tex.general_gemm(
kernel2_t_fp8, kernel2_t_fp8,
gemm2_inp, gemm2_inp,
workspace,
out_dtype=torch.float8_e4m3fn if opts.fp8_output else torch.bfloat16, out_dtype=torch.float8_e4m3fn if opts.fp8_output else torch.bfloat16,
quantization_params=out2_quantizer, quantization_params=out2_quantizer,
use_split_accumulator=te.module.base._2X_ACC_FPROP, use_split_accumulator=te.module.base._2X_ACC_FPROP,
...@@ -651,7 +643,6 @@ def _main(opts): ...@@ -651,7 +643,6 @@ def _main(opts):
return tex.general_gemm( return tex.general_gemm(
kernel_t, kernel_t,
gemm_inp, gemm_inp,
workspace,
out_dtype=torch.bfloat16, out_dtype=torch.bfloat16,
use_split_accumulator=te.module.base._2X_ACC_FPROP, use_split_accumulator=te.module.base._2X_ACC_FPROP,
ub=ub_obj, ub=ub_obj,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment