"docs/vscode:/vscode.git/clone" did not exist on "71b19ee2518c9d0bf20e75599d19461be4c85b91"
Unverified Commit c87bbe1f authored by Sylvain Gugger's avatar Sylvain Gugger
Browse files

Fix quality

parent 011cc17a
......@@ -821,7 +821,7 @@ class FlaxWhisperPreTrainedModel(FlaxPreTrainedModel):
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
**kwargs,
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
......@@ -1348,7 +1348,7 @@ class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel):
task=None,
language=None,
is_multilingual=None,
**kwargs
**kwargs,
):
if generation_config is None:
generation_config = self.generation_config
......@@ -1411,7 +1411,7 @@ class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel):
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
encoder_outputs=None,
**kwargs
**kwargs,
):
# initializing the cache
batch_size, seq_length = decoder_input_ids.shape
......
......@@ -34,11 +34,11 @@ if is_datasets_available():
from datasets import load_dataset
if is_flax_available():
import numpy as np
import jax
import numpy as np
from flax.core.frozen_dict import unfreeze
from flax.traverse_util import flatten_dict
from transformers import (
FLAX_MODEL_MAPPING,
FlaxWhisperForConditionalGeneration,
......
......@@ -51,6 +51,7 @@ if is_torch_available():
if is_flax_available():
import jax.numpy as jnp
from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment