Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
26f4b5fb
"src/vscode:/vscode.git/clone" did not exist on "a60bdb672f685cdfaf52a497be5aaa23eed0e0ba"
Commit
26f4b5fb
authored
Jul 31, 2024
by
Woosuk Kwon
Browse files
Merge branch 'main' into Dao-AILab/main
parents
5018ac6a
12375706
Pipeline
#2015
failed with stages
in 0 seconds
Changes
95
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
107 additions
and
3565 deletions
+107
-3565
flash_attn/ops/triton/k_activations.py
flash_attn/ops/triton/k_activations.py
+0
-162
flash_attn/ops/triton/layer_norm.py
flash_attn/ops/triton/layer_norm.py
+0
-1086
flash_attn/ops/triton/linear.py
flash_attn/ops/triton/linear.py
+0
-594
flash_attn/ops/triton/mlp.py
flash_attn/ops/triton/mlp.py
+0
-149
flash_attn/ops/triton/rotary.py
flash_attn/ops/triton/rotary.py
+0
-227
flash_attn/utils/__init__.py
flash_attn/utils/__init__.py
+0
-0
flash_attn/utils/benchmark.py
flash_attn/utils/benchmark.py
+0
-268
flash_attn/utils/distributed.py
flash_attn/utils/distributed.py
+0
-144
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+0
-740
flash_attn/utils/pretrained.py
flash_attn/utils/pretrained.py
+0
-79
setup.py
setup.py
+20
-107
tests/test_flash_attn.py
tests/test_flash_attn.py
+46
-2
vllm_flash_attn/__init__.py
vllm_flash_attn/__init__.py
+2
-2
vllm_flash_attn/flash_attn_interface.py
vllm_flash_attn/flash_attn_interface.py
+39
-5
vllm_flash_attn/pyproject.toml
vllm_flash_attn/pyproject.toml
+0
-0
No files found.
flash_attn/ops/triton/k_activations.py
deleted
100644 → 0
View file @
5018ac6a
# Adapted from https://github.com/facebookresearch/xformers/blob/main/xformers/triton/k_activations.py
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import
math
from
enum
import
Enum
from
typing
import
Optional
import
triton
import
triton.language
as
tl
_sqrt2pi
=
math
.
sqrt
(
2.0
/
math
.
pi
)
_sqrt1_2
=
math
.
sqrt
(
1.0
/
2
)
_gaussian_pdf_normalization
=
1.0
/
math
.
sqrt
(
2
*
math
.
pi
)
class
Activation
(
str
,
Enum
):
SquaredReLU
=
"squared_relu"
GeLU
=
"gelu"
GeLUApprox
=
"gelu_approx"
LeakyReLU
=
"leaky_relu"
ReLU
=
"relu"
def
get_triton_activation_kernel
(
activation
:
Optional
[
Activation
]):
return
(
{
Activation
.
ReLU
:
relu
,
Activation
.
LeakyReLU
:
leaky_relu
,
Activation
.
GeLU
:
gelu
,
Activation
.
GeLUApprox
:
gelu_approx
,
Activation
.
SquaredReLU
:
squared_relu
,
}[
activation
]
if
activation
else
None
)
def
get_triton_activation_bwd_kernel
(
activation
:
Optional
[
Activation
]):
return
(
{
Activation
.
ReLU
:
relu_grad
,
Activation
.
LeakyReLU
:
leaky_relu_grad
,
Activation
.
GeLU
:
gelu_grad
,
Activation
.
GeLUApprox
:
gelu_approx_grad
,
Activation
.
SquaredReLU
:
squared_relu_grad
,
}[
activation
]
if
activation
else
None
)
@
triton
.
jit
def
tanh
(
x
):
# Tanh is just a scaled sigmoid
return
2
*
tl
.
sigmoid
(
2
*
x
)
-
1
@
triton
.
jit
def
cosh
(
x
):
exp_x
=
tl
.
exp
(
x
)
return
(
exp_x
+
1.0
/
exp_x
)
*
0.5
# a Triton implementation of the most used activations
# See for instance http://arxiv.org/abs/1606.08415 for an overview
# ReLU
@
triton
.
jit
def
relu
(
x
):
"""
ReLU_ activation function
.. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html
"""
zero
=
0.0
return
tl
.
where
(
x
>=
0
,
x
,
zero
.
to
(
x
.
dtype
))
@
triton
.
jit
def
relu_grad
(
x
):
# ReLU is different from other activations
# in that it does not require the input to retrospectively compute its gradient
# here the input is the downstream gradient, and we return the upstream gradient directly
zero
=
0.0
one
=
1.0
return
tl
.
where
(
x
>=
0
,
one
.
to
(
x
.
dtype
),
zero
.
to
(
x
.
dtype
))
@
triton
.
jit
def
squared_relu
(
x
):
"""
Squared ReLU activation, as proposed in the Primer_ paper.
.. _Primer: https://arxiv.org/abs/2109.08668
"""
x_
=
relu
(
x
)
return
(
x_
*
x_
).
to
(
x
.
dtype
)
@
triton
.
jit
def
squared_relu_grad
(
x
):
return
tl
.
where
(
x
>=
0
,
2.0
*
x
,
0.0
)
# Leaky ReLU
@
triton
.
jit
def
leaky_relu
(
x
):
"""
LeakyReLU_ activation
.. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html
"""
scale
=
0.01
+
0.0
scale
=
scale
.
to
(
x
.
dtype
)
return
tl
.
where
(
x
>=
0
,
x
,
scale
*
x
)
@
triton
.
jit
def
leaky_relu_grad
(
x
):
min_grad
=
0.01
max_grad
=
1
min_grad
=
min_grad
.
to
(
x
.
dtype
)
max_grad
=
max_grad
.
to
(
x
.
dtype
)
return
tl
.
where
(
x
>=
0
,
max_grad
,
min_grad
)
@
triton
.
jit
def
gelu
(
x
):
"""Gaussian Error Linear Unit (GELU)"""
return
x
*
0.5
*
(
1.0
+
tl
.
libdevice
.
erf
(
x
*
_sqrt1_2
))
@
triton
.
jit
def
gelu_grad
(
x
):
cdf
=
0.5
*
(
1.0
+
tl
.
libdevice
.
erf
(
x
*
_sqrt1_2
))
pdf
=
tl
.
exp
(
-
0.5
*
x
*
x
)
*
_gaussian_pdf_normalization
return
cdf
+
x
*
pdf
@
triton
.
jit
def
gelu_approx
(
x
):
"""
GeLU_ activation - Gaussian error linear unit, with tanh approximation
.. _GeLU: https://arxiv.org/pdf/1606.08415.pdf
"""
return
0.5
*
x
*
(
1.0
+
tanh
(
_sqrt2pi
*
x
*
(
1.0
+
0.044715
*
x
*
x
)))
@
triton
.
jit
def
gelu_approx_grad
(
x
):
# CREDITS: Fast implementation proposed in
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30
tanh_out
=
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
))
return
0.5
*
x
*
((
1
-
tanh_out
*
tanh_out
)
*
(
0.79788456
+
0.1070322243
*
x
*
x
))
+
0.5
*
(
1
+
tanh_out
)
flash_attn/ops/triton/layer_norm.py
deleted
100644 → 0
View file @
5018ac6a
# Copyright (c) 2024, Tri Dao.
# Implement dropout + residual + layer_norm / rms_norm.
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
import
math
import
torch
import
torch.nn.functional
as
F
from
torch.cuda.amp
import
custom_fwd
,
custom_bwd
import
triton
import
triton.language
as
tl
def
layer_norm_ref
(
x
,
weight
,
bias
,
residual
=
None
,
x1
=
None
,
weight1
=
None
,
bias1
=
None
,
eps
=
1e-6
,
dropout_p
=
0.0
,
rowscale
=
None
,
prenorm
=
False
,
dropout_mask
=
None
,
dropout_mask1
=
None
,
upcast
=
False
,
):
dtype
=
x
.
dtype
if
upcast
:
x
=
x
.
float
()
weight
=
weight
.
float
()
bias
=
bias
.
float
()
if
bias
is
not
None
else
None
residual
=
residual
.
float
()
if
residual
is
not
None
else
residual
x1
=
x1
.
float
()
if
x1
is
not
None
else
None
weight1
=
weight1
.
float
()
if
weight1
is
not
None
else
None
bias1
=
bias1
.
float
()
if
bias1
is
not
None
else
None
if
x1
is
not
None
:
assert
rowscale
is
None
,
"rowscale is not supported with parallel LayerNorm"
if
rowscale
is
not
None
:
x
=
x
*
rowscale
[...,
None
]
if
dropout_p
>
0.0
:
if
dropout_mask
is
not
None
:
x
=
x
.
masked_fill
(
~
dropout_mask
,
0.0
)
/
(
1.0
-
dropout_p
)
else
:
x
=
F
.
dropout
(
x
,
p
=
dropout_p
)
if
x1
is
not
None
:
if
dropout_mask1
is
not
None
:
x1
=
x1
.
masked_fill
(
~
dropout_mask1
,
0.0
)
/
(
1.0
-
dropout_p
)
else
:
x1
=
F
.
dropout
(
x1
,
p
=
dropout_p
)
if
x1
is
not
None
:
x
=
x
+
x1
if
residual
is
not
None
:
x
=
(
x
+
residual
).
to
(
x
.
dtype
)
out
=
F
.
layer_norm
(
x
.
to
(
weight
.
dtype
),
x
.
shape
[
-
1
:],
weight
=
weight
,
bias
=
bias
,
eps
=
eps
).
to
(
dtype
)
if
weight1
is
None
:
return
out
if
not
prenorm
else
(
out
,
x
)
else
:
out1
=
F
.
layer_norm
(
x
.
to
(
weight1
.
dtype
),
x
.
shape
[
-
1
:],
weight
=
weight1
,
bias
=
bias1
,
eps
=
eps
).
to
(
dtype
)
return
(
out
,
out1
)
if
not
prenorm
else
(
out
,
out1
,
x
)
def
rms_norm_ref
(
x
,
weight
,
bias
,
residual
=
None
,
x1
=
None
,
weight1
=
None
,
bias1
=
None
,
eps
=
1e-6
,
dropout_p
=
0.0
,
rowscale
=
None
,
prenorm
=
False
,
dropout_mask
=
None
,
dropout_mask1
=
None
,
upcast
=
False
,
):
dtype
=
x
.
dtype
if
upcast
:
x
=
x
.
float
()
weight
=
weight
.
float
()
bias
=
bias
.
float
()
if
bias
is
not
None
else
None
residual
=
residual
.
float
()
if
residual
is
not
None
else
residual
x1
=
x1
.
float
()
if
x1
is
not
None
else
None
weight1
=
weight1
.
float
()
if
weight1
is
not
None
else
None
bias1
=
bias1
.
float
()
if
bias1
is
not
None
else
None
if
x1
is
not
None
:
assert
rowscale
is
None
,
"rowscale is not supported with parallel LayerNorm"
if
rowscale
is
not
None
:
x
=
x
*
rowscale
[...,
None
]
if
dropout_p
>
0.0
:
if
dropout_mask
is
not
None
:
x
=
x
.
masked_fill
(
~
dropout_mask
,
0.0
)
/
(
1.0
-
dropout_p
)
else
:
x
=
F
.
dropout
(
x
,
p
=
dropout_p
)
if
x1
is
not
None
:
if
dropout_mask1
is
not
None
:
x1
=
x1
.
masked_fill
(
~
dropout_mask1
,
0.0
)
/
(
1.0
-
dropout_p
)
else
:
x1
=
F
.
dropout
(
x1
,
p
=
dropout_p
)
if
x1
is
not
None
:
x
=
x
+
x1
if
residual
is
not
None
:
x
=
(
x
+
residual
).
to
(
x
.
dtype
)
rstd
=
1
/
torch
.
sqrt
((
x
.
square
()).
mean
(
dim
=-
1
,
keepdim
=
True
)
+
eps
)
out
=
((
x
*
rstd
*
weight
)
+
bias
if
bias
is
not
None
else
(
x
*
rstd
*
weight
)).
to
(
dtype
)
if
weight1
is
None
:
return
out
if
not
prenorm
else
(
out
,
x
)
else
:
out1
=
((
x
*
rstd
*
weight1
)
+
bias1
if
bias1
is
not
None
else
(
x
*
rstd
*
weight1
)).
to
(
dtype
)
return
(
out
,
out1
)
if
not
prenorm
else
(
out
,
out1
,
x
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
1
),
triton
.
Config
({},
num_warps
=
2
),
triton
.
Config
({},
num_warps
=
4
),
triton
.
Config
({},
num_warps
=
8
),
triton
.
Config
({},
num_warps
=
16
),
triton
.
Config
({},
num_warps
=
32
),
],
key
=
[
"N"
,
"HAS_RESIDUAL"
,
"STORE_RESIDUAL_OUT"
,
"IS_RMS_NORM"
,
"HAS_BIAS"
],
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
@
triton
.
heuristics
({
"HAS_X1"
:
lambda
args
:
args
[
"X1"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_W1"
:
lambda
args
:
args
[
"W1"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_B1"
:
lambda
args
:
args
[
"B1"
]
is
not
None
})
@
triton
.
jit
def
_layer_norm_fwd_1pass_kernel
(
X
,
# pointer to the input
Y
,
# pointer to the output
W
,
# pointer to the weights
B
,
# pointer to the biases
RESIDUAL
,
# pointer to the residual
X1
,
W1
,
B1
,
Y1
,
RESIDUAL_OUT
,
# pointer to the residual
ROWSCALE
,
SEEDS
,
# Dropout seeds for each row
DROPOUT_MASK
,
Mean
,
# pointer to the mean
Rstd
,
# pointer to the 1/std
stride_x_row
,
# how much to increase the pointer when moving by 1 row
stride_y_row
,
stride_res_row
,
stride_res_out_row
,
stride_x1_row
,
stride_y1_row
,
M
,
# number of rows in X
N
,
# number of columns in X
eps
,
# epsilon to avoid division by zero
dropout_p
,
# Dropout probability
IS_RMS_NORM
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
HAS_RESIDUAL
:
tl
.
constexpr
,
STORE_RESIDUAL_OUT
:
tl
.
constexpr
,
HAS_BIAS
:
tl
.
constexpr
,
HAS_DROPOUT
:
tl
.
constexpr
,
STORE_DROPOUT_MASK
:
tl
.
constexpr
,
HAS_ROWSCALE
:
tl
.
constexpr
,
HAS_X1
:
tl
.
constexpr
,
HAS_W1
:
tl
.
constexpr
,
HAS_B1
:
tl
.
constexpr
,
):
# Map the program id to the row of X and Y it should compute.
row
=
tl
.
program_id
(
0
)
X
+=
row
*
stride_x_row
Y
+=
row
*
stride_y_row
if
HAS_RESIDUAL
:
RESIDUAL
+=
row
*
stride_res_row
if
STORE_RESIDUAL_OUT
:
RESIDUAL_OUT
+=
row
*
stride_res_out_row
if
HAS_X1
:
X1
+=
row
*
stride_x1_row
if
HAS_W1
:
Y1
+=
row
*
stride_y1_row
# Compute mean and variance
cols
=
tl
.
arange
(
0
,
BLOCK_N
)
x
=
tl
.
load
(
X
+
cols
,
mask
=
cols
<
N
,
other
=
0.0
).
to
(
tl
.
float32
)
if
HAS_ROWSCALE
:
rowscale
=
tl
.
load
(
ROWSCALE
+
row
).
to
(
tl
.
float32
)
x
*=
rowscale
if
HAS_DROPOUT
:
# Compute dropout mask
# 7 rounds is good enough, and reduces register pressure
keep_mask
=
tl
.
rand
(
tl
.
load
(
SEEDS
+
row
).
to
(
tl
.
uint32
),
cols
,
n_rounds
=
7
)
>
dropout_p
x
=
tl
.
where
(
keep_mask
,
x
/
(
1.0
-
dropout_p
),
0.0
)
if
STORE_DROPOUT_MASK
:
tl
.
store
(
DROPOUT_MASK
+
row
*
N
+
cols
,
keep_mask
,
mask
=
cols
<
N
)
if
HAS_X1
:
x1
=
tl
.
load
(
X1
+
cols
,
mask
=
cols
<
N
,
other
=
0.0
).
to
(
tl
.
float32
)
if
HAS_ROWSCALE
:
rowscale
=
tl
.
load
(
ROWSCALE
+
M
+
row
).
to
(
tl
.
float32
)
x1
*=
rowscale
if
HAS_DROPOUT
:
# Compute dropout mask
# 7 rounds is good enough, and reduces register pressure
keep_mask
=
(
tl
.
rand
(
tl
.
load
(
SEEDS
+
M
+
row
).
to
(
tl
.
uint32
),
cols
,
n_rounds
=
7
)
>
dropout_p
)
x1
=
tl
.
where
(
keep_mask
,
x1
/
(
1.0
-
dropout_p
),
0.0
)
if
STORE_DROPOUT_MASK
:
tl
.
store
(
DROPOUT_MASK
+
(
M
+
row
)
*
N
+
cols
,
keep_mask
,
mask
=
cols
<
N
)
x
+=
x1
if
HAS_RESIDUAL
:
residual
=
tl
.
load
(
RESIDUAL
+
cols
,
mask
=
cols
<
N
,
other
=
0.0
).
to
(
tl
.
float32
)
x
+=
residual
if
STORE_RESIDUAL_OUT
:
tl
.
store
(
RESIDUAL_OUT
+
cols
,
x
,
mask
=
cols
<
N
)
if
not
IS_RMS_NORM
:
mean
=
tl
.
sum
(
x
,
axis
=
0
)
/
N
tl
.
store
(
Mean
+
row
,
mean
)
xbar
=
tl
.
where
(
cols
<
N
,
x
-
mean
,
0.0
)
var
=
tl
.
sum
(
xbar
*
xbar
,
axis
=
0
)
/
N
else
:
xbar
=
tl
.
where
(
cols
<
N
,
x
,
0.0
)
var
=
tl
.
sum
(
xbar
*
xbar
,
axis
=
0
)
/
N
rstd
=
1
/
tl
.
sqrt
(
var
+
eps
)
tl
.
store
(
Rstd
+
row
,
rstd
)
# Normalize and apply linear transformation
mask
=
cols
<
N
w
=
tl
.
load
(
W
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
if
HAS_BIAS
:
b
=
tl
.
load
(
B
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
x_hat
=
(
x
-
mean
)
*
rstd
if
not
IS_RMS_NORM
else
x
*
rstd
y
=
x_hat
*
w
+
b
if
HAS_BIAS
else
x_hat
*
w
# Write output
tl
.
store
(
Y
+
cols
,
y
,
mask
=
mask
)
if
HAS_W1
:
w1
=
tl
.
load
(
W1
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
if
HAS_B1
:
b1
=
tl
.
load
(
B1
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
y1
=
x_hat
*
w1
+
b1
if
HAS_B1
else
x_hat
*
w1
tl
.
store
(
Y1
+
cols
,
y1
,
mask
=
mask
)
def
_layer_norm_fwd
(
x
,
weight
,
bias
,
eps
,
residual
=
None
,
x1
=
None
,
weight1
=
None
,
bias1
=
None
,
dropout_p
=
0.0
,
rowscale
=
None
,
out_dtype
=
None
,
residual_dtype
=
None
,
is_rms_norm
=
False
,
return_dropout_mask
=
False
,
):
if
residual
is
not
None
:
residual_dtype
=
residual
.
dtype
M
,
N
=
x
.
shape
assert
x
.
stride
(
-
1
)
==
1
if
residual
is
not
None
:
assert
residual
.
stride
(
-
1
)
==
1
assert
residual
.
shape
==
(
M
,
N
)
assert
weight
.
shape
==
(
N
,)
assert
weight
.
stride
(
-
1
)
==
1
if
bias
is
not
None
:
assert
bias
.
stride
(
-
1
)
==
1
assert
bias
.
shape
==
(
N
,)
if
x1
is
not
None
:
assert
x1
.
shape
==
x
.
shape
assert
rowscale
is
None
assert
x1
.
stride
(
-
1
)
==
1
if
weight1
is
not
None
:
assert
weight1
.
shape
==
(
N
,)
assert
weight1
.
stride
(
-
1
)
==
1
if
bias1
is
not
None
:
assert
bias1
.
shape
==
(
N
,)
assert
bias1
.
stride
(
-
1
)
==
1
if
rowscale
is
not
None
:
assert
rowscale
.
is_contiguous
()
assert
rowscale
.
shape
==
(
M
,)
# allocate output
y
=
torch
.
empty_like
(
x
,
dtype
=
x
.
dtype
if
out_dtype
is
None
else
out_dtype
)
assert
y
.
stride
(
-
1
)
==
1
if
weight1
is
not
None
:
y1
=
torch
.
empty_like
(
y
)
assert
y1
.
stride
(
-
1
)
==
1
else
:
y1
=
None
if
(
residual
is
not
None
or
(
residual_dtype
is
not
None
and
residual_dtype
!=
x
.
dtype
)
or
dropout_p
>
0.0
or
rowscale
is
not
None
or
x1
is
not
None
):
residual_out
=
torch
.
empty
(
M
,
N
,
device
=
x
.
device
,
dtype
=
residual_dtype
if
residual_dtype
is
not
None
else
x
.
dtype
)
assert
residual_out
.
stride
(
-
1
)
==
1
else
:
residual_out
=
None
mean
=
torch
.
empty
((
M
,),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
if
not
is_rms_norm
else
None
rstd
=
torch
.
empty
((
M
,),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
if
dropout_p
>
0.0
:
seeds
=
torch
.
randint
(
2
**
32
,
(
M
if
x1
is
None
else
2
*
M
,),
device
=
x
.
device
,
dtype
=
torch
.
int64
)
else
:
seeds
=
None
if
return_dropout_mask
and
dropout_p
>
0.0
:
dropout_mask
=
torch
.
empty
(
M
if
x1
is
None
else
2
*
M
,
N
,
device
=
x
.
device
,
dtype
=
torch
.
bool
)
else
:
dropout_mask
=
None
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE
=
65536
//
x
.
element_size
()
BLOCK_N
=
min
(
MAX_FUSED_SIZE
,
triton
.
next_power_of_2
(
N
))
if
N
>
BLOCK_N
:
raise
RuntimeError
(
"This layer norm doesn't support feature dim >= 64KB."
)
with
torch
.
cuda
.
device
(
x
.
device
.
index
):
_layer_norm_fwd_1pass_kernel
[(
M
,)](
x
,
y
,
weight
,
bias
,
residual
,
x1
,
weight1
,
bias1
,
y1
,
residual_out
,
rowscale
,
seeds
,
dropout_mask
,
mean
,
rstd
,
x
.
stride
(
0
),
y
.
stride
(
0
),
residual
.
stride
(
0
)
if
residual
is
not
None
else
0
,
residual_out
.
stride
(
0
)
if
residual_out
is
not
None
else
0
,
x1
.
stride
(
0
)
if
x1
is
not
None
else
0
,
y1
.
stride
(
0
)
if
y1
is
not
None
else
0
,
M
,
N
,
eps
,
dropout_p
,
is_rms_norm
,
BLOCK_N
,
residual
is
not
None
,
residual_out
is
not
None
,
bias
is
not
None
,
dropout_p
>
0.0
,
dropout_mask
is
not
None
,
rowscale
is
not
None
,
)
# residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
if
dropout_mask
is
not
None
and
x1
is
not
None
:
dropout_mask
,
dropout_mask1
=
dropout_mask
.
tensor_split
(
2
,
dim
=
0
)
else
:
dropout_mask1
=
None
return
(
y
,
y1
,
mean
,
rstd
,
residual_out
if
residual_out
is
not
None
else
x
,
seeds
,
dropout_mask
,
dropout_mask1
,
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
1
),
triton
.
Config
({},
num_warps
=
2
),
triton
.
Config
({},
num_warps
=
4
),
triton
.
Config
({},
num_warps
=
8
),
triton
.
Config
({},
num_warps
=
16
),
triton
.
Config
({},
num_warps
=
32
),
],
key
=
[
"N"
,
"HAS_DRESIDUAL"
,
"STORE_DRESIDUAL"
,
"IS_RMS_NORM"
,
"HAS_BIAS"
,
"HAS_DROPOUT"
],
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
@
triton
.
heuristics
({
"HAS_ROWSCALE"
:
lambda
args
:
args
[
"ROWSCALE"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_DY1"
:
lambda
args
:
args
[
"DY1"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_DX1"
:
lambda
args
:
args
[
"DX1"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_B1"
:
lambda
args
:
args
[
"DB1"
]
is
not
None
})
@
triton
.
heuristics
({
"RECOMPUTE_OUTPUT"
:
lambda
args
:
args
[
"Y"
]
is
not
None
})
@
triton
.
jit
def
_layer_norm_bwd_kernel
(
X
,
# pointer to the input
W
,
# pointer to the weights
B
,
# pointer to the biases
Y
,
# pointer to the output to be recomputed
DY
,
# pointer to the output gradient
DX
,
# pointer to the input gradient
DW
,
# pointer to the partial sum of weights gradient
DB
,
# pointer to the partial sum of biases gradient
DRESIDUAL
,
W1
,
DY1
,
DX1
,
DW1
,
DB1
,
DRESIDUAL_IN
,
ROWSCALE
,
SEEDS
,
Mean
,
# pointer to the mean
Rstd
,
# pointer to the 1/std
stride_x_row
,
# how much to increase the pointer when moving by 1 row
stride_y_row
,
stride_dy_row
,
stride_dx_row
,
stride_dres_row
,
stride_dy1_row
,
stride_dx1_row
,
stride_dres_in_row
,
M
,
# number of rows in X
N
,
# number of columns in X
eps
,
# epsilon to avoid division by zero
dropout_p
,
rows_per_program
,
IS_RMS_NORM
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
HAS_DRESIDUAL
:
tl
.
constexpr
,
STORE_DRESIDUAL
:
tl
.
constexpr
,
HAS_BIAS
:
tl
.
constexpr
,
HAS_DROPOUT
:
tl
.
constexpr
,
HAS_ROWSCALE
:
tl
.
constexpr
,
HAS_DY1
:
tl
.
constexpr
,
HAS_DX1
:
tl
.
constexpr
,
HAS_B1
:
tl
.
constexpr
,
RECOMPUTE_OUTPUT
:
tl
.
constexpr
,
):
# Map the program id to the elements of X, DX, and DY it should compute.
row_block_id
=
tl
.
program_id
(
0
)
row_start
=
row_block_id
*
rows_per_program
# Do not early exit if row_start >= M, because we need to write DW and DB
cols
=
tl
.
arange
(
0
,
BLOCK_N
)
mask
=
cols
<
N
X
+=
row_start
*
stride_x_row
if
HAS_DRESIDUAL
:
DRESIDUAL
+=
row_start
*
stride_dres_row
if
STORE_DRESIDUAL
:
DRESIDUAL_IN
+=
row_start
*
stride_dres_in_row
DY
+=
row_start
*
stride_dy_row
DX
+=
row_start
*
stride_dx_row
if
HAS_DY1
:
DY1
+=
row_start
*
stride_dy1_row
if
HAS_DX1
:
DX1
+=
row_start
*
stride_dx1_row
if
RECOMPUTE_OUTPUT
:
Y
+=
row_start
*
stride_y_row
w
=
tl
.
load
(
W
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
if
RECOMPUTE_OUTPUT
and
HAS_BIAS
:
b
=
tl
.
load
(
B
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
if
HAS_DY1
:
w1
=
tl
.
load
(
W1
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
dw
=
tl
.
zeros
((
BLOCK_N
,),
dtype
=
tl
.
float32
)
if
HAS_BIAS
:
db
=
tl
.
zeros
((
BLOCK_N
,),
dtype
=
tl
.
float32
)
if
HAS_DY1
:
dw1
=
tl
.
zeros
((
BLOCK_N
,),
dtype
=
tl
.
float32
)
if
HAS_B1
:
db1
=
tl
.
zeros
((
BLOCK_N
,),
dtype
=
tl
.
float32
)
row_end
=
min
((
row_block_id
+
1
)
*
rows_per_program
,
M
)
for
row
in
range
(
row_start
,
row_end
):
# Load data to SRAM
x
=
tl
.
load
(
X
+
cols
,
mask
=
mask
,
other
=
0
).
to
(
tl
.
float32
)
dy
=
tl
.
load
(
DY
+
cols
,
mask
=
mask
,
other
=
0
).
to
(
tl
.
float32
)
if
HAS_DY1
:
dy1
=
tl
.
load
(
DY1
+
cols
,
mask
=
mask
,
other
=
0
).
to
(
tl
.
float32
)
if
not
IS_RMS_NORM
:
mean
=
tl
.
load
(
Mean
+
row
)
rstd
=
tl
.
load
(
Rstd
+
row
)
# Compute dx
xhat
=
(
x
-
mean
)
*
rstd
if
not
IS_RMS_NORM
else
x
*
rstd
xhat
=
tl
.
where
(
mask
,
xhat
,
0.0
)
if
RECOMPUTE_OUTPUT
:
y
=
xhat
*
w
+
b
if
HAS_BIAS
else
xhat
*
w
tl
.
store
(
Y
+
cols
,
y
,
mask
=
mask
)
wdy
=
w
*
dy
dw
+=
dy
*
xhat
if
HAS_BIAS
:
db
+=
dy
if
HAS_DY1
:
wdy
+=
w1
*
dy1
dw1
+=
dy1
*
xhat
if
HAS_B1
:
db1
+=
dy1
if
not
IS_RMS_NORM
:
c1
=
tl
.
sum
(
xhat
*
wdy
,
axis
=
0
)
/
N
c2
=
tl
.
sum
(
wdy
,
axis
=
0
)
/
N
dx
=
(
wdy
-
(
xhat
*
c1
+
c2
))
*
rstd
else
:
c1
=
tl
.
sum
(
xhat
*
wdy
,
axis
=
0
)
/
N
dx
=
(
wdy
-
xhat
*
c1
)
*
rstd
if
HAS_DRESIDUAL
:
dres
=
tl
.
load
(
DRESIDUAL
+
cols
,
mask
=
mask
,
other
=
0
).
to
(
tl
.
float32
)
dx
+=
dres
# Write dx
if
STORE_DRESIDUAL
:
tl
.
store
(
DRESIDUAL_IN
+
cols
,
dx
,
mask
=
mask
)
if
HAS_DX1
:
if
HAS_DROPOUT
:
keep_mask
=
(
tl
.
rand
(
tl
.
load
(
SEEDS
+
M
+
row
).
to
(
tl
.
uint32
),
cols
,
n_rounds
=
7
)
>
dropout_p
)
dx1
=
tl
.
where
(
keep_mask
,
dx
/
(
1.0
-
dropout_p
),
0.0
)
else
:
dx1
=
dx
tl
.
store
(
DX1
+
cols
,
dx1
,
mask
=
mask
)
if
HAS_DROPOUT
:
keep_mask
=
tl
.
rand
(
tl
.
load
(
SEEDS
+
row
).
to
(
tl
.
uint32
),
cols
,
n_rounds
=
7
)
>
dropout_p
dx
=
tl
.
where
(
keep_mask
,
dx
/
(
1.0
-
dropout_p
),
0.0
)
if
HAS_ROWSCALE
:
rowscale
=
tl
.
load
(
ROWSCALE
+
row
).
to
(
tl
.
float32
)
dx
*=
rowscale
tl
.
store
(
DX
+
cols
,
dx
,
mask
=
mask
)
X
+=
stride_x_row
if
HAS_DRESIDUAL
:
DRESIDUAL
+=
stride_dres_row
if
STORE_DRESIDUAL
:
DRESIDUAL_IN
+=
stride_dres_in_row
if
RECOMPUTE_OUTPUT
:
Y
+=
stride_y_row
DY
+=
stride_dy_row
DX
+=
stride_dx_row
if
HAS_DY1
:
DY1
+=
stride_dy1_row
if
HAS_DX1
:
DX1
+=
stride_dx1_row
tl
.
store
(
DW
+
row_block_id
*
N
+
cols
,
dw
,
mask
=
mask
)
if
HAS_BIAS
:
tl
.
store
(
DB
+
row_block_id
*
N
+
cols
,
db
,
mask
=
mask
)
if
HAS_DY1
:
tl
.
store
(
DW1
+
row_block_id
*
N
+
cols
,
dw1
,
mask
=
mask
)
if
HAS_B1
:
tl
.
store
(
DB1
+
row_block_id
*
N
+
cols
,
db1
,
mask
=
mask
)
def
_layer_norm_bwd
(
dy
,
x
,
weight
,
bias
,
eps
,
mean
,
rstd
,
dresidual
=
None
,
dy1
=
None
,
weight1
=
None
,
bias1
=
None
,
seeds
=
None
,
dropout_p
=
0.0
,
rowscale
=
None
,
has_residual
=
False
,
has_x1
=
False
,
is_rms_norm
=
False
,
x_dtype
=
None
,
recompute_output
=
False
,
):
M
,
N
=
x
.
shape
assert
x
.
stride
(
-
1
)
==
1
assert
dy
.
stride
(
-
1
)
==
1
assert
dy
.
shape
==
(
M
,
N
)
if
dresidual
is
not
None
:
assert
dresidual
.
stride
(
-
1
)
==
1
assert
dresidual
.
shape
==
(
M
,
N
)
assert
weight
.
shape
==
(
N
,)
assert
weight
.
stride
(
-
1
)
==
1
if
bias
is
not
None
:
assert
bias
.
stride
(
-
1
)
==
1
assert
bias
.
shape
==
(
N
,)
if
dy1
is
not
None
:
assert
weight1
is
not
None
assert
dy1
.
shape
==
dy
.
shape
assert
dy1
.
stride
(
-
1
)
==
1
if
weight1
is
not
None
:
assert
weight1
.
shape
==
(
N
,)
assert
weight1
.
stride
(
-
1
)
==
1
if
bias1
is
not
None
:
assert
bias1
.
shape
==
(
N
,)
assert
bias1
.
stride
(
-
1
)
==
1
if
seeds
is
not
None
:
assert
seeds
.
is_contiguous
()
assert
seeds
.
shape
==
(
M
if
not
has_x1
else
M
*
2
,)
if
rowscale
is
not
None
:
assert
rowscale
.
is_contiguous
()
assert
rowscale
.
shape
==
(
M
,)
# allocate output
dx
=
(
torch
.
empty_like
(
x
)
if
x_dtype
is
None
else
torch
.
empty
(
M
,
N
,
dtype
=
x_dtype
,
device
=
x
.
device
)
)
dresidual_in
=
(
torch
.
empty_like
(
x
)
if
has_residual
and
(
dx
.
dtype
!=
x
.
dtype
or
dropout_p
>
0.0
or
rowscale
is
not
None
or
has_x1
)
else
None
)
dx1
=
torch
.
empty_like
(
dx
)
if
(
has_x1
and
dropout_p
>
0.0
)
else
None
y
=
torch
.
empty
(
M
,
N
,
dtype
=
dy
.
dtype
,
device
=
dy
.
device
)
if
recompute_output
else
None
if
recompute_output
:
assert
weight1
is
None
,
"recompute_output is not supported with parallel LayerNorm"
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE
=
65536
//
x
.
element_size
()
BLOCK_N
=
min
(
MAX_FUSED_SIZE
,
triton
.
next_power_of_2
(
N
))
if
N
>
BLOCK_N
:
raise
RuntimeError
(
"This layer norm doesn't support feature dim >= 64KB."
)
sm_count
=
torch
.
cuda
.
get_device_properties
(
x
.
device
).
multi_processor_count
_dw
=
torch
.
empty
((
sm_count
,
N
),
dtype
=
torch
.
float32
,
device
=
weight
.
device
)
_db
=
(
torch
.
empty
((
sm_count
,
N
),
dtype
=
torch
.
float32
,
device
=
bias
.
device
)
if
bias
is
not
None
else
None
)
_dw1
=
torch
.
empty_like
(
_dw
)
if
weight1
is
not
None
else
None
_db1
=
torch
.
empty_like
(
_db
)
if
bias1
is
not
None
else
None
rows_per_program
=
math
.
ceil
(
M
/
sm_count
)
grid
=
(
sm_count
,)
with
torch
.
cuda
.
device
(
x
.
device
.
index
):
_layer_norm_bwd_kernel
[
grid
](
x
,
weight
,
bias
,
y
,
dy
,
dx
,
_dw
,
_db
,
dresidual
,
weight1
,
dy1
,
dx1
,
_dw1
,
_db1
,
dresidual_in
,
rowscale
,
seeds
,
mean
,
rstd
,
x
.
stride
(
0
),
0
if
not
recompute_output
else
y
.
stride
(
0
),
dy
.
stride
(
0
),
dx
.
stride
(
0
),
dresidual
.
stride
(
0
)
if
dresidual
is
not
None
else
0
,
dy1
.
stride
(
0
)
if
dy1
is
not
None
else
0
,
dx1
.
stride
(
0
)
if
dx1
is
not
None
else
0
,
dresidual_in
.
stride
(
0
)
if
dresidual_in
is
not
None
else
0
,
M
,
N
,
eps
,
dropout_p
,
rows_per_program
,
is_rms_norm
,
BLOCK_N
,
dresidual
is
not
None
,
dresidual_in
is
not
None
,
bias
is
not
None
,
dropout_p
>
0.0
,
)
dw
=
_dw
.
sum
(
0
).
to
(
weight
.
dtype
)
db
=
_db
.
sum
(
0
).
to
(
bias
.
dtype
)
if
bias
is
not
None
else
None
dw1
=
_dw1
.
sum
(
0
).
to
(
weight1
.
dtype
)
if
weight1
is
not
None
else
None
db1
=
_db1
.
sum
(
0
).
to
(
bias1
.
dtype
)
if
bias1
is
not
None
else
None
# Don't need to compute dresidual_in separately in this case
if
has_residual
and
dx
.
dtype
==
x
.
dtype
and
dropout_p
==
0.0
and
rowscale
is
None
:
dresidual_in
=
dx
if
has_x1
and
dropout_p
==
0.0
:
dx1
=
dx
return
(
(
dx
,
dw
,
db
,
dresidual_in
,
dx1
,
dw1
,
db1
)
if
not
recompute_output
else
(
dx
,
dw
,
db
,
dresidual_in
,
dx1
,
dw1
,
db1
,
y
)
)
class
LayerNormFn
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
weight
,
bias
,
residual
=
None
,
x1
=
None
,
weight1
=
None
,
bias1
=
None
,
eps
=
1e-6
,
dropout_p
=
0.0
,
rowscale
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
,
return_dropout_mask
=
False
,
):
x_shape_og
=
x
.
shape
# reshape input data into 2D tensor
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
if
x
.
stride
(
-
1
)
!=
1
:
x
=
x
.
contiguous
()
if
residual
is
not
None
:
assert
residual
.
shape
==
x_shape_og
residual
=
residual
.
reshape
(
-
1
,
residual
.
shape
[
-
1
])
if
residual
.
stride
(
-
1
)
!=
1
:
residual
=
residual
.
contiguous
()
if
x1
is
not
None
:
assert
x1
.
shape
==
x_shape_og
assert
rowscale
is
None
,
"rowscale is not supported with parallel LayerNorm"
x1
=
x1
.
reshape
(
-
1
,
x1
.
shape
[
-
1
])
if
x1
.
stride
(
-
1
)
!=
1
:
x1
=
x1
.
contiguous
()
weight
=
weight
.
contiguous
()
if
bias
is
not
None
:
bias
=
bias
.
contiguous
()
if
weight1
is
not
None
:
weight1
=
weight1
.
contiguous
()
if
bias1
is
not
None
:
bias1
=
bias1
.
contiguous
()
if
rowscale
is
not
None
:
rowscale
=
rowscale
.
reshape
(
-
1
).
contiguous
()
residual_dtype
=
(
residual
.
dtype
if
residual
is
not
None
else
(
torch
.
float32
if
residual_in_fp32
else
None
)
)
y
,
y1
,
mean
,
rstd
,
residual_out
,
seeds
,
dropout_mask
,
dropout_mask1
=
_layer_norm_fwd
(
x
,
weight
,
bias
,
eps
,
residual
,
x1
,
weight1
,
bias1
,
dropout_p
=
dropout_p
,
rowscale
=
rowscale
,
residual_dtype
=
residual_dtype
,
is_rms_norm
=
is_rms_norm
,
return_dropout_mask
=
return_dropout_mask
,
)
ctx
.
save_for_backward
(
residual_out
,
weight
,
bias
,
weight1
,
bias1
,
rowscale
,
seeds
,
mean
,
rstd
)
ctx
.
x_shape_og
=
x_shape_og
ctx
.
eps
=
eps
ctx
.
dropout_p
=
dropout_p
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
has_residual
=
residual
is
not
None
ctx
.
has_x1
=
x1
is
not
None
ctx
.
prenorm
=
prenorm
ctx
.
x_dtype
=
x
.
dtype
y
=
y
.
reshape
(
x_shape_og
)
y1
=
y1
.
reshape
(
x_shape_og
)
if
y1
is
not
None
else
None
residual_out
=
residual_out
.
reshape
(
x_shape_og
)
if
residual_out
is
not
None
else
None
dropout_mask
=
dropout_mask
.
reshape
(
x_shape_og
)
if
dropout_mask
is
not
None
else
None
dropout_mask1
=
dropout_mask1
.
reshape
(
x_shape_og
)
if
dropout_mask1
is
not
None
else
None
if
not
return_dropout_mask
:
if
weight1
is
None
:
return
y
if
not
prenorm
else
(
y
,
residual_out
)
else
:
return
(
y
,
y1
)
if
not
prenorm
else
(
y
,
y1
,
residual_out
)
else
:
if
weight1
is
None
:
return
(
(
y
,
dropout_mask
,
dropout_mask1
)
if
not
prenorm
else
(
y
,
residual_out
,
dropout_mask
,
dropout_mask1
)
)
else
:
return
(
(
y
,
y1
,
dropout_mask
,
dropout_mask1
)
if
not
prenorm
else
(
y
,
y1
,
residual_out
,
dropout_mask
,
dropout_mask1
)
)
@
staticmethod
def
backward
(
ctx
,
dy
,
*
args
):
x
,
weight
,
bias
,
weight1
,
bias1
,
rowscale
,
seeds
,
mean
,
rstd
=
ctx
.
saved_tensors
dy
=
dy
.
reshape
(
-
1
,
dy
.
shape
[
-
1
])
if
dy
.
stride
(
-
1
)
!=
1
:
dy
=
dy
.
contiguous
()
assert
dy
.
shape
==
x
.
shape
if
weight1
is
not
None
:
dy1
,
args
=
args
[
0
],
args
[
1
:]
dy1
=
dy1
.
reshape
(
-
1
,
dy1
.
shape
[
-
1
])
if
dy1
.
stride
(
-
1
)
!=
1
:
dy1
=
dy1
.
contiguous
()
assert
dy1
.
shape
==
x
.
shape
else
:
dy1
=
None
if
ctx
.
prenorm
:
dresidual
=
args
[
0
]
dresidual
=
dresidual
.
reshape
(
-
1
,
dresidual
.
shape
[
-
1
])
if
dresidual
.
stride
(
-
1
)
!=
1
:
dresidual
=
dresidual
.
contiguous
()
assert
dresidual
.
shape
==
x
.
shape
else
:
dresidual
=
None
dx
,
dw
,
db
,
dresidual_in
,
dx1
,
dw1
,
db1
=
_layer_norm_bwd
(
dy
,
x
,
weight
,
bias
,
ctx
.
eps
,
mean
,
rstd
,
dresidual
,
dy1
,
weight1
,
bias1
,
seeds
,
ctx
.
dropout_p
,
rowscale
,
ctx
.
has_residual
,
ctx
.
has_x1
,
ctx
.
is_rms_norm
,
x_dtype
=
ctx
.
x_dtype
,
)
return
(
dx
.
reshape
(
ctx
.
x_shape_og
),
dw
,
db
,
dresidual_in
.
reshape
(
ctx
.
x_shape_og
)
if
ctx
.
has_residual
else
None
,
dx1
.
reshape
(
ctx
.
x_shape_og
)
if
dx1
is
not
None
else
None
,
dw1
,
db1
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
)
def
layer_norm_fn
(
x
,
weight
,
bias
,
residual
=
None
,
x1
=
None
,
weight1
=
None
,
bias1
=
None
,
eps
=
1e-6
,
dropout_p
=
0.0
,
rowscale
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
,
return_dropout_mask
=
False
,
):
return
LayerNormFn
.
apply
(
x
,
weight
,
bias
,
residual
,
x1
,
weight1
,
bias1
,
eps
,
dropout_p
,
rowscale
,
prenorm
,
residual_in_fp32
,
is_rms_norm
,
return_dropout_mask
,
)
def
rms_norm_fn
(
x
,
weight
,
bias
,
residual
=
None
,
x1
=
None
,
weight1
=
None
,
bias1
=
None
,
eps
=
1e-6
,
dropout_p
=
0.0
,
rowscale
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
,
):
return
LayerNormFn
.
apply
(
x
,
weight
,
bias
,
residual
,
x1
,
weight1
,
bias1
,
eps
,
dropout_p
,
rowscale
,
prenorm
,
residual_in_fp32
,
True
,
return_dropout_mask
,
)
class
RMSNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-5
,
dropout_p
=
0.0
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
eps
=
eps
if
dropout_p
>
0.0
:
self
.
drop
=
torch
.
nn
.
Dropout
(
dropout_p
)
else
:
self
.
drop
=
None
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
register_parameter
(
"bias"
,
None
)
self
.
reset_parameters
()
def
reset_parameters
(
self
):
torch
.
nn
.
init
.
ones_
(
self
.
weight
)
def
forward
(
self
,
x
,
residual
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
):
return
rms_norm_fn
(
x
,
self
.
weight
,
self
.
bias
,
residual
=
residual
,
eps
=
self
.
eps
,
dropout_p
=
self
.
drop
.
p
if
self
.
drop
is
not
None
and
self
.
training
else
0.0
,
prenorm
=
prenorm
,
residual_in_fp32
=
residual_in_fp32
,
)
class
LayerNormLinearFn
(
torch
.
autograd
.
Function
):
@
staticmethod
@
custom_fwd
def
forward
(
ctx
,
x
,
norm_weight
,
norm_bias
,
linear_weight
,
linear_bias
,
residual
=
None
,
eps
=
1e-6
,
prenorm
=
False
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
,
):
x_shape_og
=
x
.
shape
# reshape input data into 2D tensor
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
if
x
.
stride
(
-
1
)
!=
1
:
x
=
x
.
contiguous
()
if
residual
is
not
None
:
assert
residual
.
shape
==
x_shape_og
residual
=
residual
.
reshape
(
-
1
,
residual
.
shape
[
-
1
])
if
residual
.
stride
(
-
1
)
!=
1
:
residual
=
residual
.
contiguous
()
norm_weight
=
norm_weight
.
contiguous
()
if
norm_bias
is
not
None
:
norm_bias
=
norm_bias
.
contiguous
()
residual_dtype
=
(
residual
.
dtype
if
residual
is
not
None
else
(
torch
.
float32
if
residual_in_fp32
else
None
)
)
y
,
_
,
mean
,
rstd
,
residual_out
,
*
rest
=
_layer_norm_fwd
(
x
,
norm_weight
,
norm_bias
,
eps
,
residual
,
out_dtype
=
None
if
not
torch
.
is_autocast_enabled
()
else
torch
.
get_autocast_gpu_dtype
(),
residual_dtype
=
residual_dtype
,
is_rms_norm
=
is_rms_norm
,
)
y
=
y
.
reshape
(
x_shape_og
)
dtype
=
torch
.
get_autocast_gpu_dtype
()
if
torch
.
is_autocast_enabled
()
else
y
.
dtype
linear_weight
=
linear_weight
.
to
(
dtype
)
linear_bias
=
linear_bias
.
to
(
dtype
)
if
linear_bias
is
not
None
else
None
out
=
F
.
linear
(
y
.
to
(
linear_weight
.
dtype
),
linear_weight
,
linear_bias
)
# We don't store y, will be recomputed in the backward pass to save memory
ctx
.
save_for_backward
(
residual_out
,
norm_weight
,
norm_bias
,
linear_weight
,
mean
,
rstd
)
ctx
.
x_shape_og
=
x_shape_og
ctx
.
eps
=
eps
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
has_residual
=
residual
is
not
None
ctx
.
prenorm
=
prenorm
ctx
.
x_dtype
=
x
.
dtype
ctx
.
linear_bias_is_none
=
linear_bias
is
None
return
out
if
not
prenorm
else
(
out
,
residual_out
.
reshape
(
x_shape_og
))
@
staticmethod
@
custom_bwd
def
backward
(
ctx
,
dout
,
*
args
):
x
,
norm_weight
,
norm_bias
,
linear_weight
,
mean
,
rstd
=
ctx
.
saved_tensors
dout
=
dout
.
reshape
(
-
1
,
dout
.
shape
[
-
1
])
dy
=
F
.
linear
(
dout
,
linear_weight
.
t
())
dlinear_bias
=
None
if
ctx
.
linear_bias_is_none
else
dout
.
sum
(
0
)
if
dy
.
stride
(
-
1
)
!=
1
:
dy
=
dy
.
contiguous
()
assert
dy
.
shape
==
x
.
shape
if
ctx
.
prenorm
:
dresidual
=
args
[
0
]
dresidual
=
dresidual
.
reshape
(
-
1
,
dresidual
.
shape
[
-
1
])
if
dresidual
.
stride
(
-
1
)
!=
1
:
dresidual
=
dresidual
.
contiguous
()
assert
dresidual
.
shape
==
x
.
shape
else
:
dresidual
=
None
dx
,
dnorm_weight
,
dnorm_bias
,
dresidual_in
,
_
,
_
,
_
,
y
=
_layer_norm_bwd
(
dy
,
x
,
norm_weight
,
norm_bias
,
ctx
.
eps
,
mean
,
rstd
,
dresidual
=
dresidual
,
has_residual
=
ctx
.
has_residual
,
is_rms_norm
=
ctx
.
is_rms_norm
,
x_dtype
=
ctx
.
x_dtype
,
recompute_output
=
True
,
)
dlinear_weight
=
torch
.
einsum
(
"bo,bi->oi"
,
dout
,
y
)
return
(
dx
.
reshape
(
ctx
.
x_shape_og
),
dnorm_weight
,
dnorm_bias
,
dlinear_weight
,
dlinear_bias
,
dresidual_in
.
reshape
(
ctx
.
x_shape_og
)
if
ctx
.
has_residual
else
None
,
None
,
None
,
None
,
None
,
)
def
layer_norm_linear_fn
(
x
,
norm_weight
,
norm_bias
,
linear_weight
,
linear_bias
,
residual
=
None
,
eps
=
1e-6
,
prenorm
=
False
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
,
):
return
LayerNormLinearFn
.
apply
(
x
,
norm_weight
,
norm_bias
,
linear_weight
,
linear_bias
,
residual
,
eps
,
prenorm
,
residual_in_fp32
,
is_rms_norm
,
)
flash_attn/ops/triton/linear.py
deleted
100644 → 0
View file @
5018ac6a
# Adapted from https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py
# and https://github.com/openai/triton/blob/master/python/triton/ops/matmul.py
from
typing
import
Optional
import
torch
import
triton
import
triton.language
as
tl
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
from
flash_attn.ops.triton.k_activations
import
(
gelu
,
gelu_approx
,
gelu_approx_grad
,
gelu_grad
,
squared_relu
,
squared_relu_grad
,
)
# CREDITS: Initially inspired by the Triton tutorial on matrix multiplications
def
init_to_zero
(
name
):
return
lambda
nargs
:
nargs
[
name
].
zero_
()
def
get_configs_io_bound
():
configs
=
[]
for
num_stages
in
[
2
,
3
,
4
,
5
,
6
]:
for
block_m
in
[
16
,
32
]:
for
block_k
in
[
32
,
64
]:
for
block_n
in
[
32
,
64
,
128
,
256
]:
num_warps
=
2
if
block_n
<=
64
else
4
configs
.
append
(
triton
.
Config
(
{
"BLOCK_M"
:
block_m
,
"BLOCK_N"
:
block_n
,
"BLOCK_K"
:
block_k
,
"SPLIT_K"
:
1
,
},
num_stages
=
num_stages
,
num_warps
=
num_warps
,
)
)
# split_k not used
# for split_k in [2, 4, 8, 16]:
# configs.append(triton.Config(
# {'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
# num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
return
configs
@
triton
.
autotune
(
configs
=
[
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
# good for int8
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
,
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
,
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
,
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
]
+
get_configs_io_bound
(),
key
=
[
"CACHE_KEY_M"
,
"CACHE_KEY_N"
,
"CACHE_KEY_K"
],
prune_configs_by
=
{
"early_config_prune"
:
early_config_prune
,
"perf_model"
:
estimate_matmul_time
,
"top_k"
:
10
,
},
)
@
triton
.
heuristics
(
{
"EVEN_K"
:
lambda
args
:
args
[
"K"
]
%
(
args
[
"BLOCK_K"
]
*
args
[
"SPLIT_K"
])
==
0
,
}
)
@
triton
.
jit
def
kernel_fwd
(
C
,
# Pointers to matrices
ACT_INPUT
,
A
,
B
,
bias
,
# Matrix dimensions
M
,
N
,
K
,
CACHE_KEY_M
,
CACHE_KEY_N
,
CACHE_KEY_K
,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
# by to get the element one row down (A has M rows)
stride_cm
,
# stride_cn, # Assume that stride_cn == 1
stride_am
,
stride_ak
,
stride_bn
,
stride_bk
,
# Meta-parameters
BLOCK_M
:
tl
.
constexpr
,
GROUP_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
# split k not used, not performant with activation, kept because early_config_prune is expecting it
SPLIT_K
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
A_ROWMAJOR
:
tl
.
constexpr
,
B_COLMAJOR
:
tl
.
constexpr
,
BIAS
:
tl
.
constexpr
,
SAVE_ACT_INPUT
:
tl
.
constexpr
,
ACTIVATION
:
tl
.
constexpr
,
):
"""
Kernel for computing Out = activation(A x W + C)
- Input has shape (M, K)
- Weight has shape (K, N)
- Bias has shape (N,)
- Output has shape (M, N)
- ActInputs (optional) has shape (M, N)
'ActInputs' optionally saves the A x W + C intermediate for backward computations
This kernel will consolidate over K
"""
pid
=
tl
.
program_id
(
axis
=
0
)
grid_m
=
(
M
+
BLOCK_M
-
1
)
//
BLOCK_M
grid_n
=
(
N
+
BLOCK_N
-
1
)
//
BLOCK_N
# re-order program ID for better L2 performance
width
=
GROUP_M
*
grid_n
group_id
=
pid
//
width
group_size
=
min
(
grid_m
-
group_id
*
GROUP_M
,
GROUP_M
)
pid_m
=
group_id
*
GROUP_M
+
(
pid
%
group_size
)
pid_n
=
(
pid
%
width
)
//
(
group_size
)
# now compute the block that each program will go through
# rm (resp. rn) denotes a range of indices
# for rows (resp. col) of C
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
# trick to avoid masking on M and N axis
ram
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
rm
%
M
,
BLOCK_M
),
BLOCK_M
)
rbn
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
rn
%
N
,
BLOCK_N
),
BLOCK_N
)
rk
=
tl
.
arange
(
0
,
BLOCK_K
)
if
A_ROWMAJOR
:
A
=
A
+
(
ram
[:,
None
]
*
stride_am
+
rk
[
None
,
:])
else
:
A
=
A
+
(
ram
[:,
None
]
*
stride_am
+
rk
[
None
,
:]
*
stride_ak
)
if
B_COLMAJOR
:
B
=
B
+
(
rk
[:,
None
]
+
rbn
[
None
,
:]
*
stride_bn
)
else
:
B
=
B
+
(
rk
[:,
None
]
*
stride_bk
+
rbn
[
None
,
:]
*
stride_bn
)
acc
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
K
,
0
,
-
BLOCK_K
):
if
EVEN_K
:
a
=
tl
.
load
(
A
)
b
=
tl
.
load
(
B
)
else
:
a
=
tl
.
load
(
A
,
mask
=
rk
[
None
,
:]
<
k
,
other
=
0.0
)
b
=
tl
.
load
(
B
,
mask
=
rk
[:,
None
]
<
k
,
other
=
0.0
)
acc
+=
tl
.
dot
(
a
,
b
)
if
A_ROWMAJOR
:
A
+=
BLOCK_K
else
:
A
+=
BLOCK_K
*
stride_ak
if
B_COLMAJOR
:
B
+=
BLOCK_K
else
:
B
+=
BLOCK_K
*
stride_bk
# Putting bias after the matmul (instead of before) is faster, idk why
if
BIAS
:
bias
=
tl
.
load
(
bias
+
rn
,
mask
=
rn
<
N
,
other
=
0.0
).
to
(
tl
.
float32
)
acc
+=
bias
[
None
,
:]
# optional: save the activation inputs
if
SAVE_ACT_INPUT
:
# act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] * stride_cn
act_in_ptrs
=
ACT_INPUT
+
ram
[:,
None
]
*
stride_cm
+
rbn
[
None
,
:]
tl
.
store
(
act_in_ptrs
,
acc
)
# optional: fused activation (while the data is in shared memory)
if
ACTIVATION
==
"gelu"
:
acc
=
gelu
(
acc
)
elif
ACTIVATION
==
"gelu_approx"
:
acc
=
gelu_approx
(
acc
)
elif
ACTIVATION
==
"squared_relu"
:
acc
=
squared_relu
(
acc
)
# rematerialize rm and rn to save registers
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
# write back result
# C = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn
C
=
C
+
rm
[:,
None
]
*
stride_cm
+
rn
[
None
,
:]
mask
=
(
rm
<
M
)[:,
None
]
&
(
rn
<
N
)[
None
,
:]
tl
.
store
(
C
,
acc
)
def
triton_linear_act
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"id"
,
save_act_input
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""
Compute e = activation(x @ weight.T + bias).
This wrapper kicks the `kernel_fwd` Triton kernel
:param x: input tensor
:param weight: weight matrix
:param bias: an optional bias tensor
:param activation: Activation name. Needs to be a Triton kernel.
:param act_input: an optional tensor to save the activation inputs (for backward)
:return: result tensor
"""
# if torch.is_autocast_enabled():
# dtype = torch.get_autocast_gpu_dtype()
# x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]]
assert
activation
in
[
"id"
,
"gelu"
,
"gelu_approx"
,
"squared_relu"
]
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
x_reshaped
=
x
.
reshape
(
batch_dim
,
n
)
if
x_reshaped
.
stride
(
0
)
>
1
and
x_reshaped
.
stride
(
1
)
>
1
:
x_reshaped
=
x_reshaped
.
contiguous
()
if
weight
.
stride
(
0
)
>
1
and
weight
.
stride
(
1
)
>
1
:
weight
=
weight
.
contiguous
()
bias
=
bias
.
contiguous
()
if
bias
is
not
None
else
None
assert
(
x
.
dtype
==
weight
.
dtype
),
f
"Input and weight must have the same dtype, got
{
x
.
dtype
}
and
{
weight
.
dtype
}
"
if
bias
is
not
None
:
assert
(
x
.
dtype
==
bias
.
dtype
),
f
"Input and bias must have the same dtype, got
{
x
.
dtype
}
and
{
bias
.
dtype
}
"
assert
(
x_reshaped
.
shape
[
1
]
==
weight
.
shape
[
1
]
),
f
"Incompatible dimensions:
{
x_reshaped
.
shape
}
-
{
weight
.
shape
}
"
assert
(
bias
is
None
or
bias
.
shape
[
0
]
==
weight
.
shape
[
0
]
),
"Incompatible dimensions in between weight and bias"
M
,
K
=
x_reshaped
.
shape
N
,
K
=
weight
.
shape
output
=
torch
.
empty
((
M
,
N
),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
act_input
=
torch
.
empty_like
(
output
)
if
save_act_input
else
None
# 1D launch kernel where each block gets its own program.
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
"BLOCK_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_N"
]),)
# noqa
kernel_fwd
[
grid
](
output
,
act_input
,
x_reshaped
,
weight
,
# data ptrs
bias
if
bias
is
not
None
else
x
,
# auto skip bias if not present
M
,
# shapes
N
,
K
,
M
//
32
,
# key for triton cache (limit number of compilations)
N
//
32
,
K
//
32
,
stride_cm
=
output
.
stride
(
0
),
# strides
# stride_cn=output.stride(1),
stride_am
=
x_reshaped
.
stride
(
0
),
stride_ak
=
x_reshaped
.
stride
(
1
),
stride_bk
=
weight
.
stride
(
1
),
stride_bn
=
weight
.
stride
(
0
),
BIAS
=
bias
is
not
None
,
# optional fused bias
SAVE_ACT_INPUT
=
save_act_input
,
# optional save activation inputs
ACTIVATION
=
activation
,
# optional fused activation
A_ROWMAJOR
=
x_reshaped
.
stride
(
1
)
==
1
,
B_COLMAJOR
=
weight
.
stride
(
1
)
==
1
,
GROUP_M
=
8
,
# speed optimization: group the programs
)
if
not
save_act_input
:
return
output
.
reshape
(
*
batch_shape
,
output
.
shape
[
-
1
])
else
:
return
(
output
.
reshape
(
*
batch_shape
,
output
.
shape
[
-
1
]),
act_input
.
reshape
(
*
batch_shape
,
act_input
.
shape
[
-
1
]),
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
# good for int8
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
,
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
,
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
,
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
]
+
get_configs_io_bound
(),
key
=
[
"CACHE_KEY_M"
,
"CACHE_KEY_N"
,
"CACHE_KEY_K"
],
prune_configs_by
=
{
"early_config_prune"
:
early_config_prune
,
"perf_model"
:
estimate_matmul_time
,
"top_k"
:
10
,
},
)
@
triton
.
heuristics
(
{
"EVEN_K"
:
lambda
args
:
args
[
"K"
]
%
(
args
[
"BLOCK_K"
]
*
args
[
"SPLIT_K"
])
==
0
,
}
)
@
triton
.
jit
def
kernel_bwd
(
C
,
# Pointers to matrices
ACT_INPUT
,
A
,
B
,
# Matrix dimensions
M
,
N
,
K
,
CACHE_KEY_M
,
CACHE_KEY_N
,
CACHE_KEY_K
,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
# by to get the element one row down (A has M rows)
stride_cm
,
# stride_cn, # Assume that stride_cn == 1
stride_am
,
stride_ak
,
stride_bk
,
stride_bn
,
# Meta-parameters
BLOCK_M
:
tl
.
constexpr
,
GROUP_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
# split k not used, not performant with activation, kept because early_config_prune is expecting it
SPLIT_K
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
ACTIVATION
:
tl
.
constexpr
,
):
"""
Kernel for computing Out = activation(A x W + C)
- Input has shape (M, K)
- Weight has shape (K, N)
- Output has shape (M, N)
- ActInputs (optional) has shape (M, N)
'ActInputs' optionally saves the A x W + C intermediate for backward computations
This kernel will consolidate over K
"""
pid
=
tl
.
program_id
(
axis
=
0
)
grid_m
=
(
M
+
BLOCK_M
-
1
)
//
BLOCK_M
grid_n
=
(
N
+
BLOCK_N
-
1
)
//
BLOCK_N
# re-order program ID for better L2 performance
width
=
GROUP_M
*
grid_n
group_id
=
pid
//
width
group_size
=
min
(
grid_m
-
group_id
*
GROUP_M
,
GROUP_M
)
pid_m
=
group_id
*
GROUP_M
+
(
pid
%
group_size
)
pid_n
=
(
pid
%
width
)
//
(
group_size
)
# now compute the block that each program will go through
# rm (resp. rn) denotes a range of indices
# for rows (resp. col) of C
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
# trick to avoid masking on M and N axis
ram
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
rm
%
M
,
BLOCK_M
),
BLOCK_M
)
rbn
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
rn
%
N
,
BLOCK_N
),
BLOCK_N
)
rk
=
tl
.
arange
(
0
,
BLOCK_K
)
A
=
A
+
(
ram
[:,
None
]
*
stride_am
+
rk
[
None
,
:]
*
stride_ak
)
B
=
B
+
(
rk
[:,
None
]
*
stride_bk
+
rbn
[
None
,
:]
*
stride_bn
)
acc
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
K
,
0
,
-
BLOCK_K
):
if
EVEN_K
:
a
=
tl
.
load
(
A
)
b
=
tl
.
load
(
B
)
else
:
a
=
tl
.
load
(
A
,
mask
=
rk
[
None
,
:]
<
k
,
other
=
0.0
)
b
=
tl
.
load
(
B
,
mask
=
rk
[:,
None
]
<
k
,
other
=
0.0
)
acc
+=
tl
.
dot
(
a
,
b
)
A
+=
BLOCK_K
*
stride_ak
B
+=
BLOCK_K
*
stride_bk
# optional: fused activation (while the data is in shared memory)
if
ACTIVATION
!=
"id"
:
act_in_ptrs
=
ACT_INPUT
+
ram
[:,
None
]
*
stride_cm
+
rbn
[
None
,
:]
act_input
=
tl
.
load
(
act_in_ptrs
).
to
(
acc
.
dtype
)
if
ACTIVATION
==
"gelu"
:
acc
*=
gelu_grad
(
act_input
)
elif
ACTIVATION
==
"gelu_approx"
:
acc
*=
gelu_approx_grad
(
act_input
)
elif
ACTIVATION
==
"squared_relu"
:
acc
*=
squared_relu_grad
(
act_input
)
# rematerialize rm and rn to save registers
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
# write back result
C
=
C
+
rm
[:,
None
]
*
stride_cm
+
rn
[
None
,
:]
mask
=
(
rm
<
M
)[:,
None
]
&
(
rn
<
N
)[
None
,
:]
tl
.
store
(
C
,
acc
,
mask
=
mask
)
def
triton_dgrad_act
(
grad_output
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
activation
:
str
=
"id"
,
act_input
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""
Compute e = activation(grad_output @ weight + bias).
This wrapper kicks the `kernel_fwd` Triton kernel
:param grad_output: input tensor
:param weight: weight matrix
:param activation: Activation name. Needs to be a Triton kernel.
:param act_input: an optional tensor to save the activation inputs (for backward)
:return: result tensor
"""
assert
activation
in
[
"id"
,
"gelu"
,
"gelu_approx"
,
"squared_relu"
]
batch_shape
,
n
=
grad_output
.
shape
[:
-
1
],
grad_output
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
grad_output_reshaped
=
grad_output
.
reshape
(
batch_dim
,
n
)
if
grad_output_reshaped
.
stride
(
0
)
>
1
and
grad_output_reshaped
.
stride
(
1
)
>
1
:
grad_output_reshaped
=
grad_output_reshaped
.
contiguous
()
if
weight
.
stride
(
0
)
>
1
and
weight
.
stride
(
1
)
>
1
:
weight
=
weight
.
contiguous
()
assert
(
grad_output
.
dtype
==
weight
.
dtype
),
f
"grad_output and weight must have the same dtype, got
{
grad_output
.
dtype
}
and
{
weight
.
dtype
}
"
assert
(
grad_output_reshaped
.
shape
[
1
]
==
weight
.
shape
[
0
]
),
f
"Incompatible dimensions:
{
grad_output_reshaped
.
shape
}
-
{
weight
.
shape
}
"
if
activation
!=
"id"
:
assert
act_input
is
not
None
,
f
"act_input is required for activation
{
activation
}
"
# M, N, K in bwd are different from M, N, K in fwd
M
,
K
=
grad_output_reshaped
.
shape
K
,
N
=
weight
.
shape
grad_input
=
torch
.
empty
((
M
,
N
),
device
=
grad_output
.
device
,
dtype
=
grad_output
.
dtype
)
# 1D launch kernel where each block gets its own program.
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
"BLOCK_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_N"
]),)
# noqa
kernel_bwd
[
grid
](
grad_input
,
act_input
,
grad_output_reshaped
,
weight
,
# data ptrs
M
,
# shapes
N
,
K
,
M
//
32
,
# key for triton cache (limit number of compilations)
N
//
32
,
K
//
32
,
stride_cm
=
grad_input
.
stride
(
0
),
# strides
# stride_cn=grad_input.stride(1),
stride_am
=
grad_output_reshaped
.
stride
(
0
),
stride_ak
=
grad_output_reshaped
.
stride
(
1
),
stride_bk
=
weight
.
stride
(
0
),
stride_bn
=
weight
.
stride
(
1
),
ACTIVATION
=
activation
,
# optional fused activation
GROUP_M
=
8
,
# speed optimization: group the programs
)
return
grad_input
.
reshape
(
*
batch_shape
,
grad_input
.
shape
[
-
1
])
flash_attn/ops/triton/mlp.py
deleted
100644 → 0
View file @
5018ac6a
# The triton fused matmul + sqrelu is faster for fp16 but slower for bf16, compared
# to naive implementation.
import
fused_dense_lib
as
fused_dense_cuda
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
flash_attn.ops.activations
import
sqrelu_bwd
,
sqrelu_fwd
from
flash_attn.ops.triton.linear
import
triton_dgrad_act
,
triton_linear_act
class
FusedDenseSqreluDenseFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
custom_fwd
def
forward
(
ctx
,
x
,
weight1
,
bias1
,
weight2
,
bias2
,
checkpoint_lvl
=
0
):
"""checkpoint_lvl:
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute act_input and gelu_out in the bwd
"""
if
torch
.
is_autocast_enabled
():
dtype
=
torch
.
get_autocast_gpu_dtype
()
x
,
weight1
,
bias1
,
weight2
,
bias2
=
[
a
.
to
(
dtype
=
dtype
)
for
a
in
[
x
,
weight1
,
bias1
,
weight2
,
bias2
]
]
is_bf16
=
x
.
dtype
==
torch
.
bfloat16
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
x
=
x
.
contiguous
()
weight1
=
weight1
.
contiguous
()
bias1
=
bias1
.
contiguous
()
weight2
=
weight2
.
contiguous
()
bias2
=
bias2
.
contiguous
()
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
if
is_bf16
:
act_input
=
fused_dense_cuda
.
linear_bias_forward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
)
output1
=
sqrelu_fwd
(
act_input
)
else
:
save_act_input
=
checkpoint_lvl
!=
2
result
=
triton_linear_act
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
,
activation
=
"squared_relu"
,
save_act_input
=
save_act_input
,
)
if
save_act_input
:
output1
,
act_input
=
result
else
:
output1
=
result
output2
=
fused_dense_cuda
.
linear_bias_forward
(
output1
,
weight2
,
bias2
)
ctx
.
checkpoint_lvl
=
checkpoint_lvl
if
checkpoint_lvl
==
0
:
ctx
.
save_for_backward
(
x
,
weight1
,
bias1
,
weight2
,
act_input
,
output1
)
elif
checkpoint_lvl
==
1
:
ctx
.
save_for_backward
(
x
,
weight1
,
bias1
,
weight2
,
act_input
)
elif
checkpoint_lvl
==
2
:
ctx
.
save_for_backward
(
x
,
weight1
,
bias1
,
weight2
)
return
output2
.
reshape
(
*
batch_shape
,
output2
.
shape
[
-
1
])
@
staticmethod
@
custom_bwd
def
backward
(
ctx
,
grad_output
):
grad_output
=
grad_output
.
contiguous
()
checkpoint_lvl
=
ctx
.
checkpoint_lvl
x
,
weight1
,
bias1
,
weight2
,
*
rest
=
ctx
.
saved_tensors
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
is_bf16
=
x
.
dtype
==
torch
.
bfloat16
if
checkpoint_lvl
==
0
:
act_input
,
output1
=
rest
elif
checkpoint_lvl
==
1
:
(
act_input
,)
=
rest
output1
=
sqrelu_fwd
(
act_input
)
elif
checkpoint_lvl
==
2
:
if
is_bf16
:
act_input
=
fused_dense_cuda
.
linear_bias_forward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
)
output1
=
sqrelu_fwd
(
act_input
)
else
:
output1
,
act_input
=
triton_linear_act
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
,
activation
=
"squared_relu"
,
save_act_input
=
True
,
)
if
is_bf16
:
grad_output
=
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
])
grad_weight2
,
grad_bias2
=
fused_dense_cuda
.
linear_bias_wgrad
(
output1
,
grad_output
)
grad_output1
=
grad_output
@
weight2
grad_act_input
=
sqrelu_bwd
(
grad_output1
,
act_input
)
grad_input
,
grad_weight1
,
grad_bias1
=
fused_dense_cuda
.
linear_bias_backward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
grad_act_input
)
else
:
grad_output
=
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
])
grad_weight2
,
grad_bias2
=
fused_dense_cuda
.
linear_bias_wgrad
(
output1
,
grad_output
)
grad_act_input
=
triton_dgrad_act
(
grad_output
,
weight2
,
activation
=
"squared_relu"
,
act_input
=
act_input
)
grad_input
,
grad_weight1
,
grad_bias1
=
fused_dense_cuda
.
linear_bias_backward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
grad_act_input
)
return
grad_input
.
reshape_as
(
x
),
grad_weight1
,
grad_bias1
,
grad_weight2
,
grad_bias2
,
None
fused_dense_sqrelu_dense_function
=
FusedDenseSqreluDenseFunc
.
apply
class
FusedDenseSqreluDense
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
bias1
=
True
,
bias2
=
True
,
checkpoint_lvl
=
0
,
device
=
None
,
dtype
=
None
,
):
"""
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute gelu_in and gelu_out in the bwd
"""
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
*
4
assert
bias1
==
True
,
"DenseSqreluDense module without bias is currently not supported"
assert
bias2
==
True
,
"DenseSqreluDense module without bias is currently not supported"
self
.
checkpoint_lvl
=
checkpoint_lvl
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
,
bias
=
bias1
,
**
factory_kwargs
)
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
,
bias
=
bias2
,
**
factory_kwargs
)
def
forward
(
self
,
x
):
assert
x
.
is_cuda
return
fused_dense_sqrelu_dense_function
(
x
,
self
.
fc1
.
weight
,
self
.
fc1
.
bias
,
self
.
fc2
.
weight
,
self
.
fc2
.
bias
,
self
.
checkpoint_lvl
)
flash_attn/ops/triton/rotary.py
deleted
100644 → 0
View file @
5018ac6a
# Copyright (c) 2023, Tri Dao.
from
typing
import
Optional
,
Union
import
torch
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
rotary_kernel
(
OUT
,
# Pointers to matrices
X
,
COS
,
SIN
,
CU_SEQLENS
,
SEQLEN_OFFSETS
,
# this could be int or a pointer
# Matrix dimensions
seqlen
,
rotary_dim
,
seqlen_ro
,
# strides
stride_out_batch
,
stride_out_seqlen
,
stride_out_nheads
,
stride_out_headdim
,
stride_x_batch
,
stride_x_seqlen
,
stride_x_nheads
,
stride_x_headdim
,
# Meta-parameters
BLOCK_K
:
tl
.
constexpr
,
IS_SEQLEN_OFFSETS_TENSOR
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
INTERLEAVED
:
tl
.
constexpr
,
CONJUGATE
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
):
pid_m
=
tl
.
program_id
(
axis
=
0
)
pid_batch
=
tl
.
program_id
(
axis
=
1
)
pid_head
=
tl
.
program_id
(
axis
=
2
)
rotary_dim_half
=
rotary_dim
//
2
if
not
IS_VARLEN
:
X
=
X
+
pid_batch
*
stride_x_batch
+
pid_head
*
stride_x_nheads
OUT
=
OUT
+
pid_batch
*
stride_out_batch
+
pid_head
*
stride_out_nheads
else
:
start_idx
=
tl
.
load
(
CU_SEQLENS
+
pid_batch
)
seqlen
=
tl
.
load
(
CU_SEQLENS
+
pid_batch
+
1
)
-
start_idx
X
=
X
+
start_idx
*
stride_x_seqlen
+
pid_head
*
stride_x_nheads
OUT
=
OUT
+
start_idx
*
stride_out_seqlen
+
pid_head
*
stride_out_nheads
if
pid_m
*
BLOCK_M
>=
seqlen
:
return
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
if
not
IS_SEQLEN_OFFSETS_TENSOR
:
rm_cs
=
rm
+
SEQLEN_OFFSETS
else
:
rm_cs
=
rm
+
tl
.
load
(
SEQLEN_OFFSETS
+
pid_batch
)
rk
=
tl
.
arange
(
0
,
BLOCK_K
)
rk_half
=
tl
.
arange
(
0
,
BLOCK_K
//
2
)
if
not
INTERLEAVED
:
# Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT
X
=
X
+
(
rm
[:,
None
]
*
stride_x_seqlen
+
rk_half
[
None
,
:]
*
stride_x_headdim
)
COS
=
COS
+
(
rm_cs
[:,
None
]
*
rotary_dim_half
+
rk_half
[
None
,
:])
SIN
=
SIN
+
(
rm_cs
[:,
None
]
*
rotary_dim_half
+
rk_half
[
None
,
:])
cos
=
tl
.
load
(
COS
,
mask
=
(
rm_cs
[:,
None
]
<
seqlen_ro
)
&
(
rk_half
[
None
,
:]
<
rotary_dim_half
),
other
=
1.0
).
to
(
tl
.
float32
)
sin
=
tl
.
load
(
SIN
,
mask
=
(
rm_cs
[:,
None
]
<
seqlen_ro
)
&
(
rk_half
[
None
,
:]
<
rotary_dim_half
),
other
=
0.0
).
to
(
tl
.
float32
)
x0
=
tl
.
load
(
X
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk_half
[
None
,
:]
<
rotary_dim_half
),
other
=
0.0
).
to
(
tl
.
float32
)
x1
=
tl
.
load
(
X
+
rotary_dim_half
*
stride_x_headdim
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk_half
[
None
,
:]
<
rotary_dim_half
),
other
=
0.0
,
).
to
(
tl
.
float32
)
if
CONJUGATE
:
sin
=
-
sin
o0
=
x0
*
cos
-
x1
*
sin
o1
=
x0
*
sin
+
x1
*
cos
# write back result
OUT
=
OUT
+
(
rm
[:,
None
]
*
stride_out_seqlen
+
rk_half
[
None
,
:]
*
stride_out_headdim
)
tl
.
store
(
OUT
,
o0
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk_half
[
None
,
:]
<
rotary_dim_half
))
tl
.
store
(
OUT
+
rotary_dim_half
*
stride_out_headdim
,
o1
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk_half
[
None
,
:]
<
rotary_dim_half
),
)
else
:
# We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow.
# Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...].
# Loading x0 will be fast but x1 will be slow.
# Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...].
# Then we do the calculation and use tl.where to pick put the right outputs for the even
# and for the odd indices.
rk_swap
=
rk
+
((
rk
+
1
)
%
2
)
*
2
-
1
# 1, 0, 3, 2, 5, 4, ...
rk_repeat
=
tl
.
arange
(
0
,
BLOCK_K
)
//
2
X0
=
X
+
(
rm
[:,
None
]
*
stride_x_seqlen
+
rk
[
None
,
:]
*
stride_x_headdim
)
X1
=
X
+
(
rm
[:,
None
]
*
stride_x_seqlen
+
rk_swap
[
None
,
:]
*
stride_x_headdim
)
COS
=
COS
+
(
rm_cs
[:,
None
]
*
rotary_dim_half
+
rk_repeat
[
None
,
:])
SIN
=
SIN
+
(
rm_cs
[:,
None
]
*
rotary_dim_half
+
rk_repeat
[
None
,
:])
cos
=
tl
.
load
(
COS
,
mask
=
(
rm_cs
[:,
None
]
<
seqlen_ro
)
&
(
rk_repeat
[
None
,
:]
<
rotary_dim_half
),
other
=
1.0
,
).
to
(
tl
.
float32
)
sin
=
tl
.
load
(
SIN
,
mask
=
(
rm_cs
[:,
None
]
<
seqlen_ro
)
&
(
rk_repeat
[
None
,
:]
<
rotary_dim_half
),
other
=
0.0
,
).
to
(
tl
.
float32
)
x0
=
tl
.
load
(
X0
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk
[
None
,
:]
<
rotary_dim
),
other
=
0.0
).
to
(
tl
.
float32
)
x1
=
tl
.
load
(
X1
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk_swap
[
None
,
:]
<
rotary_dim
),
other
=
0.0
).
to
(
tl
.
float32
)
if
CONJUGATE
:
sin
=
-
sin
x0_cos
=
x0
*
cos
x1_sin
=
x1
*
sin
out
=
tl
.
where
(
rk
[
None
,
:]
%
2
==
0
,
x0_cos
-
x1_sin
,
x0_cos
+
x1_sin
)
OUT
=
OUT
+
(
rm
[:,
None
]
*
stride_out_seqlen
+
rk
[
None
,
:]
*
stride_out_headdim
)
tl
.
store
(
OUT
,
out
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk
[
None
,
:]
<
rotary_dim
))
def
apply_rotary
(
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
seqlen_offsets
:
Union
[
int
,
torch
.
Tensor
]
=
0
,
cu_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
max_seqlen
:
Optional
[
int
]
=
None
,
interleaved
=
False
,
inplace
=
False
,
conjugate
=
False
,
)
->
torch
.
Tensor
:
"""
Arguments:
x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
else (total_seqlen, nheads, headdim).
cos: (seqlen_ro, rotary_dim / 2)
sin: (seqlen_ro, rotary_dim / 2)
seqlen_offsets: integer or integer tensor of size (batch,)
cu_seqlens: (batch + 1,) or None
max_seqlen: int
Returns:
y: (batch, seqlen, nheads, headdim)
"""
is_varlen
=
cu_seqlens
is
not
None
if
not
is_varlen
:
batch
,
seqlen
,
nheads
,
headdim
=
x
.
shape
else
:
assert
max_seqlen
is
not
None
,
"If cu_seqlens is passed in, then max_seqlen must be passed"
total_seqlen
,
nheads
,
headdim
=
x
.
shape
batch_p_1
=
cu_seqlens
.
shape
[
0
]
batch
=
batch_p_1
-
1
seqlen
=
max_seqlen
seqlen_ro
,
rotary_dim
=
cos
.
shape
assert
sin
.
shape
==
cos
.
shape
rotary_dim
*=
2
assert
rotary_dim
<=
headdim
,
"rotary_dim must be <= headdim"
assert
headdim
<=
256
,
"Only support headdim <= 256"
assert
seqlen_ro
>=
seqlen
,
"seqlen_ro must be >= seqlen"
assert
(
cos
.
dtype
==
sin
.
dtype
),
f
"cos and sin must have the same dtype, got
{
cos
.
dtype
}
and
{
sin
.
dtype
}
"
assert
(
x
.
dtype
==
cos
.
dtype
),
f
"Input and cos/sin must have the same dtype, got
{
x
.
dtype
}
and
{
cos
.
dtype
}
"
cos
,
sin
=
cos
.
contiguous
(),
sin
.
contiguous
()
if
isinstance
(
seqlen_offsets
,
torch
.
Tensor
):
assert
seqlen_offsets
.
shape
==
(
batch
,)
assert
seqlen_offsets
.
dtype
in
[
torch
.
int32
,
torch
.
int64
]
seqlen_offsets
=
seqlen_offsets
.
contiguous
()
else
:
assert
seqlen_offsets
+
seqlen
<=
seqlen_ro
output
=
torch
.
empty_like
(
x
)
if
not
inplace
else
x
if
rotary_dim
<
headdim
and
not
inplace
:
output
[...,
rotary_dim
:].
copy_
(
x
[...,
rotary_dim
:])
BLOCK_K
=
(
32
if
rotary_dim
<=
32
else
(
64
if
rotary_dim
<=
64
else
(
128
if
rotary_dim
<=
128
else
256
))
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
seqlen
,
META
[
"BLOCK_M"
]),
batch
,
nheads
)
# noqa
BLOCK_M
=
4
if
interleaved
else
(
8
if
rotary_dim
<=
64
else
4
)
# Need this, otherwise Triton tries to launch from cuda:0 and we get
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
with
torch
.
cuda
.
device
(
x
.
device
.
index
):
rotary_kernel
[
grid
](
output
,
# data ptrs
x
,
cos
,
sin
,
cu_seqlens
,
seqlen_offsets
,
seqlen
,
# shapes
rotary_dim
,
seqlen_ro
,
output
.
stride
(
0
)
if
not
is_varlen
else
0
,
# batch_strides if not varlen else 0
output
.
stride
(
-
3
),
# seqlen_stride or total_seqlen_stride
output
.
stride
(
-
2
),
# nheads_stride
output
.
stride
(
-
1
),
# headdim_stride
x
.
stride
(
0
)
if
not
is_varlen
else
0
,
# batch_strides if not varlen else 0
x
.
stride
(
-
3
),
# seqlen stride or total_seqlen_stride
x
.
stride
(
-
2
),
# nheads stride
x
.
stride
(
-
1
),
# headdim stride
BLOCK_K
,
isinstance
(
seqlen_offsets
,
torch
.
Tensor
),
is_varlen
,
interleaved
,
conjugate
,
BLOCK_M
,
)
return
output
flash_attn/utils/__init__.py
deleted
100644 → 0
View file @
5018ac6a
flash_attn/utils/benchmark.py
deleted
100644 → 0
View file @
5018ac6a
# Copyright (c) 2023, Tri Dao.
""" Useful functions for writing test code. """
import
torch
import
torch.utils.benchmark
as
benchmark
def
benchmark_forward
(
fn
,
*
inputs
,
repeats
=
10
,
desc
=
""
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
):
"""Use Pytorch Benchmark on the forward pass of an arbitrary function."""
if
verbose
:
print
(
desc
,
"- Forward pass"
)
def
amp_wrapper
(
*
inputs
,
**
kwinputs
):
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
fn
(
*
inputs
,
**
kwinputs
)
t
=
benchmark
.
Timer
(
stmt
=
"fn_amp(*inputs, **kwinputs)"
,
globals
=
{
"fn_amp"
:
amp_wrapper
,
"inputs"
:
inputs
,
"kwinputs"
:
kwinputs
},
num_threads
=
torch
.
get_num_threads
(),
)
m
=
t
.
timeit
(
repeats
)
if
verbose
:
print
(
m
)
return
t
,
m
def
benchmark_backward
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
""
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
,
):
"""Use Pytorch Benchmark on the backward pass of an arbitrary function."""
if
verbose
:
print
(
desc
,
"- Backward pass"
)
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
y
=
fn
(
*
inputs
,
**
kwinputs
)
if
type
(
y
)
is
tuple
:
y
=
y
[
0
]
if
grad
is
None
:
grad
=
torch
.
randn_like
(
y
)
else
:
if
grad
.
shape
!=
y
.
shape
:
raise
RuntimeError
(
"Grad shape does not match output shape"
)
def
f
(
*
inputs
,
y
,
grad
):
# Set .grad to None to avoid extra operation of gradient accumulation
for
x
in
inputs
:
if
isinstance
(
x
,
torch
.
Tensor
):
x
.
grad
=
None
y
.
backward
(
grad
,
retain_graph
=
True
)
t
=
benchmark
.
Timer
(
stmt
=
"f(*inputs, y=y, grad=grad)"
,
globals
=
{
"f"
:
f
,
"inputs"
:
inputs
,
"y"
:
y
,
"grad"
:
grad
},
num_threads
=
torch
.
get_num_threads
(),
)
m
=
t
.
timeit
(
repeats
)
if
verbose
:
print
(
m
)
return
t
,
m
def
benchmark_combined
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
""
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
,
):
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
if
verbose
:
print
(
desc
,
"- Forward + Backward pass"
)
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
y
=
fn
(
*
inputs
,
**
kwinputs
)
if
type
(
y
)
is
tuple
:
y
=
y
[
0
]
if
grad
is
None
:
grad
=
torch
.
randn_like
(
y
)
else
:
if
grad
.
shape
!=
y
.
shape
:
raise
RuntimeError
(
"Grad shape does not match output shape"
)
def
f
(
grad
,
*
inputs
,
**
kwinputs
):
for
x
in
inputs
:
if
isinstance
(
x
,
torch
.
Tensor
):
x
.
grad
=
None
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
y
=
fn
(
*
inputs
,
**
kwinputs
)
if
type
(
y
)
is
tuple
:
y
=
y
[
0
]
y
.
backward
(
grad
,
retain_graph
=
True
)
t
=
benchmark
.
Timer
(
stmt
=
"f(grad, *inputs, **kwinputs)"
,
globals
=
{
"f"
:
f
,
"fn"
:
fn
,
"inputs"
:
inputs
,
"grad"
:
grad
,
"kwinputs"
:
kwinputs
},
num_threads
=
torch
.
get_num_threads
(),
)
m
=
t
.
timeit
(
repeats
)
if
verbose
:
print
(
m
)
return
t
,
m
def
benchmark_fwd_bwd
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
""
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
,
):
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
return
(
benchmark_forward
(
fn
,
*
inputs
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
,
),
benchmark_backward
(
fn
,
*
inputs
,
grad
=
grad
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
,
),
)
def
benchmark_all
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
""
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
,
):
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
return
(
benchmark_forward
(
fn
,
*
inputs
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
,
),
benchmark_backward
(
fn
,
*
inputs
,
grad
=
grad
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
,
),
benchmark_combined
(
fn
,
*
inputs
,
grad
=
grad
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
,
),
)
def
pytorch_profiler
(
fn
,
*
inputs
,
trace_filename
=
None
,
backward
=
False
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
cpu
=
False
,
verbose
=
True
,
**
kwinputs
,
):
"""Wrap benchmark functions in Pytorch profiler to see CUDA information."""
if
backward
:
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
out
=
fn
(
*
inputs
,
**
kwinputs
)
if
type
(
out
)
is
tuple
:
out
=
out
[
0
]
g
=
torch
.
randn_like
(
out
)
for
_
in
range
(
30
):
# Warm up
if
backward
:
for
x
in
inputs
:
if
isinstance
(
x
,
torch
.
Tensor
):
x
.
grad
=
None
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
out
=
fn
(
*
inputs
,
**
kwinputs
)
if
type
(
out
)
is
tuple
:
out
=
out
[
0
]
# Backward should be done outside autocast
if
backward
:
out
.
backward
(
g
,
retain_graph
=
True
)
activities
=
([
torch
.
profiler
.
ProfilerActivity
.
CPU
]
if
cpu
else
[])
+
[
torch
.
profiler
.
ProfilerActivity
.
CUDA
]
with
torch
.
profiler
.
profile
(
activities
=
activities
,
record_shapes
=
True
,
# profile_memory=True,
with_stack
=
True
,
)
as
prof
:
if
backward
:
for
x
in
inputs
:
if
isinstance
(
x
,
torch
.
Tensor
):
x
.
grad
=
None
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
out
=
fn
(
*
inputs
,
**
kwinputs
)
if
type
(
out
)
is
tuple
:
out
=
out
[
0
]
if
backward
:
out
.
backward
(
g
,
retain_graph
=
True
)
if
verbose
:
# print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50))
print
(
prof
.
key_averages
().
table
(
row_limit
=
50
))
if
trace_filename
is
not
None
:
prof
.
export_chrome_trace
(
trace_filename
)
def
benchmark_memory
(
fn
,
*
inputs
,
desc
=
""
,
verbose
=
True
,
**
kwinputs
):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
reset_peak_memory_stats
()
torch
.
cuda
.
synchronize
()
fn
(
*
inputs
,
**
kwinputs
)
torch
.
cuda
.
synchronize
()
mem
=
torch
.
cuda
.
max_memory_allocated
()
/
((
2
**
20
)
*
1000
)
if
verbose
:
print
(
f
"
{
desc
}
max memory:
{
mem
}
GB"
)
torch
.
cuda
.
empty_cache
()
return
mem
flash_attn/utils/distributed.py
deleted
100644 → 0
View file @
5018ac6a
from
typing
import
Optional
import
torch
from
torch
import
Tensor
from
torch.distributed
import
ProcessGroup
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 4 lines are for backward compatibility with
# older PyTorch.
if
"all_gather_into_tensor"
not
in
dir
(
torch
.
distributed
):
torch
.
distributed
.
all_gather_into_tensor
=
torch
.
distributed
.
_all_gather_base
if
"reduce_scatter_tensor"
not
in
dir
(
torch
.
distributed
):
torch
.
distributed
.
reduce_scatter_tensor
=
torch
.
distributed
.
_reduce_scatter_base
# Raw operation, does not support autograd, but does support async
def
all_gather_raw
(
input_
:
Tensor
,
process_group
:
ProcessGroup
,
async_op
:
bool
=
False
):
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
output
=
torch
.
empty
(
world_size
*
input_
.
shape
[
0
],
*
input_
.
shape
[
1
:],
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
handle
=
torch
.
distributed
.
all_gather_into_tensor
(
output
,
input_
.
contiguous
(),
group
=
process_group
,
async_op
=
async_op
)
return
output
,
handle
# Raw operation, does not support autograd, but does support async
def
reduce_scatter_raw
(
input_
:
Tensor
,
process_group
:
ProcessGroup
,
async_op
:
bool
=
False
):
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
assert
input_
.
shape
[
0
]
%
world_size
==
0
output
=
torch
.
empty
(
input_
.
shape
[
0
]
//
world_size
,
*
input_
.
shape
[
1
:],
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
handle
=
torch
.
distributed
.
reduce_scatter_tensor
(
output
,
input_
.
contiguous
(),
group
=
process_group
,
async_op
=
async_op
)
return
output
,
handle
# Raw operation, does not support autograd, but does support async
def
all_reduce_raw
(
input_
:
Tensor
,
process_group
:
ProcessGroup
,
async_op
:
bool
=
False
):
input_
=
input_
.
contiguous
()
handle
=
torch
.
distributed
.
all_reduce
(
input_
,
group
=
process_group
,
async_op
=
async_op
)
return
input_
,
handle
class
AllGatherFunc
(
torch
.
autograd
.
Function
):
"""Gather the input from sequence parallel region and concatenate."""
@
staticmethod
def
forward
(
ctx
,
input_
:
Tensor
,
process_group
:
ProcessGroup
)
->
Tensor
:
ctx
.
process_group
=
process_group
output
,
_
=
all_gather_raw
(
input_
,
process_group
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
:
Tensor
):
grad_input
,
_
=
reduce_scatter_raw
(
grad_output
,
ctx
.
process_group
)
return
grad_input
,
None
# Supports autograd, but does not support async
all_gather
=
AllGatherFunc
.
apply
class
ReduceScatterFunc
(
torch
.
autograd
.
Function
):
"""Reduce scatter the input from the sequence parallel region and concatenate."""
@
staticmethod
def
forward
(
ctx
,
input_
:
Tensor
,
process_group
:
ProcessGroup
)
->
Tensor
:
ctx
.
process_group
=
process_group
output
,
_
=
reduce_scatter_raw
(
input_
,
process_group
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
:
Tensor
):
grad_input
,
_
=
all_gather_raw
(
grad_output
,
ctx
.
process_group
)
return
grad_input
,
None
# Supports autograd, but does not support async
reduce_scatter
=
ReduceScatterFunc
.
apply
class
AllReduceFunc
(
torch
.
autograd
.
Function
):
"""Gather the input from sequence parallel region and concatenate."""
@
staticmethod
def
forward
(
ctx
,
input_
:
Tensor
,
process_group
:
ProcessGroup
)
->
Tensor
:
ctx
.
process_group
=
process_group
output
,
_
=
all_reduce_raw
(
input_
,
process_group
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
:
Tensor
):
return
grad_output
,
None
# Supports autograd, but does not support async
all_reduce
=
AllReduceFunc
.
apply
def
sync_shared_params
(
model
:
torch
.
nn
.
Module
,
process_group
:
ProcessGroup
):
# We want to iterate over parameters with _shared_params=True in the same order,
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
pamams_shared
=
{
name
:
p
for
name
,
p
in
model
.
named_parameters
()
if
getattr
(
p
,
"_shared_params"
,
False
)
}
for
_
,
p
in
sorted
(
pamams_shared
.
items
()):
with
torch
.
no_grad
():
# Broadcast needs src to be global rank, not group rank
torch
.
distributed
.
broadcast
(
p
,
src
=
torch
.
distributed
.
get_global_rank
(
process_group
,
0
),
group
=
process_group
)
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256
def
allreduce_sequence_parallel_grad
(
model
:
torch
.
nn
.
Module
,
process_group
:
ProcessGroup
):
# We want to iterate over parameters with _sequence_parallel=True in the same order,
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
params_seqparallel
=
{
name
:
p
for
name
,
p
in
model
.
named_parameters
()
if
getattr
(
p
,
"_sequence_parallel"
,
False
)
}
grads
=
[
p
.
grad
for
_
,
p
in
sorted
(
params_seqparallel
.
items
())]
if
grads
:
with
torch
.
no_grad
():
coalesced
=
torch
.
_utils
.
_flatten_dense_tensors
(
grads
)
torch
.
distributed
.
all_reduce
(
coalesced
,
group
=
process_group
)
for
buf
,
synced
in
zip
(
grads
,
torch
.
_utils
.
_unflatten_dense_tensors
(
coalesced
,
grads
)):
buf
.
copy_
(
synced
)
def
get_dim_for_local_rank
(
dim
:
int
,
world_size
:
int
,
local_rank
:
int
,
multiple_of
:
int
=
1
)
->
int
:
"""Get the dim for the local rank derived from splitting dim on world_size processes.
The split may not be even across the world_size processes.
"""
multiple
=
dim
//
multiple_of
div
=
multiple
//
world_size
mod
=
multiple
%
world_size
local_multiple
=
div
+
int
(
local_rank
<
mod
)
return
local_multiple
*
multiple_of
flash_attn/utils/generation.py
deleted
100644 → 0
View file @
5018ac6a
# Copyright (c) 2023, Tri Dao.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
import
gc
import
time
from
collections
import
namedtuple
from
dataclasses
import
dataclass
,
field
from
functools
import
partial
from
typing
import
Callable
,
Optional
,
Sequence
,
Union
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
from
torch
import
Tensor
from
torch.profiler
import
ProfilerActivity
,
profile
,
record_function
try
:
from
transformers.generation
import
GreedySearchDecoderOnlyOutput
,
SampleDecoderOnlyOutput
except
ImportError
:
GreedySearchDecoderOnlyOutput
=
namedtuple
(
"GreedySearchDecoderOnlyOutput"
,
[
"sequences"
,
"scores"
])
SampleDecoderOnlyOutput
=
namedtuple
(
"SampleDecoderOnlyOutput"
,
[
"sequences"
,
"scores"
])
@
dataclass
class
InferenceParams
:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
max_seqlen
:
int
max_batch_size
:
int
seqlen_offset
:
int
=
0
batch_size_offset
:
int
=
0
key_value_memory_dict
:
dict
=
field
(
default_factory
=
dict
)
lengths_per_sample
:
Optional
[
Tensor
]
=
None
def
reset
(
self
,
max_seqlen
,
max_batch_size
):
self
.
max_seqlen
=
max_seqlen
self
.
max_batch_size
=
max_batch_size
self
.
seqlen_offset
=
0
if
self
.
lengths_per_sample
is
not
None
:
self
.
lengths_per_sample
.
zero_
()
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
def
modify_logits_for_top_k_filtering
(
logits
,
top_k
):
"""Set the logits for none top-k values to -inf. Done in-place."""
indices_to_remove
=
logits
<
torch
.
topk
(
logits
,
top_k
)[
0
][...,
-
1
,
None
]
logits
.
masked_fill_
(
indices_to_remove
,
float
(
"-Inf"
))
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
def
modify_logits_for_top_p_filtering
(
logits
,
top_p
):
"""Set the logits for none top-p values to -inf. Done in-place."""
if
top_p
<=
0.0
or
top_p
>=
1.0
:
return
# First sort and calculate cumulative sum of probabilities.
sorted_logits
,
sorted_indices
=
torch
.
sort
(
logits
,
descending
=
False
)
cumulative_probs
=
sorted_logits
.
softmax
(
dim
=-
1
).
cumsum
(
dim
=-
1
)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove
=
cumulative_probs
<=
(
1
-
top_p
)
# scatter sorted tensors to original indexing
indices_to_remove
=
sorted_indices_to_remove
.
scatter
(
1
,
sorted_indices
,
sorted_indices_to_remove
)
logits
.
masked_fill_
(
indices_to_remove
,
float
(
"-inf"
))
def
sample
(
logits
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
):
"""Sample from top-k logits.
Arguments:
logits: Tensor of shape (batch_size, vocab_size)
"""
if
top_k
==
1
:
# Short-circuit for greedy decoding
return
logits
.
argmax
(
dim
=-
1
)
else
:
if
top_p
>
0.0
:
assert
top_p
<=
1.0
,
"top-p should be in (0, 1]."
if
top_k
>
0
:
top_k
=
min
(
top_k
,
logits
.
size
(
-
1
))
# Safety check
logits_top
,
indices
=
torch
.
topk
(
logits
,
top_k
,
dim
=-
1
)
if
temperature
!=
1.0
:
logits_top
/=
temperature
modify_logits_for_top_p_filtering
(
logits_top
,
top_p
)
return
indices
[
torch
.
arange
(
indices
.
shape
[
0
],
device
=
indices
.
device
),
torch
.
multinomial
(
torch
.
softmax
(
logits_top
,
dim
=-
1
),
num_samples
=
1
).
squeeze
(
dim
=-
1
),
]
else
:
# Clone so that when we modify for top_p we don't change the original logits
logits_top
=
logits
/
temperature
if
temperature
!=
1.0
else
logits
.
clone
()
modify_logits_for_top_p_filtering
(
logits_top
,
top_p
)
return
torch
.
multinomial
(
torch
.
softmax
(
logits_top
,
dim
=-
1
),
num_samples
=
1
).
squeeze
(
dim
=-
1
)
@
torch
.
inference_mode
()
def
decode
(
input_ids
,
model
,
max_length
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
,
eos_token_id
=
None
,
teacher_outputs
=
None
,
vocab_size
=
None
,
tensor_parallel
=
1
,
cg
=
False
,
enable_timing
=
False
,
):
"""Decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
then top-p.
We assume that all sequences in the same batch have the same length.
Arguments:
input_ids: (batch, seq_len)
max_length: int
teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
logits, the next token is taken from the teacher_outputs. Useful for testing.
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
sequences: (batch, max_length)
scores: tuples of (batch, vocab_size)
"""
batch_size
,
seqlen_og
=
input_ids
.
shape
teacher_output_len
=
teacher_outputs
.
shape
[
1
]
if
teacher_outputs
is
not
None
else
0
if
cg
:
if
not
hasattr
(
model
,
"_decoding_cache"
):
model
.
_decoding_cache
=
None
model
.
_decoding_cache
=
update_graph_cache
(
model
,
model
.
_decoding_cache
,
batch_size
,
seqlen_og
,
max_length
,
tensor_parallel
=
tensor_parallel
,
)
inference_params
=
model
.
_decoding_cache
.
inference_params
inference_params
.
reset
(
max_length
,
batch_size
)
else
:
inference_params
=
InferenceParams
(
max_seqlen
=
max_length
,
max_batch_size
=
batch_size
)
def
get_logits
(
input_ids
,
inference_params
):
decoding
=
inference_params
.
seqlen_offset
>
0
if
decoding
:
position_ids
=
torch
.
full
(
(
batch_size
,
1
),
inference_params
.
seqlen_offset
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
,
)
else
:
position_ids
=
None
if
not
cg
or
not
decoding
:
logits
=
model
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
,
num_last_tokens
=
1
,
).
logits
.
squeeze
(
dim
=
1
)
else
:
logits
=
model
.
_decoding_cache
.
run
(
input_ids
,
position_ids
,
inference_params
.
seqlen_offset
).
squeeze
(
dim
=
1
)
return
logits
[...,
:
vocab_size
]
if
vocab_size
is
not
None
else
logits
def
sample_tokens
(
logits
,
inference_params
):
if
teacher_outputs
is
None
or
teacher_output_len
<=
inference_params
.
seqlen_offset
:
token
=
sample
(
logits
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
)
else
:
token
=
teacher_outputs
[:,
inference_params
.
seqlen_offset
]
# return rearrange(token, "b -> b 1")
return
token
.
unsqueeze
(
1
)
def
should_stop
(
current_token
,
inference_params
):
if
inference_params
.
seqlen_offset
==
0
:
return
False
if
eos_token_id
is
not
None
and
(
current_token
==
eos_token_id
).
all
():
return
True
if
inference_params
.
seqlen_offset
>=
max_length
-
1
:
return
True
return
False
start
=
torch
.
cuda
.
Event
(
enable_timing
=
enable_timing
)
end
=
torch
.
cuda
.
Event
(
enable_timing
=
enable_timing
)
if
enable_timing
:
if
tensor_parallel
>
1
:
torch
.
distributed
.
barrier
()
start
.
record
()
scores
,
sequences
=
[],
[
input_ids
]
while
not
should_stop
(
sequences
[
-
1
],
inference_params
):
scores
.
append
(
get_logits
(
sequences
[
-
1
],
inference_params
))
inference_params
.
seqlen_offset
+=
sequences
[
-
1
].
shape
[
1
]
sequences
.
append
(
sample_tokens
(
scores
[
-
1
],
inference_params
))
if
enable_timing
:
end
.
record
()
if
tensor_parallel
>
1
:
torch
.
distributed
.
barrier
()
torch
.
cuda
.
synchronize
()
print
(
f
"Prompt processing + decoding time:
{
(
start
.
elapsed_time
(
end
)):.
0
f
}
ms"
)
output_cls
=
GreedySearchDecoderOnlyOutput
if
top_k
==
1
else
SampleDecoderOnlyOutput
return
output_cls
(
sequences
=
torch
.
cat
(
sequences
,
dim
=
1
),
scores
=
tuple
(
scores
))
def
sample_speculative
(
logits
,
logits_draft
,
tokens_draft
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
):
"""Algorithm 1 from [1]
[1] Fast Inference from Transformers via Speculative Decoding
Yaniv Leviathan, Matan Kalman, Yossi Matias
https://arxiv.org/abs/2211.17192
Arguments:
logits: Tensor of shape (batch_size, seqlen + 1, vocab_size)
logits_draft: Tensor of shape (batch_size, seqlen, vocab_size)
tokens_draft: Tensor of shape (batch_size, seqlen)
Return:
tokens: Tensor of shape (batch_size, seqlen + 1)
num_generated_tokens: Tensor of shape (batch_size), with value in [1, seqlen + 1].
For each sequence in the batch, the number of valid tokens that were sampled by
speculative sampling.
"""
batch
,
seqlen_p_1
,
vocab_size
=
logits
.
shape
seqlen
=
seqlen_p_1
-
1
assert
logits_draft
.
shape
==
(
batch
,
seqlen
,
vocab_size
)
assert
tokens_draft
.
shape
==
(
batch
,
seqlen
)
assert
tokens_draft
.
dtype
in
[
torch
.
int64
,
torch
.
int32
]
# TODO: if top_k = 1 we can simplify things and only work with indices
if
top_p
>
0.0
:
assert
top_p
<=
1.0
,
"top-p should be in (0, 1]."
# Clone so that when we modify for top_p we don't change the original logits
logits
=
logits
/
temperature
if
temperature
!=
1.0
else
logits
.
clone
()
logits_draft
=
logits_draft
/
temperature
if
temperature
!=
1.0
else
logits_draft
.
clone
()
if
top_k
>
0
:
top_k
=
min
(
top_k
,
logits
.
size
(
-
1
))
# Safety check
modify_logits_for_top_k_filtering
(
logits
,
top_k
)
modify_logits_for_top_k_filtering
(
logits_draft
,
top_k
)
modify_logits_for_top_p_filtering
(
logits
,
top_p
)
modify_logits_for_top_p_filtering
(
logits_draft
,
top_p
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
probs_draft
=
torch
.
softmax
(
logits_draft
,
dim
=-
1
)
gather
=
lambda
probs
,
tokens
:
rearrange
(
probs
.
gather
(
dim
=-
1
,
index
=
rearrange
(
tokens
,
"... -> ... 1"
)),
"... 1 -> ..."
)
# (batch, seqlen)
accepted
=
torch
.
rand
(
batch
,
seqlen
,
device
=
probs
.
device
)
*
gather
(
probs_draft
,
tokens_draft
)
<=
gather
(
probs
[:,
:
-
1
],
tokens_draft
)
accepted_all
=
accepted
.
all
(
dim
=-
1
)
# (batch,)
first_rejected_idx
=
torch
.
where
(
accepted_all
,
seqlen
,
accepted
.
int
().
argmin
(
dim
=-
1
))
probs_diff
=
torch
.
clamp
(
probs
[:,
:
-
1
]
-
probs_draft
,
min
=
0.0
)
# torch.multinomial can deal with unnormalized probabilities
# probs_diff /= probs_diff.sum(dim=-1, keepdim=True)
resample_probs
=
torch
.
cat
([
probs_diff
,
probs
[:,
-
1
:]],
dim
=
1
)
resample_probs
=
rearrange
(
resample_probs
.
gather
(
dim
=
1
,
index
=
repeat
(
first_rejected_idx
,
"b -> b 1 d"
,
d
=
vocab_size
)),
"b 1 d -> b d"
,
)
resample
=
torch
.
multinomial
(
resample_probs
,
num_samples
=
1
).
squeeze
(
dim
=-
1
)
# (batch,)
tokens
=
F
.
pad
(
tokens_draft
,
(
0
,
1
))
tokens
[:,
first_rejected_idx
]
=
resample
return
tokens
,
first_rejected_idx
+
1
@
torch
.
inference_mode
()
def
decode_speculative
(
input_ids
,
model
,
model_draft
,
max_length
,
speculative_lookahead
=
3
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
,
eos_token_id
=
None
,
vocab_size
=
None
,
tensor_parallel
=
1
,
cg
=
False
,
enable_timing
=
False
,
debug
=
False
,
):
"""
TD: WIP, for my own understanding, lightly tested. Only support batch_size == 1 for now.
Speculative decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
then top-p.
We assume that all sequences in the same batch have the same length.
Arguments:
input_ids: (batch, seq_len)
max_length: int
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
sequences: (batch, max_length)
scores: tuples of (batch, vocab_size)
"""
batch_size
,
seqlen_og
=
input_ids
.
shape
assert
batch_size
==
1
,
"Speculative decoding implementation only supports batch_size=1"
assert
eos_token_id
is
None
,
"Speculative decoding implementation doesn't support eos_token_id"
if
cg
:
if
not
hasattr
(
model_draft
,
"_decoding_cache"
):
model_draft
.
_decoding_cache
=
None
model_draft
.
_decoding_cache
=
update_graph_cache
(
model_draft
,
model_draft
.
_decoding_cache
,
batch_size
,
seqlen_og
,
max_length
,
# draft model needs to process either 1 or 2 tokens at a time
decoding_seqlens
=
(
1
,
2
),
tensor_parallel
=
tensor_parallel
,
)
inference_params_draft
=
model_draft
.
_decoding_cache
.
inference_params
inference_params_draft
.
reset
(
max_length
,
batch_size
)
if
not
hasattr
(
model
,
"_decoding_cache"
):
model
.
_decoding_cache
=
None
model
.
_decoding_cache
=
update_graph_cache
(
model
,
model
.
_decoding_cache
,
batch_size
,
seqlen_og
,
max_length
,
decoding_seqlens
=
range
(
1
,
speculative_lookahead
+
2
),
tensor_parallel
=
tensor_parallel
,
)
inference_params
=
model
.
_decoding_cache
.
inference_params
inference_params
.
reset
(
max_length
,
batch_size
)
else
:
inference_params_draft
=
InferenceParams
(
max_seqlen
=
max_length
,
max_batch_size
=
batch_size
)
inference_params
=
InferenceParams
(
max_seqlen
=
max_length
,
max_batch_size
=
batch_size
)
def
get_logits
(
input_ids
,
inference_params
,
model
,
num_last_tokens
=
1
,
cg
=
False
):
decoding
=
inference_params
.
seqlen_offset
>
0
if
decoding
:
seqlen
=
input_ids
.
shape
[
1
]
# if inference_params.lengths_per_sample is None:
# TODO: in the case of batched decoding where each sequence has a different length,
# we need to compute the position_ids for each sequence using lengths_per_sample
if
True
:
cache_seqlens
=
torch
.
full
(
(
input_ids
.
shape
[
0
],),
inference_params
.
seqlen_offset
,
dtype
=
torch
.
int32
,
device
=
input_ids
.
device
,
)
else
:
cache_seqlens
=
inference_params
.
lengths_per_sample
position_ids
=
cache_seqlens
[:,
None
]
+
torch
.
arange
(
seqlen
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
else
:
position_ids
=
None
if
not
cg
or
not
decoding
:
logits
=
model
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
,
num_last_tokens
=
num_last_tokens
,
).
logits
else
:
# NOTE: careful, CUDA graph is set to have num_last_tokens=input_ids.shape[1].
# This might not be compatible the num_last_tokens used here.
assert
num_last_tokens
<=
input_ids
.
shape
[
1
]
logits
=
model
.
_decoding_cache
.
run
(
input_ids
,
position_ids
,
inference_params
.
seqlen_offset
)[:,
-
num_last_tokens
:]
return
logits
[...,
:
vocab_size
]
if
vocab_size
is
not
None
else
logits
def
sample_tokens
(
input_ids
,
get_logits_fn
,
inference_params
,
sample_fn
,
num_tokens
=
1
):
"""Sample `num_tokens` tokens from the model, given the previous logits.
Also return the logits of the sampled tokens.
Arguments:
input_ids: (batch, seqlen)
Return:
tokens: (batch, num_tokens)
scores: (batch, num_tokens), which contains @previous_logits and the logits of the next
(num_tokens - 1) tokens. The logits of the last token isn't computed.
"""
assert
num_tokens
>=
1
sequences
,
scores
=
[
input_ids
],
[]
for
i
in
range
(
num_tokens
):
scores
.
append
(
get_logits_fn
(
sequences
[
-
1
],
inference_params
)[:,
-
1
])
inference_params
.
seqlen_offset
+=
sequences
[
-
1
].
shape
[
1
]
sequences
.
append
(
sample_fn
(
scores
[
-
1
]).
unsqueeze
(
1
))
return
torch
.
cat
(
sequences
[
1
:],
dim
=
1
),
torch
.
stack
(
scores
,
dim
=
1
)
sampling_kwargs
=
dict
(
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
)
sample_fn
=
partial
(
sample
,
**
sampling_kwargs
)
get_logits_main
=
partial
(
get_logits
,
model
=
model
,
cg
=
cg
)
get_logits_draft
=
partial
(
get_logits
,
model
=
model_draft
,
cg
=
cg
)
sample_tokens_main
=
partial
(
sample_tokens
,
get_logits_fn
=
get_logits_main
,
sample_fn
=
sample_fn
,
inference_params
=
inference_params
,
)
sample_tokens_draft
=
partial
(
sample_tokens
,
get_logits_fn
=
get_logits_draft
,
sample_fn
=
sample_fn
,
inference_params
=
inference_params_draft
,
)
if
debug
:
from
transformers
import
AutoTokenizer
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"gpt2"
)
if
enable_timing
:
if
tensor_parallel
>
1
:
torch
.
distributed
.
barrier
()
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
sequences
,
scores
=
[
input_ids
],
[]
num_main_model_calls
=
0
num_draft_tokens
=
0
num_accepted_tokens_history
=
[]
if
seqlen_og
>=
max_length
-
1
:
# Don't do speculative sampling, just sample 1 token from the model
tokens
,
scores_new
=
sample_tokens_main
(
input_ids
,
num_tokens
=
1
)
sequences
.
append
(
tokens
)
scores
.
append
(
scores_new
)
else
:
# Sample from draft model, which produces @n_spec_tokens, and @model
# will then use to produce between 1 and 1 + @n_spec_tokens tokens.
# We want seqlen_og + 1 + @n_spec_tokens to be <= @max_length.
n_spec_tokens
=
min
(
speculative_lookahead
,
max_length
-
seqlen_og
-
1
)
tokens_draft
,
scores_draft
=
sample_tokens_draft
(
input_ids
,
num_tokens
=
n_spec_tokens
)
num_draft_tokens
+=
n_spec_tokens
if
debug
:
scores_draft_ref
=
model_draft
(
torch
.
cat
([
input_ids
,
tokens_draft
],
dim
=
1
),
num_last_tokens
=
n_spec_tokens
+
1
).
logits
print
((
scores_draft
-
scores_draft_ref
[:,
:
-
1
]).
abs
().
max
())
# Evaluate the draft tokens with the model
logits
=
get_logits_main
(
torch
.
cat
([
input_ids
,
tokens_draft
],
dim
=
1
),
inference_params
,
num_last_tokens
=
n_spec_tokens
+
1
,
)
num_main_model_calls
+=
1
if
debug
:
logits_ref
=
model
(
torch
.
cat
([
input_ids
,
tokens_draft
],
dim
=
1
),
num_last_tokens
=
n_spec_tokens
+
1
).
logits
print
((
logits
-
logits_ref
).
abs
().
max
())
# breakpoint()
tokens
,
num_generated_tokens
=
sample_speculative
(
logits
,
scores_draft
,
tokens_draft
,
**
sampling_kwargs
)
num_accepted_tokens_history
.
append
(
num_generated_tokens
-
1
)
if
debug
:
print
(
tokens
)
print
(
num_generated_tokens
)
# breakpoint()
# TODO: we're using the fact that batch_size == 1
# TODO: check eos_token_id
sequences
.
append
(
tokens
[:
1
,
:
num_generated_tokens
[
0
]])
scores
.
append
(
logits
[:
1
,
:
num_generated_tokens
[
0
]])
# Note that @model has not evaluated the last sampled token yet, so we'll need to pass
# that in the next time we call @model.
num_generated
=
num_generated_tokens
[
0
].
item
()
inference_params
.
seqlen_offset
=
seqlen_og
+
num_generated
-
1
inference_params_draft
.
seqlen_offset
=
(
inference_params
.
seqlen_offset
-
1
if
num_generated
>
1
else
inference_params
.
seqlen_offset
)
if
debug
:
cur_ids
=
torch
.
cat
([
input_ids
,
sequences
[
-
1
]],
dim
=
1
)
scores_ref
=
model
(
cur_ids
,
num_last_tokens
=
num_generated_tokens
[
0
].
item
()
+
1
).
logits
print
((
scores
[
-
1
]
-
scores_ref
[:,
:
-
1
]).
abs
().
max
())
# breakpoint()
while
True
:
# seqlen_offset is total length generated - 1
if
inference_params
.
seqlen_offset
>=
max_length
-
1
:
break
if
inference_params
.
seqlen_offset
>=
max_length
-
2
:
# Don't do speculative sampling, just sample 1 token from the model
tokens
,
scores_new
=
sample_tokens_main
(
sequences
[
-
1
][:,
-
1
:],
num_tokens
=
1
)
sequences
.
append
(
tokens
)
scores
.
append
(
scores_new
)
break
# Sample from draft model
n_spec_tokens
=
min
(
speculative_lookahead
,
max_length
-
inference_params_draft
.
seqlen_offset
-
2
)
# If the main model accepts all the draft tokens, plus it samples one new token,
# then at the next iteration the draft model need to evaluate the logits of the last draft
# token and the logits of the newly sampled token. So here we pass in the last 2 tokens
# of sequences[-1].
# This exception is when the main model rejects all the draft tokens, in which case we
# will only have 1 token to pass in.
tokens_draft
,
scores_draft
=
sample_tokens_draft
(
sequences
[
-
1
][:,
-
2
:],
num_tokens
=
n_spec_tokens
)
num_draft_tokens
+=
n_spec_tokens
if
debug
:
scores_draft_ref
=
model_draft
(
torch
.
cat
([
cur_ids
,
tokens_draft
],
dim
=
1
),
num_last_tokens
=
n_spec_tokens
+
1
).
logits
print
((
scores_draft
-
scores_draft_ref
[:,
:
-
1
]).
abs
().
max
())
# breakpoint()
# Evaluate the draft tokens with the model
logits
=
get_logits_main
(
torch
.
cat
([
sequences
[
-
1
][:,
-
1
:],
tokens_draft
],
dim
=
1
),
inference_params
,
num_last_tokens
=
n_spec_tokens
+
1
,
)
# (batch, n_spec_tokens + 1, vocab_size)
num_main_model_calls
+=
1
if
debug
:
logits_ref
=
model
(
torch
.
cat
([
cur_ids
,
tokens_draft
],
dim
=
1
),
num_last_tokens
=
n_spec_tokens
+
1
).
logits
print
((
logits
-
logits_ref
).
abs
().
max
())
# breakpoint()
tokens
,
num_generated_tokens
=
sample_speculative
(
logits
,
scores_draft
,
tokens_draft
,
**
sampling_kwargs
)
num_accepted_tokens_history
.
append
(
num_generated_tokens
-
1
)
if
debug
:
print
(
tokens
)
print
(
num_generated_tokens
)
# breakpoint()
sequences
.
append
(
tokens
[:
1
,
:
num_generated_tokens
[
0
]])
scores
.
append
(
logits
[:
1
,
:
num_generated_tokens
[
0
]])
# We've evaluated 1 token from sequences[-1][:, -1:] above, plus
# num_generated_tokens[0].item() - 1 tokens from the draft model.
num_generated
=
num_generated_tokens
[
0
].
item
()
inference_params
.
seqlen_offset
+=
num_generated
inference_params_draft
.
seqlen_offset
=
(
inference_params
.
seqlen_offset
-
1
if
num_generated
>
1
else
inference_params
.
seqlen_offset
)
if
debug
:
cur_ids
=
torch
.
cat
([
cur_ids
,
sequences
[
-
1
]],
dim
=
1
)
scores_ref
=
model
(
cur_ids
,
num_last_tokens
=
num_generated_tokens
[
0
].
item
()
+
1
).
logits
print
((
scores
[
-
1
]
-
scores_ref
[:,
:
-
1
]).
abs
().
max
())
# breakpoint()
if
enable_timing
:
if
tensor_parallel
>
1
:
torch
.
distributed
.
barrier
()
torch
.
cuda
.
synchronize
()
print
(
f
"Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms"
)
print
(
f
"Number of calls to main model:
{
num_main_model_calls
}
"
)
print
(
f
"Acceptance rate:
{
torch
.
cat
(
num_accepted_tokens_history
).
sum
().
item
()
/
num_draft_tokens
*
100
:.
2
f
}
%"
)
sequences
=
torch
.
cat
(
sequences
,
dim
=
1
)
scores
=
torch
.
cat
(
scores
,
dim
=
1
)
if
debug
:
scores_ref
=
model
(
sequences
).
logits
print
((
scores
-
scores_ref
[:,
seqlen_og
-
1
:
-
1
]).
abs
().
max
())
output_cls
=
GreedySearchDecoderOnlyOutput
if
top_k
==
1
else
SampleDecoderOnlyOutput
return
output_cls
(
sequences
=
sequences
,
scores
=
scores
)
class
GenerationMixin
:
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
**
kwargs
):
raise
NotImplementedError
def
generate
(
self
,
input_ids
,
max_length
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
,
return_dict_in_generate
=
False
,
output_scores
=
False
,
**
kwargs
,
):
output
=
decode
(
input_ids
,
self
,
max_length
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
,
**
kwargs
)
if
not
output_scores
:
output
.
scores
=
None
return
output
if
return_dict_in_generate
else
output
.
sequences
def
allocate_inference_cache
(
max_batch_size
,
max_seqlen
,
nheads
,
headdim
,
layers
:
Union
[
int
,
Sequence
],
device
,
dtype
=
torch
.
float16
,
):
assert
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
kv_cache_shape
=
(
max_batch_size
,
max_seqlen
,
2
,
nheads
,
headdim
)
if
isinstance
(
layers
,
int
):
layers
=
range
(
layers
)
return
{
i
:
torch
.
empty
(
kv_cache_shape
,
device
=
device
,
dtype
=
dtype
)
for
i
in
layers
}
@
dataclass
class
DecodingCGCache
:
max_batch_size
:
int
=
0
max_seqlen
:
int
=
0
device
=
None
dtype
=
None
callables
:
dict
=
field
(
default_factory
=
dict
)
mempool
=
None
inference_params
:
Optional
[
InferenceParams
]
=
None
run
:
Optional
[
Callable
]
=
None
@
torch
.
inference_mode
()
def
update_graph_cache
(
model
,
cache
,
batch_size
,
seqlen_og
,
max_seqlen
,
decoding_seqlens
=
(
1
,),
tensor_parallel
=
1
,
dtype
=
None
,
n_warmups
=
2
,
):
if
cache
is
None
:
cache
=
DecodingCGCache
()
param_example
=
next
(
iter
(
model
.
parameters
()))
device
=
param_example
.
device
if
dtype
is
None
:
dtype
=
param_example
.
dtype
if
(
(
device
,
dtype
)
!=
(
cache
.
device
,
cache
.
dtype
)
or
batch_size
>
cache
.
max_batch_size
or
max_seqlen
>
cache
.
max_seqlen
):
# Invalidate the cache
cache
.
callables
=
{}
cache
.
mempool
=
None
cache
.
inference_params
=
None
gc
.
collect
()
cache
.
device
,
cache
.
dtype
=
device
,
dtype
cache
.
max_batch_size
,
cache
.
max_seqlen
=
batch_size
,
max_seqlen
if
hasattr
(
model
,
"allocate_inference_cache"
):
inf_cache
=
model
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
)
else
:
headdim
=
getattr
(
model
.
config
,
"head_dim"
,
model
.
config
.
hidden_size
//
model
.
config
.
num_attention_heads
,
)
inf_cache
=
allocate_inference_cache
(
batch_size
,
max_seqlen
,
model
.
config
.
num_attention_heads
//
tensor_parallel
,
headdim
,
model
.
config
.
num_hidden_layers
,
device
,
dtype
,
)
lengths_per_sample
=
torch
.
full
((
batch_size
,),
seqlen_og
,
dtype
=
torch
.
int32
,
device
=
device
)
cache
.
inference_params
=
InferenceParams
(
max_seqlen
=
max_seqlen
,
max_batch_size
=
batch_size
,
seqlen_offset
=
seqlen_og
,
key_value_memory_dict
=
inf_cache
,
lengths_per_sample
=
lengths_per_sample
,
)
cache
.
mempool
=
torch
.
cuda
.
graphs
.
graph_pool_handle
()
for
decoding_seqlen
in
decoding_seqlens
:
if
(
batch_size
,
decoding_seqlen
)
not
in
cache
.
callables
:
cache
.
callables
[
batch_size
,
decoding_seqlen
]
=
capture_graph
(
model
,
cache
.
inference_params
,
batch_size
,
max_seqlen
,
decoding_seqlen
=
decoding_seqlen
,
mempool
=
cache
.
mempool
,
n_warmups
=
n_warmups
,
)
def
dispatch
(
input_ids
,
position_ids
,
seqlen
):
batch_size
,
decoding_seqlen
=
input_ids
.
shape
[:
2
]
return
cache
.
callables
[
batch_size
,
decoding_seqlen
](
input_ids
,
position_ids
,
seqlen
)
cache
.
run
=
dispatch
cache
.
inference_params
.
seqlen_offset
=
0
# Reset so it's not confusing
return
cache
def
capture_graph
(
model
,
inference_params
,
batch_size
,
max_seqlen
,
decoding_seqlen
=
1
,
mempool
=
None
,
n_warmups
=
2
):
device
=
next
(
iter
(
model
.
parameters
())).
device
input_ids
=
torch
.
full
((
batch_size
,
decoding_seqlen
),
0
,
dtype
=
torch
.
long
,
device
=
device
)
position_ids
=
torch
.
full
((
batch_size
,
decoding_seqlen
),
0
,
dtype
=
torch
.
long
,
device
=
device
)
seqlen_offset_og
=
inference_params
.
seqlen_offset
inference_params
.
seqlen_offset
=
max_seqlen
-
decoding_seqlen
inference_params
.
lengths_per_sample
[:]
=
inference_params
.
seqlen_offset
# Warmup before capture
s
=
torch
.
cuda
.
Stream
()
s
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
s
):
for
_
in
range
(
n_warmups
):
logits
=
model
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
,
num_last_tokens
=
decoding_seqlen
,
).
logits
s
.
synchronize
()
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
# which requires that graph launch and non-captured launch to not overlap (I think,
# that's how I interpret the documentation). I'm not sure if this is required.
if
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
barrier
()
torch
.
cuda
.
current_stream
().
wait_stream
(
s
)
# Captures the graph
# To allow capture, automatically sets a side stream as the current stream in the context
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
,
pool
=
mempool
):
logits
=
model
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
,
num_last_tokens
=
decoding_seqlen
,
).
logits
def
run
(
new_input_ids
,
new_position_ids
,
seqlen
):
inference_params
.
lengths_per_sample
[:]
=
seqlen
input_ids
.
copy_
(
new_input_ids
)
position_ids
.
copy_
(
new_position_ids
)
graph
.
replay
()
return
logits
.
clone
()
inference_params
.
seqlen_offset
=
seqlen_offset_og
return
run
flash_attn/utils/pretrained.py
deleted
100644 → 0
View file @
5018ac6a
import
os
from
functools
import
partial
import
torch
from
safetensors.torch
import
load_file
as
safe_load_file
from
transformers.utils
import
(
SAFE_WEIGHTS_INDEX_NAME
,
SAFE_WEIGHTS_NAME
,
WEIGHTS_INDEX_NAME
,
WEIGHTS_NAME
,
)
from
transformers.utils.hub
import
cached_file
,
get_checkpoint_shard_files
def
state_dict_from_pretrained
(
model_name
,
device
=
None
,
dtype
=
None
):
# If not fp32, then we don't want to load directly to the GPU
mapped_device
=
"cpu"
if
dtype
not
in
[
torch
.
float32
,
None
]
else
device
is_sharded
=
False
load_safe
=
False
resolved_archive_file
=
None
weights_path
=
os
.
path
.
join
(
model_name
,
WEIGHTS_NAME
)
weights_index_path
=
os
.
path
.
join
(
model_name
,
WEIGHTS_INDEX_NAME
)
safe_weights_path
=
os
.
path
.
join
(
model_name
,
SAFE_WEIGHTS_NAME
)
safe_weights_index_path
=
os
.
path
.
join
(
model_name
,
SAFE_WEIGHTS_INDEX_NAME
)
if
os
.
path
.
isfile
(
weights_path
):
resolved_archive_file
=
cached_file
(
model_name
,
WEIGHTS_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
elif
os
.
path
.
isfile
(
weights_index_path
):
resolved_archive_file
=
cached_file
(
model_name
,
WEIGHTS_INDEX_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
is_sharded
=
True
elif
os
.
path
.
isfile
(
safe_weights_path
):
resolved_archive_file
=
cached_file
(
model_name
,
SAFE_WEIGHTS_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
load_safe
=
True
elif
os
.
path
.
isfile
(
safe_weights_index_path
):
resolved_archive_file
=
cached_file
(
model_name
,
SAFE_WEIGHTS_INDEX_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
is_sharded
=
True
load_safe
=
True
else
:
# Try loading from HF hub instead of from local files
resolved_archive_file
=
cached_file
(
model_name
,
WEIGHTS_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
if
resolved_archive_file
is
None
:
resolved_archive_file
=
cached_file
(
model_name
,
WEIGHTS_INDEX_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
if
resolved_archive_file
is
not
None
:
is_sharded
=
True
if
resolved_archive_file
is
None
:
raise
EnvironmentError
(
f
"Model name
{
model_name
}
was not found."
)
if
load_safe
:
loader
=
partial
(
safe_load_file
,
device
=
mapped_device
)
else
:
loader
=
partial
(
torch
.
load
,
map_location
=
mapped_device
)
if
is_sharded
:
# resolved_archive_file becomes a list of files that point to the different
# checkpoint shards in this case.
resolved_archive_file
,
sharded_metadata
=
get_checkpoint_shard_files
(
model_name
,
resolved_archive_file
)
state_dict
=
{}
for
sharded_file
in
resolved_archive_file
:
state_dict
.
update
(
loader
(
sharded_file
))
else
:
state_dict
=
loader
(
resolved_archive_file
)
# Convert dtype before moving to GPU to save memory
if
dtype
is
not
None
:
state_dict
=
{
k
:
v
.
to
(
dtype
=
dtype
)
for
k
,
v
in
state_dict
.
items
()}
state_dict
=
{
k
:
v
.
to
(
device
=
device
)
for
k
,
v
in
state_dict
.
items
()}
return
state_dict
setup.py
View file @
26f4b5fb
...
@@ -49,7 +49,7 @@ else:
...
@@ -49,7 +49,7 @@ else:
elif
BUILD_TARGET
==
"rocm"
:
elif
BUILD_TARGET
==
"rocm"
:
IS_ROCM
=
True
IS_ROCM
=
True
PACKAGE_NAME
=
"flash_attn"
PACKAGE_NAME
=
"
vllm_
flash_attn"
BASE_WHEEL_URL
=
(
BASE_WHEEL_URL
=
(
"https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}"
"https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}"
...
@@ -57,10 +57,10 @@ BASE_WHEEL_URL = (
...
@@ -57,10 +57,10 @@ BASE_WHEEL_URL = (
# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
FORCE_BUILD
=
os
.
getenv
(
"FLASH_ATTENTION_FORCE_BUILD"
,
"FALSE"
)
==
"TRUE"
FORCE_BUILD
=
True
SKIP_CUDA_BUILD
=
os
.
getenv
(
"FLASH_ATTENTION_SKIP_CUDA_BUILD"
,
"FALSE"
)
==
"TRUE"
SKIP_CUDA_BUILD
=
False
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
FORCE_CXX11_ABI
=
os
.
getenv
(
"FLASH_ATTENTION_FORCE_CXX11_ABI"
,
"FALSE"
)
==
"TRUE"
FORCE_CXX11_ABI
=
torch
.
_C
.
_GLIBCXX_USE_CXX11_ABI
def
get_platform
():
def
get_platform
():
...
@@ -151,7 +151,7 @@ if not SKIP_CUDA_BUILD and not IS_ROCM:
...
@@ -151,7 +151,7 @@ if not SKIP_CUDA_BUILD and not IS_ROCM:
if
os
.
path
.
exists
(
os
.
path
.
join
(
torch_dir
,
"include"
,
"ATen"
,
"CUDAGeneratorImpl.h"
)):
if
os
.
path
.
exists
(
os
.
path
.
join
(
torch_dir
,
"include"
,
"ATen"
,
"CUDAGeneratorImpl.h"
)):
generator_flag
=
[
"-DOLD_GENERATOR_PATH"
]
generator_flag
=
[
"-DOLD_GENERATOR_PATH"
]
check_if_cuda_home_none
(
"flash_attn"
)
check_if_cuda_home_none
(
PACKAGE_NAME
)
# Check, if CUDA11 is installed for compute capability 8.0
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag
=
[]
cc_flag
=
[]
if
CUDA_HOME
is
not
None
:
if
CUDA_HOME
is
not
None
:
...
@@ -177,7 +177,7 @@ if not SKIP_CUDA_BUILD and not IS_ROCM:
...
@@ -177,7 +177,7 @@ if not SKIP_CUDA_BUILD and not IS_ROCM:
torch
.
_C
.
_GLIBCXX_USE_CXX11_ABI
=
True
torch
.
_C
.
_GLIBCXX_USE_CXX11_ABI
=
True
ext_modules
.
append
(
ext_modules
.
append
(
CUDAExtension
(
CUDAExtension
(
name
=
"flash_attn_2_cuda"
,
name
=
"
vllm_
flash_attn_2_cuda"
,
sources
=
[
sources
=
[
"csrc/flash_attn/flash_api.cpp"
,
"csrc/flash_attn/flash_api.cpp"
,
"csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu"
,
...
@@ -208,34 +208,6 @@ if not SKIP_CUDA_BUILD and not IS_ROCM:
...
@@ -208,34 +208,6 @@ if not SKIP_CUDA_BUILD and not IS_ROCM:
"csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu"
,
...
@@ -282,10 +254,10 @@ if not SKIP_CUDA_BUILD and not IS_ROCM:
...
@@ -282,10 +254,10 @@ if not SKIP_CUDA_BUILD and not IS_ROCM:
# "--ptxas-options=-O2",
# "--ptxas-options=-O2",
# "-lineinfo",
# "-lineinfo",
# "-DFLASHATTENTION_DISABLE_BACKWARD",
# "-DFLASHATTENTION_DISABLE_BACKWARD",
#
"-DFLASHATTENTION_DISABLE_DROPOUT",
"-DFLASHATTENTION_DISABLE_DROPOUT"
,
# "-DFLASHATTENTION_DISABLE_ALIBI",
# "-DFLASHATTENTION_DISABLE_ALIBI",
# "-DFLASHATTENTION_DISABLE_SOFTCAP",
# "-DFLASHATTENTION_DISABLE_SOFTCAP",
#
"-DFLASHATTENTION_DISABLE_UNEVEN_K",
"-DFLASHATTENTION_DISABLE_UNEVEN_K"
,
# "-DFLASHATTENTION_DISABLE_LOCAL",
# "-DFLASHATTENTION_DISABLE_LOCAL",
]
]
+
generator_flag
+
generator_flag
...
@@ -391,7 +363,7 @@ elif not SKIP_CUDA_BUILD and IS_ROCM:
...
@@ -391,7 +363,7 @@ elif not SKIP_CUDA_BUILD and IS_ROCM:
def
get_package_version
():
def
get_package_version
():
with
open
(
Path
(
this_dir
)
/
"flash_attn"
/
"__init__.py"
,
"r"
)
as
f
:
with
open
(
Path
(
this_dir
)
/
PACKAGE_NAME
/
"__init__.py"
,
"r"
)
as
f
:
version_match
=
re
.
search
(
r
"^__version__\s*=\s*(.*)$"
,
f
.
read
(),
re
.
MULTILINE
)
version_match
=
re
.
search
(
r
"^__version__\s*=\s*(.*)$"
,
f
.
read
(),
re
.
MULTILINE
)
public_version
=
ast
.
literal_eval
(
version_match
.
group
(
1
))
public_version
=
ast
.
literal_eval
(
version_match
.
group
(
1
))
local_version
=
os
.
environ
.
get
(
"FLASH_ATTN_LOCAL_VERSION"
)
local_version
=
os
.
environ
.
get
(
"FLASH_ATTN_LOCAL_VERSION"
)
...
@@ -401,37 +373,6 @@ def get_package_version():
...
@@ -401,37 +373,6 @@ def get_package_version():
return
str
(
public_version
)
return
str
(
public_version
)
def
get_wheel_url
():
torch_version_raw
=
parse
(
torch
.
__version__
)
python_version
=
f
"cp
{
sys
.
version_info
.
major
}{
sys
.
version_info
.
minor
}
"
platform_name
=
get_platform
()
flash_version
=
get_package_version
()
torch_version
=
f
"
{
torch_version_raw
.
major
}
.
{
torch_version_raw
.
minor
}
"
cxx11_abi
=
str
(
torch
.
_C
.
_GLIBCXX_USE_CXX11_ABI
).
upper
()
if
IS_ROCM
:
torch_hip_version
=
parse
(
torch
.
version
.
hip
.
split
()[
-
1
].
rstrip
(
'-'
).
replace
(
'-'
,
'+'
))
hip_version
=
f
"
{
torch_hip_version
.
major
}{
torch_hip_version
.
minor
}
"
wheel_filename
=
f
"
{
PACKAGE_NAME
}
-
{
flash_version
}
+rocm
{
hip_version
}
torch
{
torch_version
}
cxx11abi
{
cxx11_abi
}
-
{
python_version
}
-
{
python_version
}
-
{
platform_name
}
.whl"
else
:
# Determine the version numbers that will be used to determine the correct wheel
# We're using the CUDA version used to build torch, not the one currently installed
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
torch_cuda_version
=
parse
(
torch
.
version
.
cuda
)
# For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3
# to save CI time. Minor versions should be compatible.
torch_cuda_version
=
parse
(
"11.8"
)
if
torch_cuda_version
.
major
==
11
else
parse
(
"12.3"
)
# cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
cuda_version
=
f
"
{
torch_cuda_version
.
major
}{
torch_cuda_version
.
minor
}
"
# Determine wheel URL based on CUDA version, torch version, python version and OS
wheel_filename
=
f
"
{
PACKAGE_NAME
}
-
{
flash_version
}
+cu
{
cuda_version
}
torch
{
torch_version
}
cxx11abi
{
cxx11_abi
}
-
{
python_version
}
-
{
python_version
}
-
{
platform_name
}
.whl"
wheel_url
=
BASE_WHEEL_URL
.
format
(
tag_name
=
f
"v
{
flash_version
}
"
,
wheel_name
=
wheel_filename
)
return
wheel_url
,
wheel_filename
class
CachedWheelsCommand
(
_bdist_wheel
):
class
CachedWheelsCommand
(
_bdist_wheel
):
"""
"""
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
...
@@ -444,28 +385,6 @@ class CachedWheelsCommand(_bdist_wheel):
...
@@ -444,28 +385,6 @@ class CachedWheelsCommand(_bdist_wheel):
if
FORCE_BUILD
:
if
FORCE_BUILD
:
return
super
().
run
()
return
super
().
run
()
wheel_url
,
wheel_filename
=
get_wheel_url
()
print
(
"Guessing wheel URL: "
,
wheel_url
)
try
:
urllib
.
request
.
urlretrieve
(
wheel_url
,
wheel_filename
)
# Make the archive
# Lifted from the root wheel processing command
# https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
if
not
os
.
path
.
exists
(
self
.
dist_dir
):
os
.
makedirs
(
self
.
dist_dir
)
impl_tag
,
abi_tag
,
plat_tag
=
self
.
get_tag
()
archive_basename
=
f
"
{
self
.
wheel_dist_name
}
-
{
impl_tag
}
-
{
abi_tag
}
-
{
plat_tag
}
"
wheel_path
=
os
.
path
.
join
(
self
.
dist_dir
,
archive_basename
+
".whl"
)
print
(
"Raw wheel path"
,
wheel_path
)
os
.
rename
(
wheel_filename
,
wheel_path
)
except
(
urllib
.
error
.
HTTPError
,
urllib
.
error
.
URLError
):
print
(
"Precompiled wheel not found. Building from source..."
)
# If the wheel could not be downloaded, build from source
super
().
run
()
class
NinjaBuildExtension
(
BuildExtension
):
class
NinjaBuildExtension
(
BuildExtension
):
def
__init__
(
self
,
*
args
,
**
kwargs
)
->
None
:
def
__init__
(
self
,
*
args
,
**
kwargs
)
->
None
:
...
@@ -487,8 +406,11 @@ class NinjaBuildExtension(BuildExtension):
...
@@ -487,8 +406,11 @@ class NinjaBuildExtension(BuildExtension):
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
PYTORCH_VERSION
=
"2.4.0"
CUDA_VERSION
=
"12.1"
setup
(
setup
(
name
=
PACKAGE_NAME
,
name
=
"vllm-flash-attn"
,
version
=
get_package_version
(),
version
=
get_package_version
(),
packages
=
find_packages
(
packages
=
find_packages
(
exclude
=
(
exclude
=
(
...
@@ -499,15 +421,13 @@ setup(
...
@@ -499,15 +421,13 @@ setup(
"dist"
,
"dist"
,
"docs"
,
"docs"
,
"benchmarks"
,
"benchmarks"
,
"flash_attn
.egg-info"
,
f
"
{
PACKAGE_NAME
}
.egg-info"
,
)
)
),
),
author
=
"Tri Dao"
,
author
=
"vLLM Team"
,
author_email
=
"tri@tridao.me"
,
description
=
"Forward-only flash-attn"
,
description
=
"Flash Attention: Fast and Memory-Efficient Exact Attention"
,
long_description
=
f
"Forward-only flash-attn package built for PyTorch
{
PYTORCH_VERSION
}
and CUDA
{
CUDA_VERSION
}
"
,
long_description
=
long_description
,
url
=
"https://github.com/vllm-project/flash-attention.git"
,
long_description_content_type
=
"text/markdown"
,
url
=
"https://github.com/Dao-AILab/flash-attention"
,
classifiers
=
[
classifiers
=
[
"Programming Language :: Python :: 3"
,
"Programming Language :: Python :: 3"
,
"License :: OSI Approved :: BSD License"
,
"License :: OSI Approved :: BSD License"
,
...
@@ -520,13 +440,6 @@ setup(
...
@@ -520,13 +440,6 @@ setup(
"bdist_wheel"
:
CachedWheelsCommand
,
"bdist_wheel"
:
CachedWheelsCommand
,
},
},
python_requires
=
">=3.8"
,
python_requires
=
">=3.8"
,
install_requires
=
[
install_requires
=
[
f
"torch ==
{
PYTORCH_VERSION
}
"
],
"torch"
,
setup_requires
=
[
"psutil"
],
"einops"
,
],
setup_requires
=
[
"packaging"
,
"psutil"
,
"ninja"
,
],
)
)
tests/test_flash_attn.py
View file @
26f4b5fb
...
@@ -1588,7 +1588,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
...
@@ -1588,7 +1588,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
],
],
)
)
# TODO: add smaller page sizes when https://github.com/Dao-AILab/flash-attention/pull/824 is merged
# TODO: add smaller page sizes when https://github.com/Dao-AILab/flash-attention/pull/824 is merged
@
pytest
.
mark
.
parametrize
(
"paged_kv_block_size"
,
[
None
,
256
,
512
])
@
pytest
.
mark
.
parametrize
(
"paged_kv_block_size"
,
[
None
,
16
,
256
,
512
])
# @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)])
# @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)])
def
test_flash_attn_varlen_causal
(
def
test_flash_attn_varlen_causal
(
seqlen_q
,
seqlen_k
,
swap_sq_sk
,
d
,
local
,
paged_kv_block_size
,
dtype
seqlen_q
,
seqlen_k
,
swap_sq_sk
,
d
,
local
,
paged_kv_block_size
,
dtype
...
@@ -1875,7 +1875,7 @@ def test_flash_attn_splitkv(
...
@@ -1875,7 +1875,7 @@ def test_flash_attn_splitkv(
# @pytest.mark.parametrize("rotary_interleaved", [False])
# @pytest.mark.parametrize("rotary_interleaved", [False])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
0.0
,
0.5
,
1.0
])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
0.0
,
0.5
,
1.0
])
# @pytest.mark.parametrize("rotary_fraction", [0.0])
# @pytest.mark.parametrize("rotary_fraction", [0.0])
@
pytest
.
mark
.
parametrize
(
"paged_kv_block_size"
,
[
None
,
256
])
@
pytest
.
mark
.
parametrize
(
"paged_kv_block_size"
,
[
None
,
16
,
256
])
# @pytest.mark.parametrize("paged_kv_block_size", [256, 512])
# @pytest.mark.parametrize("paged_kv_block_size", [256, 512])
# @pytest.mark.parametrize("paged_kv_block_size", [None])
# @pytest.mark.parametrize("paged_kv_block_size", [None])
@
pytest
.
mark
.
parametrize
(
"has_leftpad"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"has_leftpad"
,
[
False
,
True
])
...
@@ -2523,3 +2523,47 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus
...
@@ -2523,3 +2523,47 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus
assert
torch
.
equal
(
dv
,
dv0
)
assert
torch
.
equal
(
dv
,
dv0
)
assert
torch
.
equal
(
dk
,
dk0
)
assert
torch
.
equal
(
dk
,
dk0
)
assert
torch
.
equal
(
dq
,
dq0
)
assert
torch
.
equal
(
dq
,
dq0
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [False])
@
pytest
.
mark
.
parametrize
(
"paged_kv_block_size"
,
[
16
])
# @pytest.mark.parametrize("has_batch_idx", [False])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"nheads"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"b"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
10
])
@
pytest
.
mark
.
parametrize
(
"seqlen_q,seqlen_k"
,
[(
170
,
170
)])
def
test_flash_attn_paged_kvcache_overflow
(
seqlen_q
,
seqlen_k
,
d
,
nheads
,
b
,
n
,
paged_kv_block_size
,
causal
,
dtype
,
):
device
=
"cuda"
num_blocks
=
1000
*
16
//
paged_kv_block_size
key_cache
=
torch
.
rand
([
num_blocks
,
paged_kv_block_size
,
nheads
,
d
],
dtype
=
dtype
,
device
=
device
)
value_cache
=
torch
.
rand
([
num_blocks
,
paged_kv_block_size
,
nheads
,
d
],
dtype
=
dtype
,
device
=
device
)
cache_seqlens
=
torch
.
zeros
(
b
,
dtype
=
torch
.
int32
,
device
=
device
)
for
_
in
range
(
n
):
query
=
torch
.
rand
([
b
,
seqlen_q
,
nheads
,
d
],
dtype
=
dtype
,
device
=
device
)
key
=
torch
.
rand
([
b
,
seqlen_k
,
nheads
,
d
],
dtype
=
dtype
,
device
=
device
)
value
=
torch
.
rand
([
b
,
seqlen_k
,
nheads
,
d
],
dtype
=
dtype
,
device
=
device
)
block_tables
=
torch
.
randint
(
0
,
num_blocks
,
size
=
(
b
,
(
seqlen_k
+
paged_kv_block_size
-
1
)
//
paged_kv_block_size
),
dtype
=
torch
.
int32
,
device
=
device
)
output
=
flash_attn_with_kvcache
(
query
,
key_cache
,
value_cache
,
k
=
key
,
v
=
value
,
cache_seqlens
=
cache_seqlens
,
block_table
=
block_tables
,
causal
=
causal
,
)
flash_attn/__init__.py
→
vllm_
flash_attn/__init__.py
View file @
26f4b5fb
__version__
=
"2.6.
3
"
__version__
=
"2.6.
0
"
from
flash_attn.flash_attn_interface
import
(
from
vllm_
flash_attn.flash_attn_interface
import
(
flash_attn_func
,
flash_attn_func
,
flash_attn_kvpacked_func
,
flash_attn_kvpacked_func
,
flash_attn_qkvpacked_func
,
flash_attn_qkvpacked_func
,
...
...
flash_attn/flash_attn_interface.py
→
vllm_
flash_attn/flash_attn_interface.py
View file @
26f4b5fb
...
@@ -7,7 +7,7 @@ import torch.nn as nn
...
@@ -7,7 +7,7 @@ import torch.nn as nn
# isort: off
# isort: off
# We need to import the CUDA kernels after importing torch
# We need to import the CUDA kernels after importing torch
import
flash_attn_2_cuda
as
flash_attn_cuda
import
vllm_
flash_attn_2_cuda
as
flash_attn_cuda
# isort: on
# isort: on
...
@@ -46,14 +46,14 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal):
...
@@ -46,14 +46,14 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal):
def
_flash_attn_forward
(
def
_flash_attn_forward
(
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
return_softmax
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
return_softmax
,
*
,
out
=
None
):
):
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
flash_attn_cuda
.
fwd
(
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
flash_attn_cuda
.
fwd
(
q
,
q
,
k
,
k
,
v
,
v
,
None
,
out
,
alibi_slopes
,
alibi_slopes
,
dropout_p
,
dropout_p
,
softmax_scale
,
softmax_scale
,
...
@@ -91,7 +91,7 @@ def _flash_attn_varlen_forward(
...
@@ -91,7 +91,7 @@ def _flash_attn_varlen_forward(
q
,
q
,
k
,
k
,
v
,
v
,
None
,
out
,
cu_seqlens_q
,
cu_seqlens_q
,
cu_seqlens_k
,
cu_seqlens_k
,
seqused_k
,
seqused_k
,
...
@@ -239,6 +239,8 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
...
@@ -239,6 +239,8 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
alibi_slopes
,
alibi_slopes
,
deterministic
,
deterministic
,
return_softmax
,
return_softmax
,
*
,
out
=
None
,
):
):
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
...
@@ -253,6 +255,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
...
@@ -253,6 +255,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
softcap
=
softcap
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
out
=
out
,
)
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
rng_state
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
ctx
.
dropout_p
=
dropout_p
...
@@ -307,6 +310,8 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
...
@@ -307,6 +310,8 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
alibi_slopes
,
alibi_slopes
,
deterministic
,
deterministic
,
return_softmax
,
return_softmax
,
*
,
out
=
None
,
):
):
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
...
@@ -326,6 +331,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
...
@@ -326,6 +331,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
block_table
=
None
,
block_table
=
None
,
out
=
out
,
)
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
cu_seqlens
,
rng_state
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
cu_seqlens
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
ctx
.
dropout_p
=
dropout_p
...
@@ -384,6 +390,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
...
@@ -384,6 +390,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
alibi_slopes
,
alibi_slopes
,
deterministic
,
deterministic
,
return_softmax
,
return_softmax
,
out
=
None
,
):
):
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
...
@@ -398,6 +405,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
...
@@ -398,6 +405,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
softcap
=
softcap
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
out
=
out
,
)
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
rng_state
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
ctx
.
dropout_p
=
dropout_p
...
@@ -457,6 +465,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
...
@@ -457,6 +465,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
alibi_slopes
,
alibi_slopes
,
deterministic
,
deterministic
,
return_softmax
,
return_softmax
,
out
=
None
,
):
):
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
...
@@ -476,6 +485,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
...
@@ -476,6 +485,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
block_table
=
None
,
block_table
=
None
,
out
=
out
,
)
)
ctx
.
save_for_backward
(
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
cu_seqlens_q
,
cu_seqlens_k
,
rng_state
q
,
k
,
v
,
out_padded
,
softmax_lse
,
cu_seqlens_q
,
cu_seqlens_k
,
rng_state
...
@@ -540,6 +550,7 @@ class FlashAttnFunc(torch.autograd.Function):
...
@@ -540,6 +550,7 @@ class FlashAttnFunc(torch.autograd.Function):
alibi_slopes
,
alibi_slopes
,
deterministic
,
deterministic
,
return_softmax
,
return_softmax
,
out
=
None
,
):
):
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
...
@@ -554,6 +565,7 @@ class FlashAttnFunc(torch.autograd.Function):
...
@@ -554,6 +565,7 @@ class FlashAttnFunc(torch.autograd.Function):
softcap
=
softcap
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
out
=
out
,
)
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
rng_state
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
ctx
.
dropout_p
=
dropout_p
...
@@ -614,6 +626,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
...
@@ -614,6 +626,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
deterministic
,
deterministic
,
return_softmax
,
return_softmax
,
block_table
,
block_table
,
out
=
None
,
):
):
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
...
@@ -633,6 +646,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
...
@@ -633,6 +646,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
block_table
=
block_table
,
block_table
=
block_table
,
out
=
out
,
)
)
ctx
.
save_for_backward
(
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
cu_seqlens_q
,
cu_seqlens_k
,
rng_state
q
,
k
,
v
,
out_padded
,
softmax_lse
,
cu_seqlens_q
,
cu_seqlens_k
,
rng_state
...
@@ -691,6 +705,8 @@ def flash_attn_qkvpacked_func(
...
@@ -691,6 +705,8 @@ def flash_attn_qkvpacked_func(
alibi_slopes
=
None
,
alibi_slopes
=
None
,
deterministic
=
False
,
deterministic
=
False
,
return_attn_probs
=
False
,
return_attn_probs
=
False
,
*
,
out
=
None
,
):
):
"""dropout_p should be set to 0.0 during evaluation
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
If Q, K, V are already stacked into 1 tensor, this function will be faster than
...
@@ -736,6 +752,7 @@ def flash_attn_qkvpacked_func(
...
@@ -736,6 +752,7 @@ def flash_attn_qkvpacked_func(
alibi_slopes
,
alibi_slopes
,
deterministic
,
deterministic
,
return_attn_probs
,
return_attn_probs
,
out
,
)
)
...
@@ -750,6 +767,8 @@ def flash_attn_kvpacked_func(
...
@@ -750,6 +767,8 @@ def flash_attn_kvpacked_func(
alibi_slopes
=
None
,
alibi_slopes
=
None
,
deterministic
=
False
,
deterministic
=
False
,
return_attn_probs
=
False
,
return_attn_probs
=
False
,
*
,
out
=
None
,
):
):
"""dropout_p should be set to 0.0 during evaluation
"""dropout_p should be set to 0.0 during evaluation
If K, V are already stacked into 1 tensor, this function will be faster than
If K, V are already stacked into 1 tensor, this function will be faster than
...
@@ -813,6 +832,7 @@ def flash_attn_kvpacked_func(
...
@@ -813,6 +832,7 @@ def flash_attn_kvpacked_func(
alibi_slopes
,
alibi_slopes
,
deterministic
,
deterministic
,
return_attn_probs
,
return_attn_probs
,
out
,
)
)
...
@@ -828,6 +848,8 @@ def flash_attn_func(
...
@@ -828,6 +848,8 @@ def flash_attn_func(
alibi_slopes
=
None
,
alibi_slopes
=
None
,
deterministic
=
False
,
deterministic
=
False
,
return_attn_probs
=
False
,
return_attn_probs
=
False
,
*
,
out
=
None
,
):
):
"""dropout_p should be set to 0.0 during evaluation
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
...
@@ -889,6 +911,7 @@ def flash_attn_func(
...
@@ -889,6 +911,7 @@ def flash_attn_func(
alibi_slopes
,
alibi_slopes
,
deterministic
,
deterministic
,
return_attn_probs
,
return_attn_probs
,
out
,
)
)
...
@@ -904,6 +927,8 @@ def flash_attn_varlen_qkvpacked_func(
...
@@ -904,6 +927,8 @@ def flash_attn_varlen_qkvpacked_func(
alibi_slopes
=
None
,
alibi_slopes
=
None
,
deterministic
=
False
,
deterministic
=
False
,
return_attn_probs
=
False
,
return_attn_probs
=
False
,
*
,
out
=
None
,
):
):
"""dropout_p should be set to 0.0 during evaluation
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
If Q, K, V are already stacked into 1 tensor, this function will be faster than
...
@@ -954,6 +979,7 @@ def flash_attn_varlen_qkvpacked_func(
...
@@ -954,6 +979,7 @@ def flash_attn_varlen_qkvpacked_func(
alibi_slopes
,
alibi_slopes
,
deterministic
,
deterministic
,
return_attn_probs
,
return_attn_probs
,
out
,
)
)
...
@@ -972,6 +998,8 @@ def flash_attn_varlen_kvpacked_func(
...
@@ -972,6 +998,8 @@ def flash_attn_varlen_kvpacked_func(
alibi_slopes
=
None
,
alibi_slopes
=
None
,
deterministic
=
False
,
deterministic
=
False
,
return_attn_probs
=
False
,
return_attn_probs
=
False
,
*
,
out
=
None
,
):
):
"""dropout_p should be set to 0.0 during evaluation
"""dropout_p should be set to 0.0 during evaluation
If K, V are already stacked into 1 tensor, this function will be faster than
If K, V are already stacked into 1 tensor, this function will be faster than
...
@@ -1045,6 +1073,7 @@ def flash_attn_varlen_kvpacked_func(
...
@@ -1045,6 +1073,7 @@ def flash_attn_varlen_kvpacked_func(
alibi_slopes
,
alibi_slopes
,
deterministic
,
deterministic
,
return_attn_probs
,
return_attn_probs
,
out
,
)
)
...
@@ -1065,6 +1094,8 @@ def flash_attn_varlen_func(
...
@@ -1065,6 +1094,8 @@ def flash_attn_varlen_func(
deterministic
=
False
,
deterministic
=
False
,
return_attn_probs
=
False
,
return_attn_probs
=
False
,
block_table
=
None
,
block_table
=
None
,
*
,
out
=
None
,
):
):
"""dropout_p should be set to 0.0 during evaluation
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
...
@@ -1138,6 +1169,7 @@ def flash_attn_varlen_func(
...
@@ -1138,6 +1169,7 @@ def flash_attn_varlen_func(
deterministic
,
deterministic
,
return_attn_probs
,
return_attn_probs
,
block_table
,
block_table
,
out
,
)
)
...
@@ -1161,6 +1193,8 @@ def flash_attn_with_kvcache(
...
@@ -1161,6 +1193,8 @@ def flash_attn_with_kvcache(
alibi_slopes
=
None
,
alibi_slopes
=
None
,
num_splits
=
0
,
num_splits
=
0
,
return_softmax_lse
=
False
,
return_softmax_lse
=
False
,
*
,
out
=
None
,
):
):
"""
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
...
@@ -1274,7 +1308,7 @@ def flash_attn_with_kvcache(
...
@@ -1274,7 +1308,7 @@ def flash_attn_with_kvcache(
cache_leftpad
,
cache_leftpad
,
block_table
,
block_table
,
alibi_slopes
,
alibi_slopes
,
None
,
out
,
softmax_scale
,
softmax_scale
,
causal
,
causal
,
window_size
[
0
],
window_size
[
0
],
...
...
flash_attn/pyproject.toml
→
vllm_
flash_attn/pyproject.toml
View file @
26f4b5fb
File moved
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment