Unverified Commit f6a23d19 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[BART] add bart-large-xsum weights (#3422)

parent 601ac5b1
...@@ -26,6 +26,7 @@ BART_PRETRAINED_CONFIG_ARCHIVE_MAP = { ...@@ -26,6 +26,7 @@ BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"bart-large": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/config.json", "bart-large": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/config.json",
"bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/config.json", "bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/config.json",
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json", "bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json",
"bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/config.json",
} }
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import argparse import argparse
import logging import logging
import os
from pathlib import Path from pathlib import Path
import fairseq import fairseq
...@@ -30,10 +31,11 @@ from transformers import ( ...@@ -30,10 +31,11 @@ from transformers import (
BartModel, BartModel,
BartTokenizer, BartTokenizer,
) )
from transformers.modeling_bart import _make_linear_from_emb
FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn"] FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn", "bart_xsum/model.pt"]
extra_arch = {"bart.large": BartModel, "bart.large.mnli": BartForSequenceClassification}
if version.parse(fairseq.__version__) < version.parse("0.9.0"): if version.parse(fairseq.__version__) < version.parse("0.9.0"):
raise Exception("requires fairseq >= 0.9.0") raise Exception("requires fairseq >= 0.9.0")
...@@ -57,62 +59,79 @@ def rename_key(dct, old, new): ...@@ -57,62 +59,79 @@ def rename_key(dct, old, new):
dct[new] = val dct[new] = val
def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path): def load_xsum_checkpoint(checkpoint_path):
"""Checkpoint path should end in model.pt"""
sd = torch.load(checkpoint_path, map_location="cpu")
hub_interface = torch.hub.load("pytorch/fairseq", "bart.large.cnn").eval()
hub_interface.model.load_state_dict(sd["model"])
return hub_interface
@torch.no_grad()
def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkpoint_name=None):
""" """
Copy/paste/tweak model's weights to our BERT structure. Copy/paste/tweak model's weights to our BERT structure.
""" """
bart = torch.hub.load("pytorch/fairseq", checkpoint_path) if not os.path.exists(checkpoint_path):
bart.eval() # disable dropout bart = torch.hub.load("pytorch/fairseq", checkpoint_path).eval()
else:
bart = load_xsum_checkpoint(checkpoint_path)
bart.model.upgrade_state_dict(bart.model.state_dict()) bart.model.upgrade_state_dict(bart.model.state_dict())
hf_model_name = checkpoint_path.replace(".", "-") if hf_checkpoint_name is None:
config = BartConfig.from_pretrained(hf_model_name) hf_checkpoint_name = checkpoint_path.replace(".", "-")
config = BartConfig.from_pretrained(hf_checkpoint_name)
tokens = bart.encode(SAMPLE_TEXT).unsqueeze(0) tokens = bart.encode(SAMPLE_TEXT).unsqueeze(0)
tokens2 = BartTokenizer.from_pretrained(hf_model_name).encode(SAMPLE_TEXT, return_tensors="pt").unsqueeze(0) tokens2 = BartTokenizer.from_pretrained(hf_checkpoint_name).encode(SAMPLE_TEXT, return_tensors="pt").unsqueeze(0)
assert torch.eq(tokens, tokens2).all() assert torch.eq(tokens, tokens2).all()
if checkpoint_path in ["bart.large", "bart.large.cnn"]: if checkpoint_path == "bart.large.mnli":
state_dict = bart.model.state_dict()
for k in IGNORE_KEYS:
state_dict.pop(k, None)
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
model = BartModel(config)
their_output = bart.extract_features(tokens)
else: # MNLI Case
state_dict = bart.state_dict() state_dict = bart.state_dict()
for k in IGNORE_KEYS: remove_ignore_keys_(state_dict)
state_dict.pop(k, None)
state_dict["model.shared.weight"] = state_dict["model.decoder.embed_tokens.weight"] state_dict["model.shared.weight"] = state_dict["model.decoder.embed_tokens.weight"]
for src, dest in rename_keys: for src, dest in rename_keys:
rename_key(state_dict, src, dest) rename_key(state_dict, src, dest)
model = BartForSequenceClassification(config) model = BartForSequenceClassification(config).eval()
their_output = bart.predict("mnli", tokens, return_logits=True) model.load_state_dict(state_dict)
fairseq_output = bart.predict("mnli", tokens, return_logits=True)
new_model_outputs = model(tokens)[0] # logits
else: # no classification heads to worry about
state_dict = bart.model.state_dict()
remove_ignore_keys_(state_dict)
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
fairseq_output = bart.extract_features(tokens)
if hf_checkpoint_name == "bart-large":
model = BartModel(config).eval()
model.load_state_dict(state_dict)
new_model_outputs = model(tokens).model[0]
else:
model = BartForConditionalGeneration(config).eval() # an existing summarization ckpt
model.model.load_state_dict(state_dict)
if hasattr(model, "lm_head"):
model.lm_head = _make_linear_from_emb(model.model.shared)
new_model_outputs = model.model(tokens)[0]
# Load state dict
model.load_state_dict(state_dict)
model.eval()
# Check results # Check results
assert fairseq_output.shape == new_model_outputs.shape
if checkpoint_path == "bart.large.cnn": assert (fairseq_output == new_model_outputs).all().item()
model = BartForConditionalGeneration(config, base_model=model)
assert "lm_head.weight" in model.state_dict()
assert model.lm_head.out_features == config.max_position_embeddings
model.eval()
our_outputs = model.model(tokens)[0]
else:
our_outputs = model(tokens)[0]
assert their_output.shape == our_outputs.shape
assert (their_output == our_outputs).all().item()
Path(pytorch_dump_folder_path).mkdir(exist_ok=True) Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
model.save_pretrained(pytorch_dump_folder_path) model.save_pretrained(pytorch_dump_folder_path)
def remove_ignore_keys_(state_dict):
for k in IGNORE_KEYS:
state_dict.pop(k, None)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# Required parameters # Required parameters
parser.add_argument("fairseq_path", choices=FAIRSEQ_MODELS, type=str, help="") parser.add_argument(
"fairseq_path", type=str, help="bart.large, bart.large.cnn or a path to a model.pt on local filesystem."
)
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
args = parser.parse_args() parser.add_argument(
convert_bart_checkpoint( "--hf_config", default=None, type=str, help="Which huggingface architecture to use: bart-large-xsum"
args.fairseq_path, args.pytorch_dump_folder_path,
) )
args = parser.parse_args()
convert_bart_checkpoint(args.fairseq_path, args.pytorch_dump_folder_path, hf_checkpoint_name=args.hf_config)
...@@ -34,6 +34,7 @@ BART_PRETRAINED_MODEL_ARCHIVE_MAP = { ...@@ -34,6 +34,7 @@ BART_PRETRAINED_MODEL_ARCHIVE_MAP = {
"bart-large": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/pytorch_model.bin", "bart-large": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/pytorch_model.bin",
"bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/pytorch_model.bin", "bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/pytorch_model.bin",
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/pytorch_model.bin", "bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/pytorch_model.bin",
"bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/pytorch_model.bin",
} }
BART_START_DOCSTRING = r""" BART_START_DOCSTRING = r"""
......
...@@ -19,7 +19,7 @@ from .tokenization_roberta import RobertaTokenizer ...@@ -19,7 +19,7 @@ from .tokenization_roberta import RobertaTokenizer
# vocab and merges same as roberta # vocab and merges same as roberta
vocab_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json" vocab_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json"
merges_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt" merges_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt"
_all_bart_models = ["bart-large", "bart-large-mnli", "bart-large-cnn"] _all_bart_models = ["bart-large", "bart-large-mnli", "bart-large-cnn", "bart-large-xsum"]
class BartTokenizer(RobertaTokenizer): class BartTokenizer(RobertaTokenizer):
......
...@@ -450,6 +450,38 @@ class BartModelIntegrationTests(unittest.TestCase): ...@@ -450,6 +450,38 @@ class BartModelIntegrationTests(unittest.TestCase):
model = BartModel.from_pretrained(model_name, cache_dir=CACHE_DIR) model = BartModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model) self.assertIsNotNone(model)
@slow
def test_xsum_summarization_same_as_fairseq(self):
model = BartForConditionalGeneration.from_pretrained("bart-large-xsum").to(torch_device)
tok = BartTokenizer.from_pretrained("bart-large")
PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
EXPECTED_SUMMARY = "California's largest power company has begun shutting off power to tens of thousands of homes and businesses in the state."
dct = tok.batch_encode_plus([PGE_ARTICLE], max_length=1024, pad_to_max_length=True, return_tensors="pt",)
hypotheses_batch = model.generate(
input_ids=dct["input_ids"].to(torch_device),
attention_mask=dct["attention_mask"].to(torch_device),
num_beams=2,
max_length=62,
min_length=11,
length_penalty=1.0,
no_repeat_ngram_size=3,
early_stopping=True,
decoder_start_token_id=model.config.eos_token_ids[0],
)
decoded = [
tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in hypotheses_batch
]
self.assertEqual(EXPECTED_SUMMARY, decoded[0])
def test_xsum_config_generation_params(self):
config = BartConfig.from_pretrained("bart-large-xsum")
expected_params = dict(num_beams=6, do_sample=False, early_stopping=True, length_penalty=1.0)
config_params = {k: getattr(config, k, "MISSING") for k, v in expected_params.items()}
self.assertDictEqual(expected_params, config_params)
@slow @slow
def test_cnn_summarization_same_as_fairseq(self): def test_cnn_summarization_same_as_fairseq(self):
hf = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(torch_device) hf = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment