Unverified Commit d5b0a0e2 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

mBART Conversion script (#6230)

parent 268bf346
......@@ -78,19 +78,6 @@ def load_xsum_checkpoint(checkpoint_path):
return hub_interface
def convert_checkpoint_from_disk(checkpoint_path, **config_kwargs):
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
remove_ignore_keys_(state_dict)
vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0]
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
mbart_config = BartConfig(vocab_size=vocab_size, **config_kwargs)
model = BartForConditionalGeneration(mbart_config)
model.model.load_state_dict(state_dict)
if hasattr(model, "lm_head"):
model.lm_head = _make_linear_from_emb(model.model.shared)
return model
@torch.no_grad()
def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkpoint_name=None):
"""
......
import argparse
import torch
from transformers import BartForConditionalGeneration, MBartConfig
from .convert_bart_original_pytorch_checkpoint_to_pytorch import remove_ignore_keys_
def convert_fairseq_mbart_checkpoint_from_disk(checkpoint_path, hf_config_path="facebook/mbart-large-en-ro"):
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
remove_ignore_keys_(state_dict)
vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0]
mbart_config = MBartConfig.from_pretrained(hf_config_path, vocab_size=vocab_size)
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
model = BartForConditionalGeneration(mbart_config)
model.model.load_state_dict(state_dict)
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"fairseq_path", type=str, help="bart.large, bart.large.cnn or a path to a model.pt on local filesystem."
)
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument(
"--hf_config",
default="facebook/mbart-large-cc25",
type=str,
help="Which huggingface architecture to use: bart-large-xsum",
)
args = parser.parse_args()
model = convert_fairseq_mbart_checkpoint_from_disk(args.fairseq_path, hf_config_path=args.hf_config)
model.save_pretrained(args.pytorch_dump_folder_path)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment