Unverified Commit 0af901e8 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[megatron_gpt2] checkpoint v3 (#13508)

* [megatron_gpt2] checkpoint v3

* bug fix

* fixes

* switch to default  from  - which is what the current megatron-lm uses

* cleanup

* back compat
parent 936b3fde
...@@ -80,6 +80,22 @@ def convert_megatron_checkpoint(args, input_state_dict, config): ...@@ -80,6 +80,22 @@ def convert_megatron_checkpoint(args, input_state_dict, config):
# The converted output model. # The converted output model.
output_state_dict = {} output_state_dict = {}
# old versions did not store training args
if "args" in input_state_dict:
# do not make the user write a config file when the exact dimensions/sizes are already in the checkpoint
train_args = input_state_dict["args"]
# from pprint import pprint
# pprint(vars(train_args))
config.vocab_size = train_args.padded_vocab_size
config.n_positions = train_args.max_position_embeddings
config.n_ctx = train_args.seq_length
config.n_embd = train_args.hidden_size
config.n_layer = train_args.num_layers
config.n_head = train_args.num_attention_heads
config.n_inner = train_args.ffn_hidden_size
# pprint(config)
# The number of heads. # The number of heads.
heads = config.n_head heads = config.n_head
# The hidden_size per head. # The hidden_size per head.
...@@ -106,9 +122,11 @@ def convert_megatron_checkpoint(args, input_state_dict, config): ...@@ -106,9 +122,11 @@ def convert_megatron_checkpoint(args, input_state_dict, config):
# The position embeddings. # The position embeddings.
pos_embeddings = embeddings["position_embeddings"]["weight"] pos_embeddings = embeddings["position_embeddings"]["weight"]
# Read the hidden dimension. # Read the hidden dimension.
n_embed = pos_embeddings.size(0) n_embed = pos_embeddings.size(1)
# DEBUG. # DEBUG.
assert n_embed == heads * hidden_size_per_head assert (
n_embed == heads * hidden_size_per_head
), f"detected mismatch n_embed={n_embed} != heads={heads}*hidden_size_per_head={hidden_size_per_head}"
# Store the position embeddings. # Store the position embeddings.
output_state_dict["transformer.wpe.weight"] = pos_embeddings output_state_dict["transformer.wpe.weight"] = pos_embeddings
...@@ -215,7 +233,7 @@ def main(): ...@@ -215,7 +233,7 @@ def main():
parser.add_argument( parser.add_argument(
"path_to_checkpoint", "path_to_checkpoint",
type=str, type=str,
help="Path to the ZIP file containing the checkpoint", help="Path to the checkpoint file (.zip archive or direct .pt file)",
) )
parser.add_argument( parser.add_argument(
"--config_file", "--config_file",
...@@ -229,10 +247,14 @@ def main(): ...@@ -229,10 +247,14 @@ def main():
basename = os.path.dirname(args.path_to_checkpoint) basename = os.path.dirname(args.path_to_checkpoint)
# Load the model. # Load the model.
# the .zip is very optional, let's keep it for backward compatibility
print(f"Extracting PyTorch state dictionary from {args.path_to_checkpoint}") print(f"Extracting PyTorch state dictionary from {args.path_to_checkpoint}")
with zipfile.ZipFile(args.path_to_checkpoint, "r") as checkpoint: if args.path_to_checkpoint.endswith(".zip"):
with checkpoint.open("release/mp_rank_00/model_optim_rng.pt") as pytorch_dict: with zipfile.ZipFile(args.path_to_checkpoint, "r") as checkpoint:
input_state_dict = torch.load(pytorch_dict, map_location="cpu") with checkpoint.open("release/mp_rank_00/model_optim_rng.pt") as pytorch_dict:
input_state_dict = torch.load(pytorch_dict, map_location="cpu")
else:
input_state_dict = torch.load(args.path_to_checkpoint, map_location="cpu")
# Read the config, or default to the model released by NVIDIA. # Read the config, or default to the model released by NVIDIA.
if args.config_file == "": if args.config_file == "":
...@@ -245,7 +267,7 @@ def main(): ...@@ -245,7 +267,7 @@ def main():
n_layer=24, n_layer=24,
n_head=16, n_head=16,
n_inner=4096, n_inner=4096,
activation_function="gelu_new", activation_function="gelu", # used to be "gelu_new" in earlier versions
resid_pdrop=0.1, resid_pdrop=0.1,
embd_pdrop=0.1, embd_pdrop=0.1,
attn_pdrop=0.1, attn_pdrop=0.1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment