Unverified Commit 857e0a0d authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

Rename BartForMaskedLM -> BartForConditionalGeneration (#3114)

* improved documentation
parent fa2aa699
......@@ -7,7 +7,7 @@ file a `Github Issue <https://github.com/huggingface/transformers/issues/new?ass
Paper
~~~~~
The Bart model was `proposed <https://arxiv.org/abs/1910.13461>`_ by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer on 29 Oct, 2019.
According to the abstract:
According to the abstract,
- Bart uses a standard seq2seq/machine translation architecture with a bidirectional encoder (like BERT) and a left-to-right decoder (like GPT).
- The pretraining task involves randomly shuffling the order of the original sentences and a novel in-filling scheme, where spans of text are replaced with a single mask token.
......@@ -18,26 +18,28 @@ The Authors' code can be found `here <https://github.com/pytorch/fairseq/tree/ma
Implementation Notes
~~~~~~~~~~~~~~~~~~~~
- Bart doesn't use :obj:`token_type_ids`, for sequence classification just use BartTokenizer.encode to get the proper splitting.
- Inputs to the decoder are created by BartModel.forward if they are not passed. This is different than some other model APIs.
- Model predictions are intended to be identical to the original implementation. This only works, however, if the string you pass to fairseq.encode starts with a space.
- Decoder inputs are created automatically by the helper function ``transformers.modeling_bart._prepare_bart_decoder_inputs``
BartModel
- ``MaskedLM.generate`` should be used for summarization, see the example in that docstrings
- Bart doesn't use :obj:`token_type_ids` for sequence classification. Use BartTokenizer.encode to get the proper splitting.
- The forward pass of ``BartModel`` will create decoder inputs (using the helper function ``transformers.modeling_bart._prepare_bart_decoder_inputs``) if they are not passed. This is different than some other modeling APIs.
- Model predictions are intended to be identical to the original implementation. This only works, however, if the string you pass to ``fairseq.encode`` starts with a space.
- ``BartForConditionalGeneration.generate`` should be used for conditional generation tasks like summarization, see the example in that docstrings
- Models that load the ``"bart-large-cnn"`` weights will not have a ``mask_token_id``, or be able to perform mask filling tasks.
BartModel
~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~
.. autoclass:: transformers.BartModel
:members: forward
.. autofunction:: transformers.modeling_bart._prepare_bart_decoder_inputs
BartForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BartForMaskedLM
:members: forward, generate
BartForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BartForConditionalGeneration
:members: generate, forward
BartForSequenceClassification
......@@ -52,8 +54,3 @@ BartConfig
.. autoclass:: transformers.BartConfig
:members:
Automatic Creation of Decoder Inputs
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
This is enabled by default
.. autofunction:: transformers.modeling_bart._prepare_bart_decoder_inputs
......@@ -4,7 +4,7 @@ from pathlib import Path
import torch
from tqdm import tqdm
from transformers import BartForMaskedLM, BartTokenizer
from transformers import BartForConditionalGeneration, BartTokenizer
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
......@@ -18,7 +18,7 @@ def chunks(lst, n):
def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
fout = Path(out_file).open("w")
model = BartForMaskedLM.from_pretrained("bart-large-cnn", output_past=True,)
model = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,)
tokenizer = BartTokenizer.from_pretrained("bart-large")
for batch in tqdm(list(chunks(lns, batch_size))):
dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True)
......
......@@ -206,7 +206,11 @@ if is_torch_available():
XLMForQuestionAnsweringSimple,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_bart import BartForSequenceClassification, BartModel, BartForMaskedLM
from .modeling_bart import (
BartForSequenceClassification,
BartModel,
BartForConditionalGeneration,
)
from .modeling_roberta import (
RobertaForMaskedLM,
RobertaModel,
......
......@@ -23,7 +23,13 @@ import fairseq
import torch
from packaging import version
from transformers import BartConfig, BartForMaskedLM, BartForSequenceClassification, BartModel, BartTokenizer
from transformers import (
BartConfig,
BartForConditionalGeneration,
BartForSequenceClassification,
BartModel,
BartTokenizer,
)
FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn"]
......@@ -86,14 +92,14 @@ def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path):
model.eval()
# Check results
if checkpoint_path == "bart.large.cnn": # generate doesnt work yet
model = BartForMaskedLM(config, base_model=model)
if checkpoint_path == "bart.large.cnn":
model = BartForConditionalGeneration(config, base_model=model)
assert "lm_head.weight" in model.state_dict()
assert model.lm_head.out_features == config.max_position_embeddings
model.eval()
our_outputs = model.model.forward(tokens)[0]
our_outputs = model.model(tokens)[0]
else:
our_outputs = model.forward(tokens)[0]
our_outputs = model(tokens)[0]
assert their_output.shape == our_outputs.shape
assert (their_output == our_outputs).all().item()
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
......
......@@ -45,7 +45,12 @@ from .modeling_albert import (
AlbertForTokenClassification,
AlbertModel,
)
from .modeling_bart import BART_PRETRAINED_MODEL_ARCHIVE_MAP, BartForMaskedLM, BartForSequenceClassification, BartModel
from .modeling_bart import (
BART_PRETRAINED_MODEL_ARCHIVE_MAP,
BartForConditionalGeneration,
BartForSequenceClassification,
BartModel,
)
from .modeling_bert import (
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
BertForMaskedLM,
......@@ -166,7 +171,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
(AlbertConfig, AlbertForMaskedLM),
(CamembertConfig, CamembertForMaskedLM),
(XLMRobertaConfig, XLMRobertaForMaskedLM),
(BartConfig, BartForMaskedLM),
(BartConfig, BartForConditionalGeneration),
(RobertaConfig, RobertaForMaskedLM),
(BertConfig, BertForPreTraining),
(OpenAIGPTConfig, OpenAIGPTLMHeadModel),
......@@ -186,7 +191,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
(AlbertConfig, AlbertForMaskedLM),
(CamembertConfig, CamembertForMaskedLM),
(XLMRobertaConfig, XLMRobertaForMaskedLM),
(BartConfig, BartForMaskedLM),
(BartConfig, BartForConditionalGeneration),
(RobertaConfig, RobertaForMaskedLM),
(BertConfig, BertForMaskedLM),
(OpenAIGPTConfig, OpenAIGPTLMHeadModel),
......
......@@ -778,21 +778,6 @@ def _filter_out_falsey_values(tup) -> Tuple:
return tuple(x for x in tup if isinstance(x, torch.Tensor) or x)
RET_DOCSTRING = r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
# Public API
......@@ -863,10 +848,9 @@ class BartModel(PretrainedBartModel):
@add_start_docstrings(
"The bare BART Model with a language modeling head. This is the model used for summarization.",
BART_START_DOCSTRING,
"The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING,
)
class BartForMaskedLM(PretrainedBartModel):
class BartForConditionalGeneration(PretrainedBartModel):
base_model_prefix = "model"
def __init__(self, config: BartConfig):
......@@ -919,11 +903,18 @@ class BartForMaskedLM(PretrainedBartModel):
Examples::
tokenizer = BartTokenizer.from_pretrained('bart-large')
model = BartForMaskedLM.from_pretrained('bart-large')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids=input_ids, lm_labels=input_ids)
loss, prediction_scores = outputs[:2]
# Mask filling only works for bart-large
from transformers import BartTokenizer, BartForConditionalGeneration
tokenizer = AutoTokenizer.from_pretrained('bart-large')
TXT = "My friends are <mask> but they eat too many carbs."
model = BartForConditionalGeneration.from_pretrained('bart-large')
input_ids = tokenizer.batch_encode_plus([TXT], return_tensors='pt')['input_ids']
logits = model(input_ids)[0]
masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
probs = logits[0, masked_index].softmax(dim=0)
values, predictions = probs.topk(5)
tokenizer.decode(predictions).split()
# ['good', 'great', 'all', 'really', 'very']
"""
outputs = self.model(
input_ids,
......@@ -992,8 +983,7 @@ class BartForMaskedLM(PretrainedBartModel):
min_len=0,
no_repeat_ngram_size=0,
):
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
and beam-search.
r""" Generates summaries using the lm-head and greedy beam search
Adapted in part from Facebook's `XLM beam search code`_ and `Fairseq beam search code`_.
......@@ -1031,16 +1021,16 @@ class BartForMaskedLM(PretrainedBartModel):
sequence_length is <= max_length (examples can finish early)
Examples::
config = BartConfig(vocab_size=50264, output_past=True)
model = AutoModelWithLMHead.from_pretrained('bart-large-cnn', config=config)
tokenizer = AutoTokenizer.from_pretrained('bart-large-cnn')
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
# see ``examples/summarization/bart/evaluate_cnn.py`` for a longer example
config = BartConfig(vocab_size=50264, output_past=True) # no mask_token_id
model = BartForConditionalGeneration.from_pretrained('bart-large-cnn', config=config)
tokenizer = BartTokenizer.from_pretrained('bart-large-cnn')
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
# Generate Summary
generated_ids = model.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'], num_beams=4, max_length=5)
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in generated_ids])
summary_ids = model.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'], num_beams=4, max_length=5)
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])
"""
bos_token_id = self.config.bos_token_id
pad_token_id = self.config.pad_token_id
......
......@@ -29,7 +29,7 @@ if is_torch_available():
from transformers import (
AutoModelForSequenceClassification,
BartModel,
BartForMaskedLM,
BartForConditionalGeneration,
BartForSequenceClassification,
BartConfig,
)
......@@ -97,7 +97,9 @@ def prepare_bart_inputs_dict(
@require_torch
class BARTModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (BartModel, BartForMaskedLM, BartForSequenceClassification) if is_torch_available() else ()
all_model_classes = (
(BartModel, BartForConditionalGeneration, BartForSequenceClassification) if is_torch_available() else ()
)
is_encoder_decoder = True
# TODO(SS): fix the below in a separate PR
test_pruning = False
......@@ -221,8 +223,8 @@ class BartHeadTests(unittest.TestCase):
def test_lm_forward(self):
config, input_ids, batch_size = self._get_config_and_data(output_past=False)
decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size)
lm_model = BartForMaskedLM(config)
decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size).to(torch_device)
lm_model = BartForConditionalGeneration(config)
lm_model.to(torch_device)
loss, logits, enc_features = lm_model.forward(
input_ids=input_ids, lm_labels=decoder_lm_labels, decoder_input_ids=input_ids
......@@ -243,15 +245,15 @@ class BartHeadTests(unittest.TestCase):
decoder_ffn_dim=32,
max_position_embeddings=48,
)
lm_model = BartForMaskedLM(config)
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long()
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long()
logits, enc_features = lm_model.forward(input_ids=context, decoder_input_ids=summary)
lm_model = BartForConditionalGeneration(config).to(torch_device)
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device)
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device)
loss, logits, enc_features = lm_model.forward(input_ids=context, decoder_input_ids=summary, lm_labels=summary)
expected_shape = (*summary.shape, config.vocab_size)
self.assertEqual(logits.shape, expected_shape)
def test_generate_beam_search(self):
input_ids = torch.Tensor([[71, 82, 2], [68, 34, 2]]).long()
input_ids = torch.Tensor([[71, 82, 2], [68, 34, 2]]).long().to(torch_device)
config = BartConfig(
vocab_size=self.vocab_size,
d_model=24,
......@@ -264,7 +266,7 @@ class BartHeadTests(unittest.TestCase):
max_position_embeddings=48,
output_past=True,
)
lm_model = BartForMaskedLM(config)
lm_model = BartForConditionalGeneration(config).to(torch_device)
lm_model.eval()
new_input_ids = lm_model.generate(
......@@ -376,7 +378,7 @@ class BartModelIntegrationTest(unittest.TestCase):
@slow
def test_cnn_summarization_same_as_fairseq(self):
hf = BartForMaskedLM.from_pretrained("bart-large-cnn", output_past=True,).to(torch_device)
hf = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(torch_device)
tok = BartTokenizer.from_pretrained("bart-large")
text = " (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian"
tokens = tok.encode(text, return_tensors="pt").to(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment