Unverified Commit 7b175cfa authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Flax Whisper] large-v3 compatibility (#27360)

parent 845aa832
...@@ -867,7 +867,7 @@ class FlaxWhisperPreTrainedModel(FlaxPreTrainedModel): ...@@ -867,7 +867,7 @@ class FlaxWhisperPreTrainedModel(FlaxPreTrainedModel):
def __init__( def __init__(
self, self,
config: WhisperConfig, config: WhisperConfig,
input_shape: Tuple[int] = (1, 80, 3000), input_shape: Tuple[int] = None,
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True, _do_init: bool = True,
...@@ -875,6 +875,8 @@ class FlaxWhisperPreTrainedModel(FlaxPreTrainedModel): ...@@ -875,6 +875,8 @@ class FlaxWhisperPreTrainedModel(FlaxPreTrainedModel):
**kwargs, **kwargs,
): ):
module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs) module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
if input_shape is None:
input_shape = (1, config.num_mel_bins, 2 * config.max_source_positions)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def enable_gradient_checkpointing(self): def enable_gradient_checkpointing(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment