Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
d1626ccc
Commit
d1626ccc
authored
Jul 20, 2019
by
Myle Ott
Browse files
Update FusedLayerNorm for new function API
parent
574fe244
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
39 additions
and
34 deletions
+39
-34
apex/normalization/fused_layer_norm.py
apex/normalization/fused_layer_norm.py
+39
-34
No files found.
apex/normalization/fused_layer_norm.py
View file @
d1626ccc
...
@@ -6,60 +6,66 @@ from torch.nn import init
...
@@ -6,60 +6,66 @@ from torch.nn import init
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
import
importlib
import
importlib
class
FusedLayerNormAffineFunction
(
torch
.
autograd
.
Function
):
global
fused_layer_norm_cuda
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-6
):
fused_layer_norm_cuda
=
None
global
fused_layer_norm_cuda
fused_layer_norm_cuda
=
importlib
.
import_module
(
"fused_layer_norm_cuda"
)
self
.
normalized_shape
=
normalized_shape
class
FusedLayerNormAffineFunction
(
torch
.
autograd
.
Function
):
self
.
eps
=
eps
def
forward
(
self
,
input
,
weight
,
bias
):
@
staticmethod
def
forward
(
ctx
,
input
,
weight
,
bias
,
normalized_shape
,
eps
):
global
fused_layer_norm_cuda
if
fused_layer_norm_cuda
is
None
:
fused_layer_norm_cuda
=
importlib
.
import_module
(
"fused_layer_norm_cuda"
)
ctx
.
normalized_shape
=
normalized_shape
ctx
.
eps
=
eps
input_
=
input
.
contiguous
()
input_
=
input
.
contiguous
()
weight_
=
weight
.
contiguous
()
weight_
=
weight
.
contiguous
()
bias_
=
bias
.
contiguous
()
bias_
=
bias
.
contiguous
()
output
,
mean
,
invvar
=
fused_layer_norm_cuda
.
forward_affine
(
output
,
mean
,
invvar
=
fused_layer_norm_cuda
.
forward_affine
(
input_
,
self
.
normalized_shape
,
weight_
,
bias_
,
self
.
eps
)
input_
,
ctx
.
normalized_shape
,
weight_
,
bias_
,
eps
)
self
.
save_for_backward
(
input_
,
weight_
,
bias_
,
mean
,
invvar
)
ctx
.
save_for_backward
(
input_
,
weight_
,
bias_
,
mean
,
invvar
)
return
output
return
output
def
backward
(
self
,
grad_output
):
@
staticmethod
input_
,
weight_
,
bias_
,
mean
,
invvar
=
self
.
saved_tensors
def
backward
(
ctx
,
grad_output
):
input_
,
weight_
,
bias_
,
mean
,
invvar
=
ctx
.
saved_tensors
grad_input
=
grad_weight
=
grad_bias
=
None
grad_input
=
grad_weight
=
grad_bias
=
None
grad_input
,
grad_weight
,
grad_bias
=
fused_layer_norm_cuda
.
backward_affine
(
grad_input
,
grad_weight
,
grad_bias
=
fused_layer_norm_cuda
.
backward_affine
(
grad_output
.
contiguous
(),
mean
,
invvar
,
grad_output
.
contiguous
(),
mean
,
invvar
,
input_
,
self
.
normalized_shape
,
input_
,
ctx
.
normalized_shape
,
weight_
,
bias_
,
self
.
eps
)
weight_
,
bias_
,
ctx
.
eps
)
return
grad_input
,
grad_weight
,
grad_bias
;
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
class
FusedLayerNormFunction
(
torch
.
autograd
.
Function
):
class
FusedLayerNormFunction
(
torch
.
autograd
.
Function
):
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-6
):
global
fused_layer_norm_cuda
fused_layer_norm_cuda
=
importlib
.
import_module
(
"fused_layer_norm_cuda"
)
self
.
normalized_shape
=
normalized_shape
self
.
eps
=
eps
def
forward
(
self
,
input
):
@
staticmethod
def
forward
(
ctx
,
input
,
normalized_shape
,
eps
):
global
fused_layer_norm_cuda
if
fused_layer_norm_cuda
is
None
:
fused_layer_norm_cuda
=
importlib
.
import_module
(
"fused_layer_norm_cuda"
)
ctx
.
normalized_shape
=
normalized_shape
ctx
.
eps
=
eps
input_
=
input
.
contiguous
()
input_
=
input
.
contiguous
()
output
,
mean
,
invvar
=
fused_layer_norm_cuda
.
forward
(
output
,
mean
,
invvar
=
fused_layer_norm_cuda
.
forward
(
input_
,
self
.
normalized_shape
,
self
.
eps
)
input_
,
ctx
.
normalized_shape
,
ctx
.
eps
)
self
.
save_for_backward
(
input_
,
mean
,
invvar
)
ctx
.
save_for_backward
(
input_
,
mean
,
invvar
)
return
output
return
output
def
backward
(
self
,
grad_output
):
@
staticmethod
input_
,
mean
,
invvar
=
self
.
saved_tensors
def
backward
(
ctx
,
grad_output
):
input_
,
mean
,
invvar
=
ctx
.
saved_tensors
grad_input
=
None
grad_input
=
None
grad_input
=
fused_layer_norm_cuda
.
backward
(
grad_input
=
fused_layer_norm_cuda
.
backward
(
grad_output
.
contiguous
(),
mean
,
invvar
,
grad_output
.
contiguous
(),
mean
,
invvar
,
input_
,
self
.
normalized_shape
,
input_
,
ctx
.
normalized_shape
,
self
.
eps
)
ctx
.
eps
)
return
grad_input
return
grad_input
,
None
,
None
def
fused_layer_norm_affine
(
input
,
normalized_shape
,
weight
,
bias
,
eps
=
1e-6
):
def
fused_layer_norm_affine
(
input
,
normalized_shape
,
weight
,
bias
,
eps
=
1e-6
):
return
FusedLayerNormAffineFunction
(
normalized_shape
,
eps
)(
input
,
weight
,
bia
s
)
return
FusedLayerNormAffineFunction
.
apply
(
input
,
weight
,
bias
,
normalized_shape
,
ep
s
)
def
fused_layer_norm
(
input
,
normalized_shape
,
eps
=
1e-6
):
def
fused_layer_norm
(
input
,
normalized_shape
,
eps
=
1e-6
):
return
FusedLayerNormFunction
(
normalized_shape
,
eps
)
(
input
)
return
FusedLayerNormFunction
.
apply
(
input
,
normalized_shape
,
eps
)
class
FusedLayerNorm
(
torch
.
nn
.
Module
):
class
FusedLayerNorm
(
torch
.
nn
.
Module
):
r
"""Applies Layer Normalization over a mini-batch of inputs as described in
r
"""Applies Layer Normalization over a mini-batch of inputs as described in
...
@@ -149,11 +155,10 @@ class FusedLayerNorm(torch.nn.Module):
...
@@ -149,11 +155,10 @@ class FusedLayerNorm(torch.nn.Module):
return
F
.
layer_norm
(
return
F
.
layer_norm
(
input
,
self
.
normalized_shape
,
self
.
weight
,
self
.
bias
,
self
.
eps
)
input
,
self
.
normalized_shape
,
self
.
weight
,
self
.
bias
,
self
.
eps
)
if
self
.
elementwise_affine
:
if
self
.
elementwise_affine
:
return
FusedLayerNormAffineFunction
(
self
.
normalized_shape
,
self
.
eps
)
(
return
FusedLayerNormAffineFunction
.
apply
(
input
,
self
.
weight
,
self
.
bias
)
input
,
self
.
weight
,
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
else
:
else
:
return
FusedLayerNormFunction
(
self
.
normalized_shape
,
self
.
eps
)(
return
FusedLayerNormFunction
.
apply
(
input
,
self
.
normalized_shape
,
self
.
eps
)
input
)
def
extra_repr
(
self
):
def
extra_repr
(
self
):
return
'{normalized_shape}, eps={eps}, '
\
return
'{normalized_shape}, eps={eps}, '
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment