Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
26f4b5fb
"torchvision/csrc/cuda/DeformConv_cuda.cu" did not exist on "5b1716a2ce67359c1aa8831b06011067bdbce1e2"
Commit
26f4b5fb
authored
Jul 31, 2024
by
Woosuk Kwon
Browse files
Merge branch 'main' into Dao-AILab/main
parents
5018ac6a
12375706
Pipeline
#2015
failed with stages
in 0 seconds
Changes
95
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
107 additions
and
3565 deletions
+107
-3565
flash_attn/ops/triton/k_activations.py
flash_attn/ops/triton/k_activations.py
+0
-162
flash_attn/ops/triton/layer_norm.py
flash_attn/ops/triton/layer_norm.py
+0
-1086
flash_attn/ops/triton/linear.py
flash_attn/ops/triton/linear.py
+0
-594
flash_attn/ops/triton/mlp.py
flash_attn/ops/triton/mlp.py
+0
-149
flash_attn/ops/triton/rotary.py
flash_attn/ops/triton/rotary.py
+0
-227
flash_attn/utils/__init__.py
flash_attn/utils/__init__.py
+0
-0
flash_attn/utils/benchmark.py
flash_attn/utils/benchmark.py
+0
-268
flash_attn/utils/distributed.py
flash_attn/utils/distributed.py
+0
-144
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+0
-740
flash_attn/utils/pretrained.py
flash_attn/utils/pretrained.py
+0
-79
setup.py
setup.py
+20
-107
tests/test_flash_attn.py
tests/test_flash_attn.py
+46
-2
vllm_flash_attn/__init__.py
vllm_flash_attn/__init__.py
+2
-2
vllm_flash_attn/flash_attn_interface.py
vllm_flash_attn/flash_attn_interface.py
+39
-5
vllm_flash_attn/pyproject.toml
vllm_flash_attn/pyproject.toml
+0
-0
No files found.
flash_attn/ops/triton/k_activations.py
deleted
100644 → 0
View file @
5018ac6a
# Adapted from https://github.com/facebookresearch/xformers/blob/main/xformers/triton/k_activations.py
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import
math
from
enum
import
Enum
from
typing
import
Optional
import
triton
import
triton.language
as
tl
_sqrt2pi
=
math
.
sqrt
(
2.0
/
math
.
pi
)
_sqrt1_2
=
math
.
sqrt
(
1.0
/
2
)
_gaussian_pdf_normalization
=
1.0
/
math
.
sqrt
(
2
*
math
.
pi
)
class
Activation
(
str
,
Enum
):
SquaredReLU
=
"squared_relu"
GeLU
=
"gelu"
GeLUApprox
=
"gelu_approx"
LeakyReLU
=
"leaky_relu"
ReLU
=
"relu"
def
get_triton_activation_kernel
(
activation
:
Optional
[
Activation
]):
return
(
{
Activation
.
ReLU
:
relu
,
Activation
.
LeakyReLU
:
leaky_relu
,
Activation
.
GeLU
:
gelu
,
Activation
.
GeLUApprox
:
gelu_approx
,
Activation
.
SquaredReLU
:
squared_relu
,
}[
activation
]
if
activation
else
None
)
def
get_triton_activation_bwd_kernel
(
activation
:
Optional
[
Activation
]):
return
(
{
Activation
.
ReLU
:
relu_grad
,
Activation
.
LeakyReLU
:
leaky_relu_grad
,
Activation
.
GeLU
:
gelu_grad
,
Activation
.
GeLUApprox
:
gelu_approx_grad
,
Activation
.
SquaredReLU
:
squared_relu_grad
,
}[
activation
]
if
activation
else
None
)
@
triton
.
jit
def
tanh
(
x
):
# Tanh is just a scaled sigmoid
return
2
*
tl
.
sigmoid
(
2
*
x
)
-
1
@
triton
.
jit
def
cosh
(
x
):
exp_x
=
tl
.
exp
(
x
)
return
(
exp_x
+
1.0
/
exp_x
)
*
0.5
# a Triton implementation of the most used activations
# See for instance http://arxiv.org/abs/1606.08415 for an overview
# ReLU
@
triton
.
jit
def
relu
(
x
):
"""
ReLU_ activation function
.. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html
"""
zero
=
0.0
return
tl
.
where
(
x
>=
0
,
x
,
zero
.
to
(
x
.
dtype
))
@
triton
.
jit
def
relu_grad
(
x
):
# ReLU is different from other activations
# in that it does not require the input to retrospectively compute its gradient
# here the input is the downstream gradient, and we return the upstream gradient directly
zero
=
0.0
one
=
1.0
return
tl
.
where
(
x
>=
0
,
one
.
to
(
x
.
dtype
),
zero
.
to
(
x
.
dtype
))
@
triton
.
jit
def
squared_relu
(
x
):
"""
Squared ReLU activation, as proposed in the Primer_ paper.
.. _Primer: https://arxiv.org/abs/2109.08668
"""
x_
=
relu
(
x
)
return
(
x_
*
x_
).
to
(
x
.
dtype
)
@
triton
.
jit
def
squared_relu_grad
(
x
):
return
tl
.
where
(
x
>=
0
,
2.0
*
x
,
0.0
)
# Leaky ReLU
@
triton
.
jit
def
leaky_relu
(
x
):
"""
LeakyReLU_ activation
.. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html
"""
scale
=
0.01
+
0.0
scale
=
scale
.
to
(
x
.
dtype
)
return
tl
.
where
(
x
>=
0
,
x
,
scale
*
x
)
@
triton
.
jit
def
leaky_relu_grad
(
x
):
min_grad
=
0.01
max_grad
=
1
min_grad
=
min_grad
.
to
(
x
.
dtype
)
max_grad
=
max_grad
.
to
(
x
.
dtype
)
return
tl
.
where
(
x
>=
0
,
max_grad
,
min_grad
)
@
triton
.
jit
def
gelu
(
x
):
"""Gaussian Error Linear Unit (GELU)"""
return
x
*
0.5
*
(
1.0
+
tl
.
libdevice
.
erf
(
x
*
_sqrt1_2
))
@
triton
.
jit
def
gelu_grad
(
x
):
cdf
=
0.5
*
(
1.0
+
tl
.
libdevice
.
erf
(
x
*
_sqrt1_2
))
pdf
=
tl
.
exp
(
-
0.5
*
x
*
x
)
*
_gaussian_pdf_normalization
return
cdf
+
x
*
pdf
@
triton
.
jit
def
gelu_approx
(
x
):
"""
GeLU_ activation - Gaussian error linear unit, with tanh approximation
.. _GeLU: https://arxiv.org/pdf/1606.08415.pdf
"""
return
0.5
*
x
*
(
1.0
+
tanh
(
_sqrt2pi
*
x
*
(
1.0
+
0.044715
*
x
*
x
)))
@
triton
.
jit
def
gelu_approx_grad
(
x
):
# CREDITS: Fast implementation proposed in
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30
tanh_out
=
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
))
return
0.5
*
x
*
((
1
-
tanh_out
*
tanh_out
)
*
(
0.79788456
+
0.1070322243
*
x
*
x
))
+
0.5
*
(
1
+
tanh_out
)
flash_attn/ops/triton/layer_norm.py
deleted
100644 → 0
View file @
5018ac6a
# Copyright (c) 2024, Tri Dao.
# Implement dropout + residual + layer_norm / rms_norm.
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
import
math
import
torch
import
torch.nn.functional
as
F
from
torch.cuda.amp
import
custom_fwd
,
custom_bwd
import
triton
import
triton.language
as
tl
def
layer_norm_ref
(
x
,
weight
,
bias
,
residual
=
None
,
x1
=
None
,
weight1
=
None
,
bias1
=
None
,
eps
=
1e-6
,
dropout_p
=
0.0
,
rowscale
=
None
,
prenorm
=
False
,
dropout_mask
=
None
,
dropout_mask1
=
None
,
upcast
=
False
,
):
dtype
=
x
.
dtype
if
upcast
:
x
=
x
.
float
()
weight
=
weight
.
float
()
bias
=
bias
.
float
()
if
bias
is
not
None
else
None
residual
=
residual
.
float
()
if
residual
is
not
None
else
residual
x1
=
x1
.
float
()
if
x1
is
not
None
else
None
weight1
=
weight1
.
float
()
if
weight1
is
not
None
else
None
bias1
=
bias1
.
float
()
if
bias1
is
not
None
else
None
if
x1
is
not
None
:
assert
rowscale
is
None
,
"rowscale is not supported with parallel LayerNorm"
if
rowscale
is
not
None
:
x
=
x
*
rowscale
[...,
None
]
if
dropout_p
>
0.0
:
if
dropout_mask
is
not
None
:
x
=
x
.
masked_fill
(
~
dropout_mask
,
0.0
)
/
(
1.0
-
dropout_p
)
else
:
x
=
F
.
dropout
(
x
,
p
=
dropout_p
)
if
x1
is
not
None
:
if
dropout_mask1
is
not
None
:
x1
=
x1
.
masked_fill
(
~
dropout_mask1
,
0.0
)
/
(
1.0
-
dropout_p
)
else
:
x1
=
F
.
dropout
(
x1
,
p
=
dropout_p
)
if
x1
is
not
None
:
x
=
x
+
x1
if
residual
is
not
None
:
x
=
(
x
+
residual
).
to
(
x
.
dtype
)
out
=
F
.
layer_norm
(
x
.
to
(
weight
.
dtype
),
x
.
shape
[
-
1
:],
weight
=
weight
,
bias
=
bias
,
eps
=
eps
).
to
(
dtype
)
if
weight1
is
None
:
return
out
if
not
prenorm
else
(
out
,
x
)
else
:
out1
=
F
.
layer_norm
(
x
.
to
(
weight1
.
dtype
),
x
.
shape
[
-
1
:],
weight
=
weight1
,
bias
=
bias1
,
eps
=
eps
).
to
(
dtype
)
return
(
out
,
out1
)
if
not
prenorm
else
(
out
,
out1
,
x
)
def
rms_norm_ref
(
x
,
weight
,
bias
,
residual
=
None
,
x1
=
None
,
weight1
=
None
,
bias1
=
None
,
eps
=
1e-6
,
dropout_p
=
0.0
,
rowscale
=
None
,
prenorm
=
False
,
dropout_mask
=
None
,
dropout_mask1
=
None
,
upcast
=
False
,
):
dtype
=
x
.
dtype
if
upcast
:
x
=
x
.
float
()
weight
=
weight
.
float
()
bias
=
bias
.
float
()
if
bias
is
not
None
else
None
residual
=
residual
.
float
()
if
residual
is
not
None
else
residual
x1
=
x1
.
float
()
if
x1
is
not
None
else
None
weight1
=
weight1
.
float
()
if
weight1
is
not
None
else
None
bias1
=
bias1
.
float
()
if
bias1
is
not
None
else
None
if
x1
is
not
None
:
assert
rowscale
is
None
,
"rowscale is not supported with parallel LayerNorm"
if
rowscale
is
not
None
:
x
=
x
*
rowscale
[...,
None
]
if
dropout_p
>
0.0
:
if
dropout_mask
is
not
None
:
x
=
x
.
masked_fill
(
~
dropout_mask
,
0.0
)
/
(
1.0
-
dropout_p
)
else
:
x
=
F
.
dropout
(
x
,
p
=
dropout_p
)
if
x1
is
not
None
:
if
dropout_mask1
is
not
None
:
x1
=
x1
.
masked_fill
(
~
dropout_mask1
,
0.0
)
/
(
1.0
-
dropout_p
)
else
:
x1
=
F
.
dropout
(
x1
,
p
=
dropout_p
)
if
x1
is
not
None
:
x
=
x
+
x1
if
residual
is
not
None
:
x
=
(
x
+
residual
).
to
(
x
.
dtype
)
rstd
=
1
/
torch
.
sqrt
((
x
.
square
()).
mean
(
dim
=-
1
,
keepdim
=
True
)
+
eps
)
out
=
((
x
*
rstd
*
weight
)
+
bias
if
bias
is
not
None
else
(
x
*
rstd
*
weight
)).
to
(
dtype
)
if
weight1
is
None
:
return
out
if
not
prenorm
else
(
out
,
x
)
else
:
out1
=
((
x
*
rstd
*
weight1
)
+
bias1
if
bias1
is
not
None
else
(
x
*
rstd
*
weight1
)).
to
(
dtype
)
return
(
out
,
out1
)
if
not
prenorm
else
(
out
,
out1
,
x
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
1
),
triton
.
Config
({},
num_warps
=
2
),
triton
.
Config
({},
num_warps
=
4
),
triton
.
Config
({},
num_warps
=
8
),
triton
.
Config
({},
num_warps
=
16
),
triton
.
Config
({},
num_warps
=
32
),
],
key
=
[
"N"
,
"HAS_RESIDUAL"
,
"STORE_RESIDUAL_OUT"
,
"IS_RMS_NORM"
,
"HAS_BIAS"
],
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
@
triton
.
heuristics
({
"HAS_X1"
:
lambda
args
:
args
[
"X1"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_W1"
:
lambda
args
:
args
[
"W1"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_B1"
:
lambda
args
:
args
[
"B1"
]
is
not
None
})
@
triton
.
jit
def
_layer_norm_fwd_1pass_kernel
(
X
,
# pointer to the input
Y
,
# pointer to the output
W
,
# pointer to the weights
B
,
# pointer to the biases
RESIDUAL
,
# pointer to the residual
X1
,
W1
,
B1
,
Y1
,
RESIDUAL_OUT
,
# pointer to the residual
ROWSCALE
,
SEEDS
,
# Dropout seeds for each row
DROPOUT_MASK
,
Mean
,
# pointer to the mean
Rstd
,
# pointer to the 1/std
stride_x_row
,
# how much to increase the pointer when moving by 1 row
stride_y_row
,
stride_res_row
,
stride_res_out_row
,
stride_x1_row
,
stride_y1_row
,
M
,
# number of rows in X
N
,
# number of columns in X
eps
,
# epsilon to avoid division by zero
dropout_p
,
# Dropout probability
IS_RMS_NORM
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
HAS_RESIDUAL
:
tl
.
constexpr
,
STORE_RESIDUAL_OUT
:
tl
.
constexpr
,
HAS_BIAS
:
tl
.
constexpr
,
HAS_DROPOUT
:
tl
.
constexpr
,
STORE_DROPOUT_MASK
:
tl
.
constexpr
,
HAS_ROWSCALE
:
tl
.
constexpr
,
HAS_X1
:
tl
.
constexpr
,
HAS_W1
:
tl
.
constexpr
,
HAS_B1
:
tl
.
constexpr
,
):
# Map the program id to the row of X and Y it should compute.
row
=
tl
.
program_id
(
0
)
X
+=
row
*
stride_x_row
Y
+=
row
*
stride_y_row
if
HAS_RESIDUAL
:
RESIDUAL
+=
row
*
stride_res_row
if
STORE_RESIDUAL_OUT
:
RESIDUAL_OUT
+=
row
*
stride_res_out_row
if
HAS_X1
:
X1
+=
row
*
stride_x1_row
if
HAS_W1
:
Y1
+=
row
*
stride_y1_row
# Compute mean and variance
cols
=
tl
.
arange
(
0
,
BLOCK_N
)
x
=
tl
.
load
(
X
+
cols
,
mask
=
cols
<
N
,
other
=
0.0
).
to
(
tl
.
float32
)
if
HAS_ROWSCALE
:
rowscale
=
tl
.
load
(
ROWSCALE
+
row
).
to
(
tl
.
float32
)
x
*=
rowscale
if
HAS_DROPOUT
:
# Compute dropout mask
# 7 rounds is good enough, and reduces register pressure
keep_mask
=
tl
.
rand
(
tl
.
load
(
SEEDS
+
row
).
to
(
tl
.
uint32
),
cols
,
n_rounds
=
7
)
>
dropout_p
x
=
tl
.
where
(
keep_mask
,
x
/
(
1.0
-
dropout_p
),
0.0
)
if
STORE_DROPOUT_MASK
:
tl
.
store
(
DROPOUT_MASK
+
row
*
N
+
cols
,
keep_mask
,
mask
=
cols
<
N
)
if
HAS_X1
:
x1
=
tl
.
load
(
X1
+
cols
,
mask
=
cols
<
N
,
other
=
0.0
).
to
(
tl
.
float32
)
if
HAS_ROWSCALE
:
rowscale
=
tl
.
load
(
ROWSCALE
+
M
+
row
).
to
(
tl
.
float32
)
x1
*=
rowscale
if
HAS_DROPOUT
:
# Compute dropout mask
# 7 rounds is good enough, and reduces register pressure
keep_mask
=
(
tl
.
rand
(
tl
.
load
(
SEEDS
+
M
+
row
).
to
(
tl
.
uint32
),
cols
,
n_rounds
=
7
)
>
dropout_p
)
x1
=
tl
.
where
(
keep_mask
,
x1
/
(
1.0
-
dropout_p
),
0.0
)
if
STORE_DROPOUT_MASK
:
tl
.
store
(
DROPOUT_MASK
+
(
M
+
row
)
*
N
+
cols
,
keep_mask
,
mask
=
cols
<
N
)
x
+=
x1
if
HAS_RESIDUAL
:
residual
=
tl
.
load
(
RESIDUAL
+
cols
,
mask
=
cols
<
N
,
other
=
0.0
).
to
(
tl
.
float32
)
x
+=
residual
if
STORE_RESIDUAL_OUT
:
tl
.
store
(
RESIDUAL_OUT
+
cols
,
x
,
mask
=
cols
<
N
)
if
not
IS_RMS_NORM
:
mean
=
tl
.
sum
(
x
,
axis
=
0
)
/
N
tl
.
store
(
Mean
+
row
,
mean
)
xbar
=
tl
.
where
(
cols
<
N
,
x
-
mean
,
0.0
)
var
=
tl
.
sum
(
xbar
*
xbar
,
axis
=
0
)
/
N
else
:
xbar
=
tl
.
where
(
cols
<
N
,
x
,
0.0
)
var
=
tl
.
sum
(
xbar
*
xbar
,
axis
=
0
)
/
N
rstd
=
1
/
tl
.
sqrt
(
var
+
eps
)
tl
.
store
(
Rstd
+
row
,
rstd
)
# Normalize and apply linear transformation
mask
=
cols
<
N
w
=
tl
.
load
(
W
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
if
HAS_BIAS
:
b
=
tl
.
load
(
B
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
x_hat
=
(
x
-
mean
)
*
rstd
if
not
IS_RMS_NORM
else
x
*
rstd
y
=
x_hat
*
w
+
b
if
HAS_BIAS
else
x_hat
*
w
# Write output
tl
.
store
(
Y
+
cols
,
y
,
mask
=
mask
)
if
HAS_W1
:
w1
=
tl
.
load
(
W1
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
if
HAS_B1
:
b1
=
tl
.
load
(
B1
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
y1
=
x_hat
*
w1
+
b1
if
HAS_B1
else
x_hat
*
w1
tl
.
store
(
Y1
+
cols
,
y1
,
mask
=
mask
)
def
_layer_norm_fwd
(
x
,
weight
,
bias
,
eps
,
residual
=
None
,
x1
=
None
,
weight1
=
None
,
bias1
=
None
,
dropout_p
=
0.0
,
rowscale
=
None
,
out_dtype
=
None
,
residual_dtype
=
None
,
is_rms_norm
=
False
,
return_dropout_mask
=
False
,
):
if
residual
is
not
None
:
residual_dtype
=
residual
.
dtype
M
,
N
=
x
.
shape
assert
x
.
stride
(
-
1
)
==
1
if
residual
is
not
None
:
assert
residual
.
stride
(
-
1
)
==
1
assert
residual
.
shape
==
(
M
,
N
)
assert
weight
.
shape
==
(
N
,)
assert
weight
.
stride
(
-
1
)
==
1
if
bias
is
not
None
:
assert
bias
.
stride
(
-
1
)
==
1
assert
bias
.
shape
==
(
N
,)
if
x1
is
not
None
:
assert
x1
.
shape
==
x
.
shape
assert
rowscale
is
None
assert
x1
.
stride
(
-
1
)
==
1
if
weight1
is
not
None
:
assert
weight1
.
shape
==
(
N
,)
assert
weight1
.
stride
(
-
1
)
==
1
if
bias1
is
not
None
:
assert
bias1
.
shape
==
(
N
,)
assert
bias1
.
stride
(
-
1
)
==
1
if
rowscale
is
not
None
:
assert
rowscale
.
is_contiguous
()
assert
rowscale
.
shape
==
(
M
,)
# allocate output
y
=
torch
.
empty_like
(
x
,
dtype
=
x
.
dtype
if
out_dtype
is
None
else
out_dtype
)
assert
y
.
stride
(
-
1
)
==
1
if
weight1
is
not
None
:
y1
=
torch
.
empty_like
(
y
)
assert
y1
.
stride
(
-
1
)
==
1
else
:
y1
=
None
if
(
residual
is
not
None
or
(
residual_dtype
is
not
None
and
residual_dtype
!=
x
.
dtype
)
or
dropout_p
>
0.0
or
rowscale
is
not
None
or
x1
is
not
None
):
residual_out
=
torch
.
empty
(
M
,
N
,
device
=
x
.
device
,
dtype
=
residual_dtype
if
residual_dtype
is
not
None
else
x
.
dtype
)
assert
residual_out
.
stride
(
-
1
)
==
1
else
:
residual_out
=
None
mean
=
torch
.
empty
((
M
,),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
if
not
is_rms_norm
else
None
rstd
=
torch
.
empty
((
M
,),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
if
dropout_p
>
0.0
:
seeds
=
torch
.
randint
(
2
**
32
,
(
M
if
x1
is
None
else
2
*
M
,),
device
=
x
.
device
,
dtype
=
torch
.
int64
)
else
:
seeds
=
None
if
return_dropout_mask
and
dropout_p
>
0.0
:
dropout_mask
=
torch
.
empty
(
M
if
x1
is
None
else
2
*
M
,
N
,
device
=
x
.
device
,
dtype
=
torch
.
bool
)
else
:
dropout_mask
=
None
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE
=
65536
//
x
.
element_size
()
BLOCK_N
=
min
(
MAX_FUSED_SIZE
,
triton
.
next_power_of_2
(
N
))
if
N
>
BLOCK_N
:
raise
RuntimeError
(
"This layer norm doesn't support feature dim >= 64KB."
)
with
torch
.
cuda
.
device
(
x
.
device
.
index
):
_layer_norm_fwd_1pass_kernel
[(
M
,)](
x
,
y
,
weight
,
bias
,
residual
,
x1
,
weight1
,
bias1
,
y1
,
residual_out
,
rowscale
,
seeds
,
dropout_mask
,
mean
,
rstd
,
x
.
stride
(
0
),
y
.
stride
(
0
),
residual
.
stride
(
0
)
if
residual
is
not
None
else
0
,
residual_out
.
stride
(
0
)
if
residual_out
is
not
None
else
0
,
x1
.
stride
(
0
)
if
x1
is
not
None
else
0
,
y1
.
stride
(
0
)
if
y1
is
not
None
else
0
,
M
,
N
,
eps
,
dropout_p
,
is_rms_norm
,
BLOCK_N
,
residual
is
not
None
,
residual_out
is
not
None
,
bias
is
not
None
,
dropout_p
>
0.0
,
dropout_mask
is
not
None
,
rowscale
is
not
None
,
)
# residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
if
dropout_mask
is
not
None
and
x1
is
not
None
:
dropout_mask
,
dropout_mask1
=
dropout_mask
.
tensor_split
(
2
,
dim
=
0
)
else
:
dropout_mask1
=
None
return
(
y
,
y1
,
mean
,
rstd
,
residual_out
if
residual_out
is
not
None
else
x
,
seeds
,
dropout_mask
,
dropout_mask1
,
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
1
),
triton
.
Config
({},
num_warps
=
2
),
triton
.
Config
({},
num_warps
=
4
),
triton
.
Config
({},
num_warps
=
8
),
triton
.
Config
({},
num_warps
=
16
),
triton
.
Config
({},
num_warps
=
32
),
],
key
=
[
"N"
,
"HAS_DRESIDUAL"
,
"STORE_DRESIDUAL"
,
"IS_RMS_NORM"
,
"HAS_BIAS"
,
"HAS_DROPOUT"
],
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
@
triton
.
heuristics
({
"HAS_ROWSCALE"
:
lambda
args
:
args
[
"ROWSCALE"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_DY1"
:
lambda
args
:
args
[
"DY1"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_DX1"
:
lambda
args
:
args
[
"DX1"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_B1"
:
lambda
args
:
args
[
"DB1"
]
is
not
None
})
@
triton
.
heuristics
({
"RECOMPUTE_OUTPUT"
:
lambda
args
:
args
[
"Y"
]
is
not
None
})
@
triton
.
jit
def
_layer_norm_bwd_kernel
(
X
,
# pointer to the input
W
,
# pointer to the weights
B
,
# pointer to the biases
Y
,
# pointer to the output to be recomputed
DY
,
# pointer to the output gradient
DX
,
# pointer to the input gradient
DW
,
# pointer to the partial sum of weights gradient
DB
,
# pointer to the partial sum of biases gradient
DRESIDUAL
,
W1
,
DY1
,
DX1
,
DW1
,
DB1
,
DRESIDUAL_IN
,
ROWSCALE
,
SEEDS
,
Mean
,
# pointer to the mean
Rstd
,
# pointer to the 1/std
stride_x_row
,
# how much to increase the pointer when moving by 1 row
stride_y_row
,
stride_dy_row
,
stride_dx_row
,
stride_dres_row
,
stride_dy1_row
,
stride_dx1_row
,
stride_dres_in_row
,
M
,
# number of rows in X
N
,
# number of columns in X
eps
,
# epsilon to avoid division by zero
dropout_p
,
rows_per_program
,
IS_RMS_NORM
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
HAS_DRESIDUAL
:
tl
.
constexpr
,
STORE_DRESIDUAL
:
tl
.
constexpr
,
HAS_BIAS
:
tl
.
constexpr
,
HAS_DROPOUT
:
tl
.
constexpr
,
HAS_ROWSCALE
:
tl
.
constexpr
,
HAS_DY1
:
tl
.
constexpr
,
HAS_DX1
:
tl
.
constexpr
,
HAS_B1
:
tl
.
constexpr
,
RECOMPUTE_OUTPUT
:
tl
.
constexpr
,
):
# Map the program id to the elements of X, DX, and DY it should compute.
row_block_id
=
tl
.
program_id
(
0
)
row_start
=
row_block_id
*
rows_per_program
# Do not early exit if row_start >= M, because we need to write DW and DB
cols
=
tl
.
arange
(
0
,
BLOCK_N
)
mask
=
cols
<
N
X
+=
row_start
*
stride_x_row
if
HAS_DRESIDUAL
:
DRESIDUAL
+=
row_start
*
stride_dres_row
if
STORE_DRESIDUAL
:
DRESIDUAL_IN
+=
row_start
*
stride_dres_in_row
DY
+=
row_start
*
stride_dy_row
DX
+=
row_start
*
stride_dx_row
if
HAS_DY1
:
DY1
+=
row_start
*
stride_dy1_row
if
HAS_DX1
:
DX1
+=
row_start
*
stride_dx1_row
if
RECOMPUTE_OUTPUT
:
Y
+=
row_start
*
stride_y_row
w
=
tl
.
load
(
W
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
if
RECOMPUTE_OUTPUT
and
HAS_BIAS
:
b
=
tl
.
load
(
B
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
if
HAS_DY1
:
w1
=
tl
.
load
(
W1
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
dw
=
tl
.
zeros
((
BLOCK_N
,),
dtype
=
tl
.
float32
)
if
HAS_BIAS
:
db
=
tl
.
zeros
((
BLOCK_N
,),
dtype
=
tl
.
float32
)
if
HAS_DY1
:
dw1
=
tl
.
zeros
((
BLOCK_N
,),
dtype
=
tl
.
float32
)
if
HAS_B1
:
db1
=
tl
.
zeros
((
BLOCK_N
,),
dtype
=
tl
.
float32
)
row_end
=
min
((
row_block_id
+
1
)
*
rows_per_program
,
M
)
for
row
in
range
(
row_start
,
row_end
):
# Load data to SRAM
x
=
tl
.
load
(
X
+
cols
,
mask
=
mask
,
other
=
0
).
to
(
tl
.
float32
)
dy
=
tl
.
load
(
DY
+
cols
,
mask
=
mask
,
other
=
0
).
to
(
tl
.
float32
)
if
HAS_DY1
:
dy1
=
tl
.
load
(
DY1
+
cols
,
mask
=
mask
,
other
=
0
).
to
(
tl
.
float32
)
if
not
IS_RMS_NORM
:
mean
=
tl
.
load
(
Mean
+
row
)
rstd
=
tl
.
load
(
Rstd
+
row
)
# Compute dx
xhat
=
(
x
-
mean
)
*
rstd
if
not
IS_RMS_NORM
else
x
*
rstd
xhat
=
tl
.
where
(
mask
,
xhat
,
0.0
)
if
RECOMPUTE_OUTPUT
:
y
=
xhat
*
w
+
b
if
HAS_BIAS
else
xhat
*
w
tl
.
store
(
Y
+
cols
,
y
,
mask
=
mask
)
wdy
=
w
*
dy
dw
+=
dy
*
xhat
if
HAS_BIAS
:
db
+=
dy
if
HAS_DY1
:
wdy
+=
w1
*
dy1
dw1
+=
dy1
*
xhat
if
HAS_B1
:
db1
+=
dy1
if
not
IS_RMS_NORM
:
c1
=
tl
.
sum
(
xhat
*
wdy
,
axis
=
0
)
/
N
c2
=
tl
.
sum
(
wdy
,
axis
=
0
)
/
N
dx
=
(
wdy
-
(
xhat
*
c1
+
c2
))
*
rstd
else
:
c1
=
tl
.
sum
(
xhat
*
wdy
,
axis
=
0
)
/
N
dx
=
(
wdy
-
xhat
*
c1
)
*
rstd
if
HAS_DRESIDUAL
:
dres
=
tl
.
load
(
DRESIDUAL
+
cols
,
mask
=
mask
,
other
=
0
).
to
(
tl
.
float32
)
dx
+=
dres
# Write dx
if
STORE_DRESIDUAL
:
tl
.
store
(
DRESIDUAL_IN
+
cols
,
dx
,
mask
=
mask
)
if
HAS_DX1
:
if
HAS_DROPOUT
:
keep_mask
=
(
tl
.
rand
(
tl
.
load
(
SEEDS
+
M
+
row
).
to
(
tl
.
uint32
),
cols
,
n_rounds
=
7
)
>
dropout_p
)
dx1
=
tl
.
where
(
keep_mask
,
dx
/
(
1.0
-
dropout_p
),
0.0
)
else
:
dx1
=
dx
tl
.
store
(
DX1
+
cols
,
dx1
,
mask
=
mask
)
if
HAS_DROPOUT
:
keep_mask
=
tl
.
rand
(
tl
.
load
(
SEEDS
+
row
).
to
(
tl
.
uint32
),
cols
,
n_rounds
=
7
)
>
dropout_p
dx
=
tl
.
where
(
keep_mask
,
dx
/
(
1.0
-
dropout_p
),
0.0
)
if
HAS_ROWSCALE
:
rowscale
=
tl
.
load
(
ROWSCALE
+
row
).
to
(
tl
.
float32
)
dx
*=
rowscale
tl
.
store
(
DX
+
cols
,
dx
,
mask
=
mask
)
X
+=
stride_x_row
if
HAS_DRESIDUAL
:
DRESIDUAL
+=
stride_dres_row
if
STORE_DRESIDUAL
:
DRESIDUAL_IN
+=
stride_dres_in_row
if
RECOMPUTE_OUTPUT
:
Y
+=
stride_y_row
DY
+=
stride_dy_row
DX
+=
stride_dx_row
if
HAS_DY1
:
DY1
+=
stride_dy1_row
if
HAS_DX1
:
DX1
+=
stride_dx1_row
tl
.
store
(
DW
+
row_block_id
*
N
+
cols
,
dw
,
mask
=
mask
)
if
HAS_BIAS
:
tl
.
store
(
DB
+
row_block_id
*
N
+
cols
,
db
,
mask
=
mask
)
if
HAS_DY1
:
tl
.
store
(
DW1
+
row_block_id
*
N
+
cols
,
dw1
,
mask
=
mask
)
if
HAS_B1
:
tl
.
store
(
DB1
+
row_block_id
*
N
+
cols
,
db1
,
mask
=
mask
)
def
_layer_norm_bwd
(
dy
,
x
,
weight
,
bias
,
eps
,
mean
,
rstd
,
dresidual
=
None
,
dy1
=
None
,
weight1
=
None
,
bias1
=
None
,
seeds
=
None
,
dropout_p
=
0.0
,
rowscale
=
None
,
has_residual
=
False
,
has_x1
=
False
,
is_rms_norm
=
False
,
x_dtype
=
None
,
recompute_output
=
False
,
):
M
,
N
=
x
.
shape
assert
x
.
stride
(
-
1
)
==
1
assert
dy
.
stride
(
-
1
)
==
1
assert
dy
.
shape
==
(
M
,
N
)
if
dresidual
is
not
None
:
assert
dresidual
.
stride
(
-
1
)
==
1
assert
dresidual
.
shape
==
(
M
,
N
)
assert
weight
.
shape
==
(
N
,)
assert
weight
.
stride
(
-
1
)
==
1
if
bias
is
not
None
:
assert
bias
.
stride
(
-
1
)
==
1
assert
bias
.
shape
==
(
N
,)
if
dy1
is
not
None
:
assert
weight1
is
not
None
assert
dy1
.
shape
==
dy
.
shape
assert
dy1
.
stride
(
-
1
)
==
1
if
weight1
is
not
None
:
assert
weight1
.
shape
==
(
N
,)
assert
weight1
.
stride
(
-
1
)
==
1
if
bias1
is
not
None
:
assert
bias1
.
shape
==
(
N
,)
assert
bias1
.
stride
(
-
1
)
==
1
if
seeds
is
not
None
:
assert
seeds
.
is_contiguous
()
assert
seeds
.
shape
==
(
M
if
not
has_x1
else
M
*
2
,)
if
rowscale
is
not
None
:
assert
rowscale
.
is_contiguous
()
assert
rowscale
.
shape
==
(
M
,)
# allocate output
dx
=
(
torch
.
empty_like
(
x
)
if
x_dtype
is
None
else
torch
.
empty
(
M
,
N
,
dtype
=
x_dtype
,
device
=
x
.
device
)
)
dresidual_in
=
(
torch
.
empty_like
(
x
)
if
has_residual
and
(
dx
.
dtype
!=
x
.
dtype
or
dropout_p
>
0.0
or
rowscale
is
not
None
or
has_x1
)
else
None
)
dx1
=
torch
.
empty_like
(
dx
)
if
(
has_x1
and
dropout_p
>
0.0
)
else
None
y
=
torch
.
empty
(
M
,
N
,
dtype
=
dy
.
dtype
,
device
=
dy
.
device
)
if
recompute_output
else
None
if
recompute_output
:
assert
weight1
is
None
,
"recompute_output is not supported with parallel LayerNorm"
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE
=
65536
//
x
.
element_size
()
BLOCK_N
=
min
(
MAX_FUSED_SIZE
,
triton
.
next_power_of_2
(
N
))
if
N
>
BLOCK_N
:
raise
RuntimeError
(
"This layer norm doesn't support feature dim >= 64KB."
)
sm_count
=
torch
.
cuda
.
get_device_properties
(
x
.
device
).
multi_processor_count
_dw
=
torch
.
empty
((
sm_count
,
N
),
dtype
=
torch
.
float32
,
device
=
weight
.
device
)
_db
=
(
torch
.
empty
((
sm_count
,
N
),
dtype
=
torch
.
float32
,
device
=
bias
.
device
)
if
bias
is
not
None
else
None
)
_dw1
=
torch
.
empty_like
(
_dw
)
if
weight1
is
not
None
else
None
_db1
=
torch
.
empty_like
(
_db
)
if
bias1
is
not
None
else
None
rows_per_program
=
math
.
ceil
(
M
/
sm_count
)
grid
=
(
sm_count
,)
with
torch
.
cuda
.
device
(
x
.
device
.
index
):
_layer_norm_bwd_kernel
[
grid
](
x
,
weight
,
bias
,
y
,
dy
,
dx
,
_dw
,
_db
,
dresidual
,
weight1
,
dy1
,
dx1
,
_dw1
,
_db1
,
dresidual_in
,
rowscale
,
seeds
,
mean
,
rstd
,
x
.
stride
(
0
),
0
if
not
recompute_output
else
y
.
stride
(
0
),
dy
.
stride
(
0
),
dx
.
stride
(
0
),
dresidual
.
stride
(
0
)
if
dresidual
is
not
None
else
0
,
dy1
.
stride
(
0
)
if
dy1
is
not
None
else
0
,
dx1
.
stride
(
0
)
if
dx1
is
not
None
else
0
,
dresidual_in
.
stride
(
0
)
if
dresidual_in
is
not
None
else
0
,
M
,
N
,
eps
,
dropout_p
,
rows_per_program
,
is_rms_norm
,
BLOCK_N
,
dresidual
is
not
None
,
dresidual_in
is
not
None
,
bias
is
not
None
,
dropout_p
>
0.0
,
)
dw
=
_dw
.
sum
(
0
).
to
(
weight
.
dtype
)
db
=
_db
.
sum
(
0
).
to
(
bias
.
dtype
)
if
bias
is
not
None
else
None
dw1
=
_dw1
.
sum
(
0
).
to
(
weight1
.
dtype
)
if
weight1
is
not
None
else
None
db1
=
_db1
.
sum
(
0
).
to
(
bias1
.
dtype
)
if
bias1
is
not
None
else
None
# Don't need to compute dresidual_in separately in this case
if
has_residual
and
dx
.
dtype
==
x
.
dtype
and
dropout_p
==
0.0
and
rowscale
is
None
:
dresidual_in
=
dx
if
has_x1
and
dropout_p
==
0.0
:
dx1
=
dx
return
(
(
dx
,
dw
,
db
,
dresidual_in
,
dx1
,
dw1
,
db1
)
if
not
recompute_output
else
(
dx
,
dw
,
db
,
dresidual_in
,
dx1
,
dw1
,
db1
,
y
)
)
class
LayerNormFn
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
weight
,
bias
,
residual
=
None
,
x1
=
None
,
weight1
=
None
,
bias1
=
None
,
eps
=
1e-6
,
dropout_p
=
0.0
,
rowscale
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
,
return_dropout_mask
=
False
,
):
x_shape_og
=
x
.
shape
# reshape input data into 2D tensor
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
if
x
.
stride
(
-
1
)
!=
1
:
x
=
x
.
contiguous
()
if
residual
is
not
None
:
assert
residual
.
shape
==
x_shape_og
residual
=
residual
.
reshape
(
-
1
,
residual
.
shape
[
-
1
])
if
residual
.
stride
(
-
1
)
!=
1
:
residual
=
residual
.
contiguous
()
if
x1
is
not
None
:
assert
x1
.
shape
==
x_shape_og
assert
rowscale
is
None
,
"rowscale is not supported with parallel LayerNorm"
x1
=
x1
.
reshape
(
-
1
,
x1
.
shape
[
-
1
])
if
x1
.
stride
(
-
1
)
!=
1
:
x1
=
x1
.
contiguous
()
weight
=
weight
.
contiguous
()
if
bias
is
not
None
:
bias
=
bias
.
contiguous
()
if
weight1
is
not
None
:
weight1
=
weight1
.
contiguous
()
if
bias1
is
not
None
:
bias1
=
bias1
.
contiguous
()
if
rowscale
is
not
None
:
rowscale
=
rowscale
.
reshape
(
-
1
).
contiguous
()
residual_dtype
=
(
residual
.
dtype
if
residual
is
not
None
else
(
torch
.
float32
if
residual_in_fp32
else
None
)
)
y
,
y1
,
mean
,
rstd
,
residual_out
,
seeds
,
dropout_mask
,
dropout_mask1
=
_layer_norm_fwd
(
x
,
weight
,
bias
,
eps
,
residual
,
x1
,
weight1
,
bias1
,
dropout_p
=
dropout_p
,
rowscale
=
rowscale
,
residual_dtype
=
residual_dtype
,
is_rms_norm
=
is_rms_norm
,
return_dropout_mask
=
return_dropout_mask
,
)
ctx
.
save_for_backward
(
residual_out
,
weight
,
bias
,
weight1
,
bias1
,
rowscale
,
seeds
,
mean
,
rstd
)
ctx
.
x_shape_og
=
x_shape_og
ctx
.
eps
=
eps
ctx
.
dropout_p
=
dropout_p
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
has_residual
=
residual
is
not
None
ctx
.
has_x1
=
x1
is
not
None
ctx
.
prenorm
=
prenorm
ctx
.
x_dtype
=
x
.
dtype
y
=
y
.
reshape
(
x_shape_og
)
y1
=
y1
.
reshape
(
x_shape_og
)
if
y1
is
not
None
else
None
residual_out
=
residual_out
.
reshape
(
x_shape_og
)
if
residual_out
is
not
None
else
None
dropout_mask
=
dropout_mask
.
reshape
(
x_shape_og
)
if
dropout_mask
is
not
None
else
None
dropout_mask1
=
dropout_mask1
.
reshape
(
x_shape_og
)
if
dropout_mask1
is
not
None
else
None
if
not
return_dropout_mask
:
if
weight1
is
None
:
return
y
if
not
prenorm
else
(
y
,
residual_out
)
else
:
return
(
y
,
y1
)
if
not
prenorm
else
(
y
,
y1
,
residual_out
)
else
:
if
weight1
is
None
:
return
(
(
y
,
dropout_mask
,
dropout_mask1
)
if
not
prenorm
else
(
y
,
residual_out
,
dropout_mask
,
dropout_mask1
)
)
else
:
return
(
(
y
,
y1
,
dropout_mask
,
dropout_mask1
)
if
not
prenorm
else
(
y
,
y1
,
residual_out
,
dropout_mask
,
dropout_mask1
)
)
@
staticmethod
def
backward
(
ctx
,
dy
,
*
args
):
x
,
weight
,
bias
,
weight1
,
bias1
,
rowscale
,
seeds
,
mean
,
rstd
=
ctx
.
saved_tensors
dy
=
dy
.
reshape
(
-
1
,
dy
.
shape
[
-
1
])
if
dy
.
stride
(
-
1
)
!=
1
:
dy
=
dy
.
contiguous
()
assert
dy
.
shape
==
x
.
shape
if
weight1
is
not
None
:
dy1
,
args
=
args
[
0
],
args
[
1
:]
dy1
=
dy1
.
reshape
(
-
1
,
dy1
.
shape
[
-
1
])
if
dy1
.
stride
(
-
1
)
!=
1
:
dy1
=
dy1
.
contiguous
()
assert
dy1
.
shape
==
x
.
shape
else
:
dy1
=
None
if
ctx
.
prenorm
:
dresidual
=
args
[
0
]
dresidual
=
dresidual
.
reshape
(
-
1
,
dresidual
.
shape
[
-
1
])
if
dresidual
.
stride
(
-
1
)
!=
1
:
dresidual
=
dresidual
.
contiguous
()
assert
dresidual
.
shape
==
x
.
shape
else
:
dresidual
=
None
dx
,
dw
,
db
,
dresidual_in
,
dx1
,
dw1
,
db1
=
_layer_norm_bwd
(
dy
,
x
,
weight
,
bias
,
ctx
.
eps
,
mean
,
rstd
,
dresidual
,
dy1
,
weight1
,
bias1
,
seeds
,
ctx
.
dropout_p
,
rowscale
,
ctx
.
has_residual
,
ctx
.
has_x1
,
ctx
.
is_rms_norm
,
x_dtype
=
ctx
.
x_dtype
,
)
return
(
dx
.
reshape
(
ctx
.
x_shape_og
),
dw
,
db
,
dresidual_in
.
reshape
(
ctx
.
x_shape_og
)
if
ctx
.
has_residual
else
None
,
dx1
.
reshape
(
ctx
.
x_shape_og
)
if
dx1
is
not
None
else
None
,
dw1
,
db1
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
)
def
layer_norm_fn
(
x
,
weight
,
bias
,
residual
=
None
,
x1
=
None
,
weight1
=
None
,
bias1
=
None
,
eps
=
1e-6
,
dropout_p
=
0.0
,
rowscale
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
,
return_dropout_mask
=
False
,
):
return
LayerNormFn
.
apply
(
x
,
weight
,
bias
,
residual
,
x1
,
weight1
,
bias1
,
eps
,
dropout_p
,
rowscale
,
prenorm
,
residual_in_fp32
,
is_rms_norm
,
return_dropout_mask
,
)
def
rms_norm_fn
(
x
,
weight
,
bias
,
residual
=
None
,
x1
=
None
,
weight1
=
None
,
bias1
=
None
,
eps
=
1e-6
,
dropout_p
=
0.0
,
rowscale
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
,
):
return
LayerNormFn
.
apply
(
x
,
weight
,
bias
,
residual
,
x1
,
weight1
,
bias1
,
eps
,
dropout_p
,
rowscale
,
prenorm
,
residual_in_fp32
,
True
,
return_dropout_mask
,
)
class
RMSNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-5
,
dropout_p
=
0.0
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
eps
=
eps
if
dropout_p
>
0.0
:
self
.
drop
=
torch
.
nn
.
Dropout
(
dropout_p
)
else
:
self
.
drop
=
None
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
register_parameter
(
"bias"
,
None
)
self
.
reset_parameters
()
def
reset_parameters
(
self
):
torch
.
nn
.
init
.
ones_
(
self
.
weight
)
def
forward
(
self
,
x
,
residual
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
):
return
rms_norm_fn
(
x
,
self
.
weight
,
self
.
bias
,
residual
=
residual
,
eps
=
self
.
eps
,
dropout_p
=
self
.
drop
.
p
if
self
.
drop
is
not
None
and
self
.
training
else
0.0
,
prenorm
=
prenorm
,
residual_in_fp32
=
residual_in_fp32
,
)
class
LayerNormLinearFn
(
torch
.
autograd
.
Function
):
@
staticmethod
@
custom_fwd
def
forward
(
ctx
,
x
,
norm_weight
,
norm_bias
,
linear_weight
,
linear_bias
,
residual
=
None
,
eps
=
1e-6
,
prenorm
=
False
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
,
):
x_shape_og
=
x
.
shape
# reshape input data into 2D tensor
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
if
x
.
stride
(
-
1
)
!=
1
:
x
=
x
.
contiguous
()
if
residual
is
not
None
:
assert
residual
.
shape
==
x_shape_og
residual
=
residual
.
reshape
(
-
1
,
residual
.
shape
[
-
1
])
if
residual
.
stride
(
-
1
)
!=
1
:
residual
=
residual
.
contiguous
()
norm_weight
=
norm_weight
.
contiguous
()
if
norm_bias
is
not
None
:
norm_bias
=
norm_bias
.
contiguous
()
residual_dtype
=
(
residual
.
dtype
if
residual
is
not
None
else
(
torch
.
float32
if
residual_in_fp32
else
None
)
)
y
,
_
,
mean
,
rstd
,
residual_out
,
*
rest
=
_layer_norm_fwd
(
x
,
norm_weight
,
norm_bias
,
eps
,
residual
,
out_dtype
=
None
if
not
torch
.
is_autocast_enabled
()
else
torch
.
get_autocast_gpu_dtype
(),
residual_dtype
=
residual_dtype
,
is_rms_norm
=
is_rms_norm
,
)
y
=
y
.
reshape
(
x_shape_og
)
dtype
=
torch
.
get_autocast_gpu_dtype
()
if
torch
.
is_autocast_enabled
()
else
y
.
dtype
linear_weight
=
linear_weight
.
to
(
dtype
)
linear_bias
=
linear_bias
.
to
(
dtype
)
if
linear_bias
is
not
None
else
None
out
=
F
.
linear
(
y
.
to
(
linear_weight
.
dtype
),
linear_weight
,
linear_bias
)
# We don't store y, will be recomputed in the backward pass to save memory
ctx
.
save_for_backward
(
residual_out
,
norm_weight
,
norm_bias
,
linear_weight
,
mean
,
rstd
)
ctx
.
x_shape_og
=
x_shape_og
ctx
.
eps
=
eps
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
has_residual
=
residual
is
not
None
ctx
.
prenorm
=
prenorm
ctx
.
x_dtype
=
x
.
dtype
ctx
.
linear_bias_is_none
=
linear_bias
is
None
return
out
if
not
prenorm
else
(
out
,
residual_out
.
reshape
(
x_shape_og
))
@
staticmethod
@
custom_bwd
def
backward
(
ctx
,
dout
,
*
args
):
x
,
norm_weight
,
norm_bias
,
linear_weight
,
mean
,
rstd
=
ctx
.
saved_tensors
dout
=
dout
.
reshape
(
-
1
,
dout
.
shape
[
-
1
])
dy
=
F
.
linear
(
dout
,
linear_weight
.
t
())
dlinear_bias
=
None
if
ctx
.
linear_bias_is_none
else
dout
.
sum
(
0
)
if
dy
.
stride
(
-
1
)
!=
1
:
dy
=
dy
.
contiguous
()
assert
dy
.
shape
==
x
.
shape
if
ctx
.
prenorm
:
dresidual
=
args
[
0
]
dresidual
=
dresidual
.
reshape
(
-
1
,
dresidual
.
shape
[
-
1
])
if
dresidual
.
stride
(
-
1
)
!=
1
:
dresidual
=
dresidual
.
contiguous
()
assert
dresidual
.
shape
==
x
.
shape
else
:
dresidual
=
None
dx
,
dnorm_weight
,
dnorm_bias
,
dresidual_in
,
_
,
_
,
_
,
y
=
_layer_norm_bwd
(
dy
,
x
,
norm_weight
,
norm_bias
,
ctx
.
eps
,
mean
,
rstd
,
dresidual
=
dresidual
,
has_residual
=
ctx
.
has_residual
,
is_rms_norm
=
ctx
.
is_rms_norm
,
x_dtype
=
ctx
.
x_dtype
,
recompute_output
=
True
,
)
dlinear_weight
=
torch
.
einsum
(
"bo,bi->oi"
,
dout
,
y
)
return
(
dx
.
reshape
(
ctx
.
x_shape_og
),
dnorm_weight
,
dnorm_bias
,
dlinear_weight
,
dlinear_bias
,
dresidual_in
.
reshape
(
ctx
.
x_shape_og
)
if
ctx
.
has_residual
else
None
,
None
,
None
,
None
,
None
,
)
def
layer_norm_linear_fn
(
x
,
norm_weight
,
norm_bias
,
linear_weight
,
linear_bias
,
residual
=
None
,
eps
=
1e-6
,
prenorm
=
False
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
,
):
return
LayerNormLinearFn
.
apply
(
x
,
norm_weight
,
norm_bias
,
linear_weight
,
linear_bias
,
residual
,
eps
,
prenorm
,
residual_in_fp32
,
is_rms_norm
,
)
flash_attn/ops/triton/linear.py
deleted
100644 → 0
View file @
5018ac6a
# Adapted from https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py
# and https://github.com/openai/triton/blob/master/python/triton/ops/matmul.py
from
typing
import
Optional
import
torch
import
triton
import
triton.language
as
tl
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
from
flash_attn.ops.triton.k_activations
import
(
gelu
,
gelu_approx
,
gelu_approx_grad
,
gelu_grad
,
squared_relu
,
squared_relu_grad
,
)
# CREDITS: Initially inspired by the Triton tutorial on matrix multiplications
def
init_to_zero
(
name
):
return
lambda
nargs
:
nargs
[
name
].
zero_
()
def
get_configs_io_bound
():
configs
=
[]
for
num_stages
in
[
2
,
3
,
4
,
5
,
6
]:
for
block_m
in
[
16
,
32
]:
for
block_k
in
[
32
,
64
]:
for
block_n
in
[
32
,
64
,
128
,
256
]:
num_warps
=
2
if
block_n
<=
64
else
4
configs
.
append
(
triton
.
Config
(
{
"BLOCK_M"
:
block_m
,
"BLOCK_N"
:
block_n
,
"BLOCK_K"
:
block_k
,
"SPLIT_K"
:
1
,
},
num_stages
=
num_stages
,
num_warps
=
num_warps
,
)
)
# split_k not used
# for split_k in [2, 4, 8, 16]:
# configs.append(triton.Config(
# {'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
# num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
return
configs
@
triton
.
autotune
(
configs
=
[
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
# good for int8
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
,
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
,
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
,
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
]
+
get_configs_io_bound
(),
key
=
[
"CACHE_KEY_M"
,
"CACHE_KEY_N"
,
"CACHE_KEY_K"
],
prune_configs_by
=
{
"early_config_prune"
:
early_config_prune
,
"perf_model"
:
estimate_matmul_time
,
"top_k"
:
10
,
},
)
@
triton
.
heuristics
(
{
"EVEN_K"
:
lambda
args
:
args
[
"K"
]
%
(
args
[
"BLOCK_K"
]
*
args
[
"SPLIT_K"
])
==
0
,
}
)
@
triton
.
jit
def
kernel_fwd
(
C
,
# Pointers to matrices
ACT_INPUT
,
A
,
B
,
bias
,
# Matrix dimensions
M
,
N
,
K
,
CACHE_KEY_M
,
CACHE_KEY_N
,
CACHE_KEY_K
,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
# by to get the element one row down (A has M rows)
stride_cm
,
# stride_cn, # Assume that stride_cn == 1
stride_am
,
stride_ak
,
stride_bn
,
stride_bk
,
# Meta-parameters
BLOCK_M
:
tl
.
constexpr
,
GROUP_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
# split k not used, not performant with activation, kept because early_config_prune is expecting it
SPLIT_K
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
A_ROWMAJOR
:
tl
.
constexpr
,
B_COLMAJOR
:
tl
.
constexpr
,
BIAS
:
tl
.
constexpr
,
SAVE_ACT_INPUT
:
tl
.
constexpr
,
ACTIVATION
:
tl
.
constexpr
,
):
"""
Kernel for computing Out = activation(A x W + C)
- Input has shape (M, K)
- Weight has shape (K, N)
- Bias has shape (N,)
- Output has shape (M, N)
- ActInputs (optional) has shape (M, N)
'ActInputs' optionally saves the A x W + C intermediate for backward computations
This kernel will consolidate over K
"""
pid
=
tl
.
program_id
(
axis
=
0
)
grid_m
=
(
M
+
BLOCK_M
-
1
)
//
BLOCK_M
grid_n
=
(
N
+
BLOCK_N
-
1
)
//
BLOCK_N
# re-order program ID for better L2 performance
width
=
GROUP_M
*
grid_n
group_id
=
pid
//
width
group_size
=
min
(
grid_m
-
group_id
*
GROUP_M
,
GROUP_M
)
pid_m
=
group_id
*
GROUP_M
+
(
pid
%
group_size
)
pid_n
=
(
pid
%
width
)
//
(
group_size
)
# now compute the block that each program will go through
# rm (resp. rn) denotes a range of indices
# for rows (resp. col) of C
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
# trick to avoid masking on M and N axis
ram
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
rm
%
M
,
BLOCK_M
),
BLOCK_M
)
rbn
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
rn
%
N
,
BLOCK_N
),
BLOCK_N
)
rk
=
tl
.
arange
(
0
,
BLOCK_K
)
if
A_ROWMAJOR
:
A
=
A
+
(
ram
[:,
None
]
*
stride_am
+
rk
[
None
,
:])
else
:
A
=
A
+
(
ram
[:,
None
]
*
stride_am
+
rk
[
None
,
:]
*
stride_ak
)
if
B_COLMAJOR
:
B
=
B
+
(
rk
[:,
None
]
+
rbn
[
None
,
:]
*
stride_bn
)
else
:
B
=
B
+
(
rk
[:,
None
]
*
stride_bk
+
rbn
[
None
,
:]
*
stride_bn
)
acc
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
K
,
0
,
-
BLOCK_K
):
if
EVEN_K
:
a
=
tl
.
load
(
A
)
b
=
tl
.
load
(
B
)
else
:
a
=
tl
.
load
(
A
,
mask
=
rk
[
None
,
:]
<
k
,
other
=
0.0
)
b
=
tl
.
load
(
B
,
mask
=
rk
[:,
None
]
<
k
,
other
=
0.0
)
acc
+=
tl
.
dot
(
a
,
b
)
if
A_ROWMAJOR
:
A
+=
BLOCK_K
else
:
A
+=
BLOCK_K
*
stride_ak
if
B_COLMAJOR
:
B
+=
BLOCK_K
else
:
B
+=
BLOCK_K
*
stride_bk
# Putting bias after the matmul (instead of before) is faster, idk why
if
BIAS
:
bias
=
tl
.
load
(
bias
+
rn
,
mask
=
rn
<
N
,
other
=
0.0
).
to
(
tl
.
float32
)
acc
+=
bias
[
None
,
:]
# optional: save the activation inputs
if
SAVE_ACT_INPUT
:
# act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] * stride_cn
act_in_ptrs
=
ACT_INPUT
+
ram
[:,
None
]
*
stride_cm
+
rbn
[
None
,
:]
tl
.
store
(
act_in_ptrs
,
acc
)
# optional: fused activation (while the data is in shared memory)
if
ACTIVATION
==
"gelu"
:
acc
=
gelu
(
acc
)
elif
ACTIVATION
==
"gelu_approx"
:
acc
=
gelu_approx
(
acc
)
elif
ACTIVATION
==
"squared_relu"
:
acc
=
squared_relu
(
acc
)
# rematerialize rm and rn to save registers
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
# write back result
# C = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn
C
=
C
+
rm
[:,
None
]
*
stride_cm
+
rn
[
None
,
:]
mask
=
(
rm
<
M
)[:,
None
]
&
(
rn
<
N
)[
None
,
:]
tl
.
store
(
C
,
acc
)
def
triton_linear_act
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"id"
,
save_act_input
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""
Compute e = activation(x @ weight.T + bias).
This wrapper kicks the `kernel_fwd` Triton kernel
:param x: input tensor
:param weight: weight matrix
:param bias: an optional bias tensor
:param activation: Activation name. Needs to be a Triton kernel.
:param act_input: an optional tensor to save the activation inputs (for backward)
:return: result tensor
"""
# if torch.is_autocast_enabled():
# dtype = torch.get_autocast_gpu_dtype()
# x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]]
assert
activation
in
[
"id"
,
"gelu"
,
"gelu_approx"
,
"squared_relu"
]
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
x_reshaped
=
x
.
reshape
(
batch_dim
,
n
)
if
x_reshaped
.
stride
(
0
)
>
1
and
x_reshaped
.
stride
(
1
)
>
1
:
x_reshaped
=
x_reshaped
.
contiguous
()
if
weight
.
stride
(
0
)
>
1
and
weight
.
stride
(
1
)
>
1
:
weight
=
weight
.
contiguous
()
bias
=
bias
.
contiguous
()
if
bias
is
not
None
else
None
assert
(
x
.
dtype
==
weight
.
dtype
),
f
"Input and weight must have the same dtype, got
{
x
.
dtype
}
and
{
weight
.
dtype
}
"
if
bias
is
not
None
:
assert
(
x
.
dtype
==
bias
.
dtype
),
f
"Input and bias must have the same dtype, got
{
x
.
dtype
}
and
{
bias
.
dtype
}
"
assert
(
x_reshaped
.
shape
[
1
]
==
weight
.
shape
[
1
]
),
f
"Incompatible dimensions:
{
x_reshaped
.
shape
}
-
{
weight
.
shape
}
"
assert
(
bias
is
None
or
bias
.
shape
[
0
]
==
weight
.
shape
[
0
]
),
"Incompatible dimensions in between weight and bias"
M
,
K
=
x_reshaped
.
shape
N
,
K
=
weight
.
shape
output
=
torch
.
empty
((
M
,
N
),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
act_input
=
torch
.
empty_like
(
output
)
if
save_act_input
else
None
# 1D launch kernel where each block gets its own program.
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
"BLOCK_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_N"
]),)
# noqa
kernel_fwd
[
grid
](
output
,
act_input
,
x_reshaped
,
weight
,
# data ptrs
bias
if
bias
is
not
None
else
x
,
# auto skip bias if not present
M
,
# shapes
N
,
K
,
M
//
32
,
# key for triton cache (limit number of compilations)
N
//
32
,
K
//
32
,
stride_cm
=
output
.
stride
(
0
),
# strides
# stride_cn=output.stride(1),
stride_am
=
x_reshaped
.
stride
(
0
),
stride_ak
=
x_reshaped
.
stride
(
1
),
stride_bk
=
weight
.
stride
(
1
),
stride_bn
=
weight
.
stride
(
0
),
BIAS
=
bias
is
not
None
,
# optional fused bias
SAVE_ACT_INPUT
=
save_act_input
,
# optional save activation inputs
ACTIVATION
=
activation
,
# optional fused activation
A_ROWMAJOR
=
x_reshaped
.
stride
(
1
)
==
1
,
B_COLMAJOR
=
weight
.
stride
(
1
)
==
1
,
GROUP_M
=
8
,
# speed optimization: group the programs
)
if
not
save_act_input
:
return
output
.
reshape
(
*
batch_shape
,
output
.
shape
[
-
1
])
else
:
return
(
output
.
reshape
(
*
batch_shape
,
output
.
shape
[
-
1
]),
act_input
.
reshape
(
*
batch_shape
,
act_input
.
shape
[
-
1
]),
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
# good for int8
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
,
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
,
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
,
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
]
+
get_configs_io_bound
(),
key
=
[
"CACHE_KEY_M"
,
"CACHE_KEY_N"
,
"CACHE_KEY_K"
],
prune_configs_by
=
{
"early_config_prune"
:
early_config_prune
,
"perf_model"
:
estimate_matmul_time
,
"top_k"
:
10
,
},
)
@
triton
.
heuristics
(
{
"EVEN_K"
:
lambda
args
:
args
[
"K"
]
%
(
args
[
"BLOCK_K"
]
*
args
[
"SPLIT_K"
])
==
0
,
}
)
@
triton
.
jit
def
kernel_bwd
(
C
,
# Pointers to matrices
ACT_INPUT
,
A
,
B
,
# Matrix dimensions
M
,
N
,
K
,
CACHE_KEY_M
,
CACHE_KEY_N
,
CACHE_KEY_K
,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
# by to get the element one row down (A has M rows)
stride_cm
,
# stride_cn, # Assume that stride_cn == 1
stride_am
,
stride_ak
,
stride_bk
,
stride_bn
,
# Meta-parameters
BLOCK_M
:
tl
.
constexpr
,
GROUP_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
# split k not used, not performant with activation, kept because early_config_prune is expecting it
SPLIT_K
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
ACTIVATION
:
tl
.
constexpr
,
):
"""
Kernel for computing Out = activation(A x W + C)
- Input has shape (M, K)
- Weight has shape (K, N)
- Output has shape (M, N)
- ActInputs (optional) has shape (M, N)
'ActInputs' optionally saves the A x W + C intermediate for backward computations
This kernel will consolidate over K
"""
pid
=
tl
.
program_id
(
axis
=
0
)
grid_m
=
(
M
+
BLOCK_M
-
1
)
//
BLOCK_M
grid_n
=
(
N
+
BLOCK_N
-
1
)
//
BLOCK_N
# re-order program ID for better L2 performance
width
=
GROUP_M
*
grid_n
group_id
=
pid
//
width
group_size
=
min
(
grid_m
-
group_id
*
GROUP_M
,
GROUP_M
)
pid_m
=
group_id
*
GROUP_M
+
(
pid
%
group_size
)
pid_n
=
(
pid
%
width
)
//
(
group_size
)
# now compute the block that each program will go through
# rm (resp. rn) denotes a range of indices
# for rows (resp. col) of C
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
# trick to avoid masking on M and N axis
ram
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
rm
%
M
,
BLOCK_M
),
BLOCK_M
)
rbn
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
rn
%
N
,
BLOCK_N
),
BLOCK_N
)
rk
=
tl
.
arange
(
0
,
BLOCK_K
)
A
=
A
+
(
ram
[:,
None
]
*
stride_am
+
rk
[
None
,
:]
*
stride_ak
)
B
=
B
+
(
rk
[:,
None
]
*
stride_bk
+
rbn
[
None
,
:]
*
stride_bn
)
acc
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
K
,
0
,
-
BLOCK_K
):
if
EVEN_K
:
a
=
tl
.
load
(
A
)
b
=
tl
.
load
(
B
)
else
:
a
=
tl
.
load
(
A
,
mask
=
rk
[
None
,
:]
<
k
,
other
=
0.0
)
b
=
tl
.
load
(
B
,
mask
=
rk
[:,
None
]
<
k
,
other
=
0.0
)
acc
+=
tl
.
dot
(
a
,
b
)
A
+=
BLOCK_K
*
stride_ak
B
+=
BLOCK_K
*
stride_bk
# optional: fused activation (while the data is in shared memory)
if
ACTIVATION
!=
"id"
:
act_in_ptrs
=
ACT_INPUT
+
ram
[:,
None
]
*
stride_cm
+
rbn
[
None
,
:]
act_input
=
tl
.
load
(
act_in_ptrs
).
to
(
acc
.
dtype
)
if
ACTIVATION
==
"gelu"
:
acc
*=
gelu_grad
(
act_input
)
elif
ACTIVATION
==
"gelu_approx"
:
acc
*=
gelu_approx_grad
(
act_input
)
elif
ACTIVATION
==
"squared_relu"
:
acc
*=
squared_relu_grad
(
act_input
)
# rematerialize rm and rn to save registers
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
# write back result
C
=
C
+
rm
[:,
None
]
*
stride_cm
+
rn
[
None
,
:]
mask
=
(
rm
<
M
)[:,
None
]
&
(
rn
<
N
)[
None
,
:]
tl
.
store
(
C
,
acc
,
mask
=
mask
)
def
triton_dgrad_act
(
grad_output
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
activation
:
str
=
"id"
,
act_input
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""
Compute e = activation(grad_output @ weight + bias).
This wrapper kicks the `kernel_fwd` Triton kernel
:param grad_output: input tensor
:param weight: weight matrix
:param activation: Activation name. Needs to be a Triton kernel.
:param act_input: an optional tensor to save the activation inputs (for backward)
:return: result tensor
"""
assert
activation
in
[
"id"
,
"gelu"
,
"gelu_approx"
,
"squared_relu"
]
batch_shape
,
n
=
grad_output
.
shape
[:
-
1
],
grad_output
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
grad_output_reshaped
=
grad_output
.
reshape
(
batch_dim
,
n
)
if
grad_output_reshaped
.
stride
(
0
)
>
1
and
grad_output_reshaped
.
stride
(
1
)
>
1
:
grad_output_reshaped
=
grad_output_reshaped
.
contiguous
()
if
weight
.
stride
(
0
)
>
1
and
weight
.
stride
(
1
)
>
1
:
weight
=
weight
.
contiguous
()
assert
(
grad_output
.
dtype
==
weight
.
dtype
),
f
"grad_output and weight must have the same dtype, got
{
grad_output
.
dtype
}
and
{
weight
.
dtype
}
"
assert
(
grad_output_reshaped
.
shape
[
1
]
==
weight
.
shape
[
0
]
),
f
"Incompatible dimensions:
{
grad_output_reshaped
.
shape
}
-
{
weight
.
shape
}
"
if
activation
!=
"id"
:
assert
act_input
is
not
None
,
f
"act_input is required for activation
{
activation
}
"
# M, N, K in bwd are different from M, N, K in fwd
M
,
K
=
grad_output_reshaped
.
shape
K
,
N
=
weight
.
shape
grad_input
=
torch
.
empty
((
M
,
N
),
device
=
grad_output
.
device
,
dtype
=
grad_output
.
dtype
)
# 1D launch kernel where each block gets its own program.
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
"BLOCK_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_N"
]),)
# noqa
kernel_bwd
[
grid
](
grad_input
,
act_input
,
grad_output_reshaped
,
weight
,
# data ptrs
M
,
# shapes
N
,
K
,
M
//
32
,
# key for triton cache (limit number of compilations)
N
//
32
,
K
//
32
,
stride_cm
=
grad_input
.
stride
(
0
),
# strides
# stride_cn=grad_input.stride(1),
stride_am
=
grad_output_reshaped
.
stride
(
0
),
stride_ak
=
grad_output_reshaped
.
stride
(
1
),
stride_bk
=
weight
.
stride
(
0
),
stride_bn
=
weight
.
stride
(
1
),
ACTIVATION
=
activation
,
# optional fused activation
GROUP_M
=
8
,
# speed optimization: group the programs
)
return
grad_input
.
reshape
(
*
batch_shape
,
grad_input
.
shape
[
-
1
])
flash_attn/ops/triton/mlp.py
deleted
100644 → 0
View file @
5018ac6a
# The triton fused matmul + sqrelu is faster for fp16 but slower for bf16, compared
# to naive implementation.
import
fused_dense_lib
as
fused_dense_cuda
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
flash_attn.ops.activations
import
sqrelu_bwd
,
sqrelu_fwd
from
flash_attn.ops.triton.linear
import
triton_dgrad_act
,
triton_linear_act
class
FusedDenseSqreluDenseFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
custom_fwd
def
forward
(
ctx
,
x
,
weight1
,
bias1
,
weight2
,
bias2
,
checkpoint_lvl
=
0
):
"""checkpoint_lvl:
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute act_input and gelu_out in the bwd
"""
if
torch
.
is_autocast_enabled
():
dtype
=
torch
.
get_autocast_gpu_dtype
()
x
,
weight1
,
bias1
,
weight2
,
bias2
=
[
a
.
to
(
dtype
=
dtype
)
for
a
in
[
x
,
weight1
,
bias1
,
weight2
,
bias2
]
]
is_bf16
=
x
.
dtype
==
torch
.
bfloat16
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
x
=
x
.
contiguous
()
weight1
=
weight1
.
contiguous
()
bias1
=
bias1
.
contiguous
()
weight2
=
weight2
.
contiguous
()
bias2
=
bias2
.
contiguous
()
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
if
is_bf16
:
act_input
=
fused_dense_cuda
.
linear_bias_forward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
)
output1
=
sqrelu_fwd
(
act_input
)
else
:
save_act_input
=
checkpoint_lvl
!=
2
result
=
triton_linear_act
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
,
activation
=
"squared_relu"
,
save_act_input
=
save_act_input
,
)
if
save_act_input
:
output1
,
act_input
=
result
else
:
output1
=
result
output2
=
fused_dense_cuda
.
linear_bias_forward
(
output1
,
weight2
,
bias2
)
ctx
.
checkpoint_lvl
=
checkpoint_lvl
if
checkpoint_lvl
==
0
:
ctx
.
save_for_backward
(
x
,
weight1
,
bias1
,
weight2
,
act_input
,
output1
)
elif
checkpoint_lvl
==
1
:
ctx
.
save_for_backward
(
x
,
weight1
,
bias1
,
weight2
,
act_input
)
elif
checkpoint_lvl
==
2
:
ctx
.
save_for_backward
(
x
,
weight1
,
bias1
,
weight2
)
return
output2
.
reshape
(
*
batch_shape
,
output2
.
shape
[
-
1
])
@
staticmethod
@
custom_bwd
def
backward
(
ctx
,
grad_output
):
grad_output
=
grad_output
.
contiguous
()
checkpoint_lvl
=
ctx
.
checkpoint_lvl
x
,
weight1
,
bias1
,
weight2
,
*
rest
=
ctx
.
saved_tensors
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
is_bf16
=
x
.
dtype
==
torch
.
bfloat16
if
checkpoint_lvl
==
0
:
act_input
,
output1
=
rest
elif
checkpoint_lvl
==
1
:
(
act_input
,)
=
rest
output1
=
sqrelu_fwd
(
act_input
)
elif
checkpoint_lvl
==
2
:
if
is_bf16
:
act_input
=
fused_dense_cuda
.
linear_bias_forward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
)
output1
=
sqrelu_fwd
(
act_input
)
else
:
output1
,
act_input
=
triton_linear_act
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
,
activation
=
"squared_relu"
,
save_act_input
=
True
,
)
if
is_bf16
:
grad_output
=
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
])
grad_weight2
,
grad_bias2
=
fused_dense_cuda
.
linear_bias_wgrad
(
output1
,
grad_output
)
grad_output1
=
grad_output
@
weight2
grad_act_input
=
sqrelu_bwd
(
grad_output1
,
act_input
)
grad_input
,
grad_weight1
,
grad_bias1
=
fused_dense_cuda
.
linear_bias_backward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
grad_act_input
)
else
:
grad_output
=
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
])
grad_weight2
,
grad_bias2
=
fused_dense_cuda
.
linear_bias_wgrad
(
output1
,
grad_output
)
grad_act_input
=
triton_dgrad_act
(
grad_output
,
weight2
,
activation
=
"squared_relu"
,
act_input
=
act_input
)
grad_input
,
grad_weight1
,
grad_bias1
=
fused_dense_cuda
.
linear_bias_backward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
grad_act_input
)
return
grad_input
.
reshape_as
(
x
),
grad_weight1
,
grad_bias1
,
grad_weight2
,
grad_bias2
,
None
fused_dense_sqrelu_dense_function
=
FusedDenseSqreluDenseFunc
.
apply
class
FusedDenseSqreluDense
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
bias1
=
True
,
bias2
=
True
,
checkpoint_lvl
=
0
,
device
=
None
,
dtype
=
None
,
):
"""
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute gelu_in and gelu_out in the bwd
"""
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
*
4
assert
bias1
==
True
,
"DenseSqreluDense module without bias is currently not supported"
assert
bias2
==
True
,
"DenseSqreluDense module without bias is currently not supported"
self
.
checkpoint_lvl
=
checkpoint_lvl
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
,
bias
=
bias1
,
**
factory_kwargs
)
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
,
bias
=
bias2
,
**
factory_kwargs
)
def
forward
(
self
,
x
):
assert
x
.
is_cuda
return
fused_dense_sqrelu_dense_function
(
x
,
self
.
fc1
.
weight
,
self
.
fc1
.
bias
,
self
.
fc2
.
weight
,
self
.
fc2
.
bias
,
self
.
checkpoint_lvl
)
flash_attn/ops/triton/rotary.py
deleted
100644 → 0
View file @
5018ac6a
# Copyright (c) 2023, Tri Dao.
from
typing
import
Optional
,
Union
import
torch
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
rotary_kernel
(
OUT
,
# Pointers to matrices
X
,
COS
,
SIN
,
CU_SEQLENS
,
SEQLEN_OFFSETS
,
# this could be int or a pointer
# Matrix dimensions
seqlen
,
rotary_dim
,
seqlen_ro
,
# strides
stride_out_batch
,
stride_out_seqlen
,
stride_out_nheads
,
stride_out_headdim
,
stride_x_batch
,
stride_x_seqlen
,
stride_x_nheads
,
stride_x_headdim
,
# Meta-parameters
BLOCK_K
:
tl
.
constexpr
,
IS_SEQLEN_OFFSETS_TENSOR
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
INTERLEAVED
:
tl
.
constexpr
,
CONJUGATE
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
):
pid_m
=
tl
.
program_id
(
axis
=
0
)
pid_batch
=
tl
.
program_id
(
axis
=
1
)
pid_head
=
tl
.
program_id
(
axis
=
2
)
rotary_dim_half
=
rotary_dim
//
2
if
not
IS_VARLEN
:
X
=
X
+
pid_batch
*
stride_x_batch
+
pid_head
*
stride_x_nheads
OUT
=
OUT
+
pid_batch
*
stride_out_batch
+
pid_head
*
stride_out_nheads
else
:
start_idx
=
tl
.
load
(
CU_SEQLENS
+
pid_batch
)
seqlen
=
tl
.
load
(
CU_SEQLENS
+
pid_batch
+
1
)
-
start_idx
X
=
X
+
start_idx
*
stride_x_seqlen
+
pid_head
*
stride_x_nheads
OUT
=
OUT
+
start_idx
*
stride_out_seqlen
+
pid_head
*
stride_out_nheads
if
pid_m
*
BLOCK_M
>=
seqlen
:
return
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
if
not
IS_SEQLEN_OFFSETS_TENSOR
:
rm_cs
=
rm
+
SEQLEN_OFFSETS
else
:
rm_cs
=
rm
+
tl
.
load
(
SEQLEN_OFFSETS
+
pid_batch
)
rk
=
tl
.
arange
(
0
,
BLOCK_K
)
rk_half
=
tl
.
arange
(
0
,
BLOCK_K
//
2
)
if
not
INTERLEAVED
:
# Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT
X
=
X
+
(
rm
[:,
None
]
*
stride_x_seqlen
+
rk_half
[
None
,
:]
*
stride_x_headdim
)
COS
=
COS
+
(
rm_cs
[:,
None
]
*
rotary_dim_half
+
rk_half
[
None
,
:])
SIN
=
SIN
+
(
rm_cs
[:,
None
]
*
rotary_dim_half
+
rk_half
[
None
,
:])
cos
=
tl
.
load
(
COS
,
mask
=
(
rm_cs
[:,
None
]
<
seqlen_ro
)
&
(
rk_half
[
None
,
:]
<
rotary_dim_half
),
other
=
1.0
).
to
(
tl
.
float32
)
sin
=
tl
.
load
(
SIN
,
mask
=
(
rm_cs
[:,
None
]
<
seqlen_ro
)
&
(
rk_half
[
None
,
:]
<
rotary_dim_half
),
other
=
0.0
).
to
(
tl
.
float32
)
x0
=
tl
.
load
(
X
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk_half
[
None
,
:]
<
rotary_dim_half
),
other
=
0.0
).
to
(
tl
.
float32
)
x1
=
tl
.
load
(
X
+
rotary_dim_half
*
stride_x_headdim
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk_half
[
None
,
:]
<
rotary_dim_half
),
other
=
0.0
,
).
to
(
tl
.
float32
)
if
CONJUGATE
:
sin
=
-
sin
o0
=
x0
*
cos
-
x1
*
sin
o1
=
x0
*
sin
+
x1
*
cos
# write back result
OUT
=
OUT
+
(
rm
[:,
None
]
*
stride_out_seqlen
+
rk_half
[
None
,
:]
*
stride_out_headdim
)
tl
.
store
(
OUT
,
o0
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk_half
[
None
,
:]
<
rotary_dim_half
))
tl
.
store
(
OUT
+
rotary_dim_half
*
stride_out_headdim
,
o1
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk_half
[
None
,
:]
<
rotary_dim_half
),
)
else
:
# We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow.
# Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...].
# Loading x0 will be fast but x1 will be slow.
# Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...].
# Then we do the calculation and use tl.where to pick put the right outputs for the even
# and for the odd indices.
rk_swap
=
rk
+
((
rk
+
1
)
%
2
)
*
2
-
1
# 1, 0, 3, 2, 5, 4, ...
rk_repeat
=
tl
.
arange
(
0
,
BLOCK_K
)
//
2
X0
=
X
+
(
rm
[:,
None
]
*
stride_x_seqlen
+
rk
[
None
,
:]
*
stride_x_headdim
)
X1
=
X
+
(
rm
[:,
None
]
*
stride_x_seqlen
+
rk_swap
[
None
,
:]
*
stride_x_headdim
)
COS
=
COS
+
(
rm_cs
[:,
None
]
*
rotary_dim_half
+
rk_repeat
[
None
,
:])
SIN
=
SIN
+
(
rm_cs
[:,
None
]
*
rotary_dim_half
+
rk_repeat
[
None
,
:])
cos
=
tl
.
load
(
COS
,
mask
=
(
rm_cs
[:,
None
]
<
seqlen_ro
)
&
(
rk_repeat
[
None
,
:]
<
rotary_dim_half
),
other
=
1.0
,
).
to
(
tl
.
float32
)
sin
=
tl
.
load
(
SIN
,
mask
=
(
rm_cs
[:,
None
]
<
seqlen_ro
)
&
(
rk_repeat
[
None
,
:]
<
rotary_dim_half
),
other
=
0.0
,
).
to
(
tl
.
float32
)
x0
=
tl
.
load
(
X0
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk
[
None
,
:]
<
rotary_dim
),
other
=
0.0
).
to
(
tl
.
float32
)
x1
=
tl
.
load
(
X1
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk_swap
[
None
,
:]
<
rotary_dim
),
other
=
0.0
).
to
(
tl
.
float32
)
if
CONJUGATE
:
sin
=
-
sin
x0_cos
=
x0
*
cos
x1_sin
=
x1
*
sin
out
=
tl
.
where
(
rk
[
None
,
:]
%
2
==
0
,
x0_cos
-
x1_sin
,
x0_cos
+
x1_sin
)
OUT
=
OUT
+
(
rm
[:,
None
]
*
stride_out_seqlen
+
rk
[
None
,
:]
*
stride_out_headdim
)
tl
.
store
(
OUT
,
out
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk
[
None
,
:]
<
rotary_dim
))
def
apply_rotary
(
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
seqlen_offsets
:
Union
[
int
,
torch
.
Tensor
]
=
0
,
cu_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
max_seqlen
:
Optional
[
int
]
=
None
,
interleaved
=
False
,
inplace
=
False
,
conjugate
=
False
,
)
->
torch
.
Tensor
:
"""
Arguments:
x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
else (total_seqlen, nheads, headdim).
cos: (seqlen_ro, rotary_dim / 2)
sin: (seqlen_ro, rotary_dim / 2)
seqlen_offsets: integer or integer tensor of size (batch,)
cu_seqlens: (batch + 1,) or None
max_seqlen: int
Returns:
y: (batch, seqlen, nheads, headdim)
"""
is_varlen
=
cu_seqlens
is
not
None
if
not
is_varlen
:
batch
,
seqlen
,
nheads
,
headdim
=
x
.
shape
else
:
assert
max_seqlen
is
not
None
,
"If cu_seqlens is passed in, then max_seqlen must be passed"
total_seqlen
,
nheads
,
headdim
=
x
.
shape
batch_p_1
=
cu_seqlens
.
shape
[
0
]
batch
=
batch_p_1
-
1
seqlen
=
max_seqlen
seqlen_ro
,
rotary_dim
=
cos
.
shape
assert
sin
.
shape
==
cos
.
shape
rotary_dim
*=
2
assert
rotary_dim
<=
headdim
,
"rotary_dim must be <= headdim"
assert
headdim
<=
256
,
"Only support headdim <= 256"
assert
seqlen_ro
>=
seqlen
,
"seqlen_ro must be >= seqlen"
assert
(
cos
.
dtype
==
sin
.
dtype
),
f
"cos and sin must have the same dtype, got
{
cos
.
dtype
}
and
{
sin
.
dtype
}
"
assert
(
x
.
dtype
==
cos
.
dtype
),
f
"Input and cos/sin must have the same dtype, got
{
x
.
dtype
}
and
{
cos
.
dtype
}
"
cos
,
sin
=
cos
.
contiguous
(),
sin
.
contiguous
()
if
isinstance
(
seqlen_offsets
,
torch
.
Tensor
):
assert
seqlen_offsets
.
shape
==
(
batch
,)
assert
seqlen_offsets
.
dtype
in
[
torch
.
int32
,
torch
.
int64
]
seqlen_offsets
=
seqlen_offsets
.
contiguous
()
else
:
assert
seqlen_offsets
+
seqlen
<=
seqlen_ro
output
=
torch
.
empty_like
(
x
)
if
not
inplace
else
x
if
rotary_dim
<
headdim
and
not
inplace
:
output
[...,
rotary_dim
:].
copy_
(
x
[...,
rotary_dim
:])
BLOCK_K
=
(
32
if
rotary_dim
<=
32
else
(
64
if
rotary_dim
<=
64
else
(
128
if
rotary_dim
<=
128
else
256
))
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
seqlen
,
META
[
"BLOCK_M"
]),
batch
,
nheads
)
# noqa
BLOCK_M
=
4
if
interleaved
else
(
8
if
rotary_dim
<=
64
else
4
)
# Need this, otherwise Triton tries to launch from cuda:0 and we get
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
with
torch
.
cuda
.
device
(
x
.
device
.
index
):
rotary_kernel
[
grid
](
output
,
# data ptrs
x
,
cos
,
sin
,
cu_seqlens
,
seqlen_offsets
,
seqlen
,
# shapes
rotary_dim
,
seqlen_ro
,
output
.
stride
(
0
)
if
not
is_varlen
else
0
,
# batch_strides if not varlen else 0
output
.
stride
(
-
3
),
# seqlen_stride or total_seqlen_stride
output
.
stride
(
-
2
),
# nheads_stride
output
.
stride
(
-
1
),
# headdim_stride
x
.
stride
(
0
)
if
not
is_varlen
else
0
,
# batch_strides if not varlen else 0
x
.
stride
(
-
3
),
# seqlen stride or total_seqlen_stride
x
.
stride
(
-
2
),
# nheads stride
x
.
stride
(
-
1
),
# headdim stride
BLOCK_K
,
isinstance
(
seqlen_offsets
,
torch
.
Tensor
),
is_varlen
,
interleaved
,
conjugate
,
BLOCK_M
,
)
return
output
flash_attn/utils/__init__.py
deleted
100644 → 0
View file @
5018ac6a
flash_attn/utils/benchmark.py
deleted
100644 → 0
View file @
5018ac6a
# Copyright (c) 2023, Tri Dao.
""" Useful functions for writing test code. """
import
torch
import
torch.utils.benchmark
as
benchmark
def
benchmark_forward
(
fn
,
*
inputs
,
repeats
=
10
,
desc
=
""
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
):
"""Use Pytorch Benchmark on the forward pass of an arbitrary function."""
if
verbose
:
print
(
desc
,
"- Forward pass"
)
def
amp_wrapper
(
*
inputs
,
**
kwinputs
):
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
fn
(
*
inputs
,
**
kwinputs
)
t
=
benchmark
.
Timer
(
stmt
=
"fn_amp(*inputs, **kwinputs)"
,
globals
=
{
"fn_amp"
:
amp_wrapper
,
"inputs"
:
inputs
,
"kwinputs"
:
kwinputs
},
num_threads
=
torch
.
get_num_threads
(),
)
m
=
t
.
timeit
(
repeats
)
if
verbose
:
print
(
m
)
return
t
,
m
def
benchmark_backward
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
""
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
,
):
"""Use Pytorch Benchmark on the backward pass of an arbitrary function."""
if
verbose
:
print
(
desc
,
"- Backward pass"
)
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
y
=
fn
(
*
inputs
,
**
kwinputs
)
if
type
(
y
)
is
tuple
:
y
=
y
[
0
]
if
grad
is
None
:
grad
=
torch
.
randn_like
(
y
)
else
:
if
grad
.
shape
!=
y
.
shape
:
raise
RuntimeError
(
"Grad shape does not match output shape"
)
def
f
(
*
inputs
,
y
,
grad
):
# Set .grad to None to avoid extra operation of gradient accumulation
for
x
in
inputs
:
if
isinstance
(
x
,
torch
.
Tensor
):
x
.
grad
=
None
y
.
backward
(
grad
,
retain_graph
=
True
)
t
=
benchmark
.
Timer
(
stmt
=
"f(*inputs, y=y, grad=grad)"
,
globals
=
{
"f"
:
f
,
"inputs"
:
inputs
,
"y"
:
y
,
"grad"
:
grad
},
num_threads
=
torch
.
get_num_threads
(),
)
m
=
t
.
timeit
(
repeats
)
if
verbose
:
print
(
m
)
return
t
,
m
def
benchmark_combined
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
""
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
,
):
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
if
verbose
:
print
(
desc
,
"- Forward + Backward pass"
)
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
y
=
fn
(
*
inputs
,
**
kwinputs
)
if
type
(
y
)
is
tuple
:
y
=
y
[
0
]
if
grad
is
None
:
grad
=
torch
.
randn_like
(
y
)
else
:
if
grad
.
shape
!=
y
.
shape
:
raise
RuntimeError
(
"Grad shape does not match output shape"
)
def
f
(
grad
,
*
inputs
,
**
kwinputs
):
for
x
in
inputs
:
if
isinstance
(
x
,
torch
.
Tensor
):
x
.
grad
=
None
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
y
=
fn
(
*
inputs
,
**
kwinputs
)
if
type
(
y
)
is
tuple
:
y
=
y
[
0
]
y
.
backward
(
grad
,
retain_graph
=
True
)
t
=
benchmark
.
Timer
(
stmt
=
"f(grad, *inputs, **kwinputs)"
,
globals
=
{
"f"
:
f
,
"fn"
:
fn
,
"inputs"
:
inputs
,
"grad"
:
grad
,
"kwinputs"
:
kwinputs
},
num_threads
=
torch
.
get_num_threads
(),
)
m
=
t
.
timeit
(
repeats
)
if
verbose
:
print
(
m
)
return
t
,
m
def
benchmark_fwd_bwd
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
""
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
,
):
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
return
(
benchmark_forward
(
fn
,
*
inputs
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
,
),
benchmark_backward
(
fn
,
*
inputs
,
grad
=
grad
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
,
),
)
def
benchmark_all
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
""
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
,
):
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
return
(
benchmark_forward
(
fn
,
*
inputs
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
,
),
benchmark_backward
(
fn
,
*
inputs
,
grad
=
grad
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
,
),
benchmark_combined
(
fn
,
*
inputs
,
grad
=
grad
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
,
),
)
def
pytorch_profiler
(
fn
,
*
inputs
,
trace_filename
=
None
,
backward
=
False
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
cpu
=
False
,
verbose
=
True
,
**
kwinputs
,
):
"""Wrap benchmark functions in Pytorch profiler to see CUDA information."""
if
backward
:
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
out
=
fn
(
*
inputs
,
**
kwinputs
)
if
type
(
out
)
is
tuple
:
out
=
out
[
0
]
g
=
torch
.
randn_like
(
out
)
for
_
in
range
(
30
):
# Warm up
if
backward
:
for
x
in
inputs
:
if
isinstance
(
x
,
torch
.
Tensor
):
x
.
grad
=
None
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
out
=
fn
(
*
inputs
,
**
kwinputs
)
if
type
(
out
)
is
tuple
:
out
=
out
[
0
]
# Backward should be done outside autocast
if
backward
:
out
.
backward
(
g
,
retain_graph
=
True
)
activities
=
([
torch
.
profiler
.
ProfilerActivity
.
CPU
]
if
cpu
else
[])
+
[
torch
.
profiler
.
ProfilerActivity
.
CUDA
]
with
torch
.
profiler
.
profile
(
activities
=
activities
,
record_shapes
=
True
,
# profile_memory=True,
with_stack
=
True
,
)
as
prof
:
if
backward
:
for
x
in
inputs
:
if
isinstance
(
x
,
torch
.
Tensor
):
x
.
grad
=
None
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
out
=
fn
(
*
inputs
,
**
kwinputs
)
if
type
(
out
)
is
tuple
:
out
=
out
[
0
]
if
backward
:
out
.
backward
(
g
,
retain_graph
=
True
)
if
verbose
:
# print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50))
print
(
prof
.
key_averages
().
table
(
row_limit
=
50
))
if
trace_filename
is
not
None
:
prof
.
export_chrome_trace
(
trace_filename
)
def
benchmark_memory
(
fn
,
*
inputs
,
desc
=
""
,
verbose
=
True
,
**
kwinputs
):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
reset_peak_memory_stats
()
torch
.
cuda
.
synchronize
()
fn
(
*
inputs
,
**
kwinputs
)
torch
.
cuda
.
synchronize
()
mem
=
torch
.
cuda
.
max_memory_allocated
()
/
((
2
**
20
)
*
1000
)
if
verbose
:
print
(
f
"
{
desc
}
max memory:
{
mem
}
GB"
)
torch
.
cuda
.
empty_cache
()
return
mem
flash_attn/utils/distributed.py
deleted
100644 → 0
View file @
5018ac6a
from
typing
import
Optional
import
torch
from
torch
import
Tensor
from
torch.distributed
import
ProcessGroup
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 4 lines are for backward compatibility with
# older PyTorch.
if
"all_gather_into_tensor"
not
in
dir
(
torch
.
distributed
):
torch
.
distributed
.
all_gather_into_tensor
=
torch
.
distributed
.
_all_gather_base
if
"reduce_scatter_tensor"
not
in
dir
(
torch
.
distributed
):
torch
.
distributed
.
reduce_scatter_tensor
=
torch
.
distributed
.
_reduce_scatter_base
# Raw operation, does not support autograd, but does support async
def
all_gather_raw
(
input_
:
Tensor
,
process_group
:
ProcessGroup
,
async_op
:
bool
=
False
):
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
output
=
torch
.
empty
(
world_size
*
input_
.
shape
[
0
],
*
input_
.
shape
[
1
:],
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
handle
=
torch
.
distributed
.
all_gather_into_tensor
(
output
,
input_
.
contiguous
(),
group
=
process_group
,
async_op
=
async_op
)
return
output
,
handle
# Raw operation, does not support autograd, but does support async
def
reduce_scatter_raw
(
input_
:
Tensor
,
process_group
:
ProcessGroup
,
async_op
:
bool
=
False
):
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
assert
input_
.
shape
[
0
]
%
world_size
==
0
output
=
torch
.
empty
(
input_
.
shape
[
0
]
//
world_size
,
*
input_
.
shape
[
1
:],
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
handle
=
torch
.
distributed
.
reduce_scatter_tensor
(
output
,
input_
.
contiguous
(),
group
=
process_group
,
async_op
=
async_op
)
return
output
,
handle
# Raw operation, does not support autograd, but does support async
def
all_reduce_raw
(
input_
:
Tensor
,
process_group
:
ProcessGroup
,
async_op
:
bool
=
False
):
input_
=
input_
.
contiguous
()
handle
=
torch
.
distributed
.
all_reduce
(
input_
,
group
=
process_group
,
async_op
=
async_op
)
return
input_
,
handle
class
AllGatherFunc
(
torch
.
autograd
.
Function
):
"""Gather the input from sequence parallel region and concatenate."""
@
staticmethod
def
forward
(
ctx
,
input_
:
Tensor
,
process_group
:
ProcessGroup
)
->
Tensor
:
ctx
.
process_group
=
process_group
output
,
_
=
all_gather_raw
(
input_
,
process_group
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
:
Tensor
):
grad_input
,
_
=
reduce_scatter_raw
(
grad_output
,
ctx
.
process_group
)
return
grad_input
,
None
# Supports autograd, but does not support async
all_gather
=
AllGatherFunc
.
apply
class
ReduceScatterFunc
(
torch
.
autograd
.
Function
):
"""Reduce scatter the input from the sequence parallel region and concatenate."""
@
staticmethod
def
forward
(
ctx
,
input_
:
Tensor
,
process_group
:
ProcessGroup
)
->
Tensor
:
ctx
.
process_group
=
process_group
output
,
_
=
reduce_scatter_raw
(
input_
,
process_group
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
:
Tensor
):
grad_input
,
_
=
all_gather_raw
(
grad_output
,
ctx
.
process_group
)
return
grad_input
,
None
# Supports autograd, but does not support async
reduce_scatter
=
ReduceScatterFunc
.
apply
class
AllReduceFunc
(
torch
.
autograd
.
Function
):
"""Gather the input from sequence parallel region and concatenate."""
@
staticmethod
def
forward
(
ctx
,
input_
:
Tensor
,
process_group
:
ProcessGroup
)
->
Tensor
:
ctx
.
process_group
=
process_group
output
,
_
=
all_reduce_raw
(
input_
,
process_group
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
:
Tensor
):
return
grad_output
,
None
# Supports autograd, but does not support async
all_reduce
=
AllReduceFunc
.
apply
def
sync_shared_params
(
model
:
torch
.
nn
.
Module
,
process_group
:
ProcessGroup
):
# We want to iterate over parameters with _shared_params=True in the same order,
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
pamams_shared
=
{
name
:
p
for
name
,
p
in
model
.
named_parameters
()
if
getattr
(
p
,
"_shared_params"
,
False
)
}
for
_
,
p
in
sorted
(
pamams_shared
.
items
()):
with
torch
.
no_grad
():
# Broadcast needs src to be global rank, not group rank
torch
.
distributed
.
broadcast
(
p
,
src
=
torch
.
distributed
.
get_global_rank
(
process_group
,
0
),
group
=
process_group
)
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256
def
allreduce_sequence_parallel_grad
(
model
:
torch
.
nn
.
Module
,
process_group
:
ProcessGroup
):
# We want to iterate over parameters with _sequence_parallel=True in the same order,
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
params_seqparallel
=
{
name
:
p
for
name
,
p
in
model
.
named_parameters
()
if
getattr
(
p
,
"_sequence_parallel"
,
False
)
}
grads
=
[
p
.
grad
for
_
,
p
in
sorted
(
params_seqparallel
.
items
())]
if
grads
:
with
torch
.
no_grad
():
coalesced
=
torch
.
_utils
.
_flatten_dense_tensors
(
grads
)
torch
.
distributed
.
all_reduce
(
coalesced
,
group
=
process_group
)
for
buf
,
synced
in
zip
(
grads
,
torch
.
_utils
.
_unflatten_dense_tensors
(
coalesced
,
grads
)):
buf
.
copy_
(
synced
)
def
get_dim_for_local_rank
(
dim
:
int
,
world_size
:
int
,
local_rank
:
int
,
multiple_of
:
int
=
1
)
->
int
:
"""Get the dim for the local rank derived from splitting dim on world_size processes.
The split may not be even across the world_size processes.
"""
multiple
=
dim
//
multiple_of
div
=
multiple
//
world_size
mod
=
multiple
%
world_size
local_multiple
=
div
+
int
(
local_rank
<
mod
)
return
local_multiple
*
multiple_of
flash_attn/utils/generation.py
deleted
100644 → 0
View file @
5018ac6a
# Copyright (c) 2023, Tri Dao.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
import
gc
import
time
from
collections
import
namedtuple
from
dataclasses
import
dataclass
,
field
from
functools
import
partial
from
typing
import
Callable
,
Optional
,
Sequence
,
Union
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
from
torch
import
Tensor
from
torch.profiler
import
ProfilerActivity
,
profile
,
record_function
try
:
from
transformers.generation
import
GreedySearchDecoderOnlyOutput
,
SampleDecoderOnlyOutput
except
ImportError
:
GreedySearchDecoderOnlyOutput
=
namedtuple
(
"GreedySearchDecoderOnlyOutput"
,
[
"sequences"
,
"scores"
])
SampleDecoderOnlyOutput
=
namedtuple
(
"SampleDecoderOnlyOutput"
,
[
"sequences"
,
"scores"
])
@
dataclass
class
InferenceParams
:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
max_seqlen
:
int
max_batch_size
:
int
seqlen_offset
:
int
=
0
batch_size_offset
:
int
=
0
key_value_memory_dict
:
dict
=
field
(
default_factory
=
dict
)
lengths_per_sample
:
Optional
[
Tensor
]
=
None
def
reset
(
self
,
max_seqlen
,
max_batch_size
):
self
.
max_seqlen
=
max_seqlen
self
.
max_batch_size
=
max_batch_size
self
.
seqlen_offset
=
0
if
self
.
lengths_per_sample
is
not
None
:
self
.
lengths_per_sample
.
zero_
()
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
def
modify_logits_for_top_k_filtering
(
logits
,
top_k
):
"""Set the logits for none top-k values to -inf. Done in-place."""
indices_to_remove
=
logits
<
torch
.
topk
(
logits
,
top_k
)[
0
][...,
-
1
,
None
]
logits
.
masked_fill_
(
indices_to_remove
,
float
(
"-Inf"
))
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
def
modify_logits_for_top_p_filtering
(
logits
,
top_p
):
"""Set the logits for none top-p values to -inf. Done in-place."""
if
top_p
<=
0.0
or
top_p
>=
1.0
:
return
# First sort and calculate cumulative sum of probabilities.
sorted_logits
,
sorted_indices
=
torch
.
sort
(
logits
,
descending
=
False
)
cumulative_probs
=
sorted_logits
.
softmax
(
dim
=-
1
).
cumsum
(
dim
=-
1
)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove
=
cumulative_probs
<=
(
1
-
top_p
)
# scatter sorted tensors to original indexing
indices_to_remove
=
sorted_indices_to_remove
.
scatter
(
1
,
sorted_indices
,
sorted_indices_to_remove
)
logits
.
masked_fill_
(
indices_to_remove
,
float
(
"-inf"
))
def
sample
(
logits
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
):
"""Sample from top-k logits.
Arguments:
logits: Tensor of shape (batch_size, vocab_size)
"""
if
top_k
==
1
:
# Short-circuit for greedy decoding
return
logits
.
argmax
(
dim
=-
1
)
else
:
if
top_p
>
0.0
:
assert
top_p
<=
1.0
,
"top-p should be in (0, 1]."
if
top_k
>
0
:
top_k
=
min
(
top_k
,
logits
.
size
(
-
1
))
# Safety check
logits_top
,
indices
=
torch
.
topk
(
logits
,
top_k
,
dim
=-
1
)
if
temperature
!=
1.0
:
logits_top
/=
temperature
modify_logits_for_top_p_filtering
(
logits_top
,
top_p
)
return
indices
[
torch
.
arange
(
indices
.
shape
[
0
],
device
=
indices
.
device
),
torch
.
multinomial
(
torch
.
softmax
(
logits_top
,
dim
=-
1
),
num_samples
=
1
).
squeeze
(
dim
=-
1
),
]
else
:
# Clone so that when we modify for top_p we don't change the original logits
logits_top
=
logits
/
temperature
if
temperature
!=
1.0
else
logits
.
clone
()
modify_logits_for_top_p_filtering
(
logits_top
,
top_p
)
return
torch
.
multinomial
(
torch
.
softmax
(
logits_top
,
dim
=-
1
),
num_samples
=
1
).
squeeze
(
dim
=-
1
)
@
torch
.
inference_mode
()
def
decode
(
input_ids
,
model
,
max_length
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
,
eos_token_id
=
None
,
teacher_outputs
=
None
,
vocab_size
=
None
,
tensor_parallel
=
1
,
cg
=
False
,
enable_timing
=
False
,
):
"""Decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
then top-p.
We assume that all sequences in the same batch have the same length.
Arguments:
input_ids: (batch, seq_len)
max_length: int
teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
logits, the next token is taken from the teacher_outputs. Useful for testing.
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
sequences: (batch, max_length)
scores: tuples of (batch, vocab_size)
"""
batch_size
,
seqlen_og
=
input_ids
.
shape
teacher_output_len
=
teacher_outputs
.
shape
[
1
]
if
teacher_outputs
is
not
None
else
0
if
cg
:
if
not
hasattr
(
model
,
"_decoding_cache"
):
model
.
_decoding_cache
=
None
model
.
_decoding_cache
=
update_graph_cache
(
model
,
model
.
_decoding_cache
,
batch_size
,
seqlen_og
,
max_length
,
tensor_parallel
=
tensor_parallel
,
)
inference_params
=
model
.
_decoding_cache
.
inference_params
inference_params
.
reset
(
max_length
,
batch_size
)
else
:
inference_params
=
InferenceParams
(
max_seqlen
=
max_length
,
max_batch_size
=
batch_size
)
def
get_logits
(
input_ids
,
inference_params
):
decoding
=
inference_params
.
seqlen_offset
>
0
if
decoding
:
position_ids
=
torch
.
full
(
(
batch_size
,
1
),
inference_params
.
seqlen_offset
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
,
)
else
:
position_ids
=
None
if
not
cg
or
not
decoding
:
logits
=
model
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
,
num_last_tokens
=
1
,
).
logits
.
squeeze
(
dim
=
1
)
else
:
logits
=
model
.
_decoding_cache
.
run
(
input_ids
,
position_ids
,
inference_params
.
seqlen_offset
).
squeeze
(
dim
=
1
)
return
logits
[...,
:
vocab_size
]
if
vocab_size
is
not
None
else
logits
def
sample_tokens
(
logits
,
inference_params
):
if
teacher_outputs
is
None
or
teacher_output_len
<=
inference_params
.
seqlen_offset
:
token
=
sample
(
logits
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
)
else
:
token
=
teacher_outputs
[:,
inference_params
.
seqlen_offset
]
# return rearrange(token, "b -> b 1")
return
token
.
unsqueeze
(
1
)
def
should_stop
(
current_token
,
inference_params
):
if
inference_params
.
seqlen_offset
==
0
:
return
False
if
eos_token_id
is
not
None
and
(
current_token
==
eos_token_id
).
all
():
return
True
if
inference_params
.
seqlen_offset
>=
max_length
-
1
:
return
True
return
False
start
=
torch
.
cuda
.
Event
(
enable_timing
=
enable_timing
)
end
=
torch
.
cuda
.
Event
(
enable_timing
=
enable_timing
)
if
enable_timing
:
if
tensor_parallel
>
1
:
torch
.
distributed
.
barrier
()
start
.
record
()
scores
,
sequences
=
[],
[
input_ids
]
while
not
should_stop
(
sequences
[
-
1
],
inference_params
):
scores
.
append
(
get_logits
(
sequences
[
-
1
],
inference_params
))
inference_params
.
seqlen_offset
+=
sequences
[
-
1
].
shape
[
1
]
sequences
.
append
(
sample_tokens
(
scores
[
-
1
],
inference_params
))
if
enable_timing
:
end
.
record
()
if
tensor_parallel
>
1
:
torch
.
distributed
.
barrier
()
torch
.
cuda
.
synchronize
()
print
(
f
"Prompt processing + decoding time:
{
(
start
.
elapsed_time
(
end
)):.
0
f
}
ms"
)
output_cls
=
GreedySearchDecoderOnlyOutput
if
top_k
==
1
else
SampleDecoderOnlyOutput
return
output_cls
(
sequences
=
torch
.
cat
(
sequences
,
dim
=
1
),
scores
=
tuple
(
scores
))
def
sample_speculative
(
logits
,
logits_draft
,
tokens_draft
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
):
"""Algorithm 1 from [1]
[1] Fast Inference from Transformers via Speculative Decoding
Yaniv Leviathan, Matan Kalman, Yossi Matias
https://arxiv.org/abs/2211.17192
Arguments:
logits: Tensor of shape (batch_size, seqlen + 1, vocab_size)
logits_draft: Tensor of shape (batch_size, seqlen, vocab_size)
tokens_draft: Tensor of shape (batch_size, seqlen)
Return:
tokens: Tensor of shape (batch_size, seqlen + 1)
num_generated_tokens: Tensor of shape (batch_size), with value in [1, seqlen + 1].
For each sequence in the batch, the number of valid tokens that were sampled by
speculative sampling.
"""
batch
,
seqlen_p_1
,
vocab_size
=
logits
.
shape
seqlen
=
seqlen_p_1
-
1
assert
logits_draft
.
shape
==
(
batch
,
seqlen
,
vocab_size
)
assert
tokens_draft
.
shape
==
(
batch
,
seqlen
)
assert
tokens_draft
.
dtype
in
[
torch
.
int64
,
torch
.
int32
]
# TODO: if top_k = 1 we can simplify things and only work with indices
if
top_p
>
0.0
:
assert
top_p
<=
1.0
,
"top-p should be in (0, 1]."
# Clone so that when we modify for top_p we don't change the original logits
logits
=
logits
/
temperature
if
temperature
!=
1.0
else
logits
.
clone
()
logits_draft
=
logits_draft
/
temperature
if
temperature
!=
1.0
else
logits_draft
.
clone
()
if
top_k
>
0
:
top_k
=
min
(
top_k
,
logits
.
size
(
-
1
))
# Safety check
modify_logits_for_top_k_filtering
(
logits
,
top_k
)
modify_logits_for_top_k_filtering
(
logits_draft
,
top_k
)
modify_logits_for_top_p_filtering
(
logits
,
top_p
)
modify_logits_for_top_p_filtering
(
logits_draft
,
top_p
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
probs_draft
=
torch
.
softmax
(
logits_draft
,
dim
=-
1
)
gather
=
lambda
probs
,
tokens
:
rearrange
(
probs
.
gather
(
dim
=-
1
,
index
=
rearrange
(
tokens
,
"... -> ... 1"
)),
"... 1 -> ..."
)
# (batch, seqlen)
accepted
=
torch
.
rand
(
batch
,
seqlen
,
device
=
probs
.
device
)
*
gather
(
probs_draft
,
tokens_draft
)
<=
gather
(
probs
[:,
:
-
1
],
tokens_draft
)
accepted_all
=
accepted
.
all
(
dim
=-
1
)
# (batch,)
first_rejected_idx
=
torch
.
where
(
accepted_all
,
seqlen
,
accepted
.
int
().
argmin
(
dim
=-
1
))
probs_diff
=
torch
.
clamp
(
probs
[:,
:
-
1
]
-
probs_draft
,
min
=
0.0
)
# torch.multinomial can deal with unnormalized probabilities
# probs_diff /= probs_diff.sum(dim=-1, keepdim=True)
resample_probs
=
torch
.
cat
([
probs_diff
,
probs
[:,
-
1
:]],
dim
=
1
)
resample_probs
=
rearrange
(
resample_probs
.
gather
(
dim
=
1
,
index
=
repeat
(
first_rejected_idx
,
"b -> b 1 d"
,
d
=
vocab_size
)),
"b 1 d -> b d"
,
)
resample
=
torch
.
multinomial
(
resample_probs
,
num_samples
=
1
).
squeeze
(
dim
=-
1
)
# (batch,)
tokens
=
F
.
pad
(
tokens_draft
,
(
0
,
1
))
tokens
[:,
first_rejected_idx
]
=
resample
return
tokens
,
first_rejected_idx
+
1
@
torch
.
inference_mode
()
def
decode_speculative
(
input_ids
,
model
,
model_draft
,
max_length
,
speculative_lookahead
=
3
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
,
eos_token_id
=
None
,
vocab_size
=
None
,
tensor_parallel
=
1
,
cg
=
False
,
enable_timing
=
False
,
debug
=
False
,
):
"""
TD: WIP, for my own understanding, lightly tested. Only support batch_size == 1 for now.
Speculative decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
then top-p.
We assume that all sequences in the same batch have the same length.
Arguments:
input_ids: (batch, seq_len)
max_length: int
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
sequences: (batch, max_length)
scores: tuples of (batch, vocab_size)
"""
batch_size
,
seqlen_og
=
input_ids
.
shape
assert
batch_size
==
1
,
"Speculative decoding implementation only supports batch_size=1"
assert
eos_token_id
is
None
,
"Speculative decoding implementation doesn't support eos_token_id"
if
cg
:
if
not
hasattr
(
model_draft
,
"_decoding_cache"
):
model_draft
.
_decoding_cache
=
None
model_draft
.
_decoding_cache
=
update_graph_cache
(
model_draft
,
model_draft
.
_decoding_cache
,
batch_size
,
seqlen_og
,
max_length
,
# draft model needs to process either 1 or 2 tokens at a time
decoding_seqlens
=
(
1
,
2
),
tensor_parallel
=
tensor_parallel
,
)
inference_params_draft
=
model_draft
.
_decoding_cache
.
inference_params
inference_params_draft
.
reset
(
max_length
,
batch_size
)
if
not
hasattr
(
model
,
"_decoding_cache"
):
model
.
_decoding_cache
=
None
model
.
_decoding_cache
=
update_graph_cache
(
model
,
model
.
_decoding_cache
,
batch_size
,
seqlen_og
,
max_length
,
decoding_seqlens
=
range
(
1
,
speculative_lookahead
+
2
),
tensor_parallel
=
tensor_parallel
,
)
inference_params
=
model
.
_decoding_cache
.
inference_params
inference_params
.
reset
(
max_length
,
batch_size
)
else
:
inference_params_draft
=
InferenceParams
(
max_seqlen
=
max_length
,
max_batch_size
=
batch_size
)
inference_params
=
InferenceParams
(
max_seqlen
=
max_length
,
max_batch_size
=
batch_size
)
def
get_logits
(
input_ids
,
inference_params
,
model
,
num_last_tokens
=
1
,
cg
=
False
):
decoding
=
inference_params
.
seqlen_offset
>
0
if
decoding
:
seqlen
=
input_ids
.
shape
[
1
]
# if inference_params.lengths_per_sample is None:
# TODO: in the case of batched decoding where each sequence has a different length,
# we need to compute the position_ids for each sequence using lengths_per_sample
if
True
:
cache_seqlens
=
torch
.
full
(
(
input_ids
.
shape
[
0
],),
inference_params
.
seqlen_offset
,
dtype
=
torch
.
int32
,
device
=
input_ids
.
device
,
)
else
:
cache_seqlens
=
inference_params
.
lengths_per_sample
position_ids
=
cache_seqlens
[:,
None
]
+
torch
.
arange
(
seqlen
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
else
:
position_ids
=
None
if
not
cg
or
not
decoding
:
logits
=
model
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
,
num_last_tokens
=
num_last_tokens
,
).
logits
else
:
# NOTE: careful, CUDA graph is set to have num_last_tokens=input_ids.shape[1].
# This might not be compatible the num_last_tokens used here.
assert
num_last_tokens
<=
input_ids
.
shape
[
1
]
logits
=
model
.
_decoding_cache
.
run
(
input_ids
,
position_ids
,
inference_params
.
seqlen_offset
)[:,
-
num_last_tokens
:]
return
logits
[...,
:
vocab_size
]
if
vocab_size
is
not
None
else
logits
def
sample_tokens
(
input_ids
,
get_logits_fn
,
inference_params
,
sample_fn
,
num_tokens
=
1
):
"""Sample `num_tokens` tokens from the model, given the previous logits.
Also return the logits of the sampled tokens.
Arguments:
input_ids: (batch, seqlen)
Return:
tokens: (batch, num_tokens)
scores: (batch, num_tokens), which contains @previous_logits and the logits of the next
(num_tokens - 1) tokens. The logits of the last token isn't computed.
"""
assert
num_tokens
>=
1
sequences
,
scores
=
[
input_ids
],
[]
for
i
in
range
(
num_tokens
):
scores
.
append
(
get_logits_fn
(
sequences
[
-
1
],
inference_params
)[:,
-
1
])
inference_params
.
seqlen_offset
+=
sequences
[
-
1
].
shape
[
1
]
sequences
.
append
(
sample_fn
(
scores
[
-
1
]).
unsqueeze
(
1
))
return
torch
.
cat
(
sequences
[
1
:],
dim
=
1
),
torch
.
stack
(
scores
,
dim
=
1
)
sampling_kwargs
=
dict
(
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
)
sample_fn
=
partial
(
sample
,
**
sampling_kwargs
)
get_logits_main
=
partial
(
get_logits
,
model
=
model
,
cg
=
cg
)
get_logits_draft
=
partial
(
get_logits
,
model
=
model_draft
,
cg
=
cg
)
sample_tokens_main
=
partial
(
sample_tokens
,
get_logits_fn
=
get_logits_main
,
sample_fn
=
sample_fn
,
inference_params
=
inference_params
,
)
sample_tokens_draft
=
partial
(
sample_tokens
,
get_logits_fn
=
get_logits_draft
,
sample_fn
=
sample_fn
,
inference_params
=
inference_params_draft
,
)
if
debug
:
from
transformers
import
AutoTokenizer
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"gpt2"
)
if
enable_timing
:
if
tensor_parallel
>
1
:
torch
.
distributed
.
barrier
()
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
sequences
,
scores
=
[
input_ids
],
[]
num_main_model_calls
=
0
num_draft_tokens
=
0
num_accepted_tokens_history
=
[]
if
seqlen_og
>=
max_length
-
1
:
# Don't do speculative sampling, just sample 1 token from the model
tokens
,
scores_new
=
sample_tokens_main
(
input_ids
,
num_tokens
=
1
)
sequences
.
append
(
tokens
)
scores
.
append
(
scores_new
)
else
:
# Sample from draft model, which produces @n_spec_tokens, and @model
# will then use to produce between 1 and 1 + @n_spec_tokens tokens.
# We want seqlen_og + 1 + @n_spec_tokens to be <= @max_length.
n_spec_tokens
=
min
(
speculative_lookahead
,
max_length
-
seqlen_og
-
1
)
tokens_draft
,
scores_draft
=
sample_tokens_draft
(
input_ids
,
num_tokens
=
n_spec_tokens
)
num_draft_tokens
+=
n_spec_tokens
if
debug
:
scores_draft_ref
=
model_draft
(
torch
.
cat
([
input_ids
,
tokens_draft
],
dim
=
1
),
num_last_tokens
=
n_spec_tokens
+
1
).
logits
print
((
scores_draft
-
scores_draft_ref
[:,
:
-
1
]).
abs
().
max
())
# Evaluate the draft tokens with the model
logits
=
get_logits_main
(
torch
.
cat
([
input_ids
,
tokens_draft
],
dim
=
1
),
inference_params
,
num_last_tokens
=
n_spec_tokens
+
1
,
)
num_main_model_calls
+=
1
if
debug
:
logits_ref
=
model
(
torch
.
cat
([
input_ids
,
tokens_draft
],
dim
=
1
),
num_last_tokens
=
n_spec_tokens
+
1
).
logits
print
((
logits
-
logits_ref
).
abs
().
max
())
# breakpoint()
tokens
,
num_generated_tokens
=
sample_speculative
(
logits
,
scores_draft
,
tokens_draft
,
**
sampling_kwargs
)
num_accepted_tokens_history
.
append
(
num_generated_tokens
-
1
)
if
debug
:
print
(
tokens
)
print
(
num_generated_tokens
)
# breakpoint()
# TODO: we're using the fact that batch_size == 1
# TODO: check eos_token_id
sequences
.
append
(
tokens
[:
1
,
:
num_generated_tokens
[
0
]])
scores
.
append
(
logits
[:
1
,
:
num_generated_tokens
[
0
]])
# Note that @model has not evaluated the last sampled token yet, so we'll need to pass
# that in the next time we call @model.
num_generated
=
num_generated_tokens
[
0
].
item
()
inference_params
.
seqlen_offset
=
seqlen_og
+
num_generated
-
1
inference_params_draft
.
seqlen_offset
=
(
inference_params
.
seqlen_offset
-
1
if
num_generated
>
1
else
inference_params
.
seqlen_offset
)
if
debug
:
cur_ids
=
torch
.
cat
([
input_ids
,
sequences
[
-
1
]],
dim
=
1
)
scores_ref
=
model
(
cur_ids
,
num_last_tokens
=
num_generated_tokens
[
0
].
item
()
+
1
).
logits
print
((
scores
[
-
1
]
-
scores_ref
[:,
:
-
1
]).
abs
().
max
())
# breakpoint()
while
True
:
# seqlen_offset is total length generated - 1
if
inference_params
.
seqlen_offset
>=
max_length
-
1
:
break
if
inference_params
.
seqlen_offset
>=
max_length
-
2
:
# Don't do speculative sampling, just sample 1 token from the model
tokens
,
scores_new
=
sample_tokens_main
(
sequences
[
-
1
][:,
-
1
:],
num_tokens
=
1
)
sequences
.
append
(
tokens
)
scores
.
append
(
scores_new
)
break
# Sample from draft model
n_spec_tokens
=
min
(
speculative_lookahead
,
max_length
-
inference_params_draft
.
seqlen_offset
-
2
)
# If the main model accepts all the draft tokens, plus it samples one new token,
# then at the next iteration the draft model need to evaluate the logits of the last draft
# token and the logits of the newly sampled token. So here we pass in the last 2 tokens
# of sequences[-1].
# This exception is when the main model rejects all the draft tokens, in which case we
# will only have 1 token to pass in.
tokens_draft
,
scores_draft
=
sample_tokens_draft
(
sequences
[
-
1
][:,
-
2
:],
num_tokens
=
n_spec_tokens
)
num_draft_tokens
+=
n_spec_tokens
if
debug
:
scores_draft_ref
=
model_draft
(
torch
.
cat
([
cur_ids
,
tokens_draft
],
dim
=
1
),
num_last_tokens
=
n_spec_tokens
+
1
).
logits
print
((
scores_draft
-
scores_draft_ref
[:,
:
-
1
]).
abs
().
max
())
# breakpoint()
# Evaluate the draft tokens with the model
logits
=
get_logits_main
(
torch
.
cat
([
sequences
[
-
1
][:,
-
1
:],
tokens_draft
],
dim
=
1
),
inference_params
,
num_last_tokens
=
n_spec_tokens
+
1
,
)
# (batch, n_spec_tokens + 1, vocab_size)
num_main_model_calls
+=
1
if
debug
:
logits_ref
=
model
(
torch
.
cat
([
cur_ids
,
tokens_draft
],
dim
=
1
),
num_last_tokens
=
n_spec_tokens
+
1
).
logits
print
((
logits
-
logits_ref
).
abs
().
max
())
# breakpoint()
tokens
,
num_generated_tokens
=
sample_speculative
(
logits
,
scores_draft
,
tokens_draft
,
**
sampling_kwargs
)
num_accepted_tokens_history
.
append
(
num_generated_tokens
-
1
)
if
debug
:
print
(
tokens
)
print
(
num_generated_tokens
)
# breakpoint()
sequences
.
append
(
tokens
[:
1
,
:
num_generated_tokens
[
0
]])
scores
.
append
(
logits
[:
1
,
:
num_generated_tokens
[
0
]])
# We've evaluated 1 token from sequences[-1][:, -1:] above, plus
# num_generated_tokens[0].item() - 1 tokens from the draft model.
num_generated
=
num_generated_tokens
[
0
].
item
()
inference_params
.
seqlen_offset
+=
num_generated
inference_params_draft
.
seqlen_offset
=
(
inference_params
.
seqlen_offset
-
1
if
num_generated
>
1
else
inference_params
.
seqlen_offset
)
if
debug
:
cur_ids
=
torch
.
cat
([
cur_ids
,
sequences
[
-
1
]],
dim
=
1
)
scores_ref
=
model
(
cur_ids
,
num_last_tokens
=
num_generated_tokens
[
0
].
item
()
+
1
).
logits
print
((
scores
[
-
1
]
-
scores_ref
[:,
:
-
1
]).
abs
().
max
())
# breakpoint()
if
enable_timing
:
if
tensor_parallel
>
1
:
torch
.
distributed
.
barrier
()
torch
.
cuda
.
synchronize
()
print
(
f
"Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms"
)
print
(
f
"Number of calls to main model:
{
num_main_model_calls
}
"
)
print
(
f
"Acceptance rate:
{
torch
.
cat
(
num_accepted_tokens_history
).
sum
().
item
()
/
num_draft_tokens
*
100
:.
2
f
}
%"
)
sequences
=
torch
.
cat
(
sequences
,
dim
=
1
)
scores
=
torch
.
cat
(
scores
,
dim
=
1
)
if
debug
:
scores_ref
=
model
(
sequences
).
logits
print
((
scores
-
scores_ref
[:,
seqlen_og
-
1
:
-
1
]).
abs
().
max
())
output_cls
=
GreedySearchDecoderOnlyOutput
if
top_k
==
1
else
SampleDecoderOnlyOutput
return
output_cls
(
sequences
=
sequences
,
scores
=
scores
)
class
GenerationMixin
:
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
**
kwargs
):
raise
NotImplementedError
def
generate
(
self
,
input_ids
,
max_length
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
,
return_dict_in_generate
=
False
,
output_scores
=
False
,
**
kwargs
,
):
output
=
decode
(
input_ids
,
self
,
max_length
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
,
**
kwargs
)
if
not
output_scores
:
output
.
scores
=
None
return
output
if
return_dict_in_generate
else
output
.
sequences
def
allocate_inference_cache
(
max_batch_size
,
max_seqlen
,
nheads
,
headdim
,
layers
:
Union
[
int
,
Sequence
],
device
,
dtype
=
torch
.
float16
,
):
assert
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
kv_cache_shape
=
(
max_batch_size
,
max_seqlen
,
2
,
nheads
,
headdim
)
if
isinstance
(
layers
,
int
):
layers
=
range
(
layers
)
return
{
i
:
torch
.
empty
(
kv_cache_shape
,
device
=
device
,
dtype
=
dtype
)
for
i
in
layers
}
@
dataclass
class
DecodingCGCache
:
max_batch_size
:
int
=
0
max_seqlen
:
int
=
0
device
=
None
dtype
=
None
callables
:
dict
=
field
(
default_factory
=
dict
)
mempool
=
None
inference_params
:
Optional
[
InferenceParams
]
=
None
run
:
Optional
[
Callable
]
=
None
@
torch
.
inference_mode
()
def
update_graph_cache
(
model
,
cache
,
batch_size
,
seqlen_og
,
max_seqlen
,
decoding_seqlens
=
(
1
,),
tensor_parallel
=
1
,
dtype
=
None
,
n_warmups
=
2
,
):
if
cache
is
None
:
cache
=
DecodingCGCache
()
param_example
=
next
(
iter
(
model
.
parameters
()))
device
=
param_example
.
device
if
dtype
is
None
:
dtype
=
param_example
.
dtype
if
(
(
device
,
dtype
)
!=
(
cache
.
device
,
cache
.
dtype
)
or
batch_size
>
cache
.
max_batch_size
or
max_seqlen
>
cache
.
max_seqlen
):
# Invalidate the cache
cache
.
callables
=
{}
cache
.
mempool
=
None
cache
.
inference_params
=
None
gc
.
collect
()
cache
.
device
,
cache
.
dtype
=
device
,
dtype
cache
.
max_batch_size
,
cache
.
max_seqlen
=
batch_size
,
max_seqlen
if
hasattr
(
model
,
"allocate_inference_cache"
):
inf_cache
=
model
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
)
else
:
headdim
=
getattr
(
model
.
config
,
"head_dim"
,
model
.
config
.
hidden_size
//
model
.
config
.
num_attention_heads
,
)
inf_cache
=
allocate_inference_cache
(
batch_size
,
max_seqlen
,
model
.
config
.
num_attention_heads
//
tensor_parallel
,
headdim
,
model
.
config
.
num_hidden_layers
,
device
,
dtype
,
)
lengths_per_sample
=
torch
.
full
((
batch_size
,),
seqlen_og
,
dtype
=
torch
.
int32
,
device
=
device
)
cache
.
inference_params
=
InferenceParams
(
max_seqlen
=
max_seqlen
,
max_batch_size
=
batch_size
,
seqlen_offset
=
seqlen_og
,
key_value_memory_dict
=
inf_cache
,
lengths_per_sample
=
lengths_per_sample
,
)
cache
.
mempool
=
torch
.
cuda
.
graphs
.
graph_pool_handle
()
for
decoding_seqlen
in
decoding_seqlens
:
if
(
batch_size
,
decoding_seqlen
)
not
in
cache
.
callables
:
cache
.
callables
[
batch_size
,
decoding_seqlen
]
=
capture_graph
(
model
,
cache
.
inference_params
,
batch_size
,
max_seqlen
,
decoding_seqlen
=
decoding_seqlen
,
mempool
=
cache
.
mempool
,
n_warmups
=
n_warmups
,
)
def
dispatch
(
input_ids
,
position_ids
,
seqlen
):
batch_size
,
decoding_seqlen
=
input_ids
.
shape
[:
2
]
return
cache
.
callables
[
batch_size
,
decoding_seqlen
](
input_ids
,
position_ids
,
seqlen
)
cache
.
run
=
dispatch
cache
.
inference_params
.
seqlen_offset
=
0
# Reset so it's not confusing
return
cache
def
capture_graph
(
model
,
inference_params
,
batch_size
,
max_seqlen
,
decoding_seqlen
=
1
,
mempool
=
None
,
n_warmups
=
2
):
device
=
next
(
iter
(
model
.
parameters
())).
device
input_ids
=
torch
.
full
((
batch_size
,
decoding_seqlen
),
0
,
dtype
=
torch
.
long
,
device
=
device
)
position_ids
=
torch
.
full
((
batch_size
,
decoding_seqlen
),
0
,
dtype
=
torch
.
long
,
device
=
device
)
seqlen_offset_og
=
inference_params
.
seqlen_offset
inference_params
.
seqlen_offset
=
max_seqlen
-
decoding_seqlen
inference_params
.
lengths_per_sample
[:]
=
inference_params
.
seqlen_offset
# Warmup before capture
s
=
torch
.
cuda
.
Stream
()
s
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
s
):
for
_
in
range
(
n_warmups
):
logits
=
model
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
,
num_last_tokens
=
decoding_seqlen
,
).
logits
s
.
synchronize
()
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
# which requires that graph launch and non-captured launch to not overlap (I think,
# that's how I interpret the documentation). I'm not sure if this is required.
if
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
barrier
()
torch
.
cuda
.
current_stream
().
wait_stream
(
s
)
# Captures the graph
# To allow capture, automatically sets a side stream as the current stream in the context
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
,
pool
=
mempool
):
logits
=
model
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
,
num_last_tokens
=
decoding_seqlen
,
).
logits
def
run
(
new_input_ids
,
new_position_ids
,
seqlen
):
inference_params
.
lengths_per_sample
[:]
=
seqlen
input_ids
.
copy_
(
new_input_ids
)
position_ids
.
copy_
(
new_position_ids
)
graph
.
replay
()
return
logits
.
clone
()
inference_params
.
seqlen_offset
=
seqlen_offset_og
return
run
flash_attn/utils/pretrained.py
deleted
100644 → 0
View file @
5018ac6a
import
os
from
functools
import
partial
import
torch
from
safetensors.torch
import
load_file
as
safe_load_file
from
transformers.utils
import
(
SAFE_WEIGHTS_INDEX_NAME
,
SAFE_WEIGHTS_NAME
,
WEIGHTS_INDEX_NAME
,
WEIGHTS_NAME
,
)
from
transformers.utils.hub
import
cached_file
,
get_checkpoint_shard_files
def
state_dict_from_pretrained
(
model_name
,
device
=
None
,
dtype
=
None
):
# If not fp32, then we don't want to load directly to the GPU
mapped_device
=
"cpu"
if
dtype
not
in
[
torch
.
float32
,
None
]
else
device
is_sharded
=
False
load_safe
=
False
resolved_archive_file
=
None
weights_path
=
os
.
path
.
join
(
model_name
,
WEIGHTS_NAME
)
weights_index_path
=
os
.
path
.
join
(
model_name
,
WEIGHTS_INDEX_NAME
)
safe_weights_path
=
os
.
path
.
join
(
model_name
,
SAFE_WEIGHTS_NAME
)
safe_weights_index_path
=
os
.
path
.
join
(
model_name
,
SAFE_WEIGHTS_INDEX_NAME
)
if
os
.
path
.
isfile
(
weights_path
):
resolved_archive_file
=
cached_file
(
model_name
,
WEIGHTS_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
elif
os
.
path
.
isfile
(
weights_index_path
):
resolved_archive_file
=
cached_file
(
model_name
,
WEIGHTS_INDEX_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
is_sharded
=
True
elif
os
.
path
.
isfile
(
safe_weights_path
):
resolved_archive_file
=
cached_file
(
model_name
,
SAFE_WEIGHTS_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
load_safe
=
True
elif
os
.
path
.
isfile
(
safe_weights_index_path
):
resolved_archive_file
=
cached_file
(
model_name
,
SAFE_WEIGHTS_INDEX_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
is_sharded
=
True
load_safe
=
True
else
:
# Try loading from HF hub instead of from local files
resolved_archive_file
=
cached_file
(
model_name
,
WEIGHTS_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
if
resolved_archive_file
is
None
:
resolved_archive_file
=
cached_file
(
model_name
,
WEIGHTS_INDEX_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
if
resolved_archive_file
is
not
None
:
is_sharded
=
True
if
resolved_archive_file
is
None
:
raise
EnvironmentError
(
f
"Model name
{
model_name
}
was not found."
)
if
load_safe
:
loader
=
partial
(
safe_load_file
,
device
=
mapped_device
)
else
:
loader
=
partial
(
torch
.
load
,
map_location
=
mapped_device
)
if
is_sharded
:
# resolved_archive_file becomes a list of files that point to the different
# checkpoint shards in this case.
resolved_archive_file
,
sharded_metadata
=
get_checkpoint_shard_files
(
model_name
,
resolved_archive_file
)
state_dict
=
{}
for
sharded_file
in
resolved_archive_file
:
state_dict
.
update
(
loader
(
sharded_file
))
else
:
state_dict
=
loader
(
resolved_archive_file
)
# Convert dtype before moving to GPU to save memory
if
dtype
is
not
None
:
state_dict
=
{
k
:
v
.
to
(
dtype
=
dtype
)
for
k
,
v
in
state_dict
.
items
()}
state_dict
=
{
k
:
v
.
to
(
device
=
device
)
for
k
,
v
in
state_dict
.
items
()}
return
state_dict
setup.py
View file @
26f4b5fb
...
...
@@ -49,7 +49,7 @@ else:
elif
BUILD_TARGET
==
"rocm"
:
IS_ROCM
=
True
PACKAGE_NAME
=
"flash_attn"
PACKAGE_NAME
=
"
vllm_
flash_attn"
BASE_WHEEL_URL
=
(
"https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}"
...
...
@@ -57,10 +57,10 @@ BASE_WHEEL_URL = (
# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
FORCE_BUILD
=
os
.
getenv
(
"FLASH_ATTENTION_FORCE_BUILD"
,
"FALSE"
)
==
"TRUE"
SKIP_CUDA_BUILD
=
os
.
getenv
(
"FLASH_ATTENTION_SKIP_CUDA_BUILD"
,
"FALSE"
)
==
"TRUE"
FORCE_BUILD
=
True
SKIP_CUDA_BUILD
=
False
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
FORCE_CXX11_ABI
=
os
.
getenv
(
"FLASH_ATTENTION_FORCE_CXX11_ABI"
,
"FALSE"
)
==
"TRUE"
FORCE_CXX11_ABI
=
torch
.
_C
.
_GLIBCXX_USE_CXX11_ABI
def
get_platform
():
...
...
@@ -151,7 +151,7 @@ if not SKIP_CUDA_BUILD and not IS_ROCM:
if
os
.
path
.
exists
(
os
.
path
.
join
(
torch_dir
,
"include"
,
"ATen"
,
"CUDAGeneratorImpl.h"
)):
generator_flag
=
[
"-DOLD_GENERATOR_PATH"
]
check_if_cuda_home_none
(
"flash_attn"
)
check_if_cuda_home_none
(
PACKAGE_NAME
)
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag
=
[]
if
CUDA_HOME
is
not
None
:
...
...
@@ -177,7 +177,7 @@ if not SKIP_CUDA_BUILD and not IS_ROCM:
torch
.
_C
.
_GLIBCXX_USE_CXX11_ABI
=
True
ext_modules
.
append
(
CUDAExtension
(
name
=
"flash_attn_2_cuda"
,
name
=
"
vllm_
flash_attn_2_cuda"
,
sources
=
[
"csrc/flash_attn/flash_api.cpp"
,
"csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu"
,
...
...
@@ -208,34 +208,6 @@ if not SKIP_CUDA_BUILD and not IS_ROCM:
"csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu"
,
...
...
@@ -282,10 +254,10 @@ if not SKIP_CUDA_BUILD and not IS_ROCM:
# "--ptxas-options=-O2",
# "-lineinfo",
# "-DFLASHATTENTION_DISABLE_BACKWARD",
#
"-DFLASHATTENTION_DISABLE_DROPOUT",
"-DFLASHATTENTION_DISABLE_DROPOUT"
,
# "-DFLASHATTENTION_DISABLE_ALIBI",
# "-DFLASHATTENTION_DISABLE_SOFTCAP",
#
"-DFLASHATTENTION_DISABLE_UNEVEN_K",
"-DFLASHATTENTION_DISABLE_UNEVEN_K"
,
# "-DFLASHATTENTION_DISABLE_LOCAL",
]
+
generator_flag
...
...
@@ -391,7 +363,7 @@ elif not SKIP_CUDA_BUILD and IS_ROCM:
def
get_package_version
():
with
open
(
Path
(
this_dir
)
/
"flash_attn"
/
"__init__.py"
,
"r"
)
as
f
:
with
open
(
Path
(
this_dir
)
/
PACKAGE_NAME
/
"__init__.py"
,
"r"
)
as
f
:
version_match
=
re
.
search
(
r
"^__version__\s*=\s*(.*)$"
,
f
.
read
(),
re
.
MULTILINE
)
public_version
=
ast
.
literal_eval
(
version_match
.
group
(
1
))
local_version
=
os
.
environ
.
get
(
"FLASH_ATTN_LOCAL_VERSION"
)
...
...
@@ -401,37 +373,6 @@ def get_package_version():
return
str
(
public_version
)
def
get_wheel_url
():
torch_version_raw
=
parse
(
torch
.
__version__
)
python_version
=
f
"cp
{
sys
.
version_info
.
major
}{
sys
.
version_info
.
minor
}
"
platform_name
=
get_platform
()
flash_version
=
get_package_version
()
torch_version
=
f
"
{
torch_version_raw
.
major
}
.
{
torch_version_raw
.
minor
}
"
cxx11_abi
=
str
(
torch
.
_C
.
_GLIBCXX_USE_CXX11_ABI
).
upper
()
if
IS_ROCM
:
torch_hip_version
=
parse
(
torch
.
version
.
hip
.
split
()[
-
1
].
rstrip
(
'-'
).
replace
(
'-'
,
'+'
))
hip_version
=
f
"
{
torch_hip_version
.
major
}{
torch_hip_version
.
minor
}
"
wheel_filename
=
f
"
{
PACKAGE_NAME
}
-
{
flash_version
}
+rocm
{
hip_version
}
torch
{
torch_version
}
cxx11abi
{
cxx11_abi
}
-
{
python_version
}
-
{
python_version
}
-
{
platform_name
}
.whl"
else
:
# Determine the version numbers that will be used to determine the correct wheel
# We're using the CUDA version used to build torch, not the one currently installed
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
torch_cuda_version
=
parse
(
torch
.
version
.
cuda
)
# For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3
# to save CI time. Minor versions should be compatible.
torch_cuda_version
=
parse
(
"11.8"
)
if
torch_cuda_version
.
major
==
11
else
parse
(
"12.3"
)
# cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
cuda_version
=
f
"
{
torch_cuda_version
.
major
}{
torch_cuda_version
.
minor
}
"
# Determine wheel URL based on CUDA version, torch version, python version and OS
wheel_filename
=
f
"
{
PACKAGE_NAME
}
-
{
flash_version
}
+cu
{
cuda_version
}
torch
{
torch_version
}
cxx11abi
{
cxx11_abi
}
-
{
python_version
}
-
{
python_version
}
-
{
platform_name
}
.whl"
wheel_url
=
BASE_WHEEL_URL
.
format
(
tag_name
=
f
"v
{
flash_version
}
"
,
wheel_name
=
wheel_filename
)
return
wheel_url
,
wheel_filename
class
CachedWheelsCommand
(
_bdist_wheel
):
"""
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
...
...
@@ -444,28 +385,6 @@ class CachedWheelsCommand(_bdist_wheel):
if
FORCE_BUILD
:
return
super
().
run
()
wheel_url
,
wheel_filename
=
get_wheel_url
()
print
(
"Guessing wheel URL: "
,
wheel_url
)
try
:
urllib
.
request
.
urlretrieve
(
wheel_url
,
wheel_filename
)
# Make the archive
# Lifted from the root wheel processing command
# https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
if
not
os
.
path
.
exists
(
self
.
dist_dir
):
os
.
makedirs
(
self
.
dist_dir
)
impl_tag
,
abi_tag
,
plat_tag
=
self
.
get_tag
()
archive_basename
=
f
"
{
self
.
wheel_dist_name
}
-
{
impl_tag
}
-
{
abi_tag
}
-
{
plat_tag
}
"
wheel_path
=
os
.
path
.
join
(
self
.
dist_dir
,
archive_basename
+
".whl"
)
print
(
"Raw wheel path"
,
wheel_path
)
os
.
rename
(
wheel_filename
,
wheel_path
)
except
(
urllib
.
error
.
HTTPError
,
urllib
.
error
.
URLError
):
print
(
"Precompiled wheel not found. Building from source..."
)
# If the wheel could not be downloaded, build from source
super
().
run
()
class
NinjaBuildExtension
(
BuildExtension
):
def
__init__
(
self
,
*
args
,
**
kwargs
)
->
None
:
...
...
@@ -487,8 +406,11 @@ class NinjaBuildExtension(BuildExtension):
super
().
__init__
(
*
args
,
**
kwargs
)
PYTORCH_VERSION
=
"2.4.0"
CUDA_VERSION
=
"12.1"
setup
(
name
=
PACKAGE_NAME
,
name
=
"vllm-flash-attn"
,
version
=
get_package_version
(),
packages
=
find_packages
(
exclude
=
(
...
...
@@ -499,15 +421,13 @@ setup(
"dist"
,
"docs"
,
"benchmarks"
,
"flash_attn
.egg-info"
,
f
"
{
PACKAGE_NAME
}
.egg-info"
,
)
),
author
=
"Tri Dao"
,
author_email
=
"tri@tridao.me"
,
description
=
"Flash Attention: Fast and Memory-Efficient Exact Attention"
,
long_description
=
long_description
,
long_description_content_type
=
"text/markdown"
,
url
=
"https://github.com/Dao-AILab/flash-attention"
,
author
=
"vLLM Team"
,
description
=
"Forward-only flash-attn"
,
long_description
=
f
"Forward-only flash-attn package built for PyTorch
{
PYTORCH_VERSION
}
and CUDA
{
CUDA_VERSION
}
"
,
url
=
"https://github.com/vllm-project/flash-attention.git"
,
classifiers
=
[
"Programming Language :: Python :: 3"
,
"License :: OSI Approved :: BSD License"
,
...
...
@@ -520,13 +440,6 @@ setup(
"bdist_wheel"
:
CachedWheelsCommand
,
},
python_requires
=
">=3.8"
,
install_requires
=
[
"torch"
,
"einops"
,
],
setup_requires
=
[
"packaging"
,
"psutil"
,
"ninja"
,
],
install_requires
=
[
f
"torch ==
{
PYTORCH_VERSION
}
"
],
setup_requires
=
[
"psutil"
],
)
tests/test_flash_attn.py
View file @
26f4b5fb
...
...
@@ -1588,7 +1588,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
],
)
# TODO: add smaller page sizes when https://github.com/Dao-AILab/flash-attention/pull/824 is merged
@
pytest
.
mark
.
parametrize
(
"paged_kv_block_size"
,
[
None
,
256
,
512
])
@
pytest
.
mark
.
parametrize
(
"paged_kv_block_size"
,
[
None
,
16
,
256
,
512
])
# @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)])
def
test_flash_attn_varlen_causal
(
seqlen_q
,
seqlen_k
,
swap_sq_sk
,
d
,
local
,
paged_kv_block_size
,
dtype
...
...
@@ -1875,7 +1875,7 @@ def test_flash_attn_splitkv(
# @pytest.mark.parametrize("rotary_interleaved", [False])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
0.0
,
0.5
,
1.0
])
# @pytest.mark.parametrize("rotary_fraction", [0.0])
@
pytest
.
mark
.
parametrize
(
"paged_kv_block_size"
,
[
None
,
256
])
@
pytest
.
mark
.
parametrize
(
"paged_kv_block_size"
,
[
None
,
16
,
256
])
# @pytest.mark.parametrize("paged_kv_block_size", [256, 512])
# @pytest.mark.parametrize("paged_kv_block_size", [None])
@
pytest
.
mark
.
parametrize
(
"has_leftpad"
,
[
False
,
True
])
...
...
@@ -2523,3 +2523,47 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus
assert
torch
.
equal
(
dv
,
dv0
)
assert
torch
.
equal
(
dk
,
dk0
)
assert
torch
.
equal
(
dq
,
dq0
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [False])
@
pytest
.
mark
.
parametrize
(
"paged_kv_block_size"
,
[
16
])
# @pytest.mark.parametrize("has_batch_idx", [False])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"nheads"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"b"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
10
])
@
pytest
.
mark
.
parametrize
(
"seqlen_q,seqlen_k"
,
[(
170
,
170
)])
def
test_flash_attn_paged_kvcache_overflow
(
seqlen_q
,
seqlen_k
,
d
,
nheads
,
b
,
n
,
paged_kv_block_size
,
causal
,
dtype
,
):
device
=
"cuda"
num_blocks
=
1000
*
16
//
paged_kv_block_size
key_cache
=
torch
.
rand
([
num_blocks
,
paged_kv_block_size
,
nheads
,
d
],
dtype
=
dtype
,
device
=
device
)
value_cache
=
torch
.
rand
([
num_blocks
,
paged_kv_block_size
,
nheads
,
d
],
dtype
=
dtype
,
device
=
device
)
cache_seqlens
=
torch
.
zeros
(
b
,
dtype
=
torch
.
int32
,
device
=
device
)
for
_
in
range
(
n
):
query
=
torch
.
rand
([
b
,
seqlen_q
,
nheads
,
d
],
dtype
=
dtype
,
device
=
device
)
key
=
torch
.
rand
([
b
,
seqlen_k
,
nheads
,
d
],
dtype
=
dtype
,
device
=
device
)
value
=
torch
.
rand
([
b
,
seqlen_k
,
nheads
,
d
],
dtype
=
dtype
,
device
=
device
)
block_tables
=
torch
.
randint
(
0
,
num_blocks
,
size
=
(
b
,
(
seqlen_k
+
paged_kv_block_size
-
1
)
//
paged_kv_block_size
),
dtype
=
torch
.
int32
,
device
=
device
)
output
=
flash_attn_with_kvcache
(
query
,
key_cache
,
value_cache
,
k
=
key
,
v
=
value
,
cache_seqlens
=
cache_seqlens
,
block_table
=
block_tables
,
causal
=
causal
,
)
flash_attn/__init__.py
→
vllm_
flash_attn/__init__.py
View file @
26f4b5fb
__version__
=
"2.6.
3
"
__version__
=
"2.6.
0
"
from
flash_attn.flash_attn_interface
import
(
from
vllm_
flash_attn.flash_attn_interface
import
(
flash_attn_func
,
flash_attn_kvpacked_func
,
flash_attn_qkvpacked_func
,
...
...
flash_attn/flash_attn_interface.py
→
vllm_
flash_attn/flash_attn_interface.py
View file @
26f4b5fb
...
...
@@ -7,7 +7,7 @@ import torch.nn as nn
# isort: off
# We need to import the CUDA kernels after importing torch
import
flash_attn_2_cuda
as
flash_attn_cuda
import
vllm_
flash_attn_2_cuda
as
flash_attn_cuda
# isort: on
...
...
@@ -46,14 +46,14 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal):
def
_flash_attn_forward
(
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
return_softmax
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
return_softmax
,
*
,
out
=
None
):
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
flash_attn_cuda
.
fwd
(
q
,
k
,
v
,
None
,
out
,
alibi_slopes
,
dropout_p
,
softmax_scale
,
...
...
@@ -91,7 +91,7 @@ def _flash_attn_varlen_forward(
q
,
k
,
v
,
None
,
out
,
cu_seqlens_q
,
cu_seqlens_k
,
seqused_k
,
...
...
@@ -239,6 +239,8 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
alibi_slopes
,
deterministic
,
return_softmax
,
*
,
out
=
None
,
):
if
softmax_scale
is
None
:
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
...
...
@@ -253,6 +255,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
out
=
out
,
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
...
...
@@ -307,6 +310,8 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
alibi_slopes
,
deterministic
,
return_softmax
,
*
,
out
=
None
,
):
if
softmax_scale
is
None
:
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
...
...
@@ -326,6 +331,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
block_table
=
None
,
out
=
out
,
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
cu_seqlens
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
...
...
@@ -384,6 +390,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
alibi_slopes
,
deterministic
,
return_softmax
,
out
=
None
,
):
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
...
...
@@ -398,6 +405,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
out
=
out
,
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
...
...
@@ -457,6 +465,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
alibi_slopes
,
deterministic
,
return_softmax
,
out
=
None
,
):
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
...
...
@@ -476,6 +485,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
block_table
=
None
,
out
=
out
,
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
cu_seqlens_q
,
cu_seqlens_k
,
rng_state
...
...
@@ -540,6 +550,7 @@ class FlashAttnFunc(torch.autograd.Function):
alibi_slopes
,
deterministic
,
return_softmax
,
out
=
None
,
):
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
...
...
@@ -554,6 +565,7 @@ class FlashAttnFunc(torch.autograd.Function):
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
out
=
out
,
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
...
...
@@ -614,6 +626,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
deterministic
,
return_softmax
,
block_table
,
out
=
None
,
):
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
...
...
@@ -633,6 +646,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
block_table
=
block_table
,
out
=
out
,
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
cu_seqlens_q
,
cu_seqlens_k
,
rng_state
...
...
@@ -691,6 +705,8 @@ def flash_attn_qkvpacked_func(
alibi_slopes
=
None
,
deterministic
=
False
,
return_attn_probs
=
False
,
*
,
out
=
None
,
):
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
...
...
@@ -736,6 +752,7 @@ def flash_attn_qkvpacked_func(
alibi_slopes
,
deterministic
,
return_attn_probs
,
out
,
)
...
...
@@ -750,6 +767,8 @@ def flash_attn_kvpacked_func(
alibi_slopes
=
None
,
deterministic
=
False
,
return_attn_probs
=
False
,
*
,
out
=
None
,
):
"""dropout_p should be set to 0.0 during evaluation
If K, V are already stacked into 1 tensor, this function will be faster than
...
...
@@ -813,6 +832,7 @@ def flash_attn_kvpacked_func(
alibi_slopes
,
deterministic
,
return_attn_probs
,
out
,
)
...
...
@@ -828,6 +848,8 @@ def flash_attn_func(
alibi_slopes
=
None
,
deterministic
=
False
,
return_attn_probs
=
False
,
*
,
out
=
None
,
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
...
...
@@ -889,6 +911,7 @@ def flash_attn_func(
alibi_slopes
,
deterministic
,
return_attn_probs
,
out
,
)
...
...
@@ -904,6 +927,8 @@ def flash_attn_varlen_qkvpacked_func(
alibi_slopes
=
None
,
deterministic
=
False
,
return_attn_probs
=
False
,
*
,
out
=
None
,
):
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
...
...
@@ -954,6 +979,7 @@ def flash_attn_varlen_qkvpacked_func(
alibi_slopes
,
deterministic
,
return_attn_probs
,
out
,
)
...
...
@@ -972,6 +998,8 @@ def flash_attn_varlen_kvpacked_func(
alibi_slopes
=
None
,
deterministic
=
False
,
return_attn_probs
=
False
,
*
,
out
=
None
,
):
"""dropout_p should be set to 0.0 during evaluation
If K, V are already stacked into 1 tensor, this function will be faster than
...
...
@@ -1045,6 +1073,7 @@ def flash_attn_varlen_kvpacked_func(
alibi_slopes
,
deterministic
,
return_attn_probs
,
out
,
)
...
...
@@ -1065,6 +1094,8 @@ def flash_attn_varlen_func(
deterministic
=
False
,
return_attn_probs
=
False
,
block_table
=
None
,
*
,
out
=
None
,
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
...
...
@@ -1138,6 +1169,7 @@ def flash_attn_varlen_func(
deterministic
,
return_attn_probs
,
block_table
,
out
,
)
...
...
@@ -1161,6 +1193,8 @@ def flash_attn_with_kvcache(
alibi_slopes
=
None
,
num_splits
=
0
,
return_softmax_lse
=
False
,
*
,
out
=
None
,
):
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
...
...
@@ -1274,7 +1308,7 @@ def flash_attn_with_kvcache(
cache_leftpad
,
block_table
,
alibi_slopes
,
None
,
out
,
softmax_scale
,
causal
,
window_size
[
0
],
...
...
flash_attn/pyproject.toml
→
vllm_
flash_attn/pyproject.toml
View file @
26f4b5fb
File moved
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment