Unverified Commit a8eb4f79 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`CLAP`] Fix few broken things (#21670)



* add `is_longer`

* fix docstring

* fix config class

* fix loss

* fix all doctests

* fix order

* fix last failing tests

---------
Co-authored-by: default avatararthur.zucker@gmail.com <arthur.zucker@gmail.com>
parent 3668ec17
......@@ -898,8 +898,8 @@ class ClapAudioEncoder(nn.Module):
def forward(
self,
input_features,
head_mask: Optional[torch.FloatTensor] = None,
is_longer: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
output_hidden_states_before_downsampling: Optional[bool] = False,
......@@ -1673,7 +1673,7 @@ class ClapPreTrainedModel(PreTrainedModel):
models.
"""
config_class = ClapTextConfig
config_class = ClapConfig
base_model_prefix = "clap"
supports_gradient_checkpointing = False
_keys_to_ignore_on_load_missing = [r"position_ids", r"logit_scale_a", r"logit_scale_t"]
......@@ -1746,7 +1746,7 @@ class ClapAudioModel(ClapPreTrainedModel):
>>> inputs = processor(audios=audio_sample, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.audio_emmbeds
>>> last_hidden_state = outputs.last_hidden_state
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
......@@ -2069,6 +2069,7 @@ class ClapModel(ClapPreTrainedModel):
self,
input_ids: Optional[torch.LongTensor] = None,
input_features: Optional[torch.FloatTensor] = None,
is_longer: Optional[torch.BoolTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
return_loss: Optional[bool] = None,
......@@ -2108,6 +2109,7 @@ class ClapModel(ClapPreTrainedModel):
audio_outputs = self.audio_model(
input_features=input_features,
is_longer=is_longer,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
......@@ -2141,7 +2143,7 @@ class ClapModel(ClapPreTrainedModel):
loss = None
if return_loss:
caption_loss = contrastive_loss(logits_per_text)
audio_loss = contrastive_loss(logits_per_text.t())
audio_loss = contrastive_loss(logits_per_audio.t())
loss = (caption_loss + audio_loss) / 2.0
if not return_dict:
......@@ -2203,7 +2205,7 @@ class ClapTextModelWithProjection(ClapPreTrainedModel):
>>> model = ClapTextModelWithProjection.from_pretrained("laion/clap-htsat-unfused")
>>> tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused")
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
>>> inputs = tokenizer(["a sound of a cat", "a sound of a dog"], padding=True, return_tensors="pt")
>>> outputs = model(**inputs)
>>> text_embeds = outputs.text_embeds
......
......@@ -268,7 +268,7 @@ class ClapAudioModelTest(ModelTesterMixin, unittest.TestCase):
for model_name in CLAP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = ClapAudioModelWithProjection.from_pretrained(model_name)
self.assertIsNotNone(model)
self.assertTrue(hasattr(model, "visual_projection"))
self.assertTrue(hasattr(model, "audio_projection"))
class ClapTextModelTester:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment