Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Uni-Core
Commits
f4ce5889
Commit
f4ce5889
authored
Aug 26, 2022
by
Guolin Ke
Browse files
fix a possible type bug in layernorm
parent
f913d977
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
3 deletions
+3
-3
unicore/modules/layer_norm.py
unicore/modules/layer_norm.py
+3
-3
No files found.
unicore/modules/layer_norm.py
View file @
f4ce5889
...
@@ -60,14 +60,14 @@ class LayerNorm(torch.nn.Module):
...
@@ -60,14 +60,14 @@ class LayerNorm(torch.nn.Module):
self
.
reset_parameters
()
self
.
reset_parameters
()
def
torch_layer_norm
(
input
):
def
torch_layer_norm
(
input
):
return
F
.
layer_norm
(
return
F
.
layer_norm
(
input
,
self
.
normalized_shape
,
self
.
weight
,
self
.
bias
,
self
.
eps
)
input
,
self
.
normalized_shape
,
self
.
weight
.
type
(
input
.
dtype
),
self
.
bias
.
type
(
input
.
dtype
)
,
self
.
eps
)
def
fused_layer_norm
(
input
):
def
fused_layer_norm
(
input
):
if
input
.
is_cuda
:
if
input
.
is_cuda
:
return
FusedLayerNormFastFunction
.
apply
(
return
FusedLayerNormFastFunction
.
apply
(
input
,
self
.
weight
,
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
input
,
self
.
weight
.
type
(
input
.
dtype
),
self
.
bias
.
type
(
input
.
dtype
)
,
self
.
normalized_shape
,
self
.
eps
)
else
:
else
:
return
F
.
layer_norm
(
return
F
.
layer_norm
(
input
,
self
.
normalized_shape
,
self
.
weight
,
self
.
bias
,
self
.
eps
)
input
,
self
.
normalized_shape
,
self
.
weight
.
type
(
input
.
dtype
),
self
.
bias
.
type
(
input
.
dtype
)
,
self
.
eps
)
self
.
func
=
torch_layer_norm
if
(
not
HAS_LAYER_NORM
or
normalized_shape
[
0
]
not
in
FUSED_LAYER_NORM_SUPPORT_DIM
)
else
fused_layer_norm
self
.
func
=
torch_layer_norm
if
(
not
HAS_LAYER_NORM
or
normalized_shape
[
0
]
not
in
FUSED_LAYER_NORM_SUPPORT_DIM
)
else
fused_layer_norm
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment