Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
9356a1c0
Commit
9356a1c0
authored
Nov 30, 2023
by
Tri Dao
Browse files
[LayerNorm] Implement layer_norm_linear
parent
92dd5703
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
279 additions
and
9 deletions
+279
-9
flash_attn/ops/triton/layernorm.py
flash_attn/ops/triton/layernorm.py
+138
-2
tests/ops/triton/test_layer_norm.py
tests/ops/triton/test_layer_norm.py
+141
-7
No files found.
flash_attn/ops/triton/layernorm.py
View file @
9356a1c0
...
...
@@ -10,6 +10,7 @@ import math
import
torch
import
torch.nn.functional
as
F
from
torch.cuda.amp
import
custom_fwd
,
custom_bwd
import
triton
import
triton.language
as
tl
...
...
@@ -119,7 +120,9 @@ def _layer_norm_fwd_1pass_kernel(
tl
.
store
(
Y
+
cols
,
y
,
mask
=
mask
)
def
_layer_norm_fwd
(
x
,
weight
,
bias
,
eps
,
residual
=
None
,
residual_dtype
=
None
,
is_rms_norm
=
False
):
def
_layer_norm_fwd
(
x
,
weight
,
bias
,
eps
,
residual
=
None
,
out_dtype
=
None
,
residual_dtype
=
None
,
is_rms_norm
=
False
):
if
residual
is
not
None
:
residual_dtype
=
residual
.
dtype
M
,
N
=
x
.
shape
...
...
@@ -133,7 +136,7 @@ def _layer_norm_fwd(x, weight, bias, eps, residual=None, residual_dtype=None, is
assert
bias
.
stride
(
-
1
)
==
1
assert
bias
.
shape
==
(
N
,)
# allocate output
y
=
torch
.
empty_like
(
x
)
y
=
torch
.
empty_like
(
x
,
dtype
=
x
.
dtype
if
out_dtype
is
None
else
out_dtype
)
assert
y
.
stride
(
-
1
)
==
1
if
residual
is
not
None
or
(
residual_dtype
is
not
None
and
residual_dtype
!=
x
.
dtype
):
residual_out
=
torch
.
empty
(
M
,
N
,
device
=
x
.
device
,
dtype
=
residual_dtype
)
...
...
@@ -498,3 +501,136 @@ class RMSNorm(torch.nn.Module):
residual_in_fp32
=
residual_in_fp32
,
is_rms_norm
=
True
,
)
class
LayerNormLinearFn
(
torch
.
autograd
.
Function
):
@
staticmethod
@
custom_fwd
def
forward
(
ctx
,
x
,
norm_weight
,
norm_bias
,
linear_weight
,
linear_bias
,
residual
=
None
,
eps
=
1e-6
,
prenorm
=
False
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
,
):
x_shape_og
=
x
.
shape
# reshape input data into 2D tensor
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
if
x
.
stride
(
-
1
)
!=
1
:
x
=
x
.
contiguous
()
if
residual
is
not
None
:
assert
residual
.
shape
==
x_shape_og
residual
=
residual
.
reshape
(
-
1
,
residual
.
shape
[
-
1
])
if
residual
.
stride
(
-
1
)
!=
1
:
residual
=
residual
.
contiguous
()
norm_weight
=
norm_weight
.
contiguous
()
if
norm_bias
is
not
None
:
norm_bias
=
norm_bias
.
contiguous
()
residual_dtype
=
(
residual
.
dtype
if
residual
is
not
None
else
(
torch
.
float32
if
residual_in_fp32
else
None
)
)
y
,
mean
,
rstd
,
residual_out
=
_layer_norm_fwd
(
x
,
norm_weight
,
norm_bias
,
eps
,
residual
,
out_dtype
=
None
if
not
torch
.
is_autocast_enabled
()
else
torch
.
get_autocast_gpu_dtype
(),
residual_dtype
=
residual_dtype
,
is_rms_norm
=
is_rms_norm
,
)
y
=
y
.
reshape
(
x_shape_og
)
dtype
=
torch
.
get_autocast_gpu_dtype
()
if
torch
.
is_autocast_enabled
()
else
y
.
dtype
linear_weight
=
linear_weight
.
to
(
dtype
)
linear_bias
=
linear_bias
.
to
(
dtype
)
if
linear_bias
is
not
None
else
None
out
=
F
.
linear
(
y
.
to
(
linear_weight
.
dtype
),
linear_weight
,
linear_bias
)
# We don't store y, will be recomputed in the backward pass to save memory
ctx
.
save_for_backward
(
residual_out
,
norm_weight
,
norm_bias
,
linear_weight
,
mean
,
rstd
)
ctx
.
x_shape_og
=
x_shape_og
ctx
.
eps
=
eps
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
has_residual
=
residual
is
not
None
ctx
.
prenorm
=
prenorm
ctx
.
x_dtype
=
x
.
dtype
ctx
.
linear_bias_is_none
=
linear_bias
is
None
return
out
if
not
prenorm
else
(
out
,
residual_out
.
reshape
(
x_shape_og
))
@
staticmethod
@
custom_bwd
def
backward
(
ctx
,
dout
,
*
args
):
x
,
norm_weight
,
norm_bias
,
linear_weight
,
mean
,
rstd
=
ctx
.
saved_tensors
dout
=
dout
.
reshape
(
-
1
,
dout
.
shape
[
-
1
])
dy
=
F
.
linear
(
dout
,
linear_weight
.
t
())
dlinear_bias
=
None
if
ctx
.
linear_bias_is_none
else
dout
.
sum
(
0
)
if
dy
.
stride
(
-
1
)
!=
1
:
dy
=
dy
.
contiguous
()
assert
dy
.
shape
==
x
.
shape
if
ctx
.
prenorm
:
dresidual
=
args
[
0
]
dresidual
=
dresidual
.
reshape
(
-
1
,
dresidual
.
shape
[
-
1
])
if
dresidual
.
stride
(
-
1
)
!=
1
:
dresidual
=
dresidual
.
contiguous
()
assert
dresidual
.
shape
==
x
.
shape
else
:
dresidual
=
None
dx
,
dnorm_weight
,
dnorm_bias
,
dresidual_in
,
y
=
_layer_norm_bwd
(
dy
,
x
,
norm_weight
,
norm_bias
,
ctx
.
eps
,
mean
,
rstd
,
dresidual
,
ctx
.
has_residual
,
ctx
.
is_rms_norm
,
x_dtype
=
ctx
.
x_dtype
,
recompute_output
=
True
,
)
dlinear_weight
=
torch
.
einsum
(
"bo,bi->oi"
,
dout
,
y
)
return
(
dx
.
reshape
(
ctx
.
x_shape_og
),
dnorm_weight
,
dnorm_bias
,
dlinear_weight
,
dlinear_bias
,
dresidual_in
.
reshape
(
ctx
.
x_shape_og
)
if
ctx
.
has_residual
else
None
,
None
,
None
,
None
,
None
,
)
def
layer_norm_linear_fn
(
x
,
norm_weight
,
norm_bias
,
linear_weight
,
linear_bias
,
residual
=
None
,
eps
=
1e-6
,
prenorm
=
False
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
,
):
return
LayerNormLinearFn
.
apply
(
x
,
norm_weight
,
norm_bias
,
linear_weight
,
linear_bias
,
residual
,
eps
,
prenorm
,
residual_in_fp32
,
is_rms_norm
,
)
tests/ops/triton/test_layer_norm.py
View file @
9356a1c0
import
math
from
functools
import
partial
# Copyright (c) 2023, Tri Dao.
import
pytest
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
from
flash_attn.ops.triton.layernorm
import
layer_norm_fn
,
layer_norm_ref
,
rms_norm_ref
from
flash_attn.ops.triton.layernorm
import
(
layer_norm_fn
,
layer_norm_ref
,
rms_norm_ref
,
layer_norm_linear_fn
,
)
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)[
0
]
>=
8
...
...
@@ -18,13 +23,13 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@
pytest
.
mark
.
parametrize
(
"has_residual"
,
[
True
,
False
])
# @pytest.mark.parametrize("has_residual", [False])
@
pytest
.
mark
.
parametrize
(
"weight_dtype"
,
[
torch
.
float32
,
torch
.
float16
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[])
"weight_dtype"
,
[
torch
.
float32
,
torch
.
float16
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[])
)
# @pytest.mark.parametrize("weight_dtype", [torch.float32])
@
pytest
.
mark
.
parametrize
(
"input_dtype,residual_dtype"
,
[(
torch
.
float16
,
torch
.
float16
),
(
torch
.
float16
,
torch
.
float32
),
(
torch
.
float32
,
torch
.
float32
)]
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]),
"input_dtype,residual_dtype"
,
[(
torch
.
float16
,
torch
.
float16
),
(
torch
.
float16
,
torch
.
float32
),
(
torch
.
float32
,
torch
.
float32
)]
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]),
)
# @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.bfloat16, torch.float32)])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
192
,
2048
,
2560
,
3000
,
8192
])
...
...
@@ -113,3 +118,132 @@ def test_layer_norm(
assert
allclose
(
weight
.
grad
,
weight_pt
.
grad
,
weight_ref
.
grad
)
if
bias
is
not
None
:
assert
allclose
(
bias
.
grad
,
bias_pt
.
grad
,
bias_ref
.
grad
)
@
pytest
.
mark
.
parametrize
(
"prenorm"
,
[
True
,
False
])
# @pytest.mark.parametrize("prenorm", [True])
@
pytest
.
mark
.
parametrize
(
"is_rms_norm"
,
[
False
,
True
])
# @pytest.mark.parametrize("is_rms_norm", [True])
@
pytest
.
mark
.
parametrize
(
"has_residual"
,
[
True
,
False
])
# @pytest.mark.parametrize("has_residual", [False])
@
pytest
.
mark
.
parametrize
(
"weight_dtype"
,
[
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
"input_dtype,residual_dtype"
,
[(
torch
.
float16
,
torch
.
float16
),
(
torch
.
float16
,
torch
.
float32
)]
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]),
)
# @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.bfloat16, torch.float32)])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
192
,
2048
,
2560
,
3000
])
# @pytest.mark.parametrize("hidden_size", [256])
def
test_layer_norm_linear
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
has_residual
,
is_rms_norm
,
prenorm
):
device
=
"cuda"
if
any
(
x
==
torch
.
bfloat16
for
x
in
[
input_dtype
,
residual_dtype
,
weight_dtype
]):
atol
=
5e-2
elif
any
(
x
==
torch
.
float16
for
x
in
[
input_dtype
,
residual_dtype
,
weight_dtype
]):
atol
=
1e-2
else
:
atol
=
1e-4
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
4
seqlen
=
512
# batch_size = 1
# seqlen = 1
layer_norm_ref_fn
=
layer_norm_ref
if
not
is_rms_norm
else
rms_norm_ref
allclose
=
(
lambda
x
,
x_pt
,
x_ref
,
atol
=
atol
:
(
x
-
x_ref
).
abs
().
max
()
<=
2
*
(
x_pt
-
x_ref
).
abs
().
max
()
+
atol
)
x0
=
torch
.
randn
(
batch_size
,
seqlen
,
hidden_size
,
device
=
device
,
dtype
=
input_dtype
,
requires_grad
=
True
)
x0_pt
=
x0
.
detach
().
clone
().
requires_grad_
()
x0_ref
=
x0
.
detach
().
clone
().
requires_grad_
()
if
has_residual
:
res
=
torch
.
randn_like
(
x0
,
dtype
=
residual_dtype
,
requires_grad
=
True
)
res_pt
=
res
.
detach
().
clone
().
requires_grad_
()
res_ref
=
res
.
detach
().
clone
().
requires_grad_
()
else
:
res
,
res_pt
,
res_ref
=
None
,
None
,
None
norm_weight
=
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
if
not
is_rms_norm
:
norm_bias
=
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
else
:
norm_bias
=
None
norm_weight_pt
=
norm_weight
.
detach
().
clone
().
requires_grad_
()
norm_weight_ref
=
norm_weight
.
detach
().
clone
().
requires_grad_
()
norm_bias_pt
=
norm_bias
.
detach
().
clone
().
requires_grad_
()
if
norm_bias
is
not
None
else
None
norm_bias_ref
=
norm_bias
.
detach
().
clone
().
requires_grad_
()
if
norm_bias
is
not
None
else
None
linear_weight
=
torch
.
empty
(
2
*
hidden_size
,
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
torch
.
nn
.
init
.
xavier_uniform_
(
linear_weight
)
if
not
is_rms_norm
:
linear_bias
=
torch
.
randn
(
2
*
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
else
:
linear_bias
=
None
linear_weight_pt
=
linear_weight
.
detach
().
clone
().
requires_grad_
()
linear_weight_ref
=
linear_weight
.
detach
().
clone
().
requires_grad_
()
linear_bias_pt
=
(
linear_bias
.
detach
().
clone
().
requires_grad_
()
if
linear_bias
is
not
None
else
None
)
linear_bias_ref
=
(
linear_bias
.
detach
().
clone
().
requires_grad_
()
if
linear_bias
is
not
None
else
None
)
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
input_dtype
):
out
,
*
rest
=
layer_norm_linear_fn
(
x0
,
norm_weight
,
norm_bias
,
linear_weight
,
linear_bias
,
residual
=
res
,
eps
=
1e-6
,
prenorm
=
prenorm
,
residual_in_fp32
=
residual_in_fp32
,
is_rms_norm
=
is_rms_norm
,
)
out_pt
,
*
rest_pt
=
layer_norm_ref_fn
(
x0_pt
,
norm_weight_pt
,
norm_bias_pt
,
residual
=
res_pt
,
eps
=
1e-6
,
prenorm
=
prenorm
)
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
input_dtype
):
out_pt
=
F
.
linear
(
out_pt
,
linear_weight_pt
,
linear_bias_pt
)
out_ref
,
*
rest_ref
=
layer_norm_ref_fn
(
x0_ref
,
norm_weight_ref
,
norm_bias_ref
,
residual
=
res_ref
,
eps
=
1e-6
,
prenorm
=
prenorm
,
upcast
=
True
,
)
out_ref
=
F
.
linear
(
out_ref
.
to
(
linear_weight_ref
.
dtype
),
linear_weight_ref
,
linear_bias_ref
)
if
prenorm
:
residual
=
rest
[
0
]
residual_pt
=
rest_pt
[
0
]
residual_ref
=
rest_ref
[
0
]
assert
out
.
dtype
==
input_dtype
if
prenorm
:
assert
residual
.
dtype
==
residual_dtype
assert
allclose
(
residual
,
residual_pt
,
residual_ref
)
assert
allclose
(
out
,
out_pt
,
out_ref
)
g
=
torch
.
randn_like
(
out
)
/
batch_size
out
.
backward
(
g
)
out_pt
.
backward
(
g
)
out_ref
.
backward
(
g
)
assert
allclose
(
x0
.
grad
,
x0_pt
.
grad
,
x0_ref
.
grad
)
if
has_residual
:
assert
allclose
(
res
.
grad
,
res_pt
.
grad
,
res_ref
.
grad
)
assert
allclose
(
norm_weight
.
grad
,
norm_weight_pt
.
grad
,
norm_weight_ref
.
grad
)
if
norm_bias
is
not
None
:
assert
allclose
(
norm_bias
.
grad
,
norm_bias_pt
.
grad
,
norm_bias_ref
.
grad
)
assert
allclose
(
linear_weight
.
grad
,
linear_weight_pt
.
grad
,
linear_weight_ref
.
grad
)
if
linear_bias
is
not
None
:
assert
allclose
(
linear_bias
.
grad
,
linear_bias_pt
.
grad
,
linear_bias_ref
.
grad
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment