Unverified Commit 098e3006 authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Link attention docs to the main docs and fix errors reported by Sphinx (#1062)



* Link attention docs to the main docs and fix errors reported by Sphinx
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Lower the version of nbsphinx
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* More fixes
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Change the URL of example_attention.py to GitHub
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* More fixes in the attention tutorial
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
parent 9c127ef5
...@@ -17,8 +17,8 @@ jobs: ...@@ -17,8 +17,8 @@ jobs:
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: 'Install dependencies' - name: 'Install dependencies'
run: | run: |
pip install sphinx==7.1.2 sphinx_rtd_theme==2.0.0 nbsphinx==0.9.4 IPython ipython_genutils==0.2.0 ipywidgets==8.1.3 astroid==3.2.2 pip install sphinx==5.1.1 sphinx_rtd_theme==1.0.0 nbsphinx==0.8.10 IPython ipython_genutils==0.2.0 ipywidgets==8.0.2 astroid==2.15.7
pip install breathe==4.35.0 sphinx-autoapi==3.1.1 pip install breathe==4.34.0 sphinx-autoapi==2.0.1
sudo apt-get install -y pandoc graphviz doxygen sudo apt-get install -y pandoc graphviz doxygen
export GIT_SHA=$(git show-ref --hash HEAD) export GIT_SHA=$(git show-ref --hash HEAD)
- name: 'Build docs' - name: 'Build docs'
......
...@@ -70,7 +70,7 @@ ...@@ -70,7 +70,7 @@
color: #8c0; color: #8c0;
} }
html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple)>dt { html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.citation):not(.glossary):not(.simple)>dt {
background: rgba(118, 185, 0, 0.1); background: rgba(118, 185, 0, 0.1);
color: rgba(59,93,0,1); color: rgba(59,93,0,1);
border-top: solid 3px rgba(59,93,0,1); border-top: solid 3px rgba(59,93,0,1);
......
...@@ -109,6 +109,8 @@ napoleon_custom_sections = [ ...@@ -109,6 +109,8 @@ napoleon_custom_sections = [
("Parallelism parameters", "params_style"), ("Parallelism parameters", "params_style"),
("Optimization parameters", "params_style"), ("Optimization parameters", "params_style"),
("Values", "params_style"), ("Values", "params_style"),
("Graphing parameters", "params_style"),
("FP8-related parameters", "params_style"),
] ]
breathe_projects = {"TransformerEngine": os.path.abspath("doxygen/xml/")} breathe_projects = {"TransformerEngine": os.path.abspath("doxygen/xml/")}
......
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
"<figcaption> Figure 1: Dot product attention. </figcaption>\n", "<figcaption> Figure 1: Dot product attention. </figcaption>\n",
"</figure>\n", "</figure>\n",
"\n", "\n",
"[Transformer Engine](https://github.com/NVIDIA/TransformerEngine.git) supports the calculation of dot product attention in three frameworks, [PyTorch](https://github.com/pytorch/pytorch), [JAX](https://github.com/google/jax) and [PaddlePaddle](https://github.com/PaddlePaddle/Paddle). The API for each framework is,\n", "[Transformer Engine](https://github.com/NVIDIA/TransformerEngine.git) supports the calculation of dot product attention in three frameworks, [PyTorch](https://github.com/pytorch/pytorch), [JAX](https://github.com/google/jax) and [PaddlePaddle](https://github.com/PaddlePaddle/Paddle). The API for each framework is\n",
"\n",
"- [transformer_engine.pytorch.DotProductAttention](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention)\n", "- [transformer_engine.pytorch.DotProductAttention](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention)\n",
"- [transformer_engine.jax.flax.DotProductAttention](../../api/jax.rst#transformer_engine.jax.flax.DotProductAttention)\n", "- [transformer_engine.jax.flax.DotProductAttention](../../api/jax.rst#transformer_engine.jax.flax.DotProductAttention)\n",
"- [transformer_engine.paddle.DotProductAttention](../../api/paddle.rst#transformer_engine.paddle.DotProductAttention)" "- [transformer_engine.paddle.DotProductAttention](../../api/paddle.rst#transformer_engine.paddle.DotProductAttention)"
...@@ -28,12 +29,44 @@ ...@@ -28,12 +29,44 @@
"## 1. Attention Backends\n", "## 1. Attention Backends\n",
"\n", "\n",
"Transformer Engine provides multiple attention backends for each supported framework. The framework-native backends provide a robust baseline, while the fused, GPU-optimized implementations offer more performance. For example, the flash-attention and cuDNN attention backends in PyTorch. The framework-native backends are often named with \"unfused\", while the more optimized backends are \"fused\" or \"flash\".\n", "Transformer Engine provides multiple attention backends for each supported framework. The framework-native backends provide a robust baseline, while the fused, GPU-optimized implementations offer more performance. For example, the flash-attention and cuDNN attention backends in PyTorch. The framework-native backends are often named with \"unfused\", while the more optimized backends are \"fused\" or \"flash\".\n",
"\n", "<table class=\"docutils align-default\">\n",
"| Framework | Backend (Module Name) | Module Location |\n", " <tr>\n",
"| :-------- | :-------------------- | :-------------- |\n", " <th>Framework</th>\n",
"| PyTorch | cuDNN attention (`FusedAttention`)<br> flash-attention (`FlashAttention`)<br> PyTorch-native attention (`UnfusedDotProductAttention`) | [transformer_engine.pytorch.attention](../../transformer_engine/pytorch/attention.py) |\n", " <th>Backend (Module Name)</th>\n",
"| JAX | cuDNN attention (`_FusedDotProductAttention`)<br> JAX-native attention (`_UnfusedDotProductAttention`) | [transformer_engine.jax.flax.transformer](../../transformer_engine/jax/flax/transformer.py) |\n", " <th>Module Location</th>\n",
"| PaddlePaddle | cuDNN attention (`_te_forward`)<br> PaddlePaddle-native attention (`_pd_forward`) | [transformer_engine.paddle.layer.attention](../../transformer_engine/paddle/layer/attention.py) |\n" " </tr>\n",
" <tr>\n",
" <td rowspan=\"3\">PyTorch</td>\n",
" <td>cuDNN attention (`FusedAttention`)</td>\n",
" <td rowspan=\"3\"> [transformer_engine.pytorch.attention](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py)</td>\n",
" </tr>\n",
" <tr>\n",
" <td> flash-attention (`FlashAttention`)</td>\n",
" </tr>\n",
" <tr>\n",
" <td>\n",
" PyTorch-native attention (`UnfusedDotProductAttention`)\n",
" </td> \n",
" </tr>\n",
" <tr>\n",
" <td rowspan=\"2\">JAX</td>\n",
" <td>cuDNN attention (`_FusedDotProductAttention`)</td>\n",
" <td rowspan=\"2\">[transformer_engine.jax.flax.transformer](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/jax/flax/transformer.py)</td>\n",
" </tr>\n",
" <tr>\n",
" <td>JAX-native attention (`_UnfusedDotProductAttention`)</td>\n",
" </tr>\n",
" <tr>\n",
" <td rowspan=\"2\"> PaddlePaddle</td>\n",
" <td> cuDNN attention (`_te_forward`) </td>\n",
" <td rowspan=\"2\"> [transformer_engine.paddle.layer.attention](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/paddle/layer/attention.py)\n",
" </td> \n",
" </tr>\n",
" <tr>\n",
" <td>PaddlePaddle-native attention (`_pd_forward`)</td>\n",
" </tr>\n",
" \n",
"</table>"
] ]
}, },
{ {
...@@ -52,7 +85,9 @@ ...@@ -52,7 +85,9 @@
"- **Recomputation:** The non-flash algorithm stores the softmax matrix (quadratic to sequence length) to global memory for the backward pass, while the flash algorithm only saves the softmax normalization factors (linear to sequence length). This reduces the amount of memory required as well as the bandwidth utilization between global memory and shared memory. Even though there is extra computation incurred in order to recalculate the attention in the backward pass, the bandwidth savings still provide significant improvement in efficiency.\n", "- **Recomputation:** The non-flash algorithm stores the softmax matrix (quadratic to sequence length) to global memory for the backward pass, while the flash algorithm only saves the softmax normalization factors (linear to sequence length). This reduces the amount of memory required as well as the bandwidth utilization between global memory and shared memory. Even though there is extra computation incurred in order to recalculate the attention in the backward pass, the bandwidth savings still provide significant improvement in efficiency.\n",
"\n", "\n",
"<div class=\"alert alert-info\">\n", "<div class=\"alert alert-info\">\n",
"<b>Note:</b> Transformer Engine's flash-attention backend, available in PyTorch, and cuDNN attention backend (sub-backends 1 and 2), in PyTorch, JAX and PaddlePaddle, are both based on the flash algorithm.\n", "<b>Note</b> \n",
" \n",
"Transformer Engine's flash-attention backend, available in PyTorch, and cuDNN attention backend (sub-backends 1 and 2), in PyTorch, JAX and PaddlePaddle, are both based on the flash algorithm.\n",
"</div>\n" "</div>\n"
] ]
}, },
...@@ -67,19 +102,56 @@ ...@@ -67,19 +102,56 @@
"\n", "\n",
"The flash-attention backend supports `flash-attn`'s features as they are released, and to facilitate the use of `flash-attn`, flash-attention also offers a few functionalities such as converting the `attention_mask` to cumulative sequence lengths `cu_seqlens` for `padding` mask. Please see `transformer_engine.pytorch.attention.FlashAttention` for more details.\n", "The flash-attention backend supports `flash-attn`'s features as they are released, and to facilitate the use of `flash-attn`, flash-attention also offers a few functionalities such as converting the `attention_mask` to cumulative sequence lengths `cu_seqlens` for `padding` mask. Please see `transformer_engine.pytorch.attention.FlashAttention` for more details.\n",
"\n", "\n",
"The `flash-attn` dependency is regularly updated in Transformer Engine. As of v1.7, Transformer Engine supports `flash-attn` 2.0.6+ (see [setup.py](../../setup.py)).\n", "The `flash-attn` dependency is regularly updated in Transformer Engine. As of v1.7, Transformer Engine supports `flash-attn` 2.0.6+ (see [setup.py](https://github.com/NVIDIA/TransformerEngine/blob/main/setup.py)).\n",
"\n", "\n",
"To understand `flash-attn`'s performance, please refer to their [benchmarks](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#performance).\n", "To understand `flash-attn`'s performance, please refer to their [benchmarks](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#performance).\n",
"\n", "\n",
"### 1.3 cuDNN Attention\n", "### 1.3 cuDNN Attention\n",
"\n", "\n",
"The cuDNN attention backend, available in PyTorch, JAX and PaddlePaddle, offers another high-performance solution to the attention calculation. It requires [cuDNN](https://developer.nvidia.com/cudnn) and [cudnn-frontend](../../3rdparty/cudnn-frontend) to run, and has several sub-backends to support the different precisions and sequence lengths. Out of the three, sub-backends 1 and 2 are based on the flash algorithm, as `flash-attn` is.\n", "The cuDNN attention backend, available in PyTorch, JAX and PaddlePaddle, offers another high-performance solution to the attention calculation. It requires [cuDNN](https://developer.nvidia.com/cudnn) to run, and has several sub-backends to support the different precisions and sequence lengths. Out of the three, sub-backends 1 and 2 are based on the flash algorithm, as `flash-attn` is.\n",
"\n", "\n",
"| Sub-Backend | Algorithm | Precision | Sequence Length | Architecture | Docs |\n", "<table class=\"docutils align-default\">\n",
"| :---------- | :--------- | :-------- | :-------------- | :----------- | :--- |\n", " <tr>\n",
"| 0 | Non-Flash | BF16/FP16 | <=512 | sm80, 90 | [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/latest/developer/graph-api.html#fused-attention-fprop) |\n", " <th>Sub-Backend</th>\n",
"| 1 | Flash | BF16/FP16 | Any | sm80+ | [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/latest/developer/graph-api.html#fused-flash-attention-fprop),<br>[cudnn-frontend](https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md#scaled-dot-product-attention) |\n", " <th>Algorithm</th>\n",
"| 2 | Flash | FP8 | cuDNN pre-9.0: <=512<br>cuDNN 9.0+: Any | cuDNN pre-9.0: sm90<br>cuDNN 9.0+: sm90+ | cuDNN 9.0+: [cudnn-frontend](https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md#scaled-dot-product-attention-fp8) |\n", " <th>Precision</th>\n",
" <th>Sequence Length</th>\n",
" <th>Architecture</th>\n",
" <th>Additional info</th>\n",
" </tr>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>Non-Flash</td>\n",
" <td>BF16/FP16</td>\n",
" <td> &le;512 </td>\n",
" <td> sm80, 90 </td>\n",
" <td> [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/latest/developer/graph-api.html#fused-attention-fprop)</td> \n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>Flash</td>\n",
" <td>BF16/FP16</td>\n",
" <td> Any </td>\n",
" <td> sm80+ </td>\n",
" <td> [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/latest/developer/graph-api.html#fused-flash-attention-fprop),\n",
" [cudnn-frontend](https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md#scaled-dot-product-attention)\n",
" </td>\n",
" </tr>\n",
" <tr>\n",
" <td rowspan=\"2\">2</td>\n",
" <td rowspan=\"2\">Flash</td>\n",
" <td rowspan=\"2\">FP8</td>\n",
" <td> cuDNN pre-9.0: &le;512 </td>\n",
" <td>cuDNN pre-9.0: sm90</td>\n",
" <td></td>\n",
" </tr>\n",
" <tr>\n",
" <td> cuDNN 9.0+: Any</td>\n",
" <td> cuDNN 9.0+: sm90+ </td>\n",
" <td> cuDNN 9.0+: [cudnn-frontend](https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md#scaled-dot-product-attention-fp8)\n",
" </td> \n",
" </tr>\n",
"</table>\n",
"\n", "\n",
"The cuDNN attention backend and flash-attention backend have several notable differences. As of Transformer Engine 1.7, cuDNN 9.0 and `flash-attn` 2.4.2,\n", "The cuDNN attention backend and flash-attention backend have several notable differences. As of Transformer Engine 1.7, cuDNN 9.0 and `flash-attn` 2.4.2,\n",
"\n", "\n",
...@@ -91,7 +163,7 @@ ...@@ -91,7 +163,7 @@
"- flash-attention uses bottom right diagonal for `causal` mask in cross attention, and cuDNN attention uses top left (see `flash-attn`'s [change log](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#21-change-behavior-of-causal-flag)).\n", "- flash-attention uses bottom right diagonal for `causal` mask in cross attention, and cuDNN attention uses top left (see `flash-attn`'s [change log](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#21-change-behavior-of-causal-flag)).\n",
"- flash-attention outperforms cuDNN attention on Ampere architectures, and cuDNN attention has 20-50% advantages on Hopper architectures, based on our benchmarks for a number of commonly-used model configurations.\n", "- flash-attention outperforms cuDNN attention on Ampere architectures, and cuDNN attention has 20-50% advantages on Hopper architectures, based on our benchmarks for a number of commonly-used model configurations.\n",
"\n", "\n",
"To compare cuDNN attention and flash-attention, users can modify the `model_configs` dictionary in [benchmarks/attention/benchmark_attention.py](../../benchmarks/attention/benchmark_attention.py) to collect performance numbers. The script runs each entry in `model_configs` for `num_iters` times, each time with one forward pass and one backward pass. Both backends are tried, and if one backend does not have support for the specific user input, the runtimes and speedups in the final table would be 0." "To compare cuDNN attention and flash-attention, users can modify the `model_configs` dictionary in [benchmarks/attention/benchmark_attention.py](https://github.com/NVIDIA/TransformerEngine/blob/main/benchmarks/attention/benchmark_attention.py) to collect performance numbers. The script runs each entry in `model_configs` for `num_iters` times, each time with one forward pass and one backward pass. Both backends are tried, and if one backend does not have support for the specific user input, the runtimes and speedups in the final table would be 0."
] ]
}, },
{ {
...@@ -151,11 +223,32 @@ ...@@ -151,11 +223,32 @@
"\n", "\n",
"When there are multiple backends available, Transformer Engine makes backend selection based on performance. In general, there are a few rules being followed in our selection logic (see table below). As we monitor the performance of different backends, the selection logic may change.\n", "When there are multiple backends available, Transformer Engine makes backend selection based on performance. In general, there are a few rules being followed in our selection logic (see table below). As we monitor the performance of different backends, the selection logic may change.\n",
"\n", "\n",
"| Framework | Selection Order |\n", "<table class=\"docutils align-default\">\n",
"| :-------- | :--------------------- |\n", " <tr>\n",
"| PyTorch | sm90: cuDNN attention > flash-attention > PyTorch-native attention<br>sm80: flash-attention > cuDNN attention > PyTorch-native attention<br>cuDNN attention: sub-backend 1 > sub-backend 0 |\n", " <th>Framework</th>\n",
"| JAX | cuDNN attention > JAX-native attention |\n", " <th>Selection Order</th>\n",
"| PaddlePaddle | cuDNN attention > PaddlePaddle-native attention |\n" " </tr>\n",
" <tr>\n",
" <td rowspan=\"3\">PyTorch</td>\n",
" <td>sm90: cuDNN attention > flash-attention > PyTorch-native attention</td>\n",
" </tr>\n",
" <tr>\n",
" <td> sm80: flash-attention > cuDNN attention > PyTorch-native attention</td>\n",
" </tr>\n",
" <tr>\n",
" <td>\n",
" cuDNN attention: sub-backend 1 > sub-backend 0\n",
" </td> \n",
" </tr>\n",
" <tr>\n",
" <td>JAX</td>\n",
" <td>cuDNN attention > JAX-native attention</td>\n",
" </tr>\n",
" <tr>\n",
" <td> PaddlePaddle</td>\n",
" <td> cuDNN attention > PaddlePaddle-native attention </td>\n",
" </tr>\n",
"</table>"
] ]
}, },
{ {
...@@ -171,7 +264,9 @@ ...@@ -171,7 +264,9 @@
"NVTE_DEBUG_LEVEL = 0/1/2 # enables logging.WARNING/INFO/DEBUG-level messages\n", "NVTE_DEBUG_LEVEL = 0/1/2 # enables logging.WARNING/INFO/DEBUG-level messages\n",
"```\n", "```\n",
"<div class=\"alert alert-info\">\n", "<div class=\"alert alert-info\">\n",
"<b>Note:</b> These flags are supported in PyTorch only as of Transformer Engine 1.7. JAX and PaddlePaddle support is expected to be added in the future.\n", "<b>Note</b>\n",
" \n",
"These flags are supported in PyTorch only as of Transformer Engine 1.7. JAX and PaddlePaddle support is expected to be added in the future.\n",
"</div>" "</div>"
] ]
}, },
...@@ -180,7 +275,7 @@ ...@@ -180,7 +275,7 @@
"id": "7e3b7981", "id": "7e3b7981",
"metadata": {}, "metadata": {},
"source": [ "source": [
"The [example_attention.py](./example_attention.py) script runs a very basic model with two attention backends, cuDNN attention and flash-attention. Here `NVTE_DEBUG_LEVEL=1` allows us to find out which backend/sub-backend was actually used during runtime." "The [example_attention.py](https://raw.githubusercontent.com/NVIDIA/TransformerEngine/main/docs/examples/attention/example_attention.py) script runs a very basic model with two attention backends, cuDNN attention and flash-attention. Here `NVTE_DEBUG_LEVEL=1` allows us to find out which backend/sub-backend was actually used during runtime."
] ]
}, },
{ {
...@@ -283,14 +378,16 @@ ...@@ -283,14 +378,16 @@
" NVTE_ALLOW_NONDETERMINISTIC_ALGO = 0 # enables workspace optimization path\n", " NVTE_ALLOW_NONDETERMINISTIC_ALGO = 0 # enables workspace optimization path\n",
"```\n", "```\n",
"<div class=\"alert alert-info\">\n", "<div class=\"alert alert-info\">\n",
"<b>Note:</b> Environment variables <code>NVTE_FLASH_ATTN</code>, <code>NVTE_FUSED_ATTN</code>, <code>NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT</code> and <code>NVTE_ALLOW_NONDETERMINISTIC_ALGO</code> are only supported in PyTorch, and will be added to JAX and PaddlePaddle in the future.\n", "<b>Note</b>\n",
" \n",
"Environment variables <code>NVTE_FLASH_ATTN</code>, <code>NVTE_FUSED_ATTN</code>, <code>NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT</code> and <code>NVTE_ALLOW_NONDETERMINISTIC_ALGO</code> are only supported in PyTorch, and will be added to JAX and PaddlePaddle in the future.\n",
"</div>\n", "</div>\n",
"\n", "\n",
"### 2.3 Example Tests\n", "### 2.3 Example Tests\n",
"\n", "\n",
"Our [unit tests](../../tests/) demonstrate the use of Transformer Engine dot product attention APIs. Users are encouraged to use them as a template when integrating Transformer Engine to their ML workflows.\n", "Our [unit tests](https://github.com/NVIDIA/TransformerEngine/tree/main/tests) demonstrate the use of Transformer Engine dot product attention APIs. Users are encouraged to use them as a template when integrating Transformer Engine to their ML workflows.\n",
"\n", "\n",
"For example, in PyTorch, [test_dot_product_attention](../../tests/pytorch/fused_attention/test_fused_attn.py) offers a variety of use cases of `pytorch.DotProductAttention`, from data types, model configs, checkpointing, to QKV layouts." "For example, in PyTorch, [test_dot_product_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) offers a variety of use cases of `pytorch.DotProductAttention`, from data types, model configs, checkpointing, to QKV layouts."
] ]
}, },
{ {
...@@ -302,16 +399,16 @@ ...@@ -302,16 +399,16 @@
"\n", "\n",
"Transformer Engine supports commonly-used features such as self and cross attention, FP16/BF16 precisions, dropout, and checkpointing. But it also offers a range of other features. As of v1.7, Transformer Engine's attention backends have the following support matrix.\n", "Transformer Engine supports commonly-used features such as self and cross attention, FP16/BF16 precisions, dropout, and checkpointing. But it also offers a range of other features. As of v1.7, Transformer Engine's attention backends have the following support matrix.\n",
"\n", "\n",
"| Attention Backend | Precision | Architecture | Sliding Window Attention | MQA/GQA | Context Parallelism | Deterministic |\n", "| Attention Backend | Precision | Architecture | Sliding Window Attention | MQA/GQA | Context Parallelism | Determinism Possible |\n",
"| :---------------- | :-------- | :----------- | :----------------------- | :------ | :------------------ | :------------ |\n", "| :---------------- | :-------- | :----------- | :----------------------- | :------ | :------------------ | :------------ |\n",
"| cuDNN attention<br>(PyTorch, JAX, PaddlePaddle) | PyTorch: BF16, FP16, FP8<br>JAX, PaddlePaddle: BF16, FP16 | sm80+ | No | Yes | `bshd`,`sbhd`: Yes<br>`thd`: No | Sub-backend 0, 2: Yes<br>Sub-backend 1: Yes, if workspace optimization path |\n", "| cuDNN attention (all frameworks) | BF16, FP16, FP8 (PyTorch only) | sm80+ | No | Yes | Yes (only for `bshd`,`sbhd`) | Yes |\n",
"| flash-attention<br>(PyTorch) | BF16, FP16 | sm80+ | Yes | Yes | `bshd`,`thd`: Yes<br>`sbhd`: No | Yes, if `deterministic=True` |\n", "| flash-attention (PyTorch) | BF16, FP16 | sm80+ | Yes | Yes | Yes (only for `bshd`,`thd`) | Yes |\n",
"| Framework-native attention<br>(PyTorch, JAX, PaddlePaddle) | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | No | Yes |\n", "| Framework-native attention | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | No | Yes |\n",
"\n", "\n",
"Some unit tests are provided to serve as a starting point for integrating such features into users' models. For example,\n", "Some unit tests are provided to serve as a starting point for integrating such features into users' models. For example,\n",
"- sliding window attention: [test_dpa_swa](../../tests/pytorch/fused_attention/test_fused_attn.py)\n", "- sliding window attention: [test_dpa_swa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n",
"- MQA/GQA: [test_te_layer_mqa_gqa](../../tests/pytorch/fused_attention/test_fused_attn.py)\n", "- MQA/GQA: [test_te_layer_mqa_gqa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n",
"- context parallelism: [test_cp_with_fused_attention](../../tests/pytorch/fused_attention/test_fused_attn_with_cp.py), [test_cp_with_flash_attention](../../tests/pytorch/fused_attention/test_fused_attn_with_cp.py)" "- context parallelism: [test_cp_with_fused_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py), [test_cp_with_flash_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py)"
] ]
}, },
{ {
...@@ -331,29 +428,53 @@ ...@@ -331,29 +428,53 @@
"\n", "\n",
"The notation system is that `b` stands for the batch size, `s` sequence length, `h` number of attention heads, `d` head dimension, and `t` the total number of tokens in the batch, i.e. `t = sum(s_i) for i in 0,...,b-1`. Here are a few examples of the layouts and their explanations to help clarify the definition.\n", "The notation system is that `b` stands for the batch size, `s` sequence length, `h` number of attention heads, `d` head dimension, and `t` the total number of tokens in the batch, i.e. `t = sum(s_i) for i in 0,...,b-1`. Here are a few examples of the layouts and their explanations to help clarify the definition.\n",
"\n", "\n",
"**`qkv_layout`=`sb3hd`:**\n", "**qkv_layout=sb3hd:**\n",
"`q`, `k`, `v` are sequence first, i.e. `s` is the leading dimension in each tensor. They are different slices of one tensor `qkv`: `q, k, v = [qkv[:,:,i,:,:] for i in range(3)]`. They are interleaved at the `h * d` dimension.\n", "`q`, `k`, `v` are sequence first, i.e. `s` is the leading dimension in each tensor. They are different slices of one tensor `qkv`: `q, k, v = [qkv[:,:,i,:,:] for i in range(3)]`. They are interleaved at the `h * d` dimension.\n",
"\n", "\n",
"**`qkv_layout`=`bshd_bsh2d`:**\n", "**qkv_layout=bshd_bsh2d:**\n",
"`q`, `k`, `v` are batch first, i.e. `b` is the leading dimension in each tensor. `q` is contiguous, and `k`, `v` are different slices of tensor `kv`: `k, v = [kv[:,:,:,i,:] for i in range(2)]`. `k`, `v` are interleaved at the `d` dimension.\n", "`q`, `k`, `v` are batch first, i.e. `b` is the leading dimension in each tensor. `q` is contiguous, and `k`, `v` are different slices of tensor `kv`: `k, v = [kv[:,:,:,i,:] for i in range(2)]`. `k`, `v` are interleaved at the `d` dimension.\n",
"\n", "\n",
"The `s` and `h` in `bsh2d` are the max sequence length and number of heads for `k`, `v`, which can be different from the `s` and `h` in `bshd` for `q`. We denoted them as the same for brevity reasons. Transformer Engine does differentiate their values for actual execution.\n", "The `s` and `h` in `bsh2d` are the max sequence length and number of heads for `k`, `v`, which can be different from the `s` and `h` in `bshd` for `q`. We denoted them as the same for brevity reasons. Transformer Engine does differentiate their values for actual execution.\n",
"\n", "\n",
"**`qkv_layout`=`thd_thd_thd`:**\n", "**qkv_layout=thd_thd_thd:**\n",
"`q`, `k`, `v` have variable sequence lengths in a batch. They are all contiguous and have no interleaving.\n", "`q`, `k`, `v` have variable sequence lengths in a batch. They are all contiguous and have no interleaving.\n",
"\n", "\n",
"As of v1.7, Transformer Engine has the following support matrix.\n", "As of v1.7, Transformer Engine has the following support matrix.\n",
"\n", "\n",
"| Backend | Supported QKV Formats | Notes |\n", "<table class=\"docutils align-default\">\n",
"| :--------------- | :-------------------- | :------ |\n", " <tr>\n",
"| flash-attention | `bshd`, `sbhd`, `thd`<br>(`sbhd` requires transpose operations) | PyTorch: 3 formats, i.e. 15 layouts|\n", " <th>Backend</th>\n",
"| cuDNN attention | `bshd`, `sbhd`, `thd` | PyTorch: 3 formats, i.e. 15 layouts<br>JAX, PaddlePaddle: `bs3hd`, `bshd_bs2hd`, `bshd_bshd_bshd` layouts |\n", " <th>Supported QKV Formats</th>\n",
"| Framework-native attention | `bshd`, `sbhd`<br>(`sbhd` requires transpose operations) | PyTorch, JAX, PaddlePaddle: 2 formats, i.e. 10 layouts |\n", " <th>Notes</th>\n",
"\n", " </tr>\n",
"Some example usage of the different layouts can be found at [test_dpa_qkv_layout](../../tests/pytorch/fused_attention/test_fused_attn.py) and [test_dpa_qkv_layout_thd](../../tests/pytorch/fused_attention/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.get_qkv_layout](../../transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n", " <tr>\n",
" <td>flash-attention</td>\n",
" <td>`bshd`, `sbhd`, `thd`</td>\n",
" <td>PyTorch: 3 formats, i.e. 15 layouts</td>\n",
" </tr>\n",
" <tr>\n",
" <td rowspan=\"2\">cuDNN attention</td>\n",
" <td rowspan=\"2\">`bshd`, `sbhd`, `thd`</td>\n",
" <td>PyTorch: 3 formats, i.e. 15 layouts</td>\n",
" </tr>\n",
" <tr>\n",
" <td>\n",
" JAX, PaddlePaddle: `bs3hd`, `bshd_bs2hd`, `bshd_bshd_bshd` layouts\n",
" </td> \n",
" </tr>\n",
" <tr>\n",
" <td>Framework-native attention</td>\n",
" <td>`bshd`, `sbhd`</td>\n",
" <td>PyTorch, JAX, PaddlePaddle: 2 formats, i.e. 10 layouts</td>\n",
" </tr>\n",
"</table>\n",
"\n",
"Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n",
"\n", "\n",
"<div class=\"alert alert-info\">\n", "<div class=\"alert alert-info\">\n",
"<b>Note:</b> When RoPE is employed, the <code>qkv_layout</code> may change in Transformer Engine PyTorch through [get_qkv_layout](../../transformer_engine/pytorch/attention.py). This is due to the in-place nature of our RoPE implementations. We convert `q`, `k`, `v` tensors from their initial layout to the corresponding <code>hd_hd_hd</code> layout. For example, from <code>sbh3d</code> in <code>pytorch.MultiHeadAttention</code> before RoPE, to <code>sbhd_sbhd_sbhd</code> in <code>pytorch.DotProductAttention</code> after RoPE.\n", "<b>Note</b>\n",
" \n",
"When RoPE is employed, the <code>qkv_layout</code> may change in Transformer Engine PyTorch through [get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py). This is due to the in-place nature of our RoPE implementations. We convert `q`, `k`, `v` tensors from their initial layout to the corresponding <code>hd_hd_hd</code> layout. For example, from <code>sbh3d</code> in <code>pytorch.MultiHeadAttention</code> before RoPE, to <code>sbhd_sbhd_sbhd</code> in <code>pytorch.DotProductAttention</code> after RoPE.\n",
"</div>\n" "</div>\n"
] ]
}, },
...@@ -365,17 +486,46 @@ ...@@ -365,17 +486,46 @@
"### 3.2 Attention Mask\n", "### 3.2 Attention Mask\n",
"\n", "\n",
"Transformer Engine supports 5 mask types, and all the masks are defined as `True` masking out the corresponding element and `False` including the corresponding element in attention calculation.\n", "Transformer Engine supports 5 mask types, and all the masks are defined as `True` masking out the corresponding element and `False` including the corresponding element in attention calculation.\n",
"\n",
"- `no_mask`, `padding`, `causal`, `padding_causal` (equivalent to `causal_padding`), `arbitrary`\n", "- `no_mask`, `padding`, `causal`, `padding_causal` (equivalent to `causal_padding`), `arbitrary`\n",
"\n", "\n",
"Different backends offer different support for attention mask. As of Transformer Engine 1.7,\n", "Different backends offer different support for attention mask. As of Transformer Engine 1.7,\n",
"\n", "\n",
"| Backend | Supported Mask Types | Requires `attention_mask` |\n", "<table class=\"docutils align-default\">\n",
"| :--------------- | :-------------------- | :------------------ |\n", " <tr>\n",
"| flash-attention | `no_mask`, `causal`, `padding`, `padding_causal` | `no_mask`, `causal`: No<br>`padding`, `padding_causal`: Yes if `cu_seqlens` not provided|\n", " <th>Backend</th>\n",
"| cuDNN attention | `no_mask`, `causal`, `padding`, `padding_causal` | `no_mask`, `causal`: No<br>`padding`, `padding_causal`: Yes if `cu_seqlens` not provided|\n", " <th>Supported Mask Types</th>\n",
"| Framework-native attention | `no_mask`, `causal`, `arbitrary` | `no_mask`, `causal`: No<br>`arbitrary`: Yes |\n", " <th>Requires `attention_mask`</th>\n",
"\n", " </tr>\n",
"**`padding` and `padding_causal`:** For these two mask types, users need to provide sequence length information to help Transformer Engine figure out where each sequence ends in a batch. As of Transformer Engine 1.7, there are two options to do so in PyTorch and one in JAX and PaddlePaddle.\n", " <tr>\n",
" <td rowspan=\"2\">flash-attention</td>\n",
" <td rowspan=\"2\">`no_mask`, `causal`, `padding`, `padding_causal`</td>\n",
" <td>`no_mask`, `causal`: No</td>\n",
" </tr>\n",
" <tr>\n",
" <td>`padding`, `padding_causal`: Yes if `cu_seqlens` not provided</td>\n",
" </tr>\n",
" <tr>\n",
" <td rowspan=\"2\">cuDNN attention</td>\n",
" <td rowspan=\"2\">`no_mask`, `causal`, `padding`, `padding_causal`</td>\n",
" <td>`no_mask`, `causal`: No</td>\n",
" </tr>\n",
" <tr>\n",
" <td>\n",
" `padding`, `padding_causal`: Yes if `cu_seqlens` not provided\n",
" </td> \n",
" </tr>\n",
" <tr>\n",
" <td rowspan=\"2\">Framework-native attention</td>\n",
" <td rowspan=\"2\">`no_mask`, `causal`, `arbitrary`</td>\n",
" <td>`no_mask`, `causal`: No</td>\n",
" </tr>\n",
" <tr>\n",
" <td>`arbitrary`: Yes</td>\n",
" </tr>\n",
"</table>\n",
"\n",
"**padding and padding_causal:** For these two mask types, users need to provide sequence length information to help Transformer Engine figure out where each sequence ends in a batch. As of Transformer Engine 1.7, there are two options to do so in PyTorch and one in JAX and PaddlePaddle.\n",
"\n", "\n",
"* PyTorch: When both options are provided by the user, `cu_seqlens` is preferred as there is no extra conversion needed.\n", "* PyTorch: When both options are provided by the user, `cu_seqlens` is preferred as there is no extra conversion needed.\n",
" - `cu_seqlens`: Users can provide cumulative sequence length tensors `cu_seqlens_q` and `cu_seqlens_kv` for `q` and `k`/`v` to the flash-attention or cuDNN attention backend. An example of `cu_seqlens` is `[0, 2, 6, 7]` for a batch of 3 `[aa000, bbbb0, c0000]`.\n", " - `cu_seqlens`: Users can provide cumulative sequence length tensors `cu_seqlens_q` and `cu_seqlens_kv` for `q` and `k`/`v` to the flash-attention or cuDNN attention backend. An example of `cu_seqlens` is `[0, 2, 6, 7]` for a batch of 3 `[aa000, bbbb0, c0000]`.\n",
...@@ -384,9 +534,9 @@ ...@@ -384,9 +534,9 @@
"\n", "\n",
"* JAX and PaddlePaddle: Users should provide the `attention_mask` tensor in shape `[batch_size, 1, seqlen_q, seqlen_kv]`.\n", "* JAX and PaddlePaddle: Users should provide the `attention_mask` tensor in shape `[batch_size, 1, seqlen_q, seqlen_kv]`.\n",
"\n", "\n",
"**`qkv_format`=`thd`:** Transformer Engine extracts the max sequence length information from `q`, `k`, `v` if `max_seqlen_q` and `max_seqlen_kv` are not provided. This requires GPU-CPU copy and synchronization operations. For performance reasons, please set `max_seqlen_q` and `max_seqlen_kv` to their appropriate values for `thd` QKV format.\n", "**qkv_format=thd:** Transformer Engine extracts the max sequence length information from `q`, `k`, `v` if `max_seqlen_q` and `max_seqlen_kv` are not provided. This requires GPU-CPU copy and synchronization operations. For performance reasons, please set `max_seqlen_q` and `max_seqlen_kv` to their appropriate values for `thd` QKV format.\n",
"\n", "\n",
"**`Arbitrary` mask:** cuDNN does not support `Arbitrary` mask type as of v9.0. However, users can convert the mask to a regular `post_scale_bias` bias and achieve the same functionality. An example script for this conversion is [arbitrary_mask_to_post_scale_bias.py](./arbitrary_mask_to_post_scale_bias.py).\n" "**Arbitrary mask:** cuDNN does not support `Arbitrary` mask type as of v9.0. However, users can convert the mask to a regular `post_scale_bias` bias and achieve the same functionality. An example script for this conversion is [arbitrary_mask_to_post_scale_bias.py](https://raw.githubusercontent.com/NVIDIA/TransformerEngine/main/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py).\n"
] ]
}, },
{ {
...@@ -416,23 +566,53 @@ ...@@ -416,23 +566,53 @@
"id": "e045c284", "id": "e045c284",
"metadata": {}, "metadata": {},
"source": [ "source": [
"Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](../../tests/pytorch/fused_attention/test_fused_attn.py).\n", "Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py).\n",
"\n", "\n",
"### 3.3 Attention Bias\n", "### 3.3 Attention Bias\n",
"\n", "\n",
"Transformer Engine supports 4 attention bias types, `no_bias`, `pre_scale_bias`, `post_scale_bias`, and `ALiBi` (with/without custom slopes). As of Transformer Engine 1.7, their support matrix is as follows.\n", "Transformer Engine supports 4 attention bias types, `no_bias`, `pre_scale_bias`, `post_scale_bias`, and `ALiBi` (with/without custom slopes). As of Transformer Engine 1.7, their support matrix is as follows.\n",
"\n", "\n",
"| Backend | Bias Type | Bias Shape | Bias Data Type | Architecture |\n", "<table class=\"docutils align-default\">\n",
"| :------ | :-------- | :--------- | :--------- | :----------- |\n", " <tr>\n",
"| flash-attention | `no_bias`, `ALiBi` (with slopes) | N/A | ALiBi slopes: FP32 | sm80+ |\n", " <th>Backend</th>\n",
"| cuDNN attention | PyTorch: `no_bias`, `post_scale_bias`, `ALiBi` (without slopes)<br>JAX, PaddlePaddle: `no_bias`, `post_scale_bias` | `post_scale_bias`: BHSS, 1HSS, B1SS, 11SS for forward, 1HSS for backward | `post_scale_bias`: same as QKV type<br>ALiBi slopes: FP32 | cuDNN 8.9.6+: sm90<br>cuDNN 9.0+: sm80+ |\n", " <th>Bias Type</th>\n",
"| Framework-native attention | `no_bias`, `pre_scale_bias`, `post_scale_bias` | `post_scale_bias`: BHSS, 1HSS, B1SS, 11SS | `post_scale_bias`: same as QKV type | sm80+ |\n", " <th>Bias Shape</th>\n",
" <th>Bias Data Type</th>\n",
" <th>Architecture</th>\n",
" </tr>\n",
" <tr>\n",
" <td>flash-attention</td>\n",
" <td>`no_bias`, `ALiBi` (with slopes)</td>\n",
" <td>N/A</td>\n",
" <td>ALiBi slopes: FP32</td>\n",
" <td>sm80+</td>\n",
" </tr>\n",
" <tr>\n",
" <td rowspan=\"2\">cuDNN attention</td>\n",
" <td>PyTorch: `no_bias`, `post_scale_bias`, `ALiBi` (without slopes)</td>\n",
" <td rowspan=\"2\">`post_scale_bias`: BHSS, 1HSS, B1SS, 11SS for forward, 1HSS for backward</td>\n",
" <td>`post_scale_bias`: same as QKV type</td>\n",
" <td>cuDNN 8.9.6+: sm90</td>\n",
" </tr>\n",
" <tr>\n",
" <td>JAX, PaddlePaddle: `no_bias`, `post_scale_bias`</td> \n",
" <td>ALiBi slopes: FP32</td>\n",
" <td>cuDNN 9.0+: sm80+</td>\n",
" </tr>\n",
" <tr>\n",
" <td>Framework-native attention</td>\n",
" <td>`no_bias`, `pre_scale_bias`, `post_scale_bias`</td>\n",
" <td>`post_scale_bias`: BHSS, 1HSS, B1SS, 11SS </td>\n",
" <td>`post_scale_bias`: same as QKV type</td>\n",
" <td>sm80+</td>\n",
" </tr>\n",
"</table>\n",
"\n", "\n",
"The flash-attention backend enables `ALiBi` by asking user to pass in an `alibi_slopes` tensor, which can be the default slopes of vanilla ALiBi, or user-defined slopes. On the other hand, cuDNN attention supports `ALiBi` by taking in a `Boolean` flag, and it only supports vanilla ALiBi as of cuDNN 9.0.\n", "The flash-attention backend enables `ALiBi` by asking user to pass in an `alibi_slopes` tensor, which can be the default slopes of vanilla ALiBi, or user-defined slopes. On the other hand, cuDNN attention supports `ALiBi` by taking in a `Boolean` flag, and it only supports vanilla ALiBi as of cuDNN 9.0.\n",
"\n", "\n",
"The framework-native backends do not explicitly support `ALiBi`, but users can convert `ALiBi` to a regular `post_scale_bias` bias to achieve the same effect. In PyTorch, this utility function, `transformer_engine.pytorch.attention.get_alibi`, can be used to help with the conversion.\n", "The framework-native backends do not explicitly support `ALiBi`, but users can convert `ALiBi` to a regular `post_scale_bias` bias to achieve the same effect. In PyTorch, this utility function, `transformer_engine.pytorch.attention.get_alibi`, can be used to help with the conversion.\n",
"\n", "\n",
"More examples of how to use the various attention biases are at [test_dpa_bias](../../tests/pytorch/fused_attention/test_fused_attn.py)." "More examples of how to use the various attention biases are at [test_dpa_bias](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)."
] ]
}, },
{ {
...@@ -450,7 +630,7 @@ ...@@ -450,7 +630,7 @@
"\n", "\n",
"- `DelayedScaling.fp8_mha=True (default=False)`: This option, on top of `fp8_dpa=True`, removes the casting operations at the beginning and end of the `FusedAttention` module. This feature is experimental. \n", "- `DelayedScaling.fp8_mha=True (default=False)`: This option, on top of `fp8_dpa=True`, removes the casting operations at the beginning and end of the `FusedAttention` module. This feature is experimental. \n",
"\n", "\n",
"Examples of using the two features are available at [test_dpa_fp8_vs_f16](../../tests/pytorch/fused_attention/test_fused_attn.py) and [test_mha_fp8_vs_f16](../../tests/pytorch/fused_attention/test_fused_attn.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`. This should result in the following print when the debug flags are turned on, `NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2`.\n", "Examples of using the two features are available at [test_dpa_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_mha_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`. This should result in the following print when the debug flags are turned on, `NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2`.\n",
"```\n", "```\n",
"[DEBUG | DotProductAttention]: Running with fp8_recipe.fp8_mha=False, fp8_recipe.fp8_dpa=True and NVTE_FP8_DPA_BWD=0\n", "[DEBUG | DotProductAttention]: Running with fp8_recipe.fp8_mha=False, fp8_recipe.fp8_dpa=True and NVTE_FP8_DPA_BWD=0\n",
"[DEBUG | FusedAttnFunc ]: Running forward in FP8\n", "[DEBUG | FusedAttnFunc ]: Running forward in FP8\n",
......
...@@ -51,3 +51,4 @@ Transformer Engine documentation ...@@ -51,3 +51,4 @@ Transformer Engine documentation
:caption: Advanced :caption: Advanced
api/c/index api/c/index
examples/attention/attention.ipynb
...@@ -366,8 +366,8 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method ...@@ -366,8 +366,8 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
class DenseGeneral(TransformerEngineBase): class DenseGeneral(TransformerEngineBase):
""" r"""
Applies a linear transformation to the incoming data :math:`y = xA^T + b` Applies a linear transformation to the incoming data :math:`y = xA^T + b`.
Parameters Parameters
---------- ----------
......
...@@ -1531,19 +1531,20 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1531,19 +1531,20 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
Indicate the min and max time-scales of rotary position embedding, Indicate the min and max time-scales of rotary position embedding,
only used when :attr:`enable_rotary_pos_emb=True` only used when :attr:`enable_rotary_pos_emb=True`
rotary_pos_emb_group_method: str, default = 'consecutive' rotary_pos_emb_group_method: str, default = 'consecutive'
Indicate the method to coupled the coordinates. It should be one of Indicate the method to couple the coordinates. It should be one of
['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2` ['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`,
, d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`. where :math:`d` is the hidden dimension. 'consecutive' pairs index :math:`i` with
:math:`i + 1`.
low_rank_adaptation_scope: str, default = 'none' low_rank_adaptation_scope: str, default = 'none'
Indicate the scope to apply low rank adaptation. It should be one of Indicate the scope to apply low rank adaptation. It should be one of
['none', 'all', 'qkv_proj', 'output_proj', 'mlp', 'exclude_qkv_proj', ['none', 'all', 'qkv_proj', 'output_proj', 'mlp', 'exclude_qkv_proj',
'exclude_output_proj', 'exclude_mlp'] 'exclude_output_proj', 'exclude_mlp']
low_rank_adaptation_dim: int, default = 32 low_rank_adaptation_dim: int, default = 32
The dimension for low rank adaptation, only used when The dimension for low rank adaptation, only used when
:attr:`enable_low_rank_adaptation=True` :attr:`enable_low_rank_adaptation=True`
low_rank_adaptation_alpha: float, default = None low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output. The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling. :math:`\frac{alpha}{rank} * lora\_output`. None means no scaling.
enable_sequence_parallel: bool, default = False enable_sequence_parallel: bool, default = False
Whether to enable sequence parallelism to operations except dot. Whether to enable sequence parallelism to operations except dot.
......
...@@ -328,8 +328,8 @@ def fp8_autocast( ...@@ -328,8 +328,8 @@ def fp8_autocast(
pjit(transformer.init, ...)(...) pjit(transformer.init, ...)(...)
.. note:: .. note::
We only support :attr:`margin`, :attr:`fp8_format`, :attr:`amax_history_len` We only support :attr:`margin`, :attr:`fp8_format`, :attr:`amax_history_len`,
, and :attr:`amax_compute_algo`(with value 'max' and 'most_recent') in and :attr:`amax_compute_algo` (with value 'max' and 'most_recent') in
recipe.DelayedScaling currently. Other parameters in recipe.DelayedScaling recipe.DelayedScaling currently. Other parameters in recipe.DelayedScaling
will trigger an assertion. will trigger an assertion.
......
...@@ -9,9 +9,11 @@ import warnings ...@@ -9,9 +9,11 @@ import warnings
import paddle import paddle
from paddle.incubate.nn.layer.fused_dropout_add import FusedDropoutAdd from paddle.incubate.nn.layer.fused_dropout_add import FusedDropoutAdd
from transformer_engine.paddle.layer import LayerNormMLP, LayerNorm, MultiHeadAttention from .layernorm_mlp import LayerNormMLP
from transformer_engine.paddle.constants import AttnMaskTypes, LayerTypes, dist_group_type from .layernorm import LayerNorm
from transformer_engine.paddle.distributed import get_tp_group_and_world_size, track_rng_state from .attention import MultiHeadAttention
from ..constants import AttnMaskTypes, LayerTypes, dist_group_type
from ..distributed import get_tp_group_and_world_size, track_rng_state
class TransformerLayer(paddle.nn.Layer): class TransformerLayer(paddle.nn.Layer):
......
...@@ -10,7 +10,7 @@ from typing import Optional ...@@ -10,7 +10,7 @@ from typing import Optional
import torch import torch
from transformer_engine.pytorch.ops import FusibleOperation from transformer_engine.pytorch.ops.op import FusibleOperation
from transformer_engine.pytorch.ops.fuser import OperationFuser from transformer_engine.pytorch.ops.fuser import OperationFuser
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment