Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
748a9a7d
Commit
748a9a7d
authored
Jul 11, 2022
by
Jared Casper
Browse files
Add error checking by wrapping all weights and biases into named messages with named tensors.
parent
2755bcb8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
187 additions
and
130 deletions
+187
-130
tools/checkpoint_loader_megatron.py
tools/checkpoint_loader_megatron.py
+56
-50
tools/checkpoint_saver_megatron.py
tools/checkpoint_saver_megatron.py
+77
-47
tools/checkpoint_util.py
tools/checkpoint_util.py
+54
-33
No files found.
tools/checkpoint_loader_megatron.py
View file @
748a9a7d
...
@@ -170,18 +170,20 @@ def _load_checkpoint(queue, args):
...
@@ -170,18 +170,20 @@ def _load_checkpoint(queue, args):
md
.
consumed_valid_samples
=
consumed_valid_samples
md
.
consumed_valid_samples
=
consumed_valid_samples
queue
.
put
(
md
)
queue
.
put
(
md
)
# Send embeddings
def
queue_put
(
name
,
msg
):
print
(
f
"sending
{
name
}
"
)
msg
[
"name"
]
=
name
queue
.
put
(
msg
)
word_embed
=
[]
# Send embeddings
for
tp_rank
in
range
(
tp_size
):
message
=
{
if
tp_rank
==
0
:
"position embeddings"
:
models
[
0
].
language_model
.
embedding
.
position_embeddings
.
weight
.
data
,
print
(
"Sending position embeddings"
)
"word embeddings"
:
torch
.
cat
(
queue
.
put
(
models
[
tp_rank
].
language_model
.
embedding
.
position
_embeddings
.
weight
.
data
)
[
models
[
tp_rank
].
language_model
.
embedding
.
word
_embeddings
.
weight
.
data
for
tp_rank
in
range
(
tp_size
)],
word_embed
.
append
(
models
[
tp_rank
].
language_model
.
embedding
.
word_embeddings
.
weight
.
data
)
dim
=
0
)
full_word_embed
=
torch
.
cat
(
word_embed
,
dim
=
0
)
}
print
(
"Sending word embeddings"
)
queue_put
(
"embeddings"
,
message
)
queue
.
put
(
full_word_embed
)
total_layer_num
=
0
total_layer_num
=
0
for
pp_rank
in
range
(
pp_size
):
for
pp_rank
in
range
(
pp_size
):
...
@@ -190,23 +192,24 @@ def _load_checkpoint(queue, args):
...
@@ -190,23 +192,24 @@ def _load_checkpoint(queue, args):
post_process
=
pp_rank
==
pp_size
-
1
post_process
=
pp_rank
==
pp_size
-
1
models
=
get_models
(
tp_size
,
md
.
params_dtype
,
False
,
post_process
)
models
=
get_models
(
tp_size
,
md
.
params_dtype
,
False
,
post_process
)
for
layer_num
in
range
(
len
(
models
[
0
].
language_model
.
encoder
.
layers
)):
for
layer_num
in
range
(
len
(
models
[
0
].
language_model
.
encoder
.
layers
)):
message
=
{}
# Get non-parallel tensors from tp_rank 0
layer
=
models
[
0
].
language_model
.
encoder
.
layers
[
layer_num
]
message
[
"input layernorm weight"
]
=
layer
.
input_layernorm
.
weight
.
data
message
[
"input layernorm bias"
]
=
layer
.
input_layernorm
.
bias
.
data
message
[
"dense bias"
]
=
layer
.
self_attention
.
dense
.
bias
.
data
message
[
"post layernorm weight"
]
=
layer
.
post_attention_layernorm
.
weight
.
data
message
[
"post layernorm bias"
]
=
layer
.
post_attention_layernorm
.
bias
.
data
message
[
"mlp l1 bias"
]
=
layer
.
mlp
.
dense_4h_to_h
.
bias
.
data
# Grab all parallel tensors for this layer
qkv_weight
=
[]
qkv_weight
=
[]
qkv_bias
=
[]
qkv_bias
=
[]
dense_weight
=
[]
dense_weight
=
[]
mlp_l0_weight
=
[]
mlp_l0_weight
=
[]
mlp_l0_bias
=
[]
mlp_l0_bias
=
[]
mlp_l1_weight
=
[]
mlp_l1_weight
=
[]
# Get non-parallel tensors from tp_rank 0
layer
=
models
[
0
].
language_model
.
encoder
.
layers
[
layer_num
]
input_layernorm_weight
=
layer
.
input_layernorm
.
weight
.
data
input_layernorm_bias
=
layer
.
input_layernorm
.
bias
.
data
dense_bias
=
layer
.
self_attention
.
dense
.
bias
.
data
post_layernorm_weight
=
layer
.
post_attention_layernorm
.
weight
.
data
post_layernorm_bias
=
layer
.
post_attention_layernorm
.
bias
.
data
mlp_l1_bias
=
layer
.
mlp
.
dense_4h_to_h
.
bias
.
data
# Grab all parallel tensors for this layer
for
tp_rank
,
model
in
enumerate
(
models
):
for
tp_rank
,
model
in
enumerate
(
models
):
layer
=
model
.
language_model
.
encoder
.
layers
[
layer_num
]
layer
=
model
.
language_model
.
encoder
.
layers
[
layer_num
]
qkv_weight
.
append
(
layer
.
self_attention
.
query_key_value
.
weight
.
data
)
qkv_weight
.
append
(
layer
.
self_attention
.
query_key_value
.
weight
.
data
)
...
@@ -216,47 +219,50 @@ def _load_checkpoint(queue, args):
...
@@ -216,47 +219,50 @@ def _load_checkpoint(queue, args):
mlp_l0_bias
.
append
(
layer
.
mlp
.
dense_h_to_4h
.
bias
.
data
)
mlp_l0_bias
.
append
(
layer
.
mlp
.
dense_h_to_4h
.
bias
.
data
)
mlp_l1_weight
.
append
(
layer
.
mlp
.
dense_4h_to_h
.
weight
.
data
)
mlp_l1_weight
.
append
(
layer
.
mlp
.
dense_4h_to_h
.
weight
.
data
)
# send everything in order while concatenating them
# concat them
print
(
f
"Sending layer
{
layer_num
}
of pipeline rank
{
pp_rank
}
(total layer
{
total_layer_num
}
)"
)
message
[
"qkv weight"
]
=
torch
.
cat
(
qkv_weight
,
dim
=
0
)
queue
.
put
(
input_layernorm_weight
)
message
[
"qkv bias"
]
=
torch
.
cat
(
qkv_bias
,
dim
=
0
)
queue
.
put
(
input_layernorm_bias
)
message
[
"dense weight"
]
=
torch
.
cat
(
dense_weight
,
dim
=
1
)
queue
.
put
(
torch
.
cat
(
qkv_weight
,
dim
=
0
))
message
[
"mlp l0 weight"
]
=
torch
.
cat
(
mlp_l0_weight
,
dim
=
0
)
queue
.
put
(
torch
.
cat
(
qkv_bias
,
dim
=
0
))
message
[
"mlp l0 bias"
]
=
torch
.
cat
(
mlp_l0_bias
,
dim
=
0
)
queue
.
put
(
torch
.
cat
(
dense_weight
,
dim
=
1
))
message
[
"mlp l1 weight"
]
=
torch
.
cat
(
mlp_l1_weight
,
dim
=
1
)
queue
.
put
(
dense_bias
)
queue
.
put
(
post_layernorm_weight
)
queue_put
(
f
"transformer layer
{
total_layer_num
}
"
,
message
)
queue
.
put
(
post_layernorm_bias
)
queue
.
put
(
torch
.
cat
(
mlp_l0_weight
,
dim
=
0
))
queue
.
put
(
torch
.
cat
(
mlp_l0_bias
,
dim
=
0
))
queue
.
put
(
torch
.
cat
(
mlp_l1_weight
,
dim
=
1
))
queue
.
put
(
mlp_l1_bias
)
total_layer_num
=
total_layer_num
+
1
total_layer_num
=
total_layer_num
+
1
# Send final layernorm from tp_rank 0
# Send final layernorm from tp_rank 0
print
(
"Sending final layernorm"
)
message
=
{
queue
.
put
(
models
[
0
].
language_model
.
encoder
.
final_layernorm
.
weight
.
data
)
"weight"
:
models
[
0
].
language_model
.
encoder
.
final_layernorm
.
weight
.
data
,
queue
.
put
(
models
[
0
].
language_model
.
encoder
.
final_layernorm
.
bias
.
data
)
"bias"
:
models
[
0
].
language_model
.
encoder
.
final_layernorm
.
bias
.
data
}
queue_put
(
"final layernorm"
,
message
)
# Send BERT lm head and binary head if it exists
# Send BERT lm head and binary head if it exists
if
md
.
model_type
==
'BERT'
:
if
md
.
model_type
==
'BERT'
:
print
(
"Sending LM Pooler"
)
print
(
"Sending LM Pooler"
)
queue
.
put
(
"pooler"
)
message
=
{
queue
.
put
(
models
[
0
].
language_model
.
pooler
.
dense
.
weight
.
data
)
"weight"
:
models
[
0
].
language_model
.
pooler
.
dense
.
weight
.
data
,
queue
.
put
(
models
[
0
].
language_model
.
pooler
.
dense
.
bias
.
data
)
"bias"
:
models
[
0
].
language_model
.
pooler
.
dense
.
bias
.
data
}
print
(
"Sending BERT LM head"
)
queue_put
(
"pooler"
,
message
)
queue
.
put
(
"lm head"
)
queue
.
put
(
models
[
0
].
lm_head
.
dense
.
weight
.
data
)
message
=
{
queue
.
put
(
models
[
0
].
lm_head
.
dense
.
bias
.
data
)
"dense weight"
:
models
[
0
].
lm_head
.
dense
.
weight
.
data
,
queue
.
put
(
models
[
0
].
lm_head
.
layernorm
.
weight
.
data
)
"dense bias"
:
models
[
0
].
lm_head
.
dense
.
bias
.
data
,
queue
.
put
(
models
[
0
].
lm_head
.
layernorm
.
bias
.
data
)
"layernorm weight"
:
models
[
0
].
lm_head
.
layernorm
.
weight
.
data
,
"layernorm bias"
:
models
[
0
].
lm_head
.
layernorm
.
bias
.
data
}
queue_put
(
"lm head"
,
message
)
if
md
.
bert_binary_head
:
if
md
.
bert_binary_head
:
print
(
"Sending BERT Binary head"
)
print
(
"Sending BERT Binary head"
)
queue
.
put
(
"binary head"
)
queue
.
put
(
"binary head"
)
queue
.
put
(
models
[
0
].
binary_head
.
weight
.
data
)
message
=
{
queue
.
put
(
models
[
0
].
binary_head
.
bias
.
data
)
"weight"
:
models
[
0
].
binary_head
.
weight
.
data
,
"bias"
:
models
[
0
].
binary_head
.
bias
.
data
}
queue_put
(
"binary head"
,
message
)
queue
.
put
(
"done"
)
queue
.
put
(
"done"
)
def
load_checkpoint
(
queue
,
args
):
def
load_checkpoint
(
queue
,
args
):
...
...
tools/checkpoint_saver_megatron.py
View file @
748a9a7d
import
argparse
import
argparse
from
collections.abc
import
Mapping
import
concurrent.futures
import
concurrent.futures
import
os
import
os
import
sys
import
sys
...
@@ -38,13 +39,31 @@ def save_checkpoint(queue, args):
...
@@ -38,13 +39,31 @@ def save_checkpoint(queue, args):
print
(
"Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting."
)
print
(
"Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting."
)
exit
(
1
)
exit
(
1
)
def
queue_get
():
def
queue_get
(
name
=
None
):
val
=
queue
.
get
()
val
=
queue
.
get
()
if
val
==
"exit"
:
if
val
==
"exit"
:
print
(
"Loader exited, exiting saver"
)
print
(
"Loader exited, exiting saver"
)
exit
(
1
)
exit
(
1
)
if
name
is
not
None
and
args
.
checking
and
val
[
"name"
]
!=
name
:
val_name
=
val
[
"name"
]
print
(
f
'Unexpected message. Expecting "
{
name
}
" but got "
{
val_name
}
". Exiting saver.'
)
exit
(
1
)
if
name
is
not
None
:
print
(
f
"received
{
name
}
"
)
return
val
return
val
def
check_message
(
msg
):
if
not
args
.
checking
:
return
msg_name
=
msg
.
pop
(
"name"
)
if
len
(
msg
.
keys
())
>
0
:
print
(
f
"Unexpected values in
{
msg_name
}
:"
)
for
key
in
msg
.
keys
():
print
(
f
"
{
key
}
"
)
print
(
f
"Exiting. If you want to ignore this, use the argument --no-checking."
)
exit
(
1
)
md
=
queue_get
()
md
=
queue_get
()
if
args
.
target_tensor_parallel_size
is
None
:
if
args
.
target_tensor_parallel_size
is
None
:
...
@@ -141,8 +160,11 @@ def save_checkpoint(queue, args):
...
@@ -141,8 +160,11 @@ def save_checkpoint(queue, args):
# Embeddings
# Embeddings
#-----------
#-----------
pos_embed
=
queue_get
()
embeddings_msg
=
queue_get
(
"embeddings"
)
orig_word_embed
=
queue_get
()
pos_embed
=
embeddings_msg
.
pop
(
"position embeddings"
)
orig_word_embed
=
embeddings_msg
.
pop
(
"word embeddings"
)
check_message
(
embeddings_msg
)
# Deal with padding
# Deal with padding
if
md
.
true_vocab_size
is
not
None
:
if
md
.
true_vocab_size
is
not
None
:
...
@@ -185,6 +207,7 @@ def save_checkpoint(queue, args):
...
@@ -185,6 +207,7 @@ def save_checkpoint(queue, args):
# Transformer layers
# Transformer layers
#-------------------
#-------------------
total_layer_num
=
0
for
pp_rank
in
range
(
args
.
target_pipeline_parallel_size
):
for
pp_rank
in
range
(
args
.
target_pipeline_parallel_size
):
# For later pipeline parallel ranks, make the new models
# For later pipeline parallel ranks, make the new models
if
pp_rank
>
0
:
if
pp_rank
>
0
:
...
@@ -193,47 +216,47 @@ def save_checkpoint(queue, args):
...
@@ -193,47 +216,47 @@ def save_checkpoint(queue, args):
models
=
get_models
(
args
.
target_tensor_parallel_size
,
md
.
params_dtype
,
False
,
post_process
)
models
=
get_models
(
args
.
target_tensor_parallel_size
,
md
.
params_dtype
,
False
,
post_process
)
for
layer
in
range
(
len
(
models
[
0
].
language_model
.
encoder
.
layers
)):
for
layer
in
range
(
len
(
models
[
0
].
language_model
.
encoder
.
layers
)):
# get full tensors
msg
=
queue_get
(
f
"transformer layer
{
total_layer_num
}
"
)
input_layernorm_weight
=
queue_get
()
input_layernorm_bias
=
queue_get
()
# duplicated tensors
full_qkv_weight
=
queue_get
()
input_layernorm_weight
=
msg
.
pop
(
"input layernorm weight"
)
full_qkv_bias
=
queue_get
()
input_layernorm_bias
=
msg
.
pop
(
"input layernorm bias"
)
full_dense_weight
=
queue_get
()
dense_bias
=
msg
.
pop
(
"dense bias"
)
dense_bias
=
queue_get
()
post_layernorm_weight
=
msg
.
pop
(
"post layernorm weight"
)
post_layernorm_weight
=
queue_get
()
post_layernorm_bias
=
msg
.
pop
(
"post layernorm bias"
)
post_layernorm_bias
=
queue_get
()
mlp_l1_bias
=
msg
.
pop
(
"mlp l1 bias"
)
full_mlp_l0_weight
=
queue_get
()
full_mlp_l0_bias
=
queue_get
()
full_mlp_l1_weight
=
queue_get
()
mlp_l1_bias
=
queue_get
()
# Split up the parallel tensors
# Split up the parallel tensors
out_
qkv_weight
=
torch
.
chunk
(
full_
qkv
_
weight
,
args
.
target_tensor_parallel_size
,
dim
=
0
)
qkv_weight
=
torch
.
chunk
(
msg
.
pop
(
"
qkv
weight
"
)
,
args
.
target_tensor_parallel_size
,
dim
=
0
)
out_
qkv_bias
=
torch
.
chunk
(
full_
qkv
_
bias
,
args
.
target_tensor_parallel_size
,
dim
=
0
)
qkv_bias
=
torch
.
chunk
(
msg
.
pop
(
"
qkv
bias
"
)
,
args
.
target_tensor_parallel_size
,
dim
=
0
)
out_
dense_weight
=
torch
.
chunk
(
full_
dense
_
weight
,
args
.
target_tensor_parallel_size
,
dim
=
1
)
dense_weight
=
torch
.
chunk
(
msg
.
pop
(
"
dense
weight
"
)
,
args
.
target_tensor_parallel_size
,
dim
=
1
)
out_
mlp_l0_weight
=
torch
.
chunk
(
full_
mlp
_
l0
_
weight
,
args
.
target_tensor_parallel_size
,
dim
=
0
)
mlp_l0_weight
=
torch
.
chunk
(
msg
.
pop
(
"
mlp
l0
weight
"
)
,
args
.
target_tensor_parallel_size
,
dim
=
0
)
out_
mlp_l0_bias
=
torch
.
chunk
(
full_
mlp
_
l0
_
bias
,
args
.
target_tensor_parallel_size
,
dim
=
0
)
mlp_l0_bias
=
torch
.
chunk
(
msg
.
pop
(
"
mlp
l0
bias
"
)
,
args
.
target_tensor_parallel_size
,
dim
=
0
)
out_
mlp_l1_weight
=
torch
.
chunk
(
full_
mlp
_
l1
_
weight
,
args
.
target_tensor_parallel_size
,
dim
=
1
)
mlp_l1_weight
=
torch
.
chunk
(
msg
.
pop
(
"
mlp
l1
weight
"
)
,
args
.
target_tensor_parallel_size
,
dim
=
1
)
# Save them to the model
# Save them to the model
for
tp_rank
in
range
(
args
.
target_tensor_parallel_size
):
for
tp_rank
in
range
(
args
.
target_tensor_parallel_size
):
l
=
models
[
tp_rank
].
language_model
.
encoder
.
layers
[
layer
]
l
=
models
[
tp_rank
].
language_model
.
encoder
.
layers
[
layer
]
l
.
input_layernorm
.
weight
.
data
.
copy_
(
input_layernorm_weight
)
l
.
input_layernorm
.
weight
.
data
.
copy_
(
input_layernorm_weight
)
l
.
input_layernorm
.
bias
.
data
.
copy_
(
input_layernorm_bias
)
l
.
input_layernorm
.
bias
.
data
.
copy_
(
input_layernorm_bias
)
l
.
self_attention
.
query_key_value
.
weight
.
data
.
copy_
(
out_
qkv_weight
[
tp_rank
])
l
.
self_attention
.
query_key_value
.
weight
.
data
.
copy_
(
qkv_weight
[
tp_rank
])
l
.
self_attention
.
query_key_value
.
bias
.
data
.
copy_
(
out_
qkv_bias
[
tp_rank
])
l
.
self_attention
.
query_key_value
.
bias
.
data
.
copy_
(
qkv_bias
[
tp_rank
])
l
.
self_attention
.
dense
.
weight
.
data
.
copy_
(
out_
dense_weight
[
tp_rank
])
l
.
self_attention
.
dense
.
weight
.
data
.
copy_
(
dense_weight
[
tp_rank
])
l
.
self_attention
.
dense
.
bias
.
data
.
copy_
(
dense_bias
)
l
.
self_attention
.
dense
.
bias
.
data
.
copy_
(
dense_bias
)
l
.
post_attention_layernorm
.
weight
.
data
.
copy_
(
post_layernorm_weight
)
l
.
post_attention_layernorm
.
weight
.
data
.
copy_
(
post_layernorm_weight
)
l
.
post_attention_layernorm
.
bias
.
data
.
copy_
(
post_layernorm_bias
)
l
.
post_attention_layernorm
.
bias
.
data
.
copy_
(
post_layernorm_bias
)
l
.
mlp
.
dense_h_to_4h
.
weight
.
data
.
copy_
(
out_
mlp_l0_weight
[
tp_rank
])
l
.
mlp
.
dense_h_to_4h
.
weight
.
data
.
copy_
(
mlp_l0_weight
[
tp_rank
])
l
.
mlp
.
dense_h_to_4h
.
bias
.
data
.
copy_
(
out_
mlp_l0_bias
[
tp_rank
])
l
.
mlp
.
dense_h_to_4h
.
bias
.
data
.
copy_
(
mlp_l0_bias
[
tp_rank
])
l
.
mlp
.
dense_4h_to_h
.
weight
.
data
.
copy_
(
out_
mlp_l1_weight
[
tp_rank
])
l
.
mlp
.
dense_4h_to_h
.
weight
.
data
.
copy_
(
mlp_l1_weight
[
tp_rank
])
l
.
mlp
.
dense_4h_to_h
.
bias
.
data
.
copy_
(
mlp_l1_bias
)
l
.
mlp
.
dense_4h_to_h
.
bias
.
data
.
copy_
(
mlp_l1_bias
)
total_layer_num
=
total_layer_num
+
1
check_message
(
msg
)
if
post_process
:
if
post_process
:
final_layernorm_weight
=
queue_get
()
msg
=
queue_get
(
"final layernorm"
)
final_layernorm_bias
=
queue_get
()
final_layernorm_weight
=
msg
.
pop
(
"weight"
)
final_layernorm_bias
=
msg
.
pop
(
"bias"
)
for
tp_rank
in
range
(
args
.
target_tensor_parallel_size
):
for
tp_rank
in
range
(
args
.
target_tensor_parallel_size
):
models
[
tp_rank
].
language_model
.
encoder
.
final_layernorm
.
weight
.
data
.
copy_
(
final_layernorm_weight
)
models
[
tp_rank
].
language_model
.
encoder
.
final_layernorm
.
weight
.
data
.
copy_
(
final_layernorm_weight
)
models
[
tp_rank
].
language_model
.
encoder
.
final_layernorm
.
bias
.
data
.
copy_
(
final_layernorm_bias
)
models
[
tp_rank
].
language_model
.
encoder
.
final_layernorm
.
bias
.
data
.
copy_
(
final_layernorm_bias
)
...
@@ -242,49 +265,56 @@ def save_checkpoint(queue, args):
...
@@ -242,49 +265,56 @@ def save_checkpoint(queue, args):
models
[
tp_rank
].
word_embeddings
.
weight
.
data
.
copy_
(
out_word_embed
[
tp_rank
])
models
[
tp_rank
].
word_embeddings
.
weight
.
data
.
copy_
(
out_word_embed
[
tp_rank
])
del
final_layernorm_weight
del
final_layernorm_weight
del
final_layernorm_bias
del
final_layernorm_bias
check_message
(
msg
)
name
=
queue_get
()
msg
=
queue_get
()
if
name
==
"pooler"
:
if
msg
!=
"done"
and
msg
[
"
name
"
]
==
"pooler"
:
if
not
hasattr
(
models
[
0
].
language_model
,
'pooler'
):
if
not
hasattr
(
models
[
0
].
language_model
,
'pooler'
):
print
(
"ERROR: got a pooler, but model does not have one"
)
print
(
"ERROR: got a pooler, but model does not have one"
)
exit
(
1
)
exit
(
1
)
pooler_weight
=
queue_get
()
print
(
"received pooler"
)
pooler_bias
=
queue_get
()
pooler_weight
=
msg
.
pop
(
"weight"
)
pooler_bias
=
msg
.
pop
(
"bias"
)
for
tp_rank
in
range
(
args
.
target_tensor_parallel_size
):
for
tp_rank
in
range
(
args
.
target_tensor_parallel_size
):
models
[
tp_rank
].
language_model
.
pooler
.
dense
.
weight
.
data
.
copy_
(
pooler_weight
)
models
[
tp_rank
].
language_model
.
pooler
.
dense
.
weight
.
data
.
copy_
(
pooler_weight
)
models
[
tp_rank
].
language_model
.
pooler
.
dense
.
bias
.
data
.
copy_
(
pooler_bias
)
models
[
tp_rank
].
language_model
.
pooler
.
dense
.
bias
.
data
.
copy_
(
pooler_bias
)
name
=
queue_get
()
del
pooler_weight
del
pooler_weight
del
pooler_bias
del
pooler_bias
check_message
(
msg
)
msg
=
queue_get
()
if
name
==
"lm head"
:
if
msg
!=
"done"
and
msg
[
"
name
"
]
==
"lm head"
:
if
not
hasattr
(
models
[
0
],
'lm_head'
):
if
not
hasattr
(
models
[
0
],
'lm_head'
):
print
(
"ERROR: got an lm head, but model does not have one"
)
print
(
"ERROR: got an lm head, but model does not have one"
)
exit
(
1
)
exit
(
1
)
lm_head_dense_weight
=
queue_get
()
print
(
"received lm head"
)
lm_head_dense_bias
=
queue_get
()
lm_head_dense_weight
=
msg
.
pop
(
"dense weight"
)
lm_head_layernorm_weight
=
queue_get
()
lm_head_dense_bias
=
msg
.
pop
(
"dense bias"
)
lm_head_layernorm_bias
=
queue_get
()
lm_head_layernorm_weight
=
msg
.
pop
(
"layernorm weight"
)
lm_head_layernorm_bias
=
msg
.
pop
(
"layernorm bias"
)
for
tp_rank
in
range
(
args
.
target_tensor_parallel_size
):
for
tp_rank
in
range
(
args
.
target_tensor_parallel_size
):
models
[
tp_rank
].
lm_head
.
dense
.
weight
.
data
.
copy_
(
lm_head_dense_weight
)
models
[
tp_rank
].
lm_head
.
dense
.
weight
.
data
.
copy_
(
lm_head_dense_weight
)
models
[
tp_rank
].
lm_head
.
dense
.
bias
.
data
.
copy_
(
lm_head_dense_bias
)
models
[
tp_rank
].
lm_head
.
dense
.
bias
.
data
.
copy_
(
lm_head_dense_bias
)
models
[
tp_rank
].
lm_head
.
layernorm
.
weight
.
data
.
copy_
(
lm_head_layernorm_weight
)
models
[
tp_rank
].
lm_head
.
layernorm
.
weight
.
data
.
copy_
(
lm_head_layernorm_weight
)
models
[
tp_rank
].
lm_head
.
layernorm
.
bias
.
data
.
copy_
(
lm_head_layernorm_bias
)
models
[
tp_rank
].
lm_head
.
layernorm
.
bias
.
data
.
copy_
(
lm_head_layernorm_bias
)
name
=
queue_get
()
check_message
(
msg
)
msg
=
queue_get
()
if
name
==
"binary head"
:
if
msg
!=
"done"
and
msg
[
"
name
"
]
==
"binary head"
:
if
not
hasattr
(
models
[
0
],
'binary_head'
):
if
not
hasattr
(
models
[
0
],
'binary_head'
):
print
(
"ERROR: got a binary head, but model does not have one"
)
print
(
"ERROR: got a binary head, but model does not have one"
)
exit
(
1
)
exit
(
1
)
binary_head_weight
=
queue_get
()
print
(
"received binary head"
)
binary_head_bias
=
queue_get
()
binary_head_weight
=
msg
.
pop
(
"weight"
)
binary_head_bias
=
msg
.
pop
(
"bias"
)
for
tp_rank
in
range
(
args
.
target_tensor_parallel_size
):
for
tp_rank
in
range
(
args
.
target_tensor_parallel_size
):
models
[
tp_rank
].
binary_head
.
weight
.
data
.
copy_
(
binary_head_weight
)
models
[
tp_rank
].
binary_head
.
weight
.
data
.
copy_
(
binary_head_weight
)
models
[
tp_rank
].
binary_head
.
bias
.
data
.
copy_
(
binary_head_bias
)
models
[
tp_rank
].
binary_head
.
bias
.
data
.
copy_
(
binary_head_bias
)
name
=
queue_get
()
check_message
(
msg
)
msg
=
queue_get
()
if
name
!=
"done"
:
if
msg
!=
"done"
:
print
(
"ERROR: got some more data but w
ere
expecting to be done"
)
print
(
"ERROR: got some more data but w
as
expecting to be done"
)
for
tp_rank
in
range
(
args
.
target_tensor_parallel_size
):
for
tp_rank
in
range
(
args
.
target_tensor_parallel_size
):
mpu
.
initialize
.
set_tensor_model_parallel_rank
(
tp_rank
)
mpu
.
initialize
.
set_tensor_model_parallel_rank
(
tp_rank
)
...
...
tools/checkpoint_util.py
View file @
748a9a7d
...
@@ -12,10 +12,12 @@ import sys
...
@@ -12,10 +12,12 @@ import sys
# load_checkpoint
# load_checkpoint
# The loader and saver process are each given a queue, the loader
# The loader and saver process are each given a queue, the loader
# should load the checkpoint and send the weights in the following
# should load the checkpoint and send the weights in messages in the
# order, the saver should receive them in this order and save the
# following order, the saver should receive them in this order and
# checkpoints. Note that the weight sent over the queue are the full
# save the checkpoints. A message consists of a python dictionary with
# model weights, nothing split.
# a "name" for error checking and an entry for each tensor as
# indicated below. Note that the weight sent over the queue are the
# full model weights, nothing split.
# If the loader ever sends "exit" to the queue, that means something
# If the loader ever sends "exit" to the queue, that means something
# went wrong and it is exiting.
# went wrong and it is exiting.
...
@@ -37,35 +39,51 @@ import sys
...
@@ -37,35 +39,51 @@ import sys
# make_vocab_size_divisble_by
# make_vocab_size_divisble_by
# consumed_train_samples
# consumed_train_samples
# consumed_valid_samples
# consumed_valid_samples
# - Position embeddings
# messages
# - Word embeddings
# {
# - For each transformer layer:
# "name": "embeddings"
# - input layernorm weights
# "position embeddings"
# - input layernorm bias
# "word embeddings"
# - qkv weight
# }
# - qkv bias
# (for each transformer layer):
# - dense weight
# {
# - dense bias
# "name": "transformer layer N"
# - post attention layernorm weight
# "input layernorm weight"
# - post attention layernorm bias
# "input layernorm bias"
# - mlp layer 0 (h to 4h) weight
# "qkv weight"
# - mlp layer 0 (h to 4h) bias
# "qkv bias"
# - mlp layer 1 (4h to h) weight
# "dense weight"
# - mlp layer 1 (4h to h) bias
# "dense bias"
# - final layer norm weight
# "post layernorm weight"
# - final layer norm bias
# "post layernorm bias"
# - if present (i.e. for BERT):
# "mlp l0 weight"
# - "pooler"
# "mlp l0 bias"
# - LM Pooler weight
# "mlp l1 weight"
# - LM Pooler bias
# "mlp l1 bias"
# - "lm head"
# }
# - LM head dense weight
# {
# - LM head dense bias
# "name": "final layer norm"
# - LM head layernorm weight
# "weight"
# - LM head layernorm bias
# "bias"
# - "binary head"
# }
# - BERT Binary head weight
# if present (i.e. for BERT):
# - BERT Binary head bias
# {
# "name": "pooler"
# "weight"
# "bias"
# }
# {
# "name": "lm head"
# "dense weight"
# "dense bias"
# "layernorm weight"
# "layernorm bias"
# }
# {
# "name": "binary head"
# "weight"
# "bias"
# }
# - "done"
# - "done"
def
load_plugin
(
plugin_type
,
name
):
def
load_plugin
(
plugin_type
,
name
):
...
@@ -103,6 +121,9 @@ def main():
...
@@ -103,6 +121,9 @@ def main():
help
=
'Directory to save model checkpoint to'
)
help
=
'Directory to save model checkpoint to'
)
parser
.
add_argument
(
'--max-queue-size'
,
type
=
int
,
default
=
50
,
parser
.
add_argument
(
'--max-queue-size'
,
type
=
int
,
default
=
50
,
help
=
'Maximum number of tensors in the queue'
)
help
=
'Maximum number of tensors in the queue'
)
parser
.
add_argument
(
'--no-checking'
,
action
=
'store_false'
,
help
=
'Do not perform checking on the name and ordering of weights'
,
dest
=
'checking'
)
known_args
,
_
=
parser
.
parse_known_args
()
known_args
,
_
=
parser
.
parse_known_args
()
loader
=
load_plugin
(
'loader'
,
known_args
.
loader
)
loader
=
load_plugin
(
'loader'
,
known_args
.
loader
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment