Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
01771645
Commit
01771645
authored
Nov 13, 2023
by
Tri Dao
Browse files
[LayerNorm] Add postnorm residual + LayerNorm/RMSNorm in Triton
parent
79bd1a2d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
205 additions
and
88 deletions
+205
-88
flash_attn/ops/triton/layernorm.py
flash_attn/ops/triton/layernorm.py
+178
-73
tests/ops/triton/test_layer_norm.py
tests/ops/triton/test_layer_norm.py
+27
-15
No files found.
flash_attn/ops/triton/layernorm.py
View file @
01771645
...
@@ -15,7 +15,7 @@ import triton
...
@@ -15,7 +15,7 @@ import triton
import
triton.language
as
tl
import
triton.language
as
tl
def
layer_norm_ref
(
x
,
weight
,
bias
,
residual
=
None
,
eps
=
1e-6
,
upcast
=
False
):
def
layer_norm_ref
(
x
,
weight
,
bias
,
residual
=
None
,
eps
=
1e-6
,
prenorm
=
False
,
upcast
=
False
):
dtype
=
x
.
dtype
dtype
=
x
.
dtype
if
upcast
:
if
upcast
:
weight
=
weight
.
float
()
weight
=
weight
.
float
()
...
@@ -25,11 +25,13 @@ def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, upcast=False):
...
@@ -25,11 +25,13 @@ def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, upcast=False):
residual
=
residual
.
float
()
if
residual
is
not
None
else
residual
residual
=
residual
.
float
()
if
residual
is
not
None
else
residual
if
residual
is
not
None
:
if
residual
is
not
None
:
x
=
(
x
+
residual
).
to
(
x
.
dtype
)
x
=
(
x
+
residual
).
to
(
x
.
dtype
)
out
=
F
.
layer_norm
(
x
.
to
(
weight
.
dtype
),
x
.
shape
[
-
1
:],
weight
=
weight
,
bias
=
bias
,
eps
=
eps
).
to
(
dtype
)
out
=
F
.
layer_norm
(
x
.
to
(
weight
.
dtype
),
x
.
shape
[
-
1
:],
weight
=
weight
,
bias
=
bias
,
eps
=
eps
).
to
(
return
out
if
residual
is
None
else
(
out
,
x
)
dtype
)
return
out
if
not
prenorm
else
(
out
,
x
)
def
rms_norm_ref
(
x
,
weight
,
bias
,
residual
=
None
,
eps
=
1e-6
,
upcast
=
False
):
def
rms_norm_ref
(
x
,
weight
,
bias
,
residual
=
None
,
eps
=
1e-6
,
prenorm
=
False
,
upcast
=
False
):
dtype
=
x
.
dtype
dtype
=
x
.
dtype
if
upcast
:
if
upcast
:
weight
=
weight
.
float
()
weight
=
weight
.
float
()
...
@@ -42,7 +44,7 @@ def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, upcast=False):
...
@@ -42,7 +44,7 @@ def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, upcast=False):
rstd
=
1
/
torch
.
sqrt
((
x
.
square
()).
mean
(
dim
=-
1
,
keepdim
=
True
)
+
eps
)
rstd
=
1
/
torch
.
sqrt
((
x
.
square
()).
mean
(
dim
=-
1
,
keepdim
=
True
)
+
eps
)
out
=
(
x
*
rstd
*
weight
)
+
bias
if
bias
is
not
None
else
(
x
*
rstd
*
weight
)
out
=
(
x
*
rstd
*
weight
)
+
bias
if
bias
is
not
None
else
(
x
*
rstd
*
weight
)
out
=
out
.
to
(
dtype
)
out
=
out
.
to
(
dtype
)
return
out
if
residual
is
None
else
(
out
,
x
)
return
out
if
not
prenorm
else
(
out
,
x
)
@
triton
.
autotune
(
@
triton
.
autotune
(
...
@@ -54,7 +56,7 @@ def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, upcast=False):
...
@@ -54,7 +56,7 @@ def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, upcast=False):
triton
.
Config
({},
num_warps
=
16
),
triton
.
Config
({},
num_warps
=
16
),
triton
.
Config
({},
num_warps
=
32
),
triton
.
Config
({},
num_warps
=
32
),
],
],
key
=
[
"N"
,
"HAS_RESIDUAL"
,
"IS_RMS_NORM"
,
"HAS_BIAS"
],
key
=
[
"N"
,
"HAS_RESIDUAL"
,
"STORE_RESIDUAL_OUT"
,
"IS_RMS_NORM"
,
"HAS_BIAS"
],
)
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
...
@@ -77,6 +79,7 @@ def _layer_norm_fwd_1pass_kernel(
...
@@ -77,6 +79,7 @@ def _layer_norm_fwd_1pass_kernel(
IS_RMS_NORM
:
tl
.
constexpr
,
IS_RMS_NORM
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
HAS_RESIDUAL
:
tl
.
constexpr
,
HAS_RESIDUAL
:
tl
.
constexpr
,
STORE_RESIDUAL_OUT
:
tl
.
constexpr
,
HAS_BIAS
:
tl
.
constexpr
,
HAS_BIAS
:
tl
.
constexpr
,
):
):
# Map the program id to the row of X and Y it should compute.
# Map the program id to the row of X and Y it should compute.
...
@@ -85,21 +88,23 @@ def _layer_norm_fwd_1pass_kernel(
...
@@ -85,21 +88,23 @@ def _layer_norm_fwd_1pass_kernel(
Y
+=
row
*
stride_y_row
Y
+=
row
*
stride_y_row
if
HAS_RESIDUAL
:
if
HAS_RESIDUAL
:
RESIDUAL
+=
row
*
stride_res_row
RESIDUAL
+=
row
*
stride_res_row
if
STORE_RESIDUAL_OUT
:
RESIDUAL_OUT
+=
row
*
stride_res_out_row
RESIDUAL_OUT
+=
row
*
stride_res_out_row
# Compute mean and variance
# Compute mean and variance
cols
=
tl
.
arange
(
0
,
BLOCK_N
)
cols
=
tl
.
arange
(
0
,
BLOCK_N
)
x
=
tl
.
load
(
X
+
cols
,
mask
=
cols
<
N
,
other
=
0.
).
to
(
tl
.
float32
)
x
=
tl
.
load
(
X
+
cols
,
mask
=
cols
<
N
,
other
=
0.
0
).
to
(
tl
.
float32
)
if
HAS_RESIDUAL
:
if
HAS_RESIDUAL
:
residual
=
tl
.
load
(
RESIDUAL
+
cols
,
mask
=
cols
<
N
,
other
=
0.
).
to
(
tl
.
float32
)
residual
=
tl
.
load
(
RESIDUAL
+
cols
,
mask
=
cols
<
N
,
other
=
0.
0
).
to
(
tl
.
float32
)
x
+=
residual
x
+=
residual
if
STORE_RESIDUAL_OUT
:
tl
.
store
(
RESIDUAL_OUT
+
cols
,
x
,
mask
=
cols
<
N
)
tl
.
store
(
RESIDUAL_OUT
+
cols
,
x
,
mask
=
cols
<
N
)
if
not
IS_RMS_NORM
:
if
not
IS_RMS_NORM
:
mean
=
tl
.
sum
(
x
,
axis
=
0
)
/
N
mean
=
tl
.
sum
(
x
,
axis
=
0
)
/
N
tl
.
store
(
Mean
+
row
,
mean
)
tl
.
store
(
Mean
+
row
,
mean
)
xbar
=
tl
.
where
(
cols
<
N
,
x
-
mean
,
0.
)
xbar
=
tl
.
where
(
cols
<
N
,
x
-
mean
,
0.
0
)
var
=
tl
.
sum
(
xbar
*
xbar
,
axis
=
0
)
/
N
var
=
tl
.
sum
(
xbar
*
xbar
,
axis
=
0
)
/
N
else
:
else
:
xbar
=
tl
.
where
(
cols
<
N
,
x
,
0.
)
xbar
=
tl
.
where
(
cols
<
N
,
x
,
0.
0
)
var
=
tl
.
sum
(
xbar
*
xbar
,
axis
=
0
)
/
N
var
=
tl
.
sum
(
xbar
*
xbar
,
axis
=
0
)
/
N
rstd
=
1
/
tl
.
sqrt
(
var
+
eps
)
rstd
=
1
/
tl
.
sqrt
(
var
+
eps
)
tl
.
store
(
Rstd
+
row
,
rstd
)
tl
.
store
(
Rstd
+
row
,
rstd
)
...
@@ -114,7 +119,9 @@ def _layer_norm_fwd_1pass_kernel(
...
@@ -114,7 +119,9 @@ def _layer_norm_fwd_1pass_kernel(
tl
.
store
(
Y
+
cols
,
y
,
mask
=
mask
)
tl
.
store
(
Y
+
cols
,
y
,
mask
=
mask
)
def
_layer_norm_fwd
(
x
,
weight
,
bias
,
eps
,
residual
=
None
,
is_rms_norm
=
False
):
def
_layer_norm_fwd
(
x
,
weight
,
bias
,
eps
,
residual
=
None
,
residual_dtype
=
None
,
is_rms_norm
=
False
):
if
residual
is
not
None
:
residual_dtype
=
residual
.
dtype
M
,
N
=
x
.
shape
M
,
N
=
x
.
shape
assert
x
.
stride
(
-
1
)
==
1
assert
x
.
stride
(
-
1
)
==
1
if
residual
is
not
None
:
if
residual
is
not
None
:
...
@@ -128,13 +135,13 @@ def _layer_norm_fwd(x, weight, bias, eps, residual=None, is_rms_norm=False):
...
@@ -128,13 +135,13 @@ def _layer_norm_fwd(x, weight, bias, eps, residual=None, is_rms_norm=False):
# allocate output
# allocate output
y
=
torch
.
empty_like
(
x
)
y
=
torch
.
empty_like
(
x
)
assert
y
.
stride
(
-
1
)
==
1
assert
y
.
stride
(
-
1
)
==
1
if
residual
is
not
None
:
if
residual
is
not
None
or
(
residual_dtype
is
not
None
and
residual_dtype
!=
x
.
dtype
)
:
residual_out
=
torch
.
empty
_like
(
residual
)
residual_out
=
torch
.
empty
(
M
,
N
,
device
=
x
.
device
,
dtype
=
residual
_dtype
)
assert
residual_out
.
stride
(
-
1
)
==
1
assert
residual_out
.
stride
(
-
1
)
==
1
else
:
else
:
residual_out
=
None
residual_out
=
None
mean
=
torch
.
empty
((
M
,
),
dtype
=
torch
.
float32
,
device
=
'
cuda
'
)
if
not
is_rms_norm
else
None
mean
=
torch
.
empty
((
M
,),
dtype
=
torch
.
float32
,
device
=
"
cuda
"
)
if
not
is_rms_norm
else
None
rstd
=
torch
.
empty
((
M
,
),
dtype
=
torch
.
float32
,
device
=
'
cuda
'
)
rstd
=
torch
.
empty
((
M
,),
dtype
=
torch
.
float32
,
device
=
"
cuda
"
)
# Less than 64KB per feature: enqueue fused kernel
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE
=
65536
//
x
.
element_size
()
MAX_FUSED_SIZE
=
65536
//
x
.
element_size
()
BLOCK_N
=
min
(
MAX_FUSED_SIZE
,
triton
.
next_power_of_2
(
N
))
BLOCK_N
=
min
(
MAX_FUSED_SIZE
,
triton
.
next_power_of_2
(
N
))
...
@@ -142,18 +149,29 @@ def _layer_norm_fwd(x, weight, bias, eps, residual=None, is_rms_norm=False):
...
@@ -142,18 +149,29 @@ def _layer_norm_fwd(x, weight, bias, eps, residual=None, is_rms_norm=False):
raise
RuntimeError
(
"This layer norm doesn't support feature dim >= 64KB."
)
raise
RuntimeError
(
"This layer norm doesn't support feature dim >= 64KB."
)
# heuristics for number of warps
# heuristics for number of warps
with
torch
.
cuda
.
device
(
x
.
device
.
index
):
with
torch
.
cuda
.
device
(
x
.
device
.
index
):
_layer_norm_fwd_1pass_kernel
[(
M
,)](
x
,
y
,
weight
,
bias
,
residual
,
residual_out
,
_layer_norm_fwd_1pass_kernel
[(
M
,)](
mean
,
rstd
,
x
,
x
.
stride
(
0
),
y
.
stride
(
0
),
y
,
weight
,
bias
,
residual
,
residual_out
,
mean
,
rstd
,
x
.
stride
(
0
),
y
.
stride
(
0
),
residual
.
stride
(
0
)
if
residual
is
not
None
else
0
,
residual
.
stride
(
0
)
if
residual
is
not
None
else
0
,
residual_out
.
stride
(
0
)
if
residual
is
not
None
else
0
,
residual_out
.
stride
(
0
)
if
residual_out
is
not
None
else
0
,
N
,
eps
,
N
,
eps
,
is_rms_norm
,
is_rms_norm
,
BLOCK_N
,
BLOCK_N
,
residual
is
not
None
,
residual
is
not
None
,
residual_out
is
not
None
,
bias
is
not
None
,
bias
is
not
None
,
)
)
return
y
,
mean
,
rstd
,
residual_out
# residual_out is None if residual is None and residual_dtype == input_dtype
return
y
,
mean
,
rstd
,
residual_out
if
residual_out
is
not
None
else
x
@
triton
.
autotune
(
@
triton
.
autotune
(
...
@@ -218,7 +236,7 @@ def _layer_norm_bwd_kernel(
...
@@ -218,7 +236,7 @@ def _layer_norm_bwd_kernel(
Y
+=
row_start
*
stride_y_row
Y
+=
row_start
*
stride_y_row
w
=
tl
.
load
(
W
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
w
=
tl
.
load
(
W
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
if
RECOMPUTE_OUTPUT
and
HAS_BIAS
:
if
RECOMPUTE_OUTPUT
and
HAS_BIAS
:
b
=
tl
.
load
(
B
+
cols
,
mask
=
mask
,
other
=
0.
).
to
(
tl
.
float32
)
b
=
tl
.
load
(
B
+
cols
,
mask
=
mask
,
other
=
0.
0
).
to
(
tl
.
float32
)
dw
=
tl
.
zeros
((
BLOCK_N
,),
dtype
=
tl
.
float32
)
dw
=
tl
.
zeros
((
BLOCK_N
,),
dtype
=
tl
.
float32
)
if
HAS_BIAS
:
if
HAS_BIAS
:
db
=
tl
.
zeros
((
BLOCK_N
,),
dtype
=
tl
.
float32
)
db
=
tl
.
zeros
((
BLOCK_N
,),
dtype
=
tl
.
float32
)
...
@@ -232,7 +250,7 @@ def _layer_norm_bwd_kernel(
...
@@ -232,7 +250,7 @@ def _layer_norm_bwd_kernel(
rstd
=
tl
.
load
(
Rstd
+
row
)
rstd
=
tl
.
load
(
Rstd
+
row
)
# Compute dx
# Compute dx
xhat
=
(
x
-
mean
)
*
rstd
if
not
IS_RMS_NORM
else
x
*
rstd
xhat
=
(
x
-
mean
)
*
rstd
if
not
IS_RMS_NORM
else
x
*
rstd
xhat
=
tl
.
where
(
mask
,
xhat
,
0.
)
xhat
=
tl
.
where
(
mask
,
xhat
,
0.
0
)
if
RECOMPUTE_OUTPUT
:
if
RECOMPUTE_OUTPUT
:
y
=
xhat
*
w
+
b
if
HAS_BIAS
else
xhat
*
w
y
=
xhat
*
w
+
b
if
HAS_BIAS
else
xhat
*
w
tl
.
store
(
Y
+
cols
,
y
,
mask
=
mask
)
tl
.
store
(
Y
+
cols
,
y
,
mask
=
mask
)
...
@@ -269,8 +287,20 @@ def _layer_norm_bwd_kernel(
...
@@ -269,8 +287,20 @@ def _layer_norm_bwd_kernel(
tl
.
store
(
DB
+
row_block_id
*
N
+
cols
,
db
,
mask
=
mask
)
tl
.
store
(
DB
+
row_block_id
*
N
+
cols
,
db
,
mask
=
mask
)
def
_layer_norm_bwd
(
dy
,
x
,
weight
,
bias
,
eps
,
mean
,
rstd
,
dresidual
=
None
,
is_rms_norm
=
False
,
x_dtype
=
None
,
def
_layer_norm_bwd
(
recompute_output
=
False
):
dy
,
x
,
weight
,
bias
,
eps
,
mean
,
rstd
,
dresidual
=
None
,
has_residual
=
False
,
is_rms_norm
=
False
,
x_dtype
=
None
,
recompute_output
=
False
,
):
M
,
N
=
x
.
shape
M
,
N
=
x
.
shape
assert
x
.
stride
(
-
1
)
==
1
assert
x
.
stride
(
-
1
)
==
1
assert
dy
.
stride
(
-
1
)
==
1
assert
dy
.
stride
(
-
1
)
==
1
...
@@ -284,8 +314,12 @@ def _layer_norm_bwd(dy, x, weight, bias, eps, mean, rstd, dresidual=None, is_rms
...
@@ -284,8 +314,12 @@ def _layer_norm_bwd(dy, x, weight, bias, eps, mean, rstd, dresidual=None, is_rms
assert
bias
.
stride
(
-
1
)
==
1
assert
bias
.
stride
(
-
1
)
==
1
assert
bias
.
shape
==
(
N
,)
assert
bias
.
shape
==
(
N
,)
# allocate output
# allocate output
dx
=
torch
.
empty_like
(
x
)
if
x_dtype
is
None
else
torch
.
empty
(
M
,
N
,
dtype
=
x_dtype
,
device
=
x
.
device
)
dx
=
(
dresidual_in
=
torch
.
empty_like
(
dresidual
)
if
dresidual
is
not
None
and
dx
.
dtype
!=
dresidual
.
dtype
else
None
torch
.
empty_like
(
x
)
if
x_dtype
is
None
else
torch
.
empty
(
M
,
N
,
dtype
=
x_dtype
,
device
=
x
.
device
)
)
dresidual_in
=
torch
.
empty_like
(
x
)
if
has_residual
and
dx
.
dtype
!=
x
.
dtype
else
None
y
=
torch
.
empty
(
M
,
N
,
dtype
=
dy
.
dtype
,
device
=
dy
.
device
)
if
recompute_output
else
None
y
=
torch
.
empty
(
M
,
N
,
dtype
=
dy
.
dtype
,
device
=
dy
.
device
)
if
recompute_output
else
None
# Less than 64KB per feature: enqueue fused kernel
# Less than 64KB per feature: enqueue fused kernel
...
@@ -295,37 +329,64 @@ def _layer_norm_bwd(dy, x, weight, bias, eps, mean, rstd, dresidual=None, is_rms
...
@@ -295,37 +329,64 @@ def _layer_norm_bwd(dy, x, weight, bias, eps, mean, rstd, dresidual=None, is_rms
raise
RuntimeError
(
"This layer norm doesn't support feature dim >= 64KB."
)
raise
RuntimeError
(
"This layer norm doesn't support feature dim >= 64KB."
)
sm_count
=
torch
.
cuda
.
get_device_properties
(
x
.
device
).
multi_processor_count
sm_count
=
torch
.
cuda
.
get_device_properties
(
x
.
device
).
multi_processor_count
_dw
=
torch
.
empty
((
sm_count
,
N
),
dtype
=
torch
.
float32
,
device
=
weight
.
device
)
_dw
=
torch
.
empty
((
sm_count
,
N
),
dtype
=
torch
.
float32
,
device
=
weight
.
device
)
_db
=
torch
.
empty
((
sm_count
,
N
),
dtype
=
torch
.
float32
,
device
=
bias
.
device
)
if
bias
is
not
None
else
None
_db
=
(
torch
.
empty
((
sm_count
,
N
),
dtype
=
torch
.
float32
,
device
=
bias
.
device
)
if
bias
is
not
None
else
None
)
rows_per_program
=
math
.
ceil
(
M
/
sm_count
)
rows_per_program
=
math
.
ceil
(
M
/
sm_count
)
grid
=
(
sm_count
,)
grid
=
(
sm_count
,)
with
torch
.
cuda
.
device
(
x
.
device
.
index
):
with
torch
.
cuda
.
device
(
x
.
device
.
index
):
_layer_norm_bwd_kernel
[
grid
](
x
,
weight
,
bias
,
y
,
_layer_norm_bwd_kernel
[
grid
](
dy
,
dx
,
_dw
,
_db
,
dresidual
,
dresidual_in
,
x
,
mean
,
rstd
,
weight
,
bias
,
y
,
dy
,
dx
,
_dw
,
_db
,
dresidual
,
dresidual_in
,
mean
,
rstd
,
x
.
stride
(
0
),
x
.
stride
(
0
),
0
if
not
recompute_output
else
y
.
stride
(
0
),
0
if
not
recompute_output
else
y
.
stride
(
0
),
dy
.
stride
(
0
),
dx
.
stride
(
0
),
dy
.
stride
(
0
),
dx
.
stride
(
0
),
dresidual
.
stride
(
0
)
if
dresidual
is
not
None
else
0
,
dresidual
.
stride
(
0
)
if
dresidual
is
not
None
else
0
,
dresidual_in
.
stride
(
0
)
if
dresidual_in
is
not
None
else
0
,
dresidual_in
.
stride
(
0
)
if
dresidual_in
is
not
None
else
0
,
M
,
N
,
eps
,
M
,
N
,
eps
,
rows_per_program
,
rows_per_program
,
is_rms_norm
,
is_rms_norm
,
BLOCK_N
,
BLOCK_N
,
dresidual
is
not
None
,
dresidual
is
not
None
,
dresidual_in
is
not
None
,
dresidual_in
is
not
None
,
bias
is
not
None
)
bias
is
not
None
,
)
dw
=
_dw
.
sum
(
0
).
to
(
weight
.
dtype
)
dw
=
_dw
.
sum
(
0
).
to
(
weight
.
dtype
)
db
=
_db
.
sum
(
0
).
to
(
bias
.
dtype
)
if
bias
is
not
None
else
None
db
=
_db
.
sum
(
0
).
to
(
bias
.
dtype
)
if
bias
is
not
None
else
None
# Don't need to compute dresidual_in separately in this case
# Don't need to compute dresidual_in separately in this case
if
d
residual
is
not
None
and
dx
.
dtype
==
dresidual
.
dtype
:
if
has_
residual
and
dx
.
dtype
==
x
.
dtype
:
dresidual_in
=
dx
dresidual_in
=
dx
return
(
dx
,
dw
,
db
,
dresidual_in
)
if
not
recompute_output
else
(
dx
,
dw
,
db
,
dresidual_in
,
y
)
return
(
dx
,
dw
,
db
,
dresidual_in
)
if
not
recompute_output
else
(
dx
,
dw
,
db
,
dresidual_in
,
y
)
class
LayerNormFn
(
torch
.
autograd
.
Function
):
class
LayerNormFn
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
x
,
weight
,
bias
,
residual
=
None
,
eps
=
1e-6
,
is_rms_norm
=
False
):
def
forward
(
ctx
,
x
,
weight
,
bias
,
residual
=
None
,
eps
=
1e-6
,
prenorm
=
False
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
,
):
x_shape_og
=
x
.
shape
x_shape_og
=
x
.
shape
# reshape input data into 2D tensor
# reshape input data into 2D tensor
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
...
@@ -339,17 +400,23 @@ class LayerNormFn(torch.autograd.Function):
...
@@ -339,17 +400,23 @@ class LayerNormFn(torch.autograd.Function):
weight
=
weight
.
contiguous
()
weight
=
weight
.
contiguous
()
if
bias
is
not
None
:
if
bias
is
not
None
:
bias
=
bias
.
contiguous
()
bias
=
bias
.
contiguous
()
y
,
mean
,
rstd
,
*
rest
=
_layer_norm_fwd
(
x
,
weight
,
bias
,
eps
,
residual
,
is_rms_norm
)
residual_dtype
=
(
if
residual
is
not
None
:
residual
.
dtype
residual_out
=
rest
[
0
]
if
residual
is
not
None
ctx
.
save_for_backward
(
x
if
residual
is
None
else
residual_out
,
weight
,
bias
,
mean
,
rstd
)
else
(
torch
.
float32
if
residual_in_fp32
else
None
)
)
y
,
mean
,
rstd
,
residual_out
=
_layer_norm_fwd
(
x
,
weight
,
bias
,
eps
,
residual
,
residual_dtype
=
residual_dtype
,
is_rms_norm
=
is_rms_norm
)
ctx
.
save_for_backward
(
residual_out
,
weight
,
bias
,
mean
,
rstd
)
ctx
.
x_shape_og
=
x_shape_og
ctx
.
x_shape_og
=
x_shape_og
ctx
.
eps
=
eps
ctx
.
eps
=
eps
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
has_residual
=
residual
is
not
None
ctx
.
has_residual
=
residual
is
not
None
ctx
.
prenorm
=
prenorm
ctx
.
x_dtype
=
x
.
dtype
ctx
.
x_dtype
=
x
.
dtype
y
=
y
.
reshape
(
x_shape_og
)
y
=
y
.
reshape
(
x_shape_og
)
return
y
if
residual
is
None
else
(
y
,
residual_out
.
reshape
(
x_shape_og
))
return
y
if
not
prenorm
else
(
y
,
residual_out
.
reshape
(
x_shape_og
))
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
dy
,
*
args
):
def
backward
(
ctx
,
dy
,
*
args
):
...
@@ -358,7 +425,7 @@ class LayerNormFn(torch.autograd.Function):
...
@@ -358,7 +425,7 @@ class LayerNormFn(torch.autograd.Function):
if
dy
.
stride
(
-
1
)
!=
1
:
if
dy
.
stride
(
-
1
)
!=
1
:
dy
=
dy
.
contiguous
()
dy
=
dy
.
contiguous
()
assert
dy
.
shape
==
x
.
shape
assert
dy
.
shape
==
x
.
shape
if
ctx
.
has_residual
:
if
ctx
.
prenorm
:
dresidual
=
args
[
0
]
dresidual
=
args
[
0
]
dresidual
=
dresidual
.
reshape
(
-
1
,
dresidual
.
shape
[
-
1
])
dresidual
=
dresidual
.
reshape
(
-
1
,
dresidual
.
shape
[
-
1
])
if
dresidual
.
stride
(
-
1
)
!=
1
:
if
dresidual
.
stride
(
-
1
)
!=
1
:
...
@@ -366,17 +433,46 @@ class LayerNormFn(torch.autograd.Function):
...
@@ -366,17 +433,46 @@ class LayerNormFn(torch.autograd.Function):
assert
dresidual
.
shape
==
x
.
shape
assert
dresidual
.
shape
==
x
.
shape
else
:
else
:
dresidual
=
None
dresidual
=
None
dx
,
dw
,
db
,
dresidual_in
=
_layer_norm_bwd
(
dy
,
x
,
weight
,
bias
,
ctx
.
eps
,
mean
,
rstd
,
dresidual
,
dx
,
dw
,
db
,
dresidual_in
=
_layer_norm_bwd
(
ctx
.
is_rms_norm
,
x_dtype
=
ctx
.
x_dtype
)
dy
,
return
dx
.
reshape
(
ctx
.
x_shape_og
),
dw
,
db
,
dresidual_in
.
reshape
(
ctx
.
x_shape_og
)
if
ctx
.
has_residual
else
None
,
None
,
None
x
,
weight
,
bias
,
ctx
.
eps
,
mean
,
rstd
,
dresidual
,
ctx
.
has_residual
,
ctx
.
is_rms_norm
,
x_dtype
=
ctx
.
x_dtype
,
)
return
(
dx
.
reshape
(
ctx
.
x_shape_og
),
dw
,
db
,
dresidual_in
.
reshape
(
ctx
.
x_shape_og
)
if
ctx
.
has_residual
else
None
,
None
,
None
,
None
,
None
,
)
def
layer_norm_fn
(
x
,
weight
,
bias
,
residual
=
None
,
eps
=
1e-6
,
is_rms_norm
=
False
):
def
layer_norm_fn
(
return
LayerNormFn
.
apply
(
x
,
weight
,
bias
,
residual
,
eps
,
is_rms_norm
)
x
,
weight
,
bias
,
residual
=
None
,
eps
=
1e-6
,
prenorm
=
False
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
,
):
return
LayerNormFn
.
apply
(
x
,
weight
,
bias
,
residual
,
eps
,
prenorm
,
residual_in_fp32
,
is_rms_norm
)
def
rms_norm_fn
(
x
,
weight
,
bias
,
residual
=
None
,
eps
=
1e-6
):
def
rms_norm_fn
(
x
,
weight
,
bias
,
residual
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
eps
=
1e-6
):
return
LayerNormFn
.
apply
(
x
,
weight
,
bias
,
residual
,
eps
,
True
)
return
LayerNormFn
.
apply
(
x
,
weight
,
bias
,
residual
,
eps
,
prenorm
,
residual_in_fp32
,
True
)
class
RMSNorm
(
torch
.
nn
.
Module
):
class
RMSNorm
(
torch
.
nn
.
Module
):
...
@@ -391,5 +487,14 @@ class RMSNorm(torch.nn.Module):
...
@@ -391,5 +487,14 @@ class RMSNorm(torch.nn.Module):
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
torch
.
nn
.
init
.
ones_
(
self
.
weight
)
torch
.
nn
.
init
.
ones_
(
self
.
weight
)
def
forward
(
self
,
x
,
residual
=
None
):
def
forward
(
self
,
x
,
residual
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
):
return
layer_norm_fn
(
x
,
self
.
weight
,
self
.
bias
,
residual
=
residual
,
eps
=
self
.
eps
,
is_rms_norm
=
True
)
return
rms_norm_fn
(
x
,
self
.
weight
,
self
.
bias
,
residual
=
residual
,
eps
=
self
.
eps
,
prenorm
=
prenorm
,
residual_in_fp32
=
residual_in_fp32
,
is_rms_norm
=
True
,
)
tests/ops/triton/test_layer_norm.py
View file @
01771645
...
@@ -11,30 +11,32 @@ from flash_attn.ops.triton.layernorm import layer_norm_fn, layer_norm_ref, rms_n
...
@@ -11,30 +11,32 @@ from flash_attn.ops.triton.layernorm import layer_norm_fn, layer_norm_ref, rms_n
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)[
0
]
>=
8
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)[
0
]
>=
8
@
pytest
.
mark
.
parametrize
(
"prenorm"
,
[
True
,
False
])
# @pytest.mark.parametrize("prenorm", [True])
@
pytest
.
mark
.
parametrize
(
"is_rms_norm"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"is_rms_norm"
,
[
False
,
True
])
# @pytest.mark.parametrize("is_rms_norm", [True])
# @pytest.mark.parametrize("is_rms_norm", [True])
@
pytest
.
mark
.
parametrize
(
"has_residual"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"has_residual"
,
[
True
,
False
])
# @pytest.mark.parametrize("has_residual", [
Tru
e])
# @pytest.mark.parametrize("has_residual", [
Fals
e])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"weight_dtype"
,
[
torch
.
float32
,
torch
.
float16
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[])
"weight_dtype"
,
[
torch
.
float32
,
torch
.
float16
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[])
)
)
# @pytest.mark.parametrize("weight_dtype", [torch.float32])
# @pytest.mark.parametrize("weight_dtype", [torch.float32])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"input_dtype,residual_dtype"
,
"input_dtype,residual_dtype"
,
[(
torch
.
float16
,
torch
.
float16
),
(
torch
.
float16
,
torch
.
float32
),
(
torch
.
float32
,
torch
.
float32
)]
[(
torch
.
float16
,
torch
.
float16
),
(
torch
.
float16
,
torch
.
float32
),
(
torch
.
float32
,
torch
.
float32
)]
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]),
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]),
)
)
# @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.bfloat16, torch.float32)])
# @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.bfloat16, torch.float32)])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
192
,
2048
,
2560
,
3000
,
8192
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
192
,
2048
,
2560
,
3000
,
8192
])
# @pytest.mark.parametrize("hidden_size", [256])
# @pytest.mark.parametrize("hidden_size", [256])
def
test_layer_norm
(
def
test_layer_norm
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
has_residual
,
is_rms_norm
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
has_residual
,
is_rms_norm
,
prenorm
):
):
device
=
"cuda"
device
=
"cuda"
if
any
(
x
==
torch
.
bfloat16
for
x
in
[
input_dtype
,
residual_dtype
,
weight_dtype
]):
if
any
(
x
==
torch
.
bfloat16
for
x
in
[
input_dtype
,
residual_dtype
,
weight_dtype
]):
atol
=
5e-2
atol
=
5e-2
elif
any
(
x
==
torch
.
float16
for
x
in
[
input_dtype
,
residual_dtype
,
weight_dtype
]):
elif
any
(
x
==
torch
.
float16
for
x
in
[
input_dtype
,
residual_dtype
,
weight_dtype
]):
atol
=
5
e-
3
atol
=
1
e-
2
else
:
else
:
atol
=
1e-4
atol
=
1e-4
# set seed
# set seed
...
@@ -68,26 +70,36 @@ def test_layer_norm(
...
@@ -68,26 +70,36 @@ def test_layer_norm(
weight_ref
=
weight
.
detach
().
clone
().
requires_grad_
()
weight_ref
=
weight
.
detach
().
clone
().
requires_grad_
()
bias_pt
=
bias
.
detach
().
clone
().
requires_grad_
()
if
bias
is
not
None
else
None
bias_pt
=
bias
.
detach
().
clone
().
requires_grad_
()
if
bias
is
not
None
else
None
bias_ref
=
bias
.
detach
().
clone
().
requires_grad_
()
if
bias
is
not
None
else
None
bias_ref
=
bias
.
detach
().
clone
().
requires_grad_
()
if
bias
is
not
None
else
None
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
out
,
*
rest
=
layer_norm_fn
(
x0
,
weight
,
bias
,
residual
=
res
,
eps
=
1e-6
,
is_rms_norm
=
is_rms_norm
)
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
out_pt
,
*
rest_pt
=
layer_norm_ref_fn
(
x0_pt
,
weight_pt
,
bias_pt
,
residual
=
res_pt
,
eps
=
1e-6
)
out
,
*
rest
=
layer_norm_fn
(
x0
,
weight
,
bias
,
residual
=
res
,
eps
=
1e-6
,
prenorm
=
prenorm
,
residual_in_fp32
=
residual_in_fp32
,
is_rms_norm
=
is_rms_norm
,
)
out_pt
,
*
rest_pt
=
layer_norm_ref_fn
(
x0_pt
,
weight_pt
,
bias_pt
,
residual
=
res_pt
,
eps
=
1e-6
,
prenorm
=
prenorm
)
out_ref
,
*
rest_ref
=
layer_norm_ref_fn
(
out_ref
,
*
rest_ref
=
layer_norm_ref_fn
(
x0_ref
,
weight_ref
,
bias_ref
,
residual
=
res_ref
,
eps
=
1e-6
,
upcast
=
True
x0_ref
,
weight_ref
,
bias_ref
,
residual
=
res_ref
,
eps
=
1e-6
,
prenorm
=
prenorm
,
upcast
=
True
)
)
if
has_residual
:
if
prenorm
:
residual
=
rest
[
0
]
residual
=
rest
[
0
]
residual_pt
=
rest_pt
[
0
]
residual_pt
=
rest_pt
[
0
]
residual_ref
=
rest_ref
[
0
]
residual_ref
=
rest_ref
[
0
]
residual_ref
=
x0_ref
+
res_ref
assert
out
.
dtype
==
input_dtype
assert
out
.
dtype
==
input_dtype
if
has_residual
:
if
prenorm
:
assert
residual
.
dtype
==
residual_dtype
assert
residual
.
dtype
==
residual_dtype
assert
allclose
(
residual
,
residual_pt
,
residual_ref
)
assert
allclose
(
residual
,
residual_pt
,
residual_ref
)
assert
allclose
(
out
,
out_pt
,
out_ref
)
assert
allclose
(
out
,
out_pt
,
out_ref
)
g
=
torch
.
randn_like
(
out
)
/
batch_size
g
=
torch
.
randn_like
(
out
)
/
batch_size
if
not
has_residual
:
if
not
prenorm
:
out
.
backward
(
g
)
out
.
backward
(
g
)
out_pt
.
backward
(
g
)
out_pt
.
backward
(
g
)
out_ref
.
backward
(
g
)
out_ref
.
backward
(
g
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment