Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ad0cba08
Unverified
Commit
ad0cba08
authored
Apr 04, 2022
by
Patrick von Platen
Committed by
GitHub
Apr 04, 2022
Browse files
[FlaxSpeechEncoderDecoder] Fix dtype bug (#16581)
* [FlaxSpeechEncoderDecoder] Fix dtype bug * more fixes
parent
60d27b1f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
7 deletions
+7
-7
src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py
...h_encoder_decoder/modeling_flax_speech_encoder_decoder.py
+7
-7
No files found.
src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py
View file @
ad0cba08
...
@@ -310,7 +310,7 @@ class FlaxSpeechEncoderDecoderModule(nn.Module):
...
@@ -310,7 +310,7 @@ class FlaxSpeechEncoderDecoderModule(nn.Module):
decoder_hidden_states
=
decoder_outputs
.
hidden_states
,
decoder_hidden_states
=
decoder_outputs
.
hidden_states
,
decoder_attentions
=
decoder_outputs
.
attentions
,
decoder_attentions
=
decoder_outputs
.
attentions
,
cross_attentions
=
decoder_outputs
.
cross_attentions
,
cross_attentions
=
decoder_outputs
.
cross_attentions
,
encoder_last_hidden_state
=
encoder_
outputs
.
last_
hidden_state
,
encoder_last_hidden_state
=
encoder_hidden_state
s
,
encoder_hidden_states
=
encoder_outputs
.
hidden_states
,
encoder_hidden_states
=
encoder_outputs
.
hidden_states
,
encoder_attentions
=
encoder_outputs
.
attentions
,
encoder_attentions
=
encoder_outputs
.
attentions
,
)
)
...
@@ -363,8 +363,8 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
...
@@ -363,8 +363,8 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
encoder_input_shape
,
decoder_input_shape
=
input_shape
encoder_input_shape
,
decoder_input_shape
=
input_shape
# init input DeviceArrays
# init input DeviceArrays
inputs
=
jnp
.
zeros
(
encoder_input_shape
,
dtype
=
"
i
4"
)
inputs
=
jnp
.
zeros
(
encoder_input_shape
,
dtype
=
"
f
4"
)
attention_mask
=
jnp
.
ones_like
(
inputs
)
attention_mask
=
jnp
.
ones_like
(
inputs
,
dtype
=
"i4"
)
decoder_input_ids
=
jnp
.
zeros
(
decoder_input_shape
,
dtype
=
"i4"
)
decoder_input_ids
=
jnp
.
zeros
(
decoder_input_shape
,
dtype
=
"i4"
)
decoder_attention_mask
=
jnp
.
ones_like
(
decoder_input_ids
)
decoder_attention_mask
=
jnp
.
ones_like
(
decoder_input_ids
)
...
@@ -472,7 +472,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
...
@@ -472,7 +472,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
return_dict
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
return_dict
if
attention_mask
is
None
:
if
attention_mask
is
None
:
attention_mask
=
jnp
.
ones_like
(
inputs
)
attention_mask
=
jnp
.
ones_like
(
inputs
,
dtype
=
"i4"
)
# Handle any PRNG if needed
# Handle any PRNG if needed
rngs
=
{}
rngs
=
{}
...
@@ -485,7 +485,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
...
@@ -485,7 +485,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
outputs
=
self
.
module
.
apply
(
outputs
=
self
.
module
.
apply
(
{
"params"
:
params
or
self
.
params
},
{
"params"
:
params
or
self
.
params
},
inputs
=
jnp
.
array
(
inputs
,
dtype
=
"
i
4"
),
inputs
=
jnp
.
array
(
inputs
,
dtype
=
"
f
4"
),
attention_mask
=
jnp
.
array
(
attention_mask
,
dtype
=
"i4"
),
attention_mask
=
jnp
.
array
(
attention_mask
,
dtype
=
"i4"
),
output_attentions
=
output_attentions
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
output_hidden_states
=
output_hidden_states
,
...
@@ -680,7 +680,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
...
@@ -680,7 +680,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
# prepare encoder inputs
# prepare encoder inputs
if
attention_mask
is
None
:
if
attention_mask
is
None
:
attention_mask
=
jnp
.
ones_like
(
inputs
)
attention_mask
=
jnp
.
ones_like
(
inputs
,
dtype
=
"i4"
)
# prepare decoder inputs
# prepare decoder inputs
if
decoder_input_ids
is
None
:
if
decoder_input_ids
is
None
:
...
@@ -700,7 +700,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
...
@@ -700,7 +700,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
return
self
.
module
.
apply
(
return
self
.
module
.
apply
(
{
"params"
:
params
or
self
.
params
},
{
"params"
:
params
or
self
.
params
},
inputs
=
jnp
.
array
(
inputs
,
dtype
=
"
i
4"
),
inputs
=
jnp
.
array
(
inputs
,
dtype
=
"
f
4"
),
attention_mask
=
jnp
.
array
(
attention_mask
,
dtype
=
"i4"
),
attention_mask
=
jnp
.
array
(
attention_mask
,
dtype
=
"i4"
),
decoder_input_ids
=
jnp
.
array
(
decoder_input_ids
,
dtype
=
"i4"
),
decoder_input_ids
=
jnp
.
array
(
decoder_input_ids
,
dtype
=
"i4"
),
decoder_attention_mask
=
jnp
.
array
(
decoder_attention_mask
,
dtype
=
"i4"
),
decoder_attention_mask
=
jnp
.
array
(
decoder_attention_mask
,
dtype
=
"i4"
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment