Unverified Commit c87bbe1f authored by Sylvain Gugger's avatar Sylvain Gugger
Browse files

Fix quality

parent 011cc17a
...@@ -821,7 +821,7 @@ class FlaxWhisperPreTrainedModel(FlaxPreTrainedModel): ...@@ -821,7 +821,7 @@ class FlaxWhisperPreTrainedModel(FlaxPreTrainedModel):
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True, _do_init: bool = True,
**kwargs **kwargs,
): ):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
...@@ -1348,7 +1348,7 @@ class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel): ...@@ -1348,7 +1348,7 @@ class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel):
task=None, task=None,
language=None, language=None,
is_multilingual=None, is_multilingual=None,
**kwargs **kwargs,
): ):
if generation_config is None: if generation_config is None:
generation_config = self.generation_config generation_config = self.generation_config
...@@ -1411,7 +1411,7 @@ class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel): ...@@ -1411,7 +1411,7 @@ class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel):
attention_mask: Optional[jnp.DeviceArray] = None, attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None, decoder_attention_mask: Optional[jnp.DeviceArray] = None,
encoder_outputs=None, encoder_outputs=None,
**kwargs **kwargs,
): ):
# initializing the cache # initializing the cache
batch_size, seq_length = decoder_input_ids.shape batch_size, seq_length = decoder_input_ids.shape
......
...@@ -34,11 +34,11 @@ if is_datasets_available(): ...@@ -34,11 +34,11 @@ if is_datasets_available():
from datasets import load_dataset from datasets import load_dataset
if is_flax_available(): if is_flax_available():
import numpy as np
import jax import jax
import numpy as np
from flax.core.frozen_dict import unfreeze from flax.core.frozen_dict import unfreeze
from flax.traverse_util import flatten_dict from flax.traverse_util import flatten_dict
from transformers import ( from transformers import (
FLAX_MODEL_MAPPING, FLAX_MODEL_MAPPING,
FlaxWhisperForConditionalGeneration, FlaxWhisperForConditionalGeneration,
......
...@@ -51,6 +51,7 @@ if is_torch_available(): ...@@ -51,6 +51,7 @@ if is_torch_available():
if is_flax_available(): if is_flax_available():
import jax.numpy as jnp import jax.numpy as jnp
from transformers.modeling_flax_pytorch_utils import ( from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax, convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model, load_flax_weights_in_pytorch_model,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment