attention.ipynb 38.1 KB
Newer Older
1
2
3
4
{
 "cells": [
  {
   "cell_type": "markdown",
5
   "id": "040f466a",
6
7
8
9
10
11
12
13
14
15
16
   "metadata": {},
   "source": [
    "# Attention Is All You Need!\n",
    "\n",
    "The core idea behind Transformer models is the attention mechanism [[1]](https://arxiv.org/abs/1706.03762). It identifies the correlation between words, selects the most important parts of the sentence to focus on, and captures meaningful patterns and dependencies in the data. Figure 1 shows a typical attention mechanism, where pre-softmax operations can be a combination of scaling, bias and masking while the post-softmax operation is often just dropout.\n",
    "\n",
    "<figure align=\"center\">\n",
    "<img src=\"dot_product_attention.png\" width=\"70%\">\n",
    "<figcaption> Figure 1: Dot product attention. </figcaption>\n",
    "</figure>\n",
    "\n",
17
    "[Transformer Engine](https://github.com/NVIDIA/TransformerEngine.git) supports the calculation of dot product attention in two frameworks, [PyTorch](https://github.com/pytorch/pytorch) and [JAX](https://github.com/google/jax). The API for each framework is\n",
18
    "\n",
19
    "- [transformer_engine.pytorch.DotProductAttention](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention)\n",
20
    "- [transformer_engine.jax.flax.DotProductAttention](../../api/jax.rst#transformer_engine.jax.flax.DotProductAttention)"
21
22
23
24
   ]
  },
  {
   "cell_type": "markdown",
25
   "id": "89a7d849",
26
27
28
29
30
   "metadata": {},
   "source": [
    "## 1. Attention Backends\n",
    "\n",
    "Transformer Engine provides multiple attention backends for each supported framework. The framework-native backends provide a robust baseline, while the fused, GPU-optimized implementations offer more performance. For example, the flash-attention and cuDNN attention backends in PyTorch. The framework-native backends are often named with \"unfused\", while the more optimized backends are \"fused\" or \"flash\".\n",
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    "<table class=\"docutils align-default\">\n",
    "    <tr>\n",
    "    <th>Framework</th>\n",
    "    <th>Backend (Module Name)</th>\n",
    "    <th>Module Location</th>\n",
    "  </tr>\n",
    "  <tr>\n",
    "    <td rowspan=\"3\">PyTorch</td>\n",
    "    <td>cuDNN attention (`FusedAttention`)</td>\n",
    "    <td rowspan=\"3\"> [transformer_engine.pytorch.attention](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py)</td>\n",
    "  </tr>\n",
    "  <tr>\n",
    "    <td> flash-attention (`FlashAttention`)</td>\n",
    "  </tr>\n",
    "  <tr>\n",
    "      <td>\n",
    "         PyTorch-native attention (`UnfusedDotProductAttention`)\n",
    "    </td>  \n",
    "  </tr>\n",
    "  <tr>\n",
    "    <td rowspan=\"2\">JAX</td>\n",
    "    <td>cuDNN attention (`_FusedDotProductAttention`)</td>\n",
    "    <td rowspan=\"2\">[transformer_engine.jax.flax.transformer](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/jax/flax/transformer.py)</td>\n",
    "  </tr>\n",
    "  <tr>\n",
    "      <td>JAX-native attention (`_UnfusedDotProductAttention`)</td>\n",
    "  </tr>\n",
    "          \n",
    "</table>"
60
61
62
63
   ]
  },
  {
   "cell_type": "markdown",
64
   "id": "c90a2573",
65
66
67
68
69
70
71
72
73
74
75
76
77
   "metadata": {},
   "source": [
    "### 1.1 Flash vs. Non-Flash\n",
    "\n",
    "The attention calculation has quadratic computational and memory complexities to the sequence length. Its runtime and memory requirements quadruple, when the sequence length doubles. This presents a significant challenge to scale Transformer models up for longer contexts, in order to achieve higher model quality.\n",
    "\n",
    "Compared to the standard, non-flash algorithm, the flash algorithm [[2]](https://arxiv.org/abs/2205.14135) was proposed to reduce the memory scaling to linear and improve the computational efficiency through optimized memory accesses. It employs the following two distinctive techniques.\n",
    "\n",
    "- **Tiling:** The non-flash algorithm tries to process the query, key, value tensors in one single step, requiring large amounts of global memory and incurring high volumes of reads/writes between global memory and shared memory. The flash algorithm decomposes the input into several tiles, based on the available shared memory and register size, and it computes the softmax one tile at a time.\n",
    "\n",
    "- **Recomputation:** The non-flash algorithm stores the softmax matrix (quadratic to sequence length) to global memory for the backward pass, while the flash algorithm only saves the softmax normalization factors (linear to sequence length). This reduces the amount of memory required as well as the bandwidth utilization between global memory and shared memory. Even though there is extra computation incurred in order to recalculate the attention in the backward pass, the bandwidth savings still provide significant improvement in efficiency.\n",
    "\n",
    "<div class=\"alert alert-info\">\n",
78
    "<b>Note:</b> \n",
79
    "    \n",
80
    "Transformer Engine's flash-attention backend, available in PyTorch, and cuDNN attention backend (sub-backends 1 and 2), available in PyTorch and JAX, are both based on the flash algorithm.\n",
81
82
83
84
85
    "</div>\n"
   ]
  },
  {
   "cell_type": "markdown",
86
   "id": "b5ce567d",
87
88
89
90
91
92
   "metadata": {},
   "source": [
    "### 1.2 flash-attention\n",
    "\n",
    "The flash-attention backend, available only in PyTorch, is a module wrapped around the public `flash-attn` package [[3]](https://github.com/Dao-AILab/flash-attention). \n",
    "\n",
93
    "The flash-attention backend supports `flash-attn`'s features as well as a few extra functionalities to facilitate the use of `flash-attn`, such as converting the `attention_mask` to cumulative sequence lengths `cu_seqlens` for `padding` mask use cases. Please see `transformer_engine.pytorch.attention.FlashAttention` for details.\n",
94
    "\n",
95
    "The `flash-attn` dependency is regularly updated in Transformer Engine. As of v2.0, Transformer Engine supports `flash-attn` 2.0.6+ (see [setup.py](https://github.com/NVIDIA/TransformerEngine/blob/main/setup.py)).\n",
96
    "\n",
97
    "To understand `flash-attn`'s performance, please refer to their benchmarks [here](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#performance).\n",
98
99
100
    "\n",
    "### 1.3 cuDNN Attention\n",
    "\n",
101
    "The cuDNN attention backend, available in PyTorch and JAX, offers another high-performance solution to the attention calculation. It requires [cuDNN](https://developer.nvidia.com/cudnn) to run, and has several sub-backends to support the different precisions and sequence lengths.\n",
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    "\n",
    "<table class=\"docutils align-default\">\n",
    "    <tr>\n",
    "    <th>Sub-Backend</th>\n",
    "    <th>Algorithm</th>\n",
    "    <th>Precision</th>\n",
    "    <th>Sequence Length</th>\n",
    "    <th>Architecture</th>\n",
    "    <th>Additional info</th>\n",
    "  </tr>\n",
    "  <tr>\n",
    "    <td>0</td>\n",
    "    <td>Non-Flash</td>\n",
    "    <td>BF16/FP16</td>\n",
    "    <td> &le;512 </td>\n",
    "    <td> sm80, 90 </td>\n",
    "    <td> [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/latest/developer/graph-api.html#fused-attention-fprop)</td>  \n",
    "  </tr>\n",
    "  <tr>\n",
    "    <td>1</td>\n",
    "    <td>Flash</td>\n",
    "    <td>BF16/FP16</td>\n",
    "    <td> Any </td>\n",
    "    <td> sm80+ </td>\n",
    "    <td> [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/latest/developer/graph-api.html#fused-flash-attention-fprop),\n",
    "      [cudnn-frontend](https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md#scaled-dot-product-attention)\n",
    "      </td>\n",
    "  </tr>\n",
    "  <tr>\n",
    "    <td rowspan=\"2\">2</td>\n",
    "    <td rowspan=\"2\">Flash</td>\n",
    "    <td rowspan=\"2\">FP8</td>\n",
    "    <td> cuDNN pre-9.0: &le;512 </td>\n",
    "    <td>cuDNN pre-9.0: sm90</td>\n",
    "    <td></td>\n",
    "  </tr>\n",
    "  <tr>\n",
    "    <td> cuDNN 9.0+: Any</td>\n",
    "    <td> cuDNN 9.0+: sm90+ </td>\n",
    "    <td> cuDNN 9.0+: [cudnn-frontend](https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md#scaled-dot-product-attention-fp8)\n",
    "    </td>  \n",
    "  </tr>\n",
    "</table>\n",
145
    "\n",
146
    "The cuDNN attention backend and flash-attention backend have several notable differences. As of Transformer Engine 2.0, cuDNN 9.3 and `flash-attn` 2.4.2,\n",
147
    "\n",
148
    "- flash-attention only supports the PyTorch framework while cuDNN attention supports PyTorch and JAX.\n",
149
    "- flash-attention supports BF16, FP16 precisions while cuDNN attention also supports FP8 (through its sub-backend 2).\n",
150
    "- flash-attention supports `bshd`, `thd` input formats, without any transposes, and `sbhd` format, with transposes, while cuDNN attention supports all three formats without transposes (see Section 3.1 for more details).\n",
151
    "- flash-attention does not support `post_scale_bias`, and cuDNN attention does.\n",
152
153
    "- flash-attention supports KV-caching and paged attention, and cuDNN attention does not.\n",
    "- flash-attention uses bottom right diagonal for `causal` mask in cross attention (see [change log](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#21-change-behavior-of-causal-flag)), and cuDNN attention supports both top left and bottom right.\n",
154
155
    "- flash-attention outperforms cuDNN attention on Ampere architectures, and cuDNN attention has 20-50% advantages on Hopper architectures, based on our benchmarks for a number of commonly-used model configurations.\n",
    "\n",
156
    "To compare cuDNN attention and flash-attention, users can modify the `model_configs` dictionary in [benchmarks/attention/benchmark_attention.py](https://github.com/NVIDIA/TransformerEngine/blob/main/benchmarks/attention/benchmark_attention.py) to collect performance numbers. The script runs each entry in `model_configs` for `num_iters` times, each time with one forward pass and one backward pass. Both backends are tried, and if one backend does not have support for the specific user input, the runtimes and speedups in the final table would be 0."
157
158
159
160
161
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
162
   "id": "c5b8e3d7",
163
164
165
166
167
168
169
170
171
172
173
174
175
176
   "metadata": {},
   "outputs": [],
   "source": [
    "model_configs = {\n",
    "    #   test:             b,  h, hg,   d,   sq,  skv,   p,     mask,              bias\n",
    "    \"test_0\": ModelConfig(2, 16, 16,  64,  512,  512, 0.0, \"no_mask\",         \"no_bias\"), # short seq\n",
    "    \"test_1\": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0,  \"causal\",         \"no_bias\"), # longer seq, mask\n",
    "    \"test_2\": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0,  \"causal\", \"post_scale_bias\"), # bias\n",
    "    \"test_3\": ModelConfig(2, 32,  4, 128, 8192, 8192, 0.0,  \"causal\",         \"no_bias\"), # GQA\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
177
178
   "execution_count": 1,
   "id": "50852cb5",
179
180
181
182
183
184
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
185
      "Device 0: NVIDIA H100 80GB HBM3 GPU, sm90 compute capability, 79.1GB memory\n",
186
187
188
189
190
191
      "Running test_0 with cuDNN attention and flash-attention...\n",
      "Running test_1 with cuDNN attention and flash-attention...\n",
      "Running test_2 with cuDNN attention...\n",
      "Running test_3 with cuDNN attention and flash-attention...\n",
      "\n",
      "        cuDNN fwd+bwd (ms)  flash-attn fwd+bwd (ms)  cuDNN vs flash speedup\n",
192
193
194
195
      "test_0              0.0340                   0.0468                  1.3786\n",
      "test_1              0.3664                   0.5850                  1.5968\n",
      "test_2              0.9332                   0.0000                  0.0000\n",
      "test_3              7.4875                  11.8879                  1.5877\n"
196
197
198
199
200
201
202
203
204
     ]
    }
   ],
   "source": [
    "!cd ../../../benchmarks/attention/ && python benchmark_attention.py"
   ]
  },
  {
   "cell_type": "markdown",
205
   "id": "9a615119",
206
207
208
209
210
211
212
213
214
215
   "metadata": {},
   "source": [
    "## 2. Backend Selection\n",
    "\n",
    "Given the various attention backends, Transformer Engine has a selection logic in place to choose the most appropriate backend for a particular set of user inputs and runtime environment. The selection logic is based on both backend availability and backend performance.\n",
    "\n",
    "Backend availability is determined by factors such as model configuration, training hyper-parameters, software versions, and the GPU architecture in question. For example, some considerations are the sequence length, number of attention heads, head size, attention mask type, attention bias type, training or inference mode, self or cross attention, MHA or MQA/GQA, `flash-attn`/cuDNN library versions, and the compute capability of the GPU.\n",
    "\n",
    "When there are multiple backends available, Transformer Engine makes backend selection based on performance. In general, there are a few rules being followed in our selection logic (see table below). As we monitor the performance of different backends, the selection logic may change.\n",
    "\n",
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
    "<table class=\"docutils align-default\">\n",
    "    <tr>\n",
    "    <th>Framework</th>\n",
    "    <th>Selection Order</th>\n",
    "  </tr>\n",
    "  <tr>\n",
    "    <td rowspan=\"3\">PyTorch</td>\n",
    "    <td>sm90: cuDNN attention > flash-attention > PyTorch-native attention</td>\n",
    "  </tr>\n",
    "  <tr>\n",
    "    <td> sm80: flash-attention > cuDNN attention > PyTorch-native attention</td>\n",
    "  </tr>\n",
    "  <tr>\n",
    "      <td>\n",
    "         cuDNN attention: sub-backend 1 > sub-backend 0\n",
    "    </td>  \n",
    "  </tr>\n",
    "  <tr>\n",
    "    <td>JAX</td>\n",
    "    <td>cuDNN attention > JAX-native attention</td>\n",
    "  </tr>\n",
    "</table>"
238
239
240
241
   ]
  },
  {
   "cell_type": "markdown",
242
   "id": "e6c0f3f0",
243
244
245
246
   "metadata": {},
   "source": [
    "### 2.1 Debug Information\n",
    "\n",
247
    "To find out which backend is being used during runtime, we have the following two debugging flags. Logging is done by using the `logging` package.\n",
248
249
250
251
252
    "```\n",
    "NVTE_DEBUG       = 0/1   # disables/enables debugging\n",
    "NVTE_DEBUG_LEVEL = 0/1/2 # enables logging.WARNING/INFO/DEBUG-level messages\n",
    "```\n",
    "<div class=\"alert alert-info\">\n",
253
    "<b>Note:</b>\n",
254
    "    \n",
255
    "These flags are supported in PyTorch only as of Transformer Engine 2.0. JAX support is expected to be added in the future.\n",
256
257
258
259
260
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
261
   "id": "16660323",
262
263
   "metadata": {},
   "source": [
264
    "The example script [example_attention.py](https://raw.githubusercontent.com/NVIDIA/TransformerEngine/main/docs/examples/attention/example_attention.py) runs a very basic model with two attention backends, cuDNN attention and flash-attention. Here `NVTE_DEBUG_LEVEL=1` allows us to find out which backend/sub-backend is used in runtime."
265
266
267
268
   ]
  },
  {
   "cell_type": "code",
269
270
   "execution_count": 24,
   "id": "906b8cf1",
271
272
273
274
275
276
277
278
279
280
281
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Run cuDNN attention...\n",
      "[INFO     | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)\n",
      "\n",
      "Run flash-attention...\n",
282
      "[INFO     | DotProductAttention]: Running with FlashAttention backend\n",
283
284
285
286
287
288
289
290
291
292
293
      "\n",
      "Test passed.\n"
     ]
    }
   ],
   "source": [
    "!NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python example_attention.py"
   ]
  },
  {
   "cell_type": "markdown",
294
   "id": "8ca99461",
295
296
   "metadata": {},
   "source": [
297
    "`NVTE_DEBUG_LEVEL=2` allows us to find out more about the backend selection logic. Users are encouraged to double check the `config` and provide it to the Transformer Engine team if they would like to file a bug. "
298
299
300
301
   ]
  },
  {
   "cell_type": "code",
302
303
   "execution_count": 23,
   "id": "d3637094",
304
305
306
307
308
309
310
311
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Run cuDNN attention...\n",
312
      "[DEBUG    | DotProductAttention]: Running with config={'transformer_engine_version': '1.10.0.dev0+ee85a91', 'compute_capability': 'sm90', 'flash_attn_version': <Version('2.4.2')>, 'cudnn_version': '9.3.0', 'qkv_type': <class 'torch.Tensor'>, 'qkv_dtype': torch.bfloat16, 'qkv_layout': 'bshd_bshd_bshd', 'batch_size': 2, 'num_heads': 16, 'num_gqa_groups': 16, 'max_seqlen_q': 512, 'max_seqlen_kv': 512, 'head_dim_qk': 64, 'head_dim_v': 64, 'attn_mask_type': 'no_mask', 'window_size': (-1, -1), 'alibi_slopes_shape': None, 'core_attention_bias_type': 'no_bias', 'core_attention_bias_shape': None, 'core_attention_bias_requires_grad': False, 'pad_between_seqs': False, 'attention_dropout': 0.0, 'context_parallel': False, 'deterministic': False, 'is_training': True, 'fp8': False, 'fp8_meta': {'fp8_checkpoint': False, 'fp8_group': None, 'recipe': margin=0, format=HYBRID, amax_history_len=1024, wgrad_override=False, fp8_dpa=False, fp8_mha=False}}\n",
313
      "[DEBUG    | DotProductAttention]: Disabling FlashAttention due to NVTE_FLASH_ATTN=0\n",
314
315
      "[DEBUG    | DotProductAttention]: Available backends = {FlashAttention=False, FusedAttention=True (sub-backend 1), UnfusedDotProductAttention=True}\n",
      "[DEBUG    | DotProductAttention]: Selected backend = FusedAttention (sub-backend 1)\n",
316
317
318
      "[INFO     | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)\n",
      "\n",
      "Run flash-attention...\n",
319
      "[DEBUG    | DotProductAttention]: Running with config={'transformer_engine_version': '1.10.0.dev0+ee85a91', 'compute_capability': 'sm90', 'flash_attn_version': <Version('2.4.2')>, 'cudnn_version': '9.3.0', 'qkv_type': <class 'torch.Tensor'>, 'qkv_dtype': torch.bfloat16, 'qkv_layout': 'bshd_bshd_bshd', 'batch_size': 2, 'num_heads': 16, 'num_gqa_groups': 16, 'max_seqlen_q': 512, 'max_seqlen_kv': 512, 'head_dim_qk': 64, 'head_dim_v': 64, 'attn_mask_type': 'no_mask', 'window_size': (-1, -1), 'alibi_slopes_shape': None, 'core_attention_bias_type': 'no_bias', 'core_attention_bias_shape': None, 'core_attention_bias_requires_grad': False, 'pad_between_seqs': False, 'attention_dropout': 0.0, 'context_parallel': False, 'deterministic': False, 'is_training': True, 'fp8': False, 'fp8_meta': {'fp8_checkpoint': False, 'fp8_group': None, 'recipe': margin=0, format=HYBRID, amax_history_len=1024, wgrad_override=False, fp8_dpa=False, fp8_mha=False}}\n",
320
      "[DEBUG    | DotProductAttention]: Disabling FusedAttention due to NVTE_FUSED_ATTN=0\n",
321
322
323
      "[DEBUG    | DotProductAttention]: Available backends = {FlashAttention=True, FusedAttention=False, UnfusedDotProductAttention=True}\n",
      "[DEBUG    | DotProductAttention]: Selected backend = FlashAttention\n",
      "[INFO     | DotProductAttention]: Running with FlashAttention backend\n",
324
325
326
327
328
329
330
331
332
333
334
      "\n",
      "Test passed.\n"
     ]
    }
   ],
   "source": [
    "!NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2 python example_attention.py"
   ]
  },
  {
   "cell_type": "markdown",
335
   "id": "611d8fdb",
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
   "metadata": {},
   "source": [
    "### 2.2 User Control\n",
    "\n",
    "Users usually do not need to worry about the backend selection. However, if there is a convergence or performance issue encountered, Transformer Engine provides a few other environment variables for users to experiment with different backends.\n",
    "\n",
    "**flash-attention or cuDNN attention:**\n",
    "Users can enable/disable the flash-attention backend or cuDNN attention backend via the following two environment variables in PyTorch.\n",
    "```\n",
    "NVTE_FLASH_ATTN = 0 # disables flash-attention; default = 1\n",
    "NVTE_FUSED_ATTN = 0 # disables cuDNN attention; default = 1\n",
    "```\n",
    "\n",
    "**cuDNN attention sub-backends:**\n",
    "This environment variable allows users to express their preference of cuDNN attention sub-backends. However, the elected sub-backend will only be used *if* it is eligible, i.e. if it has support for the provided inputs and runtime environment.\n",
    "```\n",
    "NVTE_FUSED_ATTN_BACKEND = 0/1/2 # user preference of cuDNN sub-backend\n",
    "```\n",
    "\n",
    "**Execution paths of cuDNN sub-backend 1:**\n",
    "cuDNN attention sub-backend 1 also offers two execution paths: workspace optimization path and non-workspace optimization path. The workspace optimization path requires a larger amount of global memory, provides determinism, and offers bias gradient support. Before cuDNN 9.0, it also has 20-30% performance advantage over the non-workspace optimization path. But after cuDNN 9.0, it is 20-30% slower than the non-workspace optimization path.\n",
    "\n",
    "Users can experiment with these two paths through the following environment variable. However, please be aware of the possible Out-Of-Memory risks.\n",
    "```\n",
360
361
362
363
364
365
366
    "Before cuDNN 9.0:\n",
    "    NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT = 0 # disables workspace optimization path\n",
    "    NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT = 1 # enables workspace optimization path\n",
    "\n",
    "After cuDNN 9.0:\n",
    "    NVTE_ALLOW_NONDETERMINISTIC_ALGO = 1 # disables workspace optimization path\n",
    "    NVTE_ALLOW_NONDETERMINISTIC_ALGO = 0 # enables workspace optimization path\n",
367
368
    "```\n",
    "<div class=\"alert alert-info\">\n",
369
370
    "<b>Note</b>\n",
    "    \n",
371
    "Environment variables <code>NVTE_FLASH_ATTN</code>, <code>NVTE_FUSED_ATTN</code>, <code>NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT</code> and <code>NVTE_ALLOW_NONDETERMINISTIC_ALGO</code> are only supported in PyTorch, and will be added to JAX in the future.\n",
372
373
374
375
    "</div>\n",
    "\n",
    "### 2.3 Example Tests\n",
    "\n",
376
    "Our [unit tests](https://github.com/NVIDIA/TransformerEngine/tree/main/tests) demonstrate the use of Transformer Engine dot product attention APIs. Users are encouraged to use them as a template when integrating Transformer Engine to their ML workflows.\n",
377
    "\n",
378
    "For example, in PyTorch, [test_dot_product_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) offers a variety of use cases of `pytorch.DotProductAttention`, from data types, model configs, checkpointing, to QKV layouts."
379
380
381
382
   ]
  },
  {
   "cell_type": "markdown",
383
   "id": "e60a2a3e",
384
385
386
387
   "metadata": {},
   "source": [
    "## 3. Backend Support\n",
    "\n",
388
    "Transformer Engine supports commonly-used features such as self and cross attention, FP16/BF16 precisions, dropout, and checkpointing. But it also offers a range of other features. As of v2.0, Transformer Engine's attention backends have the following support matrix.\n",
389
    "\n",
390
391
392
393
394
    "| Attention Backend | Precision | Architecture | Sliding Window Attention | MQA/GQA | Multi-Latent Attention | Context Parallelism | Determinism Possible |\n",
    "| :---------------- | :-------- | :----------- | :----------------------- | :------ | :--------------------- | :------------------ | :------------ |\n",
    "| cuDNN attention (all frameworks) | BF16, FP16, FP8 (PyTorch only) |  sm80+ | No  | Yes | Yes | Yes (`bshd`,`sbhd`, `thd`) | Yes |\n",
    "| flash-attention (PyTorch)           | BF16, FP16      |  sm80+ | Yes | Yes | No | Yes (`bshd`,`thd`)  | Yes                                                                                    |\n",
    "| Framework-native attention | BF16, FP16, FP32 |  Any   | No, unless used as a mask  | Yes | Yes (PyTorch only) | No                                  | Yes |\n",
395
396
    "\n",
    "Some unit tests are provided to serve as a starting point for integrating such features into users' models. For example,\n",
397
398
    "- sliding window attention: [test_dpa_swa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n",
    "- MQA/GQA: [test_te_layer_mqa_gqa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n",
399
    "- Multi-Latent Attention: [test_dpa_mla](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n",
400
    "- context parallelism: [test_cp_with_fused_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py), [test_cp_with_flash_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py)"
401
402
403
404
   ]
  },
  {
   "cell_type": "markdown",
405
   "id": "fbdcb327",
406
407
408
409
410
411
412
413
414
415
416
417
418
419
   "metadata": {},
   "source": [
    "### 3.1 QKV Layout\n",
    "\n",
    "Transformer Engine supports various layouts of the query `q`, key `k`, value `v` tensors. It has defined 15 QKV layouts, which are grouped into 3 QKV formats and 5 QKV layout groups to help with similar memory/computational operations across different layouts. The mapping relationships of these layouts and groups are,\n",
    "\n",
    "| `qkv_layout` &nbsp; &nbsp; &nbsp; &nbsp; | `qkv_layout_group`=`3hd` | `h3d` | `hd_2hd` | `hd_h2d` | `hd_hd_hd` |\n",
    "| ----------: | -----------: | -----: | ----------: | ----------: | -------------: |\n",
    "| `qkv_format`=`sbhd` | `sb3hd`                | `sbh3d` | `sbhd_sb2hd` | `sbhd_sbh2d` | `sbhd_sbhd_sbhd` |\n",
    "| `bshd` | `bs3hd`                | `bsh3d` | `bshd_bs2hd` | `bshd_bsh2d` | `bshd_bshd_bshd` |\n",
    "| `thd`  | `t3hd`                 | `th3d`  | `thd_t2hd`   | `thd_th2d`   | `thd_thd_thd`    |\n",
    "\n",
    "The notation system is that `b` stands for the batch size, `s` sequence length, `h` number of attention heads, `d` head dimension, and `t` the total number of tokens in the batch, i.e. `t = sum(s_i) for i in 0,...,b-1`. Here are a few examples of the layouts and their explanations to help clarify the definition.\n",
    "\n",
420
    "**qkv_layout=sb3hd:**\n",
421
422
    "`q`, `k`, `v` are sequence first, i.e. `s` is the leading dimension in each tensor. They are different slices of one tensor `qkv`: `q, k, v = [qkv[:,:,i,:,:] for i in range(3)]`. They are interleaved at the `h * d` dimension.\n",
    "\n",
423
    "**qkv_layout=bshd_bsh2d:**\n",
424
425
426
427
    "`q`, `k`, `v` are batch first, i.e. `b` is the leading dimension in each tensor. `q` is contiguous, and `k`, `v` are different slices of tensor `kv`: `k, v = [kv[:,:,:,i,:] for i in range(2)]`. `k`, `v` are interleaved at the `d` dimension.\n",
    "\n",
    "The `s` and `h` in `bsh2d` are the max sequence length and number of heads for `k`, `v`, which can be different from the `s` and `h` in `bshd` for `q`. We denoted them as the same for brevity reasons. Transformer Engine does differentiate their values for actual execution.\n",
    "\n",
428
    "**qkv_layout=thd_thd_thd:**\n",
429
430
    "`q`, `k`, `v` have variable sequence lengths in a batch. They are all contiguous and have no interleaving.\n",
    "\n",
431
    "As of v2.0, Transformer Engine has the following support matrix.\n",
432
    "\n",
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
    "<table class=\"docutils align-default\">\n",
    "    <tr>\n",
    "    <th>Backend</th>\n",
    "    <th>Supported QKV Formats</th>\n",
    "    <th>Notes</th>\n",
    "  </tr>\n",
    "  <tr>\n",
    "    <td>flash-attention</td>\n",
    "    <td>`bshd`, `sbhd`, `thd`</td>\n",
    "    <td>PyTorch: 3 formats, i.e. 15 layouts</td>\n",
    "  </tr>\n",
    "  <tr>\n",
    "    <td rowspan=\"2\">cuDNN attention</td>\n",
    "    <td rowspan=\"2\">`bshd`, `sbhd`, `thd`</td>\n",
    "    <td>PyTorch: 3 formats, i.e. 15 layouts</td>\n",
    "  </tr>\n",
    "  <tr>\n",
    "      <td>\n",
451
    "         JAX: `bs3hd`, `bshd_bs2hd`, `bshd_bshd_bshd` layouts\n",
452
453
454
455
456
    "    </td>  \n",
    "  </tr>\n",
    "  <tr>\n",
    "    <td>Framework-native attention</td>\n",
    "    <td>`bshd`, `sbhd`</td>\n",
457
    "    <td>PyTorch, JAX: 2 formats, i.e. 10 layouts</td>\n",
458
459
460
    "  </tr>\n",
    "</table>\n",
    "\n",
461
    "Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.dot_product_attention.utils.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n",
462
463
    "\n",
    "<div class=\"alert alert-info\">\n",
464
465
466
    "<b>Note</b>\n",
    "    \n",
    "When RoPE is employed, the <code>qkv_layout</code> may change in Transformer Engine PyTorch through [get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py). This is due to the in-place nature of our RoPE implementations. We convert `q`, `k`, `v` tensors from their initial layout to the corresponding <code>hd_hd_hd</code> layout. For example, from <code>sbh3d</code> in <code>pytorch.MultiHeadAttention</code> before RoPE, to <code>sbhd_sbhd_sbhd</code> in <code>pytorch.DotProductAttention</code> after RoPE.\n",
467
468
469
470
471
    "</div>\n"
   ]
  },
  {
   "cell_type": "markdown",
472
   "id": "855d9616",
473
474
475
476
   "metadata": {},
   "source": [
    "### 3.2 Attention Mask\n",
    "\n",
477
    "Transformer Engine supports 7 mask types, and all the masks are defined as `True` masking out the corresponding element and `False` including the corresponding element in attention calculation.\n",
478
    "\n",
479
    "- `no_mask`, `padding`, `causal`, `causal_bottom_right`, `padding_causal`, `padding_causal_bottom_right`, `arbitrary`\n",
480
    "\n",
481
    "Different backends offer different support for attention mask. As of Transformer Engine 2.0,\n",
482
    "\n",
483
484
485
486
487
488
489
    "<table class=\"docutils align-default\">\n",
    "    <tr>\n",
    "    <th>Backend</th>\n",
    "    <th>Supported Mask Types</th>\n",
    "    <th>Requires `attention_mask`</th>\n",
    "  </tr>\n",
    "  <tr>\n",
490
491
492
    "    <td>flash-attention</td>\n",
    "    <td><li>`no_mask`, `causal` (self-attention),</li><li>`padding`, `padding_causal` (self-attention),</li><li>`causal_bottom_right`, `padding_causal_bottom_right`</li></td>\n",
    "    <td rowspan=\"3\"><li>`no_mask`, `causal` `causal_bottom_right`: No</li><li>`padding`, `padding_causal`, `padding_causal_bottom_right`: Yes if `cu_seqlens` not provided</li><li>`arbitrary`: Yes</li></td>\n",
493
494
    "  </tr>\n",
    "  <tr>\n",
495
496
497
    "    <td>cuDNN attention</td>\n",
    "    <td><li>`no_mask`, `causal`,</li><li>`padding`, `padding_causal`,</li><li>`causal_bottom_right`, `padding_causal_bottom_right`</li></td>\n",
    "    <td></td>\n",
498
499
    "  </tr>\n",
    "  <tr>\n",
500
    "    <td>Framework-native attention</td>\n",
501
    "    <td><li>All (PyTorch)</li><li>`no_mask`, `causal`, `padding` (Jax)</li></td>\n",
502
503
    "  </tr>\n",
    "    <tr>\n",
504
    "        <td></td>\n",
505
506
507
    "    </tr>\n",
    "</table>\n",
    "\n",
508
    "**Padding masks:** For `padding`, `padding_causal`, `padding_causal_bottom_right` mask types, users need to provide sequence length information to help Transformer Engine figure out where each sequence ends in a batch. As of Transformer Engine 2.0, there are two options to do so in PyTorch and one in JAX.\n",
509
510
511
512
513
514
    "\n",
    "* PyTorch: When both options are provided by the user, `cu_seqlens` is preferred as there is no extra conversion needed.\n",
    "  - `cu_seqlens`: Users can provide cumulative sequence length tensors `cu_seqlens_q` and `cu_seqlens_kv` for `q` and `k`/`v` to the flash-attention or cuDNN attention backend. An example of `cu_seqlens` is `[0, 2, 6, 7]` for a batch of 3 `[aa000, bbbb0, c0000]`.\n",
    "  - `attention_mask`: Users can also provide `attention_mask` as an alternative, which will then be converted to `cu_seqlens`. For self-attention, `attention_mask` should be one single tensor in shape `[batch_size, 1, 1, seqlen_q]`, and for cross-attention, `attention_mask` should be a list of two tensors in shapes `[batch_size, 1, 1, seqlen_q]` and `[batch_size, 1, 1, seqlen_kv]`, respectively.\n",
    "\n",
    "\n",
515
    "* JAX: Users should provide the `attention_mask` tensor in shape `[batch_size, 1, seqlen_q, seqlen_kv]`.\n",
516
    "\n",
517
    "**qkv_format=thd:** Transformer Engine extracts the max sequence length information from `q`, `k`, `v` if `max_seqlen_q` and `max_seqlen_kv` are not provided. This requires GPU-CPU copy and synchronization operations. For performance reasons, please set `max_seqlen_q` and `max_seqlen_kv` to their appropriate values for `thd` QKV format.\n",
518
    "\n",
519
    "**Arbitrary mask:** cuDNN does not support `Arbitrary` mask type as of v9.3. However, users can convert the mask to a regular `post_scale_bias` bias and achieve the same functionality. An example script for this conversion is [arbitrary_mask_to_post_scale_bias.py](https://raw.githubusercontent.com/NVIDIA/TransformerEngine/main/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py).\n"
520
521
522
523
   ]
  },
  {
   "cell_type": "code",
524
525
   "execution_count": 33,
   "id": "a1f25a9b",
526
527
528
529
530
531
532
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run with post_scale_bias:\n",
533
534
      "[INFO     | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)\n",
      "\n",
535
      "Run with arbitrary mask:\n",
536
537
      "[INFO     | DotProductAttention]: Running with UnfusedDotProductAttention backend\n",
      "\n",
538
539
540
541
542
      "Test passed!\n"
     ]
    }
   ],
   "source": [
543
    "!NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python arbitrary_mask_to_post_scale_bias.py"
544
545
546
547
   ]
  },
  {
   "cell_type": "markdown",
548
   "id": "dda4a589",
549
550
   "metadata": {},
   "source": [
551
    "Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py).\n",
552
553
554
    "\n",
    "### 3.3 Attention Bias\n",
    "\n",
555
    "Transformer Engine supports 4 attention bias types, `no_bias`, `pre_scale_bias`, `post_scale_bias`, and `ALiBi` (with/without custom slopes). As of Transformer Engine 2.0, their support matrix is as follows.\n",
556
    "\n",
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
    "<table class=\"docutils align-default\">\n",
    "    <tr>\n",
    "    <th>Backend</th>\n",
    "    <th>Bias Type</th>\n",
    "    <th>Bias Shape</th>\n",
    "    <th>Bias Data Type</th>\n",
    "    <th>Architecture</th>\n",
    "  </tr>\n",
    "  <tr>\n",
    "    <td>flash-attention</td>\n",
    "    <td>`no_bias`, `ALiBi` (with slopes)</td>\n",
    "    <td>N/A</td>\n",
    "    <td>ALiBi slopes: FP32</td>\n",
    "    <td>sm80+</td>\n",
    "  </tr>\n",
    "  <tr>\n",
    "    <td rowspan=\"2\">cuDNN attention</td>\n",
    "    <td>PyTorch: `no_bias`, `post_scale_bias`, `ALiBi` (without slopes)</td>\n",
    "    <td rowspan=\"2\">`post_scale_bias`: BHSS, 1HSS, B1SS, 11SS for forward, 1HSS for backward</td>\n",
    "      <td>`post_scale_bias`: same as QKV type</td>\n",
    "      <td>cuDNN 8.9.6+: sm90</td>\n",
    "  </tr>\n",
    "  <tr>\n",
580
    "      <td>JAX: `no_bias`, `post_scale_bias`</td>  \n",
581
582
583
584
585
586
587
588
589
590
591
    "      <td>ALiBi slopes: FP32</td>\n",
    "      <td>cuDNN 9.0+: sm80+</td>\n",
    "  </tr>\n",
    "  <tr>\n",
    "    <td>Framework-native attention</td>\n",
    "    <td>`no_bias`, `pre_scale_bias`, `post_scale_bias`</td>\n",
    "    <td>`post_scale_bias`: BHSS, 1HSS, B1SS, 11SS </td>\n",
    "      <td>`post_scale_bias`: same as QKV type</td>\n",
    "      <td>sm80+</td>\n",
    "  </tr>\n",
    "</table>\n",
592
593
594
595
596
    "\n",
    "The flash-attention backend enables `ALiBi` by asking user to pass in an `alibi_slopes` tensor, which can be the default slopes of vanilla ALiBi, or user-defined slopes. On the other hand, cuDNN attention supports `ALiBi` by taking in a `Boolean` flag, and it only supports vanilla ALiBi as of cuDNN 9.0.\n",
    "\n",
    "The framework-native backends do not explicitly support `ALiBi`, but users can convert `ALiBi` to a regular `post_scale_bias` bias to achieve the same effect. In PyTorch, this utility function, `transformer_engine.pytorch.attention.get_alibi`, can be used to help with the conversion.\n",
    "\n",
597
    "More examples of how to use the various attention biases are at [test_dpa_bias](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)."
598
599
600
601
   ]
  },
  {
   "cell_type": "markdown",
602
   "id": "a0702339",
603
604
605
606
607
608
   "metadata": {},
   "source": [
    "### 3.4 FP8 Attention\n",
    "\n",
    "A unique feature of Transformer Engine is its FP8 support, not only for the `Linear` layers but also for dot product attention. Transformer Engine's FP8 attention support is through its cuDNN attention sub-backend 2. Recall Figure 1: the two `MatMul` operations are performed in FP8 for computational efficiency, and the `SoftMax` operation is performed in FP32 for numerical accuracy.\n",
    "\n",
609
    "Transformer Engine supports FP8 attention through its [C APIs](../../api/c/fused_attn.rst), and [PyTorch API](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention), as of v2.0. Its PyTorch API offers two options, both controlled through the FP8 recipe definition, `transformer_engine.common.recipe.DelayedScaling`.\n",
610
611
612
613
614
    "\n",
    "- `DelayedScaling.fp8_dpa=True (default=False)`: This enables the use of cuDNN attention sub-backend 2, when it does support the provided user inputs. The `FusedAttention` module for cuDNN attention takes FP16 or BF16 tensors as inputs, performs dot product attention in FP8, and returns attention logits in FP16 or BF16 (same as the input type). Casting operations are required to cast tensors to FP8 at the beginning, and back to FP16/BF16 at the end of the module.\n",
    "\n",
    "- `DelayedScaling.fp8_mha=True (default=False)`: This option, on top of `fp8_dpa=True`, removes the casting operations at the beginning and end of the `FusedAttention` module. This feature is experimental. \n",
    "\n",
615
    "Examples of using the two features are available at [test_dpa_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_mha_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`."
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}