Unverified Commit e9d51387 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: Merge PT and TF behavior for Bart when no decoder_input_ids are passed (#17593)

* Merge PT and TF behavior
parent e160a5dd
...@@ -1073,14 +1073,16 @@ class TFBartMainLayer(tf.keras.layers.Layer): ...@@ -1073,14 +1073,16 @@ class TFBartMainLayer(tf.keras.layers.Layer):
**kwargs **kwargs
) -> Union[TFSeq2SeqModelOutput, Tuple[tf.Tensor]]: ) -> Union[TFSeq2SeqModelOutput, Tuple[tf.Tensor]]:
# different to other models, Bart automatically creates decoder_input_ids from
# input_ids if no decoder_input_ids are provided
if decoder_input_ids is None and decoder_inputs_embeds is None: if decoder_input_ids is None and decoder_inputs_embeds is None:
use_cache = False if input_ids is None:
raise ValueError(
output_hidden_states = ( "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states "passed, `input_ids` cannot be `None`. Please pass either "
) "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
)
if decoder_input_ids is None and input_ids is not None:
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
input_ids, self.config.pad_token_id, self.config.decoder_start_token_id input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment