Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wuxk1
Megatron-LM
Commits
d693034e
Commit
d693034e
authored
Dec 08, 2022
by
Tri Dao
Browse files
Integrate FlashAttention into Megatron-LM
parent
52e63688
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
102 additions
and
5 deletions
+102
-5
README.md
README.md
+12
-0
megatron/arguments.py
megatron/arguments.py
+3
-0
megatron/model/transformer.py
megatron/model/transformer.py
+87
-5
No files found.
README.md
View file @
d693034e
...
@@ -333,6 +333,18 @@ Theoretical memory savings vary depending on the combination of the model's para
...
@@ -333,6 +333,18 @@ Theoretical memory savings vary depending on the combination of the model's para
| bf16 param, fp32 grads | 18 | 6 + 12/d |
| bf16 param, fp32 grads | 18 | 6 + 12/d |
| fp32 param, fp32 grads | 16 | 8 + 8/d |
| fp32 param, fp32 grads | 16 | 8 + 8/d |
## FlashAttention
Usage:
`--use-flash-attn`
. Support attention head dimensions at most 128.
[
FlashAttention
](
https://github.com/HazyResearch/flash-attention
)
is a fast and
memory-efficient algorithm to compute exact attention. It speeds up model
training and reduces memory requirement.
To install FlashAttention:
```
sh
pip
install
flash-attn
```
## GPT-3 Example
## GPT-3 Example
...
...
megatron/arguments.py
View file @
d693034e
...
@@ -612,6 +612,9 @@ def _add_training_args(parser):
...
@@ -612,6 +612,9 @@ def _add_training_args(parser):
group
.
add_argument
(
'--no-bias-dropout-fusion'
,
action
=
'store_false'
,
group
.
add_argument
(
'--no-bias-dropout-fusion'
,
action
=
'store_false'
,
help
=
'Disable bias and dropout fusion.'
,
help
=
'Disable bias and dropout fusion.'
,
dest
=
'bias_dropout_fusion'
)
dest
=
'bias_dropout_fusion'
)
group
.
add_argument
(
'--use-flash-attn'
,
action
=
'store_true'
,
help
=
'use FlashAttention implementation of attention. '
'https://arxiv.org/abs/2205.14135'
)
group
.
add_argument
(
'--optimizer'
,
type
=
str
,
default
=
'adam'
,
group
.
add_argument
(
'--optimizer'
,
type
=
str
,
default
=
'adam'
,
choices
=
[
'adam'
,
'sgd'
],
choices
=
[
'adam'
,
'sgd'
],
help
=
'Optimizer function'
)
help
=
'Optimizer function'
)
...
...
megatron/model/transformer.py
View file @
d693034e
...
@@ -15,6 +15,16 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax
...
@@ -15,6 +15,16 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.utils
import
attention_mask_func
,
openai_gelu
,
erf_gelu
from
megatron.model.utils
import
attention_mask_func
,
openai_gelu
,
erf_gelu
try
:
from
einops
import
rearrange
except
ImportError
:
rearrange
=
None
try
:
from
flash_attn.flash_attn_interface
import
flash_attn_unpadded_func
except
ImportError
:
flash_attn_unpadded_func
=
None
""" We use the following notation throughout this file:
""" We use the following notation throughout this file:
h: hidden size
h: hidden size
...
@@ -306,6 +316,48 @@ class CoreAttention(MegatronModule):
...
@@ -306,6 +316,48 @@ class CoreAttention(MegatronModule):
return
context_layer
return
context_layer
class
FlashSelfAttention
(
torch
.
nn
.
Module
):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
,
device
=
None
,
dtype
=
None
):
super
().
__init__
()
assert
flash_attn_unpadded_func
is
not
None
,
(
'Please install FlashAttention first, '
'e.g., with pip install flash-attn'
)
assert
rearrange
is
not
None
,
'Please install einops first, e.g., with pip install einops'
self
.
causal
=
causal
self
.
softmax_scale
=
softmax_scale
self
.
dropout_p
=
attention_dropout
def
forward
(
self
,
q
,
k
,
v
):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
"""
assert
q
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
assert
q
.
is_cuda
batch_size
,
seqlen
=
q
.
shape
[
0
],
q
.
shape
[
1
]
q
,
k
,
v
=
[
rearrange
(
x
,
'b s ... -> (b s) ...'
)
for
x
in
[
q
,
k
,
v
]]
max_s
=
seqlen
cu_seqlens
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen
,
step
=
seqlen
,
dtype
=
torch
.
int32
,
device
=
q
.
device
)
output
=
flash_attn_unpadded_func
(
q
,
k
,
v
,
cu_seqlens
,
cu_seqlens
,
max_s
,
max_s
,
self
.
dropout_p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
self
.
causal
)
output
=
rearrange
(
output
,
'(b s) ... -> b s ...'
,
b
=
batch_size
)
return
output
class
ParallelAttention
(
MegatronModule
):
class
ParallelAttention
(
MegatronModule
):
"""Parallel self-attention layer abstract class.
"""Parallel self-attention layer abstract class.
...
@@ -323,6 +375,21 @@ class ParallelAttention(MegatronModule):
...
@@ -323,6 +375,21 @@ class ParallelAttention(MegatronModule):
self
.
attention_type
=
attention_type
self
.
attention_type
=
attention_type
self
.
attn_mask_type
=
attn_mask_type
self
.
attn_mask_type
=
attn_mask_type
self
.
params_dtype
=
args
.
params_dtype
self
.
params_dtype
=
args
.
params_dtype
self
.
sequence_parallel
=
args
.
sequence_parallel
self
.
use_flash_attn
=
args
.
use_flash_attn
if
self
.
use_flash_attn
:
if
flash_attn_unpadded_func
is
None
:
raise
ImportError
(
'FlashAttention is not installed, please install with '
'pip install flash-attn'
)
assert
attention_type
==
AttnType
.
self_attn
,
(
'FlashAttention code path only supports '
'self-attention for now'
)
assert
self
.
attn_mask_type
==
AttnMaskType
.
causal
,
(
'FlashAttention code path only '
'supports causal mask for now'
)
headdim
=
args
.
hidden_size
/
args
.
num_attention_heads
assert
headdim
<=
128
,
'FlashAttention only supports head dimension at most 128'
if
rearrange
is
None
:
raise
ImportError
(
'einops is not installed, please install with pip install einops'
)
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
...
@@ -365,6 +432,11 @@ class ParallelAttention(MegatronModule):
...
@@ -365,6 +432,11 @@ class ParallelAttention(MegatronModule):
self
.
attn_mask_type
)
self
.
attn_mask_type
)
self
.
checkpoint_core_attention
=
args
.
recompute_granularity
==
'selective'
self
.
checkpoint_core_attention
=
args
.
recompute_granularity
==
'selective'
if
self
.
use_flash_attn
:
self
.
core_attention_flash
=
FlashSelfAttention
(
causal
=
True
,
attention_dropout
=
args
.
attention_dropout
)
# Output.
# Output.
self
.
dense
=
tensor_parallel
.
RowParallelLinear
(
self
.
dense
=
tensor_parallel
.
RowParallelLinear
(
projection_size
,
projection_size
,
...
@@ -487,12 +559,22 @@ class ParallelAttention(MegatronModule):
...
@@ -487,12 +559,22 @@ class ParallelAttention(MegatronModule):
# core attention computation
# core attention computation
# ==================================
# ==================================
if
not
self
.
use_flash_attn
:
if
self
.
checkpoint_core_attention
:
if
self
.
checkpoint_core_attention
:
context_layer
=
self
.
_checkpointed_attention_forward
(
context_layer
=
self
.
_checkpointed_attention_forward
(
query_layer
,
key_layer
,
value_layer
,
attention_mask
)
query_layer
,
key_layer
,
value_layer
,
attention_mask
)
else
:
else
:
context_layer
=
self
.
core_attention
(
context_layer
=
self
.
core_attention
(
query_layer
,
key_layer
,
value_layer
,
attention_mask
)
query_layer
,
key_layer
,
value_layer
,
attention_mask
)
else
:
q
,
k
,
v
=
[
rearrange
(
x
,
's b ... -> b s ...'
).
contiguous
()
for
x
in
(
query_layer
,
key_layer
,
value_layer
)]
if
not
self
.
sequence_parallel
:
with
tensor_parallel
.
get_cuda_rng_tracker
().
fork
():
context_layer
=
self
.
core_attention_flash
(
q
,
k
,
v
)
else
:
context_layer
=
self
.
core_attention_flash
(
q
,
k
,
v
)
context_layer
=
rearrange
(
context_layer
,
'b s h d -> s b (h d)'
).
contiguous
()
# =================
# =================
# Output. [sq, b, h]
# Output. [sq, b, h]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment