Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
c63906a6
Commit
c63906a6
authored
Sep 29, 2020
by
Mohammad Shoeybi
Browse files
Merge branch 'vijay/weight_mapping' into 'main'
Vijay/weight mapping See merge request ADLR/megatron-lm!138
parents
577b4657
64e45f29
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
48 additions
and
10 deletions
+48
-10
megatron/checkpointing.py
megatron/checkpointing.py
+16
-0
megatron/model/transformer.py
megatron/model/transformer.py
+32
-10
No files found.
megatron/checkpointing.py
View file @
c63906a6
...
@@ -27,6 +27,17 @@ from megatron import mpu, get_args
...
@@ -27,6 +27,17 @@ from megatron import mpu, get_args
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
_CHECKPOINT_VERSION
=
None
def
set_checkpoint_version
(
value
):
global
_CHECKPOINT_VERSION
assert
_CHECKPOINT_VERSION
is
None
,
\
"checkpoint version already set"
_CHECKPOINT_VERSION
=
value
def
get_checkpoint_version
():
global
_CHECKPOINT_VERSION
return
_CHECKPOINT_VERSION
def
check_checkpoint_args
(
checkpoint_args
):
def
check_checkpoint_args
(
checkpoint_args
):
"""Ensure fixed arguments for a model are the same for the input
"""Ensure fixed arguments for a model are the same for the input
...
@@ -90,6 +101,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
...
@@ -90,6 +101,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
# Arguments, iteration, and model.
# Arguments, iteration, and model.
state_dict
=
{}
state_dict
=
{}
state_dict
[
'args'
]
=
args
state_dict
[
'args'
]
=
args
state_dict
[
'checkpoint_version'
]
=
1.0
state_dict
[
'iteration'
]
=
iteration
state_dict
[
'iteration'
]
=
iteration
state_dict
[
'model'
]
=
model
.
state_dict_for_save_checkpoint
()
state_dict
[
'model'
]
=
model
.
state_dict_for_save_checkpoint
()
...
@@ -184,6 +196,9 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
...
@@ -184,6 +196,9 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
print_rank_0
(
'could not load the checkpoint'
)
print_rank_0
(
'could not load the checkpoint'
)
sys
.
exit
()
sys
.
exit
()
# set checkpoint version
set_checkpoint_version
(
state_dict
.
get
(
'checkpoint_version'
,
0
))
# Set iteration.
# Set iteration.
if
args
.
finetune
or
release
:
if
args
.
finetune
or
release
:
iteration
=
0
iteration
=
0
...
@@ -198,6 +213,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
...
@@ -198,6 +213,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
'iteration from checkpoint {}, exiting'
.
format
(
'iteration from checkpoint {}, exiting'
.
format
(
checkpoint_name
))
checkpoint_name
))
sys
.
exit
()
sys
.
exit
()
# Check arguments.
# Check arguments.
if
'args'
in
state_dict
:
if
'args'
in
state_dict
:
...
...
megatron/model/transformer.py
View file @
c63906a6
...
@@ -23,6 +23,7 @@ from megatron import get_args
...
@@ -23,6 +23,7 @@ from megatron import get_args
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.mpu
import
LayerNorm
from
megatron.mpu
import
LayerNorm
from
megatron.module
import
MegatronModule
from
megatron.module
import
MegatronModule
from
megatron.checkpointing
import
get_checkpoint_version
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.utils
import
openai_gelu
,
erf_gelu
from
megatron.model.utils
import
openai_gelu
,
erf_gelu
...
@@ -170,7 +171,23 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -170,7 +171,23 @@ class ParallelSelfAttention(MegatronModule):
input_is_parallel
=
True
,
input_is_parallel
=
True
,
init_method
=
output_layer_init_method
,
init_method
=
output_layer_init_method
,
skip_bias_add
=
True
)
skip_bias_add
=
True
)
def
_transpose_last_dim
(
self
,
mixed_layer
):
"""[s, b, 3 * hp] -->(view) [s, b, 3, hp] -->(tranpose)
[s, b, hp, 3] -->(view) [s, b, 3 * hp] """
input_shape
=
mixed_layer
.
size
();
last_dim
=
input_shape
[
-
1
]
assert
last_dim
%
3
==
0
,
"expected QKV dimension"
last_dim_split
=
last_dim
//
3
intermediate_shape
=
input_shape
[:
-
1
]
+
\
(
3
,
last_dim_split
)
mixed_layer
=
mixed_layer
.
view
(
*
intermediate_shape
)
mixed_layer
=
mixed_layer
.
transpose
(
-
1
,
-
2
).
contiguous
()
mixed_layer
=
mixed_layer
.
view
(
*
input_shape
)
return
mixed_layer
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
get_key_value
=
False
):
...
@@ -180,20 +197,25 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -180,20 +197,25 @@ class ParallelSelfAttention(MegatronModule):
# Query, Key, and Value
# Query, Key, and Value
# =====================
# =====================
# Attention heads [s, b, hp] --> [s, b,
3
*
hp
]
# Attention heads [s, b, hp] --> [s, b,
hp
*
3
]
mixed_x_layer
,
_
=
self
.
query_key_value
(
hidden_states
)
mixed_x_layer
,
_
=
self
.
query_key_value
(
hidden_states
)
# [s, b, 3 * hp] --> [s, b, np, 3 * hn]
checkpoint_version
=
get_checkpoint_version
()
if
checkpoint_version
is
not
None
and
\
checkpoint_version
==
0
:
# [s, b, 3 * hp] --> [s, b, hp * 3]
mixed_x_layer
=
self
.
_transpose_last_dim
(
mixed_x_layer
)
# [s, b, hp * 3] --> [s, b, np, hn, 3]
new_tensor_shape
=
mixed_x_layer
.
size
()[:
-
1
]
+
\
new_tensor_shape
=
mixed_x_layer
.
size
()[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
(
self
.
num_attention_heads_per_partition
,
3
*
self
.
hidden_size_per_attention_head
)
self
.
hidden_size_per_attention_head
,
3
)
mixed_x_layer
=
mixed_x_layer
.
view
(
*
new_tensor_shape
)
mixed_x_layer
=
mixed_x_layer
.
view
(
*
new_tensor_shape
)
# [s, b, np, 3 * hn] --> 3 [s, b, np, hn]
# [s, b, np, hn, 3] --> 3 [s, b, np, hn]
(
query_layer
,
query_layer
=
mixed_x_layer
[:,:,:,:,
0
]
key_layer
,
key_layer
=
mixed_x_layer
[:,:,:,:,
1
]
value_layer
)
=
mpu
.
split_tensor_along_last_dim
(
mixed_x_layer
,
3
)
value_layer
=
mixed_x_layer
[:,:,:,:,
2
]
# ==================================
# ==================================
# Adjust key and value for inference
# Adjust key and value for inference
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment