Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
665b55e2
Commit
665b55e2
authored
Jan 04, 2024
by
Tri Dao
Browse files
[LayerNorm] Implement parallel layer norm in Triton
parent
aa5c6438
Changes
2
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
381 additions
and
32 deletions
+381
-32
flash_attn/ops/triton/layernorm.py
flash_attn/ops/triton/layernorm.py
+290
-23
tests/ops/triton/test_layer_norm.py
tests/ops/triton/test_layer_norm.py
+91
-9
No files found.
flash_attn/ops/triton/layernorm.py
View file @
665b55e2
This diff is collapsed.
Click to expand it.
tests/ops/triton/test_layer_norm.py
View file @
665b55e2
...
@@ -16,12 +16,16 @@ from flash_attn.ops.triton.layernorm import (
...
@@ -16,12 +16,16 @@ from flash_attn.ops.triton.layernorm import (
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)[
0
]
>=
8
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)[
0
]
>=
8
@
pytest
.
mark
.
parametrize
(
"has_weight1"
,
[
False
,
True
])
# @pytest.mark.parametrize("has_weight1", [True])
@
pytest
.
mark
.
parametrize
(
"has_x1"
,
[
False
,
True
])
# @pytest.mark.parametrize("has_x1", [False])
@
pytest
.
mark
.
parametrize
(
"has_rowscale"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"has_rowscale"
,
[
False
,
True
])
# @pytest.mark.parametrize("has_rowscale", [
Tru
e])
# @pytest.mark.parametrize("has_rowscale", [
Fals
e])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.27
])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.27
])
# @pytest.mark.parametrize("dropout_p", [0.0])
# @pytest.mark.parametrize("dropout_p", [0.0])
@
pytest
.
mark
.
parametrize
(
"prenorm"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"prenorm"
,
[
True
,
False
])
# @pytest.mark.parametrize("prenorm", [
Tru
e])
# @pytest.mark.parametrize("prenorm", [
Fals
e])
@
pytest
.
mark
.
parametrize
(
"is_rms_norm"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"is_rms_norm"
,
[
False
,
True
])
# @pytest.mark.parametrize("is_rms_norm", [True])
# @pytest.mark.parametrize("is_rms_norm", [True])
@
pytest
.
mark
.
parametrize
(
"has_residual"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"has_residual"
,
[
True
,
False
])
...
@@ -48,7 +52,11 @@ def test_layer_norm(
...
@@ -48,7 +52,11 @@ def test_layer_norm(
prenorm
,
prenorm
,
dropout_p
,
dropout_p
,
has_rowscale
,
has_rowscale
,
has_x1
,
has_weight1
,
):
):
if
has_rowscale
and
has_x1
:
pytest
.
skip
(
"Not supported"
)
device
=
"cuda"
device
=
"cuda"
if
any
(
x
==
torch
.
bfloat16
for
x
in
[
input_dtype
,
residual_dtype
,
weight_dtype
]):
if
any
(
x
==
torch
.
bfloat16
for
x
in
[
input_dtype
,
residual_dtype
,
weight_dtype
]):
atol
=
5e-2
atol
=
5e-2
...
@@ -62,9 +70,16 @@ def test_layer_norm(
...
@@ -62,9 +70,16 @@ def test_layer_norm(
seqlen
=
512
seqlen
=
512
layer_norm_ref_fn
=
layer_norm_ref
if
not
is_rms_norm
else
rms_norm_ref
layer_norm_ref_fn
=
layer_norm_ref
if
not
is_rms_norm
else
rms_norm_ref
allclose
=
(
allclose
=
(
lambda
x
,
x_pt
,
x_ref
,
atol
=
atol
:
(
x
-
x_ref
).
abs
().
max
()
# Sometimes x0_pt.grad is NaN
# Sometimes x0_pt.grad is NaN
lambda
x
,
x_pt
,
x_ref
,
atol
=
atol
:
(
x
-
x_ref
).
abs
().
max
()
<=
2
*
(
x_pt
[
~
x_pt
.
isnan
()]
-
x_ref
[
~
x_pt
.
isnan
()]).
abs
().
max
()
+
atol
<=
2
*
(
x_pt
[
~
x_pt
.
isnan
()]
-
x_ref
[
~
x_pt
.
isnan
()]).
abs
().
max
()
+
atol
or
(
# Sometimes x_pt and x_ref are the same (e.g. bfloat16) so we want to perturb is a bit
# by multiply and divide by 0.3
(
x_pt
[
~
x_pt
.
isnan
()]
-
x_ref
[
~
x_pt
.
isnan
()]).
abs
().
max
()
==
0.0
and
(
x
-
x_ref
).
abs
().
max
()
<=
2
*
(
x_pt
[
~
x_pt
.
isnan
()]
*
0.3
/
0.3
-
x_ref
[
~
x_pt
.
isnan
()]).
abs
().
max
()
+
atol
)
)
)
x0
=
torch
.
randn
(
x0
=
torch
.
randn
(
batch_size
,
seqlen
,
hidden_size
,
device
=
device
,
dtype
=
input_dtype
,
requires_grad
=
True
batch_size
,
seqlen
,
hidden_size
,
device
=
device
,
dtype
=
input_dtype
,
requires_grad
=
True
...
@@ -86,8 +101,35 @@ def test_layer_norm(
...
@@ -86,8 +101,35 @@ def test_layer_norm(
weight_ref
=
weight
.
detach
().
clone
().
requires_grad_
()
weight_ref
=
weight
.
detach
().
clone
().
requires_grad_
()
bias_pt
=
bias
.
detach
().
clone
().
requires_grad_
()
if
bias
is
not
None
else
None
bias_pt
=
bias
.
detach
().
clone
().
requires_grad_
()
if
bias
is
not
None
else
None
bias_ref
=
bias
.
detach
().
clone
().
requires_grad_
()
if
bias
is
not
None
else
None
bias_ref
=
bias
.
detach
().
clone
().
requires_grad_
()
if
bias
is
not
None
else
None
if
has_x1
:
x1
=
torch
.
randn_like
(
x0
,
dtype
=
input_dtype
,
requires_grad
=
True
)
x1_pt
=
x1
.
detach
().
clone
().
requires_grad_
()
x1_ref
=
x1
.
detach
().
clone
().
requires_grad_
()
else
:
x1
,
x1_pt
,
x1_ref
=
None
,
None
,
None
if
has_weight1
:
weight1
=
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
weight1_pt
=
weight1
.
detach
().
clone
().
requires_grad_
()
weight1_ref
=
weight1
.
detach
().
clone
().
requires_grad_
()
if
not
is_rms_norm
:
bias1
=
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
else
:
bias1
=
None
bias1_pt
=
bias1
.
detach
().
clone
().
requires_grad_
()
if
bias1
is
not
None
else
None
bias1_ref
=
bias1
.
detach
().
clone
().
requires_grad_
()
if
bias1
is
not
None
else
None
else
:
weight1
,
weight1_pt
,
weight1_ref
=
None
,
None
,
None
bias1
,
bias1_pt
,
bias1_ref
=
None
,
None
,
None
rowscale
=
torch
.
randn
(
batch_size
,
seqlen
,
dtype
=
input_dtype
,
device
=
device
)
if
has_rowscale
else
None
rowscale
=
(
torch
.
randn
(
batch_size
,
seqlen
,
dtype
=
input_dtype
,
device
=
device
)
if
has_rowscale
else
None
)
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
out
,
*
rest
=
layer_norm_fn
(
out
,
*
rest
=
layer_norm_fn
(
...
@@ -95,6 +137,9 @@ def test_layer_norm(
...
@@ -95,6 +137,9 @@ def test_layer_norm(
weight
,
weight
,
bias
,
bias
,
residual
=
res
,
residual
=
res
,
x1
=
x1
,
weight1
=
weight1
,
bias1
=
bias1
,
eps
=
1e-6
,
eps
=
1e-6
,
dropout_p
=
dropout_p
,
dropout_p
=
dropout_p
,
rowscale
=
rowscale
,
rowscale
=
rowscale
,
...
@@ -103,44 +148,75 @@ def test_layer_norm(
...
@@ -103,44 +148,75 @@ def test_layer_norm(
is_rms_norm
=
is_rms_norm
,
is_rms_norm
=
is_rms_norm
,
return_dropout_mask
=
True
,
return_dropout_mask
=
True
,
)
)
dropout_mask
=
rest
[
-
1
]
if
dropout_p
>
0.0
else
None
dropout_mask
=
rest
[
-
2
]
if
dropout_p
>
0.0
else
None
dropout_mask1
=
rest
[
-
1
]
if
dropout_p
>
0.0
and
x1
is
not
None
else
None
out_pt
=
layer_norm_ref_fn
(
out_pt
=
layer_norm_ref_fn
(
x0_pt
,
x0_pt
,
weight_pt
,
weight_pt
,
bias_pt
,
bias_pt
,
residual
=
res_pt
,
residual
=
res_pt
,
x1
=
x1_pt
,
weight1
=
weight1_pt
,
bias1
=
bias1_pt
,
eps
=
1e-6
,
eps
=
1e-6
,
dropout_p
=
dropout_p
,
dropout_p
=
dropout_p
,
rowscale
=
rowscale
,
rowscale
=
rowscale
,
prenorm
=
prenorm
,
prenorm
=
prenorm
,
dropout_mask
=
dropout_mask
,
dropout_mask
=
dropout_mask
,
dropout_mask1
=
dropout_mask1
,
)
)
out_ref
=
layer_norm_ref_fn
(
out_ref
=
layer_norm_ref_fn
(
x0_ref
,
x0_ref
,
weight_ref
,
weight_ref
,
bias_ref
,
bias_ref
,
residual
=
res_ref
,
residual
=
res_ref
,
x1
=
x1_ref
,
weight1
=
weight1_ref
,
bias1
=
bias1_ref
,
eps
=
1e-6
,
eps
=
1e-6
,
dropout_p
=
dropout_p
,
dropout_p
=
dropout_p
,
rowscale
=
rowscale
,
rowscale
=
rowscale
,
prenorm
=
prenorm
,
prenorm
=
prenorm
,
dropout_mask
=
dropout_mask
,
dropout_mask
=
dropout_mask
,
dropout_mask1
=
dropout_mask1
,
upcast
=
True
,
upcast
=
True
,
)
)
if
not
has_weight1
:
if
prenorm
:
if
prenorm
:
residual
=
rest
[
0
]
residual
=
rest
[
0
]
out_pt
,
residual_pt
=
out_pt
out_pt
,
residual_pt
=
out_pt
out_ref
,
residual_ref
=
out_ref
out_ref
,
residual_ref
=
out_ref
out1
,
out1_pt
,
out1_ref
=
None
,
None
,
None
else
:
out1
=
rest
.
pop
(
0
)
if
prenorm
:
residual
=
rest
[
0
]
out_pt
,
out1_pt
,
residual_pt
=
out_pt
out_ref
,
out1_ref
,
residual_ref
=
out_ref
else
:
out_pt
,
out1_pt
=
out_pt
out_ref
,
out1_ref
=
out_ref
assert
out
.
dtype
==
input_dtype
assert
out
.
dtype
==
input_dtype
if
prenorm
:
if
prenorm
:
assert
residual
.
dtype
==
residual_dtype
assert
residual
.
dtype
==
residual_dtype
assert
allclose
(
residual
,
residual_pt
,
residual_ref
)
assert
allclose
(
residual
,
residual_pt
,
residual_ref
)
assert
allclose
(
out
,
out_pt
,
out_ref
)
assert
allclose
(
out
,
out_pt
,
out_ref
)
if
out1
is
not
None
:
assert
out1
.
dtype
==
input_dtype
assert
allclose
(
out1
,
out1_pt
,
out1_ref
)
if
dropout_mask
is
not
None
:
if
dropout_mask
is
not
None
:
dropout_fraction
=
1.0
-
dropout_mask
.
float
().
mean
()
dropout_fraction
=
1.0
-
dropout_mask
.
float
().
mean
()
assert
abs
(
dropout_fraction
-
dropout_p
)
<
0.01
assert
abs
(
dropout_fraction
-
dropout_p
)
<
0.01
if
dropout_mask1
is
not
None
:
dropout_fraction
=
1.0
-
dropout_mask1
.
float
().
mean
()
assert
abs
(
dropout_fraction
-
dropout_p
)
<
0.01
assert
not
torch
.
equal
(
dropout_mask
,
dropout_mask1
)
g
=
torch
.
randn_like
(
out
)
/
batch_size
g
=
torch
.
randn_like
(
out
)
/
batch_size
if
has_weight1
:
out
=
out
*
F
.
gelu
(
out1
)
out_pt
=
out_pt
*
F
.
gelu
(
out1_pt
)
out_ref
=
out_ref
*
F
.
gelu
(
out1_ref
)
if
not
prenorm
:
if
not
prenorm
:
out
.
backward
(
g
)
out
.
backward
(
g
)
out_pt
.
backward
(
g
)
out_pt
.
backward
(
g
)
...
@@ -152,9 +228,15 @@ def test_layer_norm(
...
@@ -152,9 +228,15 @@ def test_layer_norm(
assert
allclose
(
x0
.
grad
,
x0_pt
.
grad
,
x0_ref
.
grad
)
assert
allclose
(
x0
.
grad
,
x0_pt
.
grad
,
x0_ref
.
grad
)
if
has_residual
:
if
has_residual
:
assert
allclose
(
res
.
grad
,
res_pt
.
grad
,
res_ref
.
grad
)
assert
allclose
(
res
.
grad
,
res_pt
.
grad
,
res_ref
.
grad
)
if
has_x1
:
assert
allclose
(
x1
.
grad
,
x1_pt
.
grad
,
x1_ref
.
grad
)
assert
allclose
(
weight
.
grad
,
weight_pt
.
grad
,
weight_ref
.
grad
)
assert
allclose
(
weight
.
grad
,
weight_pt
.
grad
,
weight_ref
.
grad
)
if
bias
is
not
None
:
if
bias
is
not
None
:
assert
allclose
(
bias
.
grad
,
bias_pt
.
grad
,
bias_ref
.
grad
)
assert
allclose
(
bias
.
grad
,
bias_pt
.
grad
,
bias_ref
.
grad
)
if
has_weight1
:
assert
allclose
(
weight1
.
grad
,
weight1_pt
.
grad
,
weight1_ref
.
grad
)
if
bias1
is
not
None
:
assert
allclose
(
bias1
.
grad
,
bias1_pt
.
grad
,
bias1_ref
.
grad
)
@
pytest
.
mark
.
parametrize
(
"prenorm"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"prenorm"
,
[
True
,
False
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment