Commit be4dda7b authored by wxj's avatar wxj
Browse files

Update convert.py

parent 56bf70a2
Pipeline #2650 passed with stage
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import argparse import argparse
import importlib import importlib
import torch.multiprocessing as mp import torch.multiprocessing as mp
import sys import sys
# A loader is a python file with at least two functions # A loader is a python file with at least two functions
# - add_arguments - takes in a parser and adds any arguments needed # - add_arguments - takes in a parser and adds any arguments needed
# - load_checkpoint - takes in the queue and parsed arguments # - load_checkpoint - takes in the queue and parsed arguments
# A saver is similar but has save_checkpoint instead of # A saver is similar but has save_checkpoint instead of
# load_checkpoint # load_checkpoint
# The loader and saver process are each given a queue, the loader # The loader and saver process are each given a queue, the loader
# should load the checkpoint and send the weights in messages in the # should load the checkpoint and send the weights in messages in the
# following order, the saver should receive them in this order and # following order, the saver should receive them in this order and
# save the checkpoints. A message consists of a python dictionary with # save the checkpoints. A message consists of a python dictionary with
# a "name" for error checking and an entry for each tensor as # a "name" for error checking and an entry for each tensor as
# indicated below. Note that the weight sent over the queue are the # indicated below. Note that the weight sent over the queue are the
# full model weights, nothing split. # full model weights, nothing split.
# If the loader ever sends "exit" to the queue, that means something # If the loader ever sends "exit" to the queue, that means something
# went wrong and it is exiting. # went wrong and it is exiting.
# - Metadata Namespace with the following attributes: # - Metadata Namespace with the following attributes:
# model_type - GPT, BERT, T5, etc. (Part of protocol to allow this to be deduced later instead of given on command line) # model_type - GPT, BERT, T5, etc. (Part of protocol to allow this to be deduced later instead of given on command line)
# num_layers - Number of transformer layers # num_layers - Number of transformer layers
# hidden_size # hidden_size
# seq_length # seq_length
# num_attention_heads # num_attention_heads
# max_position_embeddings # max_position_embeddings
# tokenizer_type # tokenizer_type
# iteration # iteration
# params_dtype # params_dtype
# bert_binary_head - Used only if model_type is BERT # bert_binary_head - Used only if model_type is BERT
# previous_tensor_parallel_size - Optional # previous_tensor_parallel_size - Optional
# previous_pipeline_parallel_size - Optional # previous_pipeline_parallel_size - Optional
# true_vocab_size # true_vocab_size
# make_vocab_size_divisble_by # make_vocab_size_divisble_by
# consumed_train_samples # consumed_train_samples
# consumed_valid_samples # consumed_valid_samples
# messages # messages
# { # {
# "name": "embeddings" # "name": "embeddings"
# "position embeddings" # "position embeddings"
# "word embeddings" # "word embeddings"
# } # }
# (for each transformer layer): # (for each transformer layer):
# { # {
# "name": "transformer layer N" # "name": "transformer layer N"
# "input norm weight" # "input norm weight"
# "input norm bias" # "input norm bias"
# "qkv weight" # "qkv weight"
# "qkv bias" # "qkv bias"
# "dense weight" # "dense weight"
# "dense bias" # "dense bias"
# "post norm weight" # "post norm weight"
# "post norm bias" # "post norm bias"
# "mlp l0 weight" # "mlp l0 weight"
# "mlp l0 bias" # "mlp l0 bias"
# "mlp l1 weight" # "mlp l1 weight"
# "mlp l1 bias" # "mlp l1 bias"
# } # }
# { # {
# "name": "final layer norm" # "name": "final layer norm"
# "weight" # "weight"
# "bias" # "bias"
# } # }
# if present (i.e. for BERT): # if present (i.e. for BERT):
# { # {
# "name": "pooler" # "name": "pooler"
# "weight" # "weight"
# "bias" # "bias"
# } # }
# { # {
# "name": "lm head" # "name": "lm head"
# "dense weight" # "dense weight"
# "dense bias" # "dense bias"
# "norm weight" # "norm weight"
# "norm bias" # "norm bias"
# } # }
# { # {
# "name": "binary head" # "name": "binary head"
# "weight" # "weight"
# "bias" # "bias"
# } # }
# - "done" # - "done"
def load_plugin(plugin_type, name): def load_plugin(plugin_type, name):
module_name = f"{plugin_type}_{name}" module_name = f"{plugin_type}_{name}"
try: try:
plugin = importlib.import_module(module_name) plugin = importlib.import_module(module_name)
except ModuleNotFoundError as e: except ModuleNotFoundError as e:
print(e) print(e)
module_name = name module_name = name
try: try:
plugin = importlib.import_module(module_name) plugin = importlib.import_module(module_name)
except ModuleNotFoundError as e: except ModuleNotFoundError as e:
print(e) print(e)
sys.exit(f"Unable to load {plugin_type} plugin {name}. Exiting.") sys.exit(f"Unable to load {plugin_type} plugin {name}. Exiting.")
if not hasattr(plugin, 'add_arguments'): if not hasattr(plugin, 'add_arguments'):
sys.exit(f"{module_name} module is not a plugin. Exiting.") sys.exit(f"{module_name} module is not a plugin. Exiting.")
print(f"Loaded {module_name} as the {plugin_type}.") print(f"Loaded {module_name} as the {plugin_type}.")
return plugin return plugin
def main(): def main():
import argparse import argparse
parser = argparse.ArgumentParser(description="Megatron Checkpoint Converter Arguments", parser = argparse.ArgumentParser(description="Megatron Checkpoint Converter Arguments",
allow_abbrev=False, conflict_handler='resolve') allow_abbrev=False, conflict_handler='resolve')
parser.add_argument('--model-type', type=str, required=True, parser.add_argument('--model-type', type=str, required=True,
choices=['GPT', 'BERT'], choices=['GPT', 'BERT'],
help='Type of the model') help='Type of the model')
parser.add_argument('--loader', type=str, default='megatron', parser.add_argument('--loader', type=str, default='megatron',
help='Module name to load checkpoint, should be on python path') help='Module name to load checkpoint, should be on python path')
parser.add_argument('--saver', type=str, default='megatron', parser.add_argument('--saver', type=str, default='megatron',
help='Module name to save checkpoint, should be on python path') help='Module name to save checkpoint, should be on python path')
parser.add_argument('--load-dir', type=str, required=True, parser.add_argument('--load-dir', type=str, required=True,
help='Directory to load model checkpoint from') help='Directory to load model checkpoint from')
parser.add_argument('--save-dir', type=str, required=True, parser.add_argument('--save-dir', type=str, required=True,
help='Directory to save model checkpoint to') help='Directory to save model checkpoint to')
parser.add_argument('--max-queue-size', type=int, default=50, parser.add_argument('--max-queue-size', type=int, default=50,
help='Maximum number of tensors in the queue') help='Maximum number of tensors in the queue')
parser.add_argument('--no-checking', action='store_false', parser.add_argument('--no-checking', action='store_false',
help='Do not perform checking on the name and ordering of weights', help='Do not perform checking on the name and ordering of weights',
dest='checking') dest='checking')
known_args, _ = parser.parse_known_args() known_args, _ = parser.parse_known_args()
# Handle old arg values. # Handle old arg values.
def update_loader_saver(key): def update_loader_saver(key):
old_value = getattr(known_args, key) old_value = getattr(known_args, key)
if old_value == "megatron": if old_value == "megatron":
setattr(known_args, key, "legacy") setattr(known_args, key, "legacy")
if old_value == "mcore": if old_value == "mcore":
setattr(known_args, key, "core") setattr(known_args, key, "core")
update_loader_saver("loader") update_loader_saver("loader")
update_loader_saver("saver") update_loader_saver("saver")
# Load loader/saver plugins. # Load loader/saver plugins.
loader = load_plugin('loader', known_args.loader) loader = load_plugin('loader', known_args.loader)
saver = load_plugin('saver', known_args.saver) saver = load_plugin('saver', known_args.saver)
# Parser loader/saver args. # Parser loader/saver args.
loader.add_arguments(parser) loader.add_arguments(parser)
saver.add_arguments(parser) saver.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
# Initialize queue # Initialize queue
queue = mp.Queue(maxsize=args.max_queue_size) ctx = mp.get_context("spawn")
queue = ctx.Queue(maxsize=args.max_queue_size)
# Start saver process. # queue = mp.Queue(maxsize=args.max_queue_size)
print("Starting saver...")
saver_proc = mp.Process(target=saver.save_checkpoint, args=(queue, args)) # Start saver process.
saver_proc.start() print("Starting saver...")
saver_proc = ctx.Process(target=saver.save_checkpoint, args=(queue, args))
# Run loader. # saver_proc = mp.Process(target=saver.save_checkpoint, args=(queue, args))
print("Starting loader...") saver_proc.start()
loader.load_checkpoint(queue, args)
# Run loader.
# Finish saver process. print("Starting loader...")
print("Waiting for saver to complete...") loader.load_checkpoint(queue, args)
saver_proc.join()
# Finish saver process.
print("Waiting for saver to complete...")
if __name__ == '__main__': saver_proc.join()
main()
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment