Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
899c20e8
Commit
899c20e8
authored
Mar 20, 2025
by
wangxj
Browse files
优化legacy的定长fa接口
parent
9dabea91
Pipeline
#2564
passed with stage
Changes
3
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
50 additions
and
4 deletions
+50
-4
examples/llama/Llama2_7b.sh
examples/llama/Llama2_7b.sh
+4
-3
megatron/legacy/model/transformer.py
megatron/legacy/model/transformer.py
+44
-1
pretrain_gpt.py
pretrain_gpt.py
+2
-0
No files found.
examples/llama/Llama2_7b.sh
View file @
899c20e8
...
...
@@ -56,7 +56,7 @@ export cache_size_limit=64
# CHECKPOINT_PATH=./Llama-2-7b-hf-to-meg-tp1-pp2 #CHECKPOINT_PATH=./tmp_7b #
SAVE_PATH
=
./tmp_7b
TENSORBOARD_LOGS_PATH
=
./tmp_7b
#$2 #<Specify path>
DATA_PATH
=
"/public/home/wangxj/Downloads/datasets/oscar-1GB
-head
/oscar-1GB
_head
-llama2_text_document"
#<Specify path and file prefix>_text_document
DATA_PATH
=
"/public/home/wangxj/Downloads/datasets/oscar-1GB/oscar-1GB-llama2_text_document"
#<Specify path and file prefix>_text_document
# DATA_PATH="/data/datasets/oscar-1GB-head/oscar-1GB_head-llama2_text_document" #<Specify path and file prefix>_text_document
GPT_MODEL_ARGS
=(
...
...
@@ -83,6 +83,8 @@ TRAINING_ARGS=(
--micro-batch-size
1
--global-batch-size
256
#256 #240 #60 #512 #64
--train-iters
50
--eval-interval
10
--eval-iters
3
--weight-decay
0.1
--adam-beta1
0.9
--adam-beta2
0.95
...
...
@@ -125,6 +127,7 @@ MODEL_PARALLEL_ARGS=(
--sequence-parallel
--tensor-model-parallel-size
1
--pipeline-model-parallel-size
2
# --context-parallel-size 2
# --num-layers-per-virtual-pipeline-stage 4
# --microbatch-group-size-per-virtual-pipeline-stage 1
# --no-overlap-p2p-communication # 开启后
...
...
@@ -143,10 +146,8 @@ EVAL_AND_LOGGING_ARGS=(
--log-interval
1
--log-throughput
--save-interval
1000
--eval-interval
1000
--save
$SAVE_PATH
--load
$SAVE_PATH
--eval-iters
10
--tensorboard-dir
$TENSORBOARD_LOGS_PATH
)
...
...
megatron/legacy/model/transformer.py
View file @
899c20e8
...
...
@@ -47,6 +47,11 @@ try:
except
ImportError
:
rearrange
=
None
try
:
# 使用定长fa
from
flash_attn
import
flash_attn_func
except
ImportError
:
flash_attn_func
=
None
try
:
from
flash_attn.flash_attn_interface
import
flash_attn_unpadded_func
except
ImportError
:
...
...
@@ -510,6 +515,41 @@ class FlashSelfAttention(torch.nn.Module):
output
=
rearrange
(
output
,
'(b s) ... -> b s ...'
,
b
=
batch_size
)
return
output
class
FlashFixedSelfAttention
(
torch
.
nn
.
Module
):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
,
device
=
None
,
dtype
=
None
):
super
().
__init__
()
assert
flash_attn_func
is
not
None
,
(
'Please install FlashAttention first, '
'e.g., with pip install flash-attn'
)
assert
rearrange
is
not
None
,
'Please install einops first, e.g., with pip install einops'
self
.
causal
=
causal
self
.
softmax_scale
=
softmax_scale
self
.
dropout_p
=
attention_dropout
self
.
flash_attn_func
=
flash_attn_func
def
forward
(
self
,
q
,
k
,
v
):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
"""
assert
all
((
i
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
for
i
in
(
q
,
k
,
v
)))
assert
all
((
i
.
is_cuda
for
i
in
(
q
,
k
,
v
)))
output
=
self
.
flash_attn_func
(
q
,
k
,
v
,
dropout_p
=
self
.
dropout_p
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
self
.
causal
)
# [b,s,a,dim]
return
output
class
ParallelAttention
(
MegatronModule
):
"""Parallel self-attention layer abstract class.
...
...
@@ -605,7 +645,10 @@ class ParallelAttention(MegatronModule):
self
.
checkpoint_core_attention
=
config
.
recompute_granularity
==
'selective'
if
self
.
use_flash_attn
:
self
.
core_attention_flash
=
FlashSelfAttention
(
# self.core_attention_flash = FlashSelfAttention(
# causal=True, attention_dropout=config.attention_dropout
# )
self
.
core_attention_flash
=
FlashFixedSelfAttention
(
causal
=
True
,
attention_dropout
=
config
.
attention_dropout
)
...
...
pretrain_gpt.py
View file @
899c20e8
...
...
@@ -137,6 +137,8 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
rope_scaling
=
args
.
use_rope_scaling
)
print_rank_0
(
model
)
# model = torch.compile(model, mode="max-autotune-no-cudagraphs")
return
model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment