Unverified Commit 741e4930 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

Fix Bug in Flax Seq2Seq Models (#16021)

* Fix Bug in Flax Seq2Seq Models

* incorporate suggested changes
parent b7018abf
......@@ -104,9 +104,9 @@ ENCODER_DECODER_INPUTS_DOCSTRING = r"""
[What are decoder input IDs?](../glossary#decoder-input-ids)
For sequence to sequence training, `decoder_input_ids` should be provided. If no `decoder_input_ids` is
provided, the model will create this tensor by shifting the `input_ids` to the right for denoising
pre-training.
For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
and prepending them with the `decoder_start_token_id`.
decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
be used by default.
......@@ -169,9 +169,9 @@ ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r"""
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
For sequence to sequence training, `decoder_input_ids` should be provided. If no `decoder_input_ids` is
provided, the model will create this tensor by shifting the `input_ids` to the right for denoising
pre-training.
For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
and prepending them with the `decoder_start_token_id`.
encoder_outputs (`tuple(tuple(jnp.ndarray)`):
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
......@@ -670,6 +670,11 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
batch_size, sequence_length = input_ids.shape
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
# prepare decoder inputs
if decoder_input_ids is None:
raise ValueError(
"`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must be specified as an input argument."
)
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
if decoder_position_ids is None:
......
......@@ -108,8 +108,9 @@ SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
For training, `decoder_input_ids` are automatically created by the model by shifting the `labels` to the
right, replacing -100 by the `pad_token_id` and prepending them with the `decoder_start_token_id`.
For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
and prepending them with the `decoder_start_token_id`.
decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
be used by default.
......@@ -161,9 +162,9 @@ SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r"""
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
For sequence to sequence training, `decoder_input_ids` should be provided. If no `decoder_input_ids` is
provided, the model will create this tensor by shifting the `input_ids` to the right for denoising
pre-training.
For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
and prepending them with the `decoder_start_token_id`.
encoder_outputs (`tuple(tuple(jnp.ndarray)`):
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
......@@ -681,6 +682,10 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
attention_mask = jnp.ones_like(inputs)
# prepare decoder inputs
if decoder_input_ids is None:
raise ValueError(
"`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must be specified as an input argument."
)
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
if decoder_position_ids is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment