Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
79bd1a2d
Commit
79bd1a2d
authored
Nov 13, 2023
by
Tri Dao
Browse files
[LayerNorm] Implement residual + LayerNorm/RMSNorm in Triton
parent
3566596a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
506 additions
and
1 deletion
+506
-1
flash_attn/ops/triton/layernorm.py
flash_attn/ops/triton/layernorm.py
+395
-0
flash_attn/utils/benchmark.py
flash_attn/utils/benchmark.py
+8
-1
tests/ops/triton/test_layer_norm.py
tests/ops/triton/test_layer_norm.py
+103
-0
No files found.
flash_attn/ops/triton/layernorm.py
0 → 100644
View file @
79bd1a2d
# Copyright (c) 2023, Tri Dao.
# Implement residual + layer_norm / rms_norm.
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
import
math
import
torch
import
torch.nn.functional
as
F
import
triton
import
triton.language
as
tl
def
layer_norm_ref
(
x
,
weight
,
bias
,
residual
=
None
,
eps
=
1e-6
,
upcast
=
False
):
dtype
=
x
.
dtype
if
upcast
:
weight
=
weight
.
float
()
bias
=
bias
.
float
()
if
bias
is
not
None
else
None
if
upcast
:
x
=
x
.
float
()
residual
=
residual
.
float
()
if
residual
is
not
None
else
residual
if
residual
is
not
None
:
x
=
(
x
+
residual
).
to
(
x
.
dtype
)
out
=
F
.
layer_norm
(
x
.
to
(
weight
.
dtype
),
x
.
shape
[
-
1
:],
weight
=
weight
,
bias
=
bias
,
eps
=
eps
).
to
(
dtype
)
return
out
if
residual
is
None
else
(
out
,
x
)
def
rms_norm_ref
(
x
,
weight
,
bias
,
residual
=
None
,
eps
=
1e-6
,
upcast
=
False
):
dtype
=
x
.
dtype
if
upcast
:
weight
=
weight
.
float
()
bias
=
bias
.
float
()
if
bias
is
not
None
else
None
if
upcast
:
x
=
x
.
float
()
residual
=
residual
.
float
()
if
residual
is
not
None
else
residual
if
residual
is
not
None
:
x
=
(
x
+
residual
).
to
(
x
.
dtype
)
rstd
=
1
/
torch
.
sqrt
((
x
.
square
()).
mean
(
dim
=-
1
,
keepdim
=
True
)
+
eps
)
out
=
(
x
*
rstd
*
weight
)
+
bias
if
bias
is
not
None
else
(
x
*
rstd
*
weight
)
out
=
out
.
to
(
dtype
)
return
out
if
residual
is
None
else
(
out
,
x
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
1
),
triton
.
Config
({},
num_warps
=
2
),
triton
.
Config
({},
num_warps
=
4
),
triton
.
Config
({},
num_warps
=
8
),
triton
.
Config
({},
num_warps
=
16
),
triton
.
Config
({},
num_warps
=
32
),
],
key
=
[
"N"
,
"HAS_RESIDUAL"
,
"IS_RMS_NORM"
,
"HAS_BIAS"
],
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
@
triton
.
jit
def
_layer_norm_fwd_1pass_kernel
(
X
,
# pointer to the input
Y
,
# pointer to the output
W
,
# pointer to the weights
B
,
# pointer to the biases
RESIDUAL
,
# pointer to the residual
RESIDUAL_OUT
,
# pointer to the residual
Mean
,
# pointer to the mean
Rstd
,
# pointer to the 1/std
stride_x_row
,
# how much to increase the pointer when moving by 1 row
stride_y_row
,
stride_res_row
,
stride_res_out_row
,
N
,
# number of columns in X
eps
,
# epsilon to avoid division by zero
IS_RMS_NORM
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
HAS_RESIDUAL
:
tl
.
constexpr
,
HAS_BIAS
:
tl
.
constexpr
,
):
# Map the program id to the row of X and Y it should compute.
row
=
tl
.
program_id
(
0
)
X
+=
row
*
stride_x_row
Y
+=
row
*
stride_y_row
if
HAS_RESIDUAL
:
RESIDUAL
+=
row
*
stride_res_row
RESIDUAL_OUT
+=
row
*
stride_res_out_row
# Compute mean and variance
cols
=
tl
.
arange
(
0
,
BLOCK_N
)
x
=
tl
.
load
(
X
+
cols
,
mask
=
cols
<
N
,
other
=
0.
).
to
(
tl
.
float32
)
if
HAS_RESIDUAL
:
residual
=
tl
.
load
(
RESIDUAL
+
cols
,
mask
=
cols
<
N
,
other
=
0.
).
to
(
tl
.
float32
)
x
+=
residual
tl
.
store
(
RESIDUAL_OUT
+
cols
,
x
,
mask
=
cols
<
N
)
if
not
IS_RMS_NORM
:
mean
=
tl
.
sum
(
x
,
axis
=
0
)
/
N
tl
.
store
(
Mean
+
row
,
mean
)
xbar
=
tl
.
where
(
cols
<
N
,
x
-
mean
,
0.
)
var
=
tl
.
sum
(
xbar
*
xbar
,
axis
=
0
)
/
N
else
:
xbar
=
tl
.
where
(
cols
<
N
,
x
,
0.
)
var
=
tl
.
sum
(
xbar
*
xbar
,
axis
=
0
)
/
N
rstd
=
1
/
tl
.
sqrt
(
var
+
eps
)
tl
.
store
(
Rstd
+
row
,
rstd
)
# Normalize and apply linear transformation
mask
=
cols
<
N
w
=
tl
.
load
(
W
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
if
HAS_BIAS
:
b
=
tl
.
load
(
B
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
x_hat
=
(
x
-
mean
)
*
rstd
if
not
IS_RMS_NORM
else
x
*
rstd
y
=
x_hat
*
w
+
b
if
HAS_BIAS
else
x_hat
*
w
# Write output
tl
.
store
(
Y
+
cols
,
y
,
mask
=
mask
)
def
_layer_norm_fwd
(
x
,
weight
,
bias
,
eps
,
residual
=
None
,
is_rms_norm
=
False
):
M
,
N
=
x
.
shape
assert
x
.
stride
(
-
1
)
==
1
if
residual
is
not
None
:
assert
residual
.
stride
(
-
1
)
==
1
assert
residual
.
shape
==
(
M
,
N
)
assert
weight
.
shape
==
(
N
,)
assert
weight
.
stride
(
-
1
)
==
1
if
bias
is
not
None
:
assert
bias
.
stride
(
-
1
)
==
1
assert
bias
.
shape
==
(
N
,)
# allocate output
y
=
torch
.
empty_like
(
x
)
assert
y
.
stride
(
-
1
)
==
1
if
residual
is
not
None
:
residual_out
=
torch
.
empty_like
(
residual
)
assert
residual_out
.
stride
(
-
1
)
==
1
else
:
residual_out
=
None
mean
=
torch
.
empty
((
M
,
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
if
not
is_rms_norm
else
None
rstd
=
torch
.
empty
((
M
,
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE
=
65536
//
x
.
element_size
()
BLOCK_N
=
min
(
MAX_FUSED_SIZE
,
triton
.
next_power_of_2
(
N
))
if
N
>
BLOCK_N
:
raise
RuntimeError
(
"This layer norm doesn't support feature dim >= 64KB."
)
# heuristics for number of warps
with
torch
.
cuda
.
device
(
x
.
device
.
index
):
_layer_norm_fwd_1pass_kernel
[(
M
,)](
x
,
y
,
weight
,
bias
,
residual
,
residual_out
,
mean
,
rstd
,
x
.
stride
(
0
),
y
.
stride
(
0
),
residual
.
stride
(
0
)
if
residual
is
not
None
else
0
,
residual_out
.
stride
(
0
)
if
residual
is
not
None
else
0
,
N
,
eps
,
is_rms_norm
,
BLOCK_N
,
residual
is
not
None
,
bias
is
not
None
,
)
return
y
,
mean
,
rstd
,
residual_out
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
1
),
triton
.
Config
({},
num_warps
=
2
),
triton
.
Config
({},
num_warps
=
4
),
triton
.
Config
({},
num_warps
=
8
),
triton
.
Config
({},
num_warps
=
16
),
triton
.
Config
({},
num_warps
=
32
),
],
key
=
[
"N"
,
"HAS_DRESIDUAL"
,
"STORE_DRESIDUAL"
,
"IS_RMS_NORM"
,
"HAS_BIAS"
],
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
@
triton
.
heuristics
({
"RECOMPUTE_OUTPUT"
:
lambda
args
:
args
[
"Y"
]
is
not
None
})
@
triton
.
jit
def
_layer_norm_bwd_kernel
(
X
,
# pointer to the input
W
,
# pointer to the weights
B
,
# pointer to the biases
Y
,
# pointer to the output to be recomputed
DY
,
# pointer to the output gradient
DX
,
# pointer to the input gradient
DW
,
# pointer to the partial sum of weights gradient
DB
,
# pointer to the partial sum of biases gradient
DRESIDUAL
,
DRESIDUAL_IN
,
Mean
,
# pointer to the mean
Rstd
,
# pointer to the 1/std
stride_x_row
,
# how much to increase the pointer when moving by 1 row
stride_y_row
,
stride_dy_row
,
stride_dx_row
,
stride_dres_row
,
stride_dres_in_row
,
M
,
# number of rows in X
N
,
# number of columns in X
eps
,
# epsilon to avoid division by zero
rows_per_program
,
IS_RMS_NORM
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
HAS_DRESIDUAL
:
tl
.
constexpr
,
STORE_DRESIDUAL
:
tl
.
constexpr
,
HAS_BIAS
:
tl
.
constexpr
,
RECOMPUTE_OUTPUT
:
tl
.
constexpr
,
):
# Map the program id to the elements of X, DX, and DY it should compute.
row_block_id
=
tl
.
program_id
(
0
)
row_start
=
row_block_id
*
rows_per_program
cols
=
tl
.
arange
(
0
,
BLOCK_N
)
mask
=
cols
<
N
X
+=
row_start
*
stride_x_row
if
HAS_DRESIDUAL
:
DRESIDUAL
+=
row_start
*
stride_dres_row
if
STORE_DRESIDUAL
:
DRESIDUAL_IN
+=
row_start
*
stride_dres_in_row
DY
+=
row_start
*
stride_dy_row
DX
+=
row_start
*
stride_dx_row
if
RECOMPUTE_OUTPUT
:
Y
+=
row_start
*
stride_y_row
w
=
tl
.
load
(
W
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
if
RECOMPUTE_OUTPUT
and
HAS_BIAS
:
b
=
tl
.
load
(
B
+
cols
,
mask
=
mask
,
other
=
0.
).
to
(
tl
.
float32
)
dw
=
tl
.
zeros
((
BLOCK_N
,),
dtype
=
tl
.
float32
)
if
HAS_BIAS
:
db
=
tl
.
zeros
((
BLOCK_N
,),
dtype
=
tl
.
float32
)
row_end
=
min
((
row_block_id
+
1
)
*
rows_per_program
,
M
)
for
row
in
range
(
row_start
,
row_end
):
# Load data to SRAM
x
=
tl
.
load
(
X
+
cols
,
mask
=
mask
,
other
=
0
).
to
(
tl
.
float32
)
dy
=
tl
.
load
(
DY
+
cols
,
mask
=
mask
,
other
=
0
).
to
(
tl
.
float32
)
if
not
IS_RMS_NORM
:
mean
=
tl
.
load
(
Mean
+
row
)
rstd
=
tl
.
load
(
Rstd
+
row
)
# Compute dx
xhat
=
(
x
-
mean
)
*
rstd
if
not
IS_RMS_NORM
else
x
*
rstd
xhat
=
tl
.
where
(
mask
,
xhat
,
0.
)
if
RECOMPUTE_OUTPUT
:
y
=
xhat
*
w
+
b
if
HAS_BIAS
else
xhat
*
w
tl
.
store
(
Y
+
cols
,
y
,
mask
=
mask
)
wdy
=
w
*
dy
dw
+=
dy
*
xhat
if
HAS_BIAS
:
db
+=
dy
if
not
IS_RMS_NORM
:
c1
=
tl
.
sum
(
xhat
*
wdy
,
axis
=
0
)
/
N
c2
=
tl
.
sum
(
wdy
,
axis
=
0
)
/
N
dx
=
(
wdy
-
(
xhat
*
c1
+
c2
))
*
rstd
else
:
c1
=
tl
.
sum
(
xhat
*
wdy
,
axis
=
0
)
/
N
dx
=
(
wdy
-
xhat
*
c1
)
*
rstd
if
HAS_DRESIDUAL
:
dres
=
tl
.
load
(
DRESIDUAL
+
cols
,
mask
=
mask
,
other
=
0
).
to
(
tl
.
float32
)
dx
+=
dres
# Write dx
if
STORE_DRESIDUAL
:
tl
.
store
(
DRESIDUAL_IN
+
cols
,
dx
,
mask
=
mask
)
tl
.
store
(
DX
+
cols
,
dx
,
mask
=
mask
)
X
+=
stride_x_row
if
HAS_DRESIDUAL
:
DRESIDUAL
+=
stride_dres_row
if
STORE_DRESIDUAL
:
DRESIDUAL_IN
+=
stride_dres_in_row
if
RECOMPUTE_OUTPUT
:
Y
+=
stride_y_row
DY
+=
stride_dy_row
DX
+=
stride_dx_row
tl
.
store
(
DW
+
row_block_id
*
N
+
cols
,
dw
,
mask
=
mask
)
if
HAS_BIAS
:
tl
.
store
(
DB
+
row_block_id
*
N
+
cols
,
db
,
mask
=
mask
)
def
_layer_norm_bwd
(
dy
,
x
,
weight
,
bias
,
eps
,
mean
,
rstd
,
dresidual
=
None
,
is_rms_norm
=
False
,
x_dtype
=
None
,
recompute_output
=
False
):
M
,
N
=
x
.
shape
assert
x
.
stride
(
-
1
)
==
1
assert
dy
.
stride
(
-
1
)
==
1
assert
dy
.
shape
==
(
M
,
N
)
if
dresidual
is
not
None
:
assert
dresidual
.
stride
(
-
1
)
==
1
assert
dresidual
.
shape
==
(
M
,
N
)
assert
weight
.
shape
==
(
N
,)
assert
weight
.
stride
(
-
1
)
==
1
if
bias
is
not
None
:
assert
bias
.
stride
(
-
1
)
==
1
assert
bias
.
shape
==
(
N
,)
# allocate output
dx
=
torch
.
empty_like
(
x
)
if
x_dtype
is
None
else
torch
.
empty
(
M
,
N
,
dtype
=
x_dtype
,
device
=
x
.
device
)
dresidual_in
=
torch
.
empty_like
(
dresidual
)
if
dresidual
is
not
None
and
dx
.
dtype
!=
dresidual
.
dtype
else
None
y
=
torch
.
empty
(
M
,
N
,
dtype
=
dy
.
dtype
,
device
=
dy
.
device
)
if
recompute_output
else
None
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE
=
65536
//
x
.
element_size
()
BLOCK_N
=
min
(
MAX_FUSED_SIZE
,
triton
.
next_power_of_2
(
N
))
if
N
>
BLOCK_N
:
raise
RuntimeError
(
"This layer norm doesn't support feature dim >= 64KB."
)
sm_count
=
torch
.
cuda
.
get_device_properties
(
x
.
device
).
multi_processor_count
_dw
=
torch
.
empty
((
sm_count
,
N
),
dtype
=
torch
.
float32
,
device
=
weight
.
device
)
_db
=
torch
.
empty
((
sm_count
,
N
),
dtype
=
torch
.
float32
,
device
=
bias
.
device
)
if
bias
is
not
None
else
None
rows_per_program
=
math
.
ceil
(
M
/
sm_count
)
grid
=
(
sm_count
,)
with
torch
.
cuda
.
device
(
x
.
device
.
index
):
_layer_norm_bwd_kernel
[
grid
](
x
,
weight
,
bias
,
y
,
dy
,
dx
,
_dw
,
_db
,
dresidual
,
dresidual_in
,
mean
,
rstd
,
x
.
stride
(
0
),
0
if
not
recompute_output
else
y
.
stride
(
0
),
dy
.
stride
(
0
),
dx
.
stride
(
0
),
dresidual
.
stride
(
0
)
if
dresidual
is
not
None
else
0
,
dresidual_in
.
stride
(
0
)
if
dresidual_in
is
not
None
else
0
,
M
,
N
,
eps
,
rows_per_program
,
is_rms_norm
,
BLOCK_N
,
dresidual
is
not
None
,
dresidual_in
is
not
None
,
bias
is
not
None
)
dw
=
_dw
.
sum
(
0
).
to
(
weight
.
dtype
)
db
=
_db
.
sum
(
0
).
to
(
bias
.
dtype
)
if
bias
is
not
None
else
None
# Don't need to compute dresidual_in separately in this case
if
dresidual
is
not
None
and
dx
.
dtype
==
dresidual
.
dtype
:
dresidual_in
=
dx
return
(
dx
,
dw
,
db
,
dresidual_in
)
if
not
recompute_output
else
(
dx
,
dw
,
db
,
dresidual_in
,
y
)
class
LayerNormFn
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
weight
,
bias
,
residual
=
None
,
eps
=
1e-6
,
is_rms_norm
=
False
):
x_shape_og
=
x
.
shape
# reshape input data into 2D tensor
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
if
x
.
stride
(
-
1
)
!=
1
:
x
=
x
.
contiguous
()
if
residual
is
not
None
:
assert
residual
.
shape
==
x_shape_og
residual
=
residual
.
reshape
(
-
1
,
residual
.
shape
[
-
1
])
if
residual
.
stride
(
-
1
)
!=
1
:
residual
=
residual
.
contiguous
()
weight
=
weight
.
contiguous
()
if
bias
is
not
None
:
bias
=
bias
.
contiguous
()
y
,
mean
,
rstd
,
*
rest
=
_layer_norm_fwd
(
x
,
weight
,
bias
,
eps
,
residual
,
is_rms_norm
)
if
residual
is
not
None
:
residual_out
=
rest
[
0
]
ctx
.
save_for_backward
(
x
if
residual
is
None
else
residual_out
,
weight
,
bias
,
mean
,
rstd
)
ctx
.
x_shape_og
=
x_shape_og
ctx
.
eps
=
eps
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
has_residual
=
residual
is
not
None
ctx
.
x_dtype
=
x
.
dtype
y
=
y
.
reshape
(
x_shape_og
)
return
y
if
residual
is
None
else
(
y
,
residual_out
.
reshape
(
x_shape_og
))
@
staticmethod
def
backward
(
ctx
,
dy
,
*
args
):
x
,
weight
,
bias
,
mean
,
rstd
=
ctx
.
saved_tensors
dy
=
dy
.
reshape
(
-
1
,
dy
.
shape
[
-
1
])
if
dy
.
stride
(
-
1
)
!=
1
:
dy
=
dy
.
contiguous
()
assert
dy
.
shape
==
x
.
shape
if
ctx
.
has_residual
:
dresidual
=
args
[
0
]
dresidual
=
dresidual
.
reshape
(
-
1
,
dresidual
.
shape
[
-
1
])
if
dresidual
.
stride
(
-
1
)
!=
1
:
dresidual
=
dresidual
.
contiguous
()
assert
dresidual
.
shape
==
x
.
shape
else
:
dresidual
=
None
dx
,
dw
,
db
,
dresidual_in
=
_layer_norm_bwd
(
dy
,
x
,
weight
,
bias
,
ctx
.
eps
,
mean
,
rstd
,
dresidual
,
ctx
.
is_rms_norm
,
x_dtype
=
ctx
.
x_dtype
)
return
dx
.
reshape
(
ctx
.
x_shape_og
),
dw
,
db
,
dresidual_in
.
reshape
(
ctx
.
x_shape_og
)
if
ctx
.
has_residual
else
None
,
None
,
None
def
layer_norm_fn
(
x
,
weight
,
bias
,
residual
=
None
,
eps
=
1e-6
,
is_rms_norm
=
False
):
return
LayerNormFn
.
apply
(
x
,
weight
,
bias
,
residual
,
eps
,
is_rms_norm
)
def
rms_norm_fn
(
x
,
weight
,
bias
,
residual
=
None
,
eps
=
1e-6
):
return
LayerNormFn
.
apply
(
x
,
weight
,
bias
,
residual
,
eps
,
True
)
class
RMSNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-5
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
eps
=
eps
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
register_parameter
(
"bias"
,
None
)
self
.
reset_parameters
()
def
reset_parameters
(
self
):
torch
.
nn
.
init
.
ones_
(
self
.
weight
)
def
forward
(
self
,
x
,
residual
=
None
):
return
layer_norm_fn
(
x
,
self
.
weight
,
self
.
bias
,
residual
=
residual
,
eps
=
self
.
eps
,
is_rms_norm
=
True
)
flash_attn/utils/benchmark.py
View file @
79bd1a2d
...
...
@@ -213,7 +213,10 @@ def pytorch_profiler(
"""Wrap benchmark functions in Pytorch profiler to see CUDA information."""
if
backward
:
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
g
=
torch
.
randn_like
(
fn
(
*
inputs
,
**
kwinputs
))
out
=
fn
(
*
inputs
,
**
kwinputs
)
if
type
(
out
)
is
tuple
:
out
=
out
[
0
]
g
=
torch
.
randn_like
(
out
)
for
_
in
range
(
30
):
# Warm up
if
backward
:
for
x
in
inputs
:
...
...
@@ -221,6 +224,8 @@ def pytorch_profiler(
x
.
grad
=
None
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
out
=
fn
(
*
inputs
,
**
kwinputs
)
if
type
(
out
)
is
tuple
:
out
=
out
[
0
]
# Backward should be done outside autocast
if
backward
:
out
.
backward
(
g
,
retain_graph
=
True
)
...
...
@@ -239,6 +244,8 @@ def pytorch_profiler(
x
.
grad
=
None
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
out
=
fn
(
*
inputs
,
**
kwinputs
)
if
type
(
out
)
is
tuple
:
out
=
out
[
0
]
if
backward
:
out
.
backward
(
g
,
retain_graph
=
True
)
if
verbose
:
...
...
tests/ops/triton/test_layer_norm.py
0 → 100644
View file @
79bd1a2d
import
math
from
functools
import
partial
import
pytest
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
from
flash_attn.ops.triton.layernorm
import
layer_norm_fn
,
layer_norm_ref
,
rms_norm_ref
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)[
0
]
>=
8
@
pytest
.
mark
.
parametrize
(
"is_rms_norm"
,
[
False
,
True
])
# @pytest.mark.parametrize("is_rms_norm", [True])
@
pytest
.
mark
.
parametrize
(
"has_residual"
,
[
True
,
False
])
# @pytest.mark.parametrize("has_residual", [True])
@
pytest
.
mark
.
parametrize
(
"weight_dtype"
,
[
torch
.
float32
,
torch
.
float16
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[])
)
# @pytest.mark.parametrize("weight_dtype", [torch.float32])
@
pytest
.
mark
.
parametrize
(
"input_dtype,residual_dtype"
,
[(
torch
.
float16
,
torch
.
float16
),
(
torch
.
float16
,
torch
.
float32
),
(
torch
.
float32
,
torch
.
float32
)]
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]),
)
# @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.bfloat16, torch.float32)])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
192
,
2048
,
2560
,
3000
,
8192
])
# @pytest.mark.parametrize("hidden_size", [256])
def
test_layer_norm
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
has_residual
,
is_rms_norm
):
device
=
"cuda"
if
any
(
x
==
torch
.
bfloat16
for
x
in
[
input_dtype
,
residual_dtype
,
weight_dtype
]):
atol
=
5e-2
elif
any
(
x
==
torch
.
float16
for
x
in
[
input_dtype
,
residual_dtype
,
weight_dtype
]):
atol
=
5e-3
else
:
atol
=
1e-4
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
512
# batch_size = 1
# seqlen = 1
layer_norm_ref_fn
=
layer_norm_ref
if
not
is_rms_norm
else
rms_norm_ref
allclose
=
(
lambda
x
,
x_pt
,
x_ref
,
atol
=
atol
:
(
x
-
x_ref
).
abs
().
max
()
<=
2
*
(
x_pt
-
x_ref
).
abs
().
max
()
+
atol
)
x0
=
torch
.
randn
(
batch_size
,
seqlen
,
hidden_size
,
device
=
device
,
dtype
=
input_dtype
,
requires_grad
=
True
)
x0_pt
=
x0
.
detach
().
clone
().
requires_grad_
()
x0_ref
=
x0
.
detach
().
clone
().
requires_grad_
()
if
has_residual
:
res
=
torch
.
randn_like
(
x0
,
dtype
=
residual_dtype
,
requires_grad
=
True
)
res_pt
=
res
.
detach
().
clone
().
requires_grad_
()
res_ref
=
res
.
detach
().
clone
().
requires_grad_
()
else
:
res
,
res_pt
,
res_ref
=
None
,
None
,
None
weight
=
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
if
not
is_rms_norm
:
bias
=
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
else
:
bias
=
None
weight_pt
=
weight
.
detach
().
clone
().
requires_grad_
()
weight_ref
=
weight
.
detach
().
clone
().
requires_grad_
()
bias_pt
=
bias
.
detach
().
clone
().
requires_grad_
()
if
bias
is
not
None
else
None
bias_ref
=
bias
.
detach
().
clone
().
requires_grad_
()
if
bias
is
not
None
else
None
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
out
,
*
rest
=
layer_norm_fn
(
x0
,
weight
,
bias
,
residual
=
res
,
eps
=
1e-6
,
is_rms_norm
=
is_rms_norm
)
out_pt
,
*
rest_pt
=
layer_norm_ref_fn
(
x0_pt
,
weight_pt
,
bias_pt
,
residual
=
res_pt
,
eps
=
1e-6
)
out_ref
,
*
rest_ref
=
layer_norm_ref_fn
(
x0_ref
,
weight_ref
,
bias_ref
,
residual
=
res_ref
,
eps
=
1e-6
,
upcast
=
True
)
if
has_residual
:
residual
=
rest
[
0
]
residual_pt
=
rest_pt
[
0
]
residual_ref
=
rest_ref
[
0
]
residual_ref
=
x0_ref
+
res_ref
assert
out
.
dtype
==
input_dtype
if
has_residual
:
assert
residual
.
dtype
==
residual_dtype
assert
allclose
(
residual
,
residual_pt
,
residual_ref
)
assert
allclose
(
out
,
out_pt
,
out_ref
)
g
=
torch
.
randn_like
(
out
)
/
batch_size
if
not
has_residual
:
out
.
backward
(
g
)
out_pt
.
backward
(
g
)
out_ref
.
backward
(
g
)
else
:
(
out
*
F
.
sigmoid
(
residual
)).
backward
(
g
)
(
out_pt
*
F
.
sigmoid
(
residual_pt
)).
backward
(
g
)
(
out_ref
*
F
.
sigmoid
(
residual_ref
.
to
(
dtype
=
residual_dtype
))).
backward
(
g
)
assert
allclose
(
x0
.
grad
,
x0_pt
.
grad
,
x0_ref
.
grad
)
if
has_residual
:
assert
allclose
(
res
.
grad
,
res_pt
.
grad
,
res_ref
.
grad
)
assert
allclose
(
weight
.
grad
,
weight_pt
.
grad
,
weight_ref
.
grad
)
if
bias
is
not
None
:
assert
allclose
(
bias
.
grad
,
bias_pt
.
grad
,
bias_ref
.
grad
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment