Unverified Commit 8be9cb0a authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Tiny TF Bart fixes (#8023)

parent 07747863
......@@ -822,7 +822,7 @@ class TFBartModel(TFPretrainedBartModel):
if decoder_attn_mask is None:
decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
else:
decoder_padding_mask = invert_mask(tf.Tensor)
decoder_padding_mask = invert_mask(decoder_attn_mask)
causal_lm_mask = causal_attention_mask(tgt_len, tgt_len, mask_dtype)
return decoder_input_ids, decoder_padding_mask, causal_lm_mask
......@@ -903,6 +903,7 @@ class TFBartModel(TFPretrainedBartModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
training=training,
)
decoder_outputs = self.decoder(
decoder_input_ids,
......@@ -915,6 +916,7 @@ class TFBartModel(TFPretrainedBartModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
if not return_dict:
# Attention and hidden_states will be [] or None if they aren't needed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment