Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
8a13fabe
Commit
8a13fabe
authored
Dec 12, 2024
by
wxj
Browse files
Merge branch 'main' into 'main'
添加优化项, 添加qwen和llama3 See merge request
!5
parents
f5ca0d94
425a2473
Pipeline
#2076
failed with stages
in 0 seconds
Changes
6
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
188 additions
and
17 deletions
+188
-17
Llama_pretraining.sh
Llama_pretraining.sh
+8
-5
megatron/legacy/model/rms_norm.py
megatron/legacy/model/rms_norm.py
+4
-0
megatron/legacy/model/transformer.py
megatron/legacy/model/transformer.py
+69
-10
megatron/legacy/model/utils.py
megatron/legacy/model/utils.py
+3
-1
megatron/training/arguments.py
megatron/training/arguments.py
+8
-1
megatron/training/tokenizer/tokenizer.py
megatron/training/tokenizer/tokenizer.py
+96
-0
No files found.
Llama_pretraining.sh
View file @
8a13fabe
...
...
@@ -28,12 +28,12 @@ TENSORBOARD_LOGS_PATH=./tmp_7b #$2 #<Specify path>
DATA_PATH
=
"/datasets/oscar-1GB-llama_text_document"
#<Specify path and file prefix>_text_document
GPT_MODEL_ARGS
=(
--num-layers
6
--hidden-size
1024
--ffn-hidden-size
2048
--num-attention-heads
16
--num-layers
3
6
--hidden-size
4096
--ffn-hidden-size
11008
--num-attention-heads
32
--seq-length
4096
#4096
--max-position-embeddings
32768
--max-position-embeddings
4096
)
# export NVTE_FLASH_ATTN=1 # 走autlass
...
...
@@ -69,7 +69,10 @@ TRAINING_ARGS=(
--lr-decay-style
cosine
--min-lr
3.0e-6
--lr-warmup-iters
1
--use-flash-attn-triton
)
# --use-flash-attn-ck
# --use-flash-attn-triton
MODEL_PARALLEL_ARGS
=(
--sequence-parallel
...
...
megatron/legacy/model/rms_norm.py
View file @
8a13fabe
...
...
@@ -2,6 +2,8 @@
import
torch
from
torch
import
nn
import
torch._dynamo
torch
.
_dynamo
.
config
.
suppress_errors
=
True
class
RMSNorm
(
torch
.
nn
.
Module
):
...
...
@@ -24,9 +26,11 @@ class RMSNorm(torch.nn.Module):
setattr
(
self
.
weight
,
'sequence_parallel'
,
sequence_parallel
)
@
torch
.
compile
(
mode
=
"max-autotune-no-cudagraphs"
)
def
_norm
(
self
,
x
):
return
x
*
torch
.
rsqrt
(
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
self
.
eps
)
@
torch
.
compile
(
mode
=
"max-autotune-no-cudagraphs"
)
def
forward
(
self
,
x
):
output
=
self
.
_norm
(
x
.
float
()).
type_as
(
x
)
return
output
*
self
.
weight
megatron/legacy/model/transformer.py
View file @
8a13fabe
...
...
@@ -40,6 +40,9 @@ from megatron.legacy.model.utils import (
)
from
megatron.training
import
get_args
,
get_timers
import
torch._dynamo
torch
.
_dynamo
.
config
.
suppress_errors
=
True
from
.module
import
MegatronModule
try
:
...
...
@@ -57,6 +60,10 @@ except ImportError:
except
ImportError
:
flash_attn_unpadded_func
=
None
try
:
from
flash_attn.flash_attn_triton
import
flash_attn_func
except
ImportError
:
flash_attn_func
=
None
""" We use the following notation throughout this file:
h: hidden size
n: number of attention heads
...
...
@@ -133,6 +140,7 @@ class ParallelMLP(MegatronModule):
elif
args
.
onnx_safe
:
self
.
activation_func
=
erf_gelu
elif
args
.
swiglu
:
@
torch
.
compile
(
mode
=
"max-autotune-no-cudagraphs"
)
def
swiglu
(
x
):
x
=
torch
.
chunk
(
x
,
2
,
dim
=-
1
)
return
F
.
silu
(
x
[
0
])
*
x
[
1
]
...
...
@@ -157,6 +165,7 @@ class ParallelMLP(MegatronModule):
is_expert
=
is_expert
,
)
@
torch
.
compile
(
mode
=
"max-autotune-no-cudagraphs"
)
def
forward
(
self
,
hidden_states
):
# [s, b, 4hp]
...
...
@@ -468,6 +477,10 @@ class FlashSelfAttention(torch.nn.Module):
self
.
softmax_scale
=
softmax_scale
self
.
dropout_p
=
attention_dropout
# Use FlashAttention-2 when args.use_flash_attn_ck is True
args
=
get_args
()
self
.
flash_attn_func
=
flash_attn_unpadded_func
def
forward
(
self
,
q
,
k
,
v
):
"""Implements the multihead softmax attention.
Arguments
...
...
@@ -509,6 +522,38 @@ class FlashSelfAttention(torch.nn.Module):
output
=
rearrange
(
output
,
'(b s) ... -> b s ...'
,
b
=
batch_size
)
return
output
class
FlashSelfAttentionTriton
(
torch
.
nn
.
Module
):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
,
device
=
None
,
dtype
=
None
):
super
().
__init__
()
assert
flash_attn_func
is
not
None
,
(
'Triton version of FlashAttention is not installed.'
)
assert
rearrange
is
not
None
,
'Please install einops first, e.g., with pip install einops'
self
.
causal
=
causal
self
.
softmax_scale
=
softmax_scale
self
.
dropout_p
=
attention_dropout
def
forward
(
self
,
q
,
k
,
v
):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
"""
assert
q
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
assert
q
.
is_cuda
q
,
k
,
v
=
[
rearrange
(
x
,
's b h d -> b h s d'
).
contiguous
()
for
x
in
(
q
,
k
,
v
)]
output
=
flash_attn_func
(
q
,
k
,
v
,
self
.
causal
)
output
=
rearrange
(
output
,
'b s h d -> h b (s d)'
).
contiguous
()
return
output
class
ParallelAttention
(
MegatronModule
):
"""Parallel self-attention layer abstract class.
...
...
@@ -537,13 +582,19 @@ class ParallelAttention(MegatronModule):
else
:
kv_projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
self
.
use_flash_attn
=
args
.
use_flash_attn
\
self
.
use_flash_attn
=
(
args
.
use_flash_attn
_ck
or
args
.
use_flash_attn_triton
)
\
and
attention_type
==
AttnType
.
self_attn
\
and
self
.
attn_mask_type
==
AttnMaskType
.
causal
self
.
use_flash_attn_triton
=
args
.
use_flash_attn_triton
if
self
.
use_flash_attn
:
if
args
.
use_flash_attn_ck
:
if
flash_attn_unpadded_func
is
None
:
raise
ImportError
(
'FlashAttention is not installed, please install with '
'pip install flash-attn'
)
if
args
.
use_flash_attn_triton
:
assert
flash_attn_func
!=
None
,
"Cannot import FlashAttention triton "
assert
attention_type
==
AttnType
.
self_attn
,
(
'FlashAttention code path only supports '
'self-attention for now'
)
assert
self
.
attn_mask_type
==
AttnMaskType
.
causal
,
(
'FlashAttention code path only '
...
...
@@ -603,7 +654,11 @@ class ParallelAttention(MegatronModule):
self
.
attn_mask_type
)
self
.
checkpoint_core_attention
=
config
.
recompute_granularity
==
'selective'
if
self
.
use_flash_attn
:
if
self
.
use_flash_attn_triton
:
self
.
core_attention_flash
=
FlashSelfAttentionTriton
(
causal
=
True
,
attention_dropout
=
args
.
attention_dropout
)
elif
self
.
use_flash_attn
:
self
.
core_attention_flash
=
FlashSelfAttention
(
causal
=
True
,
attention_dropout
=
config
.
attention_dropout
)
...
...
@@ -711,7 +766,7 @@ class ParallelAttention(MegatronModule):
dim
=
3
)
# [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] -
query_layer
=
query_layer
.
view
(
query_layer
.
size
(
0
),
query_layer
.
size
(
1
),
-
1
,
self
.
hidden_size_per_attention_head
)
query_layer
=
query_layer
.
contiguous
().
view
(
query_layer
.
size
(
0
),
query_layer
.
size
(
1
),
-
1
,
self
.
hidden_size_per_attention_head
)
else
:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer
,
_
=
self
.
key_value
(
encoder_output
)
...
...
@@ -816,13 +871,17 @@ class ParallelAttention(MegatronModule):
context_layer
=
self
.
core_attention
(
query_layer
,
key_layer
,
value_layer
,
attention_mask
)
else
:
q
,
k
,
v
=
[
rearrange
(
x
,
's b ... -> b s ...'
).
contiguous
()
if
not
self
.
use_flash_attn_triton
:
query_layer
,
key_layer
,
value_layer
=
[
rearrange
(
x
,
's b ... -> b s ...'
).
contiguous
()
for
x
in
(
query_layer
,
key_layer
,
value_layer
)]
if
not
self
.
sequence_parallel
:
with
tensor_parallel
.
get_cuda_rng_tracker
().
fork
():
context_layer
=
self
.
core_attention_flash
(
q
,
k
,
v
)
context_layer
=
self
.
core_attention_flash
(
q
uery_layer
,
key_layer
,
value_layer
)
else
:
context_layer
=
self
.
core_attention_flash
(
q
,
k
,
v
)
context_layer
=
self
.
core_attention_flash
(
query_layer
,
key_layer
,
value_layer
)
if
not
self
.
use_flash_attn_triton
:
context_layer
=
rearrange
(
context_layer
,
'b s h d -> s b (h d)'
).
contiguous
()
# =================
...
...
megatron/legacy/model/utils.py
View file @
8a13fabe
...
...
@@ -9,6 +9,8 @@ import torch
from
megatron.training
import
get_args
from
megatron.legacy.model
import
LayerNorm
,
RMSNorm
from
megatron.core.jit
import
jit_fuser
import
torch._dynamo
torch
.
_dynamo
.
config
.
suppress_errors
=
True
def
init_method_normal
(
sigma
):
"""Init method based on N(0, sigma)."""
...
...
@@ -58,7 +60,7 @@ def openai_gelu(x):
def
erf_gelu
(
x
):
return
x
*
0.5
*
(
torch
.
erf
(
x
/
1.41421
).
to
(
dtype
=
x
.
dtype
)
+
torch
.
ones_like
(
x
).
to
(
dtype
=
x
.
dtype
))
@
torch
.
compile
(
mode
=
"max-autotune-no-cudagraphs"
)
def
get_norm
(
config
):
args
=
get_args
()
if
args
.
normalization
==
"LayerNorm"
:
...
...
megatron/training/arguments.py
View file @
8a13fabe
...
...
@@ -642,6 +642,9 @@ def validate_args(args, defaults={}):
assert
not
args
.
use_legacy_models
,
\
'--decoupled-lr and --decoupled-min-lr is not supported in legacy models.'
# FlashAttention
args
.
use_flash_attn
=
args
.
use_flash_attn_ck
or
args
.
use_flash_attn_triton
# Legacy RoPE arguments
if
args
.
use_rotary_position_embeddings
:
args
.
position_embedding_type
=
'rope'
...
...
@@ -1355,9 +1358,11 @@ def _add_training_args(parser):
group
.
add_argument
(
'--cross-entropy-loss-fusion'
,
action
=
'store_true'
,
help
=
'Enabled fusion of cross entropy loss calculation.'
,
dest
=
'cross_entropy_loss_fusion'
)
group
.
add_argument
(
'--use-flash-attn'
,
action
=
'store_true'
,
group
.
add_argument
(
'--use-flash-attn
-ck
'
,
action
=
'store_true'
,
help
=
'use FlashAttention implementation of attention. '
'https://arxiv.org/abs/2205.14135'
)
group
.
add_argument
(
'--use-flash-attn-triton'
,
action
=
'store_true'
,
help
=
'use FlashAttention implementation of attention using Triton.'
)
group
.
add_argument
(
'--disable-bias-linear'
,
action
=
'store_false'
,
help
=
'Disable bias in the linear layers'
,
dest
=
'add_bias_linear'
)
...
...
@@ -1824,6 +1829,8 @@ def _add_tokenizer_args(parser):
'GPTSentencePieceTokenizer'
,
'HuggingFaceTokenizer'
,
'Llama2Tokenizer'
,
'Llama3Tokenizer'
,
'QwenTokenizer'
,
'TikTokenizer'
,
'MultimodalTokenizer'
,
'NullTokenizer'
],
...
...
megatron/training/tokenizer/tokenizer.py
View file @
8a13fabe
...
...
@@ -15,6 +15,7 @@ from megatron.core.datasets.megatron_tokenizer import MegatronTokenizer
from
.bert_tokenization
import
FullTokenizer
as
FullBertTokenizer
from
.gpt2_tokenization
import
GPT2Tokenizer
from
megatron.training.tokenizer.multimodal_tokenizer
import
MultimodalTokenizer
from
transformers
import
Qwen2Tokenizer
def
build_tokenizer
(
args
,
**
kwargs
):
...
...
@@ -50,6 +51,11 @@ def build_tokenizer(args, **kwargs):
elif
args
.
tokenizer_type
==
'Llama2Tokenizer'
:
assert
args
.
tokenizer_model
is
not
None
tokenizer
=
_Llama2Tokenizer
(
args
.
tokenizer_model
)
elif
args
.
tokenizer_type
==
'Llama3Tokenizer'
:
assert
args
.
tokenizer_model
is
not
None
tokenizer
=
_Llama3Tokenizer
(
args
.
tokenizer_model
)
elif
args
.
tokenizer_type
==
'QwenTokenizer'
:
tokenizer
=
_Qwen2Tokenizer
(
args
.
vocab_file
,
args
.
merge_file
)
elif
args
.
tokenizer_type
==
'TikTokenizer'
:
assert
args
.
tokenizer_model
is
not
None
assert
args
.
tiktoken_pattern
is
not
None
...
...
@@ -606,6 +612,96 @@ class _Llama2Tokenizer(_SentencePieceTokenizer):
return
None
class
_Llama3Tokenizer
(
MegatronTokenizer
):
"""tiktokenTokenizer-Megatron llama3 改写"""
# https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py
def
__init__
(
self
,
model_file
):
super
().
__init__
(
model_file
)
from
pathlib
import
Path
import
tiktoken
from
tiktoken.load
import
load_tiktoken_bpe
tokenizer_path
=
model_file
special_tokens
=
[
"<|begin_of_text|>"
,
"<|end_of_text|>"
,
"<|reserved_special_token_0|>"
,
"<|reserved_special_token_1|>"
,
"<|reserved_special_token_2|>"
,
"<|reserved_special_token_3|>"
,
"<|start_header_id|>"
,
"<|end_header_id|>"
,
"<|reserved_special_token_4|>"
,
"<|eot_id|>"
,
# end of turn
]
+
[
f
"<|reserved_special_token_
{
i
}
|>"
for
i
in
range
(
5
,
256
-
5
)]
mergeable_ranks
=
load_tiktoken_bpe
(
tokenizer_path
)
self
.
tokenizer
=
tiktoken
.
Encoding
(
tokenizer_path
,
pat_str
=
r
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
,
mergeable_ranks
=
mergeable_ranks
,
special_tokens
=
{
token
:
len
(
mergeable_ranks
)
+
i
for
i
,
token
in
enumerate
(
special_tokens
)},
)
self
.
eod_id
=
self
.
tokenizer
.
encode
(
"<|end_of_text|>"
,
allowed_special
=
"all"
)[
0
]
@
property
def
vocab_size
(
self
):
return
self
.
tokenizer
.
n_vocab
@
property
def
vocab
(
self
):
return
self
.
tokenizer
.
encode
@
property
def
inv_vocab
(
self
):
return
self
.
tokenizer
.
encode
def
tokenize
(
self
,
text
):
return
self
.
tokenizer
.
encode
(
text
)
def
detokenize
(
self
,
token_ids
):
return
self
.
tokenizer
.
encode
(
token_ids
)
@
property
def
eod
(
self
):
return
self
.
eod_id
class
_Qwen2Tokenizer
(
MegatronTokenizer
):
def
__init__
(
self
,
vocab_file
,
merge_file
,
extra_vocab_size
=
0
):
super
().
__init__
(
vocab_file
,
merge_file
)
self
.
tokenizer
=
Qwen2Tokenizer
(
vocab_file
,
merge_file
)
self
.
extra_vocab_size
=
extra_vocab_size
self
.
tokenizer
.
add_special_tokens
(
special_tokens_dict
=
dict
(
pad_token
=
"<|extra_0|>"
))
@
property
def
vocab_size
(
self
):
return
len
(
self
.
tokenizer
.
encoder
)
+
self
.
extra_vocab_size
@
property
def
vocab
(
self
):
return
self
.
tokenizer
.
encoder
@
property
def
inv_vocab
(
self
):
return
self
.
tokenizer
.
decoder
def
tokenize
(
self
,
text
):
return
self
.
tokenizer
.
encode
(
text
)
def
detokenize
(
self
,
token_ids
):
return
self
.
tokenizer
.
decode
(
token_ids
)
@
property
def
eod
(
self
):
return
self
.
tokenizer
.
eos_token_id
@
property
def
eos_token
(
self
):
return
self
.
tokenizer
.
eos_token
@
property
def
pad_token_id
(
self
):
return
self
.
tokenizer
.
pad_token_id
def
reload_mergeable_ranks
(
path
:
str
,
max_vocab
:
Optional
[
int
]
=
None
)
->
Dict
[
bytes
,
int
]:
"""
Reload our tokenizer JSON file and convert it to Tiktoken format.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment