Unverified Commit d979cf6e authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`Whiper`] add `get_input_embeddings` to `WhisperForAudioClassification` (#22133)



* add `get_input_embeddings` to `WhisperForAudioClassification`

* add common tests

* fix another common test

* Update tests/models/whisper/test_modeling_whisper.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* fix style

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 98797237
...@@ -767,6 +767,12 @@ class WhisperEncoder(WhisperPreTrainedModel): ...@@ -767,6 +767,12 @@ class WhisperEncoder(WhisperPreTrainedModel):
param.requires_grad = False param.requires_grad = False
self._requires_grad = False self._requires_grad = False
def get_input_embeddings(self) -> nn.Module:
return self.conv1
def set_input_embeddings(self, value: nn.Module):
self.conv1 = value
def forward( def forward(
self, self,
input_features, input_features,
...@@ -1023,7 +1029,10 @@ class WhisperDecoder(WhisperPreTrainedModel): ...@@ -1023,7 +1029,10 @@ class WhisperDecoder(WhisperPreTrainedModel):
) )
# embed positions # embed positions
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) if input_ids is not None:
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
else:
positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
hidden_states = inputs_embeds + positions hidden_states = inputs_embeds + positions
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
...@@ -1330,6 +1339,9 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): ...@@ -1330,6 +1339,9 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
def set_output_embeddings(self, new_embeddings): def set_output_embeddings(self, new_embeddings):
self.proj_out = new_embeddings self.proj_out = new_embeddings
def get_input_embeddings(self) -> nn.Module:
return self.model.get_input_embeddings()
def freeze_encoder(self): def freeze_encoder(self):
""" """
Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
...@@ -1635,6 +1647,12 @@ class WhisperForAudioClassification(WhisperPreTrainedModel): ...@@ -1635,6 +1647,12 @@ class WhisperForAudioClassification(WhisperPreTrainedModel):
""" """
self.encoder._freeze_parameters() self.encoder._freeze_parameters()
def get_input_embeddings(self) -> nn.Module:
return self.encoder.get_input_embeddings()
def set_input_embeddings(self, value: nn.Module):
self.encoder.set_input_embeddings(value)
@add_start_docstrings_to_model_forward(WHISPER_ENCODER_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(WHISPER_ENCODER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
......
...@@ -357,9 +357,24 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -357,9 +357,24 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
return config, input_ids, None, max_length return config, input_ids, None, max_length
# not implemented currently
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
decoder_input_ids = inputs.pop("decoder_input_ids", None)
inputs.pop("decoder_attention_mask", None)
wte = model.get_input_embeddings()
inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
with torch.no_grad():
model(**inputs)[0]
# training is not supported yet # training is not supported yet
def test_training(self): def test_training(self):
...@@ -1566,9 +1581,16 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest. ...@@ -1566,9 +1581,16 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
self.assertTrue((outputs_embeds == outputs).all()) self.assertTrue((outputs_embeds == outputs).all())
# WhisperEncoder has no inputs_embeds and thus the `get_input_embeddings` fn is not implemented # Needs to override as the encoder input embedding is a Conv1d
def test_model_common_attributes(self): def test_model_common_attributes(self):
pass config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (torch.nn.Conv1d))
model.set_input_embeddings(torch.nn.Conv1d(10, 10, 3))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, torch.nn.Conv1d))
# WhisperEncoder cannot resize token embeddings since it has no tokens embeddings # WhisperEncoder cannot resize token embeddings since it has no tokens embeddings
def test_resize_tokens_embeddings(self): def test_resize_tokens_embeddings(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment