Unverified Commit ab96bf02 authored by raghavanone's avatar raghavanone Committed by GitHub
Browse files

Add gradient_checkpointing parameter to FlaxWhisperEncoder (#23300)

Add gradient_checkpointing parameter
parent 83eda643
......@@ -1515,7 +1515,9 @@ class FlaxWhisperForAudioClassificationModule(nn.Module):
gradient_checkpointing: bool = False
def setup(self) -> None:
self.encoder = FlaxWhisperEncoder(config=self.config, dtype=self.dtype)
self.encoder = FlaxWhisperEncoder(
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.config.is_encoder_decoder = False
num_layers = self.config.num_hidden_layers + 1
if self.config.use_weighted_layer_sum:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment