Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
1e0e555c
Commit
1e0e555c
authored
Mar 31, 2023
by
Mostofa Patwary
Browse files
merging rope to main
parent
035cae2e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
198 additions
and
37 deletions
+198
-37
megatron/arguments.py
megatron/arguments.py
+8
-0
megatron/model/language_model.py
megatron/model/language_model.py
+62
-24
megatron/model/rotary_pos_embedding.py
megatron/model/rotary_pos_embedding.py
+56
-0
megatron/model/transformer.py
megatron/model/transformer.py
+72
-13
No files found.
megatron/arguments.py
View file @
1e0e555c
...
...
@@ -509,6 +509,14 @@ def _add_network_size_args(parser):
group
.
add_argument
(
'--max-position-embeddings'
,
type
=
int
,
default
=
None
,
help
=
'Maximum number of position embeddings to use. '
'This is the size of position embedding.'
)
group
.
add_argument
(
'--use-rotary-position-embeddings'
,
action
=
'store_true'
,
help
=
'Use rotary positional embeddings or not'
)
group
.
add_argument
(
'--rotary-percent'
,
type
=
float
,
default
=
1.0
,
help
=
'Percent of rotary dimension to use, default 100%'
)
group
.
add_argument
(
'--no-position-embedding'
,
action
=
'store_false'
,
help
=
'Disable position embedding.'
,
dest
=
'add_position_embedding'
)
group
.
add_argument
(
'--make-vocab-size-divisible-by'
,
type
=
int
,
default
=
128
,
help
=
'Pad the vocab size to be divisible by this value.'
'This is added for computational efficieny reasons.'
)
...
...
megatron/model/language_model.py
View file @
1e0e555c
...
...
@@ -11,6 +11,7 @@ from megatron.core import mpu, tensor_parallel
from
.enums
import
LayerType
,
AttnMaskType
from
.module
import
MegatronModule
from
.retro_transformer
import
ParallelRetroEncoder
,
ParallelRetroTransformer
from
.rotary_pos_embedding
import
apply_rotary_pos_emb
,
RotaryEmbedding
from
.transformer
import
ParallelTransformer
from
.utils
import
get_linear_layer
from
.utils
import
init_method_normal
,
scaled_init_method_normal
...
...
@@ -158,12 +159,14 @@ class Embedding(MegatronModule):
self
.
_word_embeddings_key
=
'word_embeddings'
# Position embedding (serial).
self
.
position_embeddings
=
torch
.
nn
.
Embedding
(
max_sequence_length
,
self
.
hidden_size
)
self
.
_position_embeddings_key
=
'position_embeddings'
# Initialize the position embeddings.
if
args
.
perform_initialization
:
self
.
init_method
(
self
.
position_embeddings
.
weight
)
self
.
add_position_embedding
=
args
.
add_position_embedding
if
self
.
add_position_embedding
:
self
.
position_embeddings
=
torch
.
nn
.
Embedding
(
max_sequence_length
,
self
.
hidden_size
)
self
.
_position_embeddings_key
=
'position_embeddings'
# Initialize the position embeddings.
if
args
.
perform_initialization
:
self
.
init_method
(
self
.
position_embeddings
.
weight
)
# Token type embedding.
# Add this as an optional field that can be added through
...
...
@@ -188,8 +191,9 @@ class Embedding(MegatronModule):
"""Zero out all parameters in embedding."""
self
.
word_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
word_embeddings
.
weight
.
shared
=
True
self
.
position_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
position_embeddings
.
weight
.
shared
=
True
if
self
.
add_position_embedding
:
self
.
position_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
position_embeddings
.
weight
.
shared
=
True
if
self
.
num_tokentypes
>
0
:
self
.
tokentype_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
tokentype_embeddings
.
weight
.
shared
=
True
...
...
@@ -214,8 +218,12 @@ class Embedding(MegatronModule):
def
forward
(
self
,
input_ids
,
position_ids
,
tokentype_ids
=
None
):
# Embeddings.
words_embeddings
=
self
.
word_embeddings
(
input_ids
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
embeddings
=
words_embeddings
+
position_embeddings
if
self
.
add_position_embedding
:
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
embeddings
=
words_embeddings
+
position_embeddings
else
:
embeddings
=
words_embeddings
if
tokentype_ids
is
not
None
:
assert
self
.
tokentype_embeddings
is
not
None
embeddings
=
embeddings
+
self
.
tokentype_embeddings
(
tokentype_ids
)
...
...
@@ -246,8 +254,9 @@ class Embedding(MegatronModule):
state_dict_
[
self
.
_word_embeddings_key
]
\
=
self
.
word_embeddings
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
state_dict_
[
self
.
_position_embeddings_key
]
\
=
self
.
position_embeddings
.
state_dict
(
prefix
=
prefix
,
if
self
.
add_position_embedding
:
state_dict_
[
self
.
_position_embeddings_key
]
\
=
self
.
position_embeddings
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
if
self
.
num_tokentypes
>
0
:
state_dict_
[
self
.
_tokentype_embeddings_key
]
\
...
...
@@ -272,16 +281,17 @@ class Embedding(MegatronModule):
self
.
word_embeddings
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
# Position embedding.
if
self
.
_position_embeddings_key
in
state_dict
:
state_dict_
=
state_dict
[
self
.
_position_embeddings_key
]
else
:
# for backward compatibility.
state_dict_
=
{}
for
key
in
state_dict
.
keys
():
if
'position_embeddings'
in
key
:
state_dict_
[
key
.
split
(
'position_embeddings.'
)[
1
]]
\
=
state_dict
[
key
]
self
.
position_embeddings
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
if
self
.
add_position_embedding
:
if
self
.
_position_embeddings_key
in
state_dict
:
state_dict_
=
state_dict
[
self
.
_position_embeddings_key
]
else
:
# for backward compatibility.
state_dict_
=
{}
for
key
in
state_dict
.
keys
():
if
'position_embeddings'
in
key
:
state_dict_
[
key
.
split
(
'position_embeddings.'
)[
1
]]
\
=
state_dict
[
key
]
self
.
position_embeddings
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
# Tokentype embedding.
if
self
.
num_tokentypes
>
0
:
...
...
@@ -351,6 +361,23 @@ class TransformerLanguageModel(MegatronModule):
self
.
num_tokentypes
)
self
.
_embedding_key
=
'embedding'
# Rotary positional embeddings
self
.
use_rotary_position_embeddings
=
False
if
args
.
use_rotary_position_embeddings
:
self
.
seq_length
=
args
.
seq_length
rotary_dim
=
args
.
hidden_size
//
args
.
num_attention_heads
\
if
args
.
kv_channels
is
None
else
args
.
kv_channels
if
args
.
rotary_percent
<
1.0
:
rotary_dim
=
int
(
rotary_dim
*
args
.
rotary_percent
)
# partial rotary embeddings, which is better than full rotary
# Wang and Komatsuzaki et al
# https://github.com/kingoflolz/mesh-transformer-jax/
self
.
rotary_pos_emb
=
RotaryEmbedding
(
rotary_dim
)
self
.
use_rotary_position_embeddings
=
\
args
.
use_rotary_position_embeddings
# Retriever (bi-directional transformer with cross attention)
if
args
.
retro_add_retriever
:
self
.
retriever
=
ParallelRetroEncoder
(
...
...
@@ -458,6 +485,15 @@ class TransformerLanguageModel(MegatronModule):
else
:
encoder_input
=
None
# Rotary positional embeddings
rotary_pos_emb
=
None
if
self
.
use_rotary_position_embeddings
:
if
inference_params
is
not
None
:
rotary_pos_emb
=
\
self
.
rotary_pos_emb
(
inference_params
.
max_sequence_len
)
else
:
rotary_pos_emb
=
self
.
rotary_pos_emb
(
self
.
seq_length
)
# Run encoder.
if
enc_hidden_states
is
None
:
if
self
.
encoder
is
not
None
:
...
...
@@ -472,7 +508,8 @@ class TransformerLanguageModel(MegatronModule):
encoder_output
=
self
.
encoder
(
encoder_input
,
enc_attn_mask
,
inference_params
=
inference_params
)
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
)
else
:
encoder_output
=
self
.
encoder_hidden_state
else
:
...
...
@@ -505,7 +542,8 @@ class TransformerLanguageModel(MegatronModule):
dec_attn_mask
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
inference_params
=
inference_params
)
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
)
if
self
.
add_pooler
and
self
.
post_process
:
return
decoder_output
,
encoder_output
,
pooled_output
...
...
megatron/model/rotary_pos_embedding.py
0 → 100644
View file @
1e0e555c
# coding=utf-8
# The following code has been taken from https://github.com/NVIDIA/NeMo/blob/ \
# 782b4e1652aaa43c8be390d9db0dc89544afa080/nemo/collections/nlp/modules/ \
# common/megatron/rotary_pos_embedding.py
import
importlib.util
import
torch
from
torch
import
einsum
,
nn
__all__
=
[
'RotaryEmbedding'
,
'apply_rotary_pos_emb'
]
class
RotaryEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
inv_freq
=
1.0
/
(
10000
**
(
torch
.
arange
(
0
,
dim
,
2
).
float
()
/
dim
))
self
.
register_buffer
(
'inv_freq'
,
inv_freq
)
if
importlib
.
util
.
find_spec
(
'einops'
)
is
None
:
raise
RuntimeError
(
"einops is required for Rotary Embedding"
)
def
forward
(
self
,
max_seq_len
,
offset
=
0
):
seq
=
torch
.
arange
(
max_seq_len
,
device
=
self
.
inv_freq
.
device
)
+
offset
freqs
=
einsum
(
'i , j -> i j'
,
seq
.
type_as
(
self
.
inv_freq
),
self
.
inv_freq
)
# first part even vector components, second part odd vector components,
# 2 * dim in dimension size
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
# emb [seq_length, .., dim]
from
einops
import
rearrange
return
rearrange
(
emb
,
'n d -> n 1 1 d'
)
def
_rotate_half
(
x
):
"""
change sign so the last dimension becomes [-odd, +even]
"""
from
einops
import
rearrange
x
=
rearrange
(
x
,
'... (j d) -> ... j d'
,
j
=
2
)
x1
,
x2
=
x
.
unbind
(
dim
=-
2
)
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
def
apply_rotary_pos_emb
(
t
,
freqs
):
"""
input tensor t is of shape [seq_length, ..., dim]
rotary positional embeding tensor freqs is of shape [seq_length, ..., dim]
check https://kexue.fm/archives/8265 for detailed formulas
"""
rot_dim
=
freqs
.
shape
[
-
1
]
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
t
,
t_pass
=
t
[...,
:
rot_dim
],
t
[...,
rot_dim
:]
# first part is cosine component
# second part is sine component, need to change signs with _rotate_half method
t
=
(
t
*
freqs
.
cos
())
+
(
_rotate_half
(
t
)
*
freqs
.
sin
())
return
torch
.
cat
((
t
,
t_pass
),
dim
=-
1
)
megatron/model/transformer.py
View file @
1e0e555c
...
...
@@ -14,6 +14,7 @@ from megatron.model.enums import AttnMaskType, LayerType, AttnType
from
megatron.model
import
LayerNorm
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.rotary_pos_embedding
import
apply_rotary_pos_emb
from
megatron.model.utils
import
attention_mask_func
,
openai_gelu
,
erf_gelu
try
:
...
...
@@ -444,20 +445,27 @@ class ParallelAttention(MegatronModule):
**
_args_to_kwargs
())
def
_checkpointed_attention_forward
(
self
,
query_layer
,
key_layer
,
value_layer
,
attention_mask
):
value_layer
,
attention_mask
,
rotary_pos_emb
=
None
):
"""Forward method with activation checkpointing."""
def
custom_forward
(
*
inputs
):
query_layer
=
inputs
[
0
]
key_layer
=
inputs
[
1
]
value_layer
=
inputs
[
2
]
attention_mask
=
inputs
[
3
]
rotary_pos_emb
=
inputs
[
4
]
if
inputs
[
4
]
is
None
\
else
(
inputs
[
4
],
inputs
[
5
])
output_
=
self
.
core_attention
(
query_layer
,
key_layer
,
value_layer
,
attention_mask
)
return
output_
q_pos_emb
,
k_pos_emb
=
(
None
,
None
)
if
rotary_pos_emb
is
None
\
else
rotary_pos_emb
hidden_states
=
tensor_parallel
.
checkpoint
(
custom_forward
,
False
,
query_layer
,
key_layer
,
value_layer
,
attention_mask
)
False
,
query_layer
,
key_layer
,
value_layer
,
attention_mask
,
q_pos_emb
,
k_pos_emb
)
return
hidden_states
...
...
@@ -471,7 +479,8 @@ class ParallelAttention(MegatronModule):
device
=
torch
.
cuda
.
current_device
())
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
inference_params
=
None
):
encoder_output
=
None
,
inference_params
=
None
,
rotary_pos_emb
=
None
):
# hidden_states: [sq, b, h]
# =================================================
...
...
@@ -536,6 +545,11 @@ class ParallelAttention(MegatronModule):
# Adjust key and value for inference
# ==================================
# duplicate the pos_emb for self attention
if
rotary_pos_emb
is
not
None
:
rotary_pos_emb
=
rotary_pos_emb
if
isinstance
(
rotary_pos_emb
,
\
tuple
)
else
((
rotary_pos_emb
,)
*
2
)
if
inference_params
:
batch_start
=
inference_params
.
batch_size_offset
batch_end
=
batch_start
+
key_layer
.
size
(
1
)
...
...
@@ -553,10 +567,42 @@ class ParallelAttention(MegatronModule):
value_layer
=
inference_value_memory
[
:
sequence_end
,
batch_start
:
batch_end
,
...]
# adjust the key rotary positional embedding
if
rotary_pos_emb
is
not
None
:
q_pos_emb
,
k_pos_emb
=
rotary_pos_emb
# need to cross check this condition during inference
# if not set_inference_key_value_memory:
if
not
is_first_step
:
# In inference, we compute one token at a time.
# Select the correct positional embedding
# (only the last token in the sequence)
q_pos_emb
=
q_pos_emb
[
sequence_end
-
1
:
sequence_end
]
else
:
# In the first forward pass of inference,
# we use the entire provided prefix.
# q_pos_emb here has the rope embeddings of the entire
# prefix + to-be-generated output so
# we slice to just the prefix.
q_pos_emb
=
q_pos_emb
[:
sequence_end
,
:,
:,
:]
k_pos_emb
=
k_pos_emb
[:
sequence_end
,
:,
:,
:]
rotary_pos_emb
=
(
q_pos_emb
,
k_pos_emb
)
# ==================================
# core attention computation
# ==================================
# apply relative positional encoding (rotary embedding)
if
rotary_pos_emb
is
not
None
:
q_pos_emb
,
k_pos_emb
=
rotary_pos_emb
query_layer
=
apply_rotary_pos_emb
(
query_layer
,
q_pos_emb
)
key_layer
=
apply_rotary_pos_emb
(
key_layer
,
k_pos_emb
)
# TODO, can apply positional embedding to value_layer so it has
# absolute positional embedding.
# otherwise, only relative positional embedding takes effect
# value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb)
if
not
self
.
use_flash_attn
:
if
self
.
checkpoint_core_attention
:
context_layer
=
self
.
_checkpointed_attention_forward
(
...
...
@@ -688,17 +734,21 @@ class ParallelTransformerLayer(MegatronModule):
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
inference_params
=
None
):
inference_params
=
None
,
rotary_pos_emb
=
None
):
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
# Self attention.
self_attention_pos_emb
=
None
if
rotary_pos_emb
is
not
None
:
self_attention_pos_emb
=
rotary_pos_emb
attention_output
,
attention_bias
=
\
self
.
self_attention
(
layernorm_output
,
attention_mask
,
inference_params
=
inference_params
)
inference_params
=
inference_params
,
rotary_pos_emb
=
self_attention_pos_emb
)
# Residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
...
...
@@ -1032,7 +1082,8 @@ class ParallelTransformer(MegatronModule):
return
self
.
layers
[
layer_number
]
def
_checkpointed_forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
,
is_first_microbatch
):
encoder_output
,
enc_dec_attn_mask
,
rotary_pos_emb
,
is_first_microbatch
):
"""Forward method with activation checkpointing."""
def
custom
(
start
,
end
,
is_transformer_engine
=
False
):
def
custom_forward
(
*
args
,
**
kwargs
):
...
...
@@ -1059,12 +1110,14 @@ class ParallelTransformer(MegatronModule):
self
.
distribute_saved_activations
,
tensor_parallel
.
get_cuda_rng_tracker
,
mpu
.
get_tensor_model_parallel_group
(),
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
,
rotary_pos_emb
)
else
:
hidden_states
=
tensor_parallel
.
checkpoint
(
custom
(
l
,
l
+
self
.
recompute_num_layers
),
self
.
distribute_saved_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
,
rotary_pos_emb
)
l
+=
self
.
recompute_num_layers
...
...
@@ -1080,19 +1133,23 @@ class ParallelTransformer(MegatronModule):
self
.
distribute_saved_activations
,
tensor_parallel
.
get_cuda_rng_tracker
,
mpu
.
get_tensor_model_parallel_group
(),
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
,
rotary_pos_emb
)
else
:
hidden_states
=
tensor_parallel
.
checkpoint
(
custom
(
l
,
l
+
1
),
self
.
distribute_saved_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
,
rotary_pos_emb
)
else
:
if
self
.
transformer_impl
==
'transformer_engine'
:
hidden_states
=
custom
(
l
,
l
+
1
,
is_transformer_engine
=
True
)(
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
,
rotary_pos_emb
)
else
:
hidden_states
=
custom
(
l
,
l
+
1
)(
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
,
rotary_pos_emb
)
else
:
raise
ValueError
(
"Invalid activation recompute method."
)
...
...
@@ -1110,7 +1167,7 @@ class ParallelTransformer(MegatronModule):
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
inference_params
=
None
):
inference_params
=
None
,
rotary_pos_emb
=
None
):
# hidden_states: [s, b, h]
# Checks.
...
...
@@ -1168,12 +1225,14 @@ class ParallelTransformer(MegatronModule):
attention_mask
,
encoder_output
,
enc_dec_attn_mask
,
rotary_pos_emb
,
is_first_microbatch
)
else
:
forward_kwargs
=
{
'encoder_output'
:
encoder_output
,
'enc_dec_attn_mask'
:
enc_dec_attn_mask
,
'inference_params'
:
inference_params
,
'rotary_pos_emb'
:
rotary_pos_emb
,
}
if
self
.
transformer_impl
==
'transformer_engine'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment