Unverified Commit af556a09 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

add `accelerate` support for `Whisper` (#19697)

parent fb0bd7b7
......@@ -446,6 +446,7 @@ class WhisperPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
main_input_name = "input_features"
supports_gradient_checkpointing = True
_no_split_modules = ["WhisperEncoderLayer"]
def _init_weights(self, module):
std = self.config.init_std
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment