Unverified Commit bfd81766 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[megatron_gpt2] dynamic gelu, add tokenizer, save config (#13928)



* [megatron_gpt2] dynamic gelu, add tokenizer, save config

* cleanup

* Update src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* apply suggestions
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 919a964b
...@@ -17,14 +17,13 @@ ...@@ -17,14 +17,13 @@
#################################################################################################### ####################################################################################################
import argparse import argparse
import json
import os import os
import re import re
import zipfile import zipfile
import torch import torch
from transformers import GPT2Config from transformers import AutoTokenizer, GPT2Config
#################################################################################################### ####################################################################################################
...@@ -81,19 +80,19 @@ def convert_megatron_checkpoint(args, input_state_dict, config): ...@@ -81,19 +80,19 @@ def convert_megatron_checkpoint(args, input_state_dict, config):
output_state_dict = {} output_state_dict = {}
# old versions did not store training args # old versions did not store training args
if "args" in input_state_dict: ds_args = input_state_dict.get("args", None)
if ds_args is not None:
# do not make the user write a config file when the exact dimensions/sizes are already in the checkpoint # do not make the user write a config file when the exact dimensions/sizes are already in the checkpoint
train_args = input_state_dict["args"]
# from pprint import pprint # from pprint import pprint
# pprint(vars(train_args)) # pprint(vars(ds_args))
config.vocab_size = train_args.padded_vocab_size config.vocab_size = ds_args.padded_vocab_size
config.n_positions = train_args.max_position_embeddings config.n_positions = ds_args.max_position_embeddings
config.n_ctx = train_args.seq_length config.n_ctx = ds_args.seq_length
config.n_embd = train_args.hidden_size config.n_embd = ds_args.hidden_size
config.n_layer = train_args.num_layers config.n_layer = ds_args.num_layers
config.n_head = train_args.num_attention_heads config.n_head = ds_args.num_attention_heads
config.n_inner = train_args.ffn_hidden_size config.n_inner = ds_args.ffn_hidden_size
# pprint(config) # pprint(config)
# The number of heads. # The number of heads.
...@@ -255,8 +254,22 @@ def main(): ...@@ -255,8 +254,22 @@ def main():
else: else:
input_state_dict = torch.load(args.path_to_checkpoint, map_location="cpu") input_state_dict = torch.load(args.path_to_checkpoint, map_location="cpu")
ds_args = input_state_dict.get("args", None)
# Read the config, or default to the model released by NVIDIA. # Read the config, or default to the model released by NVIDIA.
if args.config_file == "": if args.config_file == "":
if ds_args is not None:
if ds_args.bias_gelu_fusion:
activation_function = "gelu_fast"
elif ds_args.openai_gelu:
activation_function = "gelu_new"
else:
activation_function = "gelu"
else:
# in the very early days this used to be "gelu_new"
activation_function = "gelu_new"
# Spell out all parameters in case the defaults change. # Spell out all parameters in case the defaults change.
config = GPT2Config( config = GPT2Config(
vocab_size=50257, vocab_size=50257,
...@@ -266,7 +279,7 @@ def main(): ...@@ -266,7 +279,7 @@ def main():
n_layer=24, n_layer=24,
n_head=16, n_head=16,
n_inner=4096, n_inner=4096,
activation_function="gelu", # used to be "gelu_new" in earlier versions activation_function=activation_function,
resid_pdrop=0.1, resid_pdrop=0.1,
embd_pdrop=0.1, embd_pdrop=0.1,
attn_pdrop=0.1, attn_pdrop=0.1,
...@@ -285,6 +298,8 @@ def main(): ...@@ -285,6 +298,8 @@ def main():
else: else:
config = GPT2Config.from_json_file(args.config_file) config = GPT2Config.from_json_file(args.config_file)
config.architectures = ["GPT2LMHeadModel"]
# Convert. # Convert.
print("Converting") print("Converting")
output_state_dict = convert_megatron_checkpoint(args, input_state_dict, config) output_state_dict = convert_megatron_checkpoint(args, input_state_dict, config)
...@@ -293,14 +308,30 @@ def main(): ...@@ -293,14 +308,30 @@ def main():
if args.print_checkpoint_structure: if args.print_checkpoint_structure:
recursive_print(None, output_state_dict) recursive_print(None, output_state_dict)
# Add tokenizer class info to config
# see https://github.com/huggingface/transformers/issues/13906)
if ds_args is not None:
tokenizer_type = ds_args.tokenizer_type
if tokenizer_type == "GPT2BPETokenizer":
tokenizer_model_name = "gpt2"
elif tokenizer_type == "PretrainedFromHF":
tokenizer_model_name = ds_args.tokenizer_name_or_path
else:
raise ValueError(f"Unrecognized tokenizer_type {tokenizer_type}")
else:
tokenizer_model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_model_name)
tokenizer_class = type(tokenizer).__name__
config.tokenizer_class = tokenizer_class
# Store the config to file. # Store the config to file.
output_config_file = os.path.join(basename, "config.json") print("Saving config")
output_config = config.to_dict() config.save_pretrained(basename)
output_config["architectures"] = ["GPT2LMHeadModel"]
output_config["model_type"] = "gpt2" # Save tokenizer based on args
print(f'Saving config to "{output_config_file}"') print(f"Adding {tokenizer_class} tokenizer files")
with open(output_config_file, "w") as f: tokenizer.save_pretrained(basename)
json.dump(output_config, f)
# Store the state_dict to file. # Store the state_dict to file.
output_checkpoint_file = os.path.join(basename, "pytorch_model.bin") output_checkpoint_file = os.path.join(basename, "pytorch_model.bin")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment