Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
d130ec1f
Commit
d130ec1f
authored
Apr 10, 2019
by
Lam Dang
Browse files
quick fix: make FusedLayerNorm compatible with cpu
parent
683b6e0e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
0 deletions
+45
-0
apex/normalization/fused_layer_norm.py
apex/normalization/fused_layer_norm.py
+4
-0
tests/L0/run_fused_layer_norm/test_fused_layer_norm.py
tests/L0/run_fused_layer_norm/test_fused_layer_norm.py
+41
-0
No files found.
apex/normalization/fused_layer_norm.py
View file @
d130ec1f
...
@@ -3,6 +3,7 @@ import torch
...
@@ -3,6 +3,7 @@ import torch
import
numbers
import
numbers
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
torch.nn
import
init
from
torch.nn
import
init
from
torch.nn
import
functional
as
F
import
importlib
import
importlib
class
FusedLayerNormAffineFunction
(
torch
.
autograd
.
Function
):
class
FusedLayerNormAffineFunction
(
torch
.
autograd
.
Function
):
...
@@ -144,6 +145,9 @@ class FusedLayerNorm(torch.nn.Module):
...
@@ -144,6 +145,9 @@ class FusedLayerNorm(torch.nn.Module):
init
.
zeros_
(
self
.
bias
)
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
if
not
input
.
is_cuda
:
return
F
.
layer_norm
(
input
,
self
.
normalized_shape
,
self
.
weight
,
self
.
bias
,
self
.
eps
)
if
self
.
elementwise_affine
:
if
self
.
elementwise_affine
:
return
FusedLayerNormAffineFunction
(
self
.
normalized_shape
,
self
.
eps
)(
return
FusedLayerNormAffineFunction
(
self
.
normalized_shape
,
self
.
eps
)(
input
,
self
.
weight
,
self
.
bias
)
input
,
self
.
weight
,
self
.
bias
)
...
...
tests/L0/run_fused_layer_norm/test_fused_layer_norm.py
0 → 100644
View file @
d130ec1f
import
unittest
import
os
import
random
import
torch
import
apex
class
TestFusedLayerNorm
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
module
=
apex
.
normalization
.
FusedLayerNorm
(
normalized_shape
=
[
32
,
64
],
elementwise_affine
=
False
)
self
.
input_
=
torch
.
randn
(
16
,
32
,
64
)
torch
.
cuda
.
manual_seed
(
42
)
def
forward_cpu
(
self
,
input_
):
self
.
module
.
cpu
()
return
self
.
module
(
input_
.
cpu
())
def
forward_cuda
(
self
,
input_
):
self
.
module
.
cuda
()
return
self
.
module
(
input_
.
cuda
())
def
test_forward_cuda
(
self
):
out_
=
self
.
forward_cuda
(
self
.
input_
)
assert
out_
.
is_cuda
==
True
def
test_forward_cpu
(
self
):
out_
=
self
.
forward_cpu
(
self
.
input_
)
assert
out_
.
is_cuda
==
False
def
test_same_output
(
self
):
out_cpu
=
self
.
forward_cpu
(
self
.
input_
)
out_cuda
=
self
.
forward_cuda
(
self
.
input_
)
torch
.
testing
.
assert_allclose
(
out_cpu
,
out_cuda
.
cpu
())
class
TestFusedLayerNormElemWise
(
TestFusedLayerNorm
):
def
setUp
(
self
):
self
.
module
=
apex
.
normalization
.
FusedLayerNorm
(
normalized_shape
=
[
32
,
64
],
elementwise_affine
=
True
)
self
.
input_
=
torch
.
randn
(
16
,
32
,
64
)
torch
.
cuda
.
manual_seed
(
42
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment