Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
ec561daa
Commit
ec561daa
authored
Feb 18, 2022
by
Jared Casper
Browse files
Better handling of padding in embedding table.
parent
cdf0a5d4
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
55 additions
and
7 deletions
+55
-7
tools/checkpoint_loader_megatron.py
tools/checkpoint_loader_megatron.py
+24
-1
tools/checkpoint_saver_megatron.py
tools/checkpoint_saver_megatron.py
+31
-6
No files found.
tools/checkpoint_loader_megatron.py
View file @
ec561daa
import
json
import
os
import
os
import
sys
import
sys
import
types
import
types
...
@@ -7,6 +8,11 @@ import torch
...
@@ -7,6 +8,11 @@ import torch
def
add_arguments
(
parser
):
def
add_arguments
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'Megatron loader'
)
group
=
parser
.
add_argument_group
(
title
=
'Megatron loader'
)
group
.
add_argument
(
'--true-vocab-size'
,
type
=
int
,
default
=
None
,
help
=
'original size of vocab, if specified will trim padding from embedding table.'
)
group
.
add_argument
(
'--vocab-file'
,
type
=
str
,
default
=
None
,
help
=
'Path to the vocab file. If specified will use this to get vocab size and '
'trim padding from the embedding table.'
)
group
.
add_argument
(
'--megatron-path'
,
type
=
str
,
default
=
None
,
group
.
add_argument
(
'--megatron-path'
,
type
=
str
,
default
=
None
,
help
=
'Base directory of deepspeed repository'
)
help
=
'Base directory of deepspeed repository'
)
...
@@ -21,7 +27,7 @@ def _load_checkpoint(queue, args):
...
@@ -21,7 +27,7 @@ def _load_checkpoint(queue, args):
try
:
try
:
from
megatron.arguments
import
parse_args
,
validate_args
from
megatron.arguments
import
parse_args
,
validate_args
from
megatron.global_vars
import
set_args
,
set_global_variables
,
rebuild_tokenizer
from
megatron.global_vars
import
set_args
,
set_global_variables
from
megatron.checkpointing
import
load_args_from_checkpoint
,
load_checkpoint
from
megatron.checkpointing
import
load_args_from_checkpoint
,
load_checkpoint
from
megatron.model
import
ModelType
from
megatron.model
import
ModelType
from
megatron
import
mpu
,
fused_kernels
from
megatron
import
mpu
,
fused_kernels
...
@@ -111,6 +117,19 @@ def _load_checkpoint(queue, args):
...
@@ -111,6 +117,19 @@ def _load_checkpoint(queue, args):
mpu
.
initialize
.
set_pipeline_model_parallel_world_size
(
margs
.
pipeline_model_parallel_size
)
mpu
.
initialize
.
set_pipeline_model_parallel_world_size
(
margs
.
pipeline_model_parallel_size
)
fused_kernels
.
load
(
margs
)
fused_kernels
.
load
(
margs
)
# Get true (non-padded) vocab size
if
args
.
true_vocab_size
is
not
None
:
true_vocab_size
=
args
.
true_vocab_size
elif
args
.
vocab_file
is
not
None
:
vocab
=
json
.
load
(
open
(
args
.
vocab_file
))
true_vocab_size
=
len
(
vocab
)
if
args
.
true_vocab_size
is
not
None
and
true_vocab_size
!=
args
.
true_vocab_size
:
print
(
"Both --true-vocab-size and --vocab-file specified and the vocab size does not match, aborting."
)
queue
.
put
(
"exit"
)
exit
(
1
)
else
:
true_vocab_size
=
None
# short aliases
# short aliases
tp_size
=
margs
.
tensor_model_parallel_size
tp_size
=
margs
.
tensor_model_parallel_size
pp_size
=
margs
.
pipeline_model_parallel_size
pp_size
=
margs
.
pipeline_model_parallel_size
...
@@ -129,6 +148,8 @@ def _load_checkpoint(queue, args):
...
@@ -129,6 +148,8 @@ def _load_checkpoint(queue, args):
md
.
bert_binary_head
=
margs
.
bert_binary_head
md
.
bert_binary_head
=
margs
.
bert_binary_head
md
.
previous_tensor_parallel_size
=
margs
.
tensor_model_parallel_size
md
.
previous_tensor_parallel_size
=
margs
.
tensor_model_parallel_size
md
.
previous_pipeline_parallel_size
=
margs
.
pipeline_model_parallel_size
md
.
previous_pipeline_parallel_size
=
margs
.
pipeline_model_parallel_size
md
.
true_vocab_size
=
true_vocab_size
md
.
make_vocab_size_divisible_by
=
margs
.
make_vocab_size_divisible_by
queue
.
put
(
md
)
queue
.
put
(
md
)
# Get first pipe stage
# Get first pipe stage
...
@@ -137,6 +158,7 @@ def _load_checkpoint(queue, args):
...
@@ -137,6 +158,7 @@ def _load_checkpoint(queue, args):
models
=
get_models
(
tp_size
,
md
.
params_dtype
,
True
,
post_process
)
models
=
get_models
(
tp_size
,
md
.
params_dtype
,
True
,
post_process
)
# Send embeddings
# Send embeddings
word_embed
=
[]
word_embed
=
[]
for
tp_rank
in
range
(
tp_size
):
for
tp_rank
in
range
(
tp_size
):
if
tp_rank
==
0
:
if
tp_rank
==
0
:
...
@@ -144,6 +166,7 @@ def _load_checkpoint(queue, args):
...
@@ -144,6 +166,7 @@ def _load_checkpoint(queue, args):
queue
.
put
(
models
[
tp_rank
].
language_model
.
embedding
.
position_embeddings
.
weight
.
data
)
queue
.
put
(
models
[
tp_rank
].
language_model
.
embedding
.
position_embeddings
.
weight
.
data
)
word_embed
.
append
(
models
[
tp_rank
].
language_model
.
embedding
.
word_embeddings
.
weight
.
data
)
word_embed
.
append
(
models
[
tp_rank
].
language_model
.
embedding
.
word_embeddings
.
weight
.
data
)
full_word_embed
=
torch
.
cat
(
word_embed
,
dim
=
0
)
full_word_embed
=
torch
.
cat
(
word_embed
,
dim
=
0
)
print
(
"Sending word embeddings"
)
print
(
"Sending word embeddings"
)
queue
.
put
(
full_word_embed
)
queue
.
put
(
full_word_embed
)
...
...
tools/checkpoint_saver_megatron.py
View file @
ec561daa
...
@@ -31,6 +31,7 @@ def save_checkpoint(queue, args):
...
@@ -31,6 +31,7 @@ def save_checkpoint(queue, args):
from
megatron.checkpointing
import
save_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.global_vars
import
set_global_variables
,
get_args
from
megatron.global_vars
import
set_global_variables
,
get_args
from
megatron.model
import
ModelType
from
megatron.model
import
ModelType
from
megatron.tokenizer.tokenizer
import
_vocab_size_with_padding
from
megatron
import
mpu
,
fused_kernels
from
megatron
import
mpu
,
fused_kernels
except
ModuleNotFoundError
:
except
ModuleNotFoundError
:
print
(
"Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting."
)
print
(
"Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting."
)
...
@@ -91,6 +92,9 @@ def save_checkpoint(queue, args):
...
@@ -91,6 +92,9 @@ def save_checkpoint(queue, args):
'--save-interval'
,
'1'
,
'--save-interval'
,
'1'
,
'--save'
,
args
.
save_dir
'--save'
,
args
.
save_dir
]
]
if
md
.
make_vocab_size_divisible_by
is
not
None
:
sys
.
argv
.
extend
([
'--make-vocab-size-divisible-by'
,
str
(
md
.
make_vocab_size_divisible_by
)])
if
md
.
params_dtype
==
torch
.
float16
:
if
md
.
params_dtype
==
torch
.
float16
:
sys
.
argv
.
append
(
'--fp16'
)
sys
.
argv
.
append
(
'--fp16'
)
elif
md
.
params_dtype
==
torch
.
bfloat16
:
elif
md
.
params_dtype
==
torch
.
bfloat16
:
...
@@ -127,13 +131,33 @@ def save_checkpoint(queue, args):
...
@@ -127,13 +131,33 @@ def save_checkpoint(queue, args):
# Embeddings
# Embeddings
#-----------
#-----------
pos_embed
=
queue_get
()
pos_embed
=
queue_get
()
full
_word_embed
=
queue_get
()
orig
_word_embed
=
queue_get
()
# Tell Megatron what our full size is
# Deal with padding
margs
.
padded_vocab_size
=
full_word_embed
.
shape
[
0
]
if
md
.
true_vocab_size
is
not
None
:
if
margs
.
padded_vocab_size
%
args
.
target_tensor_parallel_size
!=
0
:
# figure out what our padded vocab size is
print
(
"source vocab size is not evenly divisble by target tensor parallel size"
)
orig_vocab_size
=
orig_word_embed
.
shape
[
0
]
exit
(
1
)
margs
.
padded_vocab_size
=
_vocab_size_with_padding
(
md
.
true_vocab_size
,
margs
)
# Cut out extra padding we don't need
if
orig_vocab_size
>
margs
.
padded_vocab_size
:
full_word_embed
=
orig_word_embed
[
0
:
margs
.
padded_vocab_size
,:]
# Expanding embedding to larger size by replicating final entry
elif
orig_vocab_size
<
margs
.
padded_vocab_size
:
padding_size
=
margs
.
padded_vocab_size
-
orig_vocab_size
full_word_embed
=
torch
.
cat
((
orig_word_embed
,
orig_word_embed
[
-
1
].
unsqueeze
(
0
).
expand
(
padding_size
,
-
1
)))
# Same size!
else
:
full_word_embed
=
orig_word_embed
else
:
print
(
"Original vocab size not specified, leaving embedding table as-is. "
"If you've changed the tensor parallel size this could cause problems."
)
full_word_embed
=
orig_word_embed
# Split into new tensor model parallel sizes
# Split into new tensor model parallel sizes
out_word_embed
=
torch
.
chunk
(
full_word_embed
,
args
.
target_tensor_parallel_size
,
dim
=
0
)
out_word_embed
=
torch
.
chunk
(
full_word_embed
,
args
.
target_tensor_parallel_size
,
dim
=
0
)
...
@@ -143,6 +167,7 @@ def save_checkpoint(queue, args):
...
@@ -143,6 +167,7 @@ def save_checkpoint(queue, args):
post_process
=
args
.
target_pipeline_parallel_size
==
1
post_process
=
args
.
target_pipeline_parallel_size
==
1
models
=
get_models
(
args
.
target_tensor_parallel_size
,
md
.
params_dtype
,
True
,
post_process
)
models
=
get_models
(
args
.
target_tensor_parallel_size
,
md
.
params_dtype
,
True
,
post_process
)
for
tp_rank
,
model
in
enumerate
(
models
):
for
tp_rank
,
model
in
enumerate
(
models
):
print
(
f
"word embeddings shape
{
model
.
language_model
.
embedding
.
word_embeddings
.
weight
.
shape
}
"
)
model
.
language_model
.
embedding
.
word_embeddings
.
weight
.
data
.
copy_
(
out_word_embed
[
tp_rank
])
model
.
language_model
.
embedding
.
word_embeddings
.
weight
.
data
.
copy_
(
out_word_embed
[
tp_rank
])
model
.
language_model
.
embedding
.
position_embeddings
.
weight
.
data
.
copy_
(
pos_embed
)
model
.
language_model
.
embedding
.
position_embeddings
.
weight
.
data
.
copy_
(
pos_embed
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment