Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
56b8d49d
Unverified
Commit
56b8d49d
authored
May 03, 2023
by
Alara Dirik
Committed by
GitHub
May 03, 2023
Browse files
Fix ConvNext V2 paramater naming issue (#23122)
Fixes the parameter naming issue in ConvNextV2GRN module
parent
b53004fd
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
3 deletions
+7
-3
src/transformers/models/convnextv2/convert_convnextv2_to_pytorch.py
...ormers/models/convnextv2/convert_convnextv2_to_pytorch.py
+4
-0
src/transformers/models/convnextv2/modeling_convnextv2.py
src/transformers/models/convnextv2/modeling_convnextv2.py
+3
-3
No files found.
src/transformers/models/convnextv2/convert_convnextv2_to_pytorch.py
View file @
56b8d49d
...
...
@@ -99,6 +99,10 @@ def rename_key(name):
if
"stages"
in
name
and
"downsampling_layer"
not
in
name
:
# stages.0.0. for instance should be renamed to stages.0.layers.0.
name
=
name
[:
len
(
"stages.0"
)]
+
".layers"
+
name
[
len
(
"stages.0"
)
:]
if
"gamma"
in
name
:
name
=
name
.
replace
(
"gamma"
,
"weight"
)
if
"beta"
in
name
:
name
=
name
.
replace
(
"beta"
,
"bias"
)
if
"stages"
in
name
:
name
=
name
.
replace
(
"stages"
,
"encoder.stages"
)
if
"norm"
in
name
:
...
...
src/transformers/models/convnextv2/modeling_convnextv2.py
View file @
56b8d49d
...
...
@@ -100,14 +100,14 @@ class ConvNextV2GRN(nn.Module):
def
__init__
(
self
,
dim
:
int
):
super
().
__init__
()
self
.
gamma
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
1
,
dim
))
self
.
b
eta
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
1
,
dim
))
self
.
weight
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
1
,
dim
))
self
.
b
ias
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
1
,
dim
))
def
forward
(
self
,
hidden_states
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
# Compute and normalize global spatial feature maps
global_features
=
torch
.
norm
(
hidden_states
,
p
=
2
,
dim
=
(
1
,
2
),
keepdim
=
True
)
norm_features
=
global_features
/
(
global_features
.
mean
(
dim
=-
1
,
keepdim
=
True
)
+
1e-6
)
hidden_states
=
self
.
gamma
*
(
hidden_states
*
norm_features
)
+
self
.
b
eta
+
hidden_states
hidden_states
=
self
.
weight
*
(
hidden_states
*
norm_features
)
+
self
.
b
ias
+
hidden_states
return
hidden_states
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment