Commit 970620a5 authored by wenjh's avatar wenjh
Browse files

merge nv_release_v2.10 to release_v2.10


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents c1a1c04e 769ed778
......@@ -100,7 +100,7 @@
"\n",
"</div>\n",
"\n",
"A variety of parallelism strategies can be used to enable multi-GPU training of Transformer models, often based on different approaches to distribute their $\\text{sequence_length} \\times \\text{batch_size} \\times \\text{hidden_size}$ activation tensors. The most common approach is data parallelism, which distributes along the $\\text{batch_size}$ dimension. By storing duplicate copies of the model on each GPU, the forward and backward passes of the training step can be done independently, followed by a gradient synchronization. A more advanced strategy is tensor parallelism, a type of model parallelism that distributes along the $\\text{hidden_size}$ dimension. This allows us to scale past the limits of data parallelism (typically $\\text{hidden_size} > \\text{batch_size}$) and to reduce the per-GPU memory usage (since model parameters are also distributed), but it also incurs the overhead of communicating activation tensors between GPUs at every step. For a more detailed explanation, please see the [Megatron-LM paper](https://arxiv.org/pdf/1909.08053.pdf). Finally, sequence parallelism distributes along the $\\text{sequence_length}$ dimension. This can be used when tensor parallelism is enabled in order to parallelize operations that run outside the tensor-parallel region (e.g. layer norm). For more details, please see [this paper](https://arxiv.org/pdf/2205.05198.pdf).\n",
"A variety of parallelism strategies can be used to enable multi-GPU training of Transformer models, often based on different approaches to distribute their $\\text{sequence_length} \\cdot \\text{batch_size} \\cdot \\text{hidden_size}$ activation tensors. The most common approach is data parallelism, which distributes along the $\\text{batch_size}$ dimension. By storing duplicate copies of the model on each GPU, the forward and backward passes of the training step can be done independently, followed by a gradient synchronization. A more advanced strategy is tensor parallelism, a type of model parallelism that distributes along the $\\text{hidden_size}$ dimension. This allows us to scale past the limits of data parallelism (typically $\\text{hidden_size} > \\text{batch_size}$) and to reduce the per-GPU memory usage (since model parameters are also distributed), but it also incurs the overhead of communicating activation tensors between GPUs at every step. For a more detailed explanation, please see the [Megatron-LM paper](https://arxiv.org/pdf/1909.08053.pdf). Finally, sequence parallelism distributes along the $\\text{sequence_length}$ dimension. This can be used when tensor parallelism is enabled in order to parallelize operations that run outside the tensor-parallel region (e.g. layer norm). For more details, please see [this paper](https://arxiv.org/pdf/2205.05198.pdf).\n",
"\n",
"To show this in action, let's first initialize NCCL with a trivial process group:"
]
......@@ -131,7 +131,7 @@
"id": "1f2b80d0",
"metadata": {},
"source": [
"We only initialize with one GPU to keep this example simple. Please consult the documentation [torch.distributed](https://pytorch.org/docs/stable/distributed.html) for guidance on running with multiple GPUs. Note that we require that each distributed process corresponds to exactly one GPU, so we treat them interchangeably. In practice, there are multiple factors that can affect the optimal parallel layout: the system hardware, the network topology, usage of other parallelism schemes like pipeline parallelism. A rough rule-of-thumb is to interpret the GPUs as a 2D grid with dimensions of $\\text{num_nodes} \\times \\text{gpus_per_node}$. The rows are tensor-parallel groups and the columns are data-parallel groups.\n",
"We only initialize with one GPU to keep this example simple. Please consult the documentation [torch.distributed](https://pytorch.org/docs/stable/distributed.html) for guidance on running with multiple GPUs. Note that we require that each distributed process corresponds to exactly one GPU, so we treat them interchangeably. In practice, there are multiple factors that can affect the optimal parallel layout: the system hardware, the network topology, usage of other parallelism schemes like pipeline parallelism. A rough rule-of-thumb is to interpret the GPUs as a 2D grid with dimensions of $\\text{num_nodes} \\cdot \\text{gpus_per_node}$. The rows are tensor-parallel groups and the columns are data-parallel groups.\n",
"\n",
"Enabling data parallelism with Transformer Engine is similar to enabling data parallelism with standard PyTorch models: simply wrap the modules with [torch.nn.parallel.DistributedDataParallel](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html). Transformer Engine modules also have native support for tensor and sequence parallelism. If the user provides a process group for tensor parallelism, the modules will distribute the data and perform communication internally. If sequence parallelism is enabled, it will be applied for operations that are not amenable to tensor parallelism and it will use the tensor-parallel process group.\n",
"\n",
......
......@@ -174,7 +174,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "50852cb5",
"metadata": {},
"outputs": [
......@@ -266,7 +266,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": null,
"id": "906b8cf1",
"metadata": {},
"outputs": [
......@@ -299,7 +299,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": null,
"id": "d3637094",
"metadata": {},
"outputs": [
......@@ -509,10 +509,10 @@
"\n",
"* PyTorch: When both options are provided by the user, `cu_seqlens` is preferred as there is no extra conversion needed.\n",
" - `cu_seqlens`: Users can provide cumulative sequence length tensors `cu_seqlens_q` and `cu_seqlens_kv` for `q` and `k`/`v` to the flash-attention or cuDNN attention backend. An example of `cu_seqlens` is `[0, 2, 6, 7]` for a batch of 3 `[aa000, bbbb0, c0000]`.\n",
" - `attention_mask`: Users can also provide `attention_mask` as an alternative, which will then be converted to `cu_seqlens`. For self-attention, `attention_mask` should be one single tensor in shape `[batch_size, 1, 1, seqlen_q]`, and for cross-attention, `attention_mask` should be a list of two tensors in shapes `[batch_size, 1, 1, seqlen_q]` and `[batch_size, 1, 1, seqlen_kv]`, respectively.\n",
" - `attention_mask`: Users can also provide `attention_mask` as an alternative, which will then be converted to `cu_seqlens`. For self-attention, `attention_mask` should be one single tensor of shape `[batch_size, 1, 1, seqlen_q]`, and for cross-attention, `attention_mask` should be a list of two tensors of shapes `[batch_size, 1, 1, seqlen_q]` and `[batch_size, 1, 1, seqlen_kv]`, respectively.\n",
"\n",
"\n",
"* JAX: Users should provide the `attention_mask` tensor in shape `[batch_size, 1, seqlen_q, seqlen_kv]`.\n",
"* JAX: Users should provide the `attention_mask` tensor of shape `[batch_size, 1, seqlen_q, seqlen_kv]`.\n",
"\n",
"**qkv_format=thd:** Transformer Engine extracts the max sequence length information from `q`, `k`, `v` if `max_seqlen_q` and `max_seqlen_kv` are not provided. This requires GPU-CPU copy and synchronization operations. For performance reasons, please set `max_seqlen_q` and `max_seqlen_kv` to their appropriate values for `thd` QKV format.\n",
"\n",
......@@ -521,7 +521,7 @@
},
{
"cell_type": "code",
"execution_count": 33,
"execution_count": null,
"id": "a1f25a9b",
"metadata": {},
"outputs": [
......
This diff is collapsed.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import jax
import jax.numpy as jnp
import time
import math
from typing import Callable, Any, Dict, Optional, Tuple
from flax import linen as nn
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.flax.transformer import DotProductAttention as TEDotProductAttention
def speedometer(
model_apply_fn: Callable,
variables: Any,
input: jnp.ndarray,
output_grad: jnp.ndarray,
dropout_key: jax.random.PRNGKey,
model_init_fn: Callable = None,
forward_kwargs: dict = {},
autocast_kwargs: Optional[dict] = None,
timing_iters: int = 50,
warmup_iters: int = 50,
) -> None:
"""Measure average runtime for a JAX module
Perform forward and backward passes .
"""
if autocast_kwargs is None:
autocast_kwargs = {"enabled": False}
model_init_fn = None
train_step_fn = create_train_step_fn(model_apply_fn, autocast_kwargs, forward_kwargs)
# Warm up runs
key = dropout_key
for _ in range(warmup_iters):
key, step_key = jax.random.split(key)
loss, (param_grads, other_grads) = train_step_fn(variables, input, output_grad, step_key)
# Timing runs
start = time.time()
for _ in range(timing_iters):
key, step_key = jax.random.split(key)
loss, (param_grads, other_grads) = train_step_fn(variables, input, output_grad, step_key)
end = time.time()
print(f"Mean time: {(end - start) * 1000 / timing_iters} ms")
def create_train_step_fn(
model_apply_fn: Callable,
autocast_kwargs: Dict[str, Any],
forward_kwargs: Dict[str, Any] = None,
) -> Callable:
"""
Creates a JIT-compiled function that performs one forward/backward pass.
"""
if forward_kwargs is None:
forward_kwargs = {}
def loss_fn(variables: Any, inp: jnp.ndarray, grad_target: jnp.ndarray, dropout_key):
rngs = {"dropout": dropout_key}
with te.autocast(**autocast_kwargs):
# Forward Pass: Apply the model using current parameters and variables
call_kwargs = {**forward_kwargs, "rngs": rngs}
out = model_apply_fn(variables, inp, **call_kwargs)
# grad_target = derivative of L (loss fn) over y (output) = signma(L)/sigma(y)
# where grad_w(L) = gradient of loss over params = sigma(L)/sigma(y) * sigma(y)/sigma(w) --> chain rule
# sigma(y)/sigma(w) = J_model(w)
return jnp.vdot(out, grad_target)
def fwd_bwd_fn(*args, **kwargs):
return jax.value_and_grad(loss_fn, argnums=(0, 1))(*args, **kwargs)
# Use jax.value_and_grad to get the loss value and gradients simultaneously. (forward + backward pass)
# ∇_params[output^T · grad_target] = grad_target^T · J_output(params) = VJP
# fwd_bwd_fn = jax.value_and_grad(loss_fn, argnums=(0, 1))
# JIT-compile the fwd_bwd_fn
return jax.jit(fwd_bwd_fn)
......@@ -38,7 +38,7 @@
"\n",
"For those seeking a deeper understanding of text generation mechanisms in Transformers, it is recommended to check out the [HuggingFace generation tutorial](https://huggingface.co/docs/transformers/llm_tutorial).\n",
"\n",
"In a previous tutorial on [Llama](../te_llama/tutorial_accelerate_hf_llama_finetuning_with_te.ipynb), it was demonstrated how finetuning of an open-source Llama model can be accelerated using Transformer Engine's `TransformerLayer`. Building on that foundation, this tutorial showcases how to accelerate the token generation from the open-source Hugging Face Gemma 7B model.\n",
"In a previous tutorial on [Llama](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb), it was demonstrated how finetuning of an open-source Llama model can be accelerated using Transformer Engine's `TransformerLayer`. Building on that foundation, this tutorial showcases how to accelerate the token generation from the open-source Hugging Face Gemma 7B model.\n",
"\n",
"This tutorial introduces several features of the Transformer Engine library that contribute towards this goal. A brief explanation is as follows:\n",
"\n",
......
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
Getting Started
===============
Choose your framework to get started with Transformer Engine:
.. toctree::
:maxdepth: 1
PyTorch <examples/quickstart.ipynb>
JAX <examples/quickstart_jax.ipynb>
......@@ -4,7 +4,7 @@
See LICENSE for license information.
Transformer Engine documentation
==============================================
=================================
.. ifconfig:: "dev" in release
......@@ -29,7 +29,7 @@ Transformer Engine documentation
:caption: Getting Started
installation
examples/quickstart.ipynb
getting_started
faq
.. toctree::
......
......@@ -28,7 +28,7 @@ on `NVIDIA GPU Cloud <https://ngc.nvidia.com>`_.
pip - from PyPI
-----------------------
---------------
Transformer Engine can be directly installed from `our PyPI <https://pypi.org/project/transformer-engine/>`_, e.g.
......@@ -47,7 +47,7 @@ The core package from Transformer Engine (without any framework extensions) can
By default, this will install the core library compiled for CUDA 12. The cuda major version can be specified by modified the extra dependency to `core_cu12` or `core_cu13`.
pip - from GitHub
-----------------------
-----------------
Additional Prerequisites
^^^^^^^^^^^^^^^^^^^^^^^^
......
......@@ -30,13 +30,13 @@ do
# Build Flash Attention
if [ "${fa_version}" \< "3.0.0" ]
then
pip3 install flash-attn==${fa_version}
pip3 install flash-attn==${fa_version} --no-build-isolation
else
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/ && git checkout 27f501d && cd hopper/ && python setup.py install
cd flash-attention/hopper && python setup.py install
python_path=`python -c "import site; print(site.getsitepackages()[0])"`
mkdir -p $python_path/flash_attn_3
wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/27f501dbe011f4371bff938fe7e09311ab3002fa/hopper/flash_attn_interface.py
cp flash_attn_interface.py $python_path/flash_attn_3/
cd ../../
fi
......
......@@ -18,6 +18,7 @@ from transformer_engine.jax.attention import (
is_fused_attn_kernel_available,
AttnBiasType,
AttnMaskType,
AttnSoftmaxType,
QKVLayout,
QKVFormat,
reorder_causal_load_balancing,
......@@ -66,6 +67,7 @@ class TestDistributedSelfAttn:
bias_shape,
attn_mask_type,
dtype,
softmax_type,
use_shardy,
):
jax.config.update("jax_use_shardy_partitioner", use_shardy)
......@@ -80,6 +82,7 @@ class TestDistributedSelfAttn:
QKVLayout.BS3HD,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
num_head,
num_head,
......@@ -109,6 +112,7 @@ class TestDistributedSelfAttn:
hidden,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
is_training,
......@@ -142,6 +146,14 @@ class TestDistributedSelfAttn:
],
)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize(
"softmax_type",
[
pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"),
pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
],
)
def test_self_attn(
self,
device_count,
......@@ -153,6 +165,7 @@ class TestDistributedSelfAttn:
bias_shape,
attn_mask_type,
dtype,
softmax_type,
):
self.impl_test_self_attn(
device_count,
......@@ -164,6 +177,7 @@ class TestDistributedSelfAttn:
bias_shape,
attn_mask_type,
dtype,
softmax_type,
use_shardy=False,
)
......@@ -175,8 +189,23 @@ class TestDistributedSelfAttn:
pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"),
],
)
@pytest.mark.parametrize(
"softmax_type",
[
pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"),
pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
],
)
def test_self_attn_shardy(
self, device_count, mesh_shape, mesh_axes, mesh_resource, attn_bias_type, bias_shape
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
attn_bias_type,
bias_shape,
softmax_type,
):
data_shape = (32, 512, 12, 64)
self.impl_test_self_attn(
......@@ -189,6 +218,7 @@ class TestDistributedSelfAttn:
bias_shape,
AttnMaskType.PADDING_MASK,
jnp.bfloat16,
softmax_type,
use_shardy=True,
)
......@@ -213,8 +243,24 @@ class TestDistributedCrossAttn:
"attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize(
"softmax_type",
[
pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"),
pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
],
)
def test_cross_attn(
self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, attn_mask_type, dtype
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
attn_mask_type,
dtype,
softmax_type,
):
attn_bias_type = AttnBiasType.NO_BIAS
bias_shape = None
......@@ -230,6 +276,7 @@ class TestDistributedCrossAttn:
QKVLayout.BSHD_BS2HD,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
num_head,
num_head,
......@@ -252,6 +299,7 @@ class TestDistributedCrossAttn:
hidden,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
is_training,
......@@ -322,6 +370,8 @@ class TestDistributedContextParallelSelfAttn:
bias_shape = None
dropout_prob = 0.0
is_training = True
# Context parallel does not support softmax_offset
softmax_type = AttnSoftmaxType.VANILLA_SOFTMAX
dp_size, cp_size, tp_size = mesh_shape
batch, seqlen, num_head, hidden = data_shape
......@@ -343,6 +393,7 @@ class TestDistributedContextParallelSelfAttn:
hidden,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
is_training,
......@@ -366,6 +417,7 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout,
attn_bias_type,
mask_type,
softmax_type,
dropout_prob,
num_head,
num_kv_heads,
......
......@@ -16,7 +16,7 @@ from distributed_test_base import generate_configs, generate_collectives_count
from distributed_test_base import compare_ops
from utils import make_causal_mask, make_self_mask
from transformer_engine.jax import autocast
from transformer_engine.jax.softmax import SoftmaxType, softmax
from transformer_engine.jax.softmax import SoftmaxFusionType, softmax
DTYPES = [jnp.float16, jnp.bfloat16]
......@@ -29,12 +29,12 @@ class TestDistributedSoftmax:
return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)
def generate_inputs(
self, shape, mesh_resource, softmax_type, dtype, bad_sharding, broadcast_batch_mask
self, shape, mesh_resource, softmax_fusion_type, dtype, bad_sharding, broadcast_batch_mask
):
batch, _, sqelen, _ = shape
x = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
if softmax_type == SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
if softmax_fusion_type == SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED:
mask = make_causal_mask(batch, sqelen)
else:
mask = make_self_mask(1 if broadcast_batch_mask else batch, sqelen)
......@@ -56,8 +56,10 @@ class TestDistributedSoftmax:
return (x, mask), (x_pspec, mask_pspec)
@staticmethod
def target_func(x, mask, scale_factor=1.0, softmax_type=SoftmaxType.SCALED):
return jnp.mean(softmax(x, mask, scale_factor=scale_factor, softmax_type=softmax_type))
def target_func(x, mask, scale_factor=1.0, softmax_fusion_type=SoftmaxFusionType.SCALED):
return jnp.mean(
softmax(x, mask, scale_factor=scale_factor, softmax_fusion_type=softmax_fusion_type)
)
@staticmethod
def ref_func(x, mask, scale_factor=1.0, dtype=jnp.float16):
......@@ -80,24 +82,29 @@ class TestDistributedSoftmax:
mesh_axes,
mesh_resource,
data_shape,
softmax_type,
softmax_fusion_type,
scale_factor,
dtype,
bad_sharding,
broadcast_batch_mask,
use_shardy,
):
if broadcast_batch_mask and softmax_type != SoftmaxType.SCALED_MASKED:
if broadcast_batch_mask and softmax_fusion_type != SoftmaxFusionType.SCALED_MASKED:
pytest.skip("Softmax type has no mask.")
jax.config.update("jax_use_shardy_partitioner", use_shardy)
target_func = partial(
self.target_func, scale_factor=scale_factor, softmax_type=softmax_type
self.target_func, scale_factor=scale_factor, softmax_fusion_type=softmax_fusion_type
)
ref_func = partial(self.ref_func, scale_factor=scale_factor, dtype=dtype)
(x, mask), (x_pspec, mask_pspec) = self.generate_inputs(
data_shape, mesh_resource, softmax_type, dtype, bad_sharding, broadcast_batch_mask
data_shape,
mesh_resource,
softmax_fusion_type,
dtype,
bad_sharding,
broadcast_batch_mask,
)
collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
......@@ -139,8 +146,12 @@ class TestDistributedSoftmax:
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [8, 8, 1024, 1024]])
@pytest.mark.parametrize(
"softmax_type",
[SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED],
"softmax_fusion_type",
[
SoftmaxFusionType.SCALED,
SoftmaxFusionType.SCALED_MASKED,
SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED,
],
)
@pytest.mark.parametrize("scale_factor", [1.0, 3.0])
@pytest.mark.parametrize("dtype", DTYPES)
......@@ -153,7 +164,7 @@ class TestDistributedSoftmax:
mesh_axes,
mesh_resource,
data_shape,
softmax_type,
softmax_fusion_type,
scale_factor,
dtype,
bad_sharding,
......@@ -165,7 +176,7 @@ class TestDistributedSoftmax:
mesh_axes,
mesh_resource,
data_shape,
softmax_type,
softmax_fusion_type,
scale_factor,
dtype,
bad_sharding,
......@@ -174,7 +185,9 @@ class TestDistributedSoftmax:
)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("softmax_type", [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED])
@pytest.mark.parametrize(
"softmax_fusion_type", [SoftmaxFusionType.SCALED, SoftmaxFusionType.SCALED_MASKED]
)
@pytest.mark.parametrize("bad_sharding", [False, True])
@pytest.mark.parametrize("broadcast_batch_mask", [False, True])
def test_softmax_gspmd(
......@@ -183,7 +196,7 @@ class TestDistributedSoftmax:
mesh_shape,
mesh_axes,
mesh_resource,
softmax_type,
softmax_fusion_type,
bad_sharding,
broadcast_batch_mask,
):
......@@ -193,7 +206,7 @@ class TestDistributedSoftmax:
mesh_axes,
mesh_resource,
data_shape=[32, 12, 128, 128],
softmax_type=softmax_type,
softmax_fusion_type=softmax_fusion_type,
scale_factor=1.0,
dtype=DTYPES[0],
bad_sharding=bad_sharding,
......
......@@ -27,6 +27,7 @@ from transformer_engine.jax.sharding import MeshResource
from transformer_engine.jax.attention import (
AttnBiasType,
AttnMaskType,
AttnSoftmaxType,
QKVLayout,
QKVFormat,
reorder_causal_load_balancing,
......@@ -59,14 +60,16 @@ def init():
yield
@partial(jax.jit, static_argnums=(5, 6, 7, 9))
@partial(jax.jit, static_argnums=(6, 7, 8, 9, 11))
def general_dot_product_attention(
query: ArrayLike,
key: ArrayLike,
value: ArrayLike,
softmax_offset: Optional[ArrayLike],
bias: ArrayLike,
mask: ArrayLike,
deterministic: bool,
softmax_type: AttnSoftmaxType,
scale_factor: float,
dropout_rate: float,
dropout_rng: ArrayLike,
......@@ -99,7 +102,25 @@ def general_dot_product_attention(
mask = jnp.expand_dims(mask, axis=-3)
logits = jnp.where(mask, jnp.finfo(dtype).min, logits)
softmax_out = jax.nn.softmax(logits).astype(dtype)
match softmax_type:
case AttnSoftmaxType.VANILLA_SOFTMAX:
softmax_out = jax.nn.softmax(logits).astype(dtype)
case AttnSoftmaxType.OFF_BY_ONE_SOFTMAX:
# Softmax with +1 in denominator: exp(x_i) / (sum(exp(x_j)) + 1)
# Append a zero logit, apply standard softmax, then remove last column
zero_logit = jnp.zeros(logits.shape[:-1] + (1,), dtype=logits.dtype)
logits_with_extra = jnp.concatenate([logits, zero_logit], axis=-1)
softmax_with_extra = jax.nn.softmax(logits_with_extra, axis=-1)
softmax_out = softmax_with_extra[..., :-1].astype(dtype)
case AttnSoftmaxType.LEARNABLE_SOFTMAX:
# Append learnable offset logit, apply standard softmax, then remove last column
learnable_logit = softmax_offset.reshape(1, h_kv, num_groups, 1, 1)
learnable_logit = jnp.broadcast_to(learnable_logit, logits.shape[:-1] + (1,))
logits_with_extra = jnp.concatenate([logits, learnable_logit], axis=-1)
softmax_with_extra = jax.nn.softmax(logits_with_extra, axis=-1)
softmax_out = softmax_with_extra[..., :-1].astype(dtype)
case _:
raise NotImplementedError(f"Unknown {softmax_type=}")
if not deterministic and dropout_rate > 0.0:
keep_prob = 1.0 - dropout_rate
......@@ -238,7 +259,7 @@ def _split_valid_and_invalid(primitive, reference, pad):
return primitive_valid, primitive_invalid, reference_valid, reference_invalid
def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs):
def jax_dpa(query, key, value, bias, softmax_offset, mask, dropout_rng, **kwargs):
"""
JAX native dot product attention implementation
"""
......@@ -246,11 +267,13 @@ def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs):
query,
key,
value,
softmax_offset,
bias,
mask,
deterministic=not kwargs["is_training"],
scale_factor=kwargs["scaling_factor"],
dropout_rate=kwargs["dropout_probability"],
softmax_type=kwargs["softmax_type"],
dropout_rng=dropout_rng,
dtype=jnp.float32,
)
......@@ -262,6 +285,7 @@ def customcall_fused_dpa(
key,
value,
bias,
softmax_offset,
sequence_descriptor,
dropout_rng,
**kwargs,
......@@ -283,9 +307,9 @@ def customcall_fused_dpa(
qkv_args = (query, key, value)
case _:
raise ValueError(f"Unsupported {qkv_layout=}")
return fused_attn(qkv_args, bias, sequence_descriptor, dropout_rng, **kwargs).astype(
query.dtype
)
return fused_attn(
qkv_args, bias, sequence_descriptor, dropout_rng, softmax_offset=softmax_offset, **kwargs
).astype(query.dtype)
class BiasShape(Enum):
......@@ -320,6 +344,7 @@ class FusedAttnRunner:
head_dim_v: int
attn_bias_type: AttnBiasType
attn_mask_type: AttnMaskType
softmax_type: AttnSoftmaxType
dropout_prob: float
dtype: DTypeLike
is_training: bool
......@@ -402,6 +427,7 @@ class FusedAttnRunner:
self.qkv_layout,
self.attn_bias_type,
self.attn_mask_type,
self.softmax_type,
self.dropout_prob,
self.num_heads_q,
self.num_heads_kv,
......@@ -439,7 +465,7 @@ class FusedAttnRunner:
self.tp_size = self.mesh.shape.get(self.mesh_resource.tpsp_resource, 1)
key = jax.random.PRNGKey(0)
q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
q_key, k_key, v_key, bias_key, dropout_key, softmax_key = jax.random.split(key, 6)
q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim_qk)
k_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_qk)
......@@ -490,6 +516,13 @@ class FusedAttnRunner:
else:
pad_ratio = 0.0
if self.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX:
self.softmax_offset = jax.random.uniform(
softmax_key, (1, self.num_heads_q, 1, 1), jnp.float32, -1.0
)
else:
self.softmax_offset = None
def gen_valid(bs, max_seqlen, pad_ratio):
pad_len = int(max_seqlen * pad_ratio)
valid_len = max_seqlen - pad_len
......@@ -713,6 +746,16 @@ class FusedAttnRunner:
self.bias_pspec = PartitionSpec()
self.bias_sharding = NamedSharding(self.mesh, self.bias_pspec)
# Softmax offset sharding (1, num_heads, 1, 1)
# Use the same logic as HEAD_AXES: tpsp_resource if enabled, else tp_resource
head_resource = (
self.mesh_resource.tpsp_resource
if self.mesh_resource.tpsp_resource is not None
else self.mesh_resource.tp_resource
)
self.softmax_offset_pspec = PartitionSpec(None, head_resource, None, None)
self.softmax_offset_sharding = NamedSharding(self.mesh, self.softmax_offset_pspec)
self.dropout_rng_pspec = PartitionSpec(
None,
)
......@@ -732,7 +775,7 @@ class FusedAttnRunner:
"""
self._setup_inputs()
args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
args = [self.q, self.k, self.v, self.bias, self.softmax_offset, self.mask, self.dropout_rng]
customcall_args = [
# Put test data onto each GPU for distributed.
......@@ -742,12 +785,14 @@ class FusedAttnRunner:
jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
jax.device_put(self.bias, self.bias_sharding),
jax.device_put(self.softmax_offset, self.softmax_offset_sharding),
jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
]
kwargs = {
"attn_bias_type": self.attn_bias_type,
"attn_mask_type": self.attn_mask_type,
"softmax_type": self.softmax_type,
"scaling_factor": self.scaling_factor,
"dropout_probability": self.dropout_prob,
"is_training": self.is_training,
......@@ -766,6 +811,7 @@ class FusedAttnRunner:
self.qkvo_sharding,
self.qkvo_sharding,
self.bias_sharding,
self.softmax_offset_sharding,
self.seq_desc_sharding,
self.dropout_rng_sharding,
],
......@@ -826,7 +872,7 @@ class FusedAttnRunner:
jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier
).astype(self.dtype)
args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
args = [self.q, self.k, self.v, self.bias, self.softmax_offset, self.mask, self.dropout_rng]
customcall_args = [
# TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and
# THD params once we support those features on CP.
......@@ -834,12 +880,14 @@ class FusedAttnRunner:
jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
jax.device_put(self.bias, self.bias_sharding),
jax.device_put(self.softmax_offset, self.softmax_offset_sharding),
jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
]
kwargs = {
"attn_bias_type": self.attn_bias_type,
"attn_mask_type": self.attn_mask_type,
"softmax_type": self.softmax_type,
"scaling_factor": self.scaling_factor,
"dropout_probability": self.dropout_prob,
"is_training": self.is_training,
......@@ -866,8 +914,16 @@ class FusedAttnRunner:
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
jitted_primitive = jit(
value_and_grad(
lambda q, k, v, bias, *args: grad_func(
customcall_fused_dpa, q, k, v, bias, *args, cp_reverse_out=True, **kwargs
lambda q, k, v, bias, softmax_offset, *args: grad_func(
customcall_fused_dpa,
q,
k,
v,
bias,
softmax_offset,
*args,
cp_reverse_out=True,
**kwargs,
),
arg_nums,
),
......@@ -876,6 +932,7 @@ class FusedAttnRunner:
self.qkvo_sharding,
self.qkvo_sharding,
self.bias_sharding,
self.softmax_offset_sharding,
self.seq_desc_sharding,
self.dropout_rng_sharding,
),
......@@ -883,7 +940,9 @@ class FusedAttnRunner:
)
jitted_reference = jit(
value_and_grad(
lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs),
lambda q, k, v, bias, softmax_offset, *args: grad_func(
jax_dpa, q, k, v, bias, softmax_offset, *args, **kwargs
),
arg_nums,
)
)
......@@ -976,6 +1035,14 @@ class FusedAttnRunner:
),
],
)
@pytest.mark.parametrize(
"softmax_type",
[
pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"),
pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
],
)
@pytest.mark.parametrize(
"qkv_layout",
[
......@@ -1084,6 +1151,7 @@ class TestFusedAttn:
d_v,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
is_training,
......@@ -1110,6 +1178,7 @@ class TestFusedAttn:
d_v,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
is_training,
......@@ -1138,6 +1207,7 @@ class TestFusedAttn:
d_v,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
qkv_layout,
......@@ -1161,6 +1231,7 @@ class TestFusedAttn:
d_v,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
True,
......
......@@ -83,6 +83,7 @@ _KEY_OF_FLOAT32_ATTENTION_LOGITS = "float32_attention_logits"
_KEY_OF_USE_BIAS = "use_bias"
_KEY_OF_RELATIVE_EMBEDDING = "enable_relative_embedding"
_KEY_OF_WINDOW_SIZE = "window_size"
_KEY_OF_SOFTMAX_TYPE = "softmax_type"
BASE_ATTRS = {
_KEY_OF_TRANSPOSE_BS: True,
......@@ -276,6 +277,14 @@ ATTRS = [
_KEY_OF_RELATIVE_EMBEDDING: True,
_KEY_OF_SELF_ATTN_BIAS_TYPE: "post_scale_bias",
},
# attrs31
{
_KEY_OF_SOFTMAX_TYPE: "off_by_one",
},
# attrs31
{
_KEY_OF_SOFTMAX_TYPE: "learnable",
},
]
ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
......@@ -418,6 +427,9 @@ class EncoderRunner(BaseRunner):
"attention/qkv/ln_bias": "pre_attention_layer_norm/ln_bias",
"attention/query/scale": "pre_attention_layer_norm/scale",
"attention/query/ln_bias": "pre_attention_layer_norm/ln_bias",
"attention/DotProductAttention_0/_UnfusedDotProductAttention_0/softmax_offset": (
"attention/DotProductAttention_0/softmax_offset"
),
"mlp/wi_kernel": "mlp/wi/kernel",
"mlp/wi_bias": "mlp/wi/bias",
"mlp/wo_kernel": "mlp/wo/kernel",
......@@ -463,10 +475,16 @@ class DecoderRunner(BaseRunner):
"encoder_decoder_attention/qkv/ln_bias": "pre_cross_attention_layer_norm/ln_bias",
"encoder_decoder_attention/query/scale": "pre_cross_attention_layer_norm/scale",
"encoder_decoder_attention/query/ln_bias": "pre_cross_attention_layer_norm/ln_bias",
"encoder_decoder_attention/DotProductAttention_0/_UnfusedDotProductAttention_0/softmax_offset": (
"encoder_decoder_attention/DotProductAttention_0/softmax_offset"
),
"self_attention/qkv/scale": "pre_self_attention_layer_norm/scale",
"self_attention/qkv/ln_bias": "pre_self_attention_layer_norm/ln_bias",
"self_attention/query/scale": "pre_self_attention_layer_norm/scale",
"self_attention/query/ln_bias": "pre_self_attention_layer_norm/ln_bias",
"self_attention/DotProductAttention_0/_UnfusedDotProductAttention_0/softmax_offset": (
"self_attention/DotProductAttention_0/softmax_offset"
),
"mlp/wi_kernel": "mlp/wi/kernel",
"mlp/wi_bias": "mlp/wi/bias",
"mlp/wo_kernel": "mlp/wo/kernel",
......
......@@ -17,7 +17,8 @@ from jax.typing import DTypeLike
from utils import assert_allclose
from transformer_engine.jax.cpp_extensions import is_softmax_kernel_available
from transformer_engine.jax.softmax import SoftmaxType, softmax
from transformer_engine.jax.cpp_extensions.attention import AttnSoftmaxType
from transformer_engine.jax.softmax import SoftmaxFusionType, softmax
from transformer_engine.jax.flax.module import Softmax
......@@ -50,8 +51,9 @@ class SoftmaxRunner:
max_seqlen_kv: int
num_heads: int
scale_factor: float
softmax_type: SoftmaxType
softmax_fusion_type: SoftmaxFusionType
dtype: DTypeLike
softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
@staticmethod
def reference_softmax(logits, mask, scale_factor, **_):
......@@ -68,6 +70,7 @@ class SoftmaxRunner:
def _is_support(self):
return is_softmax_kernel_available(
self.softmax_fusion_type,
self.softmax_type,
self.batch_size,
self.num_heads,
......@@ -85,22 +88,22 @@ class SoftmaxRunner:
self.logits = jax.random.uniform(logits_key, logits_shape, self.dtype, -1.0)
match self.softmax_type:
case SoftmaxType.SCALED:
match self.softmax_fusion_type:
case SoftmaxFusionType.SCALED:
self.mask = None
case SoftmaxType.SCALED_MASKED:
case SoftmaxFusionType.SCALED_MASKED:
self.mask = jax.random.bernoulli(mask_key, shape=mask_shape).astype(jnp.uint8)
case SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
case SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED:
self.mask = (1.0 - jnp.tril(jnp.ones_like(self.logits))).astype(jnp.uint8)
case _:
raise ValueError(f"Unknown {self.softmax_type=}")
raise ValueError(f"Unknown {self.softmax_fusion_type=}")
def test_forward(self):
"""
Test transformer_engine.jax.softmax.softmax fwd rule
"""
self._setup_inputs()
primitive_out = softmax(self.logits, self.mask, self.scale_factor, self.softmax_type)
primitive_out = softmax(self.logits, self.mask, self.scale_factor, self.softmax_fusion_type)
reference_out = __class__.reference_softmax(self.logits, self.mask, self.scale_factor)
assert_allclose(primitive_out, reference_out, dtype=self.dtype)
......@@ -117,7 +120,7 @@ class SoftmaxRunner:
args = [self.logits, self.mask]
kwargs = {
"scale_factor": self.scale_factor,
"softmax_type": self.softmax_type,
"softmax_fusion_type": self.softmax_fusion_type,
}
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
......@@ -175,7 +178,7 @@ class SoftmaxModuleRunner:
rng = jax.random.PRNGKey(0)
softmax_module = Softmax(
scale_factor=runner.scale_factor,
softmax_type=runner.softmax_type,
softmax_fusion_type=runner.softmax_fusion_type,
)
softmax_vars = softmax_module.init(rng, runner.logits, runner.mask)
module_out = softmax_module.apply(softmax_vars, runner.logits, runner.mask)
......@@ -194,11 +197,11 @@ class SoftmaxModuleRunner:
)
@pytest.mark.parametrize("scale_factor", [0.125])
@pytest.mark.parametrize(
"softmax_type",
"softmax_fusion_type",
[
pytest.param(SoftmaxType.SCALED, id="SCALED"),
pytest.param(SoftmaxType.SCALED_MASKED, id="SCALED_MASKED"),
pytest.param(SoftmaxType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"),
pytest.param(SoftmaxFusionType.SCALED, id="SCALED"),
pytest.param(SoftmaxFusionType.SCALED_MASKED, id="SCALED_MASKED"),
pytest.param(SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"),
],
)
@pytest.mark.parametrize(
......@@ -214,19 +217,19 @@ class TestSoftmaxPrimitives:
"""
@staticmethod
def test_forward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype):
def test_forward(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype):
"""
Test forward with parameterized configs
"""
runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype)
runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype)
runner.test_forward()
@staticmethod
def test_backward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype):
def test_backward(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype):
"""
Test forward with parameterized configs
"""
runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype)
runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype)
runner.test_backward()
......@@ -243,11 +246,11 @@ class TestSoftmaxPrimitives:
)
@pytest.mark.parametrize("scale_factor", [0.125])
@pytest.mark.parametrize(
"softmax_type",
"softmax_fusion_type",
[
pytest.param(SoftmaxType.SCALED, id="SCALED"),
pytest.param(SoftmaxType.SCALED_MASKED, id="SCALED_MASKED"),
pytest.param(SoftmaxType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"),
pytest.param(SoftmaxFusionType.SCALED, id="SCALED"),
pytest.param(SoftmaxFusionType.SCALED_MASKED, id="SCALED_MASKED"),
pytest.param(SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"),
],
)
@pytest.mark.parametrize(
......@@ -263,11 +266,11 @@ class TestSoftmaxModule:
"""
@staticmethod
def test_forward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype):
def test_forward(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype):
"""
Test forward with parameterized configs
"""
module_runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype)
module_runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype)
bias = None
runner = SoftmaxModuleRunner(module_runner, bias)
runner.test_forward()
......@@ -21,6 +21,7 @@ from jax import random as jax_random
import pytest
from transformer_engine.jax.attention import (
AttnSoftmaxType,
canonicalize_attn_mask_type,
make_swa_mask,
)
......@@ -162,6 +163,7 @@ class DotProductAttention(nn.Module):
dropout_rate: float = 0.0
dtype: DType = jnp.float32
float32_logits: bool = False
softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
"""Computes dot-product attention given query, key, and value.
This is the core function for applying attention based on
......@@ -211,6 +213,24 @@ class DotProductAttention(nn.Module):
assert key.shape[-2] == value.shape[-2], "k, v num_heads must match."
assert query.shape[-1] == key.shape[-1], "q, k head_dim must match."
# Infer number of attention heads from query shape
# query shape: [..., h, d] where h is num_attention_heads
num_attention_heads = query.shape[-2]
# Initialize softmax_offset for off-by-one or learnable softmax
softmax_offset = None
if self.softmax_type == AttnSoftmaxType.OFF_BY_ONE_SOFTMAX:
# For off-by-one softmax, use zeros with shape (1, h, 1, 1)
softmax_offset = jnp.zeros((1, num_attention_heads, 1, 1), dtype=input_dtype)
elif self.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX:
# For learnable softmax, create a learnable parameter with shape (1, h, 1, 1)
softmax_offset = self.param(
"softmax_offset",
nn.initializers.zeros,
(1, num_attention_heads, 1, 1),
jnp.float32,
)
if self.scale_attn_logits:
head_dim = query.shape[-1]
depth_scaling = jnp.sqrt(head_dim).astype(input_dtype)
......@@ -241,9 +261,23 @@ class DotProductAttention(nn.Module):
if bias is not None:
attn_weights = attn_weights + bias.astype(attn_weights.dtype)
# Add attention sink to the last column if not vanilla softmax
if self.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
# Add extra column with softmax_offset
# softmax_offset shape: (1, h, 1, 1), attn_weights shape: [b, h, q, k]
extra_col = jnp.broadcast_to(
softmax_offset,
(attn_weights.shape[0], attn_weights.shape[1], attn_weights.shape[2], 1),
)
attn_weights = jnp.concatenate([attn_weights, extra_col], axis=-1)
# Normalize the attention weights across `kv_length` dimension.
attn_weights = jax_nn.softmax(attn_weights).astype(input_dtype)
# Remove the extra column after softmax if not vanilla softmax
if self.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
attn_weights = attn_weights[..., :-1]
# Apply attention dropout.
if not deterministic and self.dropout_rate > 0.0:
keep_prob = 1.0 - self.dropout_rate
......@@ -535,6 +569,7 @@ class MultiHeadAttention(nn.Module):
rotary_pos_emb_group_method: str = "consecutive"
fuse_qkv: bool = True
use_bias: bool = False
softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
def __post_init__(self):
if self.kernel_init is None:
......@@ -801,6 +836,7 @@ class MultiHeadAttention(nn.Module):
dropout_rate=self.dropout_rate,
dtype=self.dtype,
float32_logits=self.float32_logits,
softmax_type=self.softmax_type,
)(query, key, value, bias=attention_bias, deterministic=deterministic)
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
......@@ -1058,6 +1094,7 @@ class EncoderLayer(nn.Module):
self_attn_bias_type: Any = None
self_attn_mask_type: str = "no_mask"
window_size: Tuple[int, int] = (-1, -1)
softmax_type: str = "vanilla"
def __post_init__(self):
if self.num_gqa_groups is None:
......@@ -1111,6 +1148,9 @@ class EncoderLayer(nn.Module):
else:
x = inputs
# Convert softmax_type string to AttnSoftmaxType enum
attn_softmax_type = AttnSoftmaxType.from_str(self.softmax_type)
# [batch, length, emb_dim] -> [batch, length, emb_dim]
x = MultiHeadAttention(
num_heads=self.num_attention_heads,
......@@ -1126,6 +1166,7 @@ class EncoderLayer(nn.Module):
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
use_bias=self.use_bias,
softmax_type=attn_softmax_type,
name="attention",
)(x, x, encoder_mask, encoder_bias, deterministic=deterministic)
x = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
......@@ -1222,6 +1263,7 @@ class DecoderLayer(nn.Module):
self_attn_bias_type: Any = None
self_attn_mask_type: str = "no_mask"
window_size: Tuple[int, int] = (-1, -1)
softmax_type: str = "vanilla"
def __post_init__(self):
if self.num_gqa_groups is None:
......@@ -1290,6 +1332,9 @@ class DecoderLayer(nn.Module):
else:
x = inputs
# Convert softmax_type string to AttnSoftmaxType enum
attn_softmax_type = AttnSoftmaxType.from_str(self.softmax_type)
# Self-attention block
x = MultiHeadAttention(
num_heads=self.num_attention_heads,
......@@ -1305,6 +1350,7 @@ class DecoderLayer(nn.Module):
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
fuse_qkv=self.fuse_qkv_params,
use_bias=self.use_bias,
softmax_type=attn_softmax_type,
name="self_attention",
)(x, x, decoder_mask, decoder_bias, deterministic=deterministic, decode=decode)
x = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
......@@ -1343,6 +1389,7 @@ class DecoderLayer(nn.Module):
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
fuse_qkv=self.fuse_qkv_params,
use_bias=self.use_bias,
softmax_type=attn_softmax_type,
name="encoder_decoder_attention",
)(y, encoded, encoder_decoder_mask, deterministic=deterministic)
y = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
......
......@@ -89,40 +89,47 @@ def generate_input_shapes(
cu_seqlens_q_padded = None
cu_seqlens_kv_padded = None
elif qkv_format == "thd":
seqlens_q = torch.randint(0, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32)
seqlens_q_padded = (seqlens_q + 2 * world_size - 1) // (world_size * 2) * (world_size * 2)
cu_seqlens_q_padded = torch.cat(
[
torch.zeros([1], dtype=torch.int32),
seqlens_q_padded.cumsum(0, dtype=torch.int32),
]
).cuda()
cu_seqlens_q = torch.clone(cu_seqlens_q_padded)
# Since FlashAttention doesn't support pad b/w sequences, and FusedAttention does,
# cu_seqlens_q is updated to reflect non-padded lengths for FusedAttention only.
if kernel_backend == "FusedAttention":
cu_seqlens_q[1:] = seqlens_q.cumsum(0, dtype=torch.int32).cuda()
# NOTE: In case of Cross-Attention, `cu_seqlens_kv` and `cu_seqlens_kv_padded`
# will not be the same as `cu_seqlens_q` and `cu_seqlens_q_padded` respectively.
cu_seqlens_kv = cu_seqlens_q
cu_seqlens_kv_padded = cu_seqlens_q_padded
total_tokens = cu_seqlens_q_padded[-1]
q_input_shape = (
config.batch_size * config.max_seqlen_q,
total_tokens,
config.num_heads,
config.head_dim_qk,
)
k_input_shape = (
config.batch_size * config.max_seqlen_q,
total_tokens,
config.num_gqa_groups,
config.head_dim_qk,
)
v_input_shape = (
config.batch_size * config.max_seqlen_q,
total_tokens,
config.num_gqa_groups,
config.head_dim_v,
)
attn_output_shape = (
config.batch_size * config.max_seqlen_q,
total_tokens,
config.num_heads * config.head_dim_v,
)
seqlens_q = torch.randint(0, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32)
seqlens_q_padded = (seqlens_q + 2 * world_size - 1) // (world_size * 2) * (world_size * 2)
cu_seqlens_q_padded = torch.cat(
[
torch.zeros([1], dtype=torch.int32),
seqlens_q_padded.cumsum(0, dtype=torch.int32),
torch.tensor([q_input_shape[0]], dtype=torch.int32),
]
).cuda()
cu_seqlens_q = torch.clone(cu_seqlens_q_padded)
if kernel_backend == "FusedAttention":
cu_seqlens_q[1:-1] = seqlens_q.cumsum(0, dtype=torch.int32).cuda()
cu_seqlens_q[-1] = cu_seqlens_q[-2]
cu_seqlens_kv = cu_seqlens_q
cu_seqlens_kv_padded = cu_seqlens_q_padded
else:
assert False, f"{qkv_format=} is not supported!"
......
......@@ -117,7 +117,14 @@ model_configs_base = {
@pytest.mark.parametrize("swa", [False])
@pytest.mark.parametrize("pad_between_seqs", [False])
def test_dot_product_attention(
dtype, model_configs, model, ckpt_attn, workspace_opt, qkv_layout, swa, pad_between_seqs
dtype,
model_configs,
model,
ckpt_attn,
workspace_opt,
qkv_layout,
swa,
pad_between_seqs,
):
"""Test DotProductAttention module"""
......@@ -308,6 +315,31 @@ def test_dpa_max_logit(dtype, model_configs, model, qkv_layout):
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)
model_configs_num_splits = {
# test: ModelConfig(b, sq, hq, dqk)
"num_splits_1_0": ModelConfig(2, 2048, 24, 128, num_splits=2),
"num_splits_1_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096, num_splits=4),
}
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_num_splits])
@pytest.mark.parametrize("model", model_configs_num_splits.keys())
def test_dpa_num_splits(dtype, model_configs, model):
"""Test DotProductAttention with FlashAttention-3 num_splits enabled"""
test_dot_product_attention(
dtype,
model_configs,
model,
False,
True,
None,
False,
False,
)
model_configs_softmax = {
# test: ModelConfig(b, sq, hq, dqk)
"softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8),
......@@ -1153,6 +1185,8 @@ def _run_dot_product_attention(
core_attention_bias=bias,
alibi_slopes=alibi_slopes,
fast_zero_fill=True,
# Only pass num_splits when exercising the FlashAttention path
num_splits=config.num_splits if backend == "FlashAttention" else 1,
)
max_logit = None
if config.return_max_logit:
......@@ -1787,9 +1821,10 @@ def test_mha_fp8_vs_f16(
fp8_meta=fp8_meta,
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if flash_attn_supported + fused_attn_supported < 1:
flash_attn_supported, fused_attn_supported_fp8, unfused_attn_supported = available_backends
if flash_attn_supported + fused_attn_supported_fp8 < 1:
pytest.skip("No FP8 attention backend available.")
fused_attn_supported_f16 = False
if not fp8_dpa_bwd:
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
......@@ -1797,8 +1832,8 @@ def test_mha_fp8_vs_f16(
qkv_layout=qkv_format.replace("hd", "h3d"),
is_training=is_training,
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
_, fused_attn_supported_f16, _ = available_backends
if not fused_attn_supported_f16:
pytest.skip("No attention backend available.")
if flash_attn_supported:
......@@ -1810,23 +1845,28 @@ def test_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
)
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
)
if fused_attn_supported_fp8:
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
)
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
dtype, config, False, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
)
if fused_attn_supported_f16:
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
dtype, config, False, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
)
atol = 5e-1
rtol = 5e-1
rmse_tol = 0.15
if flash_attn_supported:
if flash_attn_supported and fused_attn_supported_f16:
logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:"))
logging.debug("========== {:^25s} ==========".format("forward output"))
compare_and_assert(
......@@ -1839,32 +1879,33 @@ def test_mha_fp8_vs_f16(
rmse_tol,
True,
)
logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:"))
logging.debug("========== {:^25s} ==========".format("forward output"))
compare_and_assert(
fused_attn_fwd_fp8,
fused_attn_fwd_f16,
"fused_attn_fwd_fp8",
"fused_attn_fwd_f16",
atol,
rtol,
rmse_tol,
True,
)
if fused_attn_supported_fp8 and fused_attn_supported_f16:
logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:"))
logging.debug("========== {:^25s} ==========".format("forward output"))
compare_and_assert(
fused_attn_fwd_fp8,
fused_attn_fwd_f16,
"fused_attn_fwd_fp8",
"fused_attn_fwd_f16",
atol,
rtol,
rmse_tol,
True,
)
if is_training:
for i in range(len(param_names[:1])):
logging.debug("========== {:^25s} ==========".format(param_names[i]))
compare_and_assert(
fused_attn_bwd_fp8[i],
fused_attn_bwd_f16[i],
f"fused_attn_bwd_fp8[{i}]",
f"fused_attn_bwd_f16[{i}]",
atol,
rtol,
rmse_tol,
True,
)
if is_training:
for i in range(len(param_names[:1])):
logging.debug("========== {:^25s} ==========".format(param_names[i]))
compare_and_assert(
fused_attn_bwd_fp8[i],
fused_attn_bwd_f16[i],
f"fused_attn_bwd_fp8[{i}]",
f"fused_attn_bwd_f16[{i}]",
atol,
rtol,
rmse_tol,
True,
)
def _run_mha_fp8_vs_f16(
......@@ -2490,7 +2531,6 @@ class _custom_mha_fp8(torch.autograd.Function):
max_s: int,
fast_zero_fill: bool,
fp8_meta: Dict[str, Any],
workspace: torch.Tensor,
is_training: bool,
mask_type: str,
quantizers: list[Quantizer],
......@@ -2519,7 +2559,6 @@ class _custom_mha_fp8(torch.autograd.Function):
qkv, *_ = ext.general_gemm(
qkv_weight_fp8,
inp_fp8,
workspace,
bias=qkv_bias,
out_dtype=qkv_weight_fp8.dtype,
quantization_params=qkv_quantizer,
......@@ -2561,9 +2600,7 @@ class _custom_mha_fp8(torch.autograd.Function):
s_quantizer=s_quantizer,
)
tensors_to_save, tensor_objects = prepare_for_saving(
q, k, v, inp_fp8, qkv_weight_fp8, workspace, out
)
tensors_to_save, tensor_objects = prepare_for_saving(q, k, v, inp_fp8, qkv_weight_fp8, out)
ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects
......@@ -2593,7 +2630,7 @@ class _custom_mha_fp8(torch.autograd.Function):
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
with torch.cuda.nvtx.range("_DPA"):
saved_tensors = ctx.saved_tensors
(q, k, v, inp_fp8, qkv_weight_fp8, workspace, out) = restore_from_saved(
(q, k, v, inp_fp8, qkv_weight_fp8, out) = restore_from_saved(
ctx.tensor_objects, saved_tensors
)
......@@ -2649,7 +2686,6 @@ class _custom_mha_fp8(torch.autograd.Function):
qkv_dgrad, *_ = ext.general_gemm(
qkv_weight_fp8,
dqkv_c,
workspace,
ctx.dtype,
use_split_accumulator=_2X_ACC_DGRAD,
layout="NN",
......@@ -2659,7 +2695,6 @@ class _custom_mha_fp8(torch.autograd.Function):
qkv_wgrad, *_ = ext.general_gemm(
inp_fp8,
dqkv,
workspace,
ctx.dtype,
use_split_accumulator=_2X_ACC_WGRAD,
layout="NT",
......@@ -2710,9 +2745,6 @@ class Custom_MHA_FP8(TransformerEngineBaseModule):
with torch.no_grad():
self.qkv_bias.zero_()
self.qkv_weight.fill_(1.0)
self.workspace = torch.empty(
_CUBLASLT_WORKSPACE_SIZE_BYTES, dtype=torch.int8, device="cuda"
)
def forward(
self,
......@@ -2731,7 +2763,6 @@ class Custom_MHA_FP8(TransformerEngineBaseModule):
max_s,
self.fast_zero_fill,
self.fp8_meta,
self.workspace,
self.training,
self.mask_type,
self.quantizers,
......
......@@ -7,7 +7,7 @@ import subprocess
import sys
import pathlib
import logging
import copy
import pytest
import torch
from transformer_engine.pytorch import (
......@@ -74,7 +74,7 @@ dtypes = ["bf16", "fp16"]
qkv_formats = ["bshd", "sbhd", "thd"]
cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"]
if test_essential:
configs = ["cp_1_0", "cp_2_1", "cp_3_2", "cp_3_3"]
configs = ["cp_1_0", "cp_1_2", "cp_2_1", "cp_3_2", "cp_3_3"]
model_configs_flash_attn = {k: model_configs_flash_attn[k] for k in configs}
dtypes = ["bf16"]
qkv_formats = ["sbhd", "thd"]
......@@ -97,12 +97,16 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("CP implementation with KV P2P does not support sliding window yet!")
if cp_comm_type == "all_gather" and qkv_format == "thd":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with KV all-gather does not support bias yet!")
if "a2a" in cp_comm_type and qkv_format == "thd":
pytest.skip("CP implementation with QKVO A2A does not support THD format yet!")
if qkv_format == "thd":
if cp_comm_type == "all_gather":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if cp_comm_type == "a2a+p2p":
pytest.skip(
"CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format"
" yet!"
)
if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with QKVO A2A does not support bias yet!")
if "a2a" in cp_comm_type and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0):
......@@ -184,7 +188,7 @@ dtypes = ["bf16", "fp16", "fp8"]
qkv_formats = ["bshd", "sbhd", "thd"]
cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"]
if test_essential:
configs = ["cp_1_0", "cp_1_1", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"]
configs = ["cp_1_0", "cp_1_1", "cp_1_4", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"]
model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs}
dtypes = ["bf16", "fp8"]
qkv_formats = ["sbhd", "thd"]
......@@ -225,10 +229,14 @@ def test_cp_with_fused_attention(
if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias":
pytest.skip("THD format does not support post_scale_bias yet!")
if qkv_format == "thd" and cp_comm_type == "all_gather":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if qkv_format == "thd" and "a2a" in cp_comm_type:
pytest.skip("CP implementation with QKVO A2A does not support THD format yet!")
if qkv_format == "thd":
if cp_comm_type == "all_gather":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if cp_comm_type == "a2a+p2p":
pytest.skip(
"CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format"
" yet!"
)
if dtype == "fp8" and cp_comm_type == "all_gather":
pytest.skip(
"CP implementation with KV all-gather does not support FP8 + context parallelism yet!"
......@@ -282,6 +290,14 @@ def test_cp_with_fused_attention(
)
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
if qkv_format == "thd":
config = copy.deepcopy(config)
if "causal" in config.attn_mask_type:
config.attn_mask_type = "padding_causal"
else:
config.attn_mask_type = "padding"
fp8_meta = {}
fp8_meta["recipe"] = None
fp8_meta["local_recipes"] = []
......
......@@ -82,7 +82,6 @@ def _fp8_gemm_kernel(tensor1, scale1, dtype1, tensor2, scale2, dtype2, use_split
out, *_ = tepytorch.cpp_extensions.general_gemm(
fp8_tensor1,
fp8_tensor2,
tepytorch.module.base.get_workspace(),
torch.float32,
use_split_accumulator=use_split_accumulator,
)
......@@ -199,7 +198,6 @@ def _emulate_linear(
wgrad, *_ = tepytorch.cpp_extensions.general_gemm(
wgrad_input,
wgrad_gradient,
tepytorch.module.base.get_workspace(),
torch.float32,
layout="NT",
grad=True,
......
......@@ -25,10 +25,8 @@ from transformer_engine.pytorch import (
MXFP8Quantizer,
)
import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.module.base import (
fill_userbuffers_buffer_for_all_gather,
get_cublas_workspace_size_bytes,
)
from transformer_engine.pytorch.cpp_extensions.gemm import get_cublas_workspace_size_bytes
from transformer_engine.pytorch.module.base import fill_userbuffers_buffer_for_all_gather
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
......@@ -420,10 +418,6 @@ def _main(opts):
std=opts.std,
)
# Allocate cuBLAS workspace
workspace_size = 1 * get_cublas_workspace_size_bytes()
workspace = torch.empty(workspace_size, dtype=torch.uint8, device="cuda")
# Gather global tensors and calculate reference result (need these first for Fp8 scales)
if opts.bulk_overlap:
ker_g = torch.transpose(kernel_t, 0, 1)
......@@ -620,7 +614,6 @@ def _main(opts):
return tex.general_gemm(
kernel_t_fp8,
gemm_inp,
workspace,
out_dtype=torch.float8_e4m3fn if opts.fp8_output else torch.bfloat16,
quantization_params=out_quantizer,
use_split_accumulator=te.module.base._2X_ACC_FPROP,
......@@ -638,7 +631,6 @@ def _main(opts):
return tex.general_gemm(
kernel2_t_fp8,
gemm2_inp,
workspace,
out_dtype=torch.float8_e4m3fn if opts.fp8_output else torch.bfloat16,
quantization_params=out2_quantizer,
use_split_accumulator=te.module.base._2X_ACC_FPROP,
......@@ -651,7 +643,6 @@ def _main(opts):
return tex.general_gemm(
kernel_t,
gemm_inp,
workspace,
out_dtype=torch.bfloat16,
use_split_accumulator=te.module.base._2X_ACC_FPROP,
ub=ub_obj,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment