Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wuxk1
Megatron-LM
Commits
d13ae546
Commit
d13ae546
authored
Jan 21, 2025
by
wuxk1
Browse files
add torch fa
parent
c271aaae
Pipeline
#2234
failed with stages
in 0 seconds
Changes
3
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
53 additions
and
4 deletions
+53
-4
megatron/legacy/model/transformer.py
megatron/legacy/model/transformer.py
+34
-3
megatron/training/arguments.py
megatron/training/arguments.py
+3
-1
run.sh
run.sh
+16
-0
No files found.
megatron/legacy/model/transformer.py
View file @
d13ae546
...
@@ -456,6 +456,34 @@ class CoreAttention(MegatronModule):
...
@@ -456,6 +456,34 @@ class CoreAttention(MegatronModule):
return
context_layer
return
context_layer
class
FlashSelfAttentionTorch
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
,
device
=
None
,
dtype
=
None
):
super
().
__init__
()
assert
flash_attn_func
is
not
None
,
(
'Triton version of FlashAttention is not installed.'
)
assert
rearrange
is
not
None
,
'Please install einops first, e.g., with pip install einops'
self
.
causal
=
causal
self
.
softmax_scale
=
softmax_scale
self
.
attention_dropout
=
attention_dropout
def
forward
(
self
,
q
,
k
,
v
):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
"""
assert
q
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
assert
q
.
is_cuda
if
os
.
environ
.
get
(
'USE_BSHD'
,
None
):
q
,
k
,
v
=
[
rearrange
(
x
,
's b h d -> b s h d'
).
contiguous
()
for
x
in
(
q
,
k
,
v
)]
else
:
q
,
k
,
v
=
[
rearrange
(
x
,
's b h d -> b h s d'
).
contiguous
()
for
x
in
(
q
,
k
,
v
)]
output
=
SDPA
(
q
,
k
,
v
,
is_causal
=
self
.
causal
,
dropout_p
=
self
.
attention_dropout
,
scale
=
self
.
softmax_scale
)
if
os
.
environ
.
get
(
'USE_BSHD'
,
None
):
output
=
rearrange
(
output
,
'b s h d -> s b (h d)'
).
contiguous
()
else
:
output
=
rearrange
(
output
,
'b h s d -> s b (h d)'
).
contiguous
()
return
output
class
FlashSelfAttention
(
torch
.
nn
.
Module
):
class
FlashSelfAttention
(
torch
.
nn
.
Module
):
"""Implement the scaled dot product attention with softmax.
"""Implement the scaled dot product attention with softmax.
...
@@ -582,10 +610,11 @@ class ParallelAttention(MegatronModule):
...
@@ -582,10 +610,11 @@ class ParallelAttention(MegatronModule):
else
:
else
:
kv_projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
kv_projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
self
.
use_flash_attn
=
(
args
.
use_flash_attn_cutlass
or
args
.
use_flash_attn_triton
)
\
self
.
use_flash_attn
=
(
args
.
use_flash_attn_cutlass
or
args
.
use_flash_attn_triton
or
args
.
use_flash_attn_torch
)
\
and
attention_type
==
AttnType
.
self_attn
\
and
attention_type
==
AttnType
.
self_attn
\
and
self
.
attn_mask_type
==
AttnMaskType
.
causal
and
self
.
attn_mask_type
==
AttnMaskType
.
causal
self
.
use_flash_attn_triton
=
args
.
use_flash_attn_triton
self
.
use_flash_attn_triton
=
args
.
use_flash_attn_triton
self
.
use_flash_attn_torch
=
args
.
use_flash_attn_torch
if
self
.
use_flash_attn
:
if
self
.
use_flash_attn
:
if
args
.
use_flash_attn_cutlass
:
if
args
.
use_flash_attn_cutlass
:
...
@@ -658,6 +687,8 @@ class ParallelAttention(MegatronModule):
...
@@ -658,6 +687,8 @@ class ParallelAttention(MegatronModule):
self
.
core_attention_flash
=
FlashSelfAttentionTriton
(
self
.
core_attention_flash
=
FlashSelfAttentionTriton
(
causal
=
True
,
attention_dropout
=
args
.
attention_dropout
causal
=
True
,
attention_dropout
=
args
.
attention_dropout
)
)
elif
self
.
use_flash_attn_torch
:
self
.
core_attention_flash
=
FlashSelfAttentionTorch
(
causal
=
True
,
attention_dropout
=
config
.
attention_dropout
)
elif
self
.
use_flash_attn
:
elif
self
.
use_flash_attn
:
self
.
core_attention_flash
=
FlashSelfAttention
(
self
.
core_attention_flash
=
FlashSelfAttention
(
causal
=
True
,
attention_dropout
=
config
.
attention_dropout
causal
=
True
,
attention_dropout
=
config
.
attention_dropout
...
@@ -871,7 +902,7 @@ class ParallelAttention(MegatronModule):
...
@@ -871,7 +902,7 @@ class ParallelAttention(MegatronModule):
context_layer
=
self
.
core_attention
(
context_layer
=
self
.
core_attention
(
query_layer
,
key_layer
,
value_layer
,
attention_mask
)
query_layer
,
key_layer
,
value_layer
,
attention_mask
)
else
:
else
:
if
not
self
.
use_flash_attn_triton
:
if
not
self
.
use_flash_attn_triton
and
not
self
.
use_flash_attn_torch
:
query_layer
,
key_layer
,
value_layer
=
[
rearrange
(
x
,
's b ... -> b s ...'
).
contiguous
()
query_layer
,
key_layer
,
value_layer
=
[
rearrange
(
x
,
's b ... -> b s ...'
).
contiguous
()
for
x
in
(
query_layer
,
key_layer
,
value_layer
)]
for
x
in
(
query_layer
,
key_layer
,
value_layer
)]
...
@@ -881,7 +912,7 @@ class ParallelAttention(MegatronModule):
...
@@ -881,7 +912,7 @@ class ParallelAttention(MegatronModule):
else
:
else
:
context_layer
=
self
.
core_attention_flash
(
query_layer
,
key_layer
,
value_layer
)
context_layer
=
self
.
core_attention_flash
(
query_layer
,
key_layer
,
value_layer
)
if
not
self
.
use_flash_attn_triton
:
if
not
self
.
use_flash_attn_triton
and
not
self
.
use_flash_attn_torch
:
context_layer
=
rearrange
(
context_layer
,
'b s h d -> s b (h d)'
).
contiguous
()
context_layer
=
rearrange
(
context_layer
,
'b s h d -> s b (h d)'
).
contiguous
()
# =================
# =================
...
...
megatron/training/arguments.py
View file @
d13ae546
...
@@ -643,7 +643,7 @@ def validate_args(args, defaults={}):
...
@@ -643,7 +643,7 @@ def validate_args(args, defaults={}):
'--decoupled-lr and --decoupled-min-lr is not supported in legacy models.'
'--decoupled-lr and --decoupled-min-lr is not supported in legacy models.'
# FlashAttention
# FlashAttention
args
.
use_flash_attn
=
args
.
use_flash_attn_cutlass
or
args
.
use_flash_attn_triton
args
.
use_flash_attn
=
args
.
use_flash_attn_cutlass
or
args
.
use_flash_attn_triton
or
args
.
use_flash_attn_torch
# Legacy RoPE arguments
# Legacy RoPE arguments
if
args
.
use_rotary_position_embeddings
:
if
args
.
use_rotary_position_embeddings
:
...
@@ -1368,6 +1368,8 @@ def _add_training_args(parser):
...
@@ -1368,6 +1368,8 @@ def _add_training_args(parser):
'https://arxiv.org/abs/2205.14135'
)
'https://arxiv.org/abs/2205.14135'
)
group
.
add_argument
(
'--use-flash-attn-triton'
,
action
=
'store_true'
,
group
.
add_argument
(
'--use-flash-attn-triton'
,
action
=
'store_true'
,
help
=
'use FlashAttention implementation of attention using Triton.'
)
help
=
'use FlashAttention implementation of attention using Triton.'
)
group
.
add_argument
(
'--use-flash-attn-torch'
,
action
=
'store_true'
,
help
=
'use FlashAttention implementation of attention using torch.'
)
group
.
add_argument
(
'--disable-bias-linear'
,
action
=
'store_false'
,
group
.
add_argument
(
'--disable-bias-linear'
,
action
=
'store_false'
,
help
=
'Disable bias in the linear layers'
,
help
=
'Disable bias in the linear layers'
,
dest
=
'add_bias_linear'
)
dest
=
'add_bias_linear'
)
...
...
run.sh
0 → 100644
View file @
d13ae546
export
TORCHINDUCTOR_COORDINATE_DESCENT_TUNING
=
1
export
TORCHINDUCTOR_BENCHMARK_FUSION
=
1
export
TORCHINDUCTOR_BENCHMARK_MULTI_TEMPLATES
=
1
# export TORCHINDUCTOR_BENCHMARK_KERNEL=1
export
TORCHINDUCTOR_MAX_AUTOTUNE
=
1
#export FLASH_ATTENTION_PRINT_PARAM=1
export
TORCHINDUCTOR_CACHE_DIR
=
./cache
# export USE_AOTRITON_FA=1
# export USE_BSHD=1 # use fa bsdh layout
#for uniq kernel name
#export TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1
mpirun
--allow-run-as-root
-np
8 ./Llama_pretraining.sh localhost
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment