Unverified Commit 77cb2ab7 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

️ [CLAP] Fix dtype of logit scales in init (#25682)

[CLAP] Fix dtype of logit scales
parent 2cf87e2b
......@@ -18,7 +18,6 @@ import math
from dataclasses import dataclass
from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
......@@ -1956,8 +1955,8 @@ class ClapModel(ClapPreTrainedModel):
text_config = config.text_config
audio_config = config.audio_config
self.logit_scale_a = nn.Parameter(torch.tensor(np.log(config.logit_scale_init_value)))
self.logit_scale_t = nn.Parameter(torch.tensor(np.log(config.logit_scale_init_value)))
self.logit_scale_a = nn.Parameter(torch.log(torch.tensor(config.logit_scale_init_value)))
self.logit_scale_t = nn.Parameter(torch.log(torch.tensor(config.logit_scale_init_value)))
self.projection_dim = config.projection_dim
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment