Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
aa5c6438
Commit
aa5c6438
authored
Jan 04, 2024
by
Tri Dao
Browse files
[LayerNorm] Implement rowscale in Triton layernorm
parent
386e3911
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
62 additions
and
10 deletions
+62
-10
flash_attn/ops/triton/layernorm.py
flash_attn/ops/triton/layernorm.py
+48
-5
tests/ops/triton/test_layer_norm.py
tests/ops/triton/test_layer_norm.py
+14
-5
No files found.
flash_attn/ops/triton/layernorm.py
View file @
aa5c6438
# Copyright (c) 202
3
, Tri Dao.
# Copyright (c) 202
4
, Tri Dao.
# Implement dropout + residual + layer_norm / rms_norm.
# Implement dropout + residual + layer_norm / rms_norm.
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
...
@@ -23,6 +23,7 @@ def layer_norm_ref(
...
@@ -23,6 +23,7 @@ def layer_norm_ref(
residual
=
None
,
residual
=
None
,
eps
=
1e-6
,
eps
=
1e-6
,
dropout_p
=
0.0
,
dropout_p
=
0.0
,
rowscale
=
None
,
prenorm
=
False
,
prenorm
=
False
,
dropout_mask
=
None
,
dropout_mask
=
None
,
upcast
=
False
,
upcast
=
False
,
...
@@ -34,6 +35,8 @@ def layer_norm_ref(
...
@@ -34,6 +35,8 @@ def layer_norm_ref(
if
upcast
:
if
upcast
:
x
=
x
.
float
()
x
=
x
.
float
()
residual
=
residual
.
float
()
if
residual
is
not
None
else
residual
residual
=
residual
.
float
()
if
residual
is
not
None
else
residual
if
rowscale
is
not
None
:
x
=
x
*
rowscale
[...,
None
]
if
dropout_p
>
0.0
:
if
dropout_p
>
0.0
:
if
dropout_mask
is
not
None
:
if
dropout_mask
is
not
None
:
x
=
x
.
masked_fill
(
~
dropout_mask
,
0.0
)
/
(
1.0
-
dropout_p
)
x
=
x
.
masked_fill
(
~
dropout_mask
,
0.0
)
/
(
1.0
-
dropout_p
)
...
@@ -54,6 +57,7 @@ def rms_norm_ref(
...
@@ -54,6 +57,7 @@ def rms_norm_ref(
residual
=
None
,
residual
=
None
,
eps
=
1e-6
,
eps
=
1e-6
,
dropout_p
=
0.0
,
dropout_p
=
0.0
,
rowscale
=
None
,
prenorm
=
False
,
prenorm
=
False
,
dropout_mask
=
None
,
dropout_mask
=
None
,
upcast
=
False
,
upcast
=
False
,
...
@@ -65,6 +69,8 @@ def rms_norm_ref(
...
@@ -65,6 +69,8 @@ def rms_norm_ref(
if
upcast
:
if
upcast
:
x
=
x
.
float
()
x
=
x
.
float
()
residual
=
residual
.
float
()
if
residual
is
not
None
else
residual
residual
=
residual
.
float
()
if
residual
is
not
None
else
residual
if
rowscale
is
not
None
:
x
=
x
*
rowscale
[...,
None
]
if
dropout_p
>
0.0
:
if
dropout_p
>
0.0
:
if
dropout_mask
is
not
None
:
if
dropout_mask
is
not
None
:
x
=
x
.
masked_fill
(
~
dropout_mask
,
0.0
)
/
(
1.0
-
dropout_p
)
x
=
x
.
masked_fill
(
~
dropout_mask
,
0.0
)
/
(
1.0
-
dropout_p
)
...
@@ -99,6 +105,7 @@ def _layer_norm_fwd_1pass_kernel(
...
@@ -99,6 +105,7 @@ def _layer_norm_fwd_1pass_kernel(
B
,
# pointer to the biases
B
,
# pointer to the biases
RESIDUAL
,
# pointer to the residual
RESIDUAL
,
# pointer to the residual
RESIDUAL_OUT
,
# pointer to the residual
RESIDUAL_OUT
,
# pointer to the residual
ROWSCALE
,
SEEDS
,
# Dropout seeds for each row
SEEDS
,
# Dropout seeds for each row
DROPOUT_MASK
,
DROPOUT_MASK
,
Mean
,
# pointer to the mean
Mean
,
# pointer to the mean
...
@@ -117,6 +124,7 @@ def _layer_norm_fwd_1pass_kernel(
...
@@ -117,6 +124,7 @@ def _layer_norm_fwd_1pass_kernel(
HAS_BIAS
:
tl
.
constexpr
,
HAS_BIAS
:
tl
.
constexpr
,
HAS_DROPOUT
:
tl
.
constexpr
,
HAS_DROPOUT
:
tl
.
constexpr
,
STORE_DROPOUT_MASK
:
tl
.
constexpr
,
STORE_DROPOUT_MASK
:
tl
.
constexpr
,
HAS_ROWSCALE
:
tl
.
constexpr
,
):
):
# Map the program id to the row of X and Y it should compute.
# Map the program id to the row of X and Y it should compute.
row
=
tl
.
program_id
(
0
)
row
=
tl
.
program_id
(
0
)
...
@@ -129,6 +137,9 @@ def _layer_norm_fwd_1pass_kernel(
...
@@ -129,6 +137,9 @@ def _layer_norm_fwd_1pass_kernel(
# Compute mean and variance
# Compute mean and variance
cols
=
tl
.
arange
(
0
,
BLOCK_N
)
cols
=
tl
.
arange
(
0
,
BLOCK_N
)
x
=
tl
.
load
(
X
+
cols
,
mask
=
cols
<
N
,
other
=
0.0
).
to
(
tl
.
float32
)
x
=
tl
.
load
(
X
+
cols
,
mask
=
cols
<
N
,
other
=
0.0
).
to
(
tl
.
float32
)
if
HAS_ROWSCALE
:
rowscale
=
tl
.
load
(
ROWSCALE
+
row
).
to
(
tl
.
float32
)
x
*=
rowscale
if
HAS_DROPOUT
:
if
HAS_DROPOUT
:
# Compute dropout mask
# Compute dropout mask
# 7 rounds is good enough, and reduces register pressure
# 7 rounds is good enough, and reduces register pressure
...
@@ -169,6 +180,7 @@ def _layer_norm_fwd(
...
@@ -169,6 +180,7 @@ def _layer_norm_fwd(
eps
,
eps
,
residual
=
None
,
residual
=
None
,
dropout_p
=
0.0
,
dropout_p
=
0.0
,
rowscale
=
None
,
out_dtype
=
None
,
out_dtype
=
None
,
residual_dtype
=
None
,
residual_dtype
=
None
,
is_rms_norm
=
False
,
is_rms_norm
=
False
,
...
@@ -186,6 +198,9 @@ def _layer_norm_fwd(
...
@@ -186,6 +198,9 @@ def _layer_norm_fwd(
if
bias
is
not
None
:
if
bias
is
not
None
:
assert
bias
.
stride
(
-
1
)
==
1
assert
bias
.
stride
(
-
1
)
==
1
assert
bias
.
shape
==
(
N
,)
assert
bias
.
shape
==
(
N
,)
if
rowscale
is
not
None
:
assert
rowscale
.
is_contiguous
()
assert
rowscale
.
shape
==
(
M
,)
# allocate output
# allocate output
y
=
torch
.
empty_like
(
x
,
dtype
=
x
.
dtype
if
out_dtype
is
None
else
out_dtype
)
y
=
torch
.
empty_like
(
x
,
dtype
=
x
.
dtype
if
out_dtype
is
None
else
out_dtype
)
assert
y
.
stride
(
-
1
)
==
1
assert
y
.
stride
(
-
1
)
==
1
...
@@ -193,6 +208,7 @@ def _layer_norm_fwd(
...
@@ -193,6 +208,7 @@ def _layer_norm_fwd(
residual
is
not
None
residual
is
not
None
or
(
residual_dtype
is
not
None
and
residual_dtype
!=
x
.
dtype
)
or
(
residual_dtype
is
not
None
and
residual_dtype
!=
x
.
dtype
)
or
dropout_p
>
0.0
or
dropout_p
>
0.0
or
rowscale
is
not
None
):
):
residual_out
=
torch
.
empty
(
residual_out
=
torch
.
empty
(
M
,
N
,
device
=
x
.
device
,
dtype
=
residual_dtype
if
residual_dtype
is
not
None
else
x
.
dtype
M
,
N
,
device
=
x
.
device
,
dtype
=
residual_dtype
if
residual_dtype
is
not
None
else
x
.
dtype
...
@@ -224,6 +240,7 @@ def _layer_norm_fwd(
...
@@ -224,6 +240,7 @@ def _layer_norm_fwd(
bias
,
bias
,
residual
,
residual
,
residual_out
,
residual_out
,
rowscale
,
seeds
,
seeds
,
dropout_mask
,
dropout_mask
,
mean
,
mean
,
...
@@ -242,6 +259,7 @@ def _layer_norm_fwd(
...
@@ -242,6 +259,7 @@ def _layer_norm_fwd(
bias
is
not
None
,
bias
is
not
None
,
dropout_p
>
0.0
,
dropout_p
>
0.0
,
dropout_mask
is
not
None
,
dropout_mask
is
not
None
,
rowscale
is
not
None
,
)
)
# residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
# residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
return
y
,
mean
,
rstd
,
residual_out
if
residual_out
is
not
None
else
x
,
seeds
,
dropout_mask
return
y
,
mean
,
rstd
,
residual_out
if
residual_out
is
not
None
else
x
,
seeds
,
dropout_mask
...
@@ -261,6 +279,7 @@ def _layer_norm_fwd(
...
@@ -261,6 +279,7 @@ def _layer_norm_fwd(
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
@
triton
.
heuristics
({
"HAS_ROWSCALE"
:
lambda
args
:
args
[
"ROWSCALE"
]
is
not
None
})
@
triton
.
heuristics
({
"RECOMPUTE_OUTPUT"
:
lambda
args
:
args
[
"Y"
]
is
not
None
})
@
triton
.
heuristics
({
"RECOMPUTE_OUTPUT"
:
lambda
args
:
args
[
"Y"
]
is
not
None
})
@
triton
.
jit
@
triton
.
jit
def
_layer_norm_bwd_kernel
(
def
_layer_norm_bwd_kernel
(
...
@@ -274,6 +293,7 @@ def _layer_norm_bwd_kernel(
...
@@ -274,6 +293,7 @@ def _layer_norm_bwd_kernel(
DB
,
# pointer to the partial sum of biases gradient
DB
,
# pointer to the partial sum of biases gradient
DRESIDUAL
,
DRESIDUAL
,
DRESIDUAL_IN
,
DRESIDUAL_IN
,
ROWSCALE
,
SEEDS
,
SEEDS
,
Mean
,
# pointer to the mean
Mean
,
# pointer to the mean
Rstd
,
# pointer to the 1/std
Rstd
,
# pointer to the 1/std
...
@@ -294,11 +314,14 @@ def _layer_norm_bwd_kernel(
...
@@ -294,11 +314,14 @@ def _layer_norm_bwd_kernel(
STORE_DRESIDUAL
:
tl
.
constexpr
,
STORE_DRESIDUAL
:
tl
.
constexpr
,
HAS_BIAS
:
tl
.
constexpr
,
HAS_BIAS
:
tl
.
constexpr
,
HAS_DROPOUT
:
tl
.
constexpr
,
HAS_DROPOUT
:
tl
.
constexpr
,
HAS_ROWSCALE
:
tl
.
constexpr
,
RECOMPUTE_OUTPUT
:
tl
.
constexpr
,
RECOMPUTE_OUTPUT
:
tl
.
constexpr
,
):
):
# Map the program id to the elements of X, DX, and DY it should compute.
# Map the program id to the elements of X, DX, and DY it should compute.
row_block_id
=
tl
.
program_id
(
0
)
row_block_id
=
tl
.
program_id
(
0
)
row_start
=
row_block_id
*
rows_per_program
row_start
=
row_block_id
*
rows_per_program
if
row_start
>=
M
:
return
cols
=
tl
.
arange
(
0
,
BLOCK_N
)
cols
=
tl
.
arange
(
0
,
BLOCK_N
)
mask
=
cols
<
N
mask
=
cols
<
N
X
+=
row_start
*
stride_x_row
X
+=
row_start
*
stride_x_row
...
@@ -350,6 +373,9 @@ def _layer_norm_bwd_kernel(
...
@@ -350,6 +373,9 @@ def _layer_norm_bwd_kernel(
if
HAS_DROPOUT
:
if
HAS_DROPOUT
:
keep_mask
=
tl
.
rand
(
tl
.
load
(
SEEDS
+
row
).
to
(
tl
.
uint32
),
cols
,
n_rounds
=
7
)
>
dropout_p
keep_mask
=
tl
.
rand
(
tl
.
load
(
SEEDS
+
row
).
to
(
tl
.
uint32
),
cols
,
n_rounds
=
7
)
>
dropout_p
dx
=
tl
.
where
(
keep_mask
,
dx
/
(
1.0
-
dropout_p
),
0.0
)
dx
=
tl
.
where
(
keep_mask
,
dx
/
(
1.0
-
dropout_p
),
0.0
)
if
HAS_ROWSCALE
:
rowscale
=
tl
.
load
(
ROWSCALE
+
row
).
to
(
tl
.
float32
)
dx
*=
rowscale
tl
.
store
(
DX
+
cols
,
dx
,
mask
=
mask
)
tl
.
store
(
DX
+
cols
,
dx
,
mask
=
mask
)
X
+=
stride_x_row
X
+=
stride_x_row
...
@@ -377,6 +403,7 @@ def _layer_norm_bwd(
...
@@ -377,6 +403,7 @@ def _layer_norm_bwd(
dresidual
=
None
,
dresidual
=
None
,
seeds
=
None
,
seeds
=
None
,
dropout_p
=
0.0
,
dropout_p
=
0.0
,
rowscale
=
None
,
has_residual
=
False
,
has_residual
=
False
,
is_rms_norm
=
False
,
is_rms_norm
=
False
,
x_dtype
=
None
,
x_dtype
=
None
,
...
@@ -397,6 +424,9 @@ def _layer_norm_bwd(
...
@@ -397,6 +424,9 @@ def _layer_norm_bwd(
if
seeds
is
not
None
:
if
seeds
is
not
None
:
assert
seeds
.
is_contiguous
()
assert
seeds
.
is_contiguous
()
assert
seeds
.
shape
==
(
M
,)
assert
seeds
.
shape
==
(
M
,)
if
rowscale
is
not
None
:
assert
rowscale
.
is_contiguous
()
assert
rowscale
.
shape
==
(
M
,)
# allocate output
# allocate output
dx
=
(
dx
=
(
torch
.
empty_like
(
x
)
torch
.
empty_like
(
x
)
...
@@ -404,7 +434,9 @@ def _layer_norm_bwd(
...
@@ -404,7 +434,9 @@ def _layer_norm_bwd(
else
torch
.
empty
(
M
,
N
,
dtype
=
x_dtype
,
device
=
x
.
device
)
else
torch
.
empty
(
M
,
N
,
dtype
=
x_dtype
,
device
=
x
.
device
)
)
)
dresidual_in
=
(
dresidual_in
=
(
torch
.
empty_like
(
x
)
if
has_residual
and
(
dx
.
dtype
!=
x
.
dtype
or
dropout_p
>
0.0
)
else
None
torch
.
empty_like
(
x
)
if
has_residual
and
(
dx
.
dtype
!=
x
.
dtype
or
dropout_p
>
0.0
or
rowscale
is
not
None
)
else
None
)
)
y
=
torch
.
empty
(
M
,
N
,
dtype
=
dy
.
dtype
,
device
=
dy
.
device
)
if
recompute_output
else
None
y
=
torch
.
empty
(
M
,
N
,
dtype
=
dy
.
dtype
,
device
=
dy
.
device
)
if
recompute_output
else
None
...
@@ -434,6 +466,7 @@ def _layer_norm_bwd(
...
@@ -434,6 +466,7 @@ def _layer_norm_bwd(
_db
,
_db
,
dresidual
,
dresidual
,
dresidual_in
,
dresidual_in
,
rowscale
,
seeds
,
seeds
,
mean
,
mean
,
rstd
,
rstd
,
...
@@ -458,7 +491,7 @@ def _layer_norm_bwd(
...
@@ -458,7 +491,7 @@ def _layer_norm_bwd(
dw
=
_dw
.
sum
(
0
).
to
(
weight
.
dtype
)
dw
=
_dw
.
sum
(
0
).
to
(
weight
.
dtype
)
db
=
_db
.
sum
(
0
).
to
(
bias
.
dtype
)
if
bias
is
not
None
else
None
db
=
_db
.
sum
(
0
).
to
(
bias
.
dtype
)
if
bias
is
not
None
else
None
# Don't need to compute dresidual_in separately in this case
# Don't need to compute dresidual_in separately in this case
if
has_residual
and
dx
.
dtype
==
x
.
dtype
and
dropout_p
==
0.0
:
if
has_residual
and
dx
.
dtype
==
x
.
dtype
and
dropout_p
==
0.0
and
rowscale
is
None
:
dresidual_in
=
dx
dresidual_in
=
dx
return
(
dx
,
dw
,
db
,
dresidual_in
)
if
not
recompute_output
else
(
dx
,
dw
,
db
,
dresidual_in
,
y
)
return
(
dx
,
dw
,
db
,
dresidual_in
)
if
not
recompute_output
else
(
dx
,
dw
,
db
,
dresidual_in
,
y
)
...
@@ -473,6 +506,7 @@ class LayerNormFn(torch.autograd.Function):
...
@@ -473,6 +506,7 @@ class LayerNormFn(torch.autograd.Function):
residual
=
None
,
residual
=
None
,
eps
=
1e-6
,
eps
=
1e-6
,
dropout_p
=
0.0
,
dropout_p
=
0.0
,
rowscale
=
None
,
prenorm
=
False
,
prenorm
=
False
,
residual_in_fp32
=
False
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
,
is_rms_norm
=
False
,
...
@@ -491,6 +525,8 @@ class LayerNormFn(torch.autograd.Function):
...
@@ -491,6 +525,8 @@ class LayerNormFn(torch.autograd.Function):
weight
=
weight
.
contiguous
()
weight
=
weight
.
contiguous
()
if
bias
is
not
None
:
if
bias
is
not
None
:
bias
=
bias
.
contiguous
()
bias
=
bias
.
contiguous
()
if
rowscale
is
not
None
:
rowscale
=
rowscale
.
reshape
(
-
1
).
contiguous
()
residual_dtype
=
(
residual_dtype
=
(
residual
.
dtype
residual
.
dtype
if
residual
is
not
None
if
residual
is
not
None
...
@@ -503,11 +539,12 @@ class LayerNormFn(torch.autograd.Function):
...
@@ -503,11 +539,12 @@ class LayerNormFn(torch.autograd.Function):
eps
,
eps
,
residual
,
residual
,
dropout_p
=
dropout_p
,
dropout_p
=
dropout_p
,
rowscale
=
rowscale
,
residual_dtype
=
residual_dtype
,
residual_dtype
=
residual_dtype
,
is_rms_norm
=
is_rms_norm
,
is_rms_norm
=
is_rms_norm
,
return_dropout_mask
=
return_dropout_mask
,
return_dropout_mask
=
return_dropout_mask
,
)
)
ctx
.
save_for_backward
(
residual_out
,
weight
,
bias
,
seeds
,
mean
,
rstd
)
ctx
.
save_for_backward
(
residual_out
,
weight
,
bias
,
rowscale
,
seeds
,
mean
,
rstd
)
ctx
.
x_shape_og
=
x_shape_og
ctx
.
x_shape_og
=
x_shape_og
ctx
.
eps
=
eps
ctx
.
eps
=
eps
ctx
.
dropout_p
=
dropout_p
ctx
.
dropout_p
=
dropout_p
...
@@ -525,7 +562,7 @@ class LayerNormFn(torch.autograd.Function):
...
@@ -525,7 +562,7 @@ class LayerNormFn(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
dy
,
*
args
):
def
backward
(
ctx
,
dy
,
*
args
):
x
,
weight
,
bias
,
seeds
,
mean
,
rstd
=
ctx
.
saved_tensors
x
,
weight
,
bias
,
rowscale
,
seeds
,
mean
,
rstd
=
ctx
.
saved_tensors
dy
=
dy
.
reshape
(
-
1
,
dy
.
shape
[
-
1
])
dy
=
dy
.
reshape
(
-
1
,
dy
.
shape
[
-
1
])
if
dy
.
stride
(
-
1
)
!=
1
:
if
dy
.
stride
(
-
1
)
!=
1
:
dy
=
dy
.
contiguous
()
dy
=
dy
.
contiguous
()
...
@@ -549,6 +586,7 @@ class LayerNormFn(torch.autograd.Function):
...
@@ -549,6 +586,7 @@ class LayerNormFn(torch.autograd.Function):
dresidual
,
dresidual
,
seeds
,
seeds
,
ctx
.
dropout_p
,
ctx
.
dropout_p
,
rowscale
,
ctx
.
has_residual
,
ctx
.
has_residual
,
ctx
.
is_rms_norm
,
ctx
.
is_rms_norm
,
x_dtype
=
ctx
.
x_dtype
,
x_dtype
=
ctx
.
x_dtype
,
...
@@ -564,6 +602,7 @@ class LayerNormFn(torch.autograd.Function):
...
@@ -564,6 +602,7 @@ class LayerNormFn(torch.autograd.Function):
None
,
None
,
None
,
None
,
None
,
None
,
None
,
)
)
...
@@ -574,6 +613,7 @@ def layer_norm_fn(
...
@@ -574,6 +613,7 @@ def layer_norm_fn(
residual
=
None
,
residual
=
None
,
eps
=
1e-6
,
eps
=
1e-6
,
dropout_p
=
0.0
,
dropout_p
=
0.0
,
rowscale
=
None
,
prenorm
=
False
,
prenorm
=
False
,
residual_in_fp32
=
False
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
,
is_rms_norm
=
False
,
...
@@ -586,6 +626,7 @@ def layer_norm_fn(
...
@@ -586,6 +626,7 @@ def layer_norm_fn(
residual
,
residual
,
eps
,
eps
,
dropout_p
,
dropout_p
,
rowscale
,
prenorm
,
prenorm
,
residual_in_fp32
,
residual_in_fp32
,
is_rms_norm
,
is_rms_norm
,
...
@@ -600,6 +641,7 @@ def rms_norm_fn(
...
@@ -600,6 +641,7 @@ def rms_norm_fn(
residual
=
None
,
residual
=
None
,
eps
=
1e-6
,
eps
=
1e-6
,
dropout_p
=
0.0
,
dropout_p
=
0.0
,
rowscale
=
None
,
prenorm
=
False
,
prenorm
=
False
,
residual_in_fp32
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
,
return_dropout_mask
=
False
,
...
@@ -611,6 +653,7 @@ def rms_norm_fn(
...
@@ -611,6 +653,7 @@ def rms_norm_fn(
residual
,
residual
,
eps
,
eps
,
dropout_p
,
dropout_p
,
rowscale
,
prenorm
,
prenorm
,
residual_in_fp32
,
residual_in_fp32
,
True
,
True
,
...
...
tests/ops/triton/test_layer_norm.py
View file @
aa5c6438
# Copyright (c) 202
3
, Tri Dao.
# Copyright (c) 202
4
, Tri Dao.
import
pytest
import
pytest
import
torch
import
torch
...
@@ -16,14 +16,16 @@ from flash_attn.ops.triton.layernorm import (
...
@@ -16,14 +16,16 @@ from flash_attn.ops.triton.layernorm import (
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)[
0
]
>=
8
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)[
0
]
>=
8
@
pytest
.
mark
.
parametrize
(
"has_rowscale"
,
[
False
,
True
])
# @pytest.mark.parametrize("has_rowscale", [True])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.27
])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.27
])
# @pytest.mark.parametrize("dropout_p", [0.
27
])
# @pytest.mark.parametrize("dropout_p", [0.
0
])
@
pytest
.
mark
.
parametrize
(
"prenorm"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"prenorm"
,
[
True
,
False
])
# @pytest.mark.parametrize("prenorm", [
Fals
e])
# @pytest.mark.parametrize("prenorm", [
Tru
e])
@
pytest
.
mark
.
parametrize
(
"is_rms_norm"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"is_rms_norm"
,
[
False
,
True
])
# @pytest.mark.parametrize("is_rms_norm", [True])
# @pytest.mark.parametrize("is_rms_norm", [True])
@
pytest
.
mark
.
parametrize
(
"has_residual"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"has_residual"
,
[
True
,
False
])
# @pytest.mark.parametrize("has_residual", [
Tru
e])
# @pytest.mark.parametrize("has_residual", [
Fals
e])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"weight_dtype"
,
[
torch
.
float32
,
torch
.
float16
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[])
"weight_dtype"
,
[
torch
.
float32
,
torch
.
float16
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[])
)
)
...
@@ -45,6 +47,7 @@ def test_layer_norm(
...
@@ -45,6 +47,7 @@ def test_layer_norm(
is_rms_norm
,
is_rms_norm
,
prenorm
,
prenorm
,
dropout_p
,
dropout_p
,
has_rowscale
,
):
):
device
=
"cuda"
device
=
"cuda"
if
any
(
x
==
torch
.
bfloat16
for
x
in
[
input_dtype
,
residual_dtype
,
weight_dtype
]):
if
any
(
x
==
torch
.
bfloat16
for
x
in
[
input_dtype
,
residual_dtype
,
weight_dtype
]):
...
@@ -60,7 +63,8 @@ def test_layer_norm(
...
@@ -60,7 +63,8 @@ def test_layer_norm(
layer_norm_ref_fn
=
layer_norm_ref
if
not
is_rms_norm
else
rms_norm_ref
layer_norm_ref_fn
=
layer_norm_ref
if
not
is_rms_norm
else
rms_norm_ref
allclose
=
(
allclose
=
(
lambda
x
,
x_pt
,
x_ref
,
atol
=
atol
:
(
x
-
x_ref
).
abs
().
max
()
lambda
x
,
x_pt
,
x_ref
,
atol
=
atol
:
(
x
-
x_ref
).
abs
().
max
()
<=
2
*
(
x_pt
-
x_ref
).
abs
().
max
()
+
atol
# Sometimes x0_pt.grad is NaN
<=
2
*
(
x_pt
[
~
x_pt
.
isnan
()]
-
x_ref
[
~
x_pt
.
isnan
()]).
abs
().
max
()
+
atol
)
)
x0
=
torch
.
randn
(
x0
=
torch
.
randn
(
batch_size
,
seqlen
,
hidden_size
,
device
=
device
,
dtype
=
input_dtype
,
requires_grad
=
True
batch_size
,
seqlen
,
hidden_size
,
device
=
device
,
dtype
=
input_dtype
,
requires_grad
=
True
...
@@ -83,6 +87,8 @@ def test_layer_norm(
...
@@ -83,6 +87,8 @@ def test_layer_norm(
bias_pt
=
bias
.
detach
().
clone
().
requires_grad_
()
if
bias
is
not
None
else
None
bias_pt
=
bias
.
detach
().
clone
().
requires_grad_
()
if
bias
is
not
None
else
None
bias_ref
=
bias
.
detach
().
clone
().
requires_grad_
()
if
bias
is
not
None
else
None
bias_ref
=
bias
.
detach
().
clone
().
requires_grad_
()
if
bias
is
not
None
else
None
rowscale
=
torch
.
randn
(
batch_size
,
seqlen
,
dtype
=
input_dtype
,
device
=
device
)
if
has_rowscale
else
None
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
out
,
*
rest
=
layer_norm_fn
(
out
,
*
rest
=
layer_norm_fn
(
x0
,
x0
,
...
@@ -91,6 +97,7 @@ def test_layer_norm(
...
@@ -91,6 +97,7 @@ def test_layer_norm(
residual
=
res
,
residual
=
res
,
eps
=
1e-6
,
eps
=
1e-6
,
dropout_p
=
dropout_p
,
dropout_p
=
dropout_p
,
rowscale
=
rowscale
,
prenorm
=
prenorm
,
prenorm
=
prenorm
,
residual_in_fp32
=
residual_in_fp32
,
residual_in_fp32
=
residual_in_fp32
,
is_rms_norm
=
is_rms_norm
,
is_rms_norm
=
is_rms_norm
,
...
@@ -104,6 +111,7 @@ def test_layer_norm(
...
@@ -104,6 +111,7 @@ def test_layer_norm(
residual
=
res_pt
,
residual
=
res_pt
,
eps
=
1e-6
,
eps
=
1e-6
,
dropout_p
=
dropout_p
,
dropout_p
=
dropout_p
,
rowscale
=
rowscale
,
prenorm
=
prenorm
,
prenorm
=
prenorm
,
dropout_mask
=
dropout_mask
,
dropout_mask
=
dropout_mask
,
)
)
...
@@ -114,6 +122,7 @@ def test_layer_norm(
...
@@ -114,6 +122,7 @@ def test_layer_norm(
residual
=
res_ref
,
residual
=
res_ref
,
eps
=
1e-6
,
eps
=
1e-6
,
dropout_p
=
dropout_p
,
dropout_p
=
dropout_p
,
rowscale
=
rowscale
,
prenorm
=
prenorm
,
prenorm
=
prenorm
,
dropout_mask
=
dropout_mask
,
dropout_mask
=
dropout_mask
,
upcast
=
True
,
upcast
=
True
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment