Unverified Commit 715aa5b1 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[Bart] Replace config.output_past with use_cache kwarg (#3632)

parent e344e3d4
......@@ -20,7 +20,7 @@ def generate_summaries(
examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE
):
fout = Path(out_file).open("w")
model = BartForConditionalGeneration.from_pretrained(model_name, output_past=True,).to(device)
model = BartForConditionalGeneration.from_pretrained(model_name).to(device)
tokenizer = BartTokenizer.from_pretrained("bart-large")
max_length = 140
......
......@@ -56,7 +56,6 @@ class BartConfig(PretrainedConfig):
max_position_embeddings=1024,
init_std=0.02,
classifier_dropout=0.0,
output_past=False,
num_labels=3,
is_encoder_decoder=True,
pad_token_id=1,
......@@ -72,7 +71,6 @@ class BartConfig(PretrainedConfig):
"""
super().__init__(
num_labels=num_labels,
output_past=output_past,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
......
......@@ -388,7 +388,6 @@ class BartDecoder(nn.Module):
def __init__(self, config: BartConfig, embed_tokens: nn.Embedding):
super().__init__()
self.output_past = config.output_past
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.dropout = config.dropout
......@@ -412,7 +411,7 @@ class BartDecoder(nn.Module):
decoder_padding_mask,
decoder_causal_mask,
decoder_cached_states=None,
generation_mode=False,
use_cache=False,
**unused
):
"""
......@@ -438,9 +437,9 @@ class BartDecoder(nn.Module):
encoder_padding_mask = invert_mask(encoder_padding_mask)
# embed positions
positions = self.embed_positions(input_ids, generation_mode=generation_mode)
positions = self.embed_positions(input_ids, use_cache=use_cache)
if generation_mode:
if use_cache:
input_ids = input_ids[:, -1:]
positions = positions[:, -1:] # happens after we embed them
assert input_ids.ne(self.padding_idx).any()
......@@ -476,7 +475,7 @@ class BartDecoder(nn.Module):
causal_mask=decoder_causal_mask,
)
if self.output_past:
if use_cache:
next_decoder_cache.append(layer_past.copy())
if self.output_hidden_states:
all_hidden_states += (x,)
......@@ -488,7 +487,7 @@ class BartDecoder(nn.Module):
x = x.transpose(0, 1)
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
if self.output_past:
if use_cache:
next_cache = ((encoder_hidden_states, encoder_padding_mask), next_decoder_cache)
else:
next_cache = None
......@@ -710,9 +709,9 @@ class LearnedPositionalEmbedding(nn.Embedding):
num_embeddings += padding_idx + 1 # WHY?
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
def forward(self, input, generation_mode=False):
def forward(self, input, use_cache=False):
"""Input is expected to be of size [bsz x seqlen]."""
if generation_mode: # the position is our current step in the decoded sequence
if use_cache: # the position is our current step in the decoded sequence
pos = int(self.padding_idx + input.size(1))
positions = input.data.new(1, 1).fill_(pos)
else:
......@@ -772,11 +771,11 @@ class BartModel(PretrainedBartModel):
encoder_outputs=None, # type: Tuple
decoder_attention_mask=None,
decoder_cached_states=None,
generation_mode=False,
use_cache=False,
):
# make masks if user doesn't supply
if not generation_mode:
if not use_cache:
decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs(
self.config,
input_ids,
......@@ -799,7 +798,7 @@ class BartModel(PretrainedBartModel):
decoder_padding_mask,
decoder_causal_mask=causal_mask,
decoder_cached_states=decoder_cached_states,
generation_mode=generation_mode,
use_cache=use_cache,
)
# Attention and hidden_states will be [] or None if they aren't needed
decoder_outputs = _filter_out_falsey_values(decoder_outputs) # type: tuple
......@@ -841,7 +840,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
decoder_attention_mask=None,
decoder_cached_states=None,
lm_labels=None,
generation_mode=False,
use_cache=False,
**unused
):
r"""
......@@ -892,7 +891,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
decoder_cached_states=decoder_cached_states,
generation_mode=generation_mode,
use_cache=use_cache,
)
lm_logits = F.linear(outputs[0], self.model.shared.weight)
outputs = (lm_logits,) + outputs[1:] # Add hidden states and attention if they are here
......@@ -918,7 +917,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
"decoder_cached_states": decoder_cached_states,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"generation_mode": True,
"use_cache": True, # change this to avoid caching (presumably for debugging)
}
def prepare_scores_for_generation(self, scores, cur_len, max_length):
......@@ -951,6 +950,10 @@ class BartForConditionalGeneration(PretrainedBartModel):
def get_output_embeddings(self):
return _make_linear_from_emb(self.model.shared) # make it on the fly
def _do_output_past(self, *args, **kwargs):
""" We should always use the cache in generate."""
return True
@add_start_docstrings(
"""Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """,
......
......@@ -200,7 +200,7 @@ class BartHeadTests(unittest.TestCase):
vocab_size = 99
def _get_config_and_data(self, output_past=False):
def _get_config_and_data(self):
input_ids = torch.tensor(
[
[71, 82, 18, 33, 46, 91, 2],
......@@ -232,7 +232,6 @@ class BartHeadTests(unittest.TestCase):
encoder_ffn_dim=32,
decoder_ffn_dim=32,
max_position_embeddings=48,
output_past=output_past,
eos_token_id=2,
pad_token_id=1,
bos_token_id=0,
......@@ -252,7 +251,7 @@ class BartHeadTests(unittest.TestCase):
self.assertIsInstance(loss.item(), float)
def test_lm_forward(self):
config, input_ids, batch_size = self._get_config_and_data(output_past=False)
config, input_ids, batch_size = self._get_config_and_data()
lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size).to(torch_device)
lm_model = BartForConditionalGeneration(config)
lm_model.to(torch_device)
......@@ -292,7 +291,6 @@ class BartHeadTests(unittest.TestCase):
encoder_ffn_dim=32,
decoder_ffn_dim=32,
max_position_embeddings=48,
output_past=True,
eos_token_id=2,
pad_token_id=1,
bos_token_id=0,
......@@ -335,20 +333,20 @@ class BartHeadTests(unittest.TestCase):
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
def test_generate_fp16(self):
config, input_ids, batch_size = self._get_config_and_data(output_past=True)
config, input_ids, batch_size = self._get_config_and_data()
attention_mask = input_ids.ne(1).to(torch_device)
model = BartForConditionalGeneration(config).eval().to(torch_device).half()
model.generate(input_ids, attention_mask=attention_mask, do_sample=False, early_stopping=True)
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
def test_base_model_fp16(self):
config, input_ids, batch_size = self._get_config_and_data(output_past=False)
config, input_ids, batch_size = self._get_config_and_data()
attention_mask = input_ids.ne(1).to(torch_device)
lm_model = BartForConditionalGeneration(config).eval().to(torch_device).half()
lm_model(input_ids, attention_mask=attention_mask)
def test_default_generate_kwargs(self):
config, input_ids, _ = self._get_config_and_data(output_past=True)
config, input_ids, _ = self._get_config_and_data()
model = BartForConditionalGeneration(config).eval().to(torch_device)
model.generate(input_ids)
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
......@@ -359,7 +357,7 @@ class BartHeadTests(unittest.TestCase):
model(**model.dummy_inputs)
def test_prepare_bart_decoder_inputs(self):
config, *_ = self._get_config_and_data(output_past=False)
config, *_ = self._get_config_and_data()
input_ids = _long_tensor(([4, 4, 2]))
decoder_input_ids = _long_tensor([[26388, 2, config.pad_token_id]])
ignore = float("-inf")
......@@ -495,7 +493,7 @@ class BartModelIntegrationTests(unittest.TestCase):
@slow
def test_cnn_summarization_same_as_fairseq(self):
hf = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(torch_device)
hf = BartForConditionalGeneration.from_pretrained("bart-large-cnn").to(torch_device)
tok = BartTokenizer.from_pretrained("bart-large")
FRANCE_ARTICLE = ' Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of \'My God\' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said, but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by specialized technicians working hand-in-hand with investigators. But none of the cell phones found so far have been sent to the institute, Menichini said. Asked whether staff involved in the search could have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered cell phones from the crash site after Bild and Paris Match published their reports. "That is something we did not know before. ... Overall we can say many things of the investigation weren\'t revealed by the investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the controls of Germanwings Flight 9525, which he\'s accused of deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa said, included medical documents he submitted in connection with resuming his flight training. The announcement indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz\'s battle with depression, allowed him to continue training and ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100% fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was sharing the information and documents -- including training and medical records -- with public prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside. He saw the crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no visible human remains were left at the site but recovery teams would keep searching. French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested. In the meantime, the recovery of the victims\' personal belongings will start Wednesday, Menichini said. Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board. Check out the latest from our correspondents . The details about Lubitz\'s correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and Lubitz\'s possible motive for downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot\'s license. Kumpa emphasized there\'s no evidence suggesting Lubitz was suicidal or acting aggressively before the crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to lose his pilot\'s license, a European government official briefed on the investigation told CNN on Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being considered. Another source, a law enforcement official briefed on the investigation, also told CNN that authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems. Lubitz\'s girlfriend told investigators he had seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had psychological issues, the European government official said. But no matter what details emerge about his previous mental health struggles, there\'s more to the story, said Brian Russell, a forensic psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they weren\'t going to keep doing their job and they\'re upset about that and so they\'re suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person\'s problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight 9525? CNN\'s Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura Smith-Spark wrote from London. CNN\'s Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report.' # @noqa
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment