Unverified Commit b54ef78d authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

Bart-CNN (#3059)

`generate` code that produces 99% identical summarizations to fairseq on CNN test data, with caching.
parent 6b1ff250
......@@ -4,20 +4,27 @@ Bart
file a `Github Issue <https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title>`__ and assign
@sshleifer
The Bart model was `proposed <https://arxiv.org/abs/1910.13461>`_ by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov, Luke Zettlemoyer on 29 Oct, 2019.
It is a sequence to sequence model where both encoder and decoder are transformers. The paper also introduces a novel pretraining objective, and demonstrates excellent summarization results.
The authors released their code `here <https://github.com/pytorch/fairseq/tree/master/examples/bart>`_
Paper
~~~~~
The Bart model was `proposed <https://arxiv.org/abs/1910.13461>`_ by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer on 29 Oct, 2019.
According to the abstract:
**Abstract:**
- Bart uses a standard seq2seq/machine translation architecture with a bidirectional encoder (like BERT) and a left-to-right decoder (like GPT).
- The pretraining task involves randomly shuffling the order of the original sentences and a novel in-filling scheme, where spans of text are replaced with a single mask token.
- BART is particularly effective when fine tuned for text generation but also works well for comprehension tasks. It matches the performance of RoBERTa with comparable training resources on GLUE and SQuAD, achieves new state-of-the-art results on a range of abstractive dialogue, question answering, and summarization tasks, with gains of up to 6 ROUGE.
*We present BART, a denoising autoencoder for pretraining sequence-to-sequence models. BART is trained by (1) corrupting text with an arbitrary noising function, and (2) learning a model to reconstruct the original text. It uses a standard Tranformer-based neural machine translation architecture which, despite its simplicity, can be seen as generalizing BERT (due to the bidirectional encoder), GPT (with the left-to-right decoder), and many other more recent pretraining schemes. We evaluate a number of noising approaches, finding the best performance by both randomly shuffling the order of the original sentences and using a novel in-filling scheme, where spans of text are replaced with a single mask token. BART is particularly effective when fine tuned for text generation but also works well for comprehension tasks. It matches the performance of RoBERTa with comparable training resources on GLUE and SQuAD, achieves new state-of-the-art results on a range of abstractive dialogue, question answering, and summarization tasks, with gains of up to 6 ROUGE. BART also provides a 1.1 BLEU increase over a back-translation system for machine translation, with only target language pretraining. We also report ablation experiments that replicate other pretraining schemes within the BART framework, to better measure which factors most influence end-task performance.*
`BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension`
The Authors' code can be found `here <https://github.com/pytorch/fairseq/tree/master/examples/bart>`_
Notes:
Implementation Notes
~~~~~~~~~~~~~~~~~~~~
- Bart doesn't use :obj:`token_type_ids`, for sequence classification just use BartTokenizer.encode to get the proper splitting.
- Inputs to the decoder are created by BartModel.forward if they are not passed. This is different than some other model APIs.
- Model predictions are intended to be identical to the original implementation. This only works, however, if the string you pass to fairseq.encode starts with a space.
- Decoder inputs are created automatically by the helper function ``transformers.modeling_bart._prepare_bart_decoder_inputs``
BartModel
- ``MaskedLM.generate`` should be used for summarization, see the example in that docstrings
BartModel
~~~~~~~~~~~~~~~~~~~~
......@@ -30,7 +37,7 @@ BartForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BartForMaskedLM
:members: forward
:members: forward, generate
BartForSequenceClassification
......
......@@ -280,7 +280,10 @@ For a list that includes community-uploaded models, refer to `https://huggingfac
| | | (see `details <https://github.com/pytorch/fairseq/tree/master/examples/bart>`_) |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bart-large-mnli`` | | Adds a 2 layer classification head with 1 million parameters |
| | | | bart-large base architecture with a classification head |
| | | | bart-large base architecture with a classification head, finetuned on MNLI |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bart-large-cnn`` | | 12-layer, 1024-hidden, 16-heads, 406M parameters (same as base) |
| | | | bart-large base architecture finetuned on cnn summarization task |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
......
......@@ -26,7 +26,7 @@ _bart_large_url = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/
BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"bart-large": _bart_large_url,
"bart-large-mnli": _bart_large_url, # fine as same
"bart-cnn": None, # not done
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json",
}
......@@ -59,6 +59,7 @@ class BartConfig(PretrainedConfig):
classifier_dropout=0.0,
output_past=False,
num_labels=3,
bos_token_id=0,
**common_kwargs
):
r"""
......@@ -67,12 +68,16 @@ class BartConfig(PretrainedConfig):
config = BartConfig.from_pretrained('bart-large')
model = BartModel(config)
"""
super().__init__(num_labels=num_labels, output_past=output_past, pad_token_id=pad_token_id, **common_kwargs)
super().__init__(
num_labels=num_labels,
output_past=output_past,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
**common_kwargs,
)
self.vocab_size = vocab_size
self.d_model = d_model # encoder_embed_dim and decoder_embed_dim
self.eos_token_id = eos_token_id
self.encoder_ffn_dim = encoder_ffn_dim
self.encoder_layers = self.num_hidden_layers = encoder_layers
self.encoder_attention_heads = encoder_attention_heads
......
......@@ -23,9 +23,11 @@ import fairseq
import torch
from packaging import version
from transformers import BartConfig, BartForSequenceClassification, BartModel, BartTokenizer
from transformers import BartConfig, BartForMaskedLM, BartForSequenceClassification, BartModel, BartTokenizer
FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn"]
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
raise Exception("requires fairseq >= 0.9.0")
......@@ -33,7 +35,7 @@ if version.parse(fairseq.__version__) < version.parse("0.9.0"):
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
SAMPLE_TEXT = "Hello world! cécé herlolip"
SAMPLE_TEXT = " Hello world! cécé herlolip"
rename_keys = [
("model.classification_heads.mnli.dense.weight", "classification_head.dense.weight"),
......@@ -41,7 +43,7 @@ rename_keys = [
("model.classification_heads.mnli.out_proj.weight", "classification_head.out_proj.weight"),
("model.classification_heads.mnli.out_proj.bias", "classification_head.out_proj.bias"),
]
IGNORE_KEYS = ["encoder.version", "decoder.version", "model.encoder.version", "model.decoder.version"]
IGNORE_KEYS = ["encoder.version", "decoder.version", "model.encoder.version", "model.decoder.version", "_float_tensor"]
def rename_key(dct, old, new):
......@@ -53,36 +55,45 @@ def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path):
"""
Copy/paste/tweak model's weights to our BERT structure.
"""
b2 = torch.hub.load("pytorch/fairseq", checkpoint_path)
b2.eval() # disable dropout
b2.model.upgrade_state_dict(b2.model.state_dict())
config = BartConfig()
tokens = b2.encode(SAMPLE_TEXT).unsqueeze(0)
tokens2 = BartTokenizer.from_pretrained("bart-large").encode(SAMPLE_TEXT).unsqueeze(0)
bart = torch.hub.load("pytorch/fairseq", checkpoint_path)
bart.eval() # disable dropout
bart.model.upgrade_state_dict(bart.model.state_dict())
hf_model_name = checkpoint_path.replace(".", "-")
config = BartConfig.from_pretrained(hf_model_name)
tokens = bart.encode(SAMPLE_TEXT).unsqueeze(0)
tokens2 = BartTokenizer.from_pretrained(hf_model_name).encode(SAMPLE_TEXT, return_tensors="pt").unsqueeze(0)
assert torch.eq(tokens, tokens2).all()
# assert their_output.size() == (1, 11, 1024)
if checkpoint_path == "bart.large":
state_dict = b2.model.state_dict()
if checkpoint_path in ["bart.large", "bart.large.cnn"]:
state_dict = bart.model.state_dict()
for k in IGNORE_KEYS:
state_dict.pop(k, None)
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
model = BartModel(config)
their_output = b2.extract_features(tokens)
their_output = bart.extract_features(tokens)
else: # MNLI Case
state_dict = b2.state_dict()
state_dict = bart.state_dict()
for k in IGNORE_KEYS:
state_dict.pop(k, None)
state_dict["model.shared.weight"] = state_dict["model.decoder.embed_tokens.weight"]
for src, dest in rename_keys:
rename_key(state_dict, src, dest)
state_dict.pop("_float_tensor", None)
model = BartForSequenceClassification(config)
their_output = b2.predict("mnli", tokens, return_logits=True)
for k in IGNORE_KEYS:
state_dict.pop(k, None)
their_output = bart.predict("mnli", tokens, return_logits=True)
# Load state dict
model.load_state_dict(state_dict)
model.eval()
our_outputs = model.forward(tokens)[0]
# Check results
if checkpoint_path == "bart.large.cnn": # generate doesnt work yet
model = BartForMaskedLM(config, base_model=model)
assert "lm_head.weight" in model.state_dict()
assert model.lm_head.out_features == config.max_position_embeddings
model.eval()
our_outputs = model.model.forward(tokens)[0]
else:
our_outputs = model.forward(tokens)[0]
assert their_output.shape == our_outputs.shape
assert (their_output == our_outputs).all().item()
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
......@@ -92,7 +103,8 @@ def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("fairseq_path", choices=["bart.large", "bart.large.mnli"], type=str, help="")
parser.add_argument("fairseq_path", choices=FAIRSEQ_MODELS, type=str, help="")
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
args = parser.parse_args()
convert_bart_checkpoint(
......
This diff is collapsed.
......@@ -171,7 +171,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
else:
output_embeddings.weight = input_embeddings.weight
if hasattr(output_embeddings, "bias") and output_embeddings.bias is not None:
if getattr(output_embeddings, "bias", None) is not None:
output_embeddings.bias.data = torch.nn.functional.pad(
output_embeddings.bias.data,
(0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]),
......@@ -558,7 +558,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
model.__class__.__name__, "\n\t".join(error_msgs)
)
)
model.tie_weights() # make sure word embedding weights are still tied if needed
# Set model in evaluation mode to desactivate DropOut modules by default
......@@ -574,16 +573,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
return {"input_ids": input_ids}
def _do_output_past(self, outputs):
has_output_past = hasattr(self.config, "output_past") and self.config.output_past
has_mem_len = hasattr(self.config, "mem_len") and self.config.mem_len
if has_output_past and not has_mem_len and len(outputs) > 1:
return True
elif has_mem_len and self.config.mem_len > 0 and len(outputs) > 1:
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
has_output_past = getattr(self.config, "output_past", False)
mem_len = getattr(self.config, "mem_len", 0)
if len(outputs) <= 1:
return False
if mem_len > 0 or has_output_past:
return True
return False
def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
"""repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
for i in range(batch_size * num_beams):
for previous_token in set(prev_output_tokens[i].tolist()):
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if lprobs[i, previous_token] < 0:
lprobs[i, previous_token] *= repetition_penalty
else:
lprobs[i, previous_token] /= repetition_penalty
@torch.no_grad()
def generate(
self,
......@@ -761,7 +769,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_return_sequences, cur_len)
input_ids = input_ids.contiguous().view(
batch_size * num_return_sequences, cur_len
) # (batch_size * num_return_sequences, cur_len)
) # shape: (batch_size * num_return_sequences, cur_len)
effective_batch_size = batch_size * num_return_sequences
else:
effective_batch_size = batch_size
......@@ -822,9 +830,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
sent_lengths = input_ids.new(batch_size).fill_(max_length)
past = None
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :]
......@@ -834,13 +842,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
for i in range(batch_size):
for previous_token in set(input_ids[i].tolist()):
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if next_token_logits[i, previous_token] < 0:
next_token_logits[i, previous_token] *= repetition_penalty
else:
next_token_logits[i, previous_token] /= repetition_penalty
self.enforce_repetition_penalty_(next_token_logits, batch_size, 1, input_ids, repetition_penalty)
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
......@@ -911,6 +913,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
""" Generate sequences for each example with beam search.
"""
# Expand input to num beams
# assert input_ids.shape == (batch_size * num_beams, cur_len)
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len)
......@@ -941,13 +944,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
for i in range(batch_size * num_beams):
for previous_token in set(input_ids[i].tolist()):
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if scores[i, previous_token] < 0:
scores[i, previous_token] *= repetition_penalty
else:
scores[i, previous_token] /= repetition_penalty
self.enforce_repetition_penalty_(scores, batch_size, num_beams, input_ids, repetition_penalty)
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
......@@ -1039,16 +1036,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# re-order internal states
if past:
reordered_past = []
for layer_past in past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` and `mems` is at 2nd position
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
# check that shape matches
assert reordered_layer_past.shape == layer_past.shape
reordered_past.append(reordered_layer_past)
past = tuple(reordered_past)
past = self._reorder_cache(past, beam_idx)
# update current length
cur_len = cur_len + 1
......@@ -1096,6 +1084,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
return decoded
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = []
for layer_past in past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` and `mems` is at 2nd position
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
# check that shape matches
assert reordered_layer_past.shape == layer_past.shape
reordered_past.append(reordered_layer_past)
past = tuple(reordered_past)
return past
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
......@@ -1164,17 +1166,22 @@ class BeamHypotheses(object):
else:
self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs):
def is_done(self, best_sum_logprobs, cur_len=None):
"""
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
"""
if len(self) < self.num_beams:
return False
elif self.early_stopping:
return True
else:
return self.worst_score >= best_sum_logprobs / self.max_length ** self.length_penalty
if cur_len is None:
cur_len = self.max_length
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
ret = self.worst_score >= cur_score
return ret
class Conv1D(nn.Module):
......
......@@ -19,11 +19,7 @@ from .tokenization_roberta import RobertaTokenizer
# vocab and merges same as roberta
vocab_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json"
merges_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt"
_all_bart_models = [
"bart-large",
"bart-large-mnli",
# "bart-large-cnn"
]
_all_bart_models = ["bart-large", "bart-large-mnli", "bart-large-cnn"]
class BartTokenizer(RobertaTokenizer):
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment