Unverified Commit 92487a1d authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

Bart: fix layerdrop and cached decoder_input_ids for generation (#2969)

parent c36416e5
...@@ -470,10 +470,6 @@ class BartDecoder(nn.Module): ...@@ -470,10 +470,6 @@ class BartDecoder(nn.Module):
""" """
# embed positions # embed positions
positions = self.embed_positions(input_ids) positions = self.embed_positions(input_ids)
if decoder_cached_states is not None:
input_ids = input_ids[:, -1:]
positions = positions[:, -1:]
x = self.embed_tokens(input_ids) x = self.embed_tokens(input_ids)
if positions is not None: if positions is not None:
...@@ -491,7 +487,7 @@ class BartDecoder(nn.Module): ...@@ -491,7 +487,7 @@ class BartDecoder(nn.Module):
decoder_layer # type: DecoderLayer decoder_layer # type: DecoderLayer
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1) dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability > self.layerdrop): if self.training and (dropout_probability < self.layerdrop):
continue continue
layer_state = decoder_cached_states[i] if decoder_cached_states is not None else None layer_state = decoder_cached_states[i] if decoder_cached_states is not None else None
x, layer_self_attn, layer_past = decoder_layer.forward( x, layer_self_attn, layer_past = decoder_layer.forward(
...@@ -940,7 +936,7 @@ class BartForMaskedLM(PretrainedBartModel): ...@@ -940,7 +936,7 @@ class BartForMaskedLM(PretrainedBartModel):
@staticmethod @staticmethod
def prepare_inputs_for_generation(input_ids, past, **kwargs): def prepare_inputs_for_generation(input_ids, past, **kwargs):
return {"input_ids": input_ids, "decoder_cached_states": past, "decoder_input_ids": input_ids} return {"input_ids": input_ids, "decoder_cached_states": past, "decoder_input_ids": input_ids[:, -1:]}
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
......
...@@ -251,6 +251,7 @@ class BartHeadTests(unittest.TestCase): ...@@ -251,6 +251,7 @@ class BartHeadTests(unittest.TestCase):
output_past=True, output_past=True,
) )
lm_model = BartForMaskedLM(config) lm_model = BartForMaskedLM(config)
lm_model.eval()
new_input_ids = lm_model.generate(input_ids) new_input_ids = lm_model.generate(input_ids)
self.assertEqual(new_input_ids.shape, (input_ids.shape[0], 20)) self.assertEqual(new_input_ids.shape, (input_ids.shape[0], 20))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment