attention.ipynb 34.1 KB
Newer Older
1
2
3
4
{
 "cells": [
  {
   "cell_type": "markdown",
5
   "id": "8ae3bc43",
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
   "metadata": {},
   "source": [
    "# Attention Is All You Need!\n",
    "\n",
    "The core idea behind Transformer models is the attention mechanism [[1]](https://arxiv.org/abs/1706.03762). It identifies the correlation between words, selects the most important parts of the sentence to focus on, and captures meaningful patterns and dependencies in the data. Figure 1 shows a typical attention mechanism, where pre-softmax operations can be a combination of scaling, bias and masking while the post-softmax operation is often just dropout.\n",
    "\n",
    "<figure align=\"center\">\n",
    "<img src=\"dot_product_attention.png\" width=\"70%\">\n",
    "<figcaption> Figure 1: Dot product attention. </figcaption>\n",
    "</figure>\n",
    "\n",
    "[Transformer Engine](https://github.com/NVIDIA/TransformerEngine.git) supports the calculation of dot product attention in three frameworks, [PyTorch](https://github.com/pytorch/pytorch), [JAX](https://github.com/google/jax) and [PaddlePaddle](https://github.com/PaddlePaddle/Paddle). The API for each framework is,\n",
    "- [transformer_engine.pytorch.DotProductAttention](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention)\n",
    "- [transformer_engine.jax.flax.DotProductAttention](../../api/jax.rst#transformer_engine.jax.flax.DotProductAttention)\n",
    "- [transformer_engine.paddle.DotProductAttention](../../api/paddle.rst#transformer_engine.paddle.DotProductAttention)"
   ]
  },
  {
   "cell_type": "markdown",
25
   "id": "47421c01",
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
   "metadata": {},
   "source": [
    "## 1. Attention Backends\n",
    "\n",
    "Transformer Engine provides multiple attention backends for each supported framework. The framework-native backends provide a robust baseline, while the fused, GPU-optimized implementations offer more performance. For example, the flash-attention and cuDNN attention backends in PyTorch. The framework-native backends are often named with \"unfused\", while the more optimized backends are \"fused\" or \"flash\".\n",
    "\n",
    "| Framework | Backend (Module Name) | Module Location |\n",
    "| :-------- | :-------------------- | :-------------- |\n",
    "| PyTorch   | cuDNN attention (`FusedAttention`)<br> flash-attention (`FlashAttention`)<br> PyTorch-native attention (`UnfusedDotProductAttention`) | [transformer_engine.pytorch.attention](../../transformer_engine/pytorch/attention.py)      |\n",
    "| JAX       | cuDNN attention (`_FusedDotProductAttention`)<br> JAX-native attention (`_UnfusedDotProductAttention`)                                | [transformer_engine.jax.flax.transformer](../../transformer_engine/jax/flax/transformer.py)   |\n",
    "| PaddlePaddle    | cuDNN attention (`_te_forward`)<br> PaddlePaddle-native attention (`_pd_forward`)                                                           | [transformer_engine.paddle.layer.attention](../../transformer_engine/paddle/layer/attention.py) |\n"
   ]
  },
  {
   "cell_type": "markdown",
41
   "id": "e52f60f0",
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
   "metadata": {},
   "source": [
    "### 1.1 Flash vs. Non-Flash\n",
    "\n",
    "The attention calculation has quadratic computational and memory complexities to the sequence length. Its runtime and memory requirements quadruple, when the sequence length doubles. This presents a significant challenge to scale Transformer models up for longer contexts, in order to achieve higher model quality.\n",
    "\n",
    "Compared to the standard, non-flash algorithm, the flash algorithm [[2]](https://arxiv.org/abs/2205.14135) was proposed to reduce the memory scaling to linear and improve the computational efficiency through optimized memory accesses. It employs the following two distinctive techniques.\n",
    "\n",
    "- **Tiling:** The non-flash algorithm tries to process the query, key, value tensors in one single step, requiring large amounts of global memory and incurring high volumes of reads/writes between global memory and shared memory. The flash algorithm decomposes the input into several tiles, based on the available shared memory and register size, and it computes the softmax one tile at a time.\n",
    "\n",
    "- **Recomputation:** The non-flash algorithm stores the softmax matrix (quadratic to sequence length) to global memory for the backward pass, while the flash algorithm only saves the softmax normalization factors (linear to sequence length). This reduces the amount of memory required as well as the bandwidth utilization between global memory and shared memory. Even though there is extra computation incurred in order to recalculate the attention in the backward pass, the bandwidth savings still provide significant improvement in efficiency.\n",
    "\n",
    "<div class=\"alert alert-info\">\n",
    "<b>Note:</b> Transformer Engine's flash-attention backend, available in PyTorch, and cuDNN attention backend (sub-backends 1 and 2), in PyTorch, JAX and PaddlePaddle, are both based on the flash algorithm.\n",
    "</div>\n"
   ]
  },
  {
   "cell_type": "markdown",
61
   "id": "bb909ac4",
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
   "metadata": {},
   "source": [
    "### 1.2 flash-attention\n",
    "\n",
    "The flash-attention backend, available only in PyTorch, is a module wrapped around the public `flash-attn` package [[3]](https://github.com/Dao-AILab/flash-attention). \n",
    "\n",
    "The flash-attention backend supports `flash-attn`'s features as they are released, and to facilitate the use of `flash-attn`, flash-attention also offers a few functionalities such as converting the `attention_mask` to cumulative sequence lengths `cu_seqlens` for `padding` mask. Please see `transformer_engine.pytorch.attention.FlashAttention` for more details.\n",
    "\n",
    "The `flash-attn` dependency is regularly updated in Transformer Engine. As of v1.7, Transformer Engine supports `flash-attn` 2.0.6+ (see [setup.py](../../setup.py)).\n",
    "\n",
    "To understand `flash-attn`'s performance, please refer to their [benchmarks](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#performance).\n",
    "\n",
    "### 1.3 cuDNN Attention\n",
    "\n",
    "The cuDNN attention backend, available in PyTorch, JAX and PaddlePaddle, offers another high-performance solution to the attention calculation. It requires [cuDNN](https://developer.nvidia.com/cudnn) and [cudnn-frontend](../../3rdparty/cudnn-frontend) to run, and has several sub-backends to support the different precisions and sequence lengths. Out of the three, sub-backends 1 and 2 are based on the flash algorithm, as `flash-attn` is.\n",
    "\n",
    "| Sub-Backend |  Algorithm | Precision | Sequence Length | Architecture | Docs |\n",
    "| :---------- | :--------- | :-------- | :-------------- | :----------- | :--- |\n",
    "| 0 | Non-Flash | BF16/FP16       | <=512       | sm80, 90 | [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/latest/developer/graph-api.html#fused-attention-fprop) |\n",
    "| 1 | Flash     | BF16/FP16       | Any         | sm80+    | [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/latest/developer/graph-api.html#fused-flash-attention-fprop),<br>[cudnn-frontend](https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md#scaled-dot-product-attention) |\n",
    "| 2 | Flash     | FP8             | cuDNN pre-9.0: <=512<br>cuDNN 9.0+: Any | cuDNN pre-9.0: sm90<br>cuDNN 9.0+:  sm90+ | cuDNN 9.0+: [cudnn-frontend](https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md#scaled-dot-product-attention-fp8) |\n",
    "\n",
    "The cuDNN attention backend and flash-attention backend have several notable differences. As of Transformer Engine 1.7, cuDNN 9.0 and `flash-attn` 2.4.2,\n",
    "\n",
    "- flash-attention only supports the PyTorch framework while cuDNN attention supports PyTorch, JAX and PaddlePaddle.\n",
    "- flash-attention supports BF16, FP16 precisions while cuDNN attention also supports FP8 (through its sub-backend 2).\n",
    "- flash-attention supports `bshd`, `thd` input formats, without any transposes, and `sbhd` format, with transposes, while cuDNN attention supports all three without transposes (see Section 3.1 for more details).\n",
    "- flash-attention does not support `post_scale_bias`, and cuDNN attention does.\n",
    "- flash-attention supports sliding window attention, and cuDNN attention does not.\n",
    "- flash-attention uses bottom right diagonal for `causal` mask in cross attention, and cuDNN attention uses top left (see `flash-attn`'s [change log](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#21-change-behavior-of-causal-flag)).\n",
    "- flash-attention outperforms cuDNN attention on Ampere architectures, and cuDNN attention has 20-50% advantages on Hopper architectures, based on our benchmarks for a number of commonly-used model configurations.\n",
    "\n",
    "To compare cuDNN attention and flash-attention, users can modify the `model_configs` dictionary in [benchmarks/attention/benchmark_attention.py](../../benchmarks/attention/benchmark_attention.py) to collect performance numbers. The script runs each entry in `model_configs` for `num_iters` times, each time with one forward pass and one backward pass. Both backends are tried, and if one backend does not have support for the specific user input, the runtimes and speedups in the final table would be 0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
100
   "id": "9a380859",
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
   "metadata": {},
   "outputs": [],
   "source": [
    "model_configs = {\n",
    "    #   test:             b,  h, hg,   d,   sq,  skv,   p,     mask,              bias\n",
    "    \"test_0\": ModelConfig(2, 16, 16,  64,  512,  512, 0.0, \"no_mask\",         \"no_bias\"), # short seq\n",
    "    \"test_1\": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0,  \"causal\",         \"no_bias\"), # longer seq, mask\n",
    "    \"test_2\": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0,  \"causal\", \"post_scale_bias\"), # bias\n",
    "    \"test_3\": ModelConfig(2, 32,  4, 128, 8192, 8192, 0.0,  \"causal\",         \"no_bias\"), # GQA\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
116
   "id": "0584bb01",
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Device 0: NVIDIA H100 PCIe GPU, sm90 compute capability, 79.1GB memory\n",
      "Running test_0 with cuDNN attention and flash-attention...\n",
      "Running test_1 with cuDNN attention and flash-attention...\n",
      "Running test_2 with cuDNN attention...\n",
      "Running test_3 with cuDNN attention and flash-attention...\n",
      "\n",
      "        cuDNN fwd+bwd (ms)  flash-attn fwd+bwd (ms)  cuDNN vs flash speedup\n",
      "test_0              0.0638                   0.0858                  1.3454\n",
      "test_1              0.5415                   0.7496                  1.3842\n",
      "test_2              1.2302                   0.0000                  0.0000\n",
      "test_3             12.0122                  19.0716                  1.5877\n"
     ]
    }
   ],
   "source": [
    "!cd ../../../benchmarks/attention/ && python benchmark_attention.py"
   ]
  },
  {
   "cell_type": "markdown",
143
   "id": "45e53fc9",
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
   "metadata": {},
   "source": [
    "## 2. Backend Selection\n",
    "\n",
    "Given the various attention backends, Transformer Engine has a selection logic in place to choose the most appropriate backend for a particular set of user inputs and runtime environment. The selection logic is based on both backend availability and backend performance.\n",
    "\n",
    "Backend availability is determined by factors such as model configuration, training hyper-parameters, software versions, and the GPU architecture in question. For example, some considerations are the sequence length, number of attention heads, head size, attention mask type, attention bias type, training or inference mode, self or cross attention, MHA or MQA/GQA, `flash-attn`/cuDNN library versions, and the compute capability of the GPU.\n",
    "\n",
    "When there are multiple backends available, Transformer Engine makes backend selection based on performance. In general, there are a few rules being followed in our selection logic (see table below). As we monitor the performance of different backends, the selection logic may change.\n",
    "\n",
    "| Framework | Selection Order                                                                                                                              |\n",
    "| :-------- | :--------------------- |\n",
    "| PyTorch   | sm90: cuDNN attention > flash-attention > PyTorch-native attention<br>sm80: flash-attention > cuDNN attention > PyTorch-native attention<br>cuDNN attention: sub-backend 1 > sub-backend 0 |\n",
    "| JAX       | cuDNN attention > JAX-native attention |\n",
    "| PaddlePaddle    | cuDNN attention > PaddlePaddle-native attention |\n"
   ]
  },
  {
   "cell_type": "markdown",
163
   "id": "6dfeade3",
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
   "metadata": {},
   "source": [
    "### 2.1 Debug Information\n",
    "\n",
    "To find out which backend is being used during runtime, users can turn on these debugging flags. Logging is done using the `logging` package.\n",
    "```\n",
    "NVTE_DEBUG       = 0/1   # disables/enables debugging\n",
    "NVTE_DEBUG_LEVEL = 0/1/2 # enables logging.WARNING/INFO/DEBUG-level messages\n",
    "```\n",
    "<div class=\"alert alert-info\">\n",
    "<b>Note:</b> These flags are supported in PyTorch only as of Transformer Engine 1.7. JAX and PaddlePaddle support is expected to be added in the future.\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
180
   "id": "7e3b7981",
181
182
183
184
185
186
187
188
   "metadata": {},
   "source": [
    "The [example_attention.py](./example_attention.py) script runs a very basic model with two attention backends, cuDNN attention and flash-attention. Here `NVTE_DEBUG_LEVEL=1` allows us to find out which backend/sub-backend was actually used during runtime."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
189
   "id": "961c51d4",
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Run cuDNN attention...\n",
      "[INFO     | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)\n",
      "\n",
      "Run flash-attention...\n",
      "[INFO     | DotProductAttention]: Running with FlashAttention backend \n",
      "\n",
      "Test passed.\n"
     ]
    }
   ],
   "source": [
    "!NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python example_attention.py"
   ]
  },
  {
   "cell_type": "markdown",
213
   "id": "11bfbbd7",
214
215
216
217
218
219
220
221
   "metadata": {},
   "source": [
    "To collect more information, users can turn on `NVTE_DEBUG_LEVEL=2`. In this example, it allows us to find out more about the run config. Users are encouraged to provide if users intend to file a bug with Transformer Engine. For example, "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
222
   "id": "162a2be1",
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Run cuDNN attention...\n",
      "[DEBUG    | DotProductAttention]: Disabling FlashAttention due to NVTE_FLASH_ATTN=0\n",
      "[INFO     | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)\n",
      "[DEBUG    | DotProductAttention]: Running with {'compute_capability': 'sm90', 'q_dtype': torch.bfloat16, 'k_dtype': torch.bfloat16, 'v_dtype': torch.bfloat16, 'q_shape': [2, 512, 16, 64], 'k_shape': [2, 512, 16, 64], 'v_shape': [2, 512, 16, 64], 'qkv_format': 'bshd', 'qkv_layout': 'bshd_bshd_bshd', 'mask_type': 'no_mask', 'bias_type': 'no_bias', 'bias_shape': None, 'dropout': 0.0, 'context_parallel': False, 'is_training': True, 'transformer_engine_version': <Version('1.8.0.dev0')>, 'flash_attn_version': <Version('2.4.2')>, 'cudnn_version': '9.2.0'}\n",
      "[DEBUG    | FusedAttnFunc      ]: Running forward in torch.bfloat16\n",
      "[DEBUG    | FusedAttnFunc      ]: Running backward in torch.bfloat16\n",
      "\n",
      "Run flash-attention...\n",
      "[DEBUG    | DotProductAttention]: Disabling FusedAttention due to NVTE_FUSED_ATTN=0\n",
      "[INFO     | DotProductAttention]: Running with FlashAttention backend \n",
      "[DEBUG    | DotProductAttention]: Running with {'compute_capability': 'sm90', 'q_dtype': torch.bfloat16, 'k_dtype': torch.bfloat16, 'v_dtype': torch.bfloat16, 'q_shape': [2, 512, 16, 64], 'k_shape': [2, 512, 16, 64], 'v_shape': [2, 512, 16, 64], 'qkv_format': 'bshd', 'qkv_layout': 'bshd_bshd_bshd', 'mask_type': 'no_mask', 'bias_type': 'no_bias', 'bias_shape': None, 'dropout': 0.0, 'context_parallel': False, 'is_training': True, 'transformer_engine_version': <Version('1.8.0.dev0')>, 'flash_attn_version': <Version('2.4.2')>, 'cudnn_version': '9.2.0'}\n",
      "\n",
      "Test passed.\n"
     ]
    }
   ],
   "source": [
    "!NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2 python example_attention.py"
   ]
  },
  {
   "cell_type": "markdown",
252
   "id": "779a51e6",
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
   "metadata": {},
   "source": [
    "### 2.2 User Control\n",
    "\n",
    "Users usually do not need to worry about the backend selection. However, if there is a convergence or performance issue encountered, Transformer Engine provides a few other environment variables for users to experiment with different backends.\n",
    "\n",
    "**flash-attention or cuDNN attention:**\n",
    "Users can enable/disable the flash-attention backend or cuDNN attention backend via the following two environment variables in PyTorch.\n",
    "```\n",
    "NVTE_FLASH_ATTN = 0 # disables flash-attention; default = 1\n",
    "NVTE_FUSED_ATTN = 0 # disables cuDNN attention; default = 1\n",
    "```\n",
    "\n",
    "**cuDNN attention sub-backends:**\n",
    "This environment variable allows users to express their preference of cuDNN attention sub-backends. However, the elected sub-backend will only be used *if* it is eligible, i.e. if it has support for the provided inputs and runtime environment.\n",
    "```\n",
    "NVTE_FUSED_ATTN_BACKEND = 0/1/2 # user preference of cuDNN sub-backend\n",
    "```\n",
    "\n",
    "**Execution paths of cuDNN sub-backend 1:**\n",
    "cuDNN attention sub-backend 1 also offers two execution paths: workspace optimization path and non-workspace optimization path. The workspace optimization path requires a larger amount of global memory, provides determinism, and offers bias gradient support. Before cuDNN 9.0, it also has 20-30% performance advantage over the non-workspace optimization path. But after cuDNN 9.0, it is 20-30% slower than the non-workspace optimization path.\n",
    "\n",
    "Users can experiment with these two paths through the following environment variable. However, please be aware of the possible Out-Of-Memory risks.\n",
    "```\n",
277
278
279
280
281
282
283
    "Before cuDNN 9.0:\n",
    "    NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT = 0 # disables workspace optimization path\n",
    "    NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT = 1 # enables workspace optimization path\n",
    "\n",
    "After cuDNN 9.0:\n",
    "    NVTE_ALLOW_NONDETERMINISTIC_ALGO = 1 # disables workspace optimization path\n",
    "    NVTE_ALLOW_NONDETERMINISTIC_ALGO = 0 # enables workspace optimization path\n",
284
285
    "```\n",
    "<div class=\"alert alert-info\">\n",
286
    "<b>Note:</b> Environment variables <code>NVTE_FLASH_ATTN</code>, <code>NVTE_FUSED_ATTN</code>, <code>NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT</code> and <code>NVTE_ALLOW_NONDETERMINISTIC_ALGO</code> are only supported in PyTorch, and will be added to JAX and PaddlePaddle in the future.\n",
287
288
289
290
291
292
293
294
295
296
297
    "</div>\n",
    "\n",
    "### 2.3 Example Tests\n",
    "\n",
    "Our [unit tests](../../tests/) demonstrate the use of Transformer Engine dot product attention APIs. Users are encouraged to use them as a template when integrating Transformer Engine to their ML workflows.\n",
    "\n",
    "For example, in PyTorch, [test_dot_product_attention](../../tests/pytorch/fused_attention/test_fused_attn.py) offers a variety of use cases of `pytorch.DotProductAttention`, from data types, model configs, checkpointing, to QKV layouts."
   ]
  },
  {
   "cell_type": "markdown",
298
   "id": "ccd5650d",
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
   "metadata": {},
   "source": [
    "## 3. Backend Support\n",
    "\n",
    "Transformer Engine supports commonly-used features such as self and cross attention, FP16/BF16 precisions, dropout, and checkpointing. But it also offers a range of other features. As of v1.7, Transformer Engine's attention backends have the following support matrix.\n",
    "\n",
    "| Attention Backend | Precision | Architecture | Sliding Window Attention | MQA/GQA | Context Parallelism | Deterministic |\n",
    "| :---------------- | :-------- | :----------- | :----------------------- | :------ | :------------------ | :------------ |\n",
    "| cuDNN attention<br>(PyTorch, JAX, PaddlePaddle) | PyTorch: BF16, FP16, FP8<br>JAX, PaddlePaddle: BF16, FP16 |  sm80+ | No  | Yes | `bshd`,`sbhd`: Yes<br>`thd`: No | Sub-backend 0, 2: Yes<br>Sub-backend 1: Yes, if workspace optimization path |\n",
    "| flash-attention<br>(PyTorch)           | BF16, FP16      |  sm80+ | Yes | Yes | `bshd`,`thd`: Yes<br>`sbhd`: No  | Yes, if `deterministic=True`                                                                                    |\n",
    "| Framework-native attention<br>(PyTorch, JAX, PaddlePaddle) | BF16, FP16, FP32 |  Any   | No, unless used as a mask  | Yes | No                                  | Yes |\n",
    "\n",
    "Some unit tests are provided to serve as a starting point for integrating such features into users' models. For example,\n",
    "- sliding window attention: [test_dpa_swa](../../tests/pytorch/fused_attention/test_fused_attn.py)\n",
    "- MQA/GQA: [test_te_layer_mqa_gqa](../../tests/pytorch/fused_attention/test_fused_attn.py)\n",
    "- context parallelism: [test_cp_with_fused_attention](../../tests/pytorch/fused_attention/test_fused_attn_with_cp.py), [test_cp_with_flash_attention](../../tests/pytorch/fused_attention/test_fused_attn_with_cp.py)"
   ]
  },
  {
   "cell_type": "markdown",
319
   "id": "8439b389",
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
   "metadata": {},
   "source": [
    "### 3.1 QKV Layout\n",
    "\n",
    "Transformer Engine supports various layouts of the query `q`, key `k`, value `v` tensors. It has defined 15 QKV layouts, which are grouped into 3 QKV formats and 5 QKV layout groups to help with similar memory/computational operations across different layouts. The mapping relationships of these layouts and groups are,\n",
    "\n",
    "| `qkv_layout` &nbsp; &nbsp; &nbsp; &nbsp; | `qkv_layout_group`=`3hd` | `h3d` | `hd_2hd` | `hd_h2d` | `hd_hd_hd` |\n",
    "| ----------: | -----------: | -----: | ----------: | ----------: | -------------: |\n",
    "| `qkv_format`=`sbhd` | `sb3hd`                | `sbh3d` | `sbhd_sb2hd` | `sbhd_sbh2d` | `sbhd_sbhd_sbhd` |\n",
    "| `bshd` | `bs3hd`                | `bsh3d` | `bshd_bs2hd` | `bshd_bsh2d` | `bshd_bshd_bshd` |\n",
    "| `thd`  | `t3hd`                 | `th3d`  | `thd_t2hd`   | `thd_th2d`   | `thd_thd_thd`    |\n",
    "\n",
    "The notation system is that `b` stands for the batch size, `s` sequence length, `h` number of attention heads, `d` head dimension, and `t` the total number of tokens in the batch, i.e. `t = sum(s_i) for i in 0,...,b-1`. Here are a few examples of the layouts and their explanations to help clarify the definition.\n",
    "\n",
    "**`qkv_layout`=`sb3hd`:**\n",
    "`q`, `k`, `v` are sequence first, i.e. `s` is the leading dimension in each tensor. They are different slices of one tensor `qkv`: `q, k, v = [qkv[:,:,i,:,:] for i in range(3)]`. They are interleaved at the `h * d` dimension.\n",
    "\n",
    "**`qkv_layout`=`bshd_bsh2d`:**\n",
    "`q`, `k`, `v` are batch first, i.e. `b` is the leading dimension in each tensor. `q` is contiguous, and `k`, `v` are different slices of tensor `kv`: `k, v = [kv[:,:,:,i,:] for i in range(2)]`. `k`, `v` are interleaved at the `d` dimension.\n",
    "\n",
    "The `s` and `h` in `bsh2d` are the max sequence length and number of heads for `k`, `v`, which can be different from the `s` and `h` in `bshd` for `q`. We denoted them as the same for brevity reasons. Transformer Engine does differentiate their values for actual execution.\n",
    "\n",
    "**`qkv_layout`=`thd_thd_thd`:**\n",
    "`q`, `k`, `v` have variable sequence lengths in a batch. They are all contiguous and have no interleaving.\n",
    "\n",
    "As of v1.7, Transformer Engine has the following support matrix.\n",
    "\n",
    "| Backend | Supported QKV Formats | Notes |\n",
    "| :--------------- | :-------------------- | :------ |\n",
    "| flash-attention | `bshd`, `sbhd`, `thd`<br>(`sbhd` requires transpose operations) | PyTorch: 3 formats, i.e. 15 layouts|\n",
    "| cuDNN attention  | `bshd`, `sbhd`, `thd`  | PyTorch: 3 formats, i.e. 15 layouts<br>JAX, PaddlePaddle: `bs3hd`, `bshd_bs2hd`, `bshd_bshd_bshd` layouts |\n",
    "| Framework-native attention | `bshd`, `sbhd`<br>(`sbhd` requires transpose operations) | PyTorch, JAX, PaddlePaddle: 2 formats, i.e. 10 layouts |\n",
    "\n",
353
    "Some example usage of the different layouts can be found at [test_dpa_qkv_layout](../../tests/pytorch/fused_attention/test_fused_attn.py) and [test_dpa_qkv_layout_thd](../../tests/pytorch/fused_attention/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.get_qkv_layout](../../transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n",
354
355
    "\n",
    "<div class=\"alert alert-info\">\n",
356
    "<b>Note:</b> When RoPE is employed, the <code>qkv_layout</code> may change in Transformer Engine PyTorch through [get_qkv_layout](../../transformer_engine/pytorch/attention.py). This is due to the in-place nature of our RoPE implementations. We convert `q`, `k`, `v` tensors from their initial layout to the corresponding <code>hd_hd_hd</code> layout. For example, from <code>sbh3d</code> in <code>pytorch.MultiHeadAttention</code> before RoPE, to <code>sbhd_sbhd_sbhd</code> in <code>pytorch.DotProductAttention</code> after RoPE.\n",
357
358
359
360
361
    "</div>\n"
   ]
  },
  {
   "cell_type": "markdown",
362
   "id": "0290f8e9",
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
   "metadata": {},
   "source": [
    "### 3.2 Attention Mask\n",
    "\n",
    "Transformer Engine supports 5 mask types, and all the masks are defined as `True` masking out the corresponding element and `False` including the corresponding element in attention calculation.\n",
    "- `no_mask`, `padding`, `causal`, `padding_causal` (equivalent to `causal_padding`), `arbitrary`\n",
    "\n",
    "Different backends offer different support for attention mask. As of Transformer Engine 1.7,\n",
    "\n",
    "| Backend          | Supported Mask Types  | Requires `attention_mask` |\n",
    "| :--------------- | :-------------------- | :------------------ |\n",
    "| flash-attention | `no_mask`, `causal`, `padding`, `padding_causal` | `no_mask`, `causal`: No<br>`padding`, `padding_causal`: Yes if `cu_seqlens` not provided|\n",
    "| cuDNN attention  | `no_mask`, `causal`, `padding`, `padding_causal` | `no_mask`, `causal`: No<br>`padding`, `padding_causal`: Yes if `cu_seqlens` not provided|\n",
    "| Framework-native attention | `no_mask`, `causal`, `arbitrary` | `no_mask`, `causal`: No<br>`arbitrary`: Yes |\n",
    "\n",
    "**`padding` and `padding_causal`:** For these two mask types, users need to provide sequence length information to help Transformer Engine figure out where each sequence ends in a batch. As of Transformer Engine 1.7, there are two options to do so in PyTorch and one in JAX and PaddlePaddle.\n",
    "\n",
    "* PyTorch: When both options are provided by the user, `cu_seqlens` is preferred as there is no extra conversion needed.\n",
    "  - `cu_seqlens`: Users can provide cumulative sequence length tensors `cu_seqlens_q` and `cu_seqlens_kv` for `q` and `k`/`v` to the flash-attention or cuDNN attention backend. An example of `cu_seqlens` is `[0, 2, 6, 7]` for a batch of 3 `[aa000, bbbb0, c0000]`.\n",
    "  - `attention_mask`: Users can also provide `attention_mask` as an alternative, which will then be converted to `cu_seqlens`. For self-attention, `attention_mask` should be one single tensor in shape `[batch_size, 1, 1, seqlen_q]`, and for cross-attention, `attention_mask` should be a list of two tensors in shapes `[batch_size, 1, 1, seqlen_q]` and `[batch_size, 1, 1, seqlen_kv]`, respectively.\n",
    "\n",
    "\n",
    "* JAX and PaddlePaddle: Users should provide the `attention_mask` tensor in shape `[batch_size, 1, seqlen_q, seqlen_kv]`.\n",
    "\n",
    "**`qkv_format`=`thd`:** Transformer Engine extracts the max sequence length information from `q`, `k`, `v` if `max_seqlen_q` and `max_seqlen_kv` are not provided. This requires GPU-CPU copy and synchronization operations. For performance reasons, please set `max_seqlen_q` and `max_seqlen_kv` to their appropriate values for `thd` QKV format.\n",
    "\n",
    "**`Arbitrary` mask:** cuDNN does not support `Arbitrary` mask type as of v9.0. However, users can convert the mask to a regular `post_scale_bias` bias and achieve the same functionality. An example script for this conversion is [arbitrary_mask_to_post_scale_bias.py](./arbitrary_mask_to_post_scale_bias.py).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
395
   "id": "b1b7cdd4",
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run with post_scale_bias:\n",
      "[DotProductAttention]: using cuDNN attention (sub-backend 1)\n",
      "Run with arbitrary mask:\n",
      "[DotProductAttention]: using unfused DPA\n",
      "Test passed!\n"
     ]
    }
   ],
   "source": [
    "!NVTE_DEBUG=1 python arbitrary_mask_to_post_scale_bias.py"
   ]
  },
  {
   "cell_type": "markdown",
416
   "id": "e045c284",
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
   "metadata": {},
   "source": [
    "Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](../../tests/pytorch/fused_attention/test_fused_attn.py).\n",
    "\n",
    "### 3.3 Attention Bias\n",
    "\n",
    "Transformer Engine supports 4 attention bias types, `no_bias`, `pre_scale_bias`, `post_scale_bias`, and `ALiBi` (with/without custom slopes). As of Transformer Engine 1.7, their support matrix is as follows.\n",
    "\n",
    "| Backend | Bias Type | Bias Shape | Bias Data Type | Architecture |\n",
    "| :------ | :-------- | :--------- | :--------- | :----------- |\n",
    "| flash-attention           | `no_bias`, `ALiBi` (with slopes) | N/A | ALiBi slopes: FP32 | sm80+ |\n",
    "| cuDNN attention            | PyTorch: `no_bias`, `post_scale_bias`, `ALiBi` (without slopes)<br>JAX, PaddlePaddle: `no_bias`, `post_scale_bias` | `post_scale_bias`: BHSS, 1HSS, B1SS, 11SS for forward, 1HSS for backward | `post_scale_bias`: same as QKV type<br>ALiBi slopes: FP32 | cuDNN 8.9.6+: sm90<br>cuDNN 9.0+: sm80+ |\n",
    "| Framework-native attention | `no_bias`, `pre_scale_bias`, `post_scale_bias` | `post_scale_bias`: BHSS, 1HSS, B1SS, 11SS | `post_scale_bias`: same as QKV type | sm80+ |\n",
    "\n",
    "The flash-attention backend enables `ALiBi` by asking user to pass in an `alibi_slopes` tensor, which can be the default slopes of vanilla ALiBi, or user-defined slopes. On the other hand, cuDNN attention supports `ALiBi` by taking in a `Boolean` flag, and it only supports vanilla ALiBi as of cuDNN 9.0.\n",
    "\n",
    "The framework-native backends do not explicitly support `ALiBi`, but users can convert `ALiBi` to a regular `post_scale_bias` bias to achieve the same effect. In PyTorch, this utility function, `transformer_engine.pytorch.attention.get_alibi`, can be used to help with the conversion.\n",
    "\n",
    "More examples of how to use the various attention biases are at [test_dpa_bias](../../tests/pytorch/fused_attention/test_fused_attn.py)."
   ]
  },
  {
   "cell_type": "markdown",
440
   "id": "8b8a4e40",
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
   "metadata": {},
   "source": [
    "### 3.4 FP8 Attention\n",
    "\n",
    "A unique feature of Transformer Engine is its FP8 support, not only for the `Linear` layers but also for dot product attention. Transformer Engine's FP8 attention support is through its cuDNN attention sub-backend 2. Recall Figure 1: the two `MatMul` operations are performed in FP8 for computational efficiency, and the `SoftMax` operation is performed in FP32 for numerical accuracy.\n",
    "\n",
    "Transformer Engine supports FP8 attention through its [C APIs](../../api/c/fused_attn.rst), and [PyTorch API](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention), as of v1.7. Its PyTorch API offers two options, both controlled through the FP8 recipe definition, `transformer_engine.common.recipe.DelayedScaling`.\n",
    "\n",
    "- `DelayedScaling.fp8_dpa=True (default=False)`: This enables the use of cuDNN attention sub-backend 2, when it does support the provided user inputs. The `FusedAttention` module for cuDNN attention takes FP16 or BF16 tensors as inputs, performs dot product attention in FP8, and returns attention logits in FP16 or BF16 (same as the input type). Casting operations are required to cast tensors to FP8 at the beginning, and back to FP16/BF16 at the end of the module.\n",
    "\n",
    "- `DelayedScaling.fp8_mha=True (default=False)`: This option, on top of `fp8_dpa=True`, removes the casting operations at the beginning and end of the `FusedAttention` module. This feature is experimental. \n",
    "\n",
    "Examples of using the two features are available at [test_dpa_fp8_vs_f16](../../tests/pytorch/fused_attention/test_fused_attn.py) and [test_mha_fp8_vs_f16](../../tests/pytorch/fused_attention/test_fused_attn.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`. This should result in the following print when the debug flags are turned on, `NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2`.\n",
    "```\n",
    "[DEBUG    | DotProductAttention]: Running with fp8_recipe.fp8_mha=False, fp8_recipe.fp8_dpa=True and NVTE_FP8_DPA_BWD=0\n",
    "[DEBUG    | FusedAttnFunc      ]: Running forward in FP8\n",
    "[DEBUG    | FusedAttnFunc      ]: Running backward in torch.bfloat16\n",
    "```"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}