"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e3ff165aa54c07c0371deb09671e3c7dd5666a99"
Unverified Commit 77cb2ab7 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

️ [CLAP] Fix dtype of logit scales in init (#25682)

[CLAP] Fix dtype of logit scales
parent 2cf87e2b
...@@ -18,7 +18,6 @@ import math ...@@ -18,7 +18,6 @@ import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, List, Optional, Tuple, Union from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
...@@ -1956,8 +1955,8 @@ class ClapModel(ClapPreTrainedModel): ...@@ -1956,8 +1955,8 @@ class ClapModel(ClapPreTrainedModel):
text_config = config.text_config text_config = config.text_config
audio_config = config.audio_config audio_config = config.audio_config
self.logit_scale_a = nn.Parameter(torch.tensor(np.log(config.logit_scale_init_value))) self.logit_scale_a = nn.Parameter(torch.log(torch.tensor(config.logit_scale_init_value)))
self.logit_scale_t = nn.Parameter(torch.tensor(np.log(config.logit_scale_init_value))) self.logit_scale_t = nn.Parameter(torch.log(torch.tensor(config.logit_scale_init_value)))
self.projection_dim = config.projection_dim self.projection_dim = config.projection_dim
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment