Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
393882bc
Commit
393882bc
authored
Mar 29, 2023
by
Tri Dao
Browse files
[LayerNorm] Implement LN with parallel residual, support dim 8k
parent
009a3e71
Changes
46
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
499 additions
and
67 deletions
+499
-67
flash_attn/modules/block.py
flash_attn/modules/block.py
+30
-15
flash_attn/ops/layer_norm.py
flash_attn/ops/layer_norm.py
+109
-2
flash_attn/ops/rms_norm.py
flash_attn/ops/rms_norm.py
+14
-0
tests/models/test_gpt_neox.py
tests/models/test_gpt_neox.py
+1
-1
tests/models/test_gptj.py
tests/models/test_gptj.py
+1
-1
tests/ops/test_dropout_layer_norm.py
tests/ops/test_dropout_layer_norm.py
+344
-48
No files found.
flash_attn/modules/block.py
View file @
393882bc
...
...
@@ -18,6 +18,11 @@ try:
except
ImportError
:
dropout_add_layer_norm
=
None
try
:
from
flash_attn.ops.layer_norm
import
dropout_add_layer_norm_parallel_residual
except
ImportError
:
dropout_add_layer_norm_parallel_residual
=
None
class
Block
(
nn
.
Module
):
...
...
@@ -64,7 +69,7 @@ class Block(nn.Module):
self
.
norm2
=
norm_cls
(
dim
)
if
self
.
fused_dropout_add_ln
:
assert
dropout_add_layer_norm
is
not
None
,
'dropout_
add_ln
is not installed'
assert
dropout_add_layer_norm
is
not
None
,
'dropout_
layer_norm
is not installed'
assert
isinstance
(
self
.
norm1
,
nn
.
LayerNorm
)
and
isinstance
(
self
.
dropout1
,
nn
.
Dropout
)
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
...
...
@@ -214,7 +219,6 @@ class ParallelBlock(nn.Module):
super
().
__init__
()
self
.
tied_norm
=
tied_norm
self
.
fused_dropout_add_ln
=
fused_dropout_add_ln
assert
not
self
.
fused_dropout_add_ln
,
'This is not implemented for ParallelBlock yet'
self
.
residual_in_fp32
=
residual_in_fp32
if
mixer_cls
is
None
:
mixer_cls
=
partial
(
MHA
,
num_heads
=
dim
//
64
)
...
...
@@ -229,7 +233,7 @@ class ParallelBlock(nn.Module):
self
.
norm2
=
norm_cls
(
dim
)
if
self
.
fused_dropout_add_ln
:
assert
dropout_add_layer_norm
is
not
None
,
'dropout_
add_ln
is not installed'
assert
dropout_add_layer_norm
_parallel_residual
is
not
None
,
'dropout_
layer_norm
is not installed'
assert
isinstance
(
self
.
norm1
,
nn
.
LayerNorm
)
and
isinstance
(
self
.
dropout1
,
nn
.
Dropout
)
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
...
...
@@ -262,19 +266,30 @@ class ParallelBlock(nn.Module):
hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
residual.
"""
dropped1
=
self
.
dropout1
(
hidden_states1
)
# For the very 1st block, we only want 1 dropout, not two different dropouts
if
hidden_states2
is
not
None
:
dropped2
=
self
.
dropout2
(
hidden_states2
)
residual
=
((
residual
+
dropped1
+
dropped2
)
if
residual
is
not
None
else
dropped1
+
dropped2
)
if
not
self
.
fused_dropout_add_ln
:
dropped1
=
self
.
dropout1
(
hidden_states1
)
# For the very 1st block, we only want 1 dropout, not two different dropouts
if
hidden_states2
is
not
None
:
dropped2
=
self
.
dropout2
(
hidden_states2
)
residual
=
((
residual
+
dropped1
+
dropped2
)
if
residual
is
not
None
else
dropped1
+
dropped2
)
else
:
residual
=
(
residual
+
dropped1
)
if
residual
is
not
None
else
dropped1
hidden_states1
=
self
.
norm1
(
residual
.
to
(
dtype
=
self
.
norm1
.
weight
.
dtype
))
hidden_states2
=
(
self
.
norm2
(
residual
.
to
(
dtype
=
self
.
norm2
.
weight
.
dtype
))
if
not
self
.
tied_norm
else
hidden_states1
)
if
self
.
residual_in_fp32
:
residual
=
residual
.
to
(
torch
.
float32
)
else
:
residual
=
(
residual
+
dropped1
)
if
residual
is
not
None
else
dropped1
hidden_states1
=
self
.
norm1
(
residual
.
to
(
dtype
=
self
.
norm1
.
weight
.
dtype
))
hidden_states2
=
(
self
.
norm2
(
residual
.
to
(
dtype
=
self
.
norm2
.
weight
.
dtype
))
if
not
self
.
tied_norm
else
hidden_states1
)
if
self
.
residual_in_fp32
:
residual
=
residual
.
to
(
torch
.
float32
)
weight2
,
bias2
=
((
self
.
norm2
.
weight
,
self
.
norm2
.
bias
)
if
not
self
.
tied_norm
else
(
None
,
None
))
hidden_states1
,
hidden_states2
,
residual
=
dropout_add_layer_norm_parallel_residual
(
hidden_states1
,
hidden_states2
,
residual
,
self
.
norm1
.
weight
,
self
.
norm1
.
bias
,
weight2
,
bias2
,
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
self
.
norm1
.
eps
,
prenorm
=
True
,
residual_in_fp32
=
self
.
residual_in_fp32
)
if
self
.
tied_norm
:
hidden_states2
=
hidden_states1
if
mixer_kwargs
is
None
:
mixer_kwargs
=
{}
hidden_states1
=
self
.
mixer
(
hidden_states1
,
**
mixer_kwargs
)
...
...
flash_attn/ops/layer_norm.py
View file @
393882bc
...
...
@@ -99,6 +99,46 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
return
dx0mat
,
dresidualmat
,
dgamma
,
dbeta
,
dcolscale
def
_dropout_add_layer_norm_parallel_residual_forward
(
x0
,
x1
,
residual
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
):
""" Assume that arguments are contiguous
"""
hidden_size
=
gamma0
.
numel
()
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
x1mat
=
x1
.
view
((
-
1
,
hidden_size
))
if
x1
is
not
None
else
None
residualmat
=
residual
.
view
((
-
1
,
hidden_size
))
if
residual
is
not
None
else
None
z0mat
,
z1mat
,
xmat
,
dmask0
,
dmask1
,
mu
,
rsigma
=
dropout_layer_norm
.
dropout_add_ln_parallel_residual_fwd
(
x0mat
,
x1mat
,
residualmat
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
None
,
residual_in_fp32
,
is_rms_norm
)
# dmask0 and dmask1 are None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
return
z0mat
,
z1mat
,
xmat
if
xmat
is
not
None
else
x0mat
,
dmask0
,
dmask1
,
mu
,
rsigma
def
_dropout_add_layer_norm_parallel_residual_backward
(
dz0
,
dz1
,
dx
,
x
,
dmask0
,
dmask1
,
mu
,
rsigma
,
gamma0
,
gamma1
,
dropout_p
,
has_x1
,
has_residual
,
is_rms_norm
=
False
):
""" Assume that arguments are contiguous
dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd).
"""
hidden_size
=
gamma0
.
numel
()
xmat
=
x
.
view
((
-
1
,
hidden_size
))
dz0mat
=
dz0
.
view
(
xmat
.
shape
)
dz1mat
=
dz1
.
view
(
xmat
.
shape
)
if
dz1
is
not
None
else
None
dxmat
=
dx
.
view
(
xmat
.
shape
)
if
dx
is
not
None
else
None
dx0mat
,
dx1mat
,
dresidualmat
,
dgamma0
,
dbeta0
,
dgamma1
,
dbeta1
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_parallel_residual_bwd
(
dz0mat
,
dz1mat
,
dxmat
,
xmat
,
dmask0
,
dmask1
,
mu
,
rsigma
,
gamma0
,
gamma1
,
dropout_p
,
has_x1
,
has_residual
,
is_rms_norm
)
# dresidualmat is None if not has_residual
return
dx0mat
,
dx1mat
,
dresidualmat
,
dgamma0
,
dbeta0
,
dgamma1
,
dbeta1
class
DropoutAddLayerNormFn
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x0
,
residual
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
...
...
@@ -115,7 +155,7 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
)
# Only need to save x0 if we need to compute gradient wrt colscale
x0_saved
=
x0
if
colscale
is
not
None
else
None
ctx
.
save_for_backward
(
xmat
.
view
(
x0
.
shape
),
x0
,
dmask
,
gamma
,
mu
,
rsigma
,
rowscale
,
colscale
)
ctx
.
save_for_backward
(
xmat
.
view
(
x0
.
shape
),
x0
_saved
,
dmask
,
gamma
,
mu
,
rsigma
,
rowscale
,
colscale
)
ctx
.
prenorm
=
prenorm
ctx
.
dropout_p
=
dropout_p
ctx
.
has_residual
=
residual
is
not
None
...
...
@@ -168,7 +208,7 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
# Only need to save x0 if we need to compute gradient wrt colscale
x0_saved
=
x0
if
colscale
is
not
None
else
None
x_shape
=
(
-
1
,
*
x0
.
shape
[
1
:])
ctx
.
save_for_backward
(
xmat
.
view
(
x_shape
),
x0
,
dmask
,
gamma
,
mu
,
rsigma
,
colscale
,
ctx
.
save_for_backward
(
xmat
.
view
(
x_shape
),
x0
_saved
,
dmask
,
gamma
,
mu
,
rsigma
,
colscale
,
x0_subset
,
out_subset
)
ctx
.
prenorm
=
prenorm
ctx
.
dropout_p
=
dropout_p
...
...
@@ -208,6 +248,60 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
)
class
DropoutAddLayerNormParallelResidualFn
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x0
,
x1
,
residual
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
residual_in_fp32
=
False
,
prenorm
=
False
,
is_rms_norm
=
False
,
return_dmask
=
False
):
x0
=
x0
.
contiguous
()
x1
=
x1
.
contiguous
()
if
x1
is
not
None
else
None
residual
=
residual
.
contiguous
()
if
residual
is
not
None
else
None
gamma0
=
gamma0
.
contiguous
()
beta0
=
beta0
.
contiguous
()
if
beta0
is
not
None
else
None
gamma1
=
gamma1
.
contiguous
()
if
gamma1
is
not
None
else
None
beta1
=
beta1
.
contiguous
()
if
beta1
is
not
None
else
None
z0mat
,
z1mat
,
xmat
,
dmask0
,
dmask1
,
mu
,
rsigma
=
_dropout_add_layer_norm_parallel_residual_forward
(
x0
,
x1
,
residual
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
residual_in_fp32
,
is_rms_norm
)
ctx
.
save_for_backward
(
xmat
.
view
(
x0
.
shape
),
dmask0
,
dmask1
,
gamma0
,
gamma1
,
mu
,
rsigma
)
ctx
.
prenorm
=
prenorm
ctx
.
dropout_p
=
dropout_p
ctx
.
has_x1
=
x1
is
not
None
ctx
.
has_residual
=
residual
is
not
None
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
has_beta
=
beta0
is
not
None
z
=
(
z0mat
.
view
(
x0
.
shape
),
z1mat
.
view
(
x0
.
shape
)
if
z1mat
is
not
None
else
None
)
if
not
return_dmask
:
return
z
if
not
prenorm
else
(
*
z
,
xmat
.
view
(
x0
.
shape
))
else
:
dmask0
=
(
dmask0
.
view
(
x0
.
shape
)
if
dropout_p
>
0.
else
torch
.
ones
(
x0
.
shape
,
dtype
=
torch
.
uint8
,
device
=
x0
.
device
))
dmask1
=
(
dmask1
.
view
(
x0
.
shape
)
if
dropout_p
>
0.
and
x1
is
not
None
else
torch
.
ones
(
x0
.
shape
,
dtype
=
torch
.
uint8
,
device
=
x0
.
device
))
ctx
.
mark_non_differentiable
(
dmask0
)
ctx
.
mark_non_differentiable
(
dmask1
)
return
(
*
z
,
dmask0
,
dmask1
)
if
not
prenorm
else
(
*
z
,
xmat
.
view
(
x0
.
shape
),
dmask0
,
dmask1
)
@
staticmethod
def
backward
(
ctx
,
dz0
,
dz1
,
*
args
):
dz0
=
dz0
.
contiguous
()
# this happens!
dz1
=
dz1
.
contiguous
()
if
dz1
is
not
None
else
None
dx
=
args
[
0
].
contiguous
()
if
ctx
.
prenorm
else
None
x
,
dmask0
,
dmask1
,
gamma0
,
gamma1
,
mu
,
rsigma
=
ctx
.
saved_tensors
dropout_p
=
ctx
.
dropout_p
has_x1
=
ctx
.
has_x1
has_residual
=
ctx
.
has_residual
dx0mat
,
dx1mat
,
dresidualmat
,
dgamma0
,
dbeta0
,
dgamma1
,
dbeta1
=
_dropout_add_layer_norm_parallel_residual_backward
(
dz0
,
dz1
,
dx
,
x
,
dmask0
,
dmask1
,
mu
,
rsigma
,
gamma0
,
gamma1
,
dropout_p
,
has_x1
,
has_residual
,
ctx
.
is_rms_norm
)
dx0
=
dx0mat
.
view
(
x
.
shape
)
dx1
=
dx1mat
.
view
(
x
.
shape
)
if
dx1mat
is
not
None
else
None
dresidual
=
dresidualmat
.
view
(
x
.
shape
)
if
dresidualmat
is
not
None
else
None
return
(
dx0
,
dx1
,
dresidual
,
dgamma0
,
dbeta0
if
ctx
.
has_beta
else
None
,
dgamma1
,
dbeta1
if
ctx
.
has_beta
else
None
,
None
,
None
,
None
,
None
,
None
,
None
)
def
layer_norm
(
x
,
weight
,
bias
,
epsilon
):
return
DropoutAddLayerNormFn
.
apply
(
x
,
None
,
weight
,
bias
,
None
,
None
,
0.0
,
epsilon
,
False
)
...
...
@@ -237,6 +331,19 @@ def dropout_add_layer_norm_subset(x0, residual, weight, bias, dropout_p, epsilon
)
def
dropout_add_layer_norm_parallel_residual
(
x0
,
x1
,
residual
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return
DropoutAddLayerNormParallelResidualFn
.
apply
(
x0
,
x1
,
residual
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
False
,
return_dropout_mask
)
class
DropoutAddLayerNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
prenorm
=
False
,
p
=
0.0
,
eps
=
1e-5
,
residual_in_fp32
=
False
,
device
=
None
,
dtype
=
None
):
...
...
flash_attn/ops/rms_norm.py
View file @
393882bc
...
...
@@ -5,6 +5,7 @@ import torch
from
torch.nn
import
init
from
flash_attn.ops.layer_norm
import
DropoutAddLayerNormFn
,
DropoutAddLayerNormSubsetFn
from
flash_attn.ops.layer_norm
import
DropoutAddLayerNormParallelResidualFn
def
rms_norm
(
x
,
weight
,
epsilon
):
...
...
@@ -37,6 +38,19 @@ def dropout_add_rms_norm_subset(x0, residual, weight, bias, dropout_p, epsilon,
)
def
dropout_add_rms_norm_parallel_residual
(
x0
,
x1
,
residual
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return
DropoutAddLayerNormParallelResidualFn
.
apply
(
x0
,
x1
,
residual
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
True
,
return_dropout_mask
)
class
DropoutAddRMSNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
prenorm
=
False
,
p
=
0.0
,
eps
=
1e-5
,
residual_in_fp32
=
False
,
device
=
None
,
dtype
=
None
):
...
...
tests/models/test_gpt_neox.py
View file @
393882bc
...
...
@@ -35,7 +35,7 @@ def test_gpt_neox_optimized(model_name):
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
True
# GPT-NeoX-20B uses "gelu_fast"
config
.
fused_dropout_add_ln
=
False
# We don't support parallel block yet
config
.
fused_dropout_add_ln
=
True
config
.
residual_in_fp32
=
True
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
,
device
=
device
,
dtype
=
dtype
)
...
...
tests/models/test_gptj.py
View file @
393882bc
...
...
@@ -36,7 +36,7 @@ def test_gptj_optimized(model_name):
config
.
use_flash_attn
=
False
# FlashAttention doesn't support hdim 256 yet
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
True
config
.
fused_dropout_add_ln
=
False
# We don't support parallel block yet
config
.
fused_dropout_add_ln
=
True
config
.
residual_in_fp32
=
True
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
,
device
=
device
,
dtype
=
dtype
)
...
...
tests/ops/test_dropout_layer_norm.py
View file @
393882bc
...
...
@@ -10,11 +10,14 @@ from flash_attn.ops.layer_norm import DropoutAddLayerNorm, dropout_add_layer_nor
from
flash_attn.ops.layer_norm
import
dropout_add_layer_norm_subset
from
flash_attn.ops.rms_norm
import
DropoutAddRMSNorm
,
dropout_add_rms_norm
from
flash_attn.ops.rms_norm
import
dropout_add_rms_norm_subset
from
flash_attn.ops.layer_norm
import
dropout_add_layer_norm_parallel_residual
from
flash_attn.ops.rms_norm
import
dropout_add_rms_norm_parallel_residual
try
:
from
apex.normalization
import
FusedRMSNorm
from
apex.normalization.fused_layer_norm
import
fused_rms_norm_affine
except
:
FusedRMSNorm
=
None
FusedRMSNorm
,
fused_rms_norm_affine
=
None
,
None
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
'cuda'
)[
0
]
>=
8
...
...
@@ -35,8 +38,8 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
(
torch
.
float32
,
torch
.
float32
)]
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]))
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
#
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
@
pytest
.
mark
.
parametrize
(
'hidden_size'
,
[
256
])
@
pytest
.
mark
.
parametrize
(
'hidden_size'
,
[
192
,
256
,
384
,
768
,
1024
,
1280
,
1536
,
1600
,
2048
,
2560
,
3000
,
3072
,
4096
,
5120
,
6144
])
#
@pytest.mark.parametrize('hidden_size', [256])
def
test_dropout_layer_norm_training
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
dropout_p
,
has_residual
,
has_rowscale
,
has_colscale
,
is_rms_norm
):
if
weight_dtype
==
torch
.
float16
and
input_dtype
==
torch
.
bfloat16
:
...
...
@@ -64,11 +67,11 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
else
:
colscale
=
None
if
has_residual
:
x1
_pt
=
torch
.
randn_like
(
x0
,
dtype
=
residual_dtype
,
requires_grad
=
True
)
x1
=
x1
_pt
.
detach
().
clone
().
requires_grad_
()
x1
_ref
=
x1
_pt
.
detach
().
clone
().
float
().
requires_grad_
()
res
_pt
=
torch
.
randn_like
(
x0
,
dtype
=
residual_dtype
,
requires_grad
=
True
)
res
=
res
_pt
.
detach
().
clone
().
requires_grad_
()
res
_ref
=
res
_pt
.
detach
().
clone
().
float
().
requires_grad_
()
else
:
x1
=
None
res
=
None
if
has_rowscale
:
rowscale
=
torch
.
empty
(
batch_size
,
seqlen
,
device
=
device
,
dtype
=
input_dtype
)
survival_rate
=
0.87
...
...
@@ -95,14 +98,14 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
model
.
bias
.
copy_
(
model_pt
.
bias
)
model_ref
.
bias
.
copy_
(
model_pt
.
bias
)
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
out
,
dmask
=
our_layer_norm_func
(
x0
,
x1
,
model
.
weight
,
model
.
bias
,
model
.
p
,
out
,
dmask
=
our_layer_norm_func
(
x0
,
res
,
model
.
weight
,
model
.
bias
,
model
.
p
,
model
.
epsilon
,
rowscale
=
rowscale
,
layerscale
=
colscale
,
residual_in_fp32
=
residual_in_fp32
,
return_dropout_mask
=
True
)
assert
out
.
dtype
==
input_dtype
print
(
f
'Actual dropout fraction:
{
1
-
dmask
.
float
().
mean
().
item
()
}
'
)
if
has_residual
:
residual_pt
=
((
x0_scaled_pt
.
float
()
*
dmask
.
float
())
/
(
1
-
dropout_p
)
+
x1
_pt
.
float
()).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
x0_scaled_ref
*
dmask
.
float
())
/
(
1
-
dropout_p
)
+
x1
_ref
residual_pt
=
((
x0_scaled_pt
.
float
()
*
dmask
.
float
())
/
(
1
-
dropout_p
)
+
res
_pt
.
float
()).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
x0_scaled_ref
*
dmask
.
float
())
/
(
1
-
dropout_p
)
+
res
_ref
else
:
residual_pt
=
((
x0_scaled_pt
.
float
()
*
dmask
.
float
())
/
(
1
-
dropout_p
)).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
x0_scaled_ref
*
dmask
.
float
())
/
(
1
-
dropout_p
)
...
...
@@ -116,8 +119,8 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
out_ref
.
backward
(
g
)
assert
(
x0
.
grad
-
x0_ref
.
grad
).
abs
().
max
()
<=
4
*
(
x0_pt
.
grad
-
x0_ref
.
grad
).
abs
().
max
()
+
1e-4
if
has_residual
:
assert
(
x1
.
grad
-
x1
_ref
.
grad
).
abs
().
max
()
<=
4
*
(
x1
_pt
.
grad
-
x1
_ref
.
grad
).
abs
().
max
()
+
1e-4
assert
(
model
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
+
3e-5
assert
(
res
.
grad
-
res
_ref
.
grad
).
abs
().
max
()
<=
4
*
(
res
_pt
.
grad
-
res
_ref
.
grad
).
abs
().
max
()
+
1e-4
assert
(
model
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
<=
3
*
(
model_pt
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
+
3e-5
if
not
is_rms_norm
:
assert
(
model
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
+
3e-5
if
has_colscale
:
...
...
@@ -145,9 +148,9 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
requires_grad
=
True
)
x0
=
x0_pt
.
detach
().
clone
().
requires_grad_
()
x0_ref
=
x0_pt
.
detach
().
clone
().
float
().
requires_grad_
()
x1
_pt
=
torch
.
randn_like
(
x0
,
dtype
=
residual_dtype
,
requires_grad
=
True
)
x1
=
x1
_pt
.
detach
().
clone
().
requires_grad_
()
x1
_ref
=
x1
_pt
.
detach
().
clone
().
float
().
requires_grad_
()
res
_pt
=
torch
.
randn_like
(
x0
,
dtype
=
residual_dtype
,
requires_grad
=
True
)
res
=
res
_pt
.
detach
().
clone
().
requires_grad_
()
res
_ref
=
res
_pt
.
detach
().
clone
().
float
().
requires_grad_
()
model_pt
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
weight
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
bias
)
...
...
@@ -161,9 +164,9 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
model_pt
.
eval
()
model
.
eval
()
model_ref
.
eval
()
out
=
model
(
x0
,
x1
)
residual_pt
=
(
x0_pt
.
float
()
+
x1
_pt
.
float
()).
to
(
dtype
=
residual_dtype
)
residual_ref
=
x0_ref
+
x1
_ref
out
=
model
(
x0
,
res
)
residual_pt
=
(
x0_pt
.
float
()
+
res
_pt
.
float
()).
to
(
dtype
=
residual_dtype
)
residual_ref
=
x0_ref
+
res
_ref
out_pt
=
model_pt
(
residual_pt
.
to
(
dtype
=
weight_dtype
)).
to
(
input_dtype
)
out_ref
=
model_ref
(
residual_ref
)
assert
(
out
-
out_ref
).
abs
().
max
()
<=
4
*
(
out_pt
-
out_ref
).
abs
().
max
()
+
1e-4
...
...
@@ -215,11 +218,11 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
else
:
colscale
=
None
if
has_residual
:
x1
_pt
=
torch
.
randn_like
(
x0
,
dtype
=
residual_dtype
,
requires_grad
=
True
)
x1
=
x1
_pt
.
detach
().
clone
().
requires_grad_
()
x1
_ref
=
x1
_pt
.
detach
().
clone
().
float
().
requires_grad_
()
res
_pt
=
torch
.
randn_like
(
x0
,
dtype
=
residual_dtype
,
requires_grad
=
True
)
res
=
res
_pt
.
detach
().
clone
().
requires_grad_
()
res
_ref
=
res
_pt
.
detach
().
clone
().
float
().
requires_grad_
()
else
:
x1
=
None
res
=
None
if
has_rowscale
:
rowscale
=
torch
.
empty
(
batch_size
,
seqlen
,
device
=
device
,
dtype
=
input_dtype
)
survival_rate
=
0.87
...
...
@@ -247,15 +250,15 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
model
.
bias
.
copy_
(
model_pt
.
bias
)
model_ref
.
bias
.
copy_
(
model_pt
.
bias
)
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
out
,
residual
,
dmask
=
our_layer_norm_func
(
x0
,
x1
,
model
.
weight
,
model
.
bias
,
model
.
p
,
out
,
residual
,
dmask
=
our_layer_norm_func
(
x0
,
res
,
model
.
weight
,
model
.
bias
,
model
.
p
,
model
.
epsilon
,
rowscale
=
rowscale
,
layerscale
=
colscale
,
prenorm
=
True
,
residual_in_fp32
=
residual_in_fp32
,
return_dropout_mask
=
True
)
print
(
f
'Actual dropout fraction:
{
1
-
dmask
.
float
().
mean
().
item
()
}
'
)
if
has_residual
:
residual_pt
=
((
x0_scaled_pt
.
float
()
*
dmask
.
float
())
/
(
1
-
dropout_p
)
+
x1
_pt
.
float
()).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
x0_scaled_ref
*
dmask
.
float
())
/
(
1
-
dropout_p
)
+
x1
_ref
residual_pt
=
((
x0_scaled_pt
.
float
()
*
dmask
.
float
())
/
(
1
-
dropout_p
)
+
res
_pt
.
float
()).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
x0_scaled_ref
*
dmask
.
float
())
/
(
1
-
dropout_p
)
+
res
_ref
else
:
residual_pt
=
((
x0_scaled_pt
.
float
()
*
dmask
.
float
())
/
(
1
-
dropout_p
)).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
x0_scaled_ref
*
dmask
.
float
())
/
(
1
-
dropout_p
)
...
...
@@ -272,7 +275,7 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
(
out_ref
*
F
.
sigmoid
(
residual_ref
.
to
(
dtype
=
residual_dtype
))).
backward
(
g
)
assert
(
x0
.
grad
-
x0_ref
.
grad
).
abs
().
max
()
<=
4
*
(
x0_pt
.
grad
-
x0_ref
.
grad
).
abs
().
max
()
+
1e-4
if
has_residual
:
assert
(
x1
.
grad
-
x1
_ref
.
grad
).
abs
().
max
()
<=
4
*
(
x1
_pt
.
grad
-
x1
_ref
.
grad
).
abs
().
max
()
+
1e-4
assert
(
res
.
grad
-
res
_ref
.
grad
).
abs
().
max
()
<=
4
*
(
res
_pt
.
grad
-
res
_ref
.
grad
).
abs
().
max
()
+
1e-4
assert
(
model
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
+
2e-4
if
not
is_rms_norm
:
assert
(
model
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
+
2e-4
...
...
@@ -301,9 +304,9 @@ def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtyp
requires_grad
=
True
)
x0
=
x0_pt
.
detach
().
clone
().
requires_grad_
()
x0_ref
=
x0_pt
.
detach
().
clone
().
float
().
requires_grad_
()
x1
_pt
=
torch
.
randn_like
(
x0
,
dtype
=
residual_dtype
,
requires_grad
=
True
)
x1
=
x1
_pt
.
detach
().
clone
().
requires_grad_
()
x1
_ref
=
x1
_pt
.
detach
().
clone
().
float
().
requires_grad_
()
res
_pt
=
torch
.
randn_like
(
x0
,
dtype
=
residual_dtype
,
requires_grad
=
True
)
res
=
res
_pt
.
detach
().
clone
().
requires_grad_
()
res
_ref
=
res
_pt
.
detach
().
clone
().
float
().
requires_grad_
()
model_pt
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
weight
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
bias
)
...
...
@@ -318,9 +321,9 @@ def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtyp
model_pt
.
eval
()
model
.
eval
()
model_ref
.
eval
()
out
,
residual
=
model
(
x0
,
x1
)
residual_pt
=
(
x0_pt
.
float
()
+
x1
_pt
.
float
()).
to
(
dtype
=
residual_dtype
)
residual_ref
=
x0_ref
+
x1
_ref
out
,
residual
=
model
(
x0
,
res
)
residual_pt
=
(
x0_pt
.
float
()
+
res
_pt
.
float
()).
to
(
dtype
=
residual_dtype
)
residual_ref
=
x0_ref
+
res
_ref
out_pt
=
model_pt
(
residual_pt
.
to
(
dtype
=
weight_dtype
)).
to
(
input_dtype
)
out_ref
=
model_ref
(
residual_ref
)
assert
(
out
-
out_ref
).
abs
().
max
()
<=
4
*
(
out_pt
-
out_ref
).
abs
().
max
()
+
1e-4
...
...
@@ -382,11 +385,11 @@ def test_dropout_layer_norm_subset_training(
else
:
colscale
=
None
if
has_residual
:
x1
_pt
=
torch
.
randn_like
(
x0_pt
,
dtype
=
residual_dtype
,
requires_grad
=
True
)
x1
=
x1
_pt
.
detach
().
clone
().
requires_grad_
()
x1
_ref
=
x1
_pt
.
detach
().
clone
().
float
().
requires_grad_
()
res
_pt
=
torch
.
randn_like
(
x0_pt
,
dtype
=
residual_dtype
,
requires_grad
=
True
)
res
=
res
_pt
.
detach
().
clone
().
requires_grad_
()
res
_ref
=
res
_pt
.
detach
().
clone
().
float
().
requires_grad_
()
else
:
x1
=
None
res
=
None
if
has_colscale
:
x0_scaled_pt
=
x0_pt
*
colscale_pt
...
...
@@ -409,7 +412,7 @@ def test_dropout_layer_norm_subset_training(
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
out
,
dmask
=
dropout_add_layer_norm_subset
(
x0
,
x1
,
model
.
weight
,
model
.
bias
,
model
.
p
,
model
.
epsilon
,
layerscale
=
colscale
,
x0
,
res
,
model
.
weight
,
model
.
bias
,
model
.
p
,
model
.
epsilon
,
layerscale
=
colscale
,
x0_subset
=
x0_subset
,
out_subset
=
out_subset
,
rowscale_const
=
drop_path_scale
,
out_numrows
=
out_numrows
,
prenorm
=
False
,
residual_in_fp32
=
residual_in_fp32
,
return_dropout_mask
=
True
)
...
...
@@ -424,8 +427,8 @@ def test_dropout_layer_norm_subset_training(
dmask_expanded
=
torch
.
zeros_like
(
x0_pt
,
dtype
=
torch
.
uint8
)
dmask_expanded
[
x0_mask_batch
]
=
dmask
if
has_residual
:
residual_pt
=
((
x0_scaled_pt
.
float
()
*
dmask_expanded
.
float
())
/
(
1
-
dropout_p
)
+
x1
_pt
.
float
()).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
x0_scaled_ref
*
dmask_expanded
.
float
())
/
(
1
-
dropout_p
)
+
x1
_ref
residual_pt
=
((
x0_scaled_pt
.
float
()
*
dmask_expanded
.
float
())
/
(
1
-
dropout_p
)
+
res
_pt
.
float
()).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
x0_scaled_ref
*
dmask_expanded
.
float
())
/
(
1
-
dropout_p
)
+
res
_ref
else
:
residual_pt
=
((
x0_scaled_pt
.
float
()
*
dmask_expanded
.
float
())
/
(
1
-
dropout_p
)).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
x0_scaled_ref
*
dmask_expanded
.
float
())
/
(
1
-
dropout_p
)
...
...
@@ -440,7 +443,7 @@ def test_dropout_layer_norm_subset_training(
out_ref
.
backward
(
g
)
assert
(
x0
.
grad
-
x0_ref
.
grad
[
x0_mask_batch
]).
abs
().
max
()
<=
4
*
(
x0_pt
.
grad
-
x0_ref
.
grad
)[
x0_mask_batch
].
abs
().
max
()
+
1e-4
if
has_residual
:
assert
(
x1
.
grad
-
x1
_ref
.
grad
).
abs
().
max
()
<=
4
*
(
x1
_pt
.
grad
-
x1
_ref
.
grad
).
abs
().
max
()
+
1e-4
assert
(
res
.
grad
-
res
_ref
.
grad
).
abs
().
max
()
<=
4
*
(
res
_pt
.
grad
-
res
_ref
.
grad
).
abs
().
max
()
+
1e-4
assert
(
model
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
+
2e-4
assert
(
model
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
+
2e-4
if
has_colscale
:
...
...
@@ -502,11 +505,11 @@ def test_dropout_layer_norm_subset_prenorm_training(
else
:
colscale
=
None
if
has_residual
:
x1
_pt
=
torch
.
randn_like
(
x0_pt
,
dtype
=
residual_dtype
,
requires_grad
=
True
)
x1
=
x1
_pt
.
detach
().
clone
().
requires_grad_
()
x1
_ref
=
x1
_pt
.
detach
().
clone
().
float
().
requires_grad_
()
res
_pt
=
torch
.
randn_like
(
x0_pt
,
dtype
=
residual_dtype
,
requires_grad
=
True
)
res
=
res
_pt
.
detach
().
clone
().
requires_grad_
()
res
_ref
=
res
_pt
.
detach
().
clone
().
float
().
requires_grad_
()
else
:
x1
=
None
res
=
None
if
has_colscale
:
x0_scaled_pt
=
x0_pt
*
colscale_pt
...
...
@@ -529,7 +532,7 @@ def test_dropout_layer_norm_subset_prenorm_training(
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
out
,
residual
,
dmask
=
dropout_add_layer_norm_subset
(
x0
,
x1
,
model
.
weight
,
model
.
bias
,
model
.
p
,
model
.
epsilon
,
layerscale
=
colscale
,
x0
,
res
,
model
.
weight
,
model
.
bias
,
model
.
p
,
model
.
epsilon
,
layerscale
=
colscale
,
x0_subset
=
x0_subset
,
out_subset
=
out_subset
,
rowscale_const
=
drop_path_scale
,
out_numrows
=
out_numrows
,
prenorm
=
True
,
residual_in_fp32
=
residual_in_fp32
,
return_dropout_mask
=
True
)
...
...
@@ -544,8 +547,8 @@ def test_dropout_layer_norm_subset_prenorm_training(
dmask_expanded
=
torch
.
zeros_like
(
x0_pt
,
dtype
=
torch
.
uint8
)
dmask_expanded
[
x0_mask_batch
]
=
dmask
if
has_residual
:
residual_pt
=
((
x0_scaled_pt
.
float
()
*
dmask_expanded
.
float
())
/
(
1
-
dropout_p
)
+
x1
_pt
.
float
()).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
x0_scaled_ref
*
dmask_expanded
.
float
())
/
(
1
-
dropout_p
)
+
x1
_ref
residual_pt
=
((
x0_scaled_pt
.
float
()
*
dmask_expanded
.
float
())
/
(
1
-
dropout_p
)
+
res
_pt
.
float
()).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
x0_scaled_ref
*
dmask_expanded
.
float
())
/
(
1
-
dropout_p
)
+
res
_ref
else
:
residual_pt
=
((
x0_scaled_pt
.
float
()
*
dmask_expanded
.
float
())
/
(
1
-
dropout_p
)).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
x0_scaled_ref
*
dmask_expanded
.
float
())
/
(
1
-
dropout_p
)
...
...
@@ -562,8 +565,301 @@ def test_dropout_layer_norm_subset_prenorm_training(
(
out_ref
*
F
.
sigmoid
(
residual_ref
[
out_mask_batch
].
to
(
dtype
=
residual_dtype
))
+
residual_ref
.
mean
(
0
,
keepdim
=
True
)).
backward
(
g
)
assert
(
x0
.
grad
-
x0_ref
.
grad
[
x0_mask_batch
]).
abs
().
max
()
<=
4
*
(
x0_pt
.
grad
-
x0_ref
.
grad
)[
x0_mask_batch
].
abs
().
max
()
+
1e-4
if
has_residual
:
assert
(
x1
.
grad
-
x1
_ref
.
grad
).
abs
().
max
()
<=
4
*
(
x1
_pt
.
grad
-
x1
_ref
.
grad
).
abs
().
max
()
+
1e-4
assert
(
res
.
grad
-
res
_ref
.
grad
).
abs
().
max
()
<=
4
*
(
res
_pt
.
grad
-
res
_ref
.
grad
).
abs
().
max
()
+
1e-4
assert
(
model
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
+
2e-4
assert
(
model
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
+
2e-4
if
has_colscale
:
assert
(
colscale
.
grad
-
colscale_ref
.
grad
).
abs
().
max
()
<=
2
*
(
colscale_pt
.
grad
-
colscale_ref
.
grad
).
abs
().
max
()
+
2e-4
@
pytest
.
mark
.
parametrize
(
'is_rms_norm'
,
[
False
,
True
])
# @pytest.mark.parametrize('is_rms_norm', [False])
@
pytest
.
mark
.
parametrize
(
'tied_norm'
,
[
False
,
True
])
# @pytest.mark.parametrize('tied_norm', [False])
@
pytest
.
mark
.
parametrize
(
'has_residual'
,
[
True
,
False
])
# @pytest.mark.parametrize('has_residual', [False])
@
pytest
.
mark
.
parametrize
(
'has_x1'
,
[
True
,
False
])
# @pytest.mark.parametrize('has_x1', [True])
@
pytest
.
mark
.
parametrize
(
'dropout_p'
,
[
0.37
,
0.0
])
# @pytest.mark.parametrize('dropout_p', [0.0])
@
pytest
.
mark
.
parametrize
(
'weight_dtype'
,
[
torch
.
float32
,
torch
.
float16
])
# @pytest.mark.parametrize('weight_dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
'input_dtype,residual_dtype'
,
[(
torch
.
float16
,
torch
.
float16
),
(
torch
.
float16
,
torch
.
float32
),
(
torch
.
float32
,
torch
.
float32
)]
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]))
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
@
pytest
.
mark
.
parametrize
(
'hidden_size'
,
[
192
,
256
,
384
,
768
,
1024
,
1280
,
1536
,
1600
,
2048
,
2560
,
3000
,
3072
,
4096
,
5120
,
6144
])
# @pytest.mark.parametrize('hidden_size', [256])
def
test_dropout_layer_norm_parallel_residual_training
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
dropout_p
,
has_x1
,
has_residual
,
tied_norm
,
is_rms_norm
):
if
weight_dtype
==
torch
.
float16
and
input_dtype
==
torch
.
bfloat16
:
pytest
.
skip
()
# Not supported
if
is_rms_norm
and
fused_rms_norm_affine
is
None
:
pytest
.
skip
()
# We need Apex's FusedRMSNorm to test
our_layer_norm_func
=
(
dropout_add_layer_norm_parallel_residual
if
not
is_rms_norm
else
dropout_add_rms_norm_parallel_residual
)
device
=
'cuda'
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol
,
atol
=
(
1e-3
,
1e-4
)
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
512
x0_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
hidden_size
,
device
=
device
,
dtype
=
input_dtype
,
requires_grad
=
True
)
x0
=
x0_pt
.
detach
().
clone
().
requires_grad_
()
x0_ref
=
x0_pt
.
detach
().
clone
().
float
().
requires_grad_
()
if
has_x1
:
x1_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
hidden_size
,
device
=
device
,
dtype
=
input_dtype
,
requires_grad
=
True
)
x1
=
x1_pt
.
detach
().
clone
().
requires_grad_
()
x1_ref
=
x1_pt
.
detach
().
clone
().
float
().
requires_grad_
()
else
:
x1
=
None
if
has_residual
:
res_pt
=
torch
.
randn_like
(
x0
,
dtype
=
residual_dtype
,
requires_grad
=
True
)
res
=
res_pt
.
detach
().
clone
().
requires_grad_
()
res_ref
=
res_pt
.
detach
().
clone
().
float
().
requires_grad_
()
else
:
res
=
None
weight0
=
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
bias0
=
(
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
if
not
is_rms_norm
else
None
)
weight0_pt
=
weight0
.
detach
().
clone
().
requires_grad_
()
weight0_ref
=
weight0
.
detach
().
clone
().
float
().
requires_grad_
()
bias0_pt
=
bias0
.
detach
().
clone
().
requires_grad_
()
if
bias0
is
not
None
else
None
bias0_ref
=
bias0
.
detach
().
clone
().
float
().
requires_grad_
()
if
bias0
is
not
None
else
None
if
not
tied_norm
:
weight1
=
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
bias1
=
(
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
if
not
is_rms_norm
else
None
)
weight1_pt
=
weight1
.
detach
().
clone
().
requires_grad_
()
weight1_ref
=
weight1
.
detach
().
clone
().
float
().
requires_grad_
()
bias1_pt
=
bias1
.
detach
().
clone
().
requires_grad_
()
if
bias1
is
not
None
else
None
bias1_ref
=
bias1
.
detach
().
clone
().
float
().
requires_grad_
()
if
bias1
is
not
None
else
None
else
:
weight1
,
bias1
=
None
,
None
epsilon
=
1e-5
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
out0
,
out1
,
dmask0
,
dmask1
=
our_layer_norm_func
(
x0
,
x1
,
res
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
residual_in_fp32
=
residual_in_fp32
,
return_dropout_mask
=
True
)
assert
out0
.
dtype
==
input_dtype
if
not
tied_norm
:
assert
out1
.
dtype
==
input_dtype
print
(
f
'Actual dropout fraction:
{
1
-
dmask0
.
float
().
mean
().
item
()
}
'
)
if
has_residual
:
if
has_x1
:
residual_pt
=
((
x0_pt
.
float
()
*
dmask0
.
float
())
/
(
1
-
dropout_p
)
+
(
x1_pt
.
float
()
*
dmask1
.
float
())
/
(
1
-
dropout_p
)
+
res_pt
.
float
()).
to
(
dtype
=
residual_dtype
)
residual_ref
=
((
x0_ref
*
dmask0
.
float
())
/
(
1
-
dropout_p
)
+
(
x1_ref
*
dmask1
.
float
())
/
(
1
-
dropout_p
))
+
res_ref
else
:
residual_pt
=
((
x0_pt
.
float
()
*
dmask0
.
float
())
/
(
1
-
dropout_p
)
+
res_pt
.
float
()).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
x0_ref
*
dmask0
.
float
())
/
(
1
-
dropout_p
)
+
res_ref
else
:
if
has_x1
:
residual_pt
=
((
x0_pt
.
float
()
*
dmask0
.
float
())
/
(
1
-
dropout_p
)
+
(
x1_pt
.
float
()
*
dmask1
.
float
())
/
(
1
-
dropout_p
)).
to
(
dtype
=
residual_dtype
)
residual_ref
=
((
x0_ref
*
dmask0
.
float
())
/
(
1
-
dropout_p
)
+
(
x1_ref
*
dmask1
.
float
())
/
(
1
-
dropout_p
))
else
:
residual_pt
=
((
x0_pt
.
float
()
*
dmask0
.
float
())
/
(
1
-
dropout_p
)).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
x0_ref
*
dmask0
.
float
())
/
(
1
-
dropout_p
)
if
not
is_rms_norm
:
out0_pt
=
F
.
layer_norm
(
residual_pt
.
to
(
dtype
=
weight_dtype
),
(
hidden_size
,),
weight0_pt
,
bias0_pt
,
eps
=
epsilon
).
to
(
dtype
=
input_dtype
)
out0_ref
=
F
.
layer_norm
(
residual_ref
,
(
hidden_size
,),
weight0_ref
,
bias0_ref
,
eps
=
epsilon
)
if
not
tied_norm
:
out1_pt
=
F
.
layer_norm
(
residual_pt
.
to
(
dtype
=
weight_dtype
),
(
hidden_size
,),
weight1_pt
,
bias1_pt
,
eps
=
epsilon
).
to
(
dtype
=
input_dtype
)
out1_ref
=
F
.
layer_norm
(
residual_ref
,
(
hidden_size
,),
weight1_ref
,
bias1_ref
,
eps
=
epsilon
)
else
:
out0_pt
=
fused_rms_norm_affine
(
residual_pt
.
to
(
dtype
=
weight_dtype
),
weight0_pt
,
(
hidden_size
,),
eps
=
epsilon
).
to
(
dtype
=
input_dtype
)
out0_ref
=
fused_rms_norm_affine
(
residual_ref
,
weight0_ref
,
(
hidden_size
,),
eps
=
epsilon
)
if
not
tied_norm
:
out1_pt
=
fused_rms_norm_affine
(
residual_pt
.
to
(
dtype
=
weight_dtype
),
weight1_pt
,
(
hidden_size
,),
eps
=
epsilon
).
to
(
dtype
=
input_dtype
)
out1_ref
=
fused_rms_norm_affine
(
residual_ref
,
weight1_ref
,
(
hidden_size
,),
eps
=
epsilon
)
assert
(
out0
-
out0_ref
).
abs
().
max
()
<=
4
*
(
out0_pt
-
out0_ref
).
abs
().
max
()
+
1e-4
if
not
tied_norm
:
assert
(
out1
-
out1_ref
).
abs
().
max
()
<=
4
*
(
out1_pt
-
out1_ref
).
abs
().
max
()
+
1e-4
g0
=
torch
.
randn_like
(
out0
)
/
batch_size
if
tied_norm
:
out0
.
backward
(
g0
)
out0_pt
.
backward
(
g0
)
out0_ref
.
backward
(
g0
)
else
:
g1
=
torch
.
randn_like
(
out1
)
/
batch_size
(
out0
*
g0
+
out1
*
g1
).
sum
().
backward
()
(
out0_pt
*
g0
+
out1_pt
*
g1
).
sum
().
backward
()
(
out0_ref
*
g0
+
out1_ref
*
g1
).
sum
().
backward
()
assert
(
x0
.
grad
-
x0_ref
.
grad
).
abs
().
max
()
<=
4
*
(
x0_pt
.
grad
-
x0_ref
.
grad
).
abs
().
max
()
+
1e-4
if
has_x1
:
assert
(
x1
.
grad
-
x1_ref
.
grad
).
abs
().
max
()
<=
4
*
(
x1_pt
.
grad
-
x1_ref
.
grad
).
abs
().
max
()
+
1e-4
if
has_residual
:
assert
(
res
.
grad
-
res_ref
.
grad
).
abs
().
max
()
<=
4
*
(
res_pt
.
grad
-
res_ref
.
grad
).
abs
().
max
()
+
1e-4
assert
(
weight0
.
grad
-
weight0_ref
.
grad
).
abs
().
max
()
<=
3
*
(
weight0_pt
.
grad
-
weight0_ref
.
grad
).
abs
().
max
()
+
3e-5
if
not
is_rms_norm
:
assert
(
bias0
.
grad
-
bias0_ref
.
grad
).
abs
().
max
()
<=
2
*
(
bias0_pt
.
grad
-
bias0_ref
.
grad
).
abs
().
max
()
+
3e-5
if
not
tied_norm
:
assert
(
weight1
.
grad
-
weight1_ref
.
grad
).
abs
().
max
()
<=
3
*
(
weight1_pt
.
grad
-
weight1_ref
.
grad
).
abs
().
max
()
+
3e-5
if
not
is_rms_norm
:
assert
(
bias1
.
grad
-
bias1_ref
.
grad
).
abs
().
max
()
<=
2
*
(
bias1_pt
.
grad
-
bias1_ref
.
grad
).
abs
().
max
()
+
3e-5
@
pytest
.
mark
.
parametrize
(
'is_rms_norm'
,
[
False
,
True
])
# @pytest.mark.parametrize('is_rms_norm', [False])
@
pytest
.
mark
.
parametrize
(
'tied_norm'
,
[
False
,
True
])
# @pytest.mark.parametrize('tied_norm', [False])
@
pytest
.
mark
.
parametrize
(
'has_residual'
,
[
True
,
False
])
# @pytest.mark.parametrize('has_residual', [False])
@
pytest
.
mark
.
parametrize
(
'has_x1'
,
[
True
,
False
])
# @pytest.mark.parametrize('has_x1', [True])
@
pytest
.
mark
.
parametrize
(
'dropout_p'
,
[
0.37
,
0.0
])
# @pytest.mark.parametrize('dropout_p', [0.0])
@
pytest
.
mark
.
parametrize
(
'weight_dtype'
,
[
torch
.
float32
,
torch
.
float16
])
# @pytest.mark.parametrize('weight_dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
'input_dtype,residual_dtype'
,
[(
torch
.
float16
,
torch
.
float16
),
(
torch
.
float16
,
torch
.
float32
),
(
torch
.
float32
,
torch
.
float32
)]
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]))
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
@
pytest
.
mark
.
parametrize
(
'hidden_size'
,
[
192
,
256
,
384
,
768
,
1024
,
1280
,
1536
,
1600
,
2048
,
2560
,
3000
,
3072
,
4096
,
5120
,
6144
])
# @pytest.mark.parametrize('hidden_size', [256])
def
test_dropout_layer_norm_parallel_residual_prenorm_training
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
dropout_p
,
has_x1
,
has_residual
,
tied_norm
,
is_rms_norm
):
if
weight_dtype
==
torch
.
float16
and
input_dtype
==
torch
.
bfloat16
:
pytest
.
skip
()
# Not supported
if
is_rms_norm
and
fused_rms_norm_affine
is
None
:
pytest
.
skip
()
# We need Apex's FusedRMSNorm to test
our_layer_norm_func
=
(
dropout_add_layer_norm_parallel_residual
if
not
is_rms_norm
else
dropout_add_rms_norm_parallel_residual
)
device
=
'cuda'
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol
,
atol
=
(
1e-3
,
1e-4
)
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
512
x0_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
hidden_size
,
device
=
device
,
dtype
=
input_dtype
,
requires_grad
=
True
)
x0
=
x0_pt
.
detach
().
clone
().
requires_grad_
()
x0_ref
=
x0_pt
.
detach
().
clone
().
float
().
requires_grad_
()
if
has_x1
:
x1_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
hidden_size
,
device
=
device
,
dtype
=
input_dtype
,
requires_grad
=
True
)
x1
=
x1_pt
.
detach
().
clone
().
requires_grad_
()
x1_ref
=
x1_pt
.
detach
().
clone
().
float
().
requires_grad_
()
else
:
x1
=
None
if
has_residual
:
res_pt
=
torch
.
randn_like
(
x0
,
dtype
=
residual_dtype
,
requires_grad
=
True
)
res
=
res_pt
.
detach
().
clone
().
requires_grad_
()
res_ref
=
res_pt
.
detach
().
clone
().
float
().
requires_grad_
()
else
:
res
=
None
weight0
=
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
bias0
=
(
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
if
not
is_rms_norm
else
None
)
weight0_pt
=
weight0
.
detach
().
clone
().
requires_grad_
()
weight0_ref
=
weight0
.
detach
().
clone
().
float
().
requires_grad_
()
bias0_pt
=
bias0
.
detach
().
clone
().
requires_grad_
()
if
bias0
is
not
None
else
None
bias0_ref
=
bias0
.
detach
().
clone
().
float
().
requires_grad_
()
if
bias0
is
not
None
else
None
if
not
tied_norm
:
weight1
=
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
bias1
=
(
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
if
not
is_rms_norm
else
None
)
weight1_pt
=
weight1
.
detach
().
clone
().
requires_grad_
()
weight1_ref
=
weight1
.
detach
().
clone
().
float
().
requires_grad_
()
bias1_pt
=
bias1
.
detach
().
clone
().
requires_grad_
()
if
bias1
is
not
None
else
None
bias1_ref
=
bias1
.
detach
().
clone
().
float
().
requires_grad_
()
if
bias1
is
not
None
else
None
else
:
weight1
,
bias1
=
None
,
None
epsilon
=
1e-5
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
out0
,
out1
,
residual
,
dmask0
,
dmask1
=
our_layer_norm_func
(
x0
,
x1
,
res
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
prenorm
=
True
,
residual_in_fp32
=
residual_in_fp32
,
return_dropout_mask
=
True
)
assert
out0
.
dtype
==
input_dtype
if
not
tied_norm
:
assert
out1
.
dtype
==
input_dtype
print
(
f
'Actual dropout fraction:
{
1
-
dmask0
.
float
().
mean
().
item
()
}
'
)
if
has_residual
:
if
has_x1
:
residual_pt
=
((
x0_pt
.
float
()
*
dmask0
.
float
())
/
(
1
-
dropout_p
)
+
(
x1_pt
.
float
()
*
dmask1
.
float
())
/
(
1
-
dropout_p
)
+
res_pt
.
float
()).
to
(
dtype
=
residual_dtype
)
residual_ref
=
((
x0_ref
*
dmask0
.
float
())
/
(
1
-
dropout_p
)
+
(
x1_ref
*
dmask1
.
float
())
/
(
1
-
dropout_p
))
+
res_ref
else
:
residual_pt
=
((
x0_pt
.
float
()
*
dmask0
.
float
())
/
(
1
-
dropout_p
)
+
res_pt
.
float
()).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
x0_ref
*
dmask0
.
float
())
/
(
1
-
dropout_p
)
+
res_ref
else
:
if
has_x1
:
residual_pt
=
((
x0_pt
.
float
()
*
dmask0
.
float
())
/
(
1
-
dropout_p
)
+
(
x1_pt
.
float
()
*
dmask1
.
float
())
/
(
1
-
dropout_p
)).
to
(
dtype
=
residual_dtype
)
residual_ref
=
((
x0_ref
*
dmask0
.
float
())
/
(
1
-
dropout_p
)
+
(
x1_ref
*
dmask1
.
float
())
/
(
1
-
dropout_p
))
else
:
residual_pt
=
((
x0_pt
.
float
()
*
dmask0
.
float
())
/
(
1
-
dropout_p
)).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
x0_ref
*
dmask0
.
float
())
/
(
1
-
dropout_p
)
if
not
is_rms_norm
:
out0_pt
=
F
.
layer_norm
(
residual_pt
.
to
(
dtype
=
weight_dtype
),
(
hidden_size
,),
weight0_pt
,
bias0_pt
,
eps
=
epsilon
).
to
(
dtype
=
input_dtype
)
out0_ref
=
F
.
layer_norm
(
residual_ref
,
(
hidden_size
,),
weight0_ref
,
bias0_ref
,
eps
=
epsilon
)
if
not
tied_norm
:
out1_pt
=
F
.
layer_norm
(
residual_pt
.
to
(
dtype
=
weight_dtype
),
(
hidden_size
,),
weight1_pt
,
bias1_pt
,
eps
=
epsilon
).
to
(
dtype
=
input_dtype
)
out1_ref
=
F
.
layer_norm
(
residual_ref
,
(
hidden_size
,),
weight1_ref
,
bias1_ref
,
eps
=
epsilon
)
else
:
out0_pt
=
fused_rms_norm_affine
(
residual_pt
.
to
(
dtype
=
weight_dtype
),
weight0_pt
,
(
hidden_size
,),
eps
=
epsilon
).
to
(
dtype
=
input_dtype
)
out0_ref
=
fused_rms_norm_affine
(
residual_ref
,
weight0_ref
,
(
hidden_size
,),
eps
=
epsilon
)
if
not
tied_norm
:
out1_pt
=
fused_rms_norm_affine
(
residual_pt
.
to
(
dtype
=
weight_dtype
),
weight1_pt
,
(
hidden_size
,),
eps
=
epsilon
).
to
(
dtype
=
input_dtype
)
out1_ref
=
fused_rms_norm_affine
(
residual_ref
,
weight1_ref
,
(
hidden_size
,),
eps
=
epsilon
)
assert
(
out0
-
out0_ref
).
abs
().
max
()
<=
4
*
(
out0_pt
-
out0_ref
).
abs
().
max
()
+
1e-4
if
not
tied_norm
:
assert
(
out1
-
out1_ref
).
abs
().
max
()
<=
4
*
(
out1_pt
-
out1_ref
).
abs
().
max
()
+
1e-4
assert
(
residual
-
residual_ref
).
abs
().
max
()
<=
4
*
(
residual_pt
-
residual_ref
).
abs
().
max
()
+
1e-4
g0
=
torch
.
randn_like
(
out0
)
/
batch_size
if
tied_norm
:
(
out0
*
F
.
sigmoid
(
residual
)).
backward
(
g0
)
(
out0_pt
*
F
.
sigmoid
(
residual_pt
)).
backward
(
g0
)
(
out0_ref
*
F
.
sigmoid
(
residual_ref
)).
backward
(
g0
)
else
:
g1
=
torch
.
randn_like
(
out1
)
/
batch_size
(
out0
*
F
.
sigmoid
(
residual
)
*
g0
+
out1
*
g1
).
sum
().
backward
()
(
out0_pt
*
F
.
sigmoid
(
residual_pt
)
*
g0
+
out1_pt
*
g1
).
sum
().
backward
()
(
out0_ref
*
F
.
sigmoid
(
residual_ref
)
*
g0
+
out1_ref
*
g1
).
sum
().
backward
()
assert
(
x0
.
grad
-
x0_ref
.
grad
).
abs
().
max
()
<=
4
*
(
x0_pt
.
grad
-
x0_ref
.
grad
).
abs
().
max
()
+
1e-4
if
has_x1
:
assert
(
x1
.
grad
-
x1_ref
.
grad
).
abs
().
max
()
<=
4
*
(
x1_pt
.
grad
-
x1_ref
.
grad
).
abs
().
max
()
+
1e-4
if
has_residual
:
assert
(
res
.
grad
-
res_ref
.
grad
).
abs
().
max
()
<=
4
*
(
res_pt
.
grad
-
res_ref
.
grad
).
abs
().
max
()
+
1e-4
assert
(
weight0
.
grad
-
weight0_ref
.
grad
).
abs
().
max
()
<=
3
*
(
weight0_pt
.
grad
-
weight0_ref
.
grad
).
abs
().
max
()
+
3e-5
if
not
is_rms_norm
:
assert
(
bias0
.
grad
-
bias0_ref
.
grad
).
abs
().
max
()
<=
2
*
(
bias0_pt
.
grad
-
bias0_ref
.
grad
).
abs
().
max
()
+
3e-5
if
not
tied_norm
:
assert
(
weight1
.
grad
-
weight1_ref
.
grad
).
abs
().
max
()
<=
3
*
(
weight1_pt
.
grad
-
weight1_ref
.
grad
).
abs
().
max
()
+
3e-5
if
not
is_rms_norm
:
assert
(
bias1
.
grad
-
bias1_ref
.
grad
).
abs
().
max
()
<=
2
*
(
bias1_pt
.
grad
-
bias1_ref
.
grad
).
abs
().
max
()
+
3e-5
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment