Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
cd089597
"...git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "2b3f89f0e203846a8c8d205487ad9362ca6c6cd2"
Commit
cd089597
authored
Dec 19, 2023
by
Tri Dao
Browse files
[LayerNorm] Implement dropout in fused residual + LN/RMSNorm
parent
713bd3aa
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
206 additions
and
37 deletions
+206
-37
flash_attn/ops/triton/layernorm.py
flash_attn/ops/triton/layernorm.py
+162
-21
tests/ops/triton/test_layer_norm.py
tests/ops/triton/test_layer_norm.py
+44
-16
No files found.
flash_attn/ops/triton/layernorm.py
View file @
cd089597
# Copyright (c) 2023, Tri Dao.
# Copyright (c) 2023, Tri Dao.
# Implement residual + layer_norm / rms_norm.
# Implement
dropout +
residual + layer_norm / rms_norm.
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
...
@@ -16,7 +16,17 @@ import triton
...
@@ -16,7 +16,17 @@ import triton
import
triton.language
as
tl
import
triton.language
as
tl
def
layer_norm_ref
(
x
,
weight
,
bias
,
residual
=
None
,
eps
=
1e-6
,
prenorm
=
False
,
upcast
=
False
):
def
layer_norm_ref
(
x
,
weight
,
bias
,
residual
=
None
,
eps
=
1e-6
,
dropout_p
=
0.0
,
prenorm
=
False
,
dropout_mask
=
None
,
upcast
=
False
,
):
dtype
=
x
.
dtype
dtype
=
x
.
dtype
if
upcast
:
if
upcast
:
weight
=
weight
.
float
()
weight
=
weight
.
float
()
...
@@ -24,6 +34,11 @@ def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upca
...
@@ -24,6 +34,11 @@ def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upca
if
upcast
:
if
upcast
:
x
=
x
.
float
()
x
=
x
.
float
()
residual
=
residual
.
float
()
if
residual
is
not
None
else
residual
residual
=
residual
.
float
()
if
residual
is
not
None
else
residual
if
dropout_p
>
0.0
:
if
dropout_mask
is
not
None
:
x
=
x
.
masked_fill
(
~
dropout_mask
,
0.0
)
/
(
1.0
-
dropout_p
)
else
:
x
=
F
.
dropout
(
x
,
p
=
dropout_p
)
if
residual
is
not
None
:
if
residual
is
not
None
:
x
=
(
x
+
residual
).
to
(
x
.
dtype
)
x
=
(
x
+
residual
).
to
(
x
.
dtype
)
out
=
F
.
layer_norm
(
x
.
to
(
weight
.
dtype
),
x
.
shape
[
-
1
:],
weight
=
weight
,
bias
=
bias
,
eps
=
eps
).
to
(
out
=
F
.
layer_norm
(
x
.
to
(
weight
.
dtype
),
x
.
shape
[
-
1
:],
weight
=
weight
,
bias
=
bias
,
eps
=
eps
).
to
(
...
@@ -32,7 +47,17 @@ def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upca
...
@@ -32,7 +47,17 @@ def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upca
return
out
if
not
prenorm
else
(
out
,
x
)
return
out
if
not
prenorm
else
(
out
,
x
)
def
rms_norm_ref
(
x
,
weight
,
bias
,
residual
=
None
,
eps
=
1e-6
,
prenorm
=
False
,
upcast
=
False
):
def
rms_norm_ref
(
x
,
weight
,
bias
,
residual
=
None
,
eps
=
1e-6
,
dropout_p
=
0.0
,
prenorm
=
False
,
dropout_mask
=
None
,
upcast
=
False
,
):
dtype
=
x
.
dtype
dtype
=
x
.
dtype
if
upcast
:
if
upcast
:
weight
=
weight
.
float
()
weight
=
weight
.
float
()
...
@@ -40,6 +65,11 @@ def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast
...
@@ -40,6 +65,11 @@ def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast
if
upcast
:
if
upcast
:
x
=
x
.
float
()
x
=
x
.
float
()
residual
=
residual
.
float
()
if
residual
is
not
None
else
residual
residual
=
residual
.
float
()
if
residual
is
not
None
else
residual
if
dropout_p
>
0.0
:
if
dropout_mask
is
not
None
:
x
=
x
.
masked_fill
(
~
dropout_mask
,
0.0
)
/
(
1.0
-
dropout_p
)
else
:
x
=
F
.
dropout
(
x
,
p
=
dropout_p
)
if
residual
is
not
None
:
if
residual
is
not
None
:
x
=
(
x
+
residual
).
to
(
x
.
dtype
)
x
=
(
x
+
residual
).
to
(
x
.
dtype
)
rstd
=
1
/
torch
.
sqrt
((
x
.
square
()).
mean
(
dim
=-
1
,
keepdim
=
True
)
+
eps
)
rstd
=
1
/
torch
.
sqrt
((
x
.
square
()).
mean
(
dim
=-
1
,
keepdim
=
True
)
+
eps
)
...
@@ -69,6 +99,8 @@ def _layer_norm_fwd_1pass_kernel(
...
@@ -69,6 +99,8 @@ def _layer_norm_fwd_1pass_kernel(
B
,
# pointer to the biases
B
,
# pointer to the biases
RESIDUAL
,
# pointer to the residual
RESIDUAL
,
# pointer to the residual
RESIDUAL_OUT
,
# pointer to the residual
RESIDUAL_OUT
,
# pointer to the residual
SEEDS
,
# Dropout seeds for each row
DROPOUT_MASK
,
Mean
,
# pointer to the mean
Mean
,
# pointer to the mean
Rstd
,
# pointer to the 1/std
Rstd
,
# pointer to the 1/std
stride_x_row
,
# how much to increase the pointer when moving by 1 row
stride_x_row
,
# how much to increase the pointer when moving by 1 row
...
@@ -77,11 +109,14 @@ def _layer_norm_fwd_1pass_kernel(
...
@@ -77,11 +109,14 @@ def _layer_norm_fwd_1pass_kernel(
stride_res_out_row
,
stride_res_out_row
,
N
,
# number of columns in X
N
,
# number of columns in X
eps
,
# epsilon to avoid division by zero
eps
,
# epsilon to avoid division by zero
dropout_p
,
# Dropout probability
IS_RMS_NORM
:
tl
.
constexpr
,
IS_RMS_NORM
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
HAS_RESIDUAL
:
tl
.
constexpr
,
HAS_RESIDUAL
:
tl
.
constexpr
,
STORE_RESIDUAL_OUT
:
tl
.
constexpr
,
STORE_RESIDUAL_OUT
:
tl
.
constexpr
,
HAS_BIAS
:
tl
.
constexpr
,
HAS_BIAS
:
tl
.
constexpr
,
HAS_DROPOUT
:
tl
.
constexpr
,
STORE_DROPOUT_MASK
:
tl
.
constexpr
,
):
):
# Map the program id to the row of X and Y it should compute.
# Map the program id to the row of X and Y it should compute.
row
=
tl
.
program_id
(
0
)
row
=
tl
.
program_id
(
0
)
...
@@ -94,6 +129,13 @@ def _layer_norm_fwd_1pass_kernel(
...
@@ -94,6 +129,13 @@ def _layer_norm_fwd_1pass_kernel(
# Compute mean and variance
# Compute mean and variance
cols
=
tl
.
arange
(
0
,
BLOCK_N
)
cols
=
tl
.
arange
(
0
,
BLOCK_N
)
x
=
tl
.
load
(
X
+
cols
,
mask
=
cols
<
N
,
other
=
0.0
).
to
(
tl
.
float32
)
x
=
tl
.
load
(
X
+
cols
,
mask
=
cols
<
N
,
other
=
0.0
).
to
(
tl
.
float32
)
if
HAS_DROPOUT
:
# Compute dropout mask
# 7 rounds is good enough, and reduces register pressure
keep_mask
=
tl
.
rand
(
tl
.
load
(
SEEDS
+
row
).
to
(
tl
.
uint32
),
cols
,
n_rounds
=
7
)
>
dropout_p
x
=
tl
.
where
(
keep_mask
,
x
/
(
1.0
-
dropout_p
),
0.0
)
if
STORE_DROPOUT_MASK
:
tl
.
store
(
DROPOUT_MASK
+
row
*
N
+
cols
,
keep_mask
,
mask
=
cols
<
N
)
if
HAS_RESIDUAL
:
if
HAS_RESIDUAL
:
residual
=
tl
.
load
(
RESIDUAL
+
cols
,
mask
=
cols
<
N
,
other
=
0.0
).
to
(
tl
.
float32
)
residual
=
tl
.
load
(
RESIDUAL
+
cols
,
mask
=
cols
<
N
,
other
=
0.0
).
to
(
tl
.
float32
)
x
+=
residual
x
+=
residual
...
@@ -121,7 +163,16 @@ def _layer_norm_fwd_1pass_kernel(
...
@@ -121,7 +163,16 @@ def _layer_norm_fwd_1pass_kernel(
def
_layer_norm_fwd
(
def
_layer_norm_fwd
(
x
,
weight
,
bias
,
eps
,
residual
=
None
,
out_dtype
=
None
,
residual_dtype
=
None
,
is_rms_norm
=
False
x
,
weight
,
bias
,
eps
,
residual
=
None
,
dropout_p
=
0.0
,
out_dtype
=
None
,
residual_dtype
=
None
,
is_rms_norm
=
False
,
return_dropout_mask
=
False
,
):
):
if
residual
is
not
None
:
if
residual
is
not
None
:
residual_dtype
=
residual
.
dtype
residual_dtype
=
residual
.
dtype
...
@@ -138,13 +189,27 @@ def _layer_norm_fwd(
...
@@ -138,13 +189,27 @@ def _layer_norm_fwd(
# allocate output
# allocate output
y
=
torch
.
empty_like
(
x
,
dtype
=
x
.
dtype
if
out_dtype
is
None
else
out_dtype
)
y
=
torch
.
empty_like
(
x
,
dtype
=
x
.
dtype
if
out_dtype
is
None
else
out_dtype
)
assert
y
.
stride
(
-
1
)
==
1
assert
y
.
stride
(
-
1
)
==
1
if
residual
is
not
None
or
(
residual_dtype
is
not
None
and
residual_dtype
!=
x
.
dtype
):
if
(
residual_out
=
torch
.
empty
(
M
,
N
,
device
=
x
.
device
,
dtype
=
residual_dtype
)
residual
is
not
None
or
(
residual_dtype
is
not
None
and
residual_dtype
!=
x
.
dtype
)
or
dropout_p
>
0.0
):
residual_out
=
torch
.
empty
(
M
,
N
,
device
=
x
.
device
,
dtype
=
residual_dtype
if
residual_dtype
is
not
None
else
x
.
dtype
)
assert
residual_out
.
stride
(
-
1
)
==
1
assert
residual_out
.
stride
(
-
1
)
==
1
else
:
else
:
residual_out
=
None
residual_out
=
None
mean
=
torch
.
empty
((
M
,),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
if
not
is_rms_norm
else
None
mean
=
torch
.
empty
((
M
,),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
if
not
is_rms_norm
else
None
rstd
=
torch
.
empty
((
M
,),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
rstd
=
torch
.
empty
((
M
,),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
if
dropout_p
>
0.0
:
seeds
=
torch
.
randint
(
2
**
32
,
(
M
,),
device
=
x
.
device
,
dtype
=
torch
.
int64
)
else
:
seeds
=
None
if
return_dropout_mask
and
dropout_p
>
0.0
:
dropout_mask
=
torch
.
empty_like
(
x
,
dtype
=
torch
.
bool
)
else
:
dropout_mask
=
None
# Less than 64KB per feature: enqueue fused kernel
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE
=
65536
//
x
.
element_size
()
MAX_FUSED_SIZE
=
65536
//
x
.
element_size
()
BLOCK_N
=
min
(
MAX_FUSED_SIZE
,
triton
.
next_power_of_2
(
N
))
BLOCK_N
=
min
(
MAX_FUSED_SIZE
,
triton
.
next_power_of_2
(
N
))
...
@@ -159,6 +224,8 @@ def _layer_norm_fwd(
...
@@ -159,6 +224,8 @@ def _layer_norm_fwd(
bias
,
bias
,
residual
,
residual
,
residual_out
,
residual_out
,
seeds
,
dropout_mask
,
mean
,
mean
,
rstd
,
rstd
,
x
.
stride
(
0
),
x
.
stride
(
0
),
...
@@ -167,14 +234,17 @@ def _layer_norm_fwd(
...
@@ -167,14 +234,17 @@ def _layer_norm_fwd(
residual_out
.
stride
(
0
)
if
residual_out
is
not
None
else
0
,
residual_out
.
stride
(
0
)
if
residual_out
is
not
None
else
0
,
N
,
N
,
eps
,
eps
,
dropout_p
,
is_rms_norm
,
is_rms_norm
,
BLOCK_N
,
BLOCK_N
,
residual
is
not
None
,
residual
is
not
None
,
residual_out
is
not
None
,
residual_out
is
not
None
,
bias
is
not
None
,
bias
is
not
None
,
dropout_p
>
0.0
,
dropout_mask
is
not
None
,
)
)
# residual_out is None if residual is None and residual_dtype == input_dtype
# residual_out is None if residual is None and residual_dtype == input_dtype
and dropout_p == 0.0
return
y
,
mean
,
rstd
,
residual_out
if
residual_out
is
not
None
else
x
return
y
,
mean
,
rstd
,
residual_out
if
residual_out
is
not
None
else
x
,
seeds
,
dropout_mask
@
triton
.
autotune
(
@
triton
.
autotune
(
...
@@ -186,7 +256,7 @@ def _layer_norm_fwd(
...
@@ -186,7 +256,7 @@ def _layer_norm_fwd(
triton
.
Config
({},
num_warps
=
16
),
triton
.
Config
({},
num_warps
=
16
),
triton
.
Config
({},
num_warps
=
32
),
triton
.
Config
({},
num_warps
=
32
),
],
],
key
=
[
"N"
,
"HAS_DRESIDUAL"
,
"STORE_DRESIDUAL"
,
"IS_RMS_NORM"
,
"HAS_BIAS"
],
key
=
[
"N"
,
"HAS_DRESIDUAL"
,
"STORE_DRESIDUAL"
,
"IS_RMS_NORM"
,
"HAS_BIAS"
,
"HAS_DROPOUT"
],
)
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
...
@@ -204,6 +274,7 @@ def _layer_norm_bwd_kernel(
...
@@ -204,6 +274,7 @@ def _layer_norm_bwd_kernel(
DB
,
# pointer to the partial sum of biases gradient
DB
,
# pointer to the partial sum of biases gradient
DRESIDUAL
,
DRESIDUAL
,
DRESIDUAL_IN
,
DRESIDUAL_IN
,
SEEDS
,
Mean
,
# pointer to the mean
Mean
,
# pointer to the mean
Rstd
,
# pointer to the 1/std
Rstd
,
# pointer to the 1/std
stride_x_row
,
# how much to increase the pointer when moving by 1 row
stride_x_row
,
# how much to increase the pointer when moving by 1 row
...
@@ -215,12 +286,14 @@ def _layer_norm_bwd_kernel(
...
@@ -215,12 +286,14 @@ def _layer_norm_bwd_kernel(
M
,
# number of rows in X
M
,
# number of rows in X
N
,
# number of columns in X
N
,
# number of columns in X
eps
,
# epsilon to avoid division by zero
eps
,
# epsilon to avoid division by zero
dropout_p
,
rows_per_program
,
rows_per_program
,
IS_RMS_NORM
:
tl
.
constexpr
,
IS_RMS_NORM
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
HAS_DRESIDUAL
:
tl
.
constexpr
,
HAS_DRESIDUAL
:
tl
.
constexpr
,
STORE_DRESIDUAL
:
tl
.
constexpr
,
STORE_DRESIDUAL
:
tl
.
constexpr
,
HAS_BIAS
:
tl
.
constexpr
,
HAS_BIAS
:
tl
.
constexpr
,
HAS_DROPOUT
:
tl
.
constexpr
,
RECOMPUTE_OUTPUT
:
tl
.
constexpr
,
RECOMPUTE_OUTPUT
:
tl
.
constexpr
,
):
):
# Map the program id to the elements of X, DX, and DY it should compute.
# Map the program id to the elements of X, DX, and DY it should compute.
...
@@ -274,6 +347,9 @@ def _layer_norm_bwd_kernel(
...
@@ -274,6 +347,9 @@ def _layer_norm_bwd_kernel(
# Write dx
# Write dx
if
STORE_DRESIDUAL
:
if
STORE_DRESIDUAL
:
tl
.
store
(
DRESIDUAL_IN
+
cols
,
dx
,
mask
=
mask
)
tl
.
store
(
DRESIDUAL_IN
+
cols
,
dx
,
mask
=
mask
)
if
HAS_DROPOUT
:
keep_mask
=
tl
.
rand
(
tl
.
load
(
SEEDS
+
row
).
to
(
tl
.
uint32
),
cols
,
n_rounds
=
7
)
>
dropout_p
dx
=
tl
.
where
(
keep_mask
,
dx
/
(
1.0
-
dropout_p
),
0.0
)
tl
.
store
(
DX
+
cols
,
dx
,
mask
=
mask
)
tl
.
store
(
DX
+
cols
,
dx
,
mask
=
mask
)
X
+=
stride_x_row
X
+=
stride_x_row
...
@@ -299,6 +375,8 @@ def _layer_norm_bwd(
...
@@ -299,6 +375,8 @@ def _layer_norm_bwd(
mean
,
mean
,
rstd
,
rstd
,
dresidual
=
None
,
dresidual
=
None
,
seeds
=
None
,
dropout_p
=
0.0
,
has_residual
=
False
,
has_residual
=
False
,
is_rms_norm
=
False
,
is_rms_norm
=
False
,
x_dtype
=
None
,
x_dtype
=
None
,
...
@@ -316,13 +394,18 @@ def _layer_norm_bwd(
...
@@ -316,13 +394,18 @@ def _layer_norm_bwd(
if
bias
is
not
None
:
if
bias
is
not
None
:
assert
bias
.
stride
(
-
1
)
==
1
assert
bias
.
stride
(
-
1
)
==
1
assert
bias
.
shape
==
(
N
,)
assert
bias
.
shape
==
(
N
,)
if
seeds
is
not
None
:
assert
seeds
.
is_contiguous
()
assert
seeds
.
shape
==
(
M
,)
# allocate output
# allocate output
dx
=
(
dx
=
(
torch
.
empty_like
(
x
)
torch
.
empty_like
(
x
)
if
x_dtype
is
None
if
x_dtype
is
None
else
torch
.
empty
(
M
,
N
,
dtype
=
x_dtype
,
device
=
x
.
device
)
else
torch
.
empty
(
M
,
N
,
dtype
=
x_dtype
,
device
=
x
.
device
)
)
)
dresidual_in
=
torch
.
empty_like
(
x
)
if
has_residual
and
dx
.
dtype
!=
x
.
dtype
else
None
dresidual_in
=
(
torch
.
empty_like
(
x
)
if
has_residual
and
(
dx
.
dtype
!=
x
.
dtype
or
dropout_p
>
0.0
)
else
None
)
y
=
torch
.
empty
(
M
,
N
,
dtype
=
dy
.
dtype
,
device
=
dy
.
device
)
if
recompute_output
else
None
y
=
torch
.
empty
(
M
,
N
,
dtype
=
dy
.
dtype
,
device
=
dy
.
device
)
if
recompute_output
else
None
# Less than 64KB per feature: enqueue fused kernel
# Less than 64KB per feature: enqueue fused kernel
...
@@ -351,6 +434,7 @@ def _layer_norm_bwd(
...
@@ -351,6 +434,7 @@ def _layer_norm_bwd(
_db
,
_db
,
dresidual
,
dresidual
,
dresidual_in
,
dresidual_in
,
seeds
,
mean
,
mean
,
rstd
,
rstd
,
x
.
stride
(
0
),
x
.
stride
(
0
),
...
@@ -362,17 +446,19 @@ def _layer_norm_bwd(
...
@@ -362,17 +446,19 @@ def _layer_norm_bwd(
M
,
M
,
N
,
N
,
eps
,
eps
,
dropout_p
,
rows_per_program
,
rows_per_program
,
is_rms_norm
,
is_rms_norm
,
BLOCK_N
,
BLOCK_N
,
dresidual
is
not
None
,
dresidual
is
not
None
,
dresidual_in
is
not
None
,
dresidual_in
is
not
None
,
bias
is
not
None
,
bias
is
not
None
,
dropout_p
>
0.0
,
)
)
dw
=
_dw
.
sum
(
0
).
to
(
weight
.
dtype
)
dw
=
_dw
.
sum
(
0
).
to
(
weight
.
dtype
)
db
=
_db
.
sum
(
0
).
to
(
bias
.
dtype
)
if
bias
is
not
None
else
None
db
=
_db
.
sum
(
0
).
to
(
bias
.
dtype
)
if
bias
is
not
None
else
None
# Don't need to compute dresidual_in separately in this case
# Don't need to compute dresidual_in separately in this case
if
has_residual
and
dx
.
dtype
==
x
.
dtype
:
if
has_residual
and
dx
.
dtype
==
x
.
dtype
and
dropout_p
==
0.0
:
dresidual_in
=
dx
dresidual_in
=
dx
return
(
dx
,
dw
,
db
,
dresidual_in
)
if
not
recompute_output
else
(
dx
,
dw
,
db
,
dresidual_in
,
y
)
return
(
dx
,
dw
,
db
,
dresidual_in
)
if
not
recompute_output
else
(
dx
,
dw
,
db
,
dresidual_in
,
y
)
...
@@ -386,9 +472,11 @@ class LayerNormFn(torch.autograd.Function):
...
@@ -386,9 +472,11 @@ class LayerNormFn(torch.autograd.Function):
bias
,
bias
,
residual
=
None
,
residual
=
None
,
eps
=
1e-6
,
eps
=
1e-6
,
dropout_p
=
0.0
,
prenorm
=
False
,
prenorm
=
False
,
residual_in_fp32
=
False
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
,
is_rms_norm
=
False
,
return_dropout_mask
=
False
,
):
):
x_shape_og
=
x
.
shape
x_shape_og
=
x
.
shape
# reshape input data into 2D tensor
# reshape input data into 2D tensor
...
@@ -408,22 +496,36 @@ class LayerNormFn(torch.autograd.Function):
...
@@ -408,22 +496,36 @@ class LayerNormFn(torch.autograd.Function):
if
residual
is
not
None
if
residual
is
not
None
else
(
torch
.
float32
if
residual_in_fp32
else
None
)
else
(
torch
.
float32
if
residual_in_fp32
else
None
)
)
)
y
,
mean
,
rstd
,
residual_out
=
_layer_norm_fwd
(
y
,
mean
,
rstd
,
residual_out
,
seeds
,
dropout_mask
=
_layer_norm_fwd
(
x
,
weight
,
bias
,
eps
,
residual
,
residual_dtype
=
residual_dtype
,
is_rms_norm
=
is_rms_norm
x
,
weight
,
bias
,
eps
,
residual
,
dropout_p
=
dropout_p
,
residual_dtype
=
residual_dtype
,
is_rms_norm
=
is_rms_norm
,
return_dropout_mask
=
return_dropout_mask
,
)
)
ctx
.
save_for_backward
(
residual_out
,
weight
,
bias
,
mean
,
rstd
)
ctx
.
save_for_backward
(
residual_out
,
weight
,
bias
,
seeds
,
mean
,
rstd
)
ctx
.
x_shape_og
=
x_shape_og
ctx
.
x_shape_og
=
x_shape_og
ctx
.
eps
=
eps
ctx
.
eps
=
eps
ctx
.
dropout_p
=
dropout_p
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
has_residual
=
residual
is
not
None
ctx
.
has_residual
=
residual
is
not
None
ctx
.
prenorm
=
prenorm
ctx
.
prenorm
=
prenorm
ctx
.
x_dtype
=
x
.
dtype
ctx
.
x_dtype
=
x
.
dtype
y
=
y
.
reshape
(
x_shape_og
)
y
=
y
.
reshape
(
x_shape_og
)
return
y
if
not
prenorm
else
(
y
,
residual_out
.
reshape
(
x_shape_og
))
residual_out
=
residual_out
.
reshape
(
x_shape_og
)
if
residual_out
is
not
None
else
None
dropout_mask
=
dropout_mask
.
reshape
(
x_shape_og
)
if
dropout_mask
is
not
None
else
None
if
not
return_dropout_mask
:
return
y
if
not
prenorm
else
(
y
,
residual_out
)
else
:
return
(
y
,
dropout_mask
)
if
not
prenorm
else
(
y
,
residual_out
,
dropout_mask
)
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
dy
,
*
args
):
def
backward
(
ctx
,
dy
,
*
args
):
x
,
weight
,
bias
,
mean
,
rstd
=
ctx
.
saved_tensors
x
,
weight
,
bias
,
seeds
,
mean
,
rstd
=
ctx
.
saved_tensors
dy
=
dy
.
reshape
(
-
1
,
dy
.
shape
[
-
1
])
dy
=
dy
.
reshape
(
-
1
,
dy
.
shape
[
-
1
])
if
dy
.
stride
(
-
1
)
!=
1
:
if
dy
.
stride
(
-
1
)
!=
1
:
dy
=
dy
.
contiguous
()
dy
=
dy
.
contiguous
()
...
@@ -445,6 +547,8 @@ class LayerNormFn(torch.autograd.Function):
...
@@ -445,6 +547,8 @@ class LayerNormFn(torch.autograd.Function):
mean
,
mean
,
rstd
,
rstd
,
dresidual
,
dresidual
,
seeds
,
ctx
.
dropout_p
,
ctx
.
has_residual
,
ctx
.
has_residual
,
ctx
.
is_rms_norm
,
ctx
.
is_rms_norm
,
x_dtype
=
ctx
.
x_dtype
,
x_dtype
=
ctx
.
x_dtype
,
...
@@ -458,6 +562,8 @@ class LayerNormFn(torch.autograd.Function):
...
@@ -458,6 +562,8 @@ class LayerNormFn(torch.autograd.Function):
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
)
)
...
@@ -467,22 +573,57 @@ def layer_norm_fn(
...
@@ -467,22 +573,57 @@ def layer_norm_fn(
bias
,
bias
,
residual
=
None
,
residual
=
None
,
eps
=
1e-6
,
eps
=
1e-6
,
dropout_p
=
0.0
,
prenorm
=
False
,
prenorm
=
False
,
residual_in_fp32
=
False
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
,
is_rms_norm
=
False
,
return_dropout_mask
=
False
,
):
):
return
LayerNormFn
.
apply
(
x
,
weight
,
bias
,
residual
,
eps
,
prenorm
,
residual_in_fp32
,
is_rms_norm
)
return
LayerNormFn
.
apply
(
x
,
weight
,
bias
,
residual
,
eps
,
dropout_p
,
prenorm
,
residual_in_fp32
,
is_rms_norm
,
return_dropout_mask
,
)
def
rms_norm_fn
(
x
,
weight
,
bias
,
residual
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
eps
=
1e-6
):
def
rms_norm_fn
(
return
LayerNormFn
.
apply
(
x
,
weight
,
bias
,
residual
,
eps
,
prenorm
,
residual_in_fp32
,
True
)
x
,
weight
,
bias
,
residual
=
None
,
eps
=
1e-6
,
dropout_p
=
0.0
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
,
):
return
LayerNormFn
.
apply
(
x
,
weight
,
bias
,
residual
,
eps
,
dropout_p
,
prenorm
,
residual_in_fp32
,
True
,
return_dropout_mask
,
)
class
RMSNorm
(
torch
.
nn
.
Module
):
class
RMSNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-5
,
device
=
None
,
dtype
=
None
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-5
,
dropout_p
=
0.0
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
super
().
__init__
()
self
.
eps
=
eps
self
.
eps
=
eps
self
.
dropout_p
=
dropout_p
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
register_parameter
(
"bias"
,
None
)
self
.
register_parameter
(
"bias"
,
None
)
self
.
reset_parameters
()
self
.
reset_parameters
()
...
@@ -497,9 +638,9 @@ class RMSNorm(torch.nn.Module):
...
@@ -497,9 +638,9 @@ class RMSNorm(torch.nn.Module):
self
.
bias
,
self
.
bias
,
residual
=
residual
,
residual
=
residual
,
eps
=
self
.
eps
,
eps
=
self
.
eps
,
dropout_p
=
self
.
dropout_p
if
self
.
training
else
0.0
,
prenorm
=
prenorm
,
prenorm
=
prenorm
,
residual_in_fp32
=
residual_in_fp32
,
residual_in_fp32
=
residual_in_fp32
,
is_rms_norm
=
True
,
)
)
...
...
tests/ops/triton/test_layer_norm.py
View file @
cd089597
...
@@ -16,12 +16,14 @@ from flash_attn.ops.triton.layernorm import (
...
@@ -16,12 +16,14 @@ from flash_attn.ops.triton.layernorm import (
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)[
0
]
>=
8
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)[
0
]
>=
8
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.27
])
# @pytest.mark.parametrize("dropout_p", [0.27])
@
pytest
.
mark
.
parametrize
(
"prenorm"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"prenorm"
,
[
True
,
False
])
# @pytest.mark.parametrize("prenorm", [
Tru
e])
# @pytest.mark.parametrize("prenorm", [
Fals
e])
@
pytest
.
mark
.
parametrize
(
"is_rms_norm"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"is_rms_norm"
,
[
False
,
True
])
# @pytest.mark.parametrize("is_rms_norm", [True])
# @pytest.mark.parametrize("is_rms_norm", [True])
@
pytest
.
mark
.
parametrize
(
"has_residual"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"has_residual"
,
[
True
,
False
])
# @pytest.mark.parametrize("has_residual", [
Fals
e])
# @pytest.mark.parametrize("has_residual", [
Tru
e])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"weight_dtype"
,
[
torch
.
float32
,
torch
.
float16
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[])
"weight_dtype"
,
[
torch
.
float32
,
torch
.
float16
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[])
)
)
...
@@ -31,11 +33,18 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
...
@@ -31,11 +33,18 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
[(
torch
.
float16
,
torch
.
float16
),
(
torch
.
float16
,
torch
.
float32
),
(
torch
.
float32
,
torch
.
float32
)]
[(
torch
.
float16
,
torch
.
float16
),
(
torch
.
float16
,
torch
.
float32
),
(
torch
.
float32
,
torch
.
float32
)]
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]),
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]),
)
)
# @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.
b
float16, torch.float
32
)])
# @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.float16, torch.float
16
)])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
192
,
2048
,
2560
,
3000
,
8192
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
192
,
2048
,
2560
,
3000
,
4096
])
# @pytest.mark.parametrize("hidden_size", [256])
# @pytest.mark.parametrize("hidden_size", [256])
def
test_layer_norm
(
def
test_layer_norm
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
has_residual
,
is_rms_norm
,
prenorm
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
has_residual
,
is_rms_norm
,
prenorm
,
dropout_p
,
):
):
device
=
"cuda"
device
=
"cuda"
if
any
(
x
==
torch
.
bfloat16
for
x
in
[
input_dtype
,
residual_dtype
,
weight_dtype
]):
if
any
(
x
==
torch
.
bfloat16
for
x
in
[
input_dtype
,
residual_dtype
,
weight_dtype
]):
...
@@ -48,8 +57,6 @@ def test_layer_norm(
...
@@ -48,8 +57,6 @@ def test_layer_norm(
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
batch_size
=
8
seqlen
=
512
seqlen
=
512
# batch_size = 1
# seqlen = 1
layer_norm_ref_fn
=
layer_norm_ref
if
not
is_rms_norm
else
rms_norm_ref
layer_norm_ref_fn
=
layer_norm_ref
if
not
is_rms_norm
else
rms_norm_ref
allclose
=
(
allclose
=
(
lambda
x
,
x_pt
,
x_ref
,
atol
=
atol
:
(
x
-
x_ref
).
abs
().
max
()
lambda
x
,
x_pt
,
x_ref
,
atol
=
atol
:
(
x
-
x_ref
).
abs
().
max
()
...
@@ -83,25 +90,46 @@ def test_layer_norm(
...
@@ -83,25 +90,46 @@ def test_layer_norm(
bias
,
bias
,
residual
=
res
,
residual
=
res
,
eps
=
1e-6
,
eps
=
1e-6
,
dropout_p
=
dropout_p
,
prenorm
=
prenorm
,
prenorm
=
prenorm
,
residual_in_fp32
=
residual_in_fp32
,
residual_in_fp32
=
residual_in_fp32
,
is_rms_norm
=
is_rms_norm
,
is_rms_norm
=
is_rms_norm
,
return_dropout_mask
=
True
,
)
)
out_pt
,
*
rest_pt
=
layer_norm_ref_fn
(
dropout_mask
=
rest
[
-
1
]
if
dropout_p
>
0.0
else
None
x0_pt
,
weight_pt
,
bias_pt
,
residual
=
res_pt
,
eps
=
1e-6
,
prenorm
=
prenorm
out_pt
=
layer_norm_ref_fn
(
x0_pt
,
weight_pt
,
bias_pt
,
residual
=
res_pt
,
eps
=
1e-6
,
dropout_p
=
dropout_p
,
prenorm
=
prenorm
,
dropout_mask
=
dropout_mask
,
)
)
out_ref
,
*
rest_ref
=
layer_norm_ref_fn
(
out_ref
=
layer_norm_ref_fn
(
x0_ref
,
weight_ref
,
bias_ref
,
residual
=
res_ref
,
eps
=
1e-6
,
prenorm
=
prenorm
,
upcast
=
True
x0_ref
,
weight_ref
,
bias_ref
,
residual
=
res_ref
,
eps
=
1e-6
,
dropout_p
=
dropout_p
,
prenorm
=
prenorm
,
dropout_mask
=
dropout_mask
,
upcast
=
True
,
)
)
if
prenorm
:
if
prenorm
:
residual
=
rest
[
0
]
residual
=
rest
[
0
]
residual_pt
=
res
t_pt
[
0
]
out_pt
,
residual_pt
=
ou
t_pt
residual_ref
=
res
t_ref
[
0
]
out_ref
,
residual_ref
=
ou
t_ref
assert
out
.
dtype
==
input_dtype
assert
out
.
dtype
==
input_dtype
if
prenorm
:
if
prenorm
:
assert
residual
.
dtype
==
residual_dtype
assert
residual
.
dtype
==
residual_dtype
assert
allclose
(
residual
,
residual_pt
,
residual_ref
)
assert
allclose
(
residual
,
residual_pt
,
residual_ref
)
assert
allclose
(
out
,
out_pt
,
out_ref
)
assert
allclose
(
out
,
out_pt
,
out_ref
)
if
dropout_mask
is
not
None
:
dropout_fraction
=
1.0
-
dropout_mask
.
float
().
mean
()
assert
abs
(
dropout_fraction
-
dropout_p
)
<
0.01
g
=
torch
.
randn_like
(
out
)
/
batch_size
g
=
torch
.
randn_like
(
out
)
/
batch_size
if
not
prenorm
:
if
not
prenorm
:
...
@@ -128,9 +156,9 @@ def test_layer_norm(
...
@@ -128,9 +156,9 @@ def test_layer_norm(
# @pytest.mark.parametrize("has_residual", [False])
# @pytest.mark.parametrize("has_residual", [False])
@
pytest
.
mark
.
parametrize
(
"weight_dtype"
,
[
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
"weight_dtype"
,
[
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"input_dtype,residual_dtype"
,
"input_dtype,residual_dtype"
,
[(
torch
.
float16
,
torch
.
float16
),
(
torch
.
float16
,
torch
.
float32
)]
[(
torch
.
float16
,
torch
.
float16
),
(
torch
.
float16
,
torch
.
float32
)]
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]),
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]),
)
)
# @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.bfloat16, torch.float32)])
# @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.bfloat16, torch.float32)])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
192
,
2048
,
2560
,
3000
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
192
,
2048
,
2560
,
3000
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment