{
"cells": [
{
"cell_type": "markdown",
"id": "8ae3bc43",
"metadata": {},
"source": [
"# Attention Is All You Need!\n",
"\n",
"The core idea behind Transformer models is the attention mechanism [[1]](https://arxiv.org/abs/1706.03762). It identifies the correlation between words, selects the most important parts of the sentence to focus on, and captures meaningful patterns and dependencies in the data. Figure 1 shows a typical attention mechanism, where pre-softmax operations can be a combination of scaling, bias and masking while the post-softmax operation is often just dropout.\n",
"\n",
"
\n",
"
flash-attention (`FlashAttention`)
PyTorch-native attention (`UnfusedDotProductAttention`) | [transformer_engine.pytorch.attention](../../transformer_engine/pytorch/attention.py) |\n",
"| JAX | cuDNN attention (`_FusedDotProductAttention`)
JAX-native attention (`_UnfusedDotProductAttention`) | [transformer_engine.jax.flax.transformer](../../transformer_engine/jax/flax/transformer.py) |\n",
"| PaddlePaddle | cuDNN attention (`_te_forward`)
PaddlePaddle-native attention (`_pd_forward`) | [transformer_engine.paddle.layer.attention](../../transformer_engine/paddle/layer/attention.py) |\n"
]
},
{
"cell_type": "markdown",
"id": "e52f60f0",
"metadata": {},
"source": [
"### 1.1 Flash vs. Non-Flash\n",
"\n",
"The attention calculation has quadratic computational and memory complexities to the sequence length. Its runtime and memory requirements quadruple, when the sequence length doubles. This presents a significant challenge to scale Transformer models up for longer contexts, in order to achieve higher model quality.\n",
"\n",
"Compared to the standard, non-flash algorithm, the flash algorithm [[2]](https://arxiv.org/abs/2205.14135) was proposed to reduce the memory scaling to linear and improve the computational efficiency through optimized memory accesses. It employs the following two distinctive techniques.\n",
"\n",
"- **Tiling:** The non-flash algorithm tries to process the query, key, value tensors in one single step, requiring large amounts of global memory and incurring high volumes of reads/writes between global memory and shared memory. The flash algorithm decomposes the input into several tiles, based on the available shared memory and register size, and it computes the softmax one tile at a time.\n",
"\n",
"- **Recomputation:** The non-flash algorithm stores the softmax matrix (quadratic to sequence length) to global memory for the backward pass, while the flash algorithm only saves the softmax normalization factors (linear to sequence length). This reduces the amount of memory required as well as the bandwidth utilization between global memory and shared memory. Even though there is extra computation incurred in order to recalculate the attention in the backward pass, the bandwidth savings still provide significant improvement in efficiency.\n",
"\n",
"
NVTE_FLASH_ATTN, NVTE_FUSED_ATTN, NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT and NVTE_ALLOW_NONDETERMINISTIC_ALGO are only supported in PyTorch, and will be added to JAX and PaddlePaddle in the future.\n",
"qkv_layout may change in Transformer Engine PyTorch through [get_qkv_layout](../../transformer_engine/pytorch/attention.py). This is due to the in-place nature of our RoPE implementations. We convert `q`, `k`, `v` tensors from their initial layout to the corresponding hd_hd_hd layout. For example, from sbh3d in pytorch.MultiHeadAttention before RoPE, to sbhd_sbhd_sbhd in pytorch.DotProductAttention after RoPE.\n",
"