Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
748a9a7d
Commit
748a9a7d
authored
Jul 11, 2022
by
Jared Casper
Browse files
Add error checking by wrapping all weights and biases into named messages with named tensors.
parent
2755bcb8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
187 additions
and
130 deletions
+187
-130
tools/checkpoint_loader_megatron.py
tools/checkpoint_loader_megatron.py
+56
-50
tools/checkpoint_saver_megatron.py
tools/checkpoint_saver_megatron.py
+77
-47
tools/checkpoint_util.py
tools/checkpoint_util.py
+54
-33
No files found.
tools/checkpoint_loader_megatron.py
View file @
748a9a7d
...
...
@@ -170,18 +170,20 @@ def _load_checkpoint(queue, args):
md
.
consumed_valid_samples
=
consumed_valid_samples
queue
.
put
(
md
)
# Send embeddings
def
queue_put
(
name
,
msg
):
print
(
f
"sending
{
name
}
"
)
msg
[
"name"
]
=
name
queue
.
put
(
msg
)
word_embed
=
[]
for
tp_rank
in
range
(
tp_size
):
if
tp_rank
==
0
:
print
(
"Sending position embeddings"
)
queue
.
put
(
models
[
tp_rank
].
language_model
.
embedding
.
position
_embeddings
.
weight
.
data
)
word_embed
.
append
(
models
[
tp_rank
].
language_model
.
embedding
.
word_embeddings
.
weight
.
data
)
full_word_embed
=
torch
.
cat
(
word_embed
,
dim
=
0
)
# Send embeddings
message
=
{
"position embeddings"
:
models
[
0
].
language_model
.
embedding
.
position_embeddings
.
weight
.
data
,
"word embeddings"
:
torch
.
cat
(
[
models
[
tp_rank
].
language_model
.
embedding
.
word
_embeddings
.
weight
.
data
for
tp_rank
in
range
(
tp_size
)],
dim
=
0
)
}
print
(
"Sending word embeddings"
)
queue
.
put
(
full_word_embed
)
queue_put
(
"embeddings"
,
message
)
total_layer_num
=
0
for
pp_rank
in
range
(
pp_size
):
...
...
@@ -190,23 +192,24 @@ def _load_checkpoint(queue, args):
post_process
=
pp_rank
==
pp_size
-
1
models
=
get_models
(
tp_size
,
md
.
params_dtype
,
False
,
post_process
)
for
layer_num
in
range
(
len
(
models
[
0
].
language_model
.
encoder
.
layers
)):
message
=
{}
# Get non-parallel tensors from tp_rank 0
layer
=
models
[
0
].
language_model
.
encoder
.
layers
[
layer_num
]
message
[
"input layernorm weight"
]
=
layer
.
input_layernorm
.
weight
.
data
message
[
"input layernorm bias"
]
=
layer
.
input_layernorm
.
bias
.
data
message
[
"dense bias"
]
=
layer
.
self_attention
.
dense
.
bias
.
data
message
[
"post layernorm weight"
]
=
layer
.
post_attention_layernorm
.
weight
.
data
message
[
"post layernorm bias"
]
=
layer
.
post_attention_layernorm
.
bias
.
data
message
[
"mlp l1 bias"
]
=
layer
.
mlp
.
dense_4h_to_h
.
bias
.
data
# Grab all parallel tensors for this layer
qkv_weight
=
[]
qkv_bias
=
[]
dense_weight
=
[]
mlp_l0_weight
=
[]
mlp_l0_bias
=
[]
mlp_l1_weight
=
[]
# Get non-parallel tensors from tp_rank 0
layer
=
models
[
0
].
language_model
.
encoder
.
layers
[
layer_num
]
input_layernorm_weight
=
layer
.
input_layernorm
.
weight
.
data
input_layernorm_bias
=
layer
.
input_layernorm
.
bias
.
data
dense_bias
=
layer
.
self_attention
.
dense
.
bias
.
data
post_layernorm_weight
=
layer
.
post_attention_layernorm
.
weight
.
data
post_layernorm_bias
=
layer
.
post_attention_layernorm
.
bias
.
data
mlp_l1_bias
=
layer
.
mlp
.
dense_4h_to_h
.
bias
.
data
# Grab all parallel tensors for this layer
for
tp_rank
,
model
in
enumerate
(
models
):
layer
=
model
.
language_model
.
encoder
.
layers
[
layer_num
]
qkv_weight
.
append
(
layer
.
self_attention
.
query_key_value
.
weight
.
data
)
...
...
@@ -216,47 +219,50 @@ def _load_checkpoint(queue, args):
mlp_l0_bias
.
append
(
layer
.
mlp
.
dense_h_to_4h
.
bias
.
data
)
mlp_l1_weight
.
append
(
layer
.
mlp
.
dense_4h_to_h
.
weight
.
data
)
# send everything in order while concatenating them
print
(
f
"Sending layer
{
layer_num
}
of pipeline rank
{
pp_rank
}
(total layer
{
total_layer_num
}
)"
)
queue
.
put
(
input_layernorm_weight
)
queue
.
put
(
input_layernorm_bias
)
queue
.
put
(
torch
.
cat
(
qkv_weight
,
dim
=
0
))
queue
.
put
(
torch
.
cat
(
qkv_bias
,
dim
=
0
))
queue
.
put
(
torch
.
cat
(
dense_weight
,
dim
=
1
))
queue
.
put
(
dense_bias
)
queue
.
put
(
post_layernorm_weight
)
queue
.
put
(
post_layernorm_bias
)
queue
.
put
(
torch
.
cat
(
mlp_l0_weight
,
dim
=
0
))
queue
.
put
(
torch
.
cat
(
mlp_l0_bias
,
dim
=
0
))
queue
.
put
(
torch
.
cat
(
mlp_l1_weight
,
dim
=
1
))
queue
.
put
(
mlp_l1_bias
)
# concat them
message
[
"qkv weight"
]
=
torch
.
cat
(
qkv_weight
,
dim
=
0
)
message
[
"qkv bias"
]
=
torch
.
cat
(
qkv_bias
,
dim
=
0
)
message
[
"dense weight"
]
=
torch
.
cat
(
dense_weight
,
dim
=
1
)
message
[
"mlp l0 weight"
]
=
torch
.
cat
(
mlp_l0_weight
,
dim
=
0
)
message
[
"mlp l0 bias"
]
=
torch
.
cat
(
mlp_l0_bias
,
dim
=
0
)
message
[
"mlp l1 weight"
]
=
torch
.
cat
(
mlp_l1_weight
,
dim
=
1
)
queue_put
(
f
"transformer layer
{
total_layer_num
}
"
,
message
)
total_layer_num
=
total_layer_num
+
1
# Send final layernorm from tp_rank 0
print
(
"Sending final layernorm"
)
queue
.
put
(
models
[
0
].
language_model
.
encoder
.
final_layernorm
.
weight
.
data
)
queue
.
put
(
models
[
0
].
language_model
.
encoder
.
final_layernorm
.
bias
.
data
)
message
=
{
"weight"
:
models
[
0
].
language_model
.
encoder
.
final_layernorm
.
weight
.
data
,
"bias"
:
models
[
0
].
language_model
.
encoder
.
final_layernorm
.
bias
.
data
}
queue_put
(
"final layernorm"
,
message
)
# Send BERT lm head and binary head if it exists
if
md
.
model_type
==
'BERT'
:
print
(
"Sending LM Pooler"
)
queue
.
put
(
"pooler"
)
queue
.
put
(
models
[
0
].
language_model
.
pooler
.
dense
.
weight
.
data
)
queue
.
put
(
models
[
0
].
language_model
.
pooler
.
dense
.
bias
.
data
)
print
(
"Sending BERT LM head"
)
queue
.
put
(
"lm head"
)
queue
.
put
(
models
[
0
].
lm_head
.
dense
.
weight
.
data
)
queue
.
put
(
models
[
0
].
lm_head
.
dense
.
bias
.
data
)
queue
.
put
(
models
[
0
].
lm_head
.
layernorm
.
weight
.
data
)
queue
.
put
(
models
[
0
].
lm_head
.
layernorm
.
bias
.
data
)
message
=
{
"weight"
:
models
[
0
].
language_model
.
pooler
.
dense
.
weight
.
data
,
"bias"
:
models
[
0
].
language_model
.
pooler
.
dense
.
bias
.
data
}
queue_put
(
"pooler"
,
message
)
message
=
{
"dense weight"
:
models
[
0
].
lm_head
.
dense
.
weight
.
data
,
"dense bias"
:
models
[
0
].
lm_head
.
dense
.
bias
.
data
,
"layernorm weight"
:
models
[
0
].
lm_head
.
layernorm
.
weight
.
data
,
"layernorm bias"
:
models
[
0
].
lm_head
.
layernorm
.
bias
.
data
}
queue_put
(
"lm head"
,
message
)
if
md
.
bert_binary_head
:
print
(
"Sending BERT Binary head"
)
queue
.
put
(
"binary head"
)
queue
.
put
(
models
[
0
].
binary_head
.
weight
.
data
)
queue
.
put
(
models
[
0
].
binary_head
.
bias
.
data
)
message
=
{
"weight"
:
models
[
0
].
binary_head
.
weight
.
data
,
"bias"
:
models
[
0
].
binary_head
.
bias
.
data
}
queue_put
(
"binary head"
,
message
)
queue
.
put
(
"done"
)
def
load_checkpoint
(
queue
,
args
):
...
...
tools/checkpoint_saver_megatron.py
View file @
748a9a7d
import
argparse
from
collections.abc
import
Mapping
import
concurrent.futures
import
os
import
sys
...
...
@@ -38,13 +39,31 @@ def save_checkpoint(queue, args):
print
(
"Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting."
)
exit
(
1
)
def
queue_get
():
def
queue_get
(
name
=
None
):
val
=
queue
.
get
()
if
val
==
"exit"
:
print
(
"Loader exited, exiting saver"
)
exit
(
1
)
if
name
is
not
None
and
args
.
checking
and
val
[
"name"
]
!=
name
:
val_name
=
val
[
"name"
]
print
(
f
'Unexpected message. Expecting "
{
name
}
" but got "
{
val_name
}
". Exiting saver.'
)
exit
(
1
)
if
name
is
not
None
:
print
(
f
"received
{
name
}
"
)
return
val
def
check_message
(
msg
):
if
not
args
.
checking
:
return
msg_name
=
msg
.
pop
(
"name"
)
if
len
(
msg
.
keys
())
>
0
:
print
(
f
"Unexpected values in
{
msg_name
}
:"
)
for
key
in
msg
.
keys
():
print
(
f
"
{
key
}
"
)
print
(
f
"Exiting. If you want to ignore this, use the argument --no-checking."
)
exit
(
1
)
md
=
queue_get
()
if
args
.
target_tensor_parallel_size
is
None
:
...
...
@@ -141,8 +160,11 @@ def save_checkpoint(queue, args):
# Embeddings
#-----------
pos_embed
=
queue_get
()
orig_word_embed
=
queue_get
()
embeddings_msg
=
queue_get
(
"embeddings"
)
pos_embed
=
embeddings_msg
.
pop
(
"position embeddings"
)
orig_word_embed
=
embeddings_msg
.
pop
(
"word embeddings"
)
check_message
(
embeddings_msg
)
# Deal with padding
if
md
.
true_vocab_size
is
not
None
:
...
...
@@ -185,6 +207,7 @@ def save_checkpoint(queue, args):
# Transformer layers
#-------------------
total_layer_num
=
0
for
pp_rank
in
range
(
args
.
target_pipeline_parallel_size
):
# For later pipeline parallel ranks, make the new models
if
pp_rank
>
0
:
...
...
@@ -193,47 +216,47 @@ def save_checkpoint(queue, args):
models
=
get_models
(
args
.
target_tensor_parallel_size
,
md
.
params_dtype
,
False
,
post_process
)
for
layer
in
range
(
len
(
models
[
0
].
language_model
.
encoder
.
layers
)):
# get full tensors
input_layernorm_weight
=
queue_get
()
input_layernorm_bias
=
queue_get
()
full_qkv_weight
=
queue_get
()
full_qkv_bias
=
queue_get
()
full_dense_weight
=
queue_get
()
dense_bias
=
queue_get
()
post_layernorm_weight
=
queue_get
()
post_layernorm_bias
=
queue_get
()
full_mlp_l0_weight
=
queue_get
()
full_mlp_l0_bias
=
queue_get
()
full_mlp_l1_weight
=
queue_get
()
mlp_l1_bias
=
queue_get
()
msg
=
queue_get
(
f
"transformer layer
{
total_layer_num
}
"
)
# duplicated tensors
input_layernorm_weight
=
msg
.
pop
(
"input layernorm weight"
)
input_layernorm_bias
=
msg
.
pop
(
"input layernorm bias"
)
dense_bias
=
msg
.
pop
(
"dense bias"
)
post_layernorm_weight
=
msg
.
pop
(
"post layernorm weight"
)
post_layernorm_bias
=
msg
.
pop
(
"post layernorm bias"
)
mlp_l1_bias
=
msg
.
pop
(
"mlp l1 bias"
)
# Split up the parallel tensors
out_
qkv_weight
=
torch
.
chunk
(
full_
qkv
_
weight
,
args
.
target_tensor_parallel_size
,
dim
=
0
)
out_
qkv_bias
=
torch
.
chunk
(
full_
qkv
_
bias
,
args
.
target_tensor_parallel_size
,
dim
=
0
)
out_
dense_weight
=
torch
.
chunk
(
full_
dense
_
weight
,
args
.
target_tensor_parallel_size
,
dim
=
1
)
out_
mlp_l0_weight
=
torch
.
chunk
(
full_
mlp
_
l0
_
weight
,
args
.
target_tensor_parallel_size
,
dim
=
0
)
out_
mlp_l0_bias
=
torch
.
chunk
(
full_
mlp
_
l0
_
bias
,
args
.
target_tensor_parallel_size
,
dim
=
0
)
out_
mlp_l1_weight
=
torch
.
chunk
(
full_
mlp
_
l1
_
weight
,
args
.
target_tensor_parallel_size
,
dim
=
1
)
qkv_weight
=
torch
.
chunk
(
msg
.
pop
(
"
qkv
weight
"
)
,
args
.
target_tensor_parallel_size
,
dim
=
0
)
qkv_bias
=
torch
.
chunk
(
msg
.
pop
(
"
qkv
bias
"
)
,
args
.
target_tensor_parallel_size
,
dim
=
0
)
dense_weight
=
torch
.
chunk
(
msg
.
pop
(
"
dense
weight
"
)
,
args
.
target_tensor_parallel_size
,
dim
=
1
)
mlp_l0_weight
=
torch
.
chunk
(
msg
.
pop
(
"
mlp
l0
weight
"
)
,
args
.
target_tensor_parallel_size
,
dim
=
0
)
mlp_l0_bias
=
torch
.
chunk
(
msg
.
pop
(
"
mlp
l0
bias
"
)
,
args
.
target_tensor_parallel_size
,
dim
=
0
)
mlp_l1_weight
=
torch
.
chunk
(
msg
.
pop
(
"
mlp
l1
weight
"
)
,
args
.
target_tensor_parallel_size
,
dim
=
1
)
# Save them to the model
for
tp_rank
in
range
(
args
.
target_tensor_parallel_size
):
l
=
models
[
tp_rank
].
language_model
.
encoder
.
layers
[
layer
]
l
.
input_layernorm
.
weight
.
data
.
copy_
(
input_layernorm_weight
)
l
.
input_layernorm
.
bias
.
data
.
copy_
(
input_layernorm_bias
)
l
.
self_attention
.
query_key_value
.
weight
.
data
.
copy_
(
out_
qkv_weight
[
tp_rank
])
l
.
self_attention
.
query_key_value
.
bias
.
data
.
copy_
(
out_
qkv_bias
[
tp_rank
])
l
.
self_attention
.
dense
.
weight
.
data
.
copy_
(
out_
dense_weight
[
tp_rank
])
l
.
self_attention
.
query_key_value
.
weight
.
data
.
copy_
(
qkv_weight
[
tp_rank
])
l
.
self_attention
.
query_key_value
.
bias
.
data
.
copy_
(
qkv_bias
[
tp_rank
])
l
.
self_attention
.
dense
.
weight
.
data
.
copy_
(
dense_weight
[
tp_rank
])
l
.
self_attention
.
dense
.
bias
.
data
.
copy_
(
dense_bias
)
l
.
post_attention_layernorm
.
weight
.
data
.
copy_
(
post_layernorm_weight
)
l
.
post_attention_layernorm
.
bias
.
data
.
copy_
(
post_layernorm_bias
)
l
.
mlp
.
dense_h_to_4h
.
weight
.
data
.
copy_
(
out_
mlp_l0_weight
[
tp_rank
])
l
.
mlp
.
dense_h_to_4h
.
bias
.
data
.
copy_
(
out_
mlp_l0_bias
[
tp_rank
])
l
.
mlp
.
dense_4h_to_h
.
weight
.
data
.
copy_
(
out_
mlp_l1_weight
[
tp_rank
])
l
.
mlp
.
dense_h_to_4h
.
weight
.
data
.
copy_
(
mlp_l0_weight
[
tp_rank
])
l
.
mlp
.
dense_h_to_4h
.
bias
.
data
.
copy_
(
mlp_l0_bias
[
tp_rank
])
l
.
mlp
.
dense_4h_to_h
.
weight
.
data
.
copy_
(
mlp_l1_weight
[
tp_rank
])
l
.
mlp
.
dense_4h_to_h
.
bias
.
data
.
copy_
(
mlp_l1_bias
)
total_layer_num
=
total_layer_num
+
1
check_message
(
msg
)
if
post_process
:
final_layernorm_weight
=
queue_get
()
final_layernorm_bias
=
queue_get
()
msg
=
queue_get
(
"final layernorm"
)
final_layernorm_weight
=
msg
.
pop
(
"weight"
)
final_layernorm_bias
=
msg
.
pop
(
"bias"
)
for
tp_rank
in
range
(
args
.
target_tensor_parallel_size
):
models
[
tp_rank
].
language_model
.
encoder
.
final_layernorm
.
weight
.
data
.
copy_
(
final_layernorm_weight
)
models
[
tp_rank
].
language_model
.
encoder
.
final_layernorm
.
bias
.
data
.
copy_
(
final_layernorm_bias
)
...
...
@@ -242,49 +265,56 @@ def save_checkpoint(queue, args):
models
[
tp_rank
].
word_embeddings
.
weight
.
data
.
copy_
(
out_word_embed
[
tp_rank
])
del
final_layernorm_weight
del
final_layernorm_bias
check_message
(
msg
)
name
=
queue_get
()
if
name
==
"pooler"
:
msg
=
queue_get
()
if
msg
!=
"done"
and
msg
[
"
name
"
]
==
"pooler"
:
if
not
hasattr
(
models
[
0
].
language_model
,
'pooler'
):
print
(
"ERROR: got a pooler, but model does not have one"
)
exit
(
1
)
pooler_weight
=
queue_get
()
pooler_bias
=
queue_get
()
print
(
"received pooler"
)
pooler_weight
=
msg
.
pop
(
"weight"
)
pooler_bias
=
msg
.
pop
(
"bias"
)
for
tp_rank
in
range
(
args
.
target_tensor_parallel_size
):
models
[
tp_rank
].
language_model
.
pooler
.
dense
.
weight
.
data
.
copy_
(
pooler_weight
)
models
[
tp_rank
].
language_model
.
pooler
.
dense
.
bias
.
data
.
copy_
(
pooler_bias
)
name
=
queue_get
()
del
pooler_weight
del
pooler_bias
check_message
(
msg
)
msg
=
queue_get
()
if
name
==
"lm head"
:
if
msg
!=
"done"
and
msg
[
"
name
"
]
==
"lm head"
:
if
not
hasattr
(
models
[
0
],
'lm_head'
):
print
(
"ERROR: got an lm head, but model does not have one"
)
exit
(
1
)
lm_head_dense_weight
=
queue_get
()
lm_head_dense_bias
=
queue_get
()
lm_head_layernorm_weight
=
queue_get
()
lm_head_layernorm_bias
=
queue_get
()
print
(
"received lm head"
)
lm_head_dense_weight
=
msg
.
pop
(
"dense weight"
)
lm_head_dense_bias
=
msg
.
pop
(
"dense bias"
)
lm_head_layernorm_weight
=
msg
.
pop
(
"layernorm weight"
)
lm_head_layernorm_bias
=
msg
.
pop
(
"layernorm bias"
)
for
tp_rank
in
range
(
args
.
target_tensor_parallel_size
):
models
[
tp_rank
].
lm_head
.
dense
.
weight
.
data
.
copy_
(
lm_head_dense_weight
)
models
[
tp_rank
].
lm_head
.
dense
.
bias
.
data
.
copy_
(
lm_head_dense_bias
)
models
[
tp_rank
].
lm_head
.
layernorm
.
weight
.
data
.
copy_
(
lm_head_layernorm_weight
)
models
[
tp_rank
].
lm_head
.
layernorm
.
bias
.
data
.
copy_
(
lm_head_layernorm_bias
)
name
=
queue_get
()
check_message
(
msg
)
msg
=
queue_get
()
if
name
==
"binary head"
:
if
msg
!=
"done"
and
msg
[
"
name
"
]
==
"binary head"
:
if
not
hasattr
(
models
[
0
],
'binary_head'
):
print
(
"ERROR: got a binary head, but model does not have one"
)
exit
(
1
)
binary_head_weight
=
queue_get
()
binary_head_bias
=
queue_get
()
print
(
"received binary head"
)
binary_head_weight
=
msg
.
pop
(
"weight"
)
binary_head_bias
=
msg
.
pop
(
"bias"
)
for
tp_rank
in
range
(
args
.
target_tensor_parallel_size
):
models
[
tp_rank
].
binary_head
.
weight
.
data
.
copy_
(
binary_head_weight
)
models
[
tp_rank
].
binary_head
.
bias
.
data
.
copy_
(
binary_head_bias
)
name
=
queue_get
()
check_message
(
msg
)
msg
=
queue_get
()
if
name
!=
"done"
:
print
(
"ERROR: got some more data but w
ere
expecting to be done"
)
if
msg
!=
"done"
:
print
(
"ERROR: got some more data but w
as
expecting to be done"
)
for
tp_rank
in
range
(
args
.
target_tensor_parallel_size
):
mpu
.
initialize
.
set_tensor_model_parallel_rank
(
tp_rank
)
...
...
tools/checkpoint_util.py
View file @
748a9a7d
...
...
@@ -12,10 +12,12 @@ import sys
# load_checkpoint
# The loader and saver process are each given a queue, the loader
# should load the checkpoint and send the weights in the following
# order, the saver should receive them in this order and save the
# checkpoints. Note that the weight sent over the queue are the full
# model weights, nothing split.
# should load the checkpoint and send the weights in messages in the
# following order, the saver should receive them in this order and
# save the checkpoints. A message consists of a python dictionary with
# a "name" for error checking and an entry for each tensor as
# indicated below. Note that the weight sent over the queue are the
# full model weights, nothing split.
# If the loader ever sends "exit" to the queue, that means something
# went wrong and it is exiting.
...
...
@@ -37,35 +39,51 @@ import sys
# make_vocab_size_divisble_by
# consumed_train_samples
# consumed_valid_samples
# - Position embeddings
# - Word embeddings
# - For each transformer layer:
# - input layernorm weights
# - input layernorm bias
# - qkv weight
# - qkv bias
# - dense weight
# - dense bias
# - post attention layernorm weight
# - post attention layernorm bias
# - mlp layer 0 (h to 4h) weight
# - mlp layer 0 (h to 4h) bias
# - mlp layer 1 (4h to h) weight
# - mlp layer 1 (4h to h) bias
# - final layer norm weight
# - final layer norm bias
# - if present (i.e. for BERT):
# - "pooler"
# - LM Pooler weight
# - LM Pooler bias
# - "lm head"
# - LM head dense weight
# - LM head dense bias
# - LM head layernorm weight
# - LM head layernorm bias
# - "binary head"
# - BERT Binary head weight
# - BERT Binary head bias
# messages
# {
# "name": "embeddings"
# "position embeddings"
# "word embeddings"
# }
# (for each transformer layer):
# {
# "name": "transformer layer N"
# "input layernorm weight"
# "input layernorm bias"
# "qkv weight"
# "qkv bias"
# "dense weight"
# "dense bias"
# "post layernorm weight"
# "post layernorm bias"
# "mlp l0 weight"
# "mlp l0 bias"
# "mlp l1 weight"
# "mlp l1 bias"
# }
# {
# "name": "final layer norm"
# "weight"
# "bias"
# }
# if present (i.e. for BERT):
# {
# "name": "pooler"
# "weight"
# "bias"
# }
# {
# "name": "lm head"
# "dense weight"
# "dense bias"
# "layernorm weight"
# "layernorm bias"
# }
# {
# "name": "binary head"
# "weight"
# "bias"
# }
# - "done"
def
load_plugin
(
plugin_type
,
name
):
...
...
@@ -103,6 +121,9 @@ def main():
help
=
'Directory to save model checkpoint to'
)
parser
.
add_argument
(
'--max-queue-size'
,
type
=
int
,
default
=
50
,
help
=
'Maximum number of tensors in the queue'
)
parser
.
add_argument
(
'--no-checking'
,
action
=
'store_false'
,
help
=
'Do not perform checking on the name and ordering of weights'
,
dest
=
'checking'
)
known_args
,
_
=
parser
.
parse_known_args
()
loader
=
load_plugin
(
'loader'
,
known_args
.
loader
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment