Unverified Commit 48fbd8fa authored by Sourab Mangrulkar's avatar Sourab Mangrulkar Committed by GitHub
Browse files

fix `_no_split_modules` for Whisper model (#22486)

parent 90067748
...@@ -577,7 +577,7 @@ class WhisperPreTrainedModel(PreTrainedModel): ...@@ -577,7 +577,7 @@ class WhisperPreTrainedModel(PreTrainedModel):
base_model_prefix = "model" base_model_prefix = "model"
main_input_name = "input_features" main_input_name = "input_features"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["WhisperEncoderLayer"] _no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"]
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.init_std std = self.config.init_std
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment