"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "1ef152eb48e777f9ff848d55e3ade9a47705745f"
Unverified Commit ab96bf02 authored by raghavanone's avatar raghavanone Committed by GitHub
Browse files

Add gradient_checkpointing parameter to FlaxWhisperEncoder (#23300)

Add gradient_checkpointing parameter
parent 83eda643
...@@ -1515,7 +1515,9 @@ class FlaxWhisperForAudioClassificationModule(nn.Module): ...@@ -1515,7 +1515,9 @@ class FlaxWhisperForAudioClassificationModule(nn.Module):
gradient_checkpointing: bool = False gradient_checkpointing: bool = False
def setup(self) -> None: def setup(self) -> None:
self.encoder = FlaxWhisperEncoder(config=self.config, dtype=self.dtype) self.encoder = FlaxWhisperEncoder(
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.config.is_encoder_decoder = False self.config.is_encoder_decoder = False
num_layers = self.config.num_hidden_layers + 1 num_layers = self.config.num_hidden_layers + 1
if self.config.use_weighted_layer_sum: if self.config.use_weighted_layer_sum:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment