Unverified Commit 0af901e8 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[megatron_gpt2] checkpoint v3 (#13508)

* [megatron_gpt2] checkpoint v3

* bug fix

* fixes

* switch to default  from  - which is what the current megatron-lm uses

* cleanup

* back compat
parent 936b3fde
......@@ -80,6 +80,22 @@ def convert_megatron_checkpoint(args, input_state_dict, config):
# The converted output model.
output_state_dict = {}
# old versions did not store training args
if "args" in input_state_dict:
# do not make the user write a config file when the exact dimensions/sizes are already in the checkpoint
train_args = input_state_dict["args"]
# from pprint import pprint
# pprint(vars(train_args))
config.vocab_size = train_args.padded_vocab_size
config.n_positions = train_args.max_position_embeddings
config.n_ctx = train_args.seq_length
config.n_embd = train_args.hidden_size
config.n_layer = train_args.num_layers
config.n_head = train_args.num_attention_heads
config.n_inner = train_args.ffn_hidden_size
# pprint(config)
# The number of heads.
heads = config.n_head
# The hidden_size per head.
......@@ -106,9 +122,11 @@ def convert_megatron_checkpoint(args, input_state_dict, config):
# The position embeddings.
pos_embeddings = embeddings["position_embeddings"]["weight"]
# Read the hidden dimension.
n_embed = pos_embeddings.size(0)
n_embed = pos_embeddings.size(1)
# DEBUG.
assert n_embed == heads * hidden_size_per_head
assert (
n_embed == heads * hidden_size_per_head
), f"detected mismatch n_embed={n_embed} != heads={heads}*hidden_size_per_head={hidden_size_per_head}"
# Store the position embeddings.
output_state_dict["transformer.wpe.weight"] = pos_embeddings
......@@ -215,7 +233,7 @@ def main():
parser.add_argument(
"path_to_checkpoint",
type=str,
help="Path to the ZIP file containing the checkpoint",
help="Path to the checkpoint file (.zip archive or direct .pt file)",
)
parser.add_argument(
"--config_file",
......@@ -229,10 +247,14 @@ def main():
basename = os.path.dirname(args.path_to_checkpoint)
# Load the model.
# the .zip is very optional, let's keep it for backward compatibility
print(f"Extracting PyTorch state dictionary from {args.path_to_checkpoint}")
with zipfile.ZipFile(args.path_to_checkpoint, "r") as checkpoint:
with checkpoint.open("release/mp_rank_00/model_optim_rng.pt") as pytorch_dict:
input_state_dict = torch.load(pytorch_dict, map_location="cpu")
if args.path_to_checkpoint.endswith(".zip"):
with zipfile.ZipFile(args.path_to_checkpoint, "r") as checkpoint:
with checkpoint.open("release/mp_rank_00/model_optim_rng.pt") as pytorch_dict:
input_state_dict = torch.load(pytorch_dict, map_location="cpu")
else:
input_state_dict = torch.load(args.path_to_checkpoint, map_location="cpu")
# Read the config, or default to the model released by NVIDIA.
if args.config_file == "":
......@@ -245,7 +267,7 @@ def main():
n_layer=24,
n_head=16,
n_inner=4096,
activation_function="gelu_new",
activation_function="gelu", # used to be "gelu_new" in earlier versions
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment