Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c164c651
Unverified
Commit
c164c651
authored
Sep 08, 2021
by
Suraj Patil
Committed by
GitHub
Sep 08, 2021
Browse files
[CLIP] fix logit_scale init (#13436)
* fix logit_scale init * add logit_scale_init_value as config param
parent
f667d5b2
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
42 additions
and
3 deletions
+42
-3
src/transformers/models/clip/configuration_clip.py
src/transformers/models/clip/configuration_clip.py
+11
-1
src/transformers/models/clip/modeling_clip.py
src/transformers/models/clip/modeling_clip.py
+1
-1
src/transformers/models/clip/modeling_flax_clip.py
src/transformers/models/clip/modeling_flax_clip.py
+4
-1
tests/test_modeling_clip.py
tests/test_modeling_clip.py
+26
-0
No files found.
src/transformers/models/clip/configuration_clip.py
View file @
c164c651
...
@@ -230,6 +230,8 @@ class CLIPConfig(PretrainedConfig):
...
@@ -230,6 +230,8 @@ class CLIPConfig(PretrainedConfig):
Dictionary of configuration options used to initialize :class:`~transformers.CLIPVisionConfig`.
Dictionary of configuration options used to initialize :class:`~transformers.CLIPVisionConfig`.
projection_dim (:obj:`int`, `optional`, defaults to 512):
projection_dim (:obj:`int`, `optional`, defaults to 512):
Dimentionality of text and vision projection layers.
Dimentionality of text and vision projection layers.
logit_scale_init_value (:obj:`float`, `optional`, defaults to 2.6592):
The inital value of the `logit_scale` paramter. Default is used as per the original CLIP implementation.
kwargs (`optional`):
kwargs (`optional`):
Dictionary of keyword arguments.
Dictionary of keyword arguments.
"""
"""
...
@@ -237,7 +239,14 @@ class CLIPConfig(PretrainedConfig):
...
@@ -237,7 +239,14 @@ class CLIPConfig(PretrainedConfig):
model_type
=
"clip"
model_type
=
"clip"
is_composition
=
True
is_composition
=
True
def
__init__
(
self
,
text_config_dict
=
None
,
vision_config_dict
=
None
,
projection_dim
=
512
,
**
kwargs
):
def
__init__
(
self
,
text_config_dict
=
None
,
vision_config_dict
=
None
,
projection_dim
=
512
,
logit_scale_init_value
=
2.6592
,
**
kwargs
):
super
().
__init__
(
text_config_dict
=
text_config_dict
,
vision_config_dict
=
vision_config_dict
,
**
kwargs
)
super
().
__init__
(
text_config_dict
=
text_config_dict
,
vision_config_dict
=
vision_config_dict
,
**
kwargs
)
if
text_config_dict
is
None
:
if
text_config_dict
is
None
:
...
@@ -252,6 +261,7 @@ class CLIPConfig(PretrainedConfig):
...
@@ -252,6 +261,7 @@ class CLIPConfig(PretrainedConfig):
self
.
vision_config
=
CLIPVisionConfig
(
**
vision_config_dict
)
self
.
vision_config
=
CLIPVisionConfig
(
**
vision_config_dict
)
self
.
projection_dim
=
projection_dim
self
.
projection_dim
=
projection_dim
self
.
logit_scale_init_value
=
logit_scale_init_value
self
.
initializer_factor
=
1.0
self
.
initializer_factor
=
1.0
@
classmethod
@
classmethod
...
...
src/transformers/models/clip/modeling_clip.py
View file @
c164c651
...
@@ -858,7 +858,7 @@ class CLIPModel(CLIPPreTrainedModel):
...
@@ -858,7 +858,7 @@ class CLIPModel(CLIPPreTrainedModel):
self
.
visual_projection
=
nn
.
Linear
(
self
.
vision_embed_dim
,
self
.
projection_dim
,
bias
=
False
)
self
.
visual_projection
=
nn
.
Linear
(
self
.
vision_embed_dim
,
self
.
projection_dim
,
bias
=
False
)
self
.
text_projection
=
nn
.
Linear
(
self
.
text_embed_dim
,
self
.
projection_dim
,
bias
=
False
)
self
.
text_projection
=
nn
.
Linear
(
self
.
text_embed_dim
,
self
.
projection_dim
,
bias
=
False
)
self
.
logit_scale
=
nn
.
Parameter
(
torch
.
ones
([]))
self
.
logit_scale
=
nn
.
Parameter
(
torch
.
ones
([])
*
self
.
config
.
logit_scale_init_value
)
self
.
init_weights
()
self
.
init_weights
()
...
...
src/transformers/models/clip/modeling_flax_clip.py
View file @
c164c651
...
@@ -1041,7 +1041,10 @@ class FlaxCLIPModule(nn.Module):
...
@@ -1041,7 +1041,10 @@ class FlaxCLIPModule(nn.Module):
kernel_init
=
jax
.
nn
.
initializers
.
normal
(
0.02
,
dtype
=
self
.
dtype
),
kernel_init
=
jax
.
nn
.
initializers
.
normal
(
0.02
,
dtype
=
self
.
dtype
),
use_bias
=
False
,
use_bias
=
False
,
)
)
self
.
logit_scale
=
self
.
param
(
"logit_scale"
,
jax
.
nn
.
initializers
.
ones
,
[])
self
.
logit_scale
=
self
.
param
(
"logit_scale"
,
lambda
_
,
shape
:
jnp
.
ones
(
shape
,
dtype
=
self
.
dtype
)
*
self
.
config
.
logit_scale_init_value
,
[]
)
def
__call__
(
def
__call__
(
self
,
self
,
...
...
tests/test_modeling_clip.py
View file @
c164c651
...
@@ -20,6 +20,8 @@ import os
...
@@ -20,6 +20,8 @@ import os
import
tempfile
import
tempfile
import
unittest
import
unittest
import
numpy
as
np
import
requests
import
requests
from
transformers
import
CLIPConfig
,
CLIPTextConfig
,
CLIPVisionConfig
from
transformers
import
CLIPConfig
,
CLIPTextConfig
,
CLIPVisionConfig
from
transformers.file_utils
import
is_torch_available
,
is_vision_available
from
transformers.file_utils
import
is_torch_available
,
is_vision_available
...
@@ -478,6 +480,30 @@ class CLIPModelTest(ModelTesterMixin, unittest.TestCase):
...
@@ -478,6 +480,30 @@ class CLIPModelTest(ModelTesterMixin, unittest.TestCase):
def
test_model_common_attributes
(
self
):
def
test_model_common_attributes
(
self
):
pass
pass
# override as the `logit_scale` parameter initilization is different for CLIP
def
test_initialization
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
configs_no_init
=
_config_zero_init
(
config
)
for
model_class
in
self
.
all_model_classes
:
model
=
model_class
(
config
=
configs_no_init
)
for
name
,
param
in
model
.
named_parameters
():
if
param
.
requires_grad
:
# check if `logit_scale` is initilized as per the original implementation
if
name
==
"logit_scale"
:
self
.
assertAlmostEqual
(
param
.
data
.
item
(),
np
.
log
(
1
/
0.07
),
delta
=
1e-3
,
msg
=
f
"Parameter
{
name
}
of model
{
model_class
}
seems not properly initialized"
,
)
else
:
self
.
assertIn
(
((
param
.
data
.
mean
()
*
1e9
).
round
()
/
1e9
).
item
(),
[
0.0
,
1.0
],
msg
=
f
"Parameter
{
name
}
of model
{
model_class
}
seems not properly initialized"
,
)
def
_create_and_check_torchscript
(
self
,
config
,
inputs_dict
):
def
_create_and_check_torchscript
(
self
,
config
,
inputs_dict
):
if
not
self
.
test_torchscript
:
if
not
self
.
test_torchscript
:
return
return
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment