Commit d693034e authored by Tri Dao's avatar Tri Dao
Browse files

Integrate FlashAttention into Megatron-LM

parent 52e63688
...@@ -333,6 +333,18 @@ Theoretical memory savings vary depending on the combination of the model's para ...@@ -333,6 +333,18 @@ Theoretical memory savings vary depending on the combination of the model's para
| bf16 param, fp32 grads | 18 | 6 + 12/d | | bf16 param, fp32 grads | 18 | 6 + 12/d |
| fp32 param, fp32 grads | 16 | 8 + 8/d | | fp32 param, fp32 grads | 16 | 8 + 8/d |
## FlashAttention
Usage: `--use-flash-attn`. Support attention head dimensions at most 128.
[FlashAttention](https://github.com/HazyResearch/flash-attention) is a fast and
memory-efficient algorithm to compute exact attention. It speeds up model
training and reduces memory requirement.
To install FlashAttention:
```sh
pip install flash-attn
```
## GPT-3 Example ## GPT-3 Example
......
...@@ -612,6 +612,9 @@ def _add_training_args(parser): ...@@ -612,6 +612,9 @@ def _add_training_args(parser):
group.add_argument('--no-bias-dropout-fusion', action='store_false', group.add_argument('--no-bias-dropout-fusion', action='store_false',
help='Disable bias and dropout fusion.', help='Disable bias and dropout fusion.',
dest='bias_dropout_fusion') dest='bias_dropout_fusion')
group.add_argument('--use-flash-attn', action='store_true',
help='use FlashAttention implementation of attention. '
'https://arxiv.org/abs/2205.14135')
group.add_argument('--optimizer', type=str, default='adam', group.add_argument('--optimizer', type=str, default='adam',
choices=['adam', 'sgd'], choices=['adam', 'sgd'],
help='Optimizer function') help='Optimizer function')
......
...@@ -15,6 +15,16 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax ...@@ -15,6 +15,16 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
try:
from einops import rearrange
except ImportError:
rearrange = None
try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
except ImportError:
flash_attn_unpadded_func = None
""" We use the following notation throughout this file: """ We use the following notation throughout this file:
h: hidden size h: hidden size
...@@ -306,6 +316,48 @@ class CoreAttention(MegatronModule): ...@@ -306,6 +316,48 @@ class CoreAttention(MegatronModule):
return context_layer return context_layer
class FlashSelfAttention(torch.nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
device=None, dtype=None):
super().__init__()
assert flash_attn_unpadded_func is not None, ('Please install FlashAttention first, '
'e.g., with pip install flash-attn')
assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
def forward(self, q, k, v):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
"""
assert q.dtype in [torch.float16, torch.bfloat16]
assert q.is_cuda
batch_size, seqlen = q.shape[0], q.shape[1]
q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
max_s = seqlen
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=q.device)
output = flash_attn_unpadded_func(
q, k, v, cu_seqlens, cu_seqlens, max_s, max_s,
self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=self.causal
)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
return output
class ParallelAttention(MegatronModule): class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class. """Parallel self-attention layer abstract class.
...@@ -323,6 +375,21 @@ class ParallelAttention(MegatronModule): ...@@ -323,6 +375,21 @@ class ParallelAttention(MegatronModule):
self.attention_type = attention_type self.attention_type = attention_type
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.params_dtype = args.params_dtype self.params_dtype = args.params_dtype
self.sequence_parallel = args.sequence_parallel
self.use_flash_attn = args.use_flash_attn
if self.use_flash_attn:
if flash_attn_unpadded_func is None:
raise ImportError('FlashAttention is not installed, please install with '
'pip install flash-attn')
assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports '
'self-attention for now')
assert self.attn_mask_type == AttnMaskType.causal, ('FlashAttention code path only '
'supports causal mask for now')
headdim = args.hidden_size / args.num_attention_heads
assert headdim <= 128, 'FlashAttention only supports head dimension at most 128'
if rearrange is None:
raise ImportError('einops is not installed, please install with pip install einops')
projection_size = args.kv_channels * args.num_attention_heads projection_size = args.kv_channels * args.num_attention_heads
...@@ -365,6 +432,11 @@ class ParallelAttention(MegatronModule): ...@@ -365,6 +432,11 @@ class ParallelAttention(MegatronModule):
self.attn_mask_type) self.attn_mask_type)
self.checkpoint_core_attention = args.recompute_granularity == 'selective' self.checkpoint_core_attention = args.recompute_granularity == 'selective'
if self.use_flash_attn:
self.core_attention_flash = FlashSelfAttention(
causal=True, attention_dropout=args.attention_dropout
)
# Output. # Output.
self.dense = tensor_parallel.RowParallelLinear( self.dense = tensor_parallel.RowParallelLinear(
projection_size, projection_size,
...@@ -487,12 +559,22 @@ class ParallelAttention(MegatronModule): ...@@ -487,12 +559,22 @@ class ParallelAttention(MegatronModule):
# core attention computation # core attention computation
# ================================== # ==================================
if self.checkpoint_core_attention: if not self.use_flash_attn:
context_layer = self._checkpointed_attention_forward( if self.checkpoint_core_attention:
query_layer, key_layer, value_layer, attention_mask) context_layer = self._checkpointed_attention_forward(
query_layer, key_layer, value_layer, attention_mask)
else:
context_layer = self.core_attention(
query_layer, key_layer, value_layer, attention_mask)
else: else:
context_layer = self.core_attention( q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous()
query_layer, key_layer, value_layer, attention_mask) for x in (query_layer, key_layer, value_layer)]
if not self.sequence_parallel:
with tensor_parallel.get_cuda_rng_tracker().fork():
context_layer = self.core_attention_flash(q, k, v)
else:
context_layer = self.core_attention_flash(q, k, v)
context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()
# ================= # =================
# Output. [sq, b, h] # Output. [sq, b, h]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment