Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0b86e330
Unverified
Commit
0b86e330
authored
Jan 19, 2023
by
Younes Belkada
Committed by
GitHub
Jan 19, 2023
Browse files
[`CVT`] Fix module initialization issue (#21193)
fix cvt init
parent
b9403e95
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
5 deletions
+6
-5
src/transformers/models/cvt/modeling_cvt.py
src/transformers/models/cvt/modeling_cvt.py
+6
-5
No files found.
src/transformers/models/cvt/modeling_cvt.py
View file @
0b86e330
...
@@ -451,11 +451,7 @@ class CvtStage(nn.Module):
...
@@ -451,11 +451,7 @@ class CvtStage(nn.Module):
self
.
config
=
config
self
.
config
=
config
self
.
stage
=
stage
self
.
stage
=
stage
if
self
.
config
.
cls_token
[
self
.
stage
]:
if
self
.
config
.
cls_token
[
self
.
stage
]:
self
.
cls_token
=
nn
.
Parameter
(
self
.
cls_token
=
nn
.
Parameter
(
torch
.
randn
(
1
,
1
,
self
.
config
.
embed_dim
[
-
1
]))
nn
.
init
.
trunc_normal_
(
torch
.
zeros
(
1
,
1
,
self
.
config
.
embed_dim
[
-
1
]),
mean
=
0.0
,
std
=
config
.
initializer_range
)
)
self
.
embedding
=
CvtEmbeddings
(
self
.
embedding
=
CvtEmbeddings
(
patch_size
=
config
.
patch_sizes
[
self
.
stage
],
patch_size
=
config
.
patch_sizes
[
self
.
stage
],
...
@@ -557,6 +553,11 @@ class CvtPreTrainedModel(PreTrainedModel):
...
@@ -557,6 +553,11 @@ class CvtPreTrainedModel(PreTrainedModel):
elif
isinstance
(
module
,
nn
.
LayerNorm
):
elif
isinstance
(
module
,
nn
.
LayerNorm
):
module
.
bias
.
data
.
zero_
()
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
1.0
)
module
.
weight
.
data
.
fill_
(
1.0
)
elif
isinstance
(
module
,
CvtStage
):
if
self
.
config
.
cls_token
[
module
.
stage
]:
module
.
cls_token
.
data
=
nn
.
init
.
trunc_normal_
(
torch
.
zeros
(
1
,
1
,
self
.
config
.
embed_dim
[
-
1
]),
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
CVT_START_DOCSTRING
=
r
"""
CVT_START_DOCSTRING
=
r
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment