Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
741e4930
Unverified
Commit
741e4930
authored
Mar 10, 2022
by
Sanchit Gandhi
Committed by
GitHub
Mar 10, 2022
Browse files
Fix Bug in Flax Seq2Seq Models (#16021)
* Fix Bug in Flax Seq2Seq Models * incorporate suggested changes
parent
b7018abf
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
11 deletions
+21
-11
src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py
...s/models/encoder_decoder/modeling_flax_encoder_decoder.py
+11
-6
src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py
...h_encoder_decoder/modeling_flax_speech_encoder_decoder.py
+10
-5
No files found.
src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py
View file @
741e4930
...
@@ -104,9 +104,9 @@ ENCODER_DECODER_INPUTS_DOCSTRING = r"""
...
@@ -104,9 +104,9 @@ ENCODER_DECODER_INPUTS_DOCSTRING = r"""
[What are decoder input IDs?](../glossary#decoder-input-ids)
[What are decoder input IDs?](../glossary#decoder-input-ids)
For sequence to sequence training, `decoder_input_ids` should be provided.
If no
`decoder_input_ids`
i
s
For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` s
hould be
provided, the model will create this tensor
by shifting the `
input_id
s` to the right
for denoising
created outside of the model
by shifting the `
label
s` to the right
, replacing -100 by the `pad_token_id`
pre-training
.
and prepending them with the `decoder_start_token_id`
.
decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
be used by default.
be used by default.
...
@@ -169,9 +169,9 @@ ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r"""
...
@@ -169,9 +169,9 @@ ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r"""
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
`past_key_values`).
For sequence to sequence training, `decoder_input_ids` should be provided.
If no
`decoder_input_ids`
i
s
For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` s
hould be
provided, the model will create this tensor
by shifting the `
input_id
s` to the right
for denoising
created outside of the model
by shifting the `
label
s` to the right
, replacing -100 by the `pad_token_id`
pre-training
.
and prepending them with the `decoder_start_token_id`
.
encoder_outputs (`tuple(tuple(jnp.ndarray)`):
encoder_outputs (`tuple(tuple(jnp.ndarray)`):
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
...
@@ -670,6 +670,11 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
...
@@ -670,6 +670,11 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
batch_size
,
sequence_length
=
input_ids
.
shape
batch_size
,
sequence_length
=
input_ids
.
shape
position_ids
=
jnp
.
broadcast_to
(
jnp
.
arange
(
sequence_length
)[
None
,
:],
(
batch_size
,
sequence_length
))
position_ids
=
jnp
.
broadcast_to
(
jnp
.
arange
(
sequence_length
)[
None
,
:],
(
batch_size
,
sequence_length
))
# prepare decoder inputs
if
decoder_input_ids
is
None
:
raise
ValueError
(
"`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must be specified as an input argument."
)
if
decoder_attention_mask
is
None
:
if
decoder_attention_mask
is
None
:
decoder_attention_mask
=
jnp
.
ones_like
(
decoder_input_ids
)
decoder_attention_mask
=
jnp
.
ones_like
(
decoder_input_ids
)
if
decoder_position_ids
is
None
:
if
decoder_position_ids
is
None
:
...
...
src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py
View file @
741e4930
...
@@ -108,8 +108,9 @@ SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
...
@@ -108,8 +108,9 @@ SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
`past_key_values`).
For training, `decoder_input_ids` are automatically created by the model by shifting the `labels` to the
For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
right, replacing -100 by the `pad_token_id` and prepending them with the `decoder_start_token_id`.
created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
and prepending them with the `decoder_start_token_id`.
decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
be used by default.
be used by default.
...
@@ -161,9 +162,9 @@ SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r"""
...
@@ -161,9 +162,9 @@ SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r"""
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
`past_key_values`).
For sequence to sequence training, `decoder_input_ids` should be provided.
If no
`decoder_input_ids`
i
s
For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` s
hould be
provided, the model will create this tensor
by shifting the `
input_id
s` to the right
for denoising
created outside of the model
by shifting the `
label
s` to the right
, replacing -100 by the `pad_token_id`
pre-training
.
and prepending them with the `decoder_start_token_id`
.
encoder_outputs (`tuple(tuple(jnp.ndarray)`):
encoder_outputs (`tuple(tuple(jnp.ndarray)`):
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
...
@@ -681,6 +682,10 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
...
@@ -681,6 +682,10 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
attention_mask
=
jnp
.
ones_like
(
inputs
)
attention_mask
=
jnp
.
ones_like
(
inputs
)
# prepare decoder inputs
# prepare decoder inputs
if
decoder_input_ids
is
None
:
raise
ValueError
(
"`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must be specified as an input argument."
)
if
decoder_attention_mask
is
None
:
if
decoder_attention_mask
is
None
:
decoder_attention_mask
=
jnp
.
ones_like
(
decoder_input_ids
)
decoder_attention_mask
=
jnp
.
ones_like
(
decoder_input_ids
)
if
decoder_position_ids
is
None
:
if
decoder_position_ids
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment