Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
3cc2c1d2
Unverified
Commit
3cc2c1d2
authored
Aug 10, 2023
by
Kirthi Shankar Sivamani
Committed by
GitHub
Aug 10, 2023
Browse files
AMP support for LN and RMSNorm (#371)
Signed-off-by:
Kirthi Shankar Sivamani
<
ksivamani@nvidia.com
>
parent
88c88654
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
74 additions
and
3 deletions
+74
-3
tests/pytorch/test_sanity.py
tests/pytorch/test_sanity.py
+46
-0
transformer_engine/pytorch/module/layernorm.py
transformer_engine/pytorch/module/layernorm.py
+14
-2
transformer_engine/pytorch/module/rmsnorm.py
transformer_engine/pytorch/module/rmsnorm.py
+14
-1
No files found.
tests/pytorch/test_sanity.py
View file @
3cc2c1d2
...
@@ -15,6 +15,8 @@ from transformer_engine.pytorch import (
...
@@ -15,6 +15,8 @@ from transformer_engine.pytorch import (
Linear
,
Linear
,
LayerNormMLP
,
LayerNormMLP
,
TransformerLayer
,
TransformerLayer
,
RMSNorm
,
LayerNorm
,
)
)
from
transformer_engine.common
import
recipe
from
transformer_engine.common
import
recipe
...
@@ -308,6 +310,50 @@ def _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad, skip_d
...
@@ -308,6 +310,50 @@ def _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad, skip_d
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
def
_test_sanity_normalization_amp
(
block
,
bs
,
dtype
,
config
,
skip_wgrad
,
skip_dgrad
):
if
skip_dgrad
and
skip_wgrad
:
pytest
.
skip
(
"No gradient computation; Skipping to avoid PyTorch RuntimeError."
)
te_inp
=
torch
.
randn
(
config
.
seq_len
,
bs
,
config
.
hidden_size
,
requires_grad
=
True
).
cuda
()
te_inp
.
retain_grad
()
with
torch
.
autocast
(
device_type
=
"cuda"
,
enabled
=
True
,
dtype
=
dtype
):
te_out
=
block
(
te_inp
)
loss
=
te_out
.
sum
()
loss
.
backward
()
torch
.
cuda
.
synchronize
()
assert
te_out
.
dtype
==
dtype
,
"AMP wrong output type."
assert
te_inp
.
grad
.
dtype
==
torch
.
float32
,
"AMP wrong dgrad type."
for
name
,
p
in
block
.
named_parameters
():
if
p
.
requires_grad
:
assert
p
.
grad
.
dtype
==
torch
.
float32
,
f
"AMP wrong wgrad type for
{
name
}
."
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"bs"
,
batch_sizes
)
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs
.
keys
())
@
pytest
.
mark
.
parametrize
(
"skip_wgrad"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"skip_dgrad"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
def
test_sanity_normalization_amp
(
dtype
,
bs
,
model
,
skip_wgrad
,
skip_dgrad
,
normalization
):
config
=
model_configs
[
model
]
module
=
RMSNorm
if
normalization
==
"RMSNorm"
else
LayerNorm
block
=
(
module
(
config
.
hidden_size
,
eps
=
config
.
eps
,
)
.
to
(
dtype
=
torch
.
float32
)
.
cuda
()
)
_test_sanity_normalization_amp
(
block
,
bs
,
dtype
,
config
,
skip_wgrad
,
skip_dgrad
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"bs"
,
batch_sizes
)
@
pytest
.
mark
.
parametrize
(
"bs"
,
batch_sizes
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
...
...
transformer_engine/pytorch/module/layernorm.py
View file @
3cc2c1d2
...
@@ -11,10 +11,12 @@ from torch.nn.parameter import Parameter
...
@@ -11,10 +11,12 @@ from torch.nn.parameter import Parameter
from
torch.nn
import
init
from
torch.nn
import
init
import
transformer_engine_extensions
as
tex
import
transformer_engine_extensions
as
tex
from
.base
import
TransformerEngineBaseModule
from
..cpp_extensions
import
(
from
..cpp_extensions
import
(
layernorm_fwd_inf
,
layernorm_fwd_inf
,
)
)
from
..jit
import
no_torch_dynamo
from
..jit
import
no_torch_dynamo
from
..utils
import
cast_if_needed
__all__
=
[
"LayerNorm"
]
__all__
=
[
"LayerNorm"
]
...
@@ -33,6 +35,7 @@ class _LayerNorm(torch.autograd.Function):
...
@@ -33,6 +35,7 @@ class _LayerNorm(torch.autograd.Function):
bwd_ln_sm_margin
:
int
,
bwd_ln_sm_margin
:
int
,
zero_centered_gamma
:
bool
,
zero_centered_gamma
:
bool
,
is_grad_enabled
:
bool
,
is_grad_enabled
:
bool
,
activation_dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# Make sure input dimensions are compatible
# Make sure input dimensions are compatible
in_features
=
ln_weight
.
numel
()
in_features
=
ln_weight
.
numel
()
...
@@ -40,6 +43,11 @@ class _LayerNorm(torch.autograd.Function):
...
@@ -40,6 +43,11 @@ class _LayerNorm(torch.autograd.Function):
assert
inp
.
shape
[
-
1
]
==
in_features
,
"LayerNorm not possible"
assert
inp
.
shape
[
-
1
]
==
in_features
,
"LayerNorm not possible"
inputmat
=
inp
.
view
((
-
1
,
in_features
))
inputmat
=
inp
.
view
((
-
1
,
in_features
))
# Cast for native AMP
inputmat
=
cast_if_needed
(
inputmat
,
activation_dtype
)
ln_weight
=
cast_if_needed
(
ln_weight
,
activation_dtype
)
ln_bias
=
cast_if_needed
(
ln_bias
,
activation_dtype
)
if
is_grad_enabled
:
if
is_grad_enabled
:
ln_out
,
mu
,
rsigma
=
tex
.
layernorm_fwd
(
inputmat
,
ln_weight
,
ln_out
,
mu
,
rsigma
=
tex
.
layernorm_fwd
(
inputmat
,
ln_weight
,
ln_bias
,
eps
,
fwd_ln_sm_margin
,
zero_centered_gamma
)
ln_bias
,
eps
,
fwd_ln_sm_margin
,
zero_centered_gamma
)
...
@@ -63,7 +71,7 @@ class _LayerNorm(torch.autograd.Function):
...
@@ -63,7 +71,7 @@ class _LayerNorm(torch.autograd.Function):
d_ln_out
,
inputmat
,
mu
,
rsigma
,
ln_weight
,
d_ln_out
,
inputmat
,
mu
,
rsigma
,
ln_weight
,
ctx
.
bwd_ln_sm_margin
,
ctx
.
zero_centered_gamma
ctx
.
bwd_ln_sm_margin
,
ctx
.
zero_centered_gamma
)
)
return
dxmat
.
view
(
ctx
.
inp_shape
),
dgamma
,
dbeta
,
None
,
None
,
None
,
None
,
None
return
dxmat
.
view
(
ctx
.
inp_shape
),
dgamma
,
dbeta
,
None
,
None
,
None
,
None
,
None
,
None
class
LayerNorm
(
torch
.
nn
.
Module
):
class
LayerNorm
(
torch
.
nn
.
Module
):
...
@@ -170,6 +178,9 @@ class LayerNorm(torch.nn.Module):
...
@@ -170,6 +178,9 @@ class LayerNorm(torch.nn.Module):
if
hasattr
(
self
,
"layer_norm_bias"
):
if
hasattr
(
self
,
"layer_norm_bias"
):
setattr
(
self
,
"bias"
,
self
.
layer_norm_bias
)
setattr
(
self
,
"bias"
,
self
.
layer_norm_bias
)
# Set the activation type for AMP.
TransformerEngineBaseModule
.
set_activation_dtype
(
self
,
inp
)
if
torch
.
is_grad_enabled
():
if
torch
.
is_grad_enabled
():
fwd_fn
=
_LayerNorm
.
apply
fwd_fn
=
_LayerNorm
.
apply
args
=
[]
args
=
[]
...
@@ -185,7 +196,8 @@ class LayerNorm(torch.nn.Module):
...
@@ -185,7 +196,8 @@ class LayerNorm(torch.nn.Module):
self
.
fwd_ln_sm_margin
,
self
.
fwd_ln_sm_margin
,
self
.
bwd_ln_sm_margin
,
self
.
bwd_ln_sm_margin
,
self
.
zero_centered_gamma
,
self
.
zero_centered_gamma
,
torch
.
is_grad_enabled
()
torch
.
is_grad_enabled
(),
self
.
activation_dtype
,
)
)
return
fwd_fn
(
*
args
)
return
fwd_fn
(
*
args
)
transformer_engine/pytorch/module/rmsnorm.py
View file @
3cc2c1d2
...
@@ -10,8 +10,10 @@ import torch
...
@@ -10,8 +10,10 @@ import torch
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
torch.nn
import
init
from
torch.nn
import
init
from
.base
import
TransformerEngineBaseModule
from
..
import
cpp_extensions
as
tex
from
..
import
cpp_extensions
as
tex
from
..jit
import
no_torch_dynamo
from
..jit
import
no_torch_dynamo
from
..utils
import
cast_if_needed
__all__
=
[
"RMSNorm"
]
__all__
=
[
"RMSNorm"
]
...
@@ -30,6 +32,7 @@ class _RMSNorm(torch.autograd.Function):
...
@@ -30,6 +32,7 @@ class _RMSNorm(torch.autograd.Function):
bwd_rmsnorm_sm_margin
:
int
,
bwd_rmsnorm_sm_margin
:
int
,
zero_centered_gamma
:
bool
,
zero_centered_gamma
:
bool
,
is_grad_enabled
:
bool
,
is_grad_enabled
:
bool
,
activation_dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# Make sure input dimensions are compatible
# Make sure input dimensions are compatible
in_features
=
rmsnorm_weight
.
numel
()
in_features
=
rmsnorm_weight
.
numel
()
...
@@ -37,6 +40,10 @@ class _RMSNorm(torch.autograd.Function):
...
@@ -37,6 +40,10 @@ class _RMSNorm(torch.autograd.Function):
assert
inp
.
shape
[
-
1
]
==
in_features
,
"RMSNorm not possible"
assert
inp
.
shape
[
-
1
]
==
in_features
,
"RMSNorm not possible"
inputmat
=
inp
.
view
((
-
1
,
in_features
))
inputmat
=
inp
.
view
((
-
1
,
in_features
))
# Cast for native AMP
inputmat
=
cast_if_needed
(
inputmat
,
activation_dtype
)
rmsnorm_weight
=
cast_if_needed
(
rmsnorm_weight
,
activation_dtype
)
if
is_grad_enabled
:
if
is_grad_enabled
:
rmsnorm_out
,
rsigma
=
tex
.
rmsnorm_fwd
(
inputmat
,
rmsnorm_weight
,
rmsnorm_out
,
rsigma
=
tex
.
rmsnorm_fwd
(
inputmat
,
rmsnorm_weight
,
eps
,
fwd_rmsnorm_sm_margin
,
eps
,
fwd_rmsnorm_sm_margin
,
...
@@ -70,6 +77,7 @@ class _RMSNorm(torch.autograd.Function):
...
@@ -70,6 +77,7 @@ class _RMSNorm(torch.autograd.Function):
None
,
None
,
None
,
None
,
None
,
None
,
None
,
)
)
...
@@ -148,6 +156,10 @@ class RMSNorm(torch.nn.Module):
...
@@ -148,6 +156,10 @@ class RMSNorm(torch.nn.Module):
@
no_torch_dynamo
@
no_torch_dynamo
def
forward
(
self
,
inp
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
inp
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""RMSNorm FWD"""
"""RMSNorm FWD"""
# Set the activation type for AMP.
TransformerEngineBaseModule
.
set_activation_dtype
(
self
,
inp
)
if
torch
.
is_grad_enabled
():
if
torch
.
is_grad_enabled
():
fwd_fn
=
_RMSNorm
.
apply
fwd_fn
=
_RMSNorm
.
apply
args
=
[]
args
=
[]
...
@@ -162,7 +174,8 @@ class RMSNorm(torch.nn.Module):
...
@@ -162,7 +174,8 @@ class RMSNorm(torch.nn.Module):
self
.
fwd_rmsnorm_sm_margin
,
self
.
fwd_rmsnorm_sm_margin
,
self
.
bwd_rmsnorm_sm_margin
,
self
.
bwd_rmsnorm_sm_margin
,
self
.
zero_centered_gamma
,
self
.
zero_centered_gamma
,
torch
.
is_grad_enabled
()
torch
.
is_grad_enabled
(),
self
.
activation_dtype
,
)
)
return
fwd_fn
(
*
args
)
return
fwd_fn
(
*
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment