Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
be4dda7b
Commit
be4dda7b
authored
Apr 17, 2025
by
wxj
Browse files
Update convert.py
parent
56bf70a2
Pipeline
#2650
passed with stage
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
173 additions
and
170 deletions
+173
-170
tools/checkpoint/convert.py
tools/checkpoint/convert.py
+173
-170
No files found.
tools/checkpoint/convert.py
View file @
be4dda7b
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import
argparse
import
argparse
import
importlib
import
importlib
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
import
sys
import
sys
# A loader is a python file with at least two functions
# A loader is a python file with at least two functions
# - add_arguments - takes in a parser and adds any arguments needed
# - add_arguments - takes in a parser and adds any arguments needed
# - load_checkpoint - takes in the queue and parsed arguments
# - load_checkpoint - takes in the queue and parsed arguments
# A saver is similar but has save_checkpoint instead of
# A saver is similar but has save_checkpoint instead of
# load_checkpoint
# load_checkpoint
# The loader and saver process are each given a queue, the loader
# The loader and saver process are each given a queue, the loader
# should load the checkpoint and send the weights in messages in the
# should load the checkpoint and send the weights in messages in the
# following order, the saver should receive them in this order and
# following order, the saver should receive them in this order and
# save the checkpoints. A message consists of a python dictionary with
# save the checkpoints. A message consists of a python dictionary with
# a "name" for error checking and an entry for each tensor as
# a "name" for error checking and an entry for each tensor as
# indicated below. Note that the weight sent over the queue are the
# indicated below. Note that the weight sent over the queue are the
# full model weights, nothing split.
# full model weights, nothing split.
# If the loader ever sends "exit" to the queue, that means something
# If the loader ever sends "exit" to the queue, that means something
# went wrong and it is exiting.
# went wrong and it is exiting.
# - Metadata Namespace with the following attributes:
# - Metadata Namespace with the following attributes:
# model_type - GPT, BERT, T5, etc. (Part of protocol to allow this to be deduced later instead of given on command line)
# model_type - GPT, BERT, T5, etc. (Part of protocol to allow this to be deduced later instead of given on command line)
# num_layers - Number of transformer layers
# num_layers - Number of transformer layers
# hidden_size
# hidden_size
# seq_length
# seq_length
# num_attention_heads
# num_attention_heads
# max_position_embeddings
# max_position_embeddings
# tokenizer_type
# tokenizer_type
# iteration
# iteration
# params_dtype
# params_dtype
# bert_binary_head - Used only if model_type is BERT
# bert_binary_head - Used only if model_type is BERT
# previous_tensor_parallel_size - Optional
# previous_tensor_parallel_size - Optional
# previous_pipeline_parallel_size - Optional
# previous_pipeline_parallel_size - Optional
# true_vocab_size
# true_vocab_size
# make_vocab_size_divisble_by
# make_vocab_size_divisble_by
# consumed_train_samples
# consumed_train_samples
# consumed_valid_samples
# consumed_valid_samples
# messages
# messages
# {
# {
# "name": "embeddings"
# "name": "embeddings"
# "position embeddings"
# "position embeddings"
# "word embeddings"
# "word embeddings"
# }
# }
# (for each transformer layer):
# (for each transformer layer):
# {
# {
# "name": "transformer layer N"
# "name": "transformer layer N"
# "input norm weight"
# "input norm weight"
# "input norm bias"
# "input norm bias"
# "qkv weight"
# "qkv weight"
# "qkv bias"
# "qkv bias"
# "dense weight"
# "dense weight"
# "dense bias"
# "dense bias"
# "post norm weight"
# "post norm weight"
# "post norm bias"
# "post norm bias"
# "mlp l0 weight"
# "mlp l0 weight"
# "mlp l0 bias"
# "mlp l0 bias"
# "mlp l1 weight"
# "mlp l1 weight"
# "mlp l1 bias"
# "mlp l1 bias"
# }
# }
# {
# {
# "name": "final layer norm"
# "name": "final layer norm"
# "weight"
# "weight"
# "bias"
# "bias"
# }
# }
# if present (i.e. for BERT):
# if present (i.e. for BERT):
# {
# {
# "name": "pooler"
# "name": "pooler"
# "weight"
# "weight"
# "bias"
# "bias"
# }
# }
# {
# {
# "name": "lm head"
# "name": "lm head"
# "dense weight"
# "dense weight"
# "dense bias"
# "dense bias"
# "norm weight"
# "norm weight"
# "norm bias"
# "norm bias"
# }
# }
# {
# {
# "name": "binary head"
# "name": "binary head"
# "weight"
# "weight"
# "bias"
# "bias"
# }
# }
# - "done"
# - "done"
def
load_plugin
(
plugin_type
,
name
):
def
load_plugin
(
plugin_type
,
name
):
module_name
=
f
"
{
plugin_type
}
_
{
name
}
"
module_name
=
f
"
{
plugin_type
}
_
{
name
}
"
try
:
try
:
plugin
=
importlib
.
import_module
(
module_name
)
plugin
=
importlib
.
import_module
(
module_name
)
except
ModuleNotFoundError
as
e
:
except
ModuleNotFoundError
as
e
:
print
(
e
)
print
(
e
)
module_name
=
name
module_name
=
name
try
:
try
:
plugin
=
importlib
.
import_module
(
module_name
)
plugin
=
importlib
.
import_module
(
module_name
)
except
ModuleNotFoundError
as
e
:
except
ModuleNotFoundError
as
e
:
print
(
e
)
print
(
e
)
sys
.
exit
(
f
"Unable to load
{
plugin_type
}
plugin
{
name
}
. Exiting."
)
sys
.
exit
(
f
"Unable to load
{
plugin_type
}
plugin
{
name
}
. Exiting."
)
if
not
hasattr
(
plugin
,
'add_arguments'
):
if
not
hasattr
(
plugin
,
'add_arguments'
):
sys
.
exit
(
f
"
{
module_name
}
module is not a plugin. Exiting."
)
sys
.
exit
(
f
"
{
module_name
}
module is not a plugin. Exiting."
)
print
(
f
"Loaded
{
module_name
}
as the
{
plugin_type
}
."
)
print
(
f
"Loaded
{
module_name
}
as the
{
plugin_type
}
."
)
return
plugin
return
plugin
def
main
():
def
main
():
import
argparse
import
argparse
parser
=
argparse
.
ArgumentParser
(
description
=
"Megatron Checkpoint Converter Arguments"
,
parser
=
argparse
.
ArgumentParser
(
description
=
"Megatron Checkpoint Converter Arguments"
,
allow_abbrev
=
False
,
conflict_handler
=
'resolve'
)
allow_abbrev
=
False
,
conflict_handler
=
'resolve'
)
parser
.
add_argument
(
'--model-type'
,
type
=
str
,
required
=
True
,
parser
.
add_argument
(
'--model-type'
,
type
=
str
,
required
=
True
,
choices
=
[
'GPT'
,
'BERT'
],
choices
=
[
'GPT'
,
'BERT'
],
help
=
'Type of the model'
)
help
=
'Type of the model'
)
parser
.
add_argument
(
'--loader'
,
type
=
str
,
default
=
'megatron'
,
parser
.
add_argument
(
'--loader'
,
type
=
str
,
default
=
'megatron'
,
help
=
'Module name to load checkpoint, should be on python path'
)
help
=
'Module name to load checkpoint, should be on python path'
)
parser
.
add_argument
(
'--saver'
,
type
=
str
,
default
=
'megatron'
,
parser
.
add_argument
(
'--saver'
,
type
=
str
,
default
=
'megatron'
,
help
=
'Module name to save checkpoint, should be on python path'
)
help
=
'Module name to save checkpoint, should be on python path'
)
parser
.
add_argument
(
'--load-dir'
,
type
=
str
,
required
=
True
,
parser
.
add_argument
(
'--load-dir'
,
type
=
str
,
required
=
True
,
help
=
'Directory to load model checkpoint from'
)
help
=
'Directory to load model checkpoint from'
)
parser
.
add_argument
(
'--save-dir'
,
type
=
str
,
required
=
True
,
parser
.
add_argument
(
'--save-dir'
,
type
=
str
,
required
=
True
,
help
=
'Directory to save model checkpoint to'
)
help
=
'Directory to save model checkpoint to'
)
parser
.
add_argument
(
'--max-queue-size'
,
type
=
int
,
default
=
50
,
parser
.
add_argument
(
'--max-queue-size'
,
type
=
int
,
default
=
50
,
help
=
'Maximum number of tensors in the queue'
)
help
=
'Maximum number of tensors in the queue'
)
parser
.
add_argument
(
'--no-checking'
,
action
=
'store_false'
,
parser
.
add_argument
(
'--no-checking'
,
action
=
'store_false'
,
help
=
'Do not perform checking on the name and ordering of weights'
,
help
=
'Do not perform checking on the name and ordering of weights'
,
dest
=
'checking'
)
dest
=
'checking'
)
known_args
,
_
=
parser
.
parse_known_args
()
known_args
,
_
=
parser
.
parse_known_args
()
# Handle old arg values.
# Handle old arg values.
def
update_loader_saver
(
key
):
def
update_loader_saver
(
key
):
old_value
=
getattr
(
known_args
,
key
)
old_value
=
getattr
(
known_args
,
key
)
if
old_value
==
"megatron"
:
if
old_value
==
"megatron"
:
setattr
(
known_args
,
key
,
"legacy"
)
setattr
(
known_args
,
key
,
"legacy"
)
if
old_value
==
"mcore"
:
if
old_value
==
"mcore"
:
setattr
(
known_args
,
key
,
"core"
)
setattr
(
known_args
,
key
,
"core"
)
update_loader_saver
(
"loader"
)
update_loader_saver
(
"loader"
)
update_loader_saver
(
"saver"
)
update_loader_saver
(
"saver"
)
# Load loader/saver plugins.
# Load loader/saver plugins.
loader
=
load_plugin
(
'loader'
,
known_args
.
loader
)
loader
=
load_plugin
(
'loader'
,
known_args
.
loader
)
saver
=
load_plugin
(
'saver'
,
known_args
.
saver
)
saver
=
load_plugin
(
'saver'
,
known_args
.
saver
)
# Parser loader/saver args.
# Parser loader/saver args.
loader
.
add_arguments
(
parser
)
loader
.
add_arguments
(
parser
)
saver
.
add_arguments
(
parser
)
saver
.
add_arguments
(
parser
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
# Initialize queue
# Initialize queue
queue
=
mp
.
Queue
(
maxsize
=
args
.
max_queue_size
)
ctx
=
mp
.
get_context
(
"spawn"
)
queue
=
ctx
.
Queue
(
maxsize
=
args
.
max_queue_size
)
# Start saver process.
# queue = mp.Queue(maxsize=args.max_queue_size)
print
(
"Starting saver..."
)
saver_proc
=
mp
.
Process
(
target
=
saver
.
save_checkpoint
,
args
=
(
queue
,
args
))
# Start saver process.
saver_proc
.
start
()
print
(
"Starting saver..."
)
saver_proc
=
ctx
.
Process
(
target
=
saver
.
save_checkpoint
,
args
=
(
queue
,
args
))
# Run loader.
# saver_proc = mp.Process(target=saver.save_checkpoint, args=(queue, args))
print
(
"Starting loader..."
)
saver_proc
.
start
()
loader
.
load_checkpoint
(
queue
,
args
)
# Run loader.
# Finish saver process.
print
(
"Starting loader..."
)
print
(
"Waiting for saver to complete..."
)
loader
.
load_checkpoint
(
queue
,
args
)
saver_proc
.
join
()
# Finish saver process.
print
(
"Waiting for saver to complete..."
)
if
__name__
==
'__main__'
:
saver_proc
.
join
()
main
()
if
__name__
==
'__main__'
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment