Unverified Commit 43569381 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

Add documentation for dot product attention (#889)



* add attention docs
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: update attention doc
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: update attention doc
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: update attention doc
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: update attn doc
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: update attn doc
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: update attn doc
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: update attention doc
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* first draft
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor tweak to first draft
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up pictures
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* first draft for review
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add logging info/debug
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix of an SWA message
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* use subprocess instaed of os.sys
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up benchmark script
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add example script and update notebook
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor tweak
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor tweaks
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix Jax/Paddle related comments
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* rerun H100 benchmark
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* restrict fp8 tests to sm90+
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* move get_cudnn_version from common to pytorch utils
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent 905d94f4
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os, sys, time
import subprocess
import pandas as pd
import numpy as np
import torch
import nvtx
import transformer_engine
from tests.pytorch.fused_attn.test_fused_attn import (
ModelConfig,
_is_flash_attention_supported,
_is_fused_attention_supported,
_is_unfused_attention_supported,
_run_dot_product_attention
)
pd.set_option("display.precision", 4)
# data type
dtype = torch.bfloat16
# number of iterations after 3 warmup iterations
num_iters = 3
# checkpointing
ckpt_attn = False
# workspace optimization path for cuDNN attention
workspace_opt = True
# QKV memory layout
qkv_layout = 'bshd_bshd_bshd'
# sliding window attention
swa = False
# padding between sequences for qkv_format=thd
pad_between_seqs = False
# training mode
is_training = True
model_configs = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq
"test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask
"test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias
"test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA
}
def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supported):
config = model_configs[model]
if dtype == torch.bfloat16:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
else:
tols = dict(atol=5e-3, rtol=5e-3)
cudnn_times = []
flash_times = []
warmup_iters = 3
for i in range(warmup_iters):
if fused_attn_supported:
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, config, "FusedAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
)
if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, config, "FlashAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
)
if fused_attn_supported and flash_attn_supported:
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
for i,_ in enumerate(flash_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
torch.cuda.cudart().cudaProfilerStart()
torch.cuda.synchronize()
fused_attn_start = time.time()
if fused_attn_supported:
for i in range(num_iters):
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, config, "FusedAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
)
torch.cuda.synchronize()
fused_attn_time = time.time() - fused_attn_start if fused_attn_supported else 0
torch.cuda.synchronize()
flash_attn_start = time.time()
if flash_attn_supported:
for i in range(num_iters):
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, config, "FlashAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
)
torch.cuda.synchronize()
flash_attn_time = time.time() - flash_attn_start if flash_attn_supported else 0
df = pd.read_csv('times.csv')
df = pd.concat([
df,
pd.DataFrame(
[[fused_attn_time*1e3/num_iters, 0, 0, 0,
flash_attn_time*1e3/num_iters, 0, 0, 0, 0]], columns=df.columns)],
ignore_index=True
)
df.to_csv('times.csv',index=False)
torch.cuda.cudart().cudaProfilerStop()
def parse_results(per_cudnn, per_flash, model):
filename = f'prof_{model}_cuda_gpu_trace.csv'
df = pd.read_csv(os.path.join('./',filename))
df_times = pd.read_csv('times.csv')
row = len(df_times.index)-1
if per_cudnn > 0:
t_cudnn_all = df[df['Name'].str.contains('cudnn')]['Duration (ns)'].to_numpy()
t_cudnn_all = t_cudnn_all.reshape(-1, per_cudnn)
t_cudnn_avg = np.average(t_cudnn_all, axis=0)
df_times.loc[row, 'FusedAttention Kernels (fwd)'] = t_cudnn_avg[0]/1e6
df_times.loc[row, 'FusedAttention Kernels (bwd)'] = t_cudnn_avg[1:4].sum()/1e6
df_times.loc[row, 'FusedAttention Kernels (fwd+bwd)'] = t_cudnn_avg.sum()/1e6
if per_flash > 0:
t_flash_all = df[df['Name'].str.contains('void flash')]['Duration (ns)'].to_numpy()
t_flash_all = t_flash_all.reshape(-1, per_flash)
t_flash_avg = np.average(t_flash_all, axis=0)
df_times.loc[row, 'FlashAttention Kernels (fwd)'] = t_flash_avg[0]/1e6
df_times.loc[row, 'FlashAttention Kernels (bwd)'] = t_flash_avg[1:4].sum()/1e6
df_times.loc[row, 'FlashAttention Kernels (fwd+bwd)'] = t_flash_avg.sum()/1e6
if per_cudnn > 0 and per_flash > 0:
df_times.loc[row, 'Fused vs Flash Kernels Speedup (fwd+bwd)'] = \
df_times.loc[row, 'FlashAttention Kernels (fwd+bwd)'] / \
df_times.loc[row, 'FusedAttention Kernels (fwd+bwd)']
df_times.to_csv('times.csv',index=False)
def main():
times = pd.DataFrame(
columns=[
'FusedAttention Module',
'FusedAttention Kernels (fwd)',
'FusedAttention Kernels (bwd)',
'FusedAttention Kernels (fwd+bwd)',
'FlashAttention Module',
'FlashAttention Kernels (fwd)',
'FlashAttention Kernels (bwd)',
'FlashAttention Kernels (fwd+bwd)',
'Fused vs Flash Kernels Speedup (fwd+bwd)',
])
times.to_csv('times.csv',index=False)
device_id = torch.cuda.current_device()
device_properties = torch.cuda.get_device_properties(device_id)
print(f"Device {device_id}: "
f"{device_properties.name} GPU, "
f"sm{device_properties.major}{device_properties.minor} compute capability, "
f"{device_properties.total_memory/1024**3:.1f}GB memory")
for model in model_configs.keys():
config = model_configs[model]
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
config, dtype, qkv_layout=qkv_layout,
)
fused_attn_supported = fused_attn_supported and not swa
flash_attn_supported = _is_flash_attention_supported(config)
print(f'Running {model} with {"cuDNN attention" if fused_attn_supported else ""}'
f'{" and flash-attention" if flash_attn_supported else ""}...')
prof_cmd = [
"nsys",
"profile",
"--capture-range=cudaProfilerApi",
"--capture-range-end=stop-shutdown",
"--force-overwrite=true",
f"--output=prof_{model}",
"python",
"-c",
f""" "import benchmark_attention;""",
f"""benchmark_attention.benchmark_dot_product_attention("""
f"""'{model}', {fused_attn_supported}, {flash_attn_supported})" """,
]
prof_cmd = ' '.join(prof_cmd)
subprocess.call(prof_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
stats_cmd = [
"nsys",
"stats",
"-q",
"-r",
"cuda_gpu_trace",
"--format",
"csv,column",
"--force-overwrite=true",
"--force-export=true",
f"--output=prof_{model}",
f"prof_{model}.nsys-rep",
]
if fused_attn_supported:
num_kernels_cudnn = 4
if config.attn_bias_type == 'post_scale_bias':
num_kernels_cudnn = num_kernels_cudnn+1
if config.num_heads != config.num_gqa_groups:
num_kernels_cudnn = num_kernels_cudnn+2
else:
num_kernels_cudnn = 0
num_kernels_flash = 4 if flash_attn_supported else 0
stats_cmd = ' '.join(stats_cmd)
subprocess.call(stats_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
parse_cmd = [
"python",
"-c",
f""" "import benchmark_attention;""",
f"""benchmark_attention.parse_results("""
f"""{num_kernels_cudnn}, {num_kernels_flash}, '{model}')" """,
]
parse_cmd = ' '.join(parse_cmd)
subprocess.call(parse_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
df_times = pd.read_csv('times.csv')
df_times.index = list(model_configs.keys())
a=df_times[['FusedAttention Kernels (fwd+bwd)',
'FlashAttention Kernels (fwd+bwd)',
'Fused vs Flash Kernels Speedup (fwd+bwd)']]
a.columns = ['cuDNN fwd+bwd (ms)', 'flash-attn fwd+bwd (ms)', 'cuDNN vs flash speedup']
print()
print(a)
if __name__ == "__main__":
main()
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import torch
from typing import Tuple
from tests.pytorch.fused_attn.test_fused_attn import ModelConfig
from transformer_engine.pytorch.distributed import _set_cuda_rng_state
from transformer_engine.pytorch.attention import DotProductAttention
# Initialize RNG state
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
def reset_rng_states() -> None:
"""Revert back to initial RNG state"""
torch.set_rng_state(_cpu_rng_state)
_set_cuda_rng_state(_cuda_rng_state)
def _run_dot_product_attention(
dtype: torch.dtype,
config: ModelConfig,
qkv_layout: str,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Run DotProductAttention module with one forward pass and one backward pass"""
reset_rng_states()
seqlens_q = torch.full([config.batch_size], config.max_seqlen_q,
dtype=torch.int32, device="cuda")
seqlens_kv = torch.full([config.batch_size], config.max_seqlen_kv,
dtype=torch.int32, device="cuda")
inp = torch.randn([config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim],
dtype=dtype, device="cuda")
q = inp[:,:,0,:,:]
k = inp[:,:,1,:,:]
v = inp[:,:,2,:,:]
q.requires_grad = True
k.requires_grad = True
v.requires_grad = True
out_grad = torch.randn([config.batch_size, config.max_seqlen_q, config.num_heads * config.head_dim],
dtype=dtype, device="cuda")
# Create attention mask / bias
attention_mask = None
bias = None
if config.attn_mask_type == "arbitrary":
attention_mask = torch.randint(-10,10,
[config.batch_size, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv]).to(
dtype=torch.bool, device="cuda")
if config.attn_bias_type == "post_scale_bias":
# convert mask to bias
attention_mask = torch.randint(-10,10,
[config.batch_size, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv]).to(
dtype=torch.bool, device="cuda")
bias = attention_mask.clone()
neginf = -2**50 if dtype == torch.bfloat16 else -2**15
bias = torch.where(bias==0, 0, neginf).to(dtype=dtype, device='cuda')
bias.requires_grad = False
attention_mask = None
block = (
DotProductAttention(
config.num_heads,
config.head_dim,
num_gqa_groups=config.num_gqa_groups,
qkv_format='bshd',
attention_dropout=config.dropout_p,
sequence_parallel=False,
tp_size=1,
get_rng_state_tracker=None,
tp_group=None,
layer_number=1,
).to(dtype=dtype, device="cuda")
)
# Run a forward and backward pass
out = None
if config.attn_mask_type == "arbitrary":
out = block(q, k, v,
attention_mask=attention_mask, # attention_mask
qkv_format='bshd',
attn_mask_type=config.attn_mask_type, # 'arbitrary'
core_attention_bias_type=config.attn_bias_type, # 'no_bias'
core_attention_bias=bias, # None
)
out.backward(out_grad)
if config.attn_bias_type == "post_scale_bias":
out = block(q, k, v,
attention_mask=attention_mask, # None
qkv_format='bshd',
attn_mask_type=config.attn_mask_type, # no_mask
core_attention_bias_type=config.attn_bias_type, # 'post_scale_bias'
core_attention_bias=bias, # bias
)
out.backward(out_grad)
return out, (q.grad, k.grad, v.grad)
dtype = torch.bfloat16
model_configs = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"test_mask": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "arbitrary", "no_bias"),
"test_bias": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "post_scale_bias"),
}
print('Run with post_scale_bias:')
config = model_configs["test_bias"]
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(dtype, config, 'bs3hd')
print('Run with arbitrary mask:')
config = model_configs["test_mask"]
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(dtype, config, 'bs3hd')
torch.testing.assert_close(unfused_attn_fwd, fused_attn_fwd, atol=2.5e-2, rtol=2.5e-2)
for i in range(3):
torch.testing.assert_close(unfused_attn_bwd[i], fused_attn_bwd[i], atol=2.5e-2, rtol=2.5e-2)
print('Test passed!')
{
"cells": [
{
"cell_type": "markdown",
"id": "141fa8bd",
"metadata": {},
"source": [
"# Attention Is All You Need!\n",
"\n",
"The core idea behind Transformer models is the attention mechanism [[1]](https://arxiv.org/abs/1706.03762). It identifies the correlation between words, selects the most important parts of the sentence to focus on, and captures meaningful patterns and dependencies in the data. Figure 1 shows a typical attention mechanism, where pre-softmax operations can be a combination of scaling, bias and masking while the post-softmax operation is often just dropout.\n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"dot_product_attention.png\" width=\"70%\">\n",
"<figcaption> Figure 1: Dot product attention. </figcaption>\n",
"</figure>\n",
"\n",
"[Transformer Engine](https://github.com/NVIDIA/TransformerEngine.git) supports the calculation of dot product attention in three frameworks, [PyTorch](https://github.com/pytorch/pytorch), [JAX](https://github.com/google/jax) and [PaddlePaddle](https://github.com/PaddlePaddle/Paddle). The API for each framework is,\n",
"- [transformer_engine.pytorch.DotProductAttention](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention)\n",
"- [transformer_engine.jax.flax.DotProductAttention](../../api/jax.rst#transformer_engine.jax.flax.DotProductAttention)\n",
"- [transformer_engine.paddle.DotProductAttention](../../api/paddle.rst#transformer_engine.paddle.DotProductAttention)"
]
},
{
"cell_type": "markdown",
"id": "09a60057",
"metadata": {},
"source": [
"## 1. Attention Backends\n",
"\n",
"Transformer Engine provides multiple attention backends for each supported framework. The framework-native backends provide a robust baseline, while the fused, GPU-optimized implementations offer more performance. For example, the flash-attention and cuDNN attention backends in PyTorch. The framework-native backends are often named with \"unfused\", while the more optimized backends are \"fused\" or \"flash\".\n",
"\n",
"| Framework | Backend (Module Name) | Module Location |\n",
"| :-------- | :-------------------- | :-------------- |\n",
"| PyTorch | cuDNN attention (`FusedAttention`)<br> flash-attention (`FlashAttention`)<br> PyTorch-native attention (`UnfusedDotProductAttention`) | [transformer_engine.pytorch.attention](../../transformer_engine/pytorch/attention.py) |\n",
"| JAX | cuDNN attention (`_FusedDotProductAttention`)<br> JAX-native attention (`_UnfusedDotProductAttention`) | [transformer_engine.jax.flax.transformer](../../transformer_engine/jax/flax/transformer.py) |\n",
"| PaddlePaddle | cuDNN attention (`_te_forward`)<br> PaddlePaddle-native attention (`_pd_forward`) | [transformer_engine.paddle.layer.attention](../../transformer_engine/paddle/layer/attention.py) |\n"
]
},
{
"cell_type": "markdown",
"id": "f387274e",
"metadata": {},
"source": [
"### 1.1 Flash vs. Non-Flash\n",
"\n",
"The attention calculation has quadratic computational and memory complexities to the sequence length. Its runtime and memory requirements quadruple, when the sequence length doubles. This presents a significant challenge to scale Transformer models up for longer contexts, in order to achieve higher model quality.\n",
"\n",
"Compared to the standard, non-flash algorithm, the flash algorithm [[2]](https://arxiv.org/abs/2205.14135) was proposed to reduce the memory scaling to linear and improve the computational efficiency through optimized memory accesses. It employs the following two distinctive techniques.\n",
"\n",
"- **Tiling:** The non-flash algorithm tries to process the query, key, value tensors in one single step, requiring large amounts of global memory and incurring high volumes of reads/writes between global memory and shared memory. The flash algorithm decomposes the input into several tiles, based on the available shared memory and register size, and it computes the softmax one tile at a time.\n",
"\n",
"- **Recomputation:** The non-flash algorithm stores the softmax matrix (quadratic to sequence length) to global memory for the backward pass, while the flash algorithm only saves the softmax normalization factors (linear to sequence length). This reduces the amount of memory required as well as the bandwidth utilization between global memory and shared memory. Even though there is extra computation incurred in order to recalculate the attention in the backward pass, the bandwidth savings still provide significant improvement in efficiency.\n",
"\n",
"<div class=\"alert alert-info\">\n",
"<b>Note:</b> Transformer Engine's flash-attention backend, available in PyTorch, and cuDNN attention backend (sub-backends 1 and 2), in PyTorch, JAX and PaddlePaddle, are both based on the flash algorithm.\n",
"</div>\n"
]
},
{
"cell_type": "markdown",
"id": "f1389145",
"metadata": {},
"source": [
"### 1.2 flash-attention\n",
"\n",
"The flash-attention backend, available only in PyTorch, is a module wrapped around the public `flash-attn` package [[3]](https://github.com/Dao-AILab/flash-attention). \n",
"\n",
"The flash-attention backend supports `flash-attn`'s features as they are released, and to facilitate the use of `flash-attn`, flash-attention also offers a few functionalities such as converting the `attention_mask` to cumulative sequence lengths `cu_seqlens` for `padding` mask. Please see `transformer_engine.pytorch.attention.FlashAttention` for more details.\n",
"\n",
"The `flash-attn` dependency is regularly updated in Transformer Engine. As of v1.7, Transformer Engine supports `flash-attn` 2.0.6+ (see [setup.py](../../setup.py)).\n",
"\n",
"To understand `flash-attn`'s performance, please refer to their [benchmarks](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#performance).\n",
"\n",
"### 1.3 cuDNN Attention\n",
"\n",
"The cuDNN attention backend, available in PyTorch, JAX and PaddlePaddle, offers another high-performance solution to the attention calculation. It requires [cuDNN](https://developer.nvidia.com/cudnn) and [cudnn-frontend](../../3rdparty/cudnn-frontend) to run, and has several sub-backends to support the different precisions and sequence lengths. Out of the three, sub-backends 1 and 2 are based on the flash algorithm, as `flash-attn` is.\n",
"\n",
"| Sub-Backend | Algorithm | Precision | Sequence Length | Architecture | Docs |\n",
"| :---------- | :--------- | :-------- | :-------------- | :----------- | :--- |\n",
"| 0 | Non-Flash | BF16/FP16 | <=512 | sm80, 90 | [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/latest/developer/graph-api.html#fused-attention-fprop) |\n",
"| 1 | Flash | BF16/FP16 | Any | sm80+ | [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/latest/developer/graph-api.html#fused-flash-attention-fprop),<br>[cudnn-frontend](https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md#scaled-dot-product-attention) |\n",
"| 2 | Flash | FP8 | cuDNN pre-9.0: <=512<br>cuDNN 9.0+: Any | cuDNN pre-9.0: sm90<br>cuDNN 9.0+: sm90+ | cuDNN 9.0+: [cudnn-frontend](https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md#scaled-dot-product-attention-fp8) |\n",
"\n",
"The cuDNN attention backend and flash-attention backend have several notable differences. As of Transformer Engine 1.7, cuDNN 9.0 and `flash-attn` 2.4.2,\n",
"\n",
"- flash-attention only supports the PyTorch framework while cuDNN attention supports PyTorch, JAX and PaddlePaddle.\n",
"- flash-attention supports BF16, FP16 precisions while cuDNN attention also supports FP8 (through its sub-backend 2).\n",
"- flash-attention supports `bshd`, `thd` input formats, without any transposes, and `sbhd` format, with transposes, while cuDNN attention supports all three without transposes (see Section 3.1 for more details).\n",
"- flash-attention does not support `post_scale_bias`, and cuDNN attention does.\n",
"- flash-attention supports sliding window attention, and cuDNN attention does not.\n",
"- flash-attention uses bottom right diagonal for `causal` mask in cross attention, and cuDNN attention uses top left (see `flash-attn`'s [change log](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#21-change-behavior-of-causal-flag)).\n",
"- flash-attention outperforms cuDNN attention on Ampere architectures, and cuDNN attention has 20-50% advantages on Hopper architectures, based on our benchmarks for a number of commonly-used model configurations.\n",
"\n",
"To compare cuDNN attention and flash-attention, users can modify the `model_configs` dictionary in [benchmarks/attention/benchmark_attention.py](../../benchmarks/attention/benchmark_attention.py) to collect performance numbers. The script runs each entry in `model_configs` for `num_iters` times, each time with one forward pass and one backward pass. Both backends are tried, and if one backend does not have support for the specific user input, the runtimes and speedups in the final table would be 0."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bbc5c73f",
"metadata": {},
"outputs": [],
"source": [
"model_configs = {\n",
" # test: b, h, hg, d, sq, skv, p, mask, bias\n",
" \"test_0\": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, \"no_mask\", \"no_bias\"), # short seq\n",
" \"test_1\": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, \"causal\", \"no_bias\"), # longer seq, mask\n",
" \"test_2\": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, \"causal\", \"post_scale_bias\"), # bias\n",
" \"test_3\": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, \"causal\", \"no_bias\"), # GQA\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "173638b6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Device 0: NVIDIA H100 PCIe GPU, sm90 compute capability, 79.1GB memory\n",
"Running test_0 with cuDNN attention and flash-attention...\n",
"Running test_1 with cuDNN attention and flash-attention...\n",
"Running test_2 with cuDNN attention...\n",
"Running test_3 with cuDNN attention and flash-attention...\n",
"\n",
" cuDNN fwd+bwd (ms) flash-attn fwd+bwd (ms) cuDNN vs flash speedup\n",
"test_0 0.0638 0.0858 1.3454\n",
"test_1 0.5415 0.7496 1.3842\n",
"test_2 1.2302 0.0000 0.0000\n",
"test_3 12.0122 19.0716 1.5877\n"
]
}
],
"source": [
"!cd ../../../benchmarks/attention/ && python benchmark_attention.py"
]
},
{
"cell_type": "markdown",
"id": "0f62d2fa",
"metadata": {},
"source": [
"## 2. Backend Selection\n",
"\n",
"Given the various attention backends, Transformer Engine has a selection logic in place to choose the most appropriate backend for a particular set of user inputs and runtime environment. The selection logic is based on both backend availability and backend performance.\n",
"\n",
"Backend availability is determined by factors such as model configuration, training hyper-parameters, software versions, and the GPU architecture in question. For example, some considerations are the sequence length, number of attention heads, head size, attention mask type, attention bias type, training or inference mode, self or cross attention, MHA or MQA/GQA, `flash-attn`/cuDNN library versions, and the compute capability of the GPU.\n",
"\n",
"When there are multiple backends available, Transformer Engine makes backend selection based on performance. In general, there are a few rules being followed in our selection logic (see table below). As we monitor the performance of different backends, the selection logic may change.\n",
"\n",
"| Framework | Selection Order |\n",
"| :-------- | :--------------------- |\n",
"| PyTorch | sm90: cuDNN attention > flash-attention > PyTorch-native attention<br>sm80: flash-attention > cuDNN attention > PyTorch-native attention<br>cuDNN attention: sub-backend 1 > sub-backend 0 |\n",
"| JAX | cuDNN attention > JAX-native attention |\n",
"| PaddlePaddle | cuDNN attention > PaddlePaddle-native attention |\n"
]
},
{
"cell_type": "markdown",
"id": "86e16a2b",
"metadata": {},
"source": [
"### 2.1 Debug Information\n",
"\n",
"To find out which backend is being used during runtime, users can turn on these debugging flags. Logging is done using the `logging` package.\n",
"```\n",
"NVTE_DEBUG = 0/1 # disables/enables debugging\n",
"NVTE_DEBUG_LEVEL = 0/1/2 # enables logging.WARNING/INFO/DEBUG-level messages\n",
"```\n",
"<div class=\"alert alert-info\">\n",
"<b>Note:</b> These flags are supported in PyTorch only as of Transformer Engine 1.7. JAX and PaddlePaddle support is expected to be added in the future.\n",
"</div>"
]
},
{
"cell_type": "markdown",
"id": "e439434e",
"metadata": {},
"source": [
"The [example_attention.py](./example_attention.py) script runs a very basic model with two attention backends, cuDNN attention and flash-attention. Here `NVTE_DEBUG_LEVEL=1` allows us to find out which backend/sub-backend was actually used during runtime."
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "9d002327",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Run cuDNN attention...\n",
"[INFO | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)\n",
"\n",
"Run flash-attention...\n",
"[INFO | DotProductAttention]: Running with FlashAttention backend \n",
"\n",
"Test passed.\n"
]
}
],
"source": [
"!NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python example_attention.py"
]
},
{
"cell_type": "markdown",
"id": "bbf1756c",
"metadata": {},
"source": [
"To collect more information, users can turn on `NVTE_DEBUG_LEVEL=2`. In this example, it allows us to find out more about the run config. Users are encouraged to provide if users intend to file a bug with Transformer Engine. For example, "
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "66a2f34c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Run cuDNN attention...\n",
"[DEBUG | DotProductAttention]: Disabling FlashAttention due to NVTE_FLASH_ATTN=0\n",
"[INFO | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)\n",
"[DEBUG | DotProductAttention]: Running with {'compute_capability': 'sm90', 'q_dtype': torch.bfloat16, 'k_dtype': torch.bfloat16, 'v_dtype': torch.bfloat16, 'q_shape': [2, 512, 16, 64], 'k_shape': [2, 512, 16, 64], 'v_shape': [2, 512, 16, 64], 'qkv_format': 'bshd', 'qkv_layout': 'bshd_bshd_bshd', 'mask_type': 'no_mask', 'bias_type': 'no_bias', 'bias_shape': None, 'dropout': 0.0, 'context_parallel': False, 'is_training': True, 'transformer_engine_version': <Version('1.8.0.dev0')>, 'flash_attn_version': <Version('2.4.2')>, 'cudnn_version': '9.2.0'}\n",
"[DEBUG | FusedAttnFunc ]: Running forward in torch.bfloat16\n",
"[DEBUG | FusedAttnFunc ]: Running backward in torch.bfloat16\n",
"\n",
"Run flash-attention...\n",
"[DEBUG | DotProductAttention]: Disabling FusedAttention due to NVTE_FUSED_ATTN=0\n",
"[INFO | DotProductAttention]: Running with FlashAttention backend \n",
"[DEBUG | DotProductAttention]: Running with {'compute_capability': 'sm90', 'q_dtype': torch.bfloat16, 'k_dtype': torch.bfloat16, 'v_dtype': torch.bfloat16, 'q_shape': [2, 512, 16, 64], 'k_shape': [2, 512, 16, 64], 'v_shape': [2, 512, 16, 64], 'qkv_format': 'bshd', 'qkv_layout': 'bshd_bshd_bshd', 'mask_type': 'no_mask', 'bias_type': 'no_bias', 'bias_shape': None, 'dropout': 0.0, 'context_parallel': False, 'is_training': True, 'transformer_engine_version': <Version('1.8.0.dev0')>, 'flash_attn_version': <Version('2.4.2')>, 'cudnn_version': '9.2.0'}\n",
"\n",
"Test passed.\n"
]
}
],
"source": [
"!NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2 python example_attention.py"
]
},
{
"cell_type": "markdown",
"id": "9f964732",
"metadata": {},
"source": [
"### 2.2 User Control\n",
"\n",
"Users usually do not need to worry about the backend selection. However, if there is a convergence or performance issue encountered, Transformer Engine provides a few other environment variables for users to experiment with different backends.\n",
"\n",
"**flash-attention or cuDNN attention:**\n",
"Users can enable/disable the flash-attention backend or cuDNN attention backend via the following two environment variables in PyTorch.\n",
"```\n",
"NVTE_FLASH_ATTN = 0 # disables flash-attention; default = 1\n",
"NVTE_FUSED_ATTN = 0 # disables cuDNN attention; default = 1\n",
"```\n",
"\n",
"**cuDNN attention sub-backends:**\n",
"This environment variable allows users to express their preference of cuDNN attention sub-backends. However, the elected sub-backend will only be used *if* it is eligible, i.e. if it has support for the provided inputs and runtime environment.\n",
"```\n",
"NVTE_FUSED_ATTN_BACKEND = 0/1/2 # user preference of cuDNN sub-backend\n",
"```\n",
"\n",
"**Execution paths of cuDNN sub-backend 1:**\n",
"cuDNN attention sub-backend 1 also offers two execution paths: workspace optimization path and non-workspace optimization path. The workspace optimization path requires a larger amount of global memory, provides determinism, and offers bias gradient support. Before cuDNN 9.0, it also has 20-30% performance advantage over the non-workspace optimization path. But after cuDNN 9.0, it is 20-30% slower than the non-workspace optimization path.\n",
"\n",
"Users can experiment with these two paths through the following environment variable. However, please be aware of the possible Out-Of-Memory risks.\n",
"```\n",
"NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT = 0 # disables workspace optimization path\n",
"NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT = 1 # enables workspace optimization path\n",
"```\n",
"<div class=\"alert alert-info\">\n",
"<b>Note:</b> Environment variables <code>NVTE_FLASH_ATTN</code>, <code>NVTE_FUSED_ATTN</code> and <code>NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT</code> are only supported in PyTorch, not JAX or PaddlePaddle.\n",
"</div>\n",
"\n",
"### 2.3 Example Tests\n",
"\n",
"Our [unit tests](../../tests/) demonstrate the use of Transformer Engine dot product attention APIs. Users are encouraged to use them as a template when integrating Transformer Engine to their ML workflows.\n",
"\n",
"For example, in PyTorch, [test_dot_product_attention](../../tests/pytorch/fused_attention/test_fused_attn.py) offers a variety of use cases of `pytorch.DotProductAttention`, from data types, model configs, checkpointing, to QKV layouts."
]
},
{
"cell_type": "markdown",
"id": "3ad85b86",
"metadata": {},
"source": [
"## 3. Backend Support\n",
"\n",
"Transformer Engine supports commonly-used features such as self and cross attention, FP16/BF16 precisions, dropout, and checkpointing. But it also offers a range of other features. As of v1.7, Transformer Engine's attention backends have the following support matrix.\n",
"\n",
"| Attention Backend | Precision | Architecture | Sliding Window Attention | MQA/GQA | Context Parallelism | Deterministic |\n",
"| :---------------- | :-------- | :----------- | :----------------------- | :------ | :------------------ | :------------ |\n",
"| cuDNN attention<br>(PyTorch, JAX, PaddlePaddle) | PyTorch: BF16, FP16, FP8<br>JAX, PaddlePaddle: BF16, FP16 | sm80+ | No | Yes | `bshd`,`sbhd`: Yes<br>`thd`: No | Sub-backend 0, 2: Yes<br>Sub-backend 1: Yes, if workspace optimization path |\n",
"| flash-attention<br>(PyTorch) | BF16, FP16 | sm80+ | Yes | Yes | `bshd`,`thd`: Yes<br>`sbhd`: No | Yes, if `deterministic=True` |\n",
"| Framework-native attention<br>(PyTorch, JAX, PaddlePaddle) | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | No | Yes |\n",
"\n",
"Some unit tests are provided to serve as a starting point for integrating such features into users' models. For example,\n",
"- sliding window attention: [test_dpa_swa](../../tests/pytorch/fused_attention/test_fused_attn.py)\n",
"- MQA/GQA: [test_te_layer_mqa_gqa](../../tests/pytorch/fused_attention/test_fused_attn.py)\n",
"- context parallelism: [test_cp_with_fused_attention](../../tests/pytorch/fused_attention/test_fused_attn_with_cp.py), [test_cp_with_flash_attention](../../tests/pytorch/fused_attention/test_fused_attn_with_cp.py)"
]
},
{
"cell_type": "markdown",
"id": "37920af4",
"metadata": {},
"source": [
"### 3.1 QKV Layout\n",
"\n",
"Transformer Engine supports various layouts of the query `q`, key `k`, value `v` tensors. It has defined 15 QKV layouts, which are grouped into 3 QKV formats and 5 QKV layout groups to help with similar memory/computational operations across different layouts. The mapping relationships of these layouts and groups are,\n",
"\n",
"| `qkv_layout` &nbsp; &nbsp; &nbsp; &nbsp; | `qkv_layout_group`=`3hd` | `h3d` | `hd_2hd` | `hd_h2d` | `hd_hd_hd` |\n",
"| ----------: | -----------: | -----: | ----------: | ----------: | -------------: |\n",
"| `qkv_format`=`sbhd` | `sb3hd` | `sbh3d` | `sbhd_sb2hd` | `sbhd_sbh2d` | `sbhd_sbhd_sbhd` |\n",
"| `bshd` | `bs3hd` | `bsh3d` | `bshd_bs2hd` | `bshd_bsh2d` | `bshd_bshd_bshd` |\n",
"| `thd` | `t3hd` | `th3d` | `thd_t2hd` | `thd_th2d` | `thd_thd_thd` |\n",
"\n",
"The notation system is that `b` stands for the batch size, `s` sequence length, `h` number of attention heads, `d` head dimension, and `t` the total number of tokens in the batch, i.e. `t = sum(s_i) for i in 0,...,b-1`. Here are a few examples of the layouts and their explanations to help clarify the definition.\n",
"\n",
"**`qkv_layout`=`sb3hd`:**\n",
"`q`, `k`, `v` are sequence first, i.e. `s` is the leading dimension in each tensor. They are different slices of one tensor `qkv`: `q, k, v = [qkv[:,:,i,:,:] for i in range(3)]`. They are interleaved at the `h * d` dimension.\n",
"\n",
"**`qkv_layout`=`bshd_bsh2d`:**\n",
"`q`, `k`, `v` are batch first, i.e. `b` is the leading dimension in each tensor. `q` is contiguous, and `k`, `v` are different slices of tensor `kv`: `k, v = [kv[:,:,:,i,:] for i in range(2)]`. `k`, `v` are interleaved at the `d` dimension.\n",
"\n",
"The `s` and `h` in `bsh2d` are the max sequence length and number of heads for `k`, `v`, which can be different from the `s` and `h` in `bshd` for `q`. We denoted them as the same for brevity reasons. Transformer Engine does differentiate their values for actual execution.\n",
"\n",
"**`qkv_layout`=`thd_thd_thd`:**\n",
"`q`, `k`, `v` have variable sequence lengths in a batch. They are all contiguous and have no interleaving.\n",
"\n",
"As of v1.7, Transformer Engine has the following support matrix.\n",
"\n",
"| Backend | Supported QKV Formats | Notes |\n",
"| :--------------- | :-------------------- | :------ |\n",
"| flash-attention | `bshd`, `sbhd`, `thd`<br>(`sbhd` requires transpose operations) | PyTorch: 3 formats, i.e. 15 layouts|\n",
"| cuDNN attention | `bshd`, `sbhd`, `thd` | PyTorch: 3 formats, i.e. 15 layouts<br>JAX, PaddlePaddle: `bs3hd`, `bshd_bs2hd`, `bshd_bshd_bshd` layouts |\n",
"| Framework-native attention | `bshd`, `sbhd`<br>(`sbhd` requires transpose operations) | PyTorch, JAX, PaddlePaddle: 2 formats, i.e. 10 layouts |\n",
"\n",
"Some example usage of the different layouts can be found at [test_dpa_qkv_layout](../../tests/pytorch/fused_attention/test_fused_attn.py) and [test_dpa_qkv_layout_thd](../../tests/pytorch/fused_attention/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention._get_qkv_layout](../../transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n",
"\n",
"<div class=\"alert alert-info\">\n",
"<b>Note:</b> When RoPE is employed, the <code>qkv_layout</code> may change in Transformer Engine PyTorch through [_get_qkv_layout](../../transformer_engine/pytorch/attention.py). This is due to the in-place nature of our RoPE implementations. We convert `q`, `k`, `v` tensors from their initial layout to the corresponding <code>hd_hd_hd</code> layout. For example, from <code>sbh3d</code> in <code>pytorch.MultiHeadAttention</code> before RoPE, to <code>sbhd_sbhd_sbhd</code> in <code>pytorch.DotProductAttention</code> after RoPE.\n",
"</div>\n"
]
},
{
"cell_type": "markdown",
"id": "94c69fae",
"metadata": {},
"source": [
"### 3.2 Attention Mask\n",
"\n",
"Transformer Engine supports 5 mask types, and all the masks are defined as `True` masking out the corresponding element and `False` including the corresponding element in attention calculation.\n",
"- `no_mask`, `padding`, `causal`, `padding_causal` (equivalent to `causal_padding`), `arbitrary`\n",
"\n",
"Different backends offer different support for attention mask. As of Transformer Engine 1.7,\n",
"\n",
"| Backend | Supported Mask Types | Requires `attention_mask` |\n",
"| :--------------- | :-------------------- | :------------------ |\n",
"| flash-attention | `no_mask`, `causal`, `padding`, `padding_causal` | `no_mask`, `causal`: No<br>`padding`, `padding_causal`: Yes if `cu_seqlens` not provided|\n",
"| cuDNN attention | `no_mask`, `causal`, `padding`, `padding_causal` | `no_mask`, `causal`: No<br>`padding`, `padding_causal`: Yes if `cu_seqlens` not provided|\n",
"| Framework-native attention | `no_mask`, `causal`, `arbitrary` | `no_mask`, `causal`: No<br>`arbitrary`: Yes |\n",
"\n",
"**`padding` and `padding_causal`:** For these two mask types, users need to provide sequence length information to help Transformer Engine figure out where each sequence ends in a batch. As of Transformer Engine 1.7, there are two options to do so in PyTorch and one in JAX and PaddlePaddle.\n",
"\n",
"* PyTorch: When both options are provided by the user, `cu_seqlens` is preferred as there is no extra conversion needed.\n",
" - `cu_seqlens`: Users can provide cumulative sequence length tensors `cu_seqlens_q` and `cu_seqlens_kv` for `q` and `k`/`v` to the flash-attention or cuDNN attention backend. An example of `cu_seqlens` is `[0, 2, 6, 7]` for a batch of 3 `[aa000, bbbb0, c0000]`.\n",
" - `attention_mask`: Users can also provide `attention_mask` as an alternative, which will then be converted to `cu_seqlens`. For self-attention, `attention_mask` should be one single tensor in shape `[batch_size, 1, 1, seqlen_q]`, and for cross-attention, `attention_mask` should be a list of two tensors in shapes `[batch_size, 1, 1, seqlen_q]` and `[batch_size, 1, 1, seqlen_kv]`, respectively.\n",
"\n",
"\n",
"* JAX and PaddlePaddle: Users should provide the `attention_mask` tensor in shape `[batch_size, 1, seqlen_q, seqlen_kv]`.\n",
"\n",
"**`qkv_format`=`thd`:** Transformer Engine extracts the max sequence length information from `q`, `k`, `v` if `max_seqlen_q` and `max_seqlen_kv` are not provided. This requires GPU-CPU copy and synchronization operations. For performance reasons, please set `max_seqlen_q` and `max_seqlen_kv` to their appropriate values for `thd` QKV format.\n",
"\n",
"**`Arbitrary` mask:** cuDNN does not support `Arbitrary` mask type as of v9.0. However, users can convert the mask to a regular `post_scale_bias` bias and achieve the same functionality. An example script for this conversion is [arbitrary_mask_to_post_scale_bias.py](./arbitrary_mask_to_post_scale_bias.py).\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "4c87df64",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Run with post_scale_bias:\n",
"[DotProductAttention]: using cuDNN attention (sub-backend 1)\n",
"Run with arbitrary mask:\n",
"[DotProductAttention]: using unfused DPA\n",
"Test passed!\n"
]
}
],
"source": [
"!NVTE_DEBUG=1 python arbitrary_mask_to_post_scale_bias.py"
]
},
{
"cell_type": "markdown",
"id": "5ec0c75d",
"metadata": {},
"source": [
"Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](../../tests/pytorch/fused_attention/test_fused_attn.py).\n",
"\n",
"### 3.3 Attention Bias\n",
"\n",
"Transformer Engine supports 4 attention bias types, `no_bias`, `pre_scale_bias`, `post_scale_bias`, and `ALiBi` (with/without custom slopes). As of Transformer Engine 1.7, their support matrix is as follows.\n",
"\n",
"| Backend | Bias Type | Bias Shape | Bias Data Type | Architecture |\n",
"| :------ | :-------- | :--------- | :--------- | :----------- |\n",
"| flash-attention | `no_bias`, `ALiBi` (with slopes) | N/A | ALiBi slopes: FP32 | sm80+ |\n",
"| cuDNN attention | PyTorch: `no_bias`, `post_scale_bias`, `ALiBi` (without slopes)<br>JAX, PaddlePaddle: `no_bias`, `post_scale_bias` | `post_scale_bias`: BHSS, 1HSS, B1SS, 11SS for forward, 1HSS for backward | `post_scale_bias`: same as QKV type<br>ALiBi slopes: FP32 | cuDNN 8.9.6+: sm90<br>cuDNN 9.0+: sm80+ |\n",
"| Framework-native attention | `no_bias`, `pre_scale_bias`, `post_scale_bias` | `post_scale_bias`: BHSS, 1HSS, B1SS, 11SS | `post_scale_bias`: same as QKV type | sm80+ |\n",
"\n",
"The flash-attention backend enables `ALiBi` by asking user to pass in an `alibi_slopes` tensor, which can be the default slopes of vanilla ALiBi, or user-defined slopes. On the other hand, cuDNN attention supports `ALiBi` by taking in a `Boolean` flag, and it only supports vanilla ALiBi as of cuDNN 9.0.\n",
"\n",
"The framework-native backends do not explicitly support `ALiBi`, but users can convert `ALiBi` to a regular `post_scale_bias` bias to achieve the same effect. In PyTorch, this utility function, `transformer_engine.pytorch.attention.get_alibi`, can be used to help with the conversion.\n",
"\n",
"More examples of how to use the various attention biases are at [test_dpa_bias](../../tests/pytorch/fused_attention/test_fused_attn.py)."
]
},
{
"cell_type": "markdown",
"id": "3f8f6f1c",
"metadata": {},
"source": [
"### 3.4 FP8 Attention\n",
"\n",
"A unique feature of Transformer Engine is its FP8 support, not only for the `Linear` layers but also for dot product attention. Transformer Engine's FP8 attention support is through its cuDNN attention sub-backend 2. Recall Figure 1: the two `MatMul` operations are performed in FP8 for computational efficiency, and the `SoftMax` operation is performed in FP32 for numerical accuracy.\n",
"\n",
"Transformer Engine supports FP8 attention through its [C APIs](../../api/c/fused_attn.rst), and [PyTorch API](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention), as of v1.7. Its PyTorch API offers two options, both controlled through the FP8 recipe definition, `transformer_engine.common.recipe.DelayedScaling`.\n",
"\n",
"- `DelayedScaling.fp8_dpa=True (default=False)`: This enables the use of cuDNN attention sub-backend 2, when it does support the provided user inputs. The `FusedAttention` module for cuDNN attention takes FP16 or BF16 tensors as inputs, performs dot product attention in FP8, and returns attention logits in FP16 or BF16 (same as the input type). Casting operations are required to cast tensors to FP8 at the beginning, and back to FP16/BF16 at the end of the module.\n",
"\n",
"- `DelayedScaling.fp8_mha=True (default=False)`: This option, on top of `fp8_dpa=True`, removes the casting operations at the beginning and end of the `FusedAttention` module. This feature is experimental. \n",
"\n",
"Examples of using the two features are available at [test_dpa_fp8_vs_f16](../../tests/pytorch/fused_attention/test_fused_attn.py) and [test_mha_fp8_vs_f16](../../tests/pytorch/fused_attention/test_fused_attn.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`. This should result in the following print when the debug flags are turned on, `NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2`.\n",
"```\n",
"[DEBUG | DotProductAttention]: Running with fp8_recipe.fp8_mha=False, fp8_recipe.fp8_dpa=True and NVTE_FP8_DPA_BWD=0\n",
"[DEBUG | FusedAttnFunc ]: Running forward in FP8\n",
"[DEBUG | FusedAttnFunc ]: Running backward in torch.bfloat16\n",
"```"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os, sys, time
import subprocess
import pandas as pd
import numpy as np
import torch
import nvtx
import transformer_engine
from tests.pytorch.fused_attn.test_fused_attn import (
ModelConfig,
_is_flash_attention_supported,
_is_fused_attention_supported,
_is_unfused_attention_supported,
_run_dot_product_attention
)
# data type
dtype = torch.bfloat16
# number of iterations after 3 warmup iterations
num_iters = 3
# checkpointing
ckpt_attn = False
# workspace optimization path for cuDNN attention
workspace_opt = True
# QKV memory layout
qkv_layout = 'bshd_bshd_bshd'
# sliding window attention
swa = False
# padding between sequences for qkv_format=thd
pad_between_seqs = False
# training mode
is_training = True
model_configs = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq
"test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask
"test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias
"test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA
}
def example_attention(model, fused_attn_supported, flash_attn_supported):
config = model_configs[model]
if dtype == torch.bfloat16:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
else:
tols = dict(atol=5e-3, rtol=5e-3)
if fused_attn_supported:
print()
print('Run cuDNN attention...')
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, config, "FusedAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
)
if flash_attn_supported:
print()
print('Run flash-attention...')
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, config, "FlashAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
)
if fused_attn_supported and flash_attn_supported:
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
for i,_ in enumerate(flash_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
print()
print('Test passed.')
def main():
models = ['test_0']
for model in models:
config = model_configs[model]
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
config, dtype, qkv_layout=qkv_layout,
)
fused_attn_supported = fused_attn_supported and not swa
flash_attn_supported = _is_flash_attention_supported(config)
example_attention(model, fused_attn_supported, flash_attn_supported)
if __name__ == "__main__":
main()
...@@ -11,6 +11,7 @@ from typing import Any, Dict, List, Tuple, Union ...@@ -11,6 +11,7 @@ from typing import Any, Dict, List, Tuple, Union
from pkg_resources import packaging from pkg_resources import packaging
import pytest import pytest
import torch import torch
import logging
from transformer_engine.common import recipe from transformer_engine.common import recipe
from transformer_engine.pytorch import TransformerLayer, fp8_autocast, fp8_model_init from transformer_engine.pytorch import TransformerLayer, fp8_autocast, fp8_model_init
...@@ -38,6 +39,7 @@ from transformer_engine.pytorch.utils import ( ...@@ -38,6 +39,7 @@ from transformer_engine.pytorch.utils import (
scaled_init_method_normal, scaled_init_method_normal,
is_bf16_compatible, is_bf16_compatible,
) )
from transformer_engine.pytorch.utils import get_cudnn_version
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import NVTE_Fused_Attn_Backend from transformer_engine_torch import NVTE_Fused_Attn_Backend
...@@ -51,8 +53,6 @@ torch.cuda.manual_seed(seed) ...@@ -51,8 +53,6 @@ torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state() _cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state() _cuda_rng_state = torch.cuda.get_rng_state()
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
def reset_rng_states() -> None: def reset_rng_states() -> None:
"""Revert back to initial RNG state""" """Revert back to initial RNG state"""
...@@ -66,16 +66,6 @@ def reset_global_fp8_state(): ...@@ -66,16 +66,6 @@ def reset_global_fp8_state():
fp8.FP8GlobalStateManager.reset() fp8.FP8GlobalStateManager.reset()
@functools.cache
def _cudnn_version() -> Tuple[int, int, int]:
"""Runtime cuDNN version (major, minor, patch)"""
encoded_version = ext.get_cudnn_version()
major_version_magnitude = 1000 if encoded_version < 90000 else 10000
major, encoded_version = divmod(encoded_version, major_version_magnitude)
minor, patch = divmod(encoded_version, 100)
return (major, minor, patch)
class ModelConfig: class ModelConfig:
def __init__( def __init__(
self, self,
...@@ -237,7 +227,7 @@ def get_swa(seq_q, seq_kv, w=None): ...@@ -237,7 +227,7 @@ def get_swa(seq_q, seq_kv, w=None):
return w, ml return w, ml
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base]) @pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", model_configs_base.keys()) @pytest.mark.parametrize("model", model_configs_base.keys())
...@@ -322,32 +312,28 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, ...@@ -322,32 +312,28 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn,
) )
if unfused_attn_supported and fused_attn_supported: if unfused_attn_supported and fused_attn_supported:
if _NVTE_DEBUG: logging.info("[test_dot_product_attention]: unfused attn vs fused attn")
print("[test_dot_product_attention]: unfused attn vs fused attn")
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols) torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
for i,_ in enumerate(unfused_attn_bwd): for i,_ in enumerate(unfused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols) torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
if unfused_attn_supported and flash_attn_supported: if unfused_attn_supported and flash_attn_supported:
if _NVTE_DEBUG: logging.info("[test_dot_product_attention]: unfused attn vs flash attn")
print("[test_dot_product_attention]: unfused attn vs flash attn")
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols) torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
for i,_ in enumerate(flash_attn_bwd): for i,_ in enumerate(flash_attn_bwd):
torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols) torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols)
if fused_attn_supported and flash_attn_supported: if fused_attn_supported and flash_attn_supported:
if _NVTE_DEBUG: logging.info("[test_dot_product_attention]: fused attn vs flash attn")
print("[test_dot_product_attention]: fused attn vs flash attn")
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols) torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
for i,_ in enumerate(flash_attn_bwd): for i,_ in enumerate(flash_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols) torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
if fused_attn_supported and len(fused_attn_backend) == 2: if fused_attn_supported and len(fused_attn_backend) == 2:
if _NVTE_DEBUG: logging.info("[test_dot_product_attention]: fused attn backend 0 vs 1")
print("[test_dot_product_attention]: fused attn backend 0 vs 1")
torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols) torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
for i,_ in enumerate(fused_attn_bwd): for i,_ in enumerate(fused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols) torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base]) @pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", ["base_1_1", "base_2_1"]) @pytest.mark.parametrize("model", ["base_1_1", "base_2_1"])
...@@ -373,7 +359,7 @@ model_configs_mask = { ...@@ -373,7 +359,7 @@ model_configs_mask = {
} }
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_mask]) @pytest.mark.parametrize("model_configs", [model_configs_mask])
@pytest.mark.parametrize("model", model_configs_mask.keys()) @pytest.mark.parametrize("model", model_configs_mask.keys())
...@@ -411,7 +397,7 @@ model_configs_bias = { ...@@ -411,7 +397,7 @@ model_configs_bias = {
} }
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias]) @pytest.mark.parametrize("model_configs", [model_configs_bias])
@pytest.mark.parametrize("model", model_configs_bias.keys()) @pytest.mark.parametrize("model", model_configs_bias.keys())
...@@ -438,7 +424,7 @@ model_configs_bias_shapes = { ...@@ -438,7 +424,7 @@ model_configs_bias_shapes = {
} }
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias_shapes]) @pytest.mark.parametrize("model_configs", [model_configs_bias_shapes])
@pytest.mark.parametrize("model", model_configs_bias_shapes.keys()) @pytest.mark.parametrize("model", model_configs_bias_shapes.keys())
...@@ -504,7 +490,7 @@ model_configs_layout = { ...@@ -504,7 +490,7 @@ model_configs_layout = {
} }
@pytest.mark.skipif(_cudnn_version() < (8,9,5), reason="cuDNN 8.9.5+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8,9,5), reason="cuDNN 8.9.5+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_layout]) @pytest.mark.parametrize("model_configs", [model_configs_layout])
@pytest.mark.parametrize("model", model_configs_layout.keys()) @pytest.mark.parametrize("model", model_configs_layout.keys())
...@@ -532,7 +518,7 @@ model_configs_layout_thd = { ...@@ -532,7 +518,7 @@ model_configs_layout_thd = {
} }
@pytest.mark.skipif(_cudnn_version() < (9,0,0), reason="cuDNN 9.0.0+ is required.") @pytest.mark.skipif(get_cudnn_version() < (9,0,0), reason="cuDNN 9.0.0+ is required.")
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="THD is only supported on Hopper+.") @pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="THD is only supported on Hopper+.")
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_layout_thd]) @pytest.mark.parametrize("model_configs", [model_configs_layout_thd])
...@@ -848,7 +834,7 @@ model_configs_te_layer = { ...@@ -848,7 +834,7 @@ model_configs_te_layer = {
} }
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer]) @pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", model_configs_te_layer.keys()) @pytest.mark.parametrize("model", model_configs_te_layer.keys())
...@@ -917,23 +903,20 @@ def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, f ...@@ -917,23 +903,20 @@ def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, f
) )
if unfused_attn_supported and fused_attn_supported: if unfused_attn_supported and fused_attn_supported:
if _NVTE_DEBUG: logging.info("[test_transformer_layer]: unfused attn vs fused attn")
print("[test_transformer_layer]: unfused attn vs fused attn")
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols) torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols) torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
if unfused_attn_supported and flash_attn_supported: if unfused_attn_supported and flash_attn_supported:
if _NVTE_DEBUG: logging.info("[test_transformer_layer]: unfused attn vs flash attn")
print("[test_transformer_layer]: unfused attn vs flash attn")
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols) torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, **tols) torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, **tols)
if fused_attn_supported and flash_attn_supported: if fused_attn_supported and flash_attn_supported:
if _NVTE_DEBUG: logging.info("[test_transformer_layer]: fused attn vs flash attn")
print("[test_transformer_layer]: fused attn vs flash attn")
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols) torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols) torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols)
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer]) @pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", ["te_1_2", "te_2_0"]) @pytest.mark.parametrize("model", ["te_1_2", "te_2_0"])
...@@ -947,7 +930,7 @@ def test_te_layer_misc(dtype, model_configs, model, qkv_format): ...@@ -947,7 +930,7 @@ def test_te_layer_misc(dtype, model_configs, model, qkv_format):
ckpt_attn, qkv_format, fused_qkv_params, RoPE) ckpt_attn, qkv_format, fused_qkv_params, RoPE)
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer]) @pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", ["te_2_0", "te_2_1", "te_2_2"]) @pytest.mark.parametrize("model", ["te_2_0", "te_2_1", "te_2_2"])
...@@ -1118,9 +1101,9 @@ def _rmse(a, b): ...@@ -1118,9 +1101,9 @@ def _rmse(a, b):
return math.sqrt((torch.pow((a-b), 2)/a.numel()).sum()) return math.sqrt((torch.pow((a-b), 2)/a.numel()).sum())
@pytest.mark.skipif(_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.") @pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16) @pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys()) @pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16) @pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16)
...@@ -1132,14 +1115,12 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd): ...@@ -1132,14 +1115,12 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd):
config = model_configs_fp8_vs_f16[model] config = model_configs_fp8_vs_f16[model]
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
if _NVTE_DEBUG:
print() logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
print("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16( fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm) dtype, config, True, qkv_format, input_layernorm)
if _NVTE_DEBUG:
print() logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
print("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16( fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
dtype, config, False, qkv_format, input_layernorm) dtype, config, False, qkv_format, input_layernorm)
...@@ -1149,19 +1130,18 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd): ...@@ -1149,19 +1130,18 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd):
fwd_range = max(fused_attn_fwd_fp8.max().item(), fwd_range = max(fused_attn_fwd_fp8.max().item(),
fused_attn_fwd_f16.max().item()) - min(fused_attn_fwd_fp8.min().item(), fused_attn_fwd_f16.max().item()) - min(fused_attn_fwd_fp8.min().item(),
fused_attn_fwd_f16.min().item()) fused_attn_fwd_f16.min().item())
if _NVTE_DEBUG:
print() logging.debug('========== {:^25s} =========='.format('forward output'))
print('========== {:^25s} =========='.format('forward output')) logging.debug('fused_attn_fwd_fp8 min {:.6f} max {:.6f}'.format(
print('fused_attn_fwd_fp8 min {:.6f} max {:.6f}'.format(
fused_attn_fwd_fp8.min().item(),fused_attn_fwd_fp8.max().item())) fused_attn_fwd_fp8.min().item(),fused_attn_fwd_fp8.max().item()))
print('fused_attn_fwd_f16 min {:.6f} max {:.6f}'.format( logging.debug('fused_attn_fwd_f16 min {:.6f} max {:.6f}'.format(
fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item())) fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item()))
print('fused_attn_fwd RMSE: {:.6f}'.format(fwd_rmse)) logging.debug('fused_attn_fwd RMSE: {:.6f}'.format(fwd_rmse))
try: try:
torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols) torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols)
except Exception as e: except Exception as e:
print(e) logging.debug(e)
print()
assert(fwd_rmse < rmse_tol * fwd_range assert(fwd_rmse < rmse_tol * fwd_range
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range) fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range)
...@@ -1170,19 +1150,18 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd): ...@@ -1170,19 +1150,18 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd):
bwd_range = max(fused_attn_bwd_fp8[i].max().item(), bwd_range = max(fused_attn_bwd_fp8[i].max().item(),
fused_attn_bwd_f16[i].max().item()) - min(fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_f16[i].max().item()) - min(fused_attn_bwd_fp8[i].min().item(),
fused_attn_bwd_f16[i].min().item()) fused_attn_bwd_f16[i].min().item())
if _NVTE_DEBUG:
print() logging.debug('========== {:^25s} =========='.format(param_names[i]))
print('========== {:^25s} =========='.format(param_names[i])) logging.debug('fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}'.format(i,
print('fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}'.format(i,
fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item())) fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item()))
print('fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}'.format(i, logging.debug('fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}'.format(i,
fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item())) fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item()))
print('fused_attn_bwd RMSE[{}]: {:.6f}'.format(i, bwd_rmse)) logging.debug('fused_attn_bwd RMSE[{}]: {:.6f}'.format(i, bwd_rmse))
try: try:
torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols) torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols)
except Exception as e: except Exception as e:
print(e) logging.debug(e)
print()
assert(bwd_rmse < rmse_tol * bwd_range assert(bwd_rmse < rmse_tol * bwd_range
), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( ), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range) bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range)
...@@ -1275,9 +1254,9 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm): ...@@ -1275,9 +1254,9 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
return out, param_names, tuple(x.grad for x in params) return out, param_names, tuple(x.grad for x in params)
@pytest.mark.skipif(_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.") @pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16) @pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys()) @pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16) @pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16)
...@@ -1289,13 +1268,13 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd): ...@@ -1289,13 +1268,13 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd):
pytest.skip("qkv_layout not applicable for MQA/GQA"); pytest.skip("qkv_layout not applicable for MQA/GQA");
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
if _NVTE_DEBUG:
print()
print("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True") logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True")
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
dtype, config, True, qkv_layout) dtype, config, True, qkv_layout)
if _NVTE_DEBUG:
print("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False") logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False")
fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16(
dtype, config, False, qkv_layout) dtype, config, False, qkv_layout)
...@@ -1306,19 +1285,18 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd): ...@@ -1306,19 +1285,18 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd):
fwd_range = max(fused_attn_fwd_fp8.max().item(), fwd_range = max(fused_attn_fwd_fp8.max().item(),
fused_attn_fwd_f16.max().item()) - min(fused_attn_fwd_fp8.min().item(), fused_attn_fwd_f16.max().item()) - min(fused_attn_fwd_fp8.min().item(),
fused_attn_fwd_f16.min().item()) fused_attn_fwd_f16.min().item())
if _NVTE_DEBUG:
print() logging.debug('========== {:^25s} =========='.format('forward output'))
print('========== {:^25s} =========='.format('forward output')) logging.debug('fused_attn_fwd_fp8 min {:.6f} max {:.6f}'.format(
print('fused_attn_fwd_fp8 min {:.6f} max {:.6f}'.format(
fused_attn_fwd_fp8.min().item(),fused_attn_fwd_fp8.max().item())) fused_attn_fwd_fp8.min().item(),fused_attn_fwd_fp8.max().item()))
print('fused_attn_fwd_f16 min {:.6f} max {:.6f}'.format( logging.debug('fused_attn_fwd_f16 min {:.6f} max {:.6f}'.format(
fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item())) fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item()))
print('fused_attn_fwd RMSE: {:.6f}'.format(fwd_rmse)) logging.debug('fused_attn_fwd RMSE: {:.6f}'.format(fwd_rmse))
try: try:
torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols) torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols)
except Exception as e: except Exception as e:
print(e) logging.debug(e)
print()
assert(fwd_rmse < rmse_tol * fwd_range assert(fwd_rmse < rmse_tol * fwd_range
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range) fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range)
...@@ -1327,19 +1305,18 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd): ...@@ -1327,19 +1305,18 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd):
bwd_range = max(fused_attn_bwd_fp8[i].max().item(), bwd_range = max(fused_attn_bwd_fp8[i].max().item(),
fused_attn_bwd_f16[i].max().item()) - min(fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_f16[i].max().item()) - min(fused_attn_bwd_fp8[i].min().item(),
fused_attn_bwd_f16[i].min().item()) fused_attn_bwd_f16[i].min().item())
if _NVTE_DEBUG:
print() logging.debug('========== {:^25s} =========='.format(bwd_names[i]))
print('========== {:^25s} =========='.format(bwd_names[i])) logging.debug('fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}'.format(i,
print('fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}'.format(i,
fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item())) fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item()))
print('fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}'.format(i, logging.debug('fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}'.format(i,
fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item())) fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item()))
print('fused_attn_bwd RMSE[{}]: {:.6f}'.format(i, bwd_rmse)) logging.debug('fused_attn_bwd RMSE[{}]: {:.6f}'.format(i, bwd_rmse))
try: try:
torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols) torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols)
except Exception as e: except Exception as e:
print(e) logging.debug(e)
print()
assert(bwd_rmse < rmse_tol * bwd_range assert(bwd_rmse < rmse_tol * bwd_range
), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( ), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range) bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range)
...@@ -1469,9 +1446,9 @@ models_v0 = ['fp8_1', 'fp8_2', 'fp8_5', 'fp8_6'] ...@@ -1469,9 +1446,9 @@ models_v0 = ['fp8_1', 'fp8_2', 'fp8_5', 'fp8_6']
models_v1 = ['fp8_3', 'fp8_4', 'fp8_7', 'fp8_8'] models_v1 = ['fp8_3', 'fp8_4', 'fp8_7', 'fp8_8']
@pytest.mark.skipif(_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.") @pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.parametrize("dtype", param_types_fp8) @pytest.mark.parametrize("dtype", param_types_fp8)
@pytest.mark.parametrize("model", models_v1 if cudnn_frontend_version == 1 else models_v0) @pytest.mark.parametrize("model", models_v1 if cudnn_frontend_version == 1 else models_v0)
def test_custom_mha_fp8_vs_f16(dtype, model): def test_custom_mha_fp8_vs_f16(dtype, model):
...@@ -1498,29 +1475,29 @@ def test_custom_mha_fp8_vs_f16(dtype, model): ...@@ -1498,29 +1475,29 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
bwd_range = max(fused_attn_bwd_fp8.max().item(), bwd_range = max(fused_attn_bwd_fp8.max().item(),
unfused_attn_bwd_f16.max().item()) - min(fused_attn_bwd_fp8.min().item(), unfused_attn_bwd_f16.max().item()) - min(fused_attn_bwd_fp8.min().item(),
unfused_attn_bwd_f16.min().item()) unfused_attn_bwd_f16.min().item())
if _NVTE_DEBUG:
print('fused_attn_fwd_fp8 min {:.6f} max {:.6f}'.format( logging.debug('fused_attn_fwd_fp8 min {:.6f} max {:.6f}'.format(
fused_attn_fwd_fp8.min().item(),fused_attn_fwd_fp8.max().item())) fused_attn_fwd_fp8.min().item(),fused_attn_fwd_fp8.max().item()))
print('unfused_attn_fwd_f16 min {:.6f} max {:.6f}'.format( logging.debug('unfused_attn_fwd_f16 min {:.6f} max {:.6f}'.format(
unfused_attn_fwd_f16.min().item(), unfused_attn_fwd_f16.max().item())) unfused_attn_fwd_f16.min().item(), unfused_attn_fwd_f16.max().item()))
print('fused_attn_fwd_fp8 vs unfused_attn_fwd_f16 RMSE: {:.6f}'.format( logging.debug('fused_attn_fwd_fp8 vs unfused_attn_fwd_f16 RMSE: {:.6f}'.format(
fwd_rmse)) fwd_rmse))
try: try:
torch.testing.assert_close(fused_attn_fwd_fp8, unfused_attn_fwd_f16, **tols) torch.testing.assert_close(fused_attn_fwd_fp8, unfused_attn_fwd_f16, **tols)
except Exception as e: except Exception as e:
print(e) logging.debug(e)
print()
print('fused_attn_bwd_fp8 min {:.6f} max {:.6f}'.format( logging.debug('fused_attn_bwd_fp8 min {:.6f} max {:.6f}'.format(
fused_attn_bwd_fp8.min().item(), fused_attn_bwd_fp8.max().item())) fused_attn_bwd_fp8.min().item(), fused_attn_bwd_fp8.max().item()))
print('unfused_attn_bwd_f16 min {:.6f} max {:.6f}'.format( logging.debug('unfused_attn_bwd_f16 min {:.6f} max {:.6f}'.format(
unfused_attn_bwd_f16.min().item(), unfused_attn_bwd_f16.max().item())) unfused_attn_bwd_f16.min().item(), unfused_attn_bwd_f16.max().item()))
print('fused_attn_bwd_fp8 vs unfused_attn_bwd_f16 RMSE: {:.6f}'.format( logging.debug('fused_attn_bwd_fp8 vs unfused_attn_bwd_f16 RMSE: {:.6f}'.format(
bwd_rmse)) bwd_rmse))
try: try:
torch.testing.assert_close(fused_attn_bwd_fp8, unfused_attn_bwd_f16, **tols) torch.testing.assert_close(fused_attn_bwd_fp8, unfused_attn_bwd_f16, **tols)
except Exception as e: except Exception as e:
print(e) logging.debug(e)
print()
assert(fwd_rmse < rmse_tol * fwd_range assert(fwd_rmse < rmse_tol * fwd_range
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range) fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range)
......
...@@ -10,6 +10,7 @@ import math ...@@ -10,6 +10,7 @@ import math
import os import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import warnings import warnings
import logging
import numpy as np import numpy as np
from packaging.version import Version as PkgVersion from packaging.version import Version as PkgVersion
...@@ -18,6 +19,8 @@ import torch ...@@ -18,6 +19,8 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import transformer_engine_torch as tex import transformer_engine_torch as tex
import transformer_engine as te
from transformer_engine.pytorch.utils import get_cudnn_version
from transformer_engine.pytorch.cpp_extensions import ( from transformer_engine.pytorch.cpp_extensions import (
cast_to_fp8, cast_to_fp8,
cast_from_fp8, cast_from_fp8,
...@@ -88,7 +91,17 @@ META_DO = tex.FP8BwdTensors.GRAD_INPUT2 ...@@ -88,7 +91,17 @@ META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_S = tex.FP8FwdTensors.GEMM3_OUTPUT META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP = tex.FP8BwdTensors.GRAD_INPUT3 META_DP = tex.FP8BwdTensors.GRAD_INPUT3
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
logging.basicConfig(
format='[%(levelname)-8s | %(name)-19s]: %(message)s',
level=log_levels[log_level if log_level in [0,1,2] else 2],
)
_alibi_cache = { _alibi_cache = {
"_num_heads": None, "_num_heads": None,
"_alibi_slopes": None, "_alibi_slopes": None,
...@@ -2297,9 +2310,9 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -2297,9 +2310,9 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
rng_gen, fused_attention_backend, use_FAv2_bwd, rng_gen, fused_attention_backend, use_FAv2_bwd,
fp8, fp8_meta): fp8, fp8_meta):
logger = logging.getLogger("FusedAttnFunc_qkvpacked")
if fp8: if fp8:
if _NVTE_DEBUG: logger.debug("Running forward in FP8")
print('[DotProductAttention]: using FP8 forward')
if fp8_meta["recipe"].fp8_mha: if fp8_meta["recipe"].fp8_mha:
assert (isinstance(qkv, Float8Tensor)), "qkv must be Float8Tensors for FP8 MHA." assert (isinstance(qkv, Float8Tensor)), "qkv must be Float8Tensors for FP8 MHA."
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv
...@@ -2356,8 +2369,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -2356,8 +2369,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
fp8_meta["scaling_fwd"].scale.clone(), fp8_meta["scaling_fwd"].scale.clone(),
fp8_meta["scaling_fwd"].scale_inv.clone()) fp8_meta["scaling_fwd"].scale_inv.clone())
else: else:
if _NVTE_DEBUG: logger.debug("Running forward in %s",qkv.dtype)
print('[DotProductAttention]: using non-FP8 forward')
out_ret, aux_ctx_tensors = fused_attn_fwd_qkvpacked( out_ret, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype,
fused_attention_backend, attn_bias, fused_attention_backend, attn_bias,
...@@ -2390,6 +2402,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -2390,6 +2402,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, d_out): def backward(ctx, d_out):
logger = logging.getLogger("FusedAttnFunc_qkvpacked")
if ctx.fp8_meta["recipe"].fp8_mha: if ctx.fp8_meta["recipe"].fp8_mha:
assert (isinstance(d_out, Float8Tensor) assert (isinstance(d_out, Float8Tensor)
), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
...@@ -2419,8 +2432,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -2419,8 +2432,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
else: else:
with torch.cuda.nvtx.range("_FusedAttn_qkvpacked"): with torch.cuda.nvtx.range("_FusedAttn_qkvpacked"):
if ctx.fp8: if ctx.fp8:
if _NVTE_DEBUG: logger.debug("Running backward in FP8")
print('[DotProductAttention]: using FP8 backward')
fp8_dtype_forward = get_fp8_te_dtype( fp8_dtype_forward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=True) ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype( fp8_dtype_backward = get_fp8_te_dtype(
...@@ -2466,8 +2478,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -2466,8 +2478,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
ctx.fp8_meta["scaling_bwd"], META_DQKV, ctx.fp8_meta["scaling_bwd"], META_DQKV,
fp8_dtype_backward, ctx.qkv_dtype).view(dqkv_fp8.shape) fp8_dtype_backward, ctx.qkv_dtype).view(dqkv_fp8.shape)
else: else:
if _NVTE_DEBUG: logger.debug("Running backward in %s",qkv.dtype)
print('[DotProductAttention]: using non-FP8 backward')
if d_out.dtype == torch.uint8: if d_out.dtype == torch.uint8:
d_out = d_out_f8tensor.from_float8(qkv.dtype) d_out = d_out_f8tensor.from_float8(qkv.dtype)
dqkv, *rest = fused_attn_bwd_qkvpacked( dqkv, *rest = fused_attn_bwd_qkvpacked(
...@@ -2499,9 +2510,9 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -2499,9 +2510,9 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
q, kv, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill, q, kv, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill,
qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend, qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend,
use_FAv2_bwd, fp8, fp8_meta): use_FAv2_bwd, fp8, fp8_meta):
logger = logging.getLogger("FusedAttnFunc_kvpacked")
if fp8: if fp8:
if _NVTE_DEBUG: logger.debug("Running forward in FP8")
print('[DotProductAttention]: using FP8 forward')
if fp8_meta["recipe"].fp8_mha: if fp8_meta["recipe"].fp8_mha:
assert (isinstance(q, Float8Tensor) assert (isinstance(q, Float8Tensor)
and isinstance(kv, Float8Tensor)), "q/kv must be Float8Tensors for FP8 MHA." and isinstance(kv, Float8Tensor)), "q/kv must be Float8Tensors for FP8 MHA."
...@@ -2565,8 +2576,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -2565,8 +2576,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
fp8_meta["scaling_fwd"].scale.clone(), fp8_meta["scaling_fwd"].scale.clone(),
fp8_meta["scaling_fwd"].scale_inv.clone()) fp8_meta["scaling_fwd"].scale_inv.clone())
else: else:
if _NVTE_DEBUG: logger.debug("Running forward in %s",q.dtype)
print('[DotProductAttention]: using non-FP8 forward')
out_ret, aux_ctx_tensors = fused_attn_fwd_kvpacked( out_ret, aux_ctx_tensors = fused_attn_fwd_kvpacked(
is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, kv, qkv_dtype, fused_attention_backend, attn_bias, q, kv, qkv_dtype, fused_attention_backend, attn_bias,
...@@ -2600,6 +2610,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -2600,6 +2610,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, d_out): def backward(ctx, d_out):
logger = logging.getLogger("FusedAttnFunc_kvpacked")
if ctx.fp8_meta["recipe"].fp8_mha: if ctx.fp8_meta["recipe"].fp8_mha:
assert (isinstance(d_out, Float8Tensor) assert (isinstance(d_out, Float8Tensor)
), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
...@@ -2631,8 +2642,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -2631,8 +2642,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
else: else:
with torch.cuda.nvtx.range("_FusedAttn_kvpacked"): with torch.cuda.nvtx.range("_FusedAttn_kvpacked"):
if ctx.fp8: if ctx.fp8:
if _NVTE_DEBUG: logger.debug("Running backward in FP8")
print('[DotProductAttention]: using FP8 backward')
fp8_dtype_forward = get_fp8_te_dtype( fp8_dtype_forward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=True) ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype( fp8_dtype_backward = get_fp8_te_dtype(
...@@ -2689,8 +2699,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -2689,8 +2699,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
ctx.fp8_meta["scaling_bwd"], META_DQKV, ctx.fp8_meta["scaling_bwd"], META_DQKV,
fp8_dtype_backward, ctx.qkv_dtype).view(dkv_fp8.shape) fp8_dtype_backward, ctx.qkv_dtype).view(dkv_fp8.shape)
else: else:
if _NVTE_DEBUG: logger.debug("Running backward in %s",q.dtype)
print('[DotProductAttention]: using non-FP8 backward')
if d_out.dtype == torch.uint8: if d_out.dtype == torch.uint8:
d_out = d_out_f8tensor.from_float8(q.dtype) d_out = d_out_f8tensor.from_float8(q.dtype)
dq, dkv, *rest = fused_attn_bwd_kvpacked( dq, dkv, *rest = fused_attn_bwd_kvpacked(
...@@ -2722,9 +2731,9 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -2722,9 +2731,9 @@ class FusedAttnFunc(torch.autograd.Function):
q, k, v, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill, q, k, v, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill,
qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend, qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend,
use_FAv2_bwd, fp8, fp8_meta): use_FAv2_bwd, fp8, fp8_meta):
logger = logging.getLogger("FusedAttnFunc")
if fp8: if fp8:
if _NVTE_DEBUG: logger.debug("Running forward in FP8")
print('[DotProductAttention]: using FP8 forward')
fused_attention_backend = FusedAttnBackend["FP8"] fused_attention_backend = FusedAttnBackend["FP8"]
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if fp8_meta["recipe"].fp8_mha: if fp8_meta["recipe"].fp8_mha:
...@@ -2837,8 +2846,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -2837,8 +2846,7 @@ class FusedAttnFunc(torch.autograd.Function):
fp8_meta["scaling_fwd"].scale.clone(), fp8_meta["scaling_fwd"].scale.clone(),
fp8_meta["scaling_fwd"].scale_inv.clone()) fp8_meta["scaling_fwd"].scale_inv.clone())
else: else:
if _NVTE_DEBUG: logger.debug("Running forward in %s",q.dtype)
print('[DotProductAttention]: using non-FP8 forward')
out_ret, aux_ctx_tensors = fused_attn_fwd( out_ret, aux_ctx_tensors = fused_attn_fwd(
is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, k, v, qkv_dtype, fused_attention_backend, attn_bias, q, k, v, qkv_dtype, fused_attention_backend, attn_bias,
...@@ -2880,6 +2888,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -2880,6 +2888,7 @@ class FusedAttnFunc(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, d_out): def backward(ctx, d_out):
logger = logging.getLogger("FusedAttnFunc")
if ctx.fp8_meta["recipe"].fp8_mha: if ctx.fp8_meta["recipe"].fp8_mha:
assert (isinstance(d_out, Float8Tensor) assert (isinstance(d_out, Float8Tensor)
), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
...@@ -2913,8 +2922,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -2913,8 +2922,7 @@ class FusedAttnFunc(torch.autograd.Function):
else: else:
with torch.cuda.nvtx.range("_FusedAttn"): with torch.cuda.nvtx.range("_FusedAttn"):
if ctx.fp8: if ctx.fp8:
if _NVTE_DEBUG: logger.debug("Running backward in FP8")
print('[DotProductAttention]: using FP8 backward')
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype( fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False) ctx.fp8_meta["recipe"], fprop_tensor=False)
...@@ -3006,8 +3014,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -3006,8 +3014,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.fp8_meta["scaling_bwd"], META_DQKV, ctx.fp8_meta["scaling_bwd"], META_DQKV,
fp8_dtype_backward, ctx.qkv_dtype).view(dv_fp8.shape) fp8_dtype_backward, ctx.qkv_dtype).view(dv_fp8.shape)
else: else:
if _NVTE_DEBUG: logger.debug("Running backward in %s",q.dtype)
print('[DotProductAttention]: using non-FP8 backward')
if d_out.dtype == torch.uint8: if d_out.dtype == torch.uint8:
d_out = d_out_f8tensor.from_float8(q.dtype) d_out = d_out_f8tensor.from_float8(q.dtype)
dq, dk, dv, *rest = fused_attn_bwd( dq, dk, dv, *rest = fused_attn_bwd(
...@@ -3072,6 +3079,7 @@ class FusedAttention(TransformerEngineBaseModule): ...@@ -3072,6 +3079,7 @@ class FusedAttention(TransformerEngineBaseModule):
) -> None: ) -> None:
super().__init__() super().__init__()
self.logger = logging.getLogger("FusedAttention")
self.softmax_scale = softmax_scale self.softmax_scale = softmax_scale
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.attention_dropout_ctx = attention_dropout_ctx self.attention_dropout_ctx = attention_dropout_ctx
...@@ -3266,12 +3274,14 @@ class FusedAttention(TransformerEngineBaseModule): ...@@ -3266,12 +3274,14 @@ class FusedAttention(TransformerEngineBaseModule):
if not self.fp8_meta["recipe"].fp8_dpa: if not self.fp8_meta["recipe"].fp8_dpa:
self.fp8_meta["recipe"].fp8_dpa = True self.fp8_meta["recipe"].fp8_dpa = True
forced_fp8_dpa = " (forced)" forced_fp8_dpa = " (forced)"
if _NVTE_DEBUG: if fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8:
print("[DotProductAttention]: " self.logger.debug(
f"""using fp8_recipe.fp8_mha={self.fp8_meta["recipe"].fp8_mha}, """ "Running with fp8_recipe.fp8_mha=%s, "
f"""fp8_recipe.fp8_dpa={self.fp8_meta["recipe"].fp8_dpa}""" "fp8_recipe.fp8_dpa=%s%s, and NVTE_FP8_DPA_BWD=%s",
f"""{forced_fp8_dpa} and """ self.fp8_meta["recipe"].fp8_mha,
f"""NVTE_FP8_DPA_BWD={int(os.getenv("NVTE_FP8_DPA_BWD", "1"))}""") self.fp8_meta["recipe"].fp8_dpa,
forced_fp8_dpa,
int(os.getenv("NVTE_FP8_DPA_BWD", "1")))
output = FusedAttnFunc.apply( output = FusedAttnFunc.apply(
self.training, self.training,
max_seqlen_q, max_seqlen_kv, max_seqlen_q, max_seqlen_kv,
...@@ -3411,6 +3421,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -3411,6 +3421,7 @@ class DotProductAttention(torch.nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.logger = logging.getLogger("DotProductAttention")
self.qkv_format = qkv_format self.qkv_format = qkv_format
attn_mask_type = attn_mask_type.replace(",","_") attn_mask_type = attn_mask_type.replace(",","_")
if attn_mask_type == "causal_padding": if attn_mask_type == "causal_padding":
...@@ -3456,9 +3467,14 @@ class DotProductAttention(torch.nn.Module): ...@@ -3456,9 +3467,14 @@ class DotProductAttention(torch.nn.Module):
int(os.getenv("NVTE_FLASH_ATTN", "1")) int(os.getenv("NVTE_FLASH_ATTN", "1"))
and self.device_compute_capability >= (8, 0) and self.device_compute_capability >= (8, 0)
) )
if int(os.getenv("NVTE_FLASH_ATTN", "1")) == 0:
self.logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0")
if self.device_compute_capability < (8, 0):
self.logger.debug("Disabling FlashAttention for compute capability < sm80")
if not _flash_attn_2_4_1_plus and self.deterministic: if not _flash_attn_2_4_1_plus and self.deterministic:
self.use_flash_attention = False self.use_flash_attention = False
warnings.warn( self.logger.warning(
"Disabling usage of FlashAttention since version <2.4.1 does not support " "Disabling usage of FlashAttention since version <2.4.1 does not support "
"deterministic execution. In order to use FA with deterministic behavior," "deterministic execution. In order to use FA with deterministic behavior,"
" please install FlashAttention version >=2.4.1." " please install FlashAttention version >=2.4.1."
...@@ -3468,6 +3484,10 @@ class DotProductAttention(torch.nn.Module): ...@@ -3468,6 +3484,10 @@ class DotProductAttention(torch.nn.Module):
int(os.getenv("NVTE_FUSED_ATTN", "1")) int(os.getenv("NVTE_FUSED_ATTN", "1"))
and self.device_compute_capability >= (8, 0) and self.device_compute_capability >= (8, 0)
) )
if int(os.getenv("NVTE_FUSED_ATTN", "1")) == 0:
self.logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0")
if self.device_compute_capability < (8, 0):
self.logger.debug("Disabling FusedAttention for compute capability < sm80")
assert ( assert (
attention_type in AttnTypes attention_type in AttnTypes
...@@ -3835,44 +3855,69 @@ class DotProductAttention(torch.nn.Module): ...@@ -3835,44 +3855,69 @@ class DotProductAttention(torch.nn.Module):
# certain asserts before executing the forward pass. # certain asserts before executing the forward pass.
# Filter: QKV layout. # Filter: QKV layout.
if qkv_format == 'thd': if use_unfused_attention and qkv_format == 'thd':
self.logger.debug("Disabling UnusedDotProductAttention for qkv_format = thd")
use_unfused_attention = False use_unfused_attention = False
# Filter: ONNX export. # Filter: ONNX export.
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
if use_flash_attention:
self.logger.debug("Disabling FlashAttention for ONNX mode")
use_flash_attention = False use_flash_attention = False
if use_fused_attention:
self.logger.debug("Disabling FusedAttention for ONNX mode")
use_fused_attention = False use_fused_attention = False
# Filter: Input type. # Filter: Input type.
if (query_layer.dtype not in [torch.bfloat16, torch.float16] if (use_flash_attention
and (query_layer.dtype not in [torch.bfloat16, torch.float16]
or key_layer.dtype not in [torch.bfloat16, torch.float16] or key_layer.dtype not in [torch.bfloat16, torch.float16]
or value_layer.dtype not in [torch.bfloat16, torch.float16] or value_layer.dtype not in [torch.bfloat16, torch.float16]
or any(isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]) or any(isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]))
): ):
self.logger.debug(
"Disabling FlashAttention due to unsupported QKV data types. "
"Supported: [torch.bfloat16, torch.float16]. "
"Found: query_layer.dtype=%s, key_layer.dtype=%s, value_layer.dtype=%s.",
query_layer.dtype, key_layer.dtype, value_layer.dtype)
use_flash_attention = False use_flash_attention = False
if (query_layer.dtype not in [torch.bfloat16, torch.float16] if (use_fused_attention
and (query_layer.dtype not in [torch.bfloat16, torch.float16]
or key_layer.dtype not in [torch.bfloat16, torch.float16] or key_layer.dtype not in [torch.bfloat16, torch.float16]
or value_layer.dtype not in [torch.bfloat16, torch.float16] or value_layer.dtype not in [torch.bfloat16, torch.float16])
): ):
self.logger.debug(
"Disabling FusedAttention due to unsupported QKV data types. "
"Supported: [torch.bfloat16, torch.float16, Float8Tensor]. "
"Found: query_layer.dtype=%s, key_layer.dtype=%s, value_layer.dtype=%s.",
query_layer.dtype, key_layer.dtype, value_layer.dtype)
use_fused_attention = False use_fused_attention = False
# Filter: Device and dimensions. # Filter: Device and dimensions.
# FAv2 supports head_dim <= 256, and for >192 requires sm80/sm90 # FAv2 supports head_dim <= 256, and for >192 requires sm80/sm90
# FAv2 requires head_dim % 8 == 0 # FAv2 requires head_dim % 8 == 0
if (key_layer.shape[-1] > 256 if (use_flash_attention
or key_layer.shape[-1] % 8 != 0 and (query_layer.shape[-1] > 256
or (key_layer.shape[-1] > 192 or query_layer.shape[-1] % 8 != 0
and self.device_compute_capability not in ((8, 0), (9, 0)))): or (query_layer.shape[-1] > 192
and self.device_compute_capability not in ((8, 0), (9, 0))))):
self.logger.debug(
"Disabling FlashAttention due to unsupported head_dim. "
"Supported: %%8 == 0, and <= 256; sm80/90 for >192. "
"Found: query_layer.shape[-1]=%s, key_layer.shape[-1]=%s, sm=%s",
query_layer.shape[-1], key_layer.shape[-1],
'.'.join([str(i) for i in self.device_compute_capability]))
use_flash_attention = False use_flash_attention = False
# Filter: cross attention + causal mask. # Filter: cross attention + causal mask.
# (in training mode) # (in training mode)
if (inference_params is None if (use_flash_attention
and inference_params is None
and _flash_attn_2_1_plus and _flash_attn_2_1_plus
and "causal" in attn_mask_type and "causal" in attn_mask_type
and max_seqlen_q != max_seqlen_kv and max_seqlen_q != max_seqlen_kv
): ):
warnings.warn( self.logger.warning(
"In training mode, disable the use of FlashAttention since version 2.1+ has " "In training mode, disable the use of FlashAttention since version 2.1+ has "
"changed its behavior for causal mask in cross attention. See " "changed its behavior for causal mask in cross attention. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag" "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
...@@ -3885,8 +3930,14 @@ class DotProductAttention(torch.nn.Module): ...@@ -3885,8 +3930,14 @@ class DotProductAttention(torch.nn.Module):
# Filter: sliding window attention. # Filter: sliding window attention.
# UnfusedDotProductAttention can support SWA via arbitrary attention mask. # UnfusedDotProductAttention can support SWA via arbitrary attention mask.
if window_size not in ((-1, -1), (-1, 0)): if window_size not in ((-1, -1), (-1, 0)):
if use_fused_attention:
self.logger.debug("Disabling FusedAttention for SWA")
use_fused_attention = False use_fused_attention = False
if (not _flash_attn_2_3_plus) or context_parallel: if (not _flash_attn_2_3_plus) or context_parallel:
if use_flash_attention:
self.logger.debug(
"Disabling FusedAttention as it requires flash-attn 2.3+ "
"and no context parallelism")
use_flash_attention = False use_flash_attention = False
# Filter: Attention mask type. # Filter: Attention mask type.
...@@ -3899,13 +3950,19 @@ class DotProductAttention(torch.nn.Module): ...@@ -3899,13 +3950,19 @@ class DotProductAttention(torch.nn.Module):
# arbitrary | UnfusedDotProductAttention # arbitrary | UnfusedDotProductAttention
# #
if attn_mask_type == "arbitrary": if attn_mask_type == "arbitrary":
if use_flash_attention:
self.logger.debug("Disabling FlashAttention for arbitrary mask")
use_flash_attention = False use_flash_attention = False
if use_fused_attention:
self.logger.debug("Disabling FusedAttention for arbitrary mask")
use_fused_attention = False use_fused_attention = False
if (inference_params is None if (use_unfused_attention
and inference_params is None
and "causal" in attn_mask_type and "causal" in attn_mask_type
and max_seqlen_q != max_seqlen_kv and max_seqlen_q != max_seqlen_kv
): ):
self.logger.debug("Disabling UnusedDotProductAttention for qkv_format = thd")
use_unfused_attention = False use_unfused_attention = False
# Filter: bias. # Filter: bias.
...@@ -3926,7 +3983,10 @@ class DotProductAttention(torch.nn.Module): ...@@ -3926,7 +3983,10 @@ class DotProductAttention(torch.nn.Module):
_alibi_cache["_alibi_slopes_require_update"] = True _alibi_cache["_alibi_slopes_require_update"] = True
_alibi_cache["_alibi_bias_require_update"] = True _alibi_cache["_alibi_bias_require_update"] = True
if core_attention_bias_type not in ["no_bias", "alibi"] or core_attention_bias is not None: if (use_flash_attention
and (core_attention_bias_type not in ["no_bias", "alibi"]
or core_attention_bias is not None)):
self.logger.debug("Disabling FlashAttention for pre/post_scale_bias")
use_flash_attention = False use_flash_attention = False
fu_core_attention_bias_type = core_attention_bias_type fu_core_attention_bias_type = core_attention_bias_type
...@@ -3943,6 +4003,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -3943,6 +4003,7 @@ class DotProductAttention(torch.nn.Module):
if fu_core_attention_bias.requires_grad: if fu_core_attention_bias.requires_grad:
# remove this line when cuDNN adds bwd support for # remove this line when cuDNN adds bwd support for
# [1, 1, s, s], [b, 1, s, s] and [b, h, s, s] # [1, 1, s, s], [b, 1, s, s] and [b, h, s, s]
self.logger.debug("Disabling FusedAttention for dBias in [1, H, S, S] shape")
use_fused_attention = False use_fused_attention = False
else: else:
# max512 backend will only support [1, h, s, s] # max512 backend will only support [1, h, s, s]
...@@ -3977,6 +4038,8 @@ class DotProductAttention(torch.nn.Module): ...@@ -3977,6 +4038,8 @@ class DotProductAttention(torch.nn.Module):
and fu_core_attention_bias_type == "post_scale_bias" and fu_core_attention_bias_type == "post_scale_bias"
and (fu_core_attention_bias.shape[0] != 1 and (fu_core_attention_bias.shape[0] != 1
or fu_core_attention_bias.shape[1] != query_layer.shape[-2])): or fu_core_attention_bias.shape[1] != query_layer.shape[-2])):
self.logger.debug(
"Disabling FusedAttention as no backend supports the provided input")
use_fused_attention = False use_fused_attention = False
# Filter: determinism. # Filter: determinism.
...@@ -3995,6 +4058,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -3995,6 +4058,7 @@ class DotProductAttention(torch.nn.Module):
and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
and self.deterministic and self.deterministic
and self.device_compute_capability != (9, 0)): and self.device_compute_capability != (9, 0)):
self.logger.debug("Disabling FusedAttention for determinism reasons")
use_fused_attention = False use_fused_attention = False
# Select FusedAttention on sm90 and FlashAttention on others for performance # Select FusedAttention on sm90 and FlashAttention on others for performance
...@@ -4002,11 +4066,35 @@ class DotProductAttention(torch.nn.Module): ...@@ -4002,11 +4066,35 @@ class DotProductAttention(torch.nn.Module):
and use_fused_attention and use_fused_attention
and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]): and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]):
if self.device_compute_capability == (9, 0): if self.device_compute_capability == (9, 0):
self.logger.debug(
"Disabling FlashAttention to give FusedAttention preference on Hopper+ "
"for performance reasons")
use_flash_attention = False use_flash_attention = False
run_config = {
"compute_capability":"sm"+str((lambda x,y: x*10+y)(
self.device_compute_capability[0],self.device_compute_capability[1])),
"q_dtype":query_layer.dtype,
"k_dtype":key_layer.dtype,
"v_dtype":value_layer.dtype,
"q_shape":list(query_layer.shape),
"k_shape":list(key_layer.shape),
"v_shape":list(value_layer.shape),
"qkv_format":qkv_format,
"qkv_layout":qkv_layout,
"mask_type":attn_mask_type,
"bias_type":core_attention_bias_type,
"bias_shape":core_attention_bias.shape if core_attention_bias is not None else None,
"dropout":self.attention_dropout,
"context_parallel":context_parallel,
"is_training":self.training,
"transformer_engine_version":te.__version__,
"flash_attn_version":_flash_attn_version,
"cudnn_version":'.'.join([str(i) for i in get_cudnn_version()])}
if use_flash_attention: if use_flash_attention:
if _NVTE_DEBUG: self.logger.info("Running with FlashAttention backend ")
print("[DotProductAttention]: using flash-attn",_flash_attn_version) self.logger.debug("Running with config=%s",run_config)
if core_attention_bias_type == "alibi": if core_attention_bias_type == "alibi":
alibi_slopes, _ = get_alibi( alibi_slopes, _ = get_alibi(
query_layer.shape[-2], max_seqlen_q, max_seqlen_kv, alibi_slopes=alibi_slopes) query_layer.shape[-2], max_seqlen_q, max_seqlen_kv, alibi_slopes=alibi_slopes)
...@@ -4027,9 +4115,10 @@ class DotProductAttention(torch.nn.Module): ...@@ -4027,9 +4115,10 @@ class DotProductAttention(torch.nn.Module):
max_seqlen_kv=max_seqlen_kv) max_seqlen_kv=max_seqlen_kv)
if use_fused_attention: if use_fused_attention:
if _NVTE_DEBUG: self.logger.info(
print("[DotProductAttention]: using cuDNN fused attention (backend " "Running with FusedAttention backend (sub-backend %s)",
+ str(int(fused_attention_backend)) + ")") int(fused_attention_backend))
self.logger.debug("Running with config=%s",run_config)
if checkpoint_core_attention: if checkpoint_core_attention:
return self._checkpointed_attention_forward( return self._checkpointed_attention_forward(
self.fused_attention, self.fused_attention,
...@@ -4089,9 +4178,9 @@ class DotProductAttention(torch.nn.Module): ...@@ -4089,9 +4178,9 @@ class DotProductAttention(torch.nn.Module):
"with Flash Attention and Fused Attention!" "with Flash Attention and Fused Attention!"
) )
if _NVTE_DEBUG:
print("[DotProductAttention]: using unfused DPA")
if use_unfused_attention: if use_unfused_attention:
self.logger.info("Running with UnfusedDotProductAttention backend")
self.logger.debug("Running with config=%s",run_config)
if checkpoint_core_attention: if checkpoint_core_attention:
return self._checkpointed_attention_forward( return self._checkpointed_attention_forward(
self.unfused_attention, self.unfused_attention,
......
...@@ -4,8 +4,10 @@ ...@@ -4,8 +4,10 @@
"""Utility functions for Transformer Engine modules""" """Utility functions for Transformer Engine modules"""
import math import math
import functools
from typing import Any, Callable, Optional, Tuple from typing import Any, Callable, Optional, Tuple
import torch import torch
import transformer_engine.pytorch.cpp_extensions as ext
def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
...@@ -241,8 +243,19 @@ def assert_dim_for_fp8_exec(tensor: torch.Tensor) -> None: ...@@ -241,8 +243,19 @@ def assert_dim_for_fp8_exec(tensor: torch.Tensor) -> None:
f"but got tensor with dims={list(tensor.size())}" f"but got tensor with dims={list(tensor.size())}"
) )
def is_bf16_compatible() -> None: def is_bf16_compatible() -> None:
"""Replaces torch.cuda.is_bf16_compatible() with an explicit """Replaces torch.cuda.is_bf16_compatible() with an explicit
check on device compute capability to enforce sm_80 or higher. check on device compute capability to enforce sm_80 or higher.
""" """
return torch.cuda.get_device_capability()[0] >= 8 return torch.cuda.get_device_capability()[0] >= 8
@functools.cache
def get_cudnn_version() -> Tuple[int, int, int]:
"""Runtime cuDNN version (major, minor, patch)"""
encoded_version = ext.get_cudnn_version()
major_version_magnitude = 1000 if encoded_version < 90000 else 10000
major, encoded_version = divmod(encoded_version, major_version_magnitude)
minor, patch = divmod(encoded_version, 100)
return (major, minor, patch)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment