Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
665b55e2
Commit
665b55e2
authored
Jan 04, 2024
by
Tri Dao
Browse files
[LayerNorm] Implement parallel layer norm in Triton
parent
aa5c6438
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
381 additions
and
32 deletions
+381
-32
flash_attn/ops/triton/layernorm.py
flash_attn/ops/triton/layernorm.py
+290
-23
tests/ops/triton/test_layer_norm.py
tests/ops/triton/test_layer_norm.py
+91
-9
No files found.
flash_attn/ops/triton/layernorm.py
View file @
665b55e2
...
...
@@ -21,20 +21,28 @@ def layer_norm_ref(
weight
,
bias
,
residual
=
None
,
x1
=
None
,
weight1
=
None
,
bias1
=
None
,
eps
=
1e-6
,
dropout_p
=
0.0
,
rowscale
=
None
,
prenorm
=
False
,
dropout_mask
=
None
,
dropout_mask1
=
None
,
upcast
=
False
,
):
dtype
=
x
.
dtype
if
upcast
:
x
=
x
.
float
()
weight
=
weight
.
float
()
bias
=
bias
.
float
()
if
bias
is
not
None
else
None
if
upcast
:
x
=
x
.
float
()
residual
=
residual
.
float
()
if
residual
is
not
None
else
residual
x1
=
x1
.
float
()
if
x1
is
not
None
else
None
weight1
=
weight1
.
float
()
if
weight1
is
not
None
else
None
bias1
=
bias1
.
float
()
if
bias1
is
not
None
else
None
if
x1
is
not
None
:
assert
rowscale
is
None
,
"rowscale is not supported with parallel LayerNorm"
if
rowscale
is
not
None
:
x
=
x
*
rowscale
[...,
None
]
if
dropout_p
>
0.0
:
...
...
@@ -42,12 +50,25 @@ def layer_norm_ref(
x
=
x
.
masked_fill
(
~
dropout_mask
,
0.0
)
/
(
1.0
-
dropout_p
)
else
:
x
=
F
.
dropout
(
x
,
p
=
dropout_p
)
if
x1
is
not
None
:
if
dropout_mask1
is
not
None
:
x1
=
x1
.
masked_fill
(
~
dropout_mask1
,
0.0
)
/
(
1.0
-
dropout_p
)
else
:
x1
=
F
.
dropout
(
x1
,
p
=
dropout_p
)
if
x1
is
not
None
:
x
=
x
+
x1
if
residual
is
not
None
:
x
=
(
x
+
residual
).
to
(
x
.
dtype
)
out
=
F
.
layer_norm
(
x
.
to
(
weight
.
dtype
),
x
.
shape
[
-
1
:],
weight
=
weight
,
bias
=
bias
,
eps
=
eps
).
to
(
dtype
)
return
out
if
not
prenorm
else
(
out
,
x
)
if
weight1
is
None
:
return
out
if
not
prenorm
else
(
out
,
x
)
else
:
out1
=
F
.
layer_norm
(
x
.
to
(
weight1
.
dtype
),
x
.
shape
[
-
1
:],
weight
=
weight1
,
bias
=
bias1
,
eps
=
eps
).
to
(
dtype
)
return
(
out
,
out1
)
if
not
prenorm
else
(
out
,
out1
,
x
)
def
rms_norm_ref
(
...
...
@@ -55,20 +76,28 @@ def rms_norm_ref(
weight
,
bias
,
residual
=
None
,
x1
=
None
,
weight1
=
None
,
bias1
=
None
,
eps
=
1e-6
,
dropout_p
=
0.0
,
rowscale
=
None
,
prenorm
=
False
,
dropout_mask
=
None
,
dropout_mask1
=
None
,
upcast
=
False
,
):
dtype
=
x
.
dtype
if
upcast
:
x
=
x
.
float
()
weight
=
weight
.
float
()
bias
=
bias
.
float
()
if
bias
is
not
None
else
None
if
upcast
:
x
=
x
.
float
()
residual
=
residual
.
float
()
if
residual
is
not
None
else
residual
x1
=
x1
.
float
()
if
x1
is
not
None
else
None
weight1
=
weight1
.
float
()
if
weight1
is
not
None
else
None
bias1
=
bias1
.
float
()
if
bias1
is
not
None
else
None
if
x1
is
not
None
:
assert
rowscale
is
None
,
"rowscale is not supported with parallel LayerNorm"
if
rowscale
is
not
None
:
x
=
x
*
rowscale
[...,
None
]
if
dropout_p
>
0.0
:
...
...
@@ -76,12 +105,24 @@ def rms_norm_ref(
x
=
x
.
masked_fill
(
~
dropout_mask
,
0.0
)
/
(
1.0
-
dropout_p
)
else
:
x
=
F
.
dropout
(
x
,
p
=
dropout_p
)
if
x1
is
not
None
:
if
dropout_mask1
is
not
None
:
x1
=
x1
.
masked_fill
(
~
dropout_mask1
,
0.0
)
/
(
1.0
-
dropout_p
)
else
:
x1
=
F
.
dropout
(
x1
,
p
=
dropout_p
)
if
x1
is
not
None
:
x
=
x
+
x1
if
residual
is
not
None
:
x
=
(
x
+
residual
).
to
(
x
.
dtype
)
rstd
=
1
/
torch
.
sqrt
((
x
.
square
()).
mean
(
dim
=-
1
,
keepdim
=
True
)
+
eps
)
out
=
(
x
*
rstd
*
weight
)
+
bias
if
bias
is
not
None
else
(
x
*
rstd
*
weight
)
out
=
out
.
to
(
dtype
)
return
out
if
not
prenorm
else
(
out
,
x
)
out
=
((
x
*
rstd
*
weight
)
+
bias
if
bias
is
not
None
else
(
x
*
rstd
*
weight
)).
to
(
dtype
)
if
weight1
is
None
:
return
out
if
not
prenorm
else
(
out
,
x
)
else
:
out1
=
((
x
*
rstd
*
weight1
)
+
bias1
if
bias1
is
not
None
else
(
x
*
rstd
*
weight1
)).
to
(
dtype
)
return
(
out
,
out1
)
if
not
prenorm
else
(
out
,
out1
,
x
)
@
triton
.
autotune
(
...
...
@@ -97,6 +138,9 @@ def rms_norm_ref(
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
@
triton
.
heuristics
({
"HAS_X1"
:
lambda
args
:
args
[
"X1"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_W1"
:
lambda
args
:
args
[
"W1"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_B1"
:
lambda
args
:
args
[
"B1"
]
is
not
None
})
@
triton
.
jit
def
_layer_norm_fwd_1pass_kernel
(
X
,
# pointer to the input
...
...
@@ -104,6 +148,10 @@ def _layer_norm_fwd_1pass_kernel(
W
,
# pointer to the weights
B
,
# pointer to the biases
RESIDUAL
,
# pointer to the residual
X1
,
W1
,
B1
,
Y1
,
RESIDUAL_OUT
,
# pointer to the residual
ROWSCALE
,
SEEDS
,
# Dropout seeds for each row
...
...
@@ -114,6 +162,9 @@ def _layer_norm_fwd_1pass_kernel(
stride_y_row
,
stride_res_row
,
stride_res_out_row
,
stride_x1_row
,
stride_y1_row
,
M
,
# number of rows in X
N
,
# number of columns in X
eps
,
# epsilon to avoid division by zero
dropout_p
,
# Dropout probability
...
...
@@ -125,6 +176,9 @@ def _layer_norm_fwd_1pass_kernel(
HAS_DROPOUT
:
tl
.
constexpr
,
STORE_DROPOUT_MASK
:
tl
.
constexpr
,
HAS_ROWSCALE
:
tl
.
constexpr
,
HAS_X1
:
tl
.
constexpr
,
HAS_W1
:
tl
.
constexpr
,
HAS_B1
:
tl
.
constexpr
,
):
# Map the program id to the row of X and Y it should compute.
row
=
tl
.
program_id
(
0
)
...
...
@@ -134,6 +188,10 @@ def _layer_norm_fwd_1pass_kernel(
RESIDUAL
+=
row
*
stride_res_row
if
STORE_RESIDUAL_OUT
:
RESIDUAL_OUT
+=
row
*
stride_res_out_row
if
HAS_X1
:
X1
+=
row
*
stride_x1_row
if
HAS_W1
:
Y1
+=
row
*
stride_y1_row
# Compute mean and variance
cols
=
tl
.
arange
(
0
,
BLOCK_N
)
x
=
tl
.
load
(
X
+
cols
,
mask
=
cols
<
N
,
other
=
0.0
).
to
(
tl
.
float32
)
...
...
@@ -147,6 +205,21 @@ def _layer_norm_fwd_1pass_kernel(
x
=
tl
.
where
(
keep_mask
,
x
/
(
1.0
-
dropout_p
),
0.0
)
if
STORE_DROPOUT_MASK
:
tl
.
store
(
DROPOUT_MASK
+
row
*
N
+
cols
,
keep_mask
,
mask
=
cols
<
N
)
if
HAS_X1
:
x1
=
tl
.
load
(
X1
+
cols
,
mask
=
cols
<
N
,
other
=
0.0
).
to
(
tl
.
float32
)
if
HAS_ROWSCALE
:
rowscale
=
tl
.
load
(
ROWSCALE
+
M
+
row
).
to
(
tl
.
float32
)
x1
*=
rowscale
if
HAS_DROPOUT
:
# Compute dropout mask
# 7 rounds is good enough, and reduces register pressure
keep_mask
=
(
tl
.
rand
(
tl
.
load
(
SEEDS
+
M
+
row
).
to
(
tl
.
uint32
),
cols
,
n_rounds
=
7
)
>
dropout_p
)
x1
=
tl
.
where
(
keep_mask
,
x1
/
(
1.0
-
dropout_p
),
0.0
)
if
STORE_DROPOUT_MASK
:
tl
.
store
(
DROPOUT_MASK
+
(
M
+
row
)
*
N
+
cols
,
keep_mask
,
mask
=
cols
<
N
)
x
+=
x1
if
HAS_RESIDUAL
:
residual
=
tl
.
load
(
RESIDUAL
+
cols
,
mask
=
cols
<
N
,
other
=
0.0
).
to
(
tl
.
float32
)
x
+=
residual
...
...
@@ -171,6 +244,12 @@ def _layer_norm_fwd_1pass_kernel(
y
=
x_hat
*
w
+
b
if
HAS_BIAS
else
x_hat
*
w
# Write output
tl
.
store
(
Y
+
cols
,
y
,
mask
=
mask
)
if
HAS_W1
:
w1
=
tl
.
load
(
W1
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
if
HAS_B1
:
b1
=
tl
.
load
(
B1
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
y1
=
x_hat
*
w1
+
b1
if
HAS_B1
else
x_hat
*
w1
tl
.
store
(
Y1
+
cols
,
y1
,
mask
=
mask
)
def
_layer_norm_fwd
(
...
...
@@ -179,6 +258,9 @@ def _layer_norm_fwd(
bias
,
eps
,
residual
=
None
,
x1
=
None
,
weight1
=
None
,
bias1
=
None
,
dropout_p
=
0.0
,
rowscale
=
None
,
out_dtype
=
None
,
...
...
@@ -198,17 +280,33 @@ def _layer_norm_fwd(
if
bias
is
not
None
:
assert
bias
.
stride
(
-
1
)
==
1
assert
bias
.
shape
==
(
N
,)
if
x1
is
not
None
:
assert
x1
.
shape
==
x
.
shape
assert
rowscale
is
None
assert
x1
.
stride
(
-
1
)
==
1
if
weight1
is
not
None
:
assert
weight1
.
shape
==
(
N
,)
assert
weight1
.
stride
(
-
1
)
==
1
if
bias1
is
not
None
:
assert
bias1
.
shape
==
(
N
,)
assert
bias1
.
stride
(
-
1
)
==
1
if
rowscale
is
not
None
:
assert
rowscale
.
is_contiguous
()
assert
rowscale
.
shape
==
(
M
,)
# allocate output
y
=
torch
.
empty_like
(
x
,
dtype
=
x
.
dtype
if
out_dtype
is
None
else
out_dtype
)
assert
y
.
stride
(
-
1
)
==
1
if
weight1
is
not
None
:
y1
=
torch
.
empty_like
(
y
)
assert
y1
.
stride
(
-
1
)
==
1
else
:
y1
=
None
if
(
residual
is
not
None
or
(
residual_dtype
is
not
None
and
residual_dtype
!=
x
.
dtype
)
or
dropout_p
>
0.0
or
rowscale
is
not
None
or
x1
is
not
None
):
residual_out
=
torch
.
empty
(
M
,
N
,
device
=
x
.
device
,
dtype
=
residual_dtype
if
residual_dtype
is
not
None
else
x
.
dtype
...
...
@@ -219,11 +317,13 @@ def _layer_norm_fwd(
mean
=
torch
.
empty
((
M
,),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
if
not
is_rms_norm
else
None
rstd
=
torch
.
empty
((
M
,),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
if
dropout_p
>
0.0
:
seeds
=
torch
.
randint
(
2
**
32
,
(
M
,),
device
=
x
.
device
,
dtype
=
torch
.
int64
)
seeds
=
torch
.
randint
(
2
**
32
,
(
M
if
x1
is
None
else
2
*
M
,),
device
=
x
.
device
,
dtype
=
torch
.
int64
)
else
:
seeds
=
None
if
return_dropout_mask
and
dropout_p
>
0.0
:
dropout_mask
=
torch
.
empty
_like
(
x
,
dtype
=
torch
.
bool
)
dropout_mask
=
torch
.
empty
(
M
if
x1
is
None
else
2
*
M
,
N
,
device
=
x
.
device
,
dtype
=
torch
.
bool
)
else
:
dropout_mask
=
None
# Less than 64KB per feature: enqueue fused kernel
...
...
@@ -231,7 +331,6 @@ def _layer_norm_fwd(
BLOCK_N
=
min
(
MAX_FUSED_SIZE
,
triton
.
next_power_of_2
(
N
))
if
N
>
BLOCK_N
:
raise
RuntimeError
(
"This layer norm doesn't support feature dim >= 64KB."
)
# heuristics for number of warps
with
torch
.
cuda
.
device
(
x
.
device
.
index
):
_layer_norm_fwd_1pass_kernel
[(
M
,)](
x
,
...
...
@@ -239,6 +338,10 @@ def _layer_norm_fwd(
weight
,
bias
,
residual
,
x1
,
weight1
,
bias1
,
y1
,
residual_out
,
rowscale
,
seeds
,
...
...
@@ -249,6 +352,9 @@ def _layer_norm_fwd(
y
.
stride
(
0
),
residual
.
stride
(
0
)
if
residual
is
not
None
else
0
,
residual_out
.
stride
(
0
)
if
residual_out
is
not
None
else
0
,
x1
.
stride
(
0
)
if
x1
is
not
None
else
0
,
y1
.
stride
(
0
)
if
y1
is
not
None
else
0
,
M
,
N
,
eps
,
dropout_p
,
...
...
@@ -262,7 +368,20 @@ def _layer_norm_fwd(
rowscale
is
not
None
,
)
# residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
return
y
,
mean
,
rstd
,
residual_out
if
residual_out
is
not
None
else
x
,
seeds
,
dropout_mask
if
dropout_mask
is
not
None
and
x1
is
not
None
:
dropout_mask
,
dropout_mask1
=
dropout_mask
.
tensor_split
(
2
,
dim
=
0
)
else
:
dropout_mask1
=
None
return
(
y
,
y1
,
mean
,
rstd
,
residual_out
if
residual_out
is
not
None
else
x
,
seeds
,
dropout_mask
,
dropout_mask1
,
)
@
triton
.
autotune
(
...
...
@@ -280,6 +399,9 @@ def _layer_norm_fwd(
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
@
triton
.
heuristics
({
"HAS_ROWSCALE"
:
lambda
args
:
args
[
"ROWSCALE"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_DY1"
:
lambda
args
:
args
[
"DY1"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_DX1"
:
lambda
args
:
args
[
"DX1"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_B1"
:
lambda
args
:
args
[
"DB1"
]
is
not
None
})
@
triton
.
heuristics
({
"RECOMPUTE_OUTPUT"
:
lambda
args
:
args
[
"Y"
]
is
not
None
})
@
triton
.
jit
def
_layer_norm_bwd_kernel
(
...
...
@@ -292,6 +414,11 @@ def _layer_norm_bwd_kernel(
DW
,
# pointer to the partial sum of weights gradient
DB
,
# pointer to the partial sum of biases gradient
DRESIDUAL
,
W1
,
DY1
,
DX1
,
DW1
,
DB1
,
DRESIDUAL_IN
,
ROWSCALE
,
SEEDS
,
...
...
@@ -302,6 +429,8 @@ def _layer_norm_bwd_kernel(
stride_dy_row
,
stride_dx_row
,
stride_dres_row
,
stride_dy1_row
,
stride_dx1_row
,
stride_dres_in_row
,
M
,
# number of rows in X
N
,
# number of columns in X
...
...
@@ -315,6 +444,9 @@ def _layer_norm_bwd_kernel(
HAS_BIAS
:
tl
.
constexpr
,
HAS_DROPOUT
:
tl
.
constexpr
,
HAS_ROWSCALE
:
tl
.
constexpr
,
HAS_DY1
:
tl
.
constexpr
,
HAS_DX1
:
tl
.
constexpr
,
HAS_B1
:
tl
.
constexpr
,
RECOMPUTE_OUTPUT
:
tl
.
constexpr
,
):
# Map the program id to the elements of X, DX, and DY it should compute.
...
...
@@ -331,19 +463,31 @@ def _layer_norm_bwd_kernel(
DRESIDUAL_IN
+=
row_start
*
stride_dres_in_row
DY
+=
row_start
*
stride_dy_row
DX
+=
row_start
*
stride_dx_row
if
HAS_DY1
:
DY1
+=
row_start
*
stride_dy1_row
if
HAS_DX1
:
DX1
+=
row_start
*
stride_dx1_row
if
RECOMPUTE_OUTPUT
:
Y
+=
row_start
*
stride_y_row
w
=
tl
.
load
(
W
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
if
RECOMPUTE_OUTPUT
and
HAS_BIAS
:
b
=
tl
.
load
(
B
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
if
HAS_DY1
:
w1
=
tl
.
load
(
W1
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
dw
=
tl
.
zeros
((
BLOCK_N
,),
dtype
=
tl
.
float32
)
if
HAS_BIAS
:
db
=
tl
.
zeros
((
BLOCK_N
,),
dtype
=
tl
.
float32
)
if
HAS_DY1
:
dw1
=
tl
.
zeros
((
BLOCK_N
,),
dtype
=
tl
.
float32
)
if
HAS_B1
:
db1
=
tl
.
zeros
((
BLOCK_N
,),
dtype
=
tl
.
float32
)
row_end
=
min
((
row_block_id
+
1
)
*
rows_per_program
,
M
)
for
row
in
range
(
row_start
,
row_end
):
# Load data to SRAM
x
=
tl
.
load
(
X
+
cols
,
mask
=
mask
,
other
=
0
).
to
(
tl
.
float32
)
dy
=
tl
.
load
(
DY
+
cols
,
mask
=
mask
,
other
=
0
).
to
(
tl
.
float32
)
if
HAS_DY1
:
dy1
=
tl
.
load
(
DY1
+
cols
,
mask
=
mask
,
other
=
0
).
to
(
tl
.
float32
)
if
not
IS_RMS_NORM
:
mean
=
tl
.
load
(
Mean
+
row
)
rstd
=
tl
.
load
(
Rstd
+
row
)
...
...
@@ -357,6 +501,11 @@ def _layer_norm_bwd_kernel(
dw
+=
dy
*
xhat
if
HAS_BIAS
:
db
+=
dy
if
HAS_DY1
:
wdy
+=
w1
*
dy1
dw1
+=
dy1
*
xhat
if
HAS_B1
:
db1
+=
dy1
if
not
IS_RMS_NORM
:
c1
=
tl
.
sum
(
xhat
*
wdy
,
axis
=
0
)
/
N
c2
=
tl
.
sum
(
wdy
,
axis
=
0
)
/
N
...
...
@@ -370,6 +519,15 @@ def _layer_norm_bwd_kernel(
# Write dx
if
STORE_DRESIDUAL
:
tl
.
store
(
DRESIDUAL_IN
+
cols
,
dx
,
mask
=
mask
)
if
HAS_DX1
:
if
HAS_DROPOUT
:
keep_mask
=
(
tl
.
rand
(
tl
.
load
(
SEEDS
+
M
+
row
).
to
(
tl
.
uint32
),
cols
,
n_rounds
=
7
)
>
dropout_p
)
dx1
=
tl
.
where
(
keep_mask
,
dx
/
(
1.0
-
dropout_p
),
0.0
)
else
:
dx1
=
dx
tl
.
store
(
DX1
+
cols
,
dx1
,
mask
=
mask
)
if
HAS_DROPOUT
:
keep_mask
=
tl
.
rand
(
tl
.
load
(
SEEDS
+
row
).
to
(
tl
.
uint32
),
cols
,
n_rounds
=
7
)
>
dropout_p
dx
=
tl
.
where
(
keep_mask
,
dx
/
(
1.0
-
dropout_p
),
0.0
)
...
...
@@ -387,9 +545,17 @@ def _layer_norm_bwd_kernel(
Y
+=
stride_y_row
DY
+=
stride_dy_row
DX
+=
stride_dx_row
if
HAS_DY1
:
DY1
+=
stride_dy1_row
if
HAS_DX1
:
DX1
+=
stride_dx1_row
tl
.
store
(
DW
+
row_block_id
*
N
+
cols
,
dw
,
mask
=
mask
)
if
HAS_BIAS
:
tl
.
store
(
DB
+
row_block_id
*
N
+
cols
,
db
,
mask
=
mask
)
if
HAS_DY1
:
tl
.
store
(
DW1
+
row_block_id
*
N
+
cols
,
dw1
,
mask
=
mask
)
if
HAS_B1
:
tl
.
store
(
DB1
+
row_block_id
*
N
+
cols
,
db1
,
mask
=
mask
)
def
_layer_norm_bwd
(
...
...
@@ -401,10 +567,14 @@ def _layer_norm_bwd(
mean
,
rstd
,
dresidual
=
None
,
dy1
=
None
,
weight1
=
None
,
bias1
=
None
,
seeds
=
None
,
dropout_p
=
0.0
,
rowscale
=
None
,
has_residual
=
False
,
has_x1
=
False
,
is_rms_norm
=
False
,
x_dtype
=
None
,
recompute_output
=
False
,
...
...
@@ -421,9 +591,19 @@ def _layer_norm_bwd(
if
bias
is
not
None
:
assert
bias
.
stride
(
-
1
)
==
1
assert
bias
.
shape
==
(
N
,)
if
dy1
is
not
None
:
assert
weight1
is
not
None
assert
dy1
.
shape
==
dy
.
shape
assert
dy1
.
stride
(
-
1
)
==
1
if
weight1
is
not
None
:
assert
weight1
.
shape
==
(
N
,)
assert
weight1
.
stride
(
-
1
)
==
1
if
bias1
is
not
None
:
assert
bias1
.
shape
==
(
N
,)
assert
bias1
.
stride
(
-
1
)
==
1
if
seeds
is
not
None
:
assert
seeds
.
is_contiguous
()
assert
seeds
.
shape
==
(
M
,)
assert
seeds
.
shape
==
(
M
if
not
has_x1
else
M
*
2
,)
if
rowscale
is
not
None
:
assert
rowscale
.
is_contiguous
()
assert
rowscale
.
shape
==
(
M
,)
...
...
@@ -435,10 +615,14 @@ def _layer_norm_bwd(
)
dresidual_in
=
(
torch
.
empty_like
(
x
)
if
has_residual
and
(
dx
.
dtype
!=
x
.
dtype
or
dropout_p
>
0.0
or
rowscale
is
not
None
)
if
has_residual
and
(
dx
.
dtype
!=
x
.
dtype
or
dropout_p
>
0.0
or
rowscale
is
not
None
or
has_x1
)
else
None
)
dx1
=
torch
.
empty_like
(
dx
)
if
(
has_x1
and
dropout_p
>
0.0
)
else
None
y
=
torch
.
empty
(
M
,
N
,
dtype
=
dy
.
dtype
,
device
=
dy
.
device
)
if
recompute_output
else
None
if
recompute_output
:
assert
weight1
is
None
,
"recompute_output is not supported with parallel LayerNorm"
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE
=
65536
//
x
.
element_size
()
...
...
@@ -452,6 +636,8 @@ def _layer_norm_bwd(
if
bias
is
not
None
else
None
)
_dw1
=
torch
.
empty_like
(
_dw
)
if
weight1
is
not
None
else
None
_db1
=
torch
.
empty_like
(
_db
)
if
bias1
is
not
None
else
None
rows_per_program
=
math
.
ceil
(
M
/
sm_count
)
grid
=
(
sm_count
,)
with
torch
.
cuda
.
device
(
x
.
device
.
index
):
...
...
@@ -465,6 +651,11 @@ def _layer_norm_bwd(
_dw
,
_db
,
dresidual
,
weight1
,
dy1
,
dx1
,
_dw1
,
_db1
,
dresidual_in
,
rowscale
,
seeds
,
...
...
@@ -475,6 +666,8 @@ def _layer_norm_bwd(
dy
.
stride
(
0
),
dx
.
stride
(
0
),
dresidual
.
stride
(
0
)
if
dresidual
is
not
None
else
0
,
dy1
.
stride
(
0
)
if
dy1
is
not
None
else
0
,
dx1
.
stride
(
0
)
if
dx1
is
not
None
else
0
,
dresidual_in
.
stride
(
0
)
if
dresidual_in
is
not
None
else
0
,
M
,
N
,
...
...
@@ -490,10 +683,18 @@ def _layer_norm_bwd(
)
dw
=
_dw
.
sum
(
0
).
to
(
weight
.
dtype
)
db
=
_db
.
sum
(
0
).
to
(
bias
.
dtype
)
if
bias
is
not
None
else
None
dw1
=
_dw1
.
sum
(
0
).
to
(
weight1
.
dtype
)
if
weight1
is
not
None
else
None
db1
=
_db1
.
sum
(
0
).
to
(
bias1
.
dtype
)
if
bias1
is
not
None
else
None
# Don't need to compute dresidual_in separately in this case
if
has_residual
and
dx
.
dtype
==
x
.
dtype
and
dropout_p
==
0.0
and
rowscale
is
None
:
dresidual_in
=
dx
return
(
dx
,
dw
,
db
,
dresidual_in
)
if
not
recompute_output
else
(
dx
,
dw
,
db
,
dresidual_in
,
y
)
if
has_x1
and
dropout_p
==
0.0
:
dx1
=
dx
return
(
(
dx
,
dw
,
db
,
dresidual_in
,
dx1
,
dw1
,
db1
)
if
not
recompute_output
else
(
dx
,
dw
,
db
,
dresidual_in
,
dx1
,
dw1
,
db1
,
y
)
)
class
LayerNormFn
(
torch
.
autograd
.
Function
):
...
...
@@ -504,6 +705,9 @@ class LayerNormFn(torch.autograd.Function):
weight
,
bias
,
residual
=
None
,
x1
=
None
,
weight1
=
None
,
bias1
=
None
,
eps
=
1e-6
,
dropout_p
=
0.0
,
rowscale
=
None
,
...
...
@@ -522,9 +726,19 @@ class LayerNormFn(torch.autograd.Function):
residual
=
residual
.
reshape
(
-
1
,
residual
.
shape
[
-
1
])
if
residual
.
stride
(
-
1
)
!=
1
:
residual
=
residual
.
contiguous
()
if
x1
is
not
None
:
assert
x1
.
shape
==
x_shape_og
assert
rowscale
is
None
,
"rowscale is not supported with parallel LayerNorm"
x1
=
x1
.
reshape
(
-
1
,
x1
.
shape
[
-
1
])
if
x1
.
stride
(
-
1
)
!=
1
:
x1
=
x1
.
contiguous
()
weight
=
weight
.
contiguous
()
if
bias
is
not
None
:
bias
=
bias
.
contiguous
()
if
weight1
is
not
None
:
weight1
=
weight1
.
contiguous
()
if
bias1
is
not
None
:
bias1
=
bias1
.
contiguous
()
if
rowscale
is
not
None
:
rowscale
=
rowscale
.
reshape
(
-
1
).
contiguous
()
residual_dtype
=
(
...
...
@@ -532,41 +746,71 @@ class LayerNormFn(torch.autograd.Function):
if
residual
is
not
None
else
(
torch
.
float32
if
residual_in_fp32
else
None
)
)
y
,
mean
,
rstd
,
residual_out
,
seeds
,
dropout_mask
=
_layer_norm_fwd
(
y
,
y1
,
mean
,
rstd
,
residual_out
,
seeds
,
dropout_mask
,
dropout_mask1
=
_layer_norm_fwd
(
x
,
weight
,
bias
,
eps
,
residual
,
x1
,
weight1
,
bias1
,
dropout_p
=
dropout_p
,
rowscale
=
rowscale
,
residual_dtype
=
residual_dtype
,
is_rms_norm
=
is_rms_norm
,
return_dropout_mask
=
return_dropout_mask
,
)
ctx
.
save_for_backward
(
residual_out
,
weight
,
bias
,
rowscale
,
seeds
,
mean
,
rstd
)
ctx
.
save_for_backward
(
residual_out
,
weight
,
bias
,
weight1
,
bias1
,
rowscale
,
seeds
,
mean
,
rstd
)
ctx
.
x_shape_og
=
x_shape_og
ctx
.
eps
=
eps
ctx
.
dropout_p
=
dropout_p
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
has_residual
=
residual
is
not
None
ctx
.
has_x1
=
x1
is
not
None
ctx
.
prenorm
=
prenorm
ctx
.
x_dtype
=
x
.
dtype
y
=
y
.
reshape
(
x_shape_og
)
y1
=
y1
.
reshape
(
x_shape_og
)
if
y1
is
not
None
else
None
residual_out
=
residual_out
.
reshape
(
x_shape_og
)
if
residual_out
is
not
None
else
None
dropout_mask
=
dropout_mask
.
reshape
(
x_shape_og
)
if
dropout_mask
is
not
None
else
None
dropout_mask1
=
dropout_mask1
.
reshape
(
x_shape_og
)
if
dropout_mask1
is
not
None
else
None
if
not
return_dropout_mask
:
return
y
if
not
prenorm
else
(
y
,
residual_out
)
if
weight1
is
None
:
return
y
if
not
prenorm
else
(
y
,
residual_out
)
else
:
return
(
y
,
y1
)
if
not
prenorm
else
(
y
,
y1
,
residual_out
)
else
:
return
(
y
,
dropout_mask
)
if
not
prenorm
else
(
y
,
residual_out
,
dropout_mask
)
if
weight1
is
None
:
return
(
(
y
,
dropout_mask
,
dropout_mask1
)
if
not
prenorm
else
(
y
,
residual_out
,
dropout_mask
,
dropout_mask1
)
)
else
:
return
(
(
y
,
y1
,
dropout_mask
,
dropout_mask1
)
if
not
prenorm
else
(
y
,
y1
,
residual_out
,
dropout_mask
,
dropout_mask1
)
)
@
staticmethod
def
backward
(
ctx
,
dy
,
*
args
):
x
,
weight
,
bias
,
rowscale
,
seeds
,
mean
,
rstd
=
ctx
.
saved_tensors
x
,
weight
,
bias
,
weight1
,
bias1
,
rowscale
,
seeds
,
mean
,
rstd
=
ctx
.
saved_tensors
dy
=
dy
.
reshape
(
-
1
,
dy
.
shape
[
-
1
])
if
dy
.
stride
(
-
1
)
!=
1
:
dy
=
dy
.
contiguous
()
assert
dy
.
shape
==
x
.
shape
if
weight1
is
not
None
:
dy1
,
args
=
args
[
0
],
args
[
1
:]
dy1
=
dy1
.
reshape
(
-
1
,
dy1
.
shape
[
-
1
])
if
dy1
.
stride
(
-
1
)
!=
1
:
dy1
=
dy1
.
contiguous
()
assert
dy1
.
shape
==
x
.
shape
else
:
dy1
=
None
if
ctx
.
prenorm
:
dresidual
=
args
[
0
]
dresidual
=
dresidual
.
reshape
(
-
1
,
dresidual
.
shape
[
-
1
])
...
...
@@ -575,7 +819,7 @@ class LayerNormFn(torch.autograd.Function):
assert
dresidual
.
shape
==
x
.
shape
else
:
dresidual
=
None
dx
,
dw
,
db
,
dresidual_in
=
_layer_norm_bwd
(
dx
,
dw
,
db
,
dresidual_in
,
dx1
,
dw1
,
db1
=
_layer_norm_bwd
(
dy
,
x
,
weight
,
...
...
@@ -584,10 +828,14 @@ class LayerNormFn(torch.autograd.Function):
mean
,
rstd
,
dresidual
,
dy1
,
weight1
,
bias1
,
seeds
,
ctx
.
dropout_p
,
rowscale
,
ctx
.
has_residual
,
ctx
.
has_x1
,
ctx
.
is_rms_norm
,
x_dtype
=
ctx
.
x_dtype
,
)
...
...
@@ -596,6 +844,9 @@ class LayerNormFn(torch.autograd.Function):
dw
,
db
,
dresidual_in
.
reshape
(
ctx
.
x_shape_og
)
if
ctx
.
has_residual
else
None
,
dx1
.
reshape
(
ctx
.
x_shape_og
)
if
dx1
is
not
None
else
None
,
dw1
,
db1
,
None
,
None
,
None
,
...
...
@@ -611,6 +862,9 @@ def layer_norm_fn(
weight
,
bias
,
residual
=
None
,
x1
=
None
,
weight1
=
None
,
bias1
=
None
,
eps
=
1e-6
,
dropout_p
=
0.0
,
rowscale
=
None
,
...
...
@@ -624,6 +878,9 @@ def layer_norm_fn(
weight
,
bias
,
residual
,
x1
,
weight1
,
bias1
,
eps
,
dropout_p
,
rowscale
,
...
...
@@ -639,6 +896,9 @@ def rms_norm_fn(
weight
,
bias
,
residual
=
None
,
x1
=
None
,
weight1
=
None
,
bias1
=
None
,
eps
=
1e-6
,
dropout_p
=
0.0
,
rowscale
=
None
,
...
...
@@ -651,6 +911,9 @@ def rms_norm_fn(
weight
,
bias
,
residual
,
x1
,
weight1
,
bias1
,
eps
,
dropout_p
,
rowscale
,
...
...
@@ -662,11 +925,15 @@ def rms_norm_fn(
class
RMSNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-5
,
dropout_p
=
0.0
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
eps
=
eps
self
.
dropout_p
=
dropout_p
if
dropout_p
>
0.0
:
self
.
drop
=
torch
.
nn
.
Dropout
(
dropout_p
)
else
:
self
.
drop
=
None
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
register_parameter
(
"bias"
,
None
)
self
.
reset_parameters
()
...
...
@@ -681,7 +948,7 @@ class RMSNorm(torch.nn.Module):
self
.
bias
,
residual
=
residual
,
eps
=
self
.
eps
,
dropout_p
=
self
.
drop
out_p
if
self
.
training
else
0.0
,
dropout_p
=
self
.
drop
.
p
if
self
.
drop
is
not
None
and
self
.
training
else
0.0
,
prenorm
=
prenorm
,
residual_in_fp32
=
residual_in_fp32
,
)
...
...
tests/ops/triton/test_layer_norm.py
View file @
665b55e2
...
...
@@ -16,12 +16,16 @@ from flash_attn.ops.triton.layernorm import (
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)[
0
]
>=
8
@
pytest
.
mark
.
parametrize
(
"has_weight1"
,
[
False
,
True
])
# @pytest.mark.parametrize("has_weight1", [True])
@
pytest
.
mark
.
parametrize
(
"has_x1"
,
[
False
,
True
])
# @pytest.mark.parametrize("has_x1", [False])
@
pytest
.
mark
.
parametrize
(
"has_rowscale"
,
[
False
,
True
])
# @pytest.mark.parametrize("has_rowscale", [
Tru
e])
# @pytest.mark.parametrize("has_rowscale", [
Fals
e])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.27
])
# @pytest.mark.parametrize("dropout_p", [0.0])
@
pytest
.
mark
.
parametrize
(
"prenorm"
,
[
True
,
False
])
# @pytest.mark.parametrize("prenorm", [
Tru
e])
# @pytest.mark.parametrize("prenorm", [
Fals
e])
@
pytest
.
mark
.
parametrize
(
"is_rms_norm"
,
[
False
,
True
])
# @pytest.mark.parametrize("is_rms_norm", [True])
@
pytest
.
mark
.
parametrize
(
"has_residual"
,
[
True
,
False
])
...
...
@@ -48,7 +52,11 @@ def test_layer_norm(
prenorm
,
dropout_p
,
has_rowscale
,
has_x1
,
has_weight1
,
):
if
has_rowscale
and
has_x1
:
pytest
.
skip
(
"Not supported"
)
device
=
"cuda"
if
any
(
x
==
torch
.
bfloat16
for
x
in
[
input_dtype
,
residual_dtype
,
weight_dtype
]):
atol
=
5e-2
...
...
@@ -62,9 +70,16 @@ def test_layer_norm(
seqlen
=
512
layer_norm_ref_fn
=
layer_norm_ref
if
not
is_rms_norm
else
rms_norm_ref
allclose
=
(
lambda
x
,
x_pt
,
x_ref
,
atol
=
atol
:
(
x
-
x_ref
).
abs
().
max
()
# Sometimes x0_pt.grad is NaN
lambda
x
,
x_pt
,
x_ref
,
atol
=
atol
:
(
x
-
x_ref
).
abs
().
max
()
<=
2
*
(
x_pt
[
~
x_pt
.
isnan
()]
-
x_ref
[
~
x_pt
.
isnan
()]).
abs
().
max
()
+
atol
or
(
# Sometimes x_pt and x_ref are the same (e.g. bfloat16) so we want to perturb is a bit
# by multiply and divide by 0.3
(
x_pt
[
~
x_pt
.
isnan
()]
-
x_ref
[
~
x_pt
.
isnan
()]).
abs
().
max
()
==
0.0
and
(
x
-
x_ref
).
abs
().
max
()
<=
2
*
(
x_pt
[
~
x_pt
.
isnan
()]
*
0.3
/
0.3
-
x_ref
[
~
x_pt
.
isnan
()]).
abs
().
max
()
+
atol
)
)
x0
=
torch
.
randn
(
batch_size
,
seqlen
,
hidden_size
,
device
=
device
,
dtype
=
input_dtype
,
requires_grad
=
True
...
...
@@ -86,8 +101,35 @@ def test_layer_norm(
weight_ref
=
weight
.
detach
().
clone
().
requires_grad_
()
bias_pt
=
bias
.
detach
().
clone
().
requires_grad_
()
if
bias
is
not
None
else
None
bias_ref
=
bias
.
detach
().
clone
().
requires_grad_
()
if
bias
is
not
None
else
None
if
has_x1
:
x1
=
torch
.
randn_like
(
x0
,
dtype
=
input_dtype
,
requires_grad
=
True
)
x1_pt
=
x1
.
detach
().
clone
().
requires_grad_
()
x1_ref
=
x1
.
detach
().
clone
().
requires_grad_
()
else
:
x1
,
x1_pt
,
x1_ref
=
None
,
None
,
None
if
has_weight1
:
weight1
=
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
weight1_pt
=
weight1
.
detach
().
clone
().
requires_grad_
()
weight1_ref
=
weight1
.
detach
().
clone
().
requires_grad_
()
if
not
is_rms_norm
:
bias1
=
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
else
:
bias1
=
None
bias1_pt
=
bias1
.
detach
().
clone
().
requires_grad_
()
if
bias1
is
not
None
else
None
bias1_ref
=
bias1
.
detach
().
clone
().
requires_grad_
()
if
bias1
is
not
None
else
None
else
:
weight1
,
weight1_pt
,
weight1_ref
=
None
,
None
,
None
bias1
,
bias1_pt
,
bias1_ref
=
None
,
None
,
None
rowscale
=
torch
.
randn
(
batch_size
,
seqlen
,
dtype
=
input_dtype
,
device
=
device
)
if
has_rowscale
else
None
rowscale
=
(
torch
.
randn
(
batch_size
,
seqlen
,
dtype
=
input_dtype
,
device
=
device
)
if
has_rowscale
else
None
)
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
out
,
*
rest
=
layer_norm_fn
(
...
...
@@ -95,6 +137,9 @@ def test_layer_norm(
weight
,
bias
,
residual
=
res
,
x1
=
x1
,
weight1
=
weight1
,
bias1
=
bias1
,
eps
=
1e-6
,
dropout_p
=
dropout_p
,
rowscale
=
rowscale
,
...
...
@@ -103,44 +148,75 @@ def test_layer_norm(
is_rms_norm
=
is_rms_norm
,
return_dropout_mask
=
True
,
)
dropout_mask
=
rest
[
-
1
]
if
dropout_p
>
0.0
else
None
dropout_mask
=
rest
[
-
2
]
if
dropout_p
>
0.0
else
None
dropout_mask1
=
rest
[
-
1
]
if
dropout_p
>
0.0
and
x1
is
not
None
else
None
out_pt
=
layer_norm_ref_fn
(
x0_pt
,
weight_pt
,
bias_pt
,
residual
=
res_pt
,
x1
=
x1_pt
,
weight1
=
weight1_pt
,
bias1
=
bias1_pt
,
eps
=
1e-6
,
dropout_p
=
dropout_p
,
rowscale
=
rowscale
,
prenorm
=
prenorm
,
dropout_mask
=
dropout_mask
,
dropout_mask1
=
dropout_mask1
,
)
out_ref
=
layer_norm_ref_fn
(
x0_ref
,
weight_ref
,
bias_ref
,
residual
=
res_ref
,
x1
=
x1_ref
,
weight1
=
weight1_ref
,
bias1
=
bias1_ref
,
eps
=
1e-6
,
dropout_p
=
dropout_p
,
rowscale
=
rowscale
,
prenorm
=
prenorm
,
dropout_mask
=
dropout_mask
,
dropout_mask1
=
dropout_mask1
,
upcast
=
True
,
)
if
prenorm
:
residual
=
rest
[
0
]
out_pt
,
residual_pt
=
out_pt
out_ref
,
residual_ref
=
out_ref
if
not
has_weight1
:
if
prenorm
:
residual
=
rest
[
0
]
out_pt
,
residual_pt
=
out_pt
out_ref
,
residual_ref
=
out_ref
out1
,
out1_pt
,
out1_ref
=
None
,
None
,
None
else
:
out1
=
rest
.
pop
(
0
)
if
prenorm
:
residual
=
rest
[
0
]
out_pt
,
out1_pt
,
residual_pt
=
out_pt
out_ref
,
out1_ref
,
residual_ref
=
out_ref
else
:
out_pt
,
out1_pt
=
out_pt
out_ref
,
out1_ref
=
out_ref
assert
out
.
dtype
==
input_dtype
if
prenorm
:
assert
residual
.
dtype
==
residual_dtype
assert
allclose
(
residual
,
residual_pt
,
residual_ref
)
assert
allclose
(
out
,
out_pt
,
out_ref
)
if
out1
is
not
None
:
assert
out1
.
dtype
==
input_dtype
assert
allclose
(
out1
,
out1_pt
,
out1_ref
)
if
dropout_mask
is
not
None
:
dropout_fraction
=
1.0
-
dropout_mask
.
float
().
mean
()
assert
abs
(
dropout_fraction
-
dropout_p
)
<
0.01
if
dropout_mask1
is
not
None
:
dropout_fraction
=
1.0
-
dropout_mask1
.
float
().
mean
()
assert
abs
(
dropout_fraction
-
dropout_p
)
<
0.01
assert
not
torch
.
equal
(
dropout_mask
,
dropout_mask1
)
g
=
torch
.
randn_like
(
out
)
/
batch_size
if
has_weight1
:
out
=
out
*
F
.
gelu
(
out1
)
out_pt
=
out_pt
*
F
.
gelu
(
out1_pt
)
out_ref
=
out_ref
*
F
.
gelu
(
out1_ref
)
if
not
prenorm
:
out
.
backward
(
g
)
out_pt
.
backward
(
g
)
...
...
@@ -152,9 +228,15 @@ def test_layer_norm(
assert
allclose
(
x0
.
grad
,
x0_pt
.
grad
,
x0_ref
.
grad
)
if
has_residual
:
assert
allclose
(
res
.
grad
,
res_pt
.
grad
,
res_ref
.
grad
)
if
has_x1
:
assert
allclose
(
x1
.
grad
,
x1_pt
.
grad
,
x1_ref
.
grad
)
assert
allclose
(
weight
.
grad
,
weight_pt
.
grad
,
weight_ref
.
grad
)
if
bias
is
not
None
:
assert
allclose
(
bias
.
grad
,
bias_pt
.
grad
,
bias_ref
.
grad
)
if
has_weight1
:
assert
allclose
(
weight1
.
grad
,
weight1_pt
.
grad
,
weight1_ref
.
grad
)
if
bias1
is
not
None
:
assert
allclose
(
bias1
.
grad
,
bias1_pt
.
grad
,
bias1_ref
.
grad
)
@
pytest
.
mark
.
parametrize
(
"prenorm"
,
[
True
,
False
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment