Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0c7f93f5
Unverified
Commit
0c7f93f5
authored
Feb 27, 2023
by
fxmarty
Committed by
GitHub
Feb 27, 2023
Browse files
Fix nn.init.trunc_normal_ call on torch.float16 data (#21789)
fix nn.init.trunc_normal_ call on half data
parent
ebf84f07
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
12 deletions
+12
-12
src/transformers/models/vit/modeling_vit.py
src/transformers/models/vit/modeling_vit.py
+6
-6
src/transformers/models/vit_hybrid/modeling_vit_hybrid.py
src/transformers/models/vit_hybrid/modeling_vit_hybrid.py
+6
-6
No files found.
src/transformers/models/vit/modeling_vit.py
View file @
0c7f93f5
...
@@ -449,17 +449,17 @@ class ViTPreTrainedModel(PreTrainedModel):
...
@@ -449,17 +449,17 @@ class ViTPreTrainedModel(PreTrainedModel):
module
.
bias
.
data
.
zero_
()
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
1.0
)
module
.
weight
.
data
.
fill_
(
1.0
)
elif
isinstance
(
module
,
ViTEmbeddings
):
elif
isinstance
(
module
,
ViTEmbeddings
):
nn
.
init
.
trunc_normal_
(
module
.
position_embeddings
.
data
=
nn
.
init
.
trunc_normal_
(
module
.
position_embeddings
,
module
.
position_embeddings
.
data
.
to
(
torch
.
float32
)
,
mean
=
0.0
,
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
,
std
=
self
.
config
.
initializer_range
,
)
)
.
to
(
module
.
position_embeddings
.
dtype
)
nn
.
init
.
trunc_normal_
(
module
.
cls_token
.
data
=
nn
.
init
.
trunc_normal_
(
module
.
cls_token
,
module
.
cls_token
.
data
.
to
(
torch
.
float32
)
,
mean
=
0.0
,
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
,
std
=
self
.
config
.
initializer_range
,
)
)
.
to
(
module
.
cls_token
.
dtype
)
def
_set_gradient_checkpointing
(
self
,
module
:
ViTEncoder
,
value
:
bool
=
False
)
->
None
:
def
_set_gradient_checkpointing
(
self
,
module
:
ViTEncoder
,
value
:
bool
=
False
)
->
None
:
if
isinstance
(
module
,
ViTEncoder
):
if
isinstance
(
module
,
ViTEncoder
):
...
...
src/transformers/models/vit_hybrid/modeling_vit_hybrid.py
View file @
0c7f93f5
...
@@ -474,17 +474,17 @@ class ViTHybridPreTrainedModel(PreTrainedModel):
...
@@ -474,17 +474,17 @@ class ViTHybridPreTrainedModel(PreTrainedModel):
module
.
bias
.
data
.
zero_
()
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
1.0
)
module
.
weight
.
data
.
fill_
(
1.0
)
elif
isinstance
(
module
,
ViTHybridEmbeddings
):
elif
isinstance
(
module
,
ViTHybridEmbeddings
):
nn
.
init
.
trunc_normal_
(
module
.
position_embeddings
.
data
=
nn
.
init
.
trunc_normal_
(
module
.
position_embeddings
,
module
.
position_embeddings
.
data
.
to
(
torch
.
float32
)
,
mean
=
0.0
,
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
,
std
=
self
.
config
.
initializer_range
,
)
)
.
to
(
module
.
position_embeddings
.
dtype
)
nn
.
init
.
trunc_normal_
(
module
.
cls_token
.
data
=
nn
.
init
.
trunc_normal_
(
module
.
cls_token
,
module
.
cls_token
.
data
.
to
(
torch
.
float32
)
,
mean
=
0.0
,
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
,
std
=
self
.
config
.
initializer_range
,
)
)
.
to
(
module
.
cls_token
.
dtype
)
def
_set_gradient_checkpointing
(
self
,
module
:
ViTHybridEncoder
,
value
:
bool
=
False
)
->
None
:
def
_set_gradient_checkpointing
(
self
,
module
:
ViTHybridEncoder
,
value
:
bool
=
False
)
->
None
:
if
isinstance
(
module
,
ViTHybridEncoder
):
if
isinstance
(
module
,
ViTHybridEncoder
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment