Unverified Commit 88c0c914 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Update docs/example and benchmarks/ scripts (#1075)



* update example/benchmark scripts
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix head_dim after MLA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update notebook
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent b8d453ef
...@@ -11,9 +11,7 @@ import nvtx ...@@ -11,9 +11,7 @@ import nvtx
import transformer_engine import transformer_engine
from tests.pytorch.fused_attn.test_fused_attn import ( from tests.pytorch.fused_attn.test_fused_attn import (
ModelConfig, ModelConfig,
_is_flash_attention_supported, _get_attention_backends,
_is_fused_attention_supported,
_is_unfused_attention_supported,
_run_dot_product_attention, _run_dot_product_attention,
) )
...@@ -29,8 +27,6 @@ ckpt_attn = False ...@@ -29,8 +27,6 @@ ckpt_attn = False
workspace_opt = True workspace_opt = True
# QKV memory layout # QKV memory layout
qkv_layout = "bshd_bshd_bshd" qkv_layout = "bshd_bshd_bshd"
# sliding window attention
swa = False
# padding between sequences for qkv_format=thd # padding between sequences for qkv_format=thd
pad_between_seqs = False pad_between_seqs = False
# training mode # training mode
...@@ -64,7 +60,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp ...@@ -64,7 +60,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
ckpt_attn, ckpt_attn,
qkv_layout, qkv_layout,
workspace_opt, workspace_opt,
swa,
pad_between_seqs, pad_between_seqs,
is_training, is_training,
) )
...@@ -76,7 +71,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp ...@@ -76,7 +71,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
ckpt_attn, ckpt_attn,
qkv_layout, qkv_layout,
workspace_opt, workspace_opt,
swa,
pad_between_seqs, pad_between_seqs,
is_training, is_training,
) )
...@@ -97,7 +91,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp ...@@ -97,7 +91,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
ckpt_attn, ckpt_attn,
qkv_layout, qkv_layout,
workspace_opt, workspace_opt,
swa,
pad_between_seqs, pad_between_seqs,
is_training, is_training,
) )
...@@ -115,7 +108,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp ...@@ -115,7 +108,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
ckpt_attn, ckpt_attn,
qkv_layout, qkv_layout,
workspace_opt, workspace_opt,
swa,
pad_between_seqs, pad_between_seqs,
is_training, is_training,
) )
...@@ -205,13 +197,15 @@ def main(): ...@@ -205,13 +197,15 @@ def main():
) )
for model in model_configs.keys(): for model in model_configs.keys():
config = model_configs[model] config = model_configs[model]
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported( available_backends, fused_attn_backends = _get_attention_backends(
config, config,
dtype, qkv_dtype=dtype,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
) )
fused_attn_supported = fused_attn_supported and not swa flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
flash_attn_supported = _is_flash_attention_supported(config)
print( print(
f'Running {model} with {"cuDNN attention" if fused_attn_supported else ""}' f'Running {model} with {"cuDNN attention" if fused_attn_supported else ""}'
f'{" and flash-attention" if flash_attn_supported else ""}...' f'{" and flash-attention" if flash_attn_supported else ""}...'
......
...@@ -6,7 +6,6 @@ import os ...@@ -6,7 +6,6 @@ import os
import torch import torch
from typing import Tuple from typing import Tuple
from tests.pytorch.fused_attn.test_fused_attn import ModelConfig from tests.pytorch.fused_attn.test_fused_attn import ModelConfig
from transformer_engine.pytorch.distributed import _set_cuda_rng_state
from transformer_engine.pytorch.attention import DotProductAttention from transformer_engine.pytorch.attention import DotProductAttention
# Initialize RNG state # Initialize RNG state
...@@ -22,7 +21,7 @@ _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) ...@@ -22,7 +21,7 @@ _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
def reset_rng_states() -> None: def reset_rng_states() -> None:
"""Revert back to initial RNG state""" """Revert back to initial RNG state"""
torch.set_rng_state(_cpu_rng_state) torch.set_rng_state(_cpu_rng_state)
_set_cuda_rng_state(_cuda_rng_state) torch.cuda.set_rng_state(_cuda_rng_state)
def _run_dot_product_attention( def _run_dot_product_attention(
...@@ -40,7 +39,7 @@ def _run_dot_product_attention( ...@@ -40,7 +39,7 @@ def _run_dot_product_attention(
[config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda" [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
) )
inp = torch.randn( inp = torch.randn(
[config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim], [config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim_qk],
dtype=dtype, dtype=dtype,
device="cuda", device="cuda",
) )
...@@ -51,7 +50,7 @@ def _run_dot_product_attention( ...@@ -51,7 +50,7 @@ def _run_dot_product_attention(
k.requires_grad = True k.requires_grad = True
v.requires_grad = True v.requires_grad = True
out_grad = torch.randn( out_grad = torch.randn(
[config.batch_size, config.max_seqlen_q, config.num_heads * config.head_dim], [config.batch_size, config.max_seqlen_q, config.num_heads * config.head_dim_v],
dtype=dtype, dtype=dtype,
device="cuda", device="cuda",
) )
...@@ -80,7 +79,7 @@ def _run_dot_product_attention( ...@@ -80,7 +79,7 @@ def _run_dot_product_attention(
block = DotProductAttention( block = DotProductAttention(
config.num_heads, config.num_heads,
config.head_dim, config.head_dim_qk,
num_gqa_groups=config.num_gqa_groups, num_gqa_groups=config.num_gqa_groups,
qkv_format="bshd", qkv_format="bshd",
attention_dropout=config.dropout_p, attention_dropout=config.dropout_p,
...@@ -89,6 +88,8 @@ def _run_dot_product_attention( ...@@ -89,6 +88,8 @@ def _run_dot_product_attention(
get_rng_state_tracker=None, get_rng_state_tracker=None,
tp_group=None, tp_group=None,
layer_number=1, layer_number=1,
attn_mask_type="no_mask",
window_size=(-1, -1),
).to(dtype=dtype, device="cuda") ).to(dtype=dtype, device="cuda")
# Run a forward and backward pass # Run a forward and backward pass
...@@ -103,6 +104,7 @@ def _run_dot_product_attention( ...@@ -103,6 +104,7 @@ def _run_dot_product_attention(
attn_mask_type=config.attn_mask_type, # 'arbitrary' attn_mask_type=config.attn_mask_type, # 'arbitrary'
core_attention_bias_type=config.attn_bias_type, # 'no_bias' core_attention_bias_type=config.attn_bias_type, # 'no_bias'
core_attention_bias=bias, # None core_attention_bias=bias, # None
window_size=(-1, -1),
) )
out.backward(out_grad) out.backward(out_grad)
...@@ -116,6 +118,7 @@ def _run_dot_product_attention( ...@@ -116,6 +118,7 @@ def _run_dot_product_attention(
attn_mask_type=config.attn_mask_type, # no_mask attn_mask_type=config.attn_mask_type, # no_mask
core_attention_bias_type=config.attn_bias_type, # 'post_scale_bias' core_attention_bias_type=config.attn_bias_type, # 'post_scale_bias'
core_attention_bias=bias, # bias core_attention_bias=bias, # bias
window_size=(-1, -1),
) )
out.backward(out_grad) out.backward(out_grad)
...@@ -133,6 +136,7 @@ print("Run with post_scale_bias:") ...@@ -133,6 +136,7 @@ print("Run with post_scale_bias:")
config = model_configs["test_bias"] config = model_configs["test_bias"]
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(dtype, config, "bs3hd") fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(dtype, config, "bs3hd")
print()
print("Run with arbitrary mask:") print("Run with arbitrary mask:")
config = model_configs["test_mask"] config = model_configs["test_mask"]
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(dtype, config, "bs3hd") unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(dtype, config, "bs3hd")
...@@ -140,4 +144,6 @@ unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(dtype, config, " ...@@ -140,4 +144,6 @@ unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(dtype, config, "
torch.testing.assert_close(unfused_attn_fwd, fused_attn_fwd, atol=2.5e-2, rtol=2.5e-2) torch.testing.assert_close(unfused_attn_fwd, fused_attn_fwd, atol=2.5e-2, rtol=2.5e-2)
for i in range(3): for i in range(3):
torch.testing.assert_close(unfused_attn_bwd[i], fused_attn_bwd[i], atol=2.5e-2, rtol=2.5e-2) torch.testing.assert_close(unfused_attn_bwd[i], fused_attn_bwd[i], atol=2.5e-2, rtol=2.5e-2)
print()
print("Test passed!") print("Test passed!")
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
"cells": [ "cells": [
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "8ae3bc43", "id": "040f466a",
"metadata": {}, "metadata": {},
"source": [ "source": [
"# Attention Is All You Need!\n", "# Attention Is All You Need!\n",
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "47421c01", "id": "89a7d849",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## 1. Attention Backends\n", "## 1. Attention Backends\n",
...@@ -71,7 +71,7 @@ ...@@ -71,7 +71,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "e52f60f0", "id": "c90a2573",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 1.1 Flash vs. Non-Flash\n", "### 1.1 Flash vs. Non-Flash\n",
...@@ -85,30 +85,30 @@ ...@@ -85,30 +85,30 @@
"- **Recomputation:** The non-flash algorithm stores the softmax matrix (quadratic to sequence length) to global memory for the backward pass, while the flash algorithm only saves the softmax normalization factors (linear to sequence length). This reduces the amount of memory required as well as the bandwidth utilization between global memory and shared memory. Even though there is extra computation incurred in order to recalculate the attention in the backward pass, the bandwidth savings still provide significant improvement in efficiency.\n", "- **Recomputation:** The non-flash algorithm stores the softmax matrix (quadratic to sequence length) to global memory for the backward pass, while the flash algorithm only saves the softmax normalization factors (linear to sequence length). This reduces the amount of memory required as well as the bandwidth utilization between global memory and shared memory. Even though there is extra computation incurred in order to recalculate the attention in the backward pass, the bandwidth savings still provide significant improvement in efficiency.\n",
"\n", "\n",
"<div class=\"alert alert-info\">\n", "<div class=\"alert alert-info\">\n",
"<b>Note</b> \n", "<b>Note:</b> \n",
" \n", " \n",
"Transformer Engine's flash-attention backend, available in PyTorch, and cuDNN attention backend (sub-backends 1 and 2), in PyTorch, JAX and PaddlePaddle, are both based on the flash algorithm.\n", "Transformer Engine's flash-attention backend, available in PyTorch, and cuDNN attention backend (sub-backends 1 and 2), available in PyTorch, JAX and PaddlePaddle, are both based on the flash algorithm.\n",
"</div>\n" "</div>\n"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "bb909ac4", "id": "b5ce567d",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 1.2 flash-attention\n", "### 1.2 flash-attention\n",
"\n", "\n",
"The flash-attention backend, available only in PyTorch, is a module wrapped around the public `flash-attn` package [[3]](https://github.com/Dao-AILab/flash-attention). \n", "The flash-attention backend, available only in PyTorch, is a module wrapped around the public `flash-attn` package [[3]](https://github.com/Dao-AILab/flash-attention). \n",
"\n", "\n",
"The flash-attention backend supports `flash-attn`'s features as they are released, and to facilitate the use of `flash-attn`, flash-attention also offers a few functionalities such as converting the `attention_mask` to cumulative sequence lengths `cu_seqlens` for `padding` mask. Please see `transformer_engine.pytorch.attention.FlashAttention` for more details.\n", "The flash-attention backend supports `flash-attn`'s features as well as a few extra functionalities to facilitate the use of `flash-attn`, such as converting the `attention_mask` to cumulative sequence lengths `cu_seqlens` for `padding` mask use cases. Please see `transformer_engine.pytorch.attention.FlashAttention` for details.\n",
"\n", "\n",
"The `flash-attn` dependency is regularly updated in Transformer Engine. As of v1.7, Transformer Engine supports `flash-attn` 2.0.6+ (see [setup.py](https://github.com/NVIDIA/TransformerEngine/blob/main/setup.py)).\n", "The `flash-attn` dependency is regularly updated in Transformer Engine. As of v1.10, Transformer Engine supports `flash-attn` 2.0.6+ (see [setup.py](https://github.com/NVIDIA/TransformerEngine/blob/main/setup.py)).\n",
"\n", "\n",
"To understand `flash-attn`'s performance, please refer to their [benchmarks](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#performance).\n", "To understand `flash-attn`'s performance, please refer to their benchmarks [here](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#performance).\n",
"\n", "\n",
"### 1.3 cuDNN Attention\n", "### 1.3 cuDNN Attention\n",
"\n", "\n",
"The cuDNN attention backend, available in PyTorch, JAX and PaddlePaddle, offers another high-performance solution to the attention calculation. It requires [cuDNN](https://developer.nvidia.com/cudnn) to run, and has several sub-backends to support the different precisions and sequence lengths. Out of the three, sub-backends 1 and 2 are based on the flash algorithm, as `flash-attn` is.\n", "The cuDNN attention backend, available in PyTorch, JAX and PaddlePaddle, offers another high-performance solution to the attention calculation. It requires [cuDNN](https://developer.nvidia.com/cudnn) to run, and has several sub-backends to support the different precisions and sequence lengths.\n",
"\n", "\n",
"<table class=\"docutils align-default\">\n", "<table class=\"docutils align-default\">\n",
" <tr>\n", " <tr>\n",
...@@ -153,14 +153,14 @@ ...@@ -153,14 +153,14 @@
" </tr>\n", " </tr>\n",
"</table>\n", "</table>\n",
"\n", "\n",
"The cuDNN attention backend and flash-attention backend have several notable differences. As of Transformer Engine 1.7, cuDNN 9.0 and `flash-attn` 2.4.2,\n", "The cuDNN attention backend and flash-attention backend have several notable differences. As of Transformer Engine 1.10, cuDNN 9.3 and `flash-attn` 2.4.2,\n",
"\n", "\n",
"- flash-attention only supports the PyTorch framework while cuDNN attention supports PyTorch, JAX and PaddlePaddle.\n", "- flash-attention only supports the PyTorch framework while cuDNN attention supports PyTorch, JAX and PaddlePaddle.\n",
"- flash-attention supports BF16, FP16 precisions while cuDNN attention also supports FP8 (through its sub-backend 2).\n", "- flash-attention supports BF16, FP16 precisions while cuDNN attention also supports FP8 (through its sub-backend 2).\n",
"- flash-attention supports `bshd`, `thd` input formats, without any transposes, and `sbhd` format, with transposes, while cuDNN attention supports all three without transposes (see Section 3.1 for more details).\n", "- flash-attention supports `bshd`, `thd` input formats, without any transposes, and `sbhd` format, with transposes, while cuDNN attention supports all three formats without transposes (see Section 3.1 for more details).\n",
"- flash-attention does not support `post_scale_bias`, and cuDNN attention does.\n", "- flash-attention does not support `post_scale_bias`, and cuDNN attention does.\n",
"- flash-attention supports sliding window attention, and cuDNN attention does not.\n", "- flash-attention supports KV-caching and paged attention, and cuDNN attention does not.\n",
"- flash-attention uses bottom right diagonal for `causal` mask in cross attention, and cuDNN attention uses top left (see `flash-attn`'s [change log](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#21-change-behavior-of-causal-flag)).\n", "- flash-attention uses bottom right diagonal for `causal` mask in cross attention (see [change log](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#21-change-behavior-of-causal-flag)), and cuDNN attention supports both top left and bottom right.\n",
"- flash-attention outperforms cuDNN attention on Ampere architectures, and cuDNN attention has 20-50% advantages on Hopper architectures, based on our benchmarks for a number of commonly-used model configurations.\n", "- flash-attention outperforms cuDNN attention on Ampere architectures, and cuDNN attention has 20-50% advantages on Hopper architectures, based on our benchmarks for a number of commonly-used model configurations.\n",
"\n", "\n",
"To compare cuDNN attention and flash-attention, users can modify the `model_configs` dictionary in [benchmarks/attention/benchmark_attention.py](https://github.com/NVIDIA/TransformerEngine/blob/main/benchmarks/attention/benchmark_attention.py) to collect performance numbers. The script runs each entry in `model_configs` for `num_iters` times, each time with one forward pass and one backward pass. Both backends are tried, and if one backend does not have support for the specific user input, the runtimes and speedups in the final table would be 0." "To compare cuDNN attention and flash-attention, users can modify the `model_configs` dictionary in [benchmarks/attention/benchmark_attention.py](https://github.com/NVIDIA/TransformerEngine/blob/main/benchmarks/attention/benchmark_attention.py) to collect performance numbers. The script runs each entry in `model_configs` for `num_iters` times, each time with one forward pass and one backward pass. Both backends are tried, and if one backend does not have support for the specific user input, the runtimes and speedups in the final table would be 0."
...@@ -169,7 +169,7 @@ ...@@ -169,7 +169,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "9a380859", "id": "c5b8e3d7",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -184,25 +184,25 @@ ...@@ -184,25 +184,25 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 1,
"id": "0584bb01", "id": "50852cb5",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Device 0: NVIDIA H100 PCIe GPU, sm90 compute capability, 79.1GB memory\n", "Device 0: NVIDIA H100 80GB HBM3 GPU, sm90 compute capability, 79.1GB memory\n",
"Running test_0 with cuDNN attention and flash-attention...\n", "Running test_0 with cuDNN attention and flash-attention...\n",
"Running test_1 with cuDNN attention and flash-attention...\n", "Running test_1 with cuDNN attention and flash-attention...\n",
"Running test_2 with cuDNN attention...\n", "Running test_2 with cuDNN attention...\n",
"Running test_3 with cuDNN attention and flash-attention...\n", "Running test_3 with cuDNN attention and flash-attention...\n",
"\n", "\n",
" cuDNN fwd+bwd (ms) flash-attn fwd+bwd (ms) cuDNN vs flash speedup\n", " cuDNN fwd+bwd (ms) flash-attn fwd+bwd (ms) cuDNN vs flash speedup\n",
"test_0 0.0638 0.0858 1.3454\n", "test_0 0.0340 0.0468 1.3786\n",
"test_1 0.5415 0.7496 1.3842\n", "test_1 0.3664 0.5850 1.5968\n",
"test_2 1.2302 0.0000 0.0000\n", "test_2 0.9332 0.0000 0.0000\n",
"test_3 12.0122 19.0716 1.5877\n" "test_3 7.4875 11.8879 1.5877\n"
] ]
} }
], ],
...@@ -212,7 +212,7 @@ ...@@ -212,7 +212,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "45e53fc9", "id": "9a615119",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## 2. Backend Selection\n", "## 2. Backend Selection\n",
...@@ -253,35 +253,35 @@ ...@@ -253,35 +253,35 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "6dfeade3", "id": "e6c0f3f0",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 2.1 Debug Information\n", "### 2.1 Debug Information\n",
"\n", "\n",
"To find out which backend is being used during runtime, users can turn on these debugging flags. Logging is done using the `logging` package.\n", "To find out which backend is being used during runtime, we have the following two debugging flags. Logging is done by using the `logging` package.\n",
"```\n", "```\n",
"NVTE_DEBUG = 0/1 # disables/enables debugging\n", "NVTE_DEBUG = 0/1 # disables/enables debugging\n",
"NVTE_DEBUG_LEVEL = 0/1/2 # enables logging.WARNING/INFO/DEBUG-level messages\n", "NVTE_DEBUG_LEVEL = 0/1/2 # enables logging.WARNING/INFO/DEBUG-level messages\n",
"```\n", "```\n",
"<div class=\"alert alert-info\">\n", "<div class=\"alert alert-info\">\n",
"<b>Note</b>\n", "<b>Note:</b>\n",
" \n", " \n",
"These flags are supported in PyTorch only as of Transformer Engine 1.7. JAX and PaddlePaddle support is expected to be added in the future.\n", "These flags are supported in PyTorch only as of Transformer Engine 1.10. JAX and PaddlePaddle support is expected to be added in the future.\n",
"</div>" "</div>"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "7e3b7981", "id": "16660323",
"metadata": {}, "metadata": {},
"source": [ "source": [
"The [example_attention.py](https://raw.githubusercontent.com/NVIDIA/TransformerEngine/main/docs/examples/attention/example_attention.py) script runs a very basic model with two attention backends, cuDNN attention and flash-attention. Here `NVTE_DEBUG_LEVEL=1` allows us to find out which backend/sub-backend was actually used during runtime." "The example script [example_attention.py](https://raw.githubusercontent.com/NVIDIA/TransformerEngine/main/docs/examples/attention/example_attention.py) runs a very basic model with two attention backends, cuDNN attention and flash-attention. Here `NVTE_DEBUG_LEVEL=1` allows us to find out which backend/sub-backend is used in runtime."
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 22, "execution_count": 24,
"id": "961c51d4", "id": "906b8cf1",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -293,7 +293,7 @@ ...@@ -293,7 +293,7 @@
"[INFO | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)\n", "[INFO | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)\n",
"\n", "\n",
"Run flash-attention...\n", "Run flash-attention...\n",
"[INFO | DotProductAttention]: Running with FlashAttention backend \n", "[INFO | DotProductAttention]: Running with FlashAttention backend\n",
"\n", "\n",
"Test passed.\n" "Test passed.\n"
] ]
...@@ -305,16 +305,16 @@ ...@@ -305,16 +305,16 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "11bfbbd7", "id": "8ca99461",
"metadata": {}, "metadata": {},
"source": [ "source": [
"To collect more information, users can turn on `NVTE_DEBUG_LEVEL=2`. In this example, it allows us to find out more about the run config. Users are encouraged to provide if users intend to file a bug with Transformer Engine. For example, " "`NVTE_DEBUG_LEVEL=2` allows us to find out more about the backend selection logic. Users are encouraged to double check the `config` and provide it to the Transformer Engine team if they would like to file a bug. "
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 25, "execution_count": 23,
"id": "162a2be1", "id": "d3637094",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -323,16 +323,18 @@ ...@@ -323,16 +323,18 @@
"text": [ "text": [
"\n", "\n",
"Run cuDNN attention...\n", "Run cuDNN attention...\n",
"[DEBUG | DotProductAttention]: Running with config={'transformer_engine_version': '1.10.0.dev0+ee85a91', 'compute_capability': 'sm90', 'flash_attn_version': <Version('2.4.2')>, 'cudnn_version': '9.3.0', 'qkv_type': <class 'torch.Tensor'>, 'qkv_dtype': torch.bfloat16, 'qkv_layout': 'bshd_bshd_bshd', 'batch_size': 2, 'num_heads': 16, 'num_gqa_groups': 16, 'max_seqlen_q': 512, 'max_seqlen_kv': 512, 'head_dim_qk': 64, 'head_dim_v': 64, 'attn_mask_type': 'no_mask', 'window_size': (-1, -1), 'alibi_slopes_shape': None, 'core_attention_bias_type': 'no_bias', 'core_attention_bias_shape': None, 'core_attention_bias_requires_grad': False, 'pad_between_seqs': False, 'attention_dropout': 0.0, 'context_parallel': False, 'deterministic': False, 'is_training': True, 'fp8': False, 'fp8_meta': {'fp8_checkpoint': False, 'fp8_group': None, 'recipe': margin=0, format=HYBRID, amax_history_len=1024, wgrad_override=False, fp8_dpa=False, fp8_mha=False}}\n",
"[DEBUG | DotProductAttention]: Disabling FlashAttention due to NVTE_FLASH_ATTN=0\n", "[DEBUG | DotProductAttention]: Disabling FlashAttention due to NVTE_FLASH_ATTN=0\n",
"[DEBUG | DotProductAttention]: Available backends = {FlashAttention=False, FusedAttention=True (sub-backend 1), UnfusedDotProductAttention=True}\n",
"[DEBUG | DotProductAttention]: Selected backend = FusedAttention (sub-backend 1)\n",
"[INFO | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)\n", "[INFO | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)\n",
"[DEBUG | DotProductAttention]: Running with {'compute_capability': 'sm90', 'q_dtype': torch.bfloat16, 'k_dtype': torch.bfloat16, 'v_dtype': torch.bfloat16, 'q_shape': [2, 512, 16, 64], 'k_shape': [2, 512, 16, 64], 'v_shape': [2, 512, 16, 64], 'qkv_format': 'bshd', 'qkv_layout': 'bshd_bshd_bshd', 'mask_type': 'no_mask', 'bias_type': 'no_bias', 'bias_shape': None, 'dropout': 0.0, 'context_parallel': False, 'is_training': True, 'transformer_engine_version': <Version('1.8.0.dev0')>, 'flash_attn_version': <Version('2.4.2')>, 'cudnn_version': '9.2.0'}\n",
"[DEBUG | FusedAttnFunc ]: Running forward in torch.bfloat16\n",
"[DEBUG | FusedAttnFunc ]: Running backward in torch.bfloat16\n",
"\n", "\n",
"Run flash-attention...\n", "Run flash-attention...\n",
"[DEBUG | DotProductAttention]: Running with config={'transformer_engine_version': '1.10.0.dev0+ee85a91', 'compute_capability': 'sm90', 'flash_attn_version': <Version('2.4.2')>, 'cudnn_version': '9.3.0', 'qkv_type': <class 'torch.Tensor'>, 'qkv_dtype': torch.bfloat16, 'qkv_layout': 'bshd_bshd_bshd', 'batch_size': 2, 'num_heads': 16, 'num_gqa_groups': 16, 'max_seqlen_q': 512, 'max_seqlen_kv': 512, 'head_dim_qk': 64, 'head_dim_v': 64, 'attn_mask_type': 'no_mask', 'window_size': (-1, -1), 'alibi_slopes_shape': None, 'core_attention_bias_type': 'no_bias', 'core_attention_bias_shape': None, 'core_attention_bias_requires_grad': False, 'pad_between_seqs': False, 'attention_dropout': 0.0, 'context_parallel': False, 'deterministic': False, 'is_training': True, 'fp8': False, 'fp8_meta': {'fp8_checkpoint': False, 'fp8_group': None, 'recipe': margin=0, format=HYBRID, amax_history_len=1024, wgrad_override=False, fp8_dpa=False, fp8_mha=False}}\n",
"[DEBUG | DotProductAttention]: Disabling FusedAttention due to NVTE_FUSED_ATTN=0\n", "[DEBUG | DotProductAttention]: Disabling FusedAttention due to NVTE_FUSED_ATTN=0\n",
"[INFO | DotProductAttention]: Running with FlashAttention backend \n", "[DEBUG | DotProductAttention]: Available backends = {FlashAttention=True, FusedAttention=False, UnfusedDotProductAttention=True}\n",
"[DEBUG | DotProductAttention]: Running with {'compute_capability': 'sm90', 'q_dtype': torch.bfloat16, 'k_dtype': torch.bfloat16, 'v_dtype': torch.bfloat16, 'q_shape': [2, 512, 16, 64], 'k_shape': [2, 512, 16, 64], 'v_shape': [2, 512, 16, 64], 'qkv_format': 'bshd', 'qkv_layout': 'bshd_bshd_bshd', 'mask_type': 'no_mask', 'bias_type': 'no_bias', 'bias_shape': None, 'dropout': 0.0, 'context_parallel': False, 'is_training': True, 'transformer_engine_version': <Version('1.8.0.dev0')>, 'flash_attn_version': <Version('2.4.2')>, 'cudnn_version': '9.2.0'}\n", "[DEBUG | DotProductAttention]: Selected backend = FlashAttention\n",
"[INFO | DotProductAttention]: Running with FlashAttention backend\n",
"\n", "\n",
"Test passed.\n" "Test passed.\n"
] ]
...@@ -344,7 +346,7 @@ ...@@ -344,7 +346,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "779a51e6", "id": "611d8fdb",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 2.2 User Control\n", "### 2.2 User Control\n",
...@@ -392,28 +394,29 @@ ...@@ -392,28 +394,29 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "ccd5650d", "id": "e60a2a3e",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## 3. Backend Support\n", "## 3. Backend Support\n",
"\n", "\n",
"Transformer Engine supports commonly-used features such as self and cross attention, FP16/BF16 precisions, dropout, and checkpointing. But it also offers a range of other features. As of v1.7, Transformer Engine's attention backends have the following support matrix.\n", "Transformer Engine supports commonly-used features such as self and cross attention, FP16/BF16 precisions, dropout, and checkpointing. But it also offers a range of other features. As of v1.10, Transformer Engine's attention backends have the following support matrix.\n",
"\n", "\n",
"| Attention Backend | Precision | Architecture | Sliding Window Attention | MQA/GQA | Context Parallelism | Determinism Possible |\n", "| Attention Backend | Precision | Architecture | Sliding Window Attention | MQA/GQA | Multi-Latent Attention | Context Parallelism | Determinism Possible |\n",
"| :---------------- | :-------- | :----------- | :----------------------- | :------ | :------------------ | :------------ |\n", "| :---------------- | :-------- | :----------- | :----------------------- | :------ | :--------------------- | :------------------ | :------------ |\n",
"| cuDNN attention (all frameworks) | BF16, FP16, FP8 (PyTorch only) | sm80+ | No | Yes | Yes (only for `bshd`,`sbhd`) | Yes |\n", "| cuDNN attention (all frameworks) | BF16, FP16, FP8 (PyTorch only) | sm80+ | No | Yes | Yes | Yes (`bshd`,`sbhd`, `thd`) | Yes |\n",
"| flash-attention (PyTorch) | BF16, FP16 | sm80+ | Yes | Yes | Yes (only for `bshd`,`thd`) | Yes |\n", "| flash-attention (PyTorch) | BF16, FP16 | sm80+ | Yes | Yes | No | Yes (`bshd`,`thd`) | Yes |\n",
"| Framework-native attention | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | No | Yes |\n", "| Framework-native attention | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | Yes (PyTorch only) | No | Yes |\n",
"\n", "\n",
"Some unit tests are provided to serve as a starting point for integrating such features into users' models. For example,\n", "Some unit tests are provided to serve as a starting point for integrating such features into users' models. For example,\n",
"- sliding window attention: [test_dpa_swa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n", "- sliding window attention: [test_dpa_swa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n",
"- MQA/GQA: [test_te_layer_mqa_gqa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n", "- MQA/GQA: [test_te_layer_mqa_gqa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n",
"- Multi-Latent Attention: [test_dpa_mla](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n",
"- context parallelism: [test_cp_with_fused_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py), [test_cp_with_flash_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py)" "- context parallelism: [test_cp_with_fused_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py), [test_cp_with_flash_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py)"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "8439b389", "id": "fbdcb327",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 3.1 QKV Layout\n", "### 3.1 QKV Layout\n",
...@@ -439,7 +442,7 @@ ...@@ -439,7 +442,7 @@
"**qkv_layout=thd_thd_thd:**\n", "**qkv_layout=thd_thd_thd:**\n",
"`q`, `k`, `v` have variable sequence lengths in a batch. They are all contiguous and have no interleaving.\n", "`q`, `k`, `v` have variable sequence lengths in a batch. They are all contiguous and have no interleaving.\n",
"\n", "\n",
"As of v1.7, Transformer Engine has the following support matrix.\n", "As of v1.10, Transformer Engine has the following support matrix.\n",
"\n", "\n",
"<table class=\"docutils align-default\">\n", "<table class=\"docutils align-default\">\n",
" <tr>\n", " <tr>\n",
...@@ -480,16 +483,16 @@ ...@@ -480,16 +483,16 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "0290f8e9", "id": "855d9616",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 3.2 Attention Mask\n", "### 3.2 Attention Mask\n",
"\n", "\n",
"Transformer Engine supports 5 mask types, and all the masks are defined as `True` masking out the corresponding element and `False` including the corresponding element in attention calculation.\n", "Transformer Engine supports 7 mask types, and all the masks are defined as `True` masking out the corresponding element and `False` including the corresponding element in attention calculation.\n",
"\n", "\n",
"- `no_mask`, `padding`, `causal`, `padding_causal` (equivalent to `causal_padding`), `arbitrary`\n", "- `no_mask`, `padding`, `causal`, `causal_bottom_right`, `padding_causal`, `padding_causal_bottom_right`, `arbitrary`\n",
"\n", "\n",
"Different backends offer different support for attention mask. As of Transformer Engine 1.7,\n", "Different backends offer different support for attention mask. As of Transformer Engine 1.10,\n",
"\n", "\n",
"<table class=\"docutils align-default\">\n", "<table class=\"docutils align-default\">\n",
" <tr>\n", " <tr>\n",
...@@ -498,34 +501,25 @@ ...@@ -498,34 +501,25 @@
" <th>Requires `attention_mask`</th>\n", " <th>Requires `attention_mask`</th>\n",
" </tr>\n", " </tr>\n",
" <tr>\n", " <tr>\n",
" <td rowspan=\"2\">flash-attention</td>\n", " <td>flash-attention</td>\n",
" <td rowspan=\"2\">`no_mask`, `causal`, `padding`, `padding_causal`</td>\n", " <td><li>`no_mask`, `causal` (self-attention),</li><li>`padding`, `padding_causal` (self-attention),</li><li>`causal_bottom_right`, `padding_causal_bottom_right`</li></td>\n",
" <td>`no_mask`, `causal`: No</td>\n", " <td rowspan=\"3\"><li>`no_mask`, `causal` `causal_bottom_right`: No</li><li>`padding`, `padding_causal`, `padding_causal_bottom_right`: Yes if `cu_seqlens` not provided</li><li>`arbitrary`: Yes</li></td>\n",
" </tr>\n",
" <tr>\n",
" <td>`padding`, `padding_causal`: Yes if `cu_seqlens` not provided</td>\n",
" </tr>\n",
" <tr>\n",
" <td rowspan=\"2\">cuDNN attention</td>\n",
" <td rowspan=\"2\">`no_mask`, `causal`, `padding`, `padding_causal`</td>\n",
" <td>`no_mask`, `causal`: No</td>\n",
" </tr>\n", " </tr>\n",
" <tr>\n", " <tr>\n",
" <td>\n", " <td>cuDNN attention</td>\n",
" `padding`, `padding_causal`: Yes if `cu_seqlens` not provided\n", " <td><li>`no_mask`, `causal`,</li><li>`padding`, `padding_causal`,</li><li>`causal_bottom_right`, `padding_causal_bottom_right`</li></td>\n",
" </td> \n", " <td></td>\n",
" </tr>\n", " </tr>\n",
" <tr>\n", " <tr>\n",
" <td rowspan=\"2\">Framework-native attention</td>\n", " <td>Framework-native attention</td>\n",
" <td rowspan=\"2\">`no_mask`, `causal`, `arbitrary`</td>\n", " <td><li>All (PyTorch)</li><li>`no_mask`, `causal`, `padding` (Jax, PaddlePaddle)</li></td>\n",
" <td>`no_mask`, `causal`: No</td>\n",
" </tr>\n", " </tr>\n",
" <tr>\n", " <tr>\n",
" <td>`arbitrary`: Yes</td>\n", " <td></td>\n",
" </tr>\n", " </tr>\n",
"</table>\n", "</table>\n",
"\n", "\n",
"**padding and padding_causal:** For these two mask types, users need to provide sequence length information to help Transformer Engine figure out where each sequence ends in a batch. As of Transformer Engine 1.7, there are two options to do so in PyTorch and one in JAX and PaddlePaddle.\n", "**Padding masks:** For `padding`, `padding_causal`, `padding_causal_bottom_right` mask types, users need to provide sequence length information to help Transformer Engine figure out where each sequence ends in a batch. As of Transformer Engine 1.10, there are two options to do so in PyTorch and one in JAX and PaddlePaddle.\n",
"\n", "\n",
"* PyTorch: When both options are provided by the user, `cu_seqlens` is preferred as there is no extra conversion needed.\n", "* PyTorch: When both options are provided by the user, `cu_seqlens` is preferred as there is no extra conversion needed.\n",
" - `cu_seqlens`: Users can provide cumulative sequence length tensors `cu_seqlens_q` and `cu_seqlens_kv` for `q` and `k`/`v` to the flash-attention or cuDNN attention backend. An example of `cu_seqlens` is `[0, 2, 6, 7]` for a batch of 3 `[aa000, bbbb0, c0000]`.\n", " - `cu_seqlens`: Users can provide cumulative sequence length tensors `cu_seqlens_q` and `cu_seqlens_kv` for `q` and `k`/`v` to the flash-attention or cuDNN attention backend. An example of `cu_seqlens` is `[0, 2, 6, 7]` for a batch of 3 `[aa000, bbbb0, c0000]`.\n",
...@@ -536,13 +530,13 @@ ...@@ -536,13 +530,13 @@
"\n", "\n",
"**qkv_format=thd:** Transformer Engine extracts the max sequence length information from `q`, `k`, `v` if `max_seqlen_q` and `max_seqlen_kv` are not provided. This requires GPU-CPU copy and synchronization operations. For performance reasons, please set `max_seqlen_q` and `max_seqlen_kv` to their appropriate values for `thd` QKV format.\n", "**qkv_format=thd:** Transformer Engine extracts the max sequence length information from `q`, `k`, `v` if `max_seqlen_q` and `max_seqlen_kv` are not provided. This requires GPU-CPU copy and synchronization operations. For performance reasons, please set `max_seqlen_q` and `max_seqlen_kv` to their appropriate values for `thd` QKV format.\n",
"\n", "\n",
"**Arbitrary mask:** cuDNN does not support `Arbitrary` mask type as of v9.0. However, users can convert the mask to a regular `post_scale_bias` bias and achieve the same functionality. An example script for this conversion is [arbitrary_mask_to_post_scale_bias.py](https://raw.githubusercontent.com/NVIDIA/TransformerEngine/main/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py).\n" "**Arbitrary mask:** cuDNN does not support `Arbitrary` mask type as of v9.3. However, users can convert the mask to a regular `post_scale_bias` bias and achieve the same functionality. An example script for this conversion is [arbitrary_mask_to_post_scale_bias.py](https://raw.githubusercontent.com/NVIDIA/TransformerEngine/main/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py).\n"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 33,
"id": "b1b7cdd4", "id": "a1f25a9b",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -550,27 +544,29 @@ ...@@ -550,27 +544,29 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Run with post_scale_bias:\n", "Run with post_scale_bias:\n",
"[DotProductAttention]: using cuDNN attention (sub-backend 1)\n", "[INFO | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)\n",
"\n",
"Run with arbitrary mask:\n", "Run with arbitrary mask:\n",
"[DotProductAttention]: using unfused DPA\n", "[INFO | DotProductAttention]: Running with UnfusedDotProductAttention backend\n",
"\n",
"Test passed!\n" "Test passed!\n"
] ]
} }
], ],
"source": [ "source": [
"!NVTE_DEBUG=1 python arbitrary_mask_to_post_scale_bias.py" "!NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python arbitrary_mask_to_post_scale_bias.py"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "e045c284", "id": "dda4a589",
"metadata": {}, "metadata": {},
"source": [ "source": [
"Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py).\n", "Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py).\n",
"\n", "\n",
"### 3.3 Attention Bias\n", "### 3.3 Attention Bias\n",
"\n", "\n",
"Transformer Engine supports 4 attention bias types, `no_bias`, `pre_scale_bias`, `post_scale_bias`, and `ALiBi` (with/without custom slopes). As of Transformer Engine 1.7, their support matrix is as follows.\n", "Transformer Engine supports 4 attention bias types, `no_bias`, `pre_scale_bias`, `post_scale_bias`, and `ALiBi` (with/without custom slopes). As of Transformer Engine 1.10, their support matrix is as follows.\n",
"\n", "\n",
"<table class=\"docutils align-default\">\n", "<table class=\"docutils align-default\">\n",
" <tr>\n", " <tr>\n",
...@@ -617,25 +613,20 @@ ...@@ -617,25 +613,20 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "8b8a4e40", "id": "a0702339",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 3.4 FP8 Attention\n", "### 3.4 FP8 Attention\n",
"\n", "\n",
"A unique feature of Transformer Engine is its FP8 support, not only for the `Linear` layers but also for dot product attention. Transformer Engine's FP8 attention support is through its cuDNN attention sub-backend 2. Recall Figure 1: the two `MatMul` operations are performed in FP8 for computational efficiency, and the `SoftMax` operation is performed in FP32 for numerical accuracy.\n", "A unique feature of Transformer Engine is its FP8 support, not only for the `Linear` layers but also for dot product attention. Transformer Engine's FP8 attention support is through its cuDNN attention sub-backend 2. Recall Figure 1: the two `MatMul` operations are performed in FP8 for computational efficiency, and the `SoftMax` operation is performed in FP32 for numerical accuracy.\n",
"\n", "\n",
"Transformer Engine supports FP8 attention through its [C APIs](../../api/c/fused_attn.rst), and [PyTorch API](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention), as of v1.7. Its PyTorch API offers two options, both controlled through the FP8 recipe definition, `transformer_engine.common.recipe.DelayedScaling`.\n", "Transformer Engine supports FP8 attention through its [C APIs](../../api/c/fused_attn.rst), and [PyTorch API](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention), as of v1.10. Its PyTorch API offers two options, both controlled through the FP8 recipe definition, `transformer_engine.common.recipe.DelayedScaling`.\n",
"\n", "\n",
"- `DelayedScaling.fp8_dpa=True (default=False)`: This enables the use of cuDNN attention sub-backend 2, when it does support the provided user inputs. The `FusedAttention` module for cuDNN attention takes FP16 or BF16 tensors as inputs, performs dot product attention in FP8, and returns attention logits in FP16 or BF16 (same as the input type). Casting operations are required to cast tensors to FP8 at the beginning, and back to FP16/BF16 at the end of the module.\n", "- `DelayedScaling.fp8_dpa=True (default=False)`: This enables the use of cuDNN attention sub-backend 2, when it does support the provided user inputs. The `FusedAttention` module for cuDNN attention takes FP16 or BF16 tensors as inputs, performs dot product attention in FP8, and returns attention logits in FP16 or BF16 (same as the input type). Casting operations are required to cast tensors to FP8 at the beginning, and back to FP16/BF16 at the end of the module.\n",
"\n", "\n",
"- `DelayedScaling.fp8_mha=True (default=False)`: This option, on top of `fp8_dpa=True`, removes the casting operations at the beginning and end of the `FusedAttention` module. This feature is experimental. \n", "- `DelayedScaling.fp8_mha=True (default=False)`: This option, on top of `fp8_dpa=True`, removes the casting operations at the beginning and end of the `FusedAttention` module. This feature is experimental. \n",
"\n", "\n",
"Examples of using the two features are available at [test_dpa_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_mha_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`. This should result in the following print when the debug flags are turned on, `NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2`.\n", "Examples of using the two features are available at [test_dpa_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_mha_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`."
"```\n",
"[DEBUG | DotProductAttention]: Running with fp8_recipe.fp8_mha=False, fp8_recipe.fp8_dpa=True and NVTE_FP8_DPA_BWD=0\n",
"[DEBUG | FusedAttnFunc ]: Running forward in FP8\n",
"[DEBUG | FusedAttnFunc ]: Running backward in torch.bfloat16\n",
"```"
] ]
} }
], ],
......
...@@ -11,9 +11,7 @@ import nvtx ...@@ -11,9 +11,7 @@ import nvtx
import transformer_engine import transformer_engine
from tests.pytorch.fused_attn.test_fused_attn import ( from tests.pytorch.fused_attn.test_fused_attn import (
ModelConfig, ModelConfig,
_is_flash_attention_supported, _get_attention_backends,
_is_fused_attention_supported,
_is_unfused_attention_supported,
_run_dot_product_attention, _run_dot_product_attention,
) )
...@@ -60,7 +58,6 @@ def example_attention(model, fused_attn_supported, flash_attn_supported): ...@@ -60,7 +58,6 @@ def example_attention(model, fused_attn_supported, flash_attn_supported):
ckpt_attn, ckpt_attn,
qkv_layout, qkv_layout,
workspace_opt, workspace_opt,
swa,
pad_between_seqs, pad_between_seqs,
is_training, is_training,
) )
...@@ -75,7 +72,6 @@ def example_attention(model, fused_attn_supported, flash_attn_supported): ...@@ -75,7 +72,6 @@ def example_attention(model, fused_attn_supported, flash_attn_supported):
ckpt_attn, ckpt_attn,
qkv_layout, qkv_layout,
workspace_opt, workspace_opt,
swa,
pad_between_seqs, pad_between_seqs,
is_training, is_training,
) )
...@@ -94,13 +90,14 @@ def main(): ...@@ -94,13 +90,14 @@ def main():
models = ["test_0"] models = ["test_0"]
for model in models: for model in models:
config = model_configs[model] config = model_configs[model]
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported( available_backends, fused_attn_backends = _get_attention_backends(
config, config,
dtype, qkv_dtype=dtype,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
) )
fused_attn_supported = fused_attn_supported and not swa flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
flash_attn_supported = _is_flash_attention_supported(config)
example_attention(model, fused_attn_supported, flash_attn_supported) example_attention(model, fused_attn_supported, flash_attn_supported)
......
...@@ -8,6 +8,7 @@ import math ...@@ -8,6 +8,7 @@ import math
import os import os
from importlib.metadata import version from importlib.metadata import version
from typing import Any, Dict, List, Tuple, Union, Optional from typing import Any, Dict, List, Tuple, Union, Optional
from contextlib import contextmanager
import pytest import pytest
import torch import torch
...@@ -108,6 +109,16 @@ class ModelConfig: ...@@ -108,6 +109,16 @@ class ModelConfig:
self.window_size = window_size self.window_size = window_size
@contextmanager
def logging_context(highest_level=logging.WARNING):
previous_level = logging.root.manager.disable
logging.disable(highest_level)
try:
yield
finally:
logging.disable(previous_level)
def _get_attention_backends( def _get_attention_backends(
config: ModelConfig, config: ModelConfig,
qkv_dtype: torch.dtype, qkv_dtype: torch.dtype,
...@@ -180,6 +191,7 @@ def _get_attention_backends( ...@@ -180,6 +191,7 @@ def _get_attention_backends(
return available_backends, fused_attention_backend return available_backends, fused_attention_backend
backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"} backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
with logging_context():
for i in range(3): for i in range(3):
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment