Unverified Commit 11573231 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[BART] generation_mode as a kwarg not a class attribute (#3278)

parent de697935
...@@ -437,7 +437,6 @@ class BartDecoder(nn.Module): ...@@ -437,7 +437,6 @@ class BartDecoder(nn.Module):
[DecoderLayer(config) for _ in range(config.decoder_layers)] [DecoderLayer(config) for _ in range(config.decoder_layers)]
) # type: List[DecoderLayer] ) # type: List[DecoderLayer]
self.layernorm_embedding = LayerNorm(config.d_model) self.layernorm_embedding = LayerNorm(config.d_model)
self.generation_mode = False
def forward( def forward(
self, self,
...@@ -446,6 +445,7 @@ class BartDecoder(nn.Module): ...@@ -446,6 +445,7 @@ class BartDecoder(nn.Module):
encoder_padding_mask, encoder_padding_mask,
combined_mask, combined_mask,
decoder_cached_states=None, decoder_cached_states=None,
generation_mode=False,
**unused **unused
): ):
""" """
...@@ -474,9 +474,9 @@ class BartDecoder(nn.Module): ...@@ -474,9 +474,9 @@ class BartDecoder(nn.Module):
assert encoder_padding_mask.max() <= 0 assert encoder_padding_mask.max() <= 0
# embed positions # embed positions
positions = self.embed_positions(input_ids, generation_mode=self.generation_mode) positions = self.embed_positions(input_ids, generation_mode=generation_mode)
if self.generation_mode: if generation_mode:
input_ids = input_ids[:, -1:] input_ids = input_ids[:, -1:]
positions = positions[:, -1:] # happens after we embed them positions = positions[:, -1:] # happens after we embed them
assert input_ids.ne(self.padding_idx).any() assert input_ids.ne(self.padding_idx).any()
...@@ -820,10 +820,11 @@ class BartModel(PretrainedBartModel): ...@@ -820,10 +820,11 @@ class BartModel(PretrainedBartModel):
encoder_outputs=None, # type: Tuple encoder_outputs=None, # type: Tuple
decoder_attention_mask=None, decoder_attention_mask=None,
decoder_cached_states=None, decoder_cached_states=None,
generation_mode=False,
): ):
# make masks if user doesn't supply # make masks if user doesn't supply
if not self.decoder.generation_mode: if not generation_mode:
decoder_input_ids, decoder_attention_mask = _prepare_bart_decoder_inputs( decoder_input_ids, decoder_attention_mask = _prepare_bart_decoder_inputs(
self.config, self.config,
input_ids, input_ids,
...@@ -842,6 +843,7 @@ class BartModel(PretrainedBartModel): ...@@ -842,6 +843,7 @@ class BartModel(PretrainedBartModel):
attention_mask, attention_mask,
decoder_attention_mask, decoder_attention_mask,
decoder_cached_states=decoder_cached_states, decoder_cached_states=decoder_cached_states,
generation_mode=generation_mode,
) )
# Attention and hidden_states will be [] or None if they aren't needed # Attention and hidden_states will be [] or None if they aren't needed
decoder_outputs = _filter_out_falsey_values(decoder_outputs) # type: tuple decoder_outputs = _filter_out_falsey_values(decoder_outputs) # type: tuple
...@@ -886,6 +888,7 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -886,6 +888,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
decoder_attention_mask=None, decoder_attention_mask=None,
decoder_cached_states=None, decoder_cached_states=None,
lm_labels=None, lm_labels=None,
generation_mode=False,
**unused **unused
): ):
r""" r"""
...@@ -936,6 +939,7 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -936,6 +939,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
decoder_cached_states=decoder_cached_states, decoder_cached_states=decoder_cached_states,
generation_mode=generation_mode,
) )
lm_logits = self.lm_head(outputs[0]) lm_logits = self.lm_head(outputs[0])
outputs = (lm_logits,) + outputs[1:] # Add hidden states and attention if they are here outputs = (lm_logits,) + outputs[1:] # Add hidden states and attention if they are here
...@@ -963,6 +967,7 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -963,6 +967,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
"decoder_cached_states": decoder_cached_states, "decoder_cached_states": decoder_cached_states,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"generation_mode": True,
} }
def prepare_scores_for_generation(self, scores, cur_len, max_length): def prepare_scores_for_generation(self, scores, cur_len, max_length):
......
...@@ -846,7 +846,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -846,7 +846,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
attention_mask = attention_mask.contiguous().view( attention_mask = attention_mask.contiguous().view(
effective_batch_size * num_beams, input_ids_len effective_batch_size * num_beams, input_ids_len
) # shape: (batch_size * num_return_sequences * num_beams, cur_len) ) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
assert bos_token_id is not None, "Encoder Decoder Models need to have a bos_token_id" assert bos_token_id is not None, "Encoder Decoder Models need to have a bos_token_id"
# encoder decoder need to start with empty input_ids and copy the input_ids to encoder_inputs # encoder decoder need to start with empty input_ids and copy the input_ids to encoder_inputs
...@@ -859,9 +858,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -859,9 +858,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
) )
cur_len = 1 cur_len = 1
# put model in generation mode if it has one
if hasattr(self.model, "decoder") and hasattr(self.model.decoder, "generation_mode"):
self.model.decoder.generation_mode = True
else: else:
encoder_inputs = None encoder_inputs = None
cur_len = input_ids.shape[-1] cur_len = input_ids.shape[-1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment