Unverified Commit 11573231 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[BART] generation_mode as a kwarg not a class attribute (#3278)

parent de697935
......@@ -437,7 +437,6 @@ class BartDecoder(nn.Module):
[DecoderLayer(config) for _ in range(config.decoder_layers)]
) # type: List[DecoderLayer]
self.layernorm_embedding = LayerNorm(config.d_model)
self.generation_mode = False
def forward(
self,
......@@ -446,6 +445,7 @@ class BartDecoder(nn.Module):
encoder_padding_mask,
combined_mask,
decoder_cached_states=None,
generation_mode=False,
**unused
):
"""
......@@ -474,9 +474,9 @@ class BartDecoder(nn.Module):
assert encoder_padding_mask.max() <= 0
# embed positions
positions = self.embed_positions(input_ids, generation_mode=self.generation_mode)
positions = self.embed_positions(input_ids, generation_mode=generation_mode)
if self.generation_mode:
if generation_mode:
input_ids = input_ids[:, -1:]
positions = positions[:, -1:] # happens after we embed them
assert input_ids.ne(self.padding_idx).any()
......@@ -820,10 +820,11 @@ class BartModel(PretrainedBartModel):
encoder_outputs=None, # type: Tuple
decoder_attention_mask=None,
decoder_cached_states=None,
generation_mode=False,
):
# make masks if user doesn't supply
if not self.decoder.generation_mode:
if not generation_mode:
decoder_input_ids, decoder_attention_mask = _prepare_bart_decoder_inputs(
self.config,
input_ids,
......@@ -842,6 +843,7 @@ class BartModel(PretrainedBartModel):
attention_mask,
decoder_attention_mask,
decoder_cached_states=decoder_cached_states,
generation_mode=generation_mode,
)
# Attention and hidden_states will be [] or None if they aren't needed
decoder_outputs = _filter_out_falsey_values(decoder_outputs) # type: tuple
......@@ -886,6 +888,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
decoder_attention_mask=None,
decoder_cached_states=None,
lm_labels=None,
generation_mode=False,
**unused
):
r"""
......@@ -936,6 +939,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
decoder_cached_states=decoder_cached_states,
generation_mode=generation_mode,
)
lm_logits = self.lm_head(outputs[0])
outputs = (lm_logits,) + outputs[1:] # Add hidden states and attention if they are here
......@@ -963,6 +967,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
"decoder_cached_states": decoder_cached_states,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"generation_mode": True,
}
def prepare_scores_for_generation(self, scores, cur_len, max_length):
......
......@@ -846,7 +846,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
attention_mask = attention_mask.contiguous().view(
effective_batch_size * num_beams, input_ids_len
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
if self.config.is_encoder_decoder:
assert bos_token_id is not None, "Encoder Decoder Models need to have a bos_token_id"
# encoder decoder need to start with empty input_ids and copy the input_ids to encoder_inputs
......@@ -859,9 +858,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
)
cur_len = 1
# put model in generation mode if it has one
if hasattr(self.model, "decoder") and hasattr(self.model.decoder, "generation_mode"):
self.model.decoder.generation_mode = True
else:
encoder_inputs = None
cur_len = input_ids.shape[-1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment