Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ad0cba08
Unverified
Commit
ad0cba08
authored
Apr 04, 2022
by
Patrick von Platen
Committed by
GitHub
Apr 04, 2022
Browse files
[FlaxSpeechEncoderDecoder] Fix dtype bug (#16581)
* [FlaxSpeechEncoderDecoder] Fix dtype bug * more fixes
parent
60d27b1f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
7 deletions
+7
-7
src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py
...h_encoder_decoder/modeling_flax_speech_encoder_decoder.py
+7
-7
No files found.
src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py
View file @
ad0cba08
...
@@ -310,7 +310,7 @@ class FlaxSpeechEncoderDecoderModule(nn.Module):
...
@@ -310,7 +310,7 @@ class FlaxSpeechEncoderDecoderModule(nn.Module):
decoder_hidden_states
=
decoder_outputs
.
hidden_states
,
decoder_hidden_states
=
decoder_outputs
.
hidden_states
,
decoder_attentions
=
decoder_outputs
.
attentions
,
decoder_attentions
=
decoder_outputs
.
attentions
,
cross_attentions
=
decoder_outputs
.
cross_attentions
,
cross_attentions
=
decoder_outputs
.
cross_attentions
,
encoder_last_hidden_state
=
encoder_
outputs
.
last_
hidden_state
,
encoder_last_hidden_state
=
encoder_hidden_state
s
,
encoder_hidden_states
=
encoder_outputs
.
hidden_states
,
encoder_hidden_states
=
encoder_outputs
.
hidden_states
,
encoder_attentions
=
encoder_outputs
.
attentions
,
encoder_attentions
=
encoder_outputs
.
attentions
,
)
)
...
@@ -363,8 +363,8 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
...
@@ -363,8 +363,8 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
encoder_input_shape
,
decoder_input_shape
=
input_shape
encoder_input_shape
,
decoder_input_shape
=
input_shape
# init input DeviceArrays
# init input DeviceArrays
inputs
=
jnp
.
zeros
(
encoder_input_shape
,
dtype
=
"
i
4"
)
inputs
=
jnp
.
zeros
(
encoder_input_shape
,
dtype
=
"
f
4"
)
attention_mask
=
jnp
.
ones_like
(
inputs
)
attention_mask
=
jnp
.
ones_like
(
inputs
,
dtype
=
"i4"
)
decoder_input_ids
=
jnp
.
zeros
(
decoder_input_shape
,
dtype
=
"i4"
)
decoder_input_ids
=
jnp
.
zeros
(
decoder_input_shape
,
dtype
=
"i4"
)
decoder_attention_mask
=
jnp
.
ones_like
(
decoder_input_ids
)
decoder_attention_mask
=
jnp
.
ones_like
(
decoder_input_ids
)
...
@@ -472,7 +472,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
...
@@ -472,7 +472,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
return_dict
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
return_dict
if
attention_mask
is
None
:
if
attention_mask
is
None
:
attention_mask
=
jnp
.
ones_like
(
inputs
)
attention_mask
=
jnp
.
ones_like
(
inputs
,
dtype
=
"i4"
)
# Handle any PRNG if needed
# Handle any PRNG if needed
rngs
=
{}
rngs
=
{}
...
@@ -485,7 +485,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
...
@@ -485,7 +485,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
outputs
=
self
.
module
.
apply
(
outputs
=
self
.
module
.
apply
(
{
"params"
:
params
or
self
.
params
},
{
"params"
:
params
or
self
.
params
},
inputs
=
jnp
.
array
(
inputs
,
dtype
=
"
i
4"
),
inputs
=
jnp
.
array
(
inputs
,
dtype
=
"
f
4"
),
attention_mask
=
jnp
.
array
(
attention_mask
,
dtype
=
"i4"
),
attention_mask
=
jnp
.
array
(
attention_mask
,
dtype
=
"i4"
),
output_attentions
=
output_attentions
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
output_hidden_states
=
output_hidden_states
,
...
@@ -680,7 +680,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
...
@@ -680,7 +680,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
# prepare encoder inputs
# prepare encoder inputs
if
attention_mask
is
None
:
if
attention_mask
is
None
:
attention_mask
=
jnp
.
ones_like
(
inputs
)
attention_mask
=
jnp
.
ones_like
(
inputs
,
dtype
=
"i4"
)
# prepare decoder inputs
# prepare decoder inputs
if
decoder_input_ids
is
None
:
if
decoder_input_ids
is
None
:
...
@@ -700,7 +700,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
...
@@ -700,7 +700,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
return
self
.
module
.
apply
(
return
self
.
module
.
apply
(
{
"params"
:
params
or
self
.
params
},
{
"params"
:
params
or
self
.
params
},
inputs
=
jnp
.
array
(
inputs
,
dtype
=
"
i
4"
),
inputs
=
jnp
.
array
(
inputs
,
dtype
=
"
f
4"
),
attention_mask
=
jnp
.
array
(
attention_mask
,
dtype
=
"i4"
),
attention_mask
=
jnp
.
array
(
attention_mask
,
dtype
=
"i4"
),
decoder_input_ids
=
jnp
.
array
(
decoder_input_ids
,
dtype
=
"i4"
),
decoder_input_ids
=
jnp
.
array
(
decoder_input_ids
,
dtype
=
"i4"
),
decoder_attention_mask
=
jnp
.
array
(
decoder_attention_mask
,
dtype
=
"i4"
),
decoder_attention_mask
=
jnp
.
array
(
decoder_attention_mask
,
dtype
=
"i4"
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment