"test/git@developer.sourcefind.cn:change/sglang.git" did not exist on "3ddb1c467979eb13afc629506ea80806935390e8"
Unverified Commit 7b175cfa authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Flax Whisper] large-v3 compatibility (#27360)

parent 845aa832
......@@ -867,7 +867,7 @@ class FlaxWhisperPreTrainedModel(FlaxPreTrainedModel):
def __init__(
self,
config: WhisperConfig,
input_shape: Tuple[int] = (1, 80, 3000),
input_shape: Tuple[int] = None,
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
......@@ -875,6 +875,8 @@ class FlaxWhisperPreTrainedModel(FlaxPreTrainedModel):
**kwargs,
):
module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
if input_shape is None:
input_shape = (1, config.num_mel_bins, 2 * config.max_source_positions)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def enable_gradient_checkpointing(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment