Unverified Commit ad0cba08 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[FlaxSpeechEncoderDecoder] Fix dtype bug (#16581)

* [FlaxSpeechEncoderDecoder] Fix dtype bug

* more fixes
parent 60d27b1f
......@@ -310,7 +310,7 @@ class FlaxSpeechEncoderDecoderModule(nn.Module):
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_last_hidden_state=encoder_hidden_states,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
......@@ -363,8 +363,8 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
encoder_input_shape, decoder_input_shape = input_shape
# init input DeviceArrays
inputs = jnp.zeros(encoder_input_shape, dtype="i4")
attention_mask = jnp.ones_like(inputs)
inputs = jnp.zeros(encoder_input_shape, dtype="f4")
attention_mask = jnp.ones_like(inputs, dtype="i4")
decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4")
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
......@@ -472,7 +472,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.return_dict
if attention_mask is None:
attention_mask = jnp.ones_like(inputs)
attention_mask = jnp.ones_like(inputs, dtype="i4")
# Handle any PRNG if needed
rngs = {}
......@@ -485,7 +485,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
outputs = self.module.apply(
{"params": params or self.params},
inputs=jnp.array(inputs, dtype="i4"),
inputs=jnp.array(inputs, dtype="f4"),
attention_mask=jnp.array(attention_mask, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
......@@ -680,7 +680,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
# prepare encoder inputs
if attention_mask is None:
attention_mask = jnp.ones_like(inputs)
attention_mask = jnp.ones_like(inputs, dtype="i4")
# prepare decoder inputs
if decoder_input_ids is None:
......@@ -700,7 +700,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
return self.module.apply(
{"params": params or self.params},
inputs=jnp.array(inputs, dtype="i4"),
inputs=jnp.array(inputs, dtype="f4"),
attention_mask=jnp.array(attention_mask, dtype="i4"),
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment