Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1863c926
Commit
1863c926
authored
Jul 06, 2024
by
zhuwenwen
Browse files
Use triton fa by default
parent
b6247705
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
1334 additions
and
15 deletions
+1334
-15
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+26
-14
vllm/attention/ops/flash_attn_triton_mqa_gqa.py
vllm/attention/ops/flash_attn_triton_mqa_gqa.py
+1307
-0
vllm/envs.py
vllm/envs.py
+1
-1
No files found.
vllm/attention/backends/rocm_flash_attn.py
View file @
1863c926
...
@@ -229,9 +229,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -229,9 +229,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self
.
use_triton_flash_attn
=
envs
.
VLLM_USE_TRITON_FLASH_ATTN
self
.
use_triton_flash_attn
=
envs
.
VLLM_USE_TRITON_FLASH_ATTN
if
self
.
use_triton_flash_attn
:
if
self
.
use_triton_flash_attn
:
from
vllm.attention.ops.triton_flash_attention
import
(
# noqa: F401
# from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
triton_attention
)
# triton_attention)
self
.
attn_func
=
triton_attention
from
vllm.attention.ops.flash_attn_triton_mqa_gqa
import
(
flash_attn_varlen_func
)
self
.
attn_func
=
flash_attn_varlen_func
# triton_attention
logger
.
debug
(
"Using Triton FA in ROCmBackend"
)
logger
.
debug
(
"Using Triton FA in ROCmBackend"
)
else
:
else
:
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
...
@@ -325,17 +327,27 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -325,17 +327,27 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# When block_tables are not filled, it means q and k are the
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
# prompt, and they have the same length.
if
self
.
use_triton_flash_attn
:
if
self
.
use_triton_flash_attn
:
out
,
_
=
self
.
attn_func
(
# out, _ = self.attn_func(
query
,
# query,
key
,
# key,
value
,
# value,
None
,
# None,
prefill_meta
.
seq_start_loc
,
# prefill_meta.seq_start_loc,
prefill_meta
.
seq_start_loc
,
# prefill_meta.seq_start_loc,
prefill_meta
.
max_prefill_seq_len
,
# prefill_meta.max_prefill_seq_len,
prefill_meta
.
max_prefill_seq_len
,
# prefill_meta.max_prefill_seq_len,
True
,
# True,
self
.
scale
,
# self.scale,
out
=
self
.
attn_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlens_q
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlens_k
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
)
)
elif
self
.
use_naive_attn
:
elif
self
.
use_naive_attn
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
...
...
vllm/attention/ops/flash_attn_triton_mqa_gqa.py
0 → 100644
View file @
1863c926
#!/usr/bin/env python
"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf)
Credits: OpenAI kernel team, AMD ML Frameworks Triton team
Features supported:
1) Fwd with causal masking
2) Any sequence lengths without padding (currently fwd kernel only)
3) Support for different sequence lengths for q and k
4) Nested tensor API currently does not support dropout or bias.
Not currently supported:
1) Non power of two head dims
"""
import
argparse
import
pytest
import
random
import
sys
import
torch
import
triton
import
triton.language
as
tl
torch_dtype
:
tl
.
constexpr
=
torch
.
float16
TORCH_HAS_FP8E5
=
hasattr
(
torch
,
'float8_e5m2fnuz'
)
if
TORCH_HAS_FP8E5
:
torch_dtype
:
tl
.
constexpr
=
torch
.
float8_e5m2fnuz
class
MetaData
():
cu_seqlens_q
=
None
cu_seqlens_k
=
None
max_seqlens_q
=
0
max_seqlens_k
=
0
bias
=
None
alibi_slopes
=
None
causal
=
False
num_contexts
=
0
varlen
=
False
dropout_p
,
return_encoded_softmax
=
0.0
,
False
def
__init__
(
self
,
sm_scale
=
1.0
,
causal
=
False
,
dropout_p
=
0.0
,
return_encoded_softmax
=
False
):
self
.
sm_scale
=
sm_scale
self
.
causal
=
causal
self
.
dropout_p
=
dropout_p
self
.
return_encoded_softmax
=
return_encoded_softmax
def
set_varlen_params
(
self
,
cu_seqlens_q
,
cu_seqlens_k
):
self
.
varlen
=
True
self
.
cu_seqlens_q
=
cu_seqlens_q
self
.
cu_seqlens_k
=
cu_seqlens_k
# Without "varlen", there should still be one sequence.
assert
len
(
cu_seqlens_q
)
>=
2
assert
len
(
cu_seqlens_q
)
==
len
(
cu_seqlens_k
)
self
.
num_contexts
=
len
(
cu_seqlens_q
)
-
1
for
i
in
range
(
0
,
self
.
num_contexts
):
self
.
max_seqlens_q
=
max
(
cu_seqlens_q
[
i
+
1
].
item
()
-
cu_seqlens_q
[
i
].
item
(),
self
.
max_seqlens_q
)
self
.
max_seqlens_k
=
max
(
cu_seqlens_k
[
i
+
1
].
item
()
-
cu_seqlens_k
[
i
].
item
(),
self
.
max_seqlens_k
)
def
need_bias
(
self
,
bias
,
batch
,
nheads
,
seqlen_q
,
seqlen_k
):
assert
bias
.
is_cuda
assert
bias
.
dim
()
==
4
assert
bias
.
shape
[
0
]
==
1
assert
bias
.
shape
[
2
:]
==
(
seqlen_q
,
seqlen_k
)
self
.
bias
=
bias
def
need_alibi
(
self
,
alibi_slopes
,
batch
,
nheads
):
assert
alibi_slopes
.
is_cuda
assert
alibi_slopes
.
dim
()
==
2
assert
alibi_slopes
.
shape
[
0
]
==
batch
assert
alibi_slopes
.
shape
[
1
]
==
nheads
self
.
alibi_slopes
=
alibi_slopes
def
need_causal
(
self
):
self
.
causal
=
True
def
need_dropout
(
dropout_p
,
return_encoded_softmax
):
self
.
dropout_p
=
dropout_p
self
.
return_encoded_softmax
=
return_encoded_softmax
def
check_args
(
self
,
q
,
k
,
v
,
o
):
assert
q
.
dim
()
==
k
.
dim
()
and
q
.
dim
()
==
v
.
dim
()
if
self
.
varlen
:
assert
q
.
dim
()
==
3
total_q
,
nheads_q
,
head_size
=
q
.
shape
total_k
,
nheads_k
,
_
=
k
.
shape
assert
self
.
cu_seqlens_q
is
not
None
assert
self
.
cu_seqlens_k
is
not
None
assert
len
(
self
.
cu_seqlens_q
)
==
len
(
self
.
cu_seqlens_k
)
# TODO: Remove once bias is supported with varlen
assert
self
.
bias
==
None
# TODO:Remove once dropout is supported with varlen
assert
self
.
dropout_p
==
0.0
assert
not
self
.
return_encoded_softmax
else
:
assert
q
.
dim
()
==
4
batch
,
nheads_q
,
seqlen_q
,
head_size
=
q
.
shape
_
,
nheads_k
,
seqlen_k
,
_
=
k
.
shape
assert
self
.
max_seqlens_q
>
0
and
self
.
max_seqlens_k
>
0
assert
self
.
cu_seqlens_q
is
None
and
self
.
cu_seqlens_k
is
None
assert
k
.
shape
==
v
.
shape
assert
q
.
shape
[
-
1
]
==
k
.
shape
[
-
1
]
and
q
.
shape
[
-
1
]
==
v
.
shape
[
-
1
]
# TODO: Change assert if we support qkl f8 and v f16
assert
q
.
dtype
==
k
.
dtype
and
q
.
dtype
==
v
.
dtype
assert
head_size
<=
256
assert
o
.
shape
==
q
.
shape
assert
(
nheads_q
%
nheads_k
)
==
0
@
triton
.
jit
def
cdiv_fn
(
x
,
y
):
return
(
x
+
y
-
1
)
//
y
@
triton
.
jit
def
max_fn
(
x
,
y
):
return
tl
.
math
.
max
(
x
,
y
)
@
triton
.
jit
def
dropout_offsets
(
philox_seed
,
philox_offset
,
dropout_p
,
m
,
n
,
stride
):
ms
=
tl
.
arange
(
0
,
m
)
ns
=
tl
.
arange
(
0
,
n
)
return
philox_offset
+
ms
[:,
None
]
*
stride
+
ns
[
None
,
:]
@
triton
.
jit
def
dropout_rng
(
philox_seed
,
philox_offset
,
dropout_p
,
m
,
n
,
stride
):
rng_offsets
=
dropout_offsets
(
philox_seed
,
philox_offset
,
dropout_p
,
m
,
n
,
stride
).
to
(
tl
.
uint32
)
# TODO: use tl.randint for better performance
return
tl
.
rand
(
philox_seed
,
rng_offsets
)
@
triton
.
jit
def
dropout_mask
(
philox_seed
,
philox_offset
,
dropout_p
,
m
,
n
,
stride
):
rng_output
=
dropout_rng
(
philox_seed
,
philox_offset
,
dropout_p
,
m
,
n
,
stride
)
rng_keep
=
rng_output
>
dropout_p
return
rng_keep
@
triton
.
jit
def
load_fn
(
block_ptr
,
first
,
second
,
pad
):
if
first
and
second
:
tensor
=
tl
.
load
(
block_ptr
,
boundary_check
=
(
0
,
1
),
padding_option
=
pad
)
elif
first
:
tensor
=
tl
.
load
(
block_ptr
,
boundary_check
=
(
0
,),
padding_option
=
pad
)
elif
second
:
tensor
=
tl
.
load
(
block_ptr
,
boundary_check
=
(
1
,),
padding_option
=
pad
)
else
:
tensor
=
tl
.
load
(
block_ptr
)
return
tensor
@
triton
.
jit
def
print_gpu
(
prefix
,
val
=
None
):
if
(
tl
.
program_id
(
0
)
==
0
)
and
((
tl
.
program_id
(
1
)
==
0
)
and
(
tl
.
program_id
(
2
)
==
0
)):
if
val
is
not
None
:
tl
.
device_print
(
prefix
,
val
)
else
:
tl
.
device_print
(
prefix
)
@
triton
.
jit
def
compute_alibi_block
(
alibi_slope
,
seqlen_q
,
seqlen_k
,
offs_m
,
offs_n
,
transpose
=
False
):
# when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix
# for casual mask we want something like this where (1 is kept and 0 is masked)
# seqlen_q = 2 and seqlen_k = 5
# 1 1 1 1 0
# 1 1 1 1 1
# seqlen_q = 5 and seqlen_k = 2
# 0 0
# 0 0
# 0 0
# 1 0
# 1 1
# for alibi the diagonal is 0 indicating no penalty for attending to that spot and increasing penalty for attending further from the diagonal
# e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False
# 1. offs_m[:,None] = [[0],
# [1],
# 2. offs_m[:,None] + seqlen_k = [[5],
# [6],
# 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3],
# [4],
# 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1],
# [4], [ 4, 3, 2, 1, 0]]
# 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1],
# [ -4, -3, -2, -1, 0]],
relative_pos_block
=
offs_m
[:,
None
]
+
seqlen_k
-
seqlen_q
-
offs_n
[
None
,:]
alibi_block
=
-
1
*
alibi_slope
*
tl
.
abs
(
relative_pos_block
)
if
transpose
:
return
alibi_block
.
T
else
:
return
alibi_block
@
triton
.
jit
def
_attn_fwd_inner
(
acc
,
l_i
,
m_i
,
q
,
K_block_ptr
,
V_block_ptr
,
start_m
,
actual_seqlen_k
,
actual_seqlen_q
,
dropout_p
,
philox_seed
,
batch_philox_offset
,
encoded_softmax_block_ptr
,
block_min
,
block_max
,
offs_n_causal
,
masked_blocks
,
n_extra_tokens
,
bias_ptr
,
alibi_slope
,
IS_CAUSAL
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
OFFS_M
:
tl
.
constexpr
,
OFFS_N
:
tl
.
constexpr
,
PRE_LOAD_V
:
tl
.
constexpr
,
MASK_STEPS
:
tl
.
constexpr
,
ENABLE_DROPOUT
:
tl
.
constexpr
,
RETURN_ENCODED_SOFTMAX
:
tl
.
constexpr
,
PADDED_HEAD
:
tl
.
constexpr
):
# loop over k, v, and update accumulator
for
start_n
in
range
(
block_min
,
block_max
,
BLOCK_N
):
# For padded blocks, we will overrun the tensor size if
# we load all BLOCK_N. For others, the blocks are all within range.
k
=
load_fn
(
K_block_ptr
,
PADDED_HEAD
,
MASK_STEPS
and
(
n_extra_tokens
!=
0
),
"zero"
)
if
PRE_LOAD_V
:
v
=
load_fn
(
V_block_ptr
,
MASK_STEPS
and
(
n_extra_tokens
!=
0
),
PADDED_HEAD
,
"zero"
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
# We start from end of seqlen_k so only the first iteration would need
# to be checked for padding if it is not a multiple of block_n
# TODO: This can be optimized to only be true for the padded block.
if
MASK_STEPS
:
# If this is the last block / iteration, we want to
# mask if the sequence length is not a multiple of block size
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn.
# last step might get wasted but that is okay. check if this masking works For
# that case.
if
(
start_n
+
BLOCK_N
==
block_max
)
and
(
n_extra_tokens
!=
0
):
boundary_m
=
tl
.
full
([
BLOCK_M
],
actual_seqlen_k
,
dtype
=
tl
.
int32
)
size_n
=
start_n
+
OFFS_N
[
None
,:]
mask
=
size_n
<
boundary_m
[:,
None
]
qk
=
tl
.
where
(
mask
,
qk
,
float
(
"-inf"
))
if
IS_CAUSAL
:
causal_boundary
=
start_n
+
offs_n_causal
causal_mask
=
OFFS_M
[:,
None
]
>=
causal_boundary
[
None
,
:]
qk
=
tl
.
where
(
causal_mask
,
qk
,
float
(
"-inf"
))
# -- compute qk ----
qk
+=
tl
.
dot
(
q
,
k
)
if
bias_ptr
is
not
None
:
bias
=
load_fn
(
bias_ptr
,
False
,
MASK_STEPS
and
(
n_extra_tokens
!=
0
),
"zero"
)
# While bias is added after multiplying qk with sm_scale,
# our optimization to use 2^x instead of e^x results in an additional
# scale factor of log2(e) which we must also multiply the bias with.
qk
+=
(
bias
*
1.44269504089
)
if
alibi_slope
is
not
None
:
# Compute the global position of each token within the sequence
global_m_positions
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
global_n_positions
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
alibi_block
=
compute_alibi_block
(
alibi_slope
,
actual_seqlen_q
,
actual_seqlen_k
,
global_m_positions
,
global_n_positions
)
qk
+=
(
alibi_block
*
1.44269504089
)
# scale factor of log2(e)
# softmax
m_ij
=
tl
.
maximum
(
m_i
,
tl
.
max
(
qk
,
1
))
qk
=
qk
-
m_ij
[:,
None
]
p
=
tl
.
math
.
exp2
(
qk
)
# CAVEAT: Must update l_ij before applying dropout
l_ij
=
tl
.
sum
(
p
,
1
)
if
ENABLE_DROPOUT
:
philox_offset
=
batch_philox_offset
+
start_m
*
BLOCK_M
*
actual_seqlen_k
+
start_n
-
BLOCK_N
keep
=
dropout_mask
(
philox_seed
,
philox_offset
,
dropout_p
,
BLOCK_M
,
BLOCK_N
,
actual_seqlen_k
)
if
RETURN_ENCODED_SOFTMAX
:
tl
.
store
(
encoded_softmax_block_ptr
,
tl
.
where
(
keep
,
p
,
-
p
).
to
(
encoded_softmax_block_ptr
.
type
.
element_ty
))
p
=
tl
.
where
(
keep
,
p
,
0.0
)
elif
RETURN_ENCODED_SOFTMAX
:
tl
.
store
(
encoded_softmax_block_ptr
,
p
.
to
(
encoded_softmax_block_ptr
.
type
.
element_ty
))
# -- update output accumulator --
alpha
=
tl
.
math
.
exp2
(
m_i
-
m_ij
)
acc
=
acc
*
alpha
[:,
None
]
if
not
PRE_LOAD_V
:
v
=
load_fn
(
V_block_ptr
,
MASK_STEPS
and
(
n_extra_tokens
!=
0
),
PADDED_HEAD
,
"zero"
)
# -- update m_i and l_i
l_i
=
l_i
*
alpha
+
l_ij
# update m_i and l_i
m_i
=
m_ij
acc
+=
tl
.
dot
(
p
.
to
(
V_block_ptr
.
type
.
element_ty
),
v
)
V_block_ptr
=
tl
.
advance
(
V_block_ptr
,
(
BLOCK_N
,
0
))
K_block_ptr
=
tl
.
advance
(
K_block_ptr
,
(
0
,
BLOCK_N
))
if
bias_ptr
is
not
None
:
bias_ptr
=
tl
.
advance
(
bias_ptr
,
(
0
,
BLOCK_N
))
if
RETURN_ENCODED_SOFTMAX
:
encoded_softmax_block_ptr
=
tl
.
advance
(
encoded_softmax_block_ptr
,
(
0
,
BLOCK_N
))
return
acc
,
l_i
,
m_i
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
'BLOCK_M'
:
256
,
'BLOCK_N'
:
64
,
'waves_per_eu'
:
0
,
'PRE_LOAD_V'
:
False
},
num_stages
=
1
,
num_warps
=
8
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
128
,
'waves_per_eu'
:
0
,
'PRE_LOAD_V'
:
False
},
num_stages
=
1
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
256
,
'BLOCK_N'
:
128
,
'waves_per_eu'
:
0
,
'PRE_LOAD_V'
:
False
},
num_stages
=
1
,
num_warps
=
8
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
64
,
'waves_per_eu'
:
0
,
'PRE_LOAD_V'
:
True
},
num_stages
=
1
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
64
,
'waves_per_eu'
:
0
,
'PRE_LOAD_V'
:
False
},
num_stages
=
1
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
32
,
'waves_per_eu'
:
0
,
'PRE_LOAD_V'
:
True
},
num_stages
=
2
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
32
,
'waves_per_eu'
:
0
,
'PRE_LOAD_V'
:
False
},
num_stages
=
2
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
64
,
'waves_per_eu'
:
0
,
'PRE_LOAD_V'
:
False
},
num_stages
=
1
,
num_warps
=
8
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
64
,
'waves_per_eu'
:
0
,
'PRE_LOAD_V'
:
True
},
num_stages
=
1
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
32
,
'BLOCK_N'
:
32
,
'waves_per_eu'
:
0
,
'PRE_LOAD_V'
:
False
},
num_stages
=
1
,
num_warps
=
8
),
# TODO: This config fails with head_size not pow2 with data mismatches. Check why.
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
triton
.
Config
({
'BLOCK_M'
:
16
,
'BLOCK_N'
:
16
,
'waves_per_eu'
:
0
,
'PRE_LOAD_V'
:
False
},
num_stages
=
1
,
num_warps
=
4
),
],
key
=
[
'IS_CAUSAL'
,
'dropout_p'
,
'BLOCK_DMODEL'
],
# use_cuda_graph=True,
)
@
triton
.
jit
def
attn_fwd
(
Q
,
K
,
V
,
bias
,
sm_scale
,
L
,
Out
,
stride_qz
,
stride_qh
,
stride_qm
,
stride_qk
,
stride_kz
,
stride_kh
,
stride_kn
,
stride_kk
,
stride_vz
,
stride_vh
,
stride_vk
,
stride_vn
,
stride_oz
,
stride_oh
,
stride_om
,
stride_on
,
stride_bz
,
stride_bh
,
stride_bm
,
stride_bn
,
stride_az
,
stride_ah
,
cu_seqlens_q
,
cu_seqlens_k
,
dropout_p
,
philox_seed
,
philox_offset_base
,
encoded_softmax
,
alibi_slopes
,
HQ
:
tl
.
constexpr
,
HK
:
tl
.
constexpr
,
ACTUAL_BLOCK_DMODEL
:
tl
.
constexpr
,
MAX_SEQLENS_Q
:
tl
.
constexpr
,
MAX_SEQLENS_K
:
tl
.
constexpr
,
VARLEN
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
PRE_LOAD_V
:
tl
.
constexpr
,
BIAS_TYPE
:
tl
.
constexpr
,
ENABLE_DROPOUT
:
tl
.
constexpr
,
RETURN_ENCODED_SOFTMAX
:
tl
.
constexpr
,
USE_ALIBI
:
tl
.
constexpr
,
BATCH_SIZE
:
tl
.
constexpr
,
):
start_m
=
tl
.
program_id
(
0
)
off_h_q
=
tl
.
program_id
(
1
)
off_z
=
tl
.
program_id
(
2
)
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
if
VARLEN
:
cu_seqlens_q_start
=
tl
.
load
(
cu_seqlens_q
+
off_z
)
cu_seqlens_q_end
=
tl
.
load
(
cu_seqlens_q
+
off_z
+
1
)
seqlen_q
=
cu_seqlens_q_end
-
cu_seqlens_q_start
# We have a one-size-fits-all grid in id(0). Some seqlens might be too
# small for all start_m so for those we return early.
if
start_m
*
BLOCK_M
>
seqlen_q
:
return
cu_seqlens_k_start
=
tl
.
load
(
cu_seqlens_k
+
off_z
)
cu_seqlens_k_end
=
tl
.
load
(
cu_seqlens_k
+
off_z
+
1
)
seqlen_k
=
cu_seqlens_k_end
-
cu_seqlens_k_start
else
:
cu_seqlens_q_start
=
0
cu_seqlens_k_start
=
0
seqlen_q
=
MAX_SEQLENS_Q
seqlen_k
=
MAX_SEQLENS_K
# Now we compute whether we need to exit early due to causal masking.
# This is because for seqlen_q > seqlen_k, M rows of the attn scores
# are completely masked, resulting in 0s written to the output, and
# inf written to LSE. We don't need to do any GEMMs in this case.
# This block of code determines what N is, and if this WG is operating
# on those M rows.
n_blocks
=
cdiv_fn
(
seqlen_k
,
BLOCK_N
)
if
(
IS_CAUSAL
):
# If seqlen_q == seqlen_k, the attn scores are a square matrix.
# If seqlen_q != seqlen_k, attn scores are rectangular which means
# the causal mask boundary is bottom right aligned, and ends at either
# the top edge (seqlen_q < seqlen_k) or left edge.
# This captures the decrease in n_blocks if we have a rectangular attn matrix
n_blocks_seqlen
=
cdiv_fn
(
(
start_m
+
1
)
*
BLOCK_M
+
seqlen_k
-
seqlen_q
,
BLOCK_N
)
# This is what adjusts the block_max for the current WG, only
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
n_blocks
=
min
(
n_blocks
,
n_blocks_seqlen
)
# If we have no blocks after adjusting for seqlen deltas, this WG is part of
# the blocks that are all 0. We exit early.
if
n_blocks
<=
0
:
o_offset
=
off_z
*
stride_oz
+
cu_seqlens_q_start
*
stride_om
+
off_h_q
*
stride_oh
O_block_ptr
=
tl
.
make_block_ptr
(
base
=
Out
+
o_offset
,
shape
=
(
seqlen_q
,
BLOCK_DMODEL
),
strides
=
(
stride_om
,
stride_on
),
offsets
=
(
start_m
*
BLOCK_M
,
0
),
block_shape
=
(
BLOCK_M
,
BLOCK_DMODEL
),
order
=
(
1
,
0
)
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
Out
.
type
.
element_ty
)
# We still need to write 0s to the result
tl
.
store
(
O_block_ptr
,
acc
.
to
(
Out
.
type
.
element_ty
),
boundary_check
=
(
0
,
1
))
l_ptrs
=
L
+
off_z
*
HQ
*
MAX_SEQLENS_Q
+
off_h_q
*
MAX_SEQLENS_Q
+
offs_m
# We store inf to LSE, not -inf because in the bwd pass, we subtract this
# from qk which makes it -inf, such that exp(qk - inf) = 0 for these masked blocks.
l
=
tl
.
full
([
BLOCK_M
],
value
=
float
(
"inf"
),
dtype
=
tl
.
float32
)
tl
.
store
(
l_ptrs
,
l
)
# TODO: Should dropout and return encoded softmax be handled here too?
return
# If MQA / GQA, set the K and V head offsets appropriately.
GROUP_SIZE
:
tl
.
constexpr
=
HQ
//
HK
if
GROUP_SIZE
!=
1
:
off_h_k
=
off_h_q
//
GROUP_SIZE
else
:
off_h_k
=
off_h_q
need_padding
=
False
n_extra_tokens
=
0
if
seqlen_k
<
BLOCK_N
:
need_padding
=
True
n_extra_tokens
=
BLOCK_N
-
seqlen_k
elif
seqlen_k
%
BLOCK_N
:
need_padding
=
True
n_extra_tokens
=
seqlen_k
%
BLOCK_N
PADDED_HEAD
:
tl
.
constexpr
=
(
ACTUAL_BLOCK_DMODEL
!=
BLOCK_DMODEL
)
# Compute pointers for all the tensors used in this kernel.
q_offset
=
off_z
*
stride_qz
+
off_h_q
*
stride_qh
+
cu_seqlens_q_start
*
stride_qm
Q_block_ptr
=
tl
.
make_block_ptr
(
base
=
Q
+
q_offset
,
shape
=
(
seqlen_q
,
ACTUAL_BLOCK_DMODEL
),
strides
=
(
stride_qm
,
stride_qk
),
offsets
=
(
start_m
*
BLOCK_M
,
0
),
block_shape
=
(
BLOCK_M
,
BLOCK_DMODEL
),
order
=
(
1
,
0
)
)
k_offset
=
off_z
*
stride_kz
+
off_h_k
*
stride_kh
+
cu_seqlens_k_start
*
stride_kn
K_block_ptr
=
tl
.
make_block_ptr
(
base
=
K
+
k_offset
,
shape
=
(
ACTUAL_BLOCK_DMODEL
,
seqlen_k
),
strides
=
(
stride_kk
,
stride_kn
),
offsets
=
(
0
,
0
),
block_shape
=
(
BLOCK_DMODEL
,
BLOCK_N
),
order
=
(
0
,
1
)
)
v_offset
=
off_z
*
stride_vz
+
off_h_k
*
stride_vh
+
cu_seqlens_k_start
*
stride_vk
V_block_ptr
=
tl
.
make_block_ptr
(
base
=
V
+
v_offset
,
shape
=
(
seqlen_k
,
ACTUAL_BLOCK_DMODEL
),
strides
=
(
stride_vk
,
stride_vn
),
offsets
=
(
0
,
0
),
block_shape
=
(
BLOCK_N
,
BLOCK_DMODEL
),
order
=
(
1
,
0
)
)
if
BIAS_TYPE
!=
0
:
b_offset
=
off_h_q
*
stride_bh
# Note: this might get large enough to overflow on some configs
bias_ptr
=
tl
.
make_block_ptr
(
base
=
bias
+
b_offset
,
shape
=
(
seqlen_q
,
seqlen_k
),
strides
=
(
stride_bm
,
stride_bn
),
offsets
=
(
start_m
*
BLOCK_M
,
0
),
block_shape
=
(
BLOCK_M
,
BLOCK_N
),
order
=
(
1
,
0
),
)
else
:
bias_ptr
=
None
if
USE_ALIBI
:
a_offset
=
off_z
*
stride_az
+
off_h_q
*
stride_ah
alibi_slope
=
tl
.
load
(
alibi_slopes
+
a_offset
)
else
:
alibi_slope
=
None
if
ENABLE_DROPOUT
:
batch_philox_offset
=
philox_offset_base
+
off_hz
*
seqlen_q
*
seqlen_k
else
:
batch_philox_offset
=
0
# We can ask to return the dropout mask without actually doing any dropout. In
# this case, we return an invalid pointer so indicate the mask is not valid.
# TODO: Fix encoded softmax. It currently uses just h_q in the base offset.
if
RETURN_ENCODED_SOFTMAX
:
encoded_softmax_block_ptr
=
tl
.
make_block_ptr
(
base
=
encoded_softmax
+
off_h_q
*
seqlen_q
*
seqlen_k
,
shape
=
(
seqlen_q
,
seqlen_k
),
strides
=
(
seqlen_k
,
1
),
offsets
=
(
start_m
*
BLOCK_M
,
0
),
block_shape
=
(
BLOCK_M
,
BLOCK_N
),
order
=
(
1
,
0
)
)
else
:
encoded_softmax_block_ptr
=
0
# initialize pointer to m and l
m_i
=
tl
.
full
([
BLOCK_M
],
float
(
"-inf"
),
dtype
=
tl
.
float32
)
l_i
=
tl
.
full
([
BLOCK_M
],
1.0
,
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
# scale sm_scale by log_2(e) and use 2^x in the loop as we do not
# have native e^x support in HW.
qk_scale
=
sm_scale
*
1.44269504089
# Q is loaded once at the beginning and shared by all N blocks.
q
=
load_fn
(
Q_block_ptr
,
True
,
PADDED_HEAD
,
"zero"
)
q
=
(
q
*
qk_scale
).
to
(
Q_block_ptr
.
type
.
element_ty
)
# Here we compute how many full and masked blocks we have.
padded_block_k
=
n_extra_tokens
!=
0
is_modulo_mn
=
not
padded_block_k
and
(
seqlen_q
%
BLOCK_M
==
0
)
if
IS_CAUSAL
:
# There are always at least BLOCK_M // BLOCK_N masked blocks.
# Additionally there might be one more due to dissimilar seqlens.
masked_blocks
=
BLOCK_M
//
BLOCK_N
+
(
not
is_modulo_mn
)
else
:
# Padding on Q does not need to be masked in the FA loop.
masked_blocks
=
padded_block_k
# if IS_CAUSAL, not is_modulo_mn does not always result in an additional block.
# In this case we might exceed n_blocks so pick the min.
masked_blocks
=
min
(
masked_blocks
,
n_blocks
)
n_full_blocks
=
n_blocks
-
masked_blocks
block_min
=
0
block_max
=
n_blocks
*
BLOCK_N
# Compute for full blocks. Here we set causal to false regardless of its actual
# value because there is no masking. Similarly we do not need padding.
if
n_full_blocks
>
0
:
block_max
=
(
n_blocks
-
masked_blocks
)
*
BLOCK_N
acc
,
l_i
,
m_i
=
_attn_fwd_inner
(
acc
,
l_i
,
m_i
,
q
,
K_block_ptr
,
V_block_ptr
,
start_m
,
seqlen_k
,
seqlen_q
,
dropout_p
,
philox_seed
,
batch_philox_offset
,
encoded_softmax_block_ptr
,
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
block_min
,
block_max
,
0
,
0
,
0
,
bias_ptr
,
alibi_slope
,
# IS_CAUSAL, ....
False
,
BLOCK_M
,
BLOCK_DMODEL
,
BLOCK_N
,
offs_m
,
offs_n
,
# _, MASK_STEPS, ...
PRE_LOAD_V
,
False
,
ENABLE_DROPOUT
,
RETURN_ENCODED_SOFTMAX
,
PADDED_HEAD
)
block_min
=
block_max
block_max
=
n_blocks
*
BLOCK_N
tl
.
debug_barrier
()
# Remaining blocks, if any, are full / not masked.
if
(
masked_blocks
>
0
):
if
IS_CAUSAL
:
offs_n_causal
=
offs_n
+
(
seqlen_q
-
seqlen_k
)
else
:
offs_n_causal
=
0
K_block_ptr
=
tl
.
advance
(
K_block_ptr
,
(
0
,
n_full_blocks
*
BLOCK_N
))
V_block_ptr
=
tl
.
advance
(
V_block_ptr
,
(
n_full_blocks
*
BLOCK_N
,
0
))
if
bias_ptr
is
not
None
:
bias_ptr
=
tl
.
advance
(
bias_ptr
,
(
0
,
n_full_blocks
*
BLOCK_N
))
if
RETURN_ENCODED_SOFTMAX
:
encoded_softmax_block_ptr
=
tl
.
advance
(
encoded_softmax_block_ptr
,
(
0
,
n_full_blocks
))
acc
,
l_i
,
m_i
=
_attn_fwd_inner
(
acc
,
l_i
,
m_i
,
q
,
K_block_ptr
,
V_block_ptr
,
start_m
,
seqlen_k
,
seqlen_q
,
dropout_p
,
philox_seed
,
batch_philox_offset
,
encoded_softmax_block_ptr
,
block_min
,
block_max
,
offs_n_causal
,
masked_blocks
,
n_extra_tokens
,
bias_ptr
,
alibi_slope
,
IS_CAUSAL
,
BLOCK_M
,
BLOCK_DMODEL
,
BLOCK_N
,
offs_m
,
offs_n
,
# _, MASK_STEPS, ...
PRE_LOAD_V
,
True
,
ENABLE_DROPOUT
,
RETURN_ENCODED_SOFTMAX
,
PADDED_HEAD
)
# epilogue
acc
=
acc
/
l_i
[:,
None
]
if
ENABLE_DROPOUT
:
acc
=
acc
/
(
1
-
dropout_p
)
# If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
# then we have one block with a row of all NaNs which come from computing
# softmax over a row of all -infs (-inf - inf = NaN). We check for that here
# and store 0s where there are NaNs as these rows should've been zeroed out.
end_m_idx
=
(
start_m
+
1
)
*
BLOCK_M
start_m_idx
=
start_m
*
BLOCK_M
causal_start_idx
=
seqlen_q
-
seqlen_k
acc
=
acc
.
to
(
Out
.
type
.
element_ty
)
if
IS_CAUSAL
:
if
causal_start_idx
>
start_m_idx
and
causal_start_idx
<
end_m_idx
:
out_mask_boundary
=
tl
.
full
((
BLOCK_DMODEL
,),
causal_start_idx
,
dtype
=
tl
.
int32
)
mask_m_offsets
=
start_m_idx
+
tl
.
arange
(
0
,
BLOCK_M
)
out_ptrs_mask
=
mask_m_offsets
[:,
None
]
>=
out_mask_boundary
[
None
,
:]
z
=
0.0
acc
=
tl
.
where
(
out_ptrs_mask
,
acc
,
z
.
to
(
acc
.
type
.
element_ty
))
# write back LSE
l_ptrs
=
L
+
off_z
*
HQ
*
MAX_SEQLENS_Q
+
off_h_q
*
MAX_SEQLENS_Q
+
offs_m
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows.
# This is only true for the last M block. For others, overflow_size will be -ve
overflow_size
=
end_m_idx
-
seqlen_q
if
overflow_size
>
0
:
boundary
=
tl
.
full
((
BLOCK_M
,),
BLOCK_M
-
overflow_size
,
dtype
=
tl
.
int32
)
# This is a > check because mask being 0 blocks the store.
l_ptrs_mask
=
boundary
>
tl
.
arange
(
0
,
BLOCK_M
)
tl
.
store
(
l_ptrs
,
m_i
+
tl
.
math
.
log2
(
l_i
),
mask
=
l_ptrs_mask
)
else
:
tl
.
store
(
l_ptrs
,
m_i
+
tl
.
math
.
log2
(
l_i
))
# write back O
o_offset
=
off_z
*
stride_oz
+
cu_seqlens_q_start
*
stride_om
+
off_h_q
*
stride_oh
O_block_ptr
=
tl
.
make_block_ptr
(
base
=
Out
+
o_offset
,
shape
=
(
seqlen_q
,
ACTUAL_BLOCK_DMODEL
),
strides
=
(
stride_om
,
stride_on
),
offsets
=
(
start_m
*
BLOCK_M
,
0
),
block_shape
=
(
BLOCK_M
,
BLOCK_DMODEL
),
order
=
(
1
,
0
)
)
# Need boundary check on this to make sure the padding from the
# Q and KV tensors in both dims are not part of what we store back.
# TODO: Do the boundary check optionally.
tl
.
store
(
O_block_ptr
,
acc
,
boundary_check
=
(
0
,
1
))
@
triton
.
jit
def
_attn_bwd_preprocess
(
Out
,
DO
,
Delta
,
stride_oz
,
stride_oh
,
stride_om
,
stride_on
,
stride_doz
,
stride_doh
,
stride_dom
,
stride_don
,
seqlen_q
,
head_dim
,
BLOCK_M
:
tl
.
constexpr
,
D_HEAD
:
tl
.
constexpr
,
):
# off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
# off_n = tl.arange(0, D_HEAD)
off_m
=
tl
.
program_id
(
0
)
*
BLOCK_M
off_h
=
tl
.
program_id
(
1
)
# head index
off_z
=
tl
.
program_id
(
2
)
# batch index
num_h
=
tl
.
num_programs
(
1
)
o_offset
=
off_h
*
stride_oh
+
off_z
*
stride_oz
O_block_ptr
=
tl
.
make_block_ptr
(
base
=
Out
+
o_offset
,
shape
=
(
seqlen_q
,
head_dim
),
strides
=
(
stride_om
,
stride_on
),
offsets
=
(
off_m
,
0
),
block_shape
=
(
BLOCK_M
,
D_HEAD
),
order
=
(
1
,
0
)
)
do_offset
=
off_h
*
stride_doh
+
off_z
*
stride_doz
DO_block_ptr
=
tl
.
make_block_ptr
(
base
=
DO
+
do_offset
,
shape
=
(
seqlen_q
,
head_dim
),
strides
=
(
stride_dom
,
stride_don
),
offsets
=
(
off_m
,
0
),
block_shape
=
(
BLOCK_M
,
D_HEAD
),
order
=
(
1
,
0
)
)
# load
# o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
# do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
o
=
tl
.
load
(
O_block_ptr
,
boundary_check
=
(
0
,
1
),
padding_option
=
"zero"
).
to
(
tl
.
float32
)
do
=
tl
.
load
(
DO_block_ptr
,
boundary_check
=
(
0
,
1
),
padding_option
=
"zero"
).
to
(
tl
.
float32
)
# compute
delta
=
tl
.
sum
(
o
*
do
,
axis
=
1
)
# write-back, shape (q.shape[0] * q.shape[1], q.shape[2])
off_zh
=
off_z
*
num_h
+
off_h
*
1
# Check for OOB accesses
delta_ptrs
=
Delta
+
off_zh
*
seqlen_q
+
off_m
+
tl
.
arange
(
0
,
BLOCK_M
)
overflow
=
off_m
+
BLOCK_M
-
seqlen_q
if
overflow
>
0
:
boundary
=
tl
.
full
((
BLOCK_M
,
),
BLOCK_M
-
overflow
,
dtype
=
tl
.
int32
)
mask
=
boundary
>
tl
.
arange
(
0
,
BLOCK_M
)
tl
.
store
(
delta_ptrs
,
delta
,
mask
=
mask
)
else
:
tl
.
store
(
delta_ptrs
,
delta
)
@
triton
.
jit
def
_bwd_kernel_dk_dv
(
dk
,
dv
,
Q
,
k
,
v
,
sm_scale
,
alibi_slope
,
DO
,
M
,
D
,
# shared by Q/K/V/DO.
stride_tok
,
stride_d
,
H
,
N_CTX
,
BLOCK_M1
:
tl
.
constexpr
,
BLOCK_N1
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
# Filled in by the wrapper.
start_n
,
start_m
,
num_steps
,
MASK
:
tl
.
constexpr
):
offs_m
=
start_m
+
tl
.
arange
(
0
,
BLOCK_M1
)
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N1
)
offs_k
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
QT_block_ptr
=
tl
.
make_block_ptr
(
base
=
Q
,
shape
=
(
BLOCK_DMODEL
,
N_CTX
),
strides
=
(
stride_d
,
stride_tok
),
offsets
=
(
0
,
start_m
),
block_shape
=
(
BLOCK_DMODEL
,
BLOCK_M1
),
order
=
(
0
,
1
)
)
DO_block_ptr
=
tl
.
make_block_ptr
(
base
=
DO
,
shape
=
(
N_CTX
,
BLOCK_DMODEL
),
strides
=
(
stride_tok
,
stride_d
),
offsets
=
(
start_m
,
0
),
block_shape
=
(
BLOCK_M1
,
BLOCK_DMODEL
),
order
=
(
1
,
0
)
)
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
tl
.
static_assert
(
BLOCK_N1
%
BLOCK_M1
==
0
)
curr_m
=
start_m
step_m
=
BLOCK_M1
for
blk_idx
in
range
(
num_steps
):
qT
=
tl
.
load
(
QT_block_ptr
)
# Load m before computing qk to reduce pipeline stall.
offs_m
=
curr_m
+
tl
.
arange
(
0
,
BLOCK_M1
)
m
=
tl
.
load
(
M
+
offs_m
)
kqT
=
tl
.
dot
(
k
,
qT
)
if
alibi_slope
is
not
None
:
alibi_block
=
compute_alibi_block
(
alibi_slope
,
N_CTX
,
N_CTX
,
offs_m
,
offs_n
,
True
)
kqT
+=
alibi_block
*
1.44269504089
pT
=
tl
.
math
.
exp2
(
kqT
-
m
[
None
,
:])
# Autoregressive masking.
if
MASK
:
mask
=
(
offs_m
[
None
,
:]
>=
offs_n
[:,
None
])
pT
=
tl
.
where
(
mask
,
pT
,
0.0
)
do
=
tl
.
load
(
DO_block_ptr
)
# Compute dV.
ppT
=
pT
ppT
=
ppT
.
to
(
tl
.
float16
)
dv
+=
tl
.
dot
(
ppT
,
do
)
# D (= delta) is pre-divided by ds_scale.
Di
=
tl
.
load
(
D
+
offs_m
)
# Compute dP and dS.
dpT
=
tl
.
dot
(
v
,
tl
.
trans
(
do
))
dsT
=
pT
*
(
dpT
-
Di
[
None
,
:])
dsT
=
dsT
.
to
(
tl
.
float16
)
dk
+=
tl
.
dot
(
dsT
,
tl
.
trans
(
qT
))
# Increment pointers.
curr_m
+=
step_m
QT_block_ptr
=
tl
.
advance
(
QT_block_ptr
,
(
0
,
step_m
))
DO_block_ptr
=
tl
.
advance
(
DO_block_ptr
,
(
step_m
,
0
))
return
dk
,
dv
@
triton
.
jit
def
_bwd_kernel_dq
(
dq
,
q
,
K
,
V
,
do
,
m
,
D
,
alibi_slope
,
# shared by Q/K/V/DO.
stride_tok
,
stride_d
,
H
,
N_CTX
,
BLOCK_M2
:
tl
.
constexpr
,
BLOCK_N2
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
# Filled in by the wrapper.
start_m
,
start_n
,
num_steps
,
MASK
:
tl
.
constexpr
):
offs_m
=
start_m
+
tl
.
arange
(
0
,
BLOCK_M2
)
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N2
)
offs_k
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
KT_block_ptr
=
tl
.
make_block_ptr
(
base
=
K
,
shape
=
(
BLOCK_DMODEL
,
N_CTX
),
strides
=
(
stride_d
,
stride_tok
),
offsets
=
(
0
,
start_n
),
block_shape
=
(
BLOCK_DMODEL
,
BLOCK_N2
),
order
=
(
0
,
1
)
)
VT_block_ptr
=
tl
.
make_block_ptr
(
base
=
V
,
shape
=
(
BLOCK_DMODEL
,
N_CTX
),
strides
=
(
stride_d
,
stride_tok
),
offsets
=
(
0
,
start_n
),
block_shape
=
(
BLOCK_DMODEL
,
BLOCK_N2
),
order
=
(
0
,
1
)
)
# D (= delta) is pre-divided by ds_scale.
Di
=
tl
.
load
(
D
+
offs_m
)
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
tl
.
static_assert
(
BLOCK_M2
%
BLOCK_N2
==
0
)
curr_n
=
start_n
step_n
=
BLOCK_N2
for
blk_idx
in
range
(
num_steps
):
kT
=
tl
.
load
(
KT_block_ptr
)
qk
=
tl
.
dot
(
q
,
kT
)
if
alibi_slope
is
not
None
:
alibi_block
=
compute_alibi_block
(
alibi_slope
,
N_CTX
,
N_CTX
,
offs_m
,
offs_n
)
qk
+=
alibi_block
*
1.44269504089
p
=
tl
.
math
.
exp2
(
qk
-
m
)
# Autoregressive masking.
if
MASK
:
offs_n
=
curr_n
+
tl
.
arange
(
0
,
BLOCK_N2
)
mask
=
(
offs_m
[:,
None
]
>=
offs_n
[
None
,
:])
p
=
tl
.
where
(
mask
,
p
,
0.0
)
# Compute dP and dS.
vT
=
tl
.
load
(
VT_block_ptr
)
dp
=
tl
.
dot
(
do
,
vT
).
to
(
tl
.
float32
)
ds
=
p
*
(
dp
-
Di
[:,
None
])
ds
=
ds
.
to
(
tl
.
float16
)
# Compute dQ.0.
# NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
dq
+=
tl
.
dot
(
ds
,
tl
.
trans
(
kT
))
# Increment pointers.
curr_n
+=
step_n
KT_block_ptr
=
tl
.
advance
(
KT_block_ptr
,
(
0
,
step_n
))
VT_block_ptr
=
tl
.
advance
(
VT_block_ptr
,
(
0
,
step_n
))
return
dq
@
triton
.
jit
def
_attn_bwd
(
Q
,
K
,
V
,
sm_scale
,
alibi_slopes
,
DO
,
DQ
,
DK
,
DV
,
M
,
D
,
# shared by Q/K/V/DO.
stride_z
,
stride_h
,
stride_tok
,
stride_d
,
# H = 16, N_CTX = 1024
H
,
N_CTX
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_M1
:
tl
.
constexpr
,
BLOCK_N1
:
tl
.
constexpr
,
BLOCK_M2
:
tl
.
constexpr
,
BLOCK_N2
:
tl
.
constexpr
,
BLK_SLICE_FACTOR
:
tl
.
constexpr
,
USE_ALIBI
:
tl
.
constexpr
):
LN2
:
tl
.
constexpr
=
0.6931471824645996
# = ln(2)
bhid
=
tl
.
program_id
(
2
)
off_chz
=
(
bhid
*
N_CTX
).
to
(
tl
.
int64
)
adj
=
(
stride_h
*
(
bhid
%
H
)
+
stride_z
*
(
bhid
//
H
)).
to
(
tl
.
int64
)
pid
=
tl
.
program_id
(
0
)
# offset pointers for batch/head
Q
+=
adj
K
+=
adj
V
+=
adj
DO
+=
adj
DQ
+=
adj
DK
+=
adj
DV
+=
adj
M
+=
off_chz
D
+=
off_chz
offs_k
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
start_n
=
pid
*
BLOCK_N1
# This assignment is important. It is what allows us to pick the diagonal
# blocks. Later, when we want to do the lower triangular, we update start_m
# after the first dkdv call.
start_m
=
start_n
MASK_BLOCK_M1
:
tl
.
constexpr
=
BLOCK_M1
//
BLK_SLICE_FACTOR
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N1
)
dv
=
tl
.
zeros
([
BLOCK_N1
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
dk
=
tl
.
zeros
([
BLOCK_N1
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
K_block_ptr
=
tl
.
make_block_ptr
(
base
=
K
,
shape
=
(
N_CTX
,
BLOCK_DMODEL
),
strides
=
(
stride_tok
,
stride_d
),
offsets
=
(
start_n
,
0
),
block_shape
=
(
BLOCK_N1
,
BLOCK_DMODEL
),
order
=
(
1
,
0
),
)
V_block_ptr
=
tl
.
make_block_ptr
(
base
=
V
,
shape
=
(
N_CTX
,
BLOCK_DMODEL
),
strides
=
(
stride_tok
,
stride_d
),
offsets
=
(
start_n
,
0
),
block_shape
=
(
BLOCK_N1
,
BLOCK_DMODEL
),
order
=
(
1
,
0
),
)
# load K and V: they stay in SRAM throughout the inner loop for dkdv.
k
=
tl
.
load
(
K_block_ptr
)
v
=
tl
.
load
(
V_block_ptr
)
if
USE_ALIBI
:
a_offset
=
bhid
alibi_slope
=
tl
.
load
(
alibi_slopes
+
a_offset
)
else
:
alibi_slope
=
None
# compute dK and dV for blocks close to the diagonal that need to be masked
num_steps
=
BLOCK_N1
//
MASK_BLOCK_M1
dk
,
dv
=
_bwd_kernel_dk_dv
(
dk
,
dv
,
Q
,
k
,
v
,
sm_scale
,
alibi_slope
,
DO
,
M
,
D
,
stride_tok
,
stride_d
,
H
,
N_CTX
,
MASK_BLOCK_M1
,
BLOCK_N1
,
BLOCK_DMODEL
,
start_n
,
start_m
,
num_steps
,
MASK
=
True
)
# compute dK and dV for blocks that don't need masking further from the diagonal
start_m
+=
num_steps
*
MASK_BLOCK_M1
num_steps
=
(
N_CTX
-
start_m
)
//
BLOCK_M1
dk
,
dv
=
_bwd_kernel_dk_dv
(
dk
,
dv
,
Q
,
k
,
v
,
sm_scale
,
alibi_slope
,
DO
,
M
,
D
,
stride_tok
,
stride_d
,
H
,
N_CTX
,
BLOCK_M1
,
BLOCK_N1
,
BLOCK_DMODEL
,
start_n
,
start_m
,
num_steps
,
MASK
=
False
)
DV_block_ptrs
=
tl
.
make_block_ptr
(
base
=
DV
,
shape
=
(
N_CTX
,
BLOCK_DMODEL
),
strides
=
(
stride_tok
,
stride_d
),
offsets
=
(
start_n
,
0
),
block_shape
=
(
BLOCK_N1
,
BLOCK_DMODEL
),
order
=
(
1
,
0
)
)
tl
.
store
(
DV_block_ptrs
,
dv
.
to
(
v
.
dtype
))
# Write back dK.
dk
*=
sm_scale
DK_block_ptrs
=
tl
.
make_block_ptr
(
base
=
DK
,
shape
=
(
N_CTX
,
BLOCK_DMODEL
),
strides
=
(
stride_tok
,
stride_d
),
offsets
=
(
start_n
,
0
),
block_shape
=
(
BLOCK_N1
,
BLOCK_DMODEL
),
order
=
(
1
,
0
)
)
tl
.
store
(
DK_block_ptrs
,
dk
.
to
(
k
.
dtype
))
# THIS BLOCK DOES DQ:
start_m
=
pid
*
BLOCK_M2
end_n
=
start_m
+
BLOCK_M2
MASK_BLOCK_N2
:
tl
.
constexpr
=
BLOCK_N2
//
BLK_SLICE_FACTOR
offs_m
=
start_m
+
tl
.
arange
(
0
,
BLOCK_M2
)
Q_block_ptr
=
tl
.
make_block_ptr
(
base
=
Q
,
shape
=
(
N_CTX
,
BLOCK_DMODEL
),
strides
=
(
stride_tok
,
stride_d
),
offsets
=
(
start_m
,
0
),
block_shape
=
(
BLOCK_M2
,
BLOCK_DMODEL
),
order
=
(
1
,
0
)
)
DO_block_ptr
=
tl
.
make_block_ptr
(
base
=
DO
,
shape
=
(
N_CTX
,
BLOCK_DMODEL
),
strides
=
(
stride_tok
,
stride_d
),
offsets
=
(
start_m
,
0
),
block_shape
=
(
BLOCK_M2
,
BLOCK_DMODEL
),
order
=
(
1
,
0
)
)
q
=
tl
.
load
(
Q_block_ptr
)
do
=
tl
.
load
(
DO_block_ptr
)
dq
=
tl
.
zeros
([
BLOCK_M2
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
m
=
tl
.
load
(
M
+
offs_m
)
m
=
m
[:,
None
]
# Compute dQ for masked (diagonal) blocks.
# NOTE: This code scans each row of QK^T backward (from right to left,
# but inside each call to _attn_bwd_dq, from left to right), but that's
# not due to anything important. I just wanted to reuse the loop
# structure for dK & dV above as much as possible.
num_steps
=
BLOCK_M2
//
MASK_BLOCK_N2
dq
=
_bwd_kernel_dq
(
dq
,
q
,
K
,
V
,
do
,
m
,
D
,
alibi_slope
,
stride_tok
,
stride_d
,
H
,
N_CTX
,
BLOCK_M2
,
MASK_BLOCK_N2
,
BLOCK_DMODEL
,
start_m
,
end_n
-
num_steps
*
MASK_BLOCK_N2
,
num_steps
,
MASK
=
True
)
end_n
-=
num_steps
*
MASK_BLOCK_N2
# stage 2
num_steps
=
end_n
//
BLOCK_N2
dq
=
_bwd_kernel_dq
(
dq
,
q
,
K
,
V
,
do
,
m
,
D
,
alibi_slope
,
stride_tok
,
stride_d
,
H
,
N_CTX
,
BLOCK_M2
,
BLOCK_N2
,
BLOCK_DMODEL
,
start_m
,
end_n
-
num_steps
*
BLOCK_N2
,
num_steps
,
MASK
=
False
)
# Write back dQ.
DQ_block_ptr
=
tl
.
make_block_ptr
(
base
=
DQ
,
shape
=
(
N_CTX
,
BLOCK_DMODEL
),
strides
=
(
stride_tok
,
stride_d
),
offsets
=
(
start_m
,
0
),
block_shape
=
(
BLOCK_M2
,
BLOCK_DMODEL
),
order
=
(
1
,
0
)
)
dq
*=
LN2
tl
.
store
(
DQ_block_ptr
,
dq
.
to
(
q
.
dtype
))
class
_attention
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
o
,
metadata
):
# NOTE: a large bias tensor leads to overflow during pointer arithmetic
if
(
metadata
.
bias
is
not
None
):
assert
(
metadata
.
bias
.
numel
()
<
2
**
31
)
if
o
is
None
:
o
=
torch
.
empty_like
(
q
,
dtype
=
v
.
dtype
)
import
os
if
os
.
environ
.
get
(
"FLASH_ATTENTION_PRINT_PARAM"
,
"0"
)
==
"1"
:
print
(
f
"triton flash attention:
{
q
.
shape
=
}
,
{
k
.
shape
=
}
,
{
v
.
shape
}
,
{
o
.
shape
=
}
"
)
print
(
f
"triton flash attention:
{
q
.
stride
()
=
}
,
{
k
.
stride
()
=
}
,
{
v
.
stride
()
=
}
,
{
o
.
stride
()
=
}
"
)
print
(
f
"triton flash attention:
{
metadata
=
}
"
)
metadata
.
check_args
(
q
,
k
,
v
,
o
)
if
metadata
.
varlen
:
total_q
,
nheads_q
,
head_size
=
q
.
shape
total_k
,
nheads_k
,
_
=
k
.
shape
batch
=
metadata
.
num_contexts
q_strides
=
(
0
,
q
.
stride
(
1
),
q
.
stride
(
0
),
q
.
stride
(
2
))
k_strides
=
(
0
,
k
.
stride
(
1
),
k
.
stride
(
0
),
k
.
stride
(
2
))
v_strides
=
(
0
,
v
.
stride
(
1
),
v
.
stride
(
0
),
v
.
stride
(
2
))
o_strides
=
(
0
,
o
.
stride
(
1
),
o
.
stride
(
0
),
o
.
stride
(
2
))
else
:
batch
,
nheads_q
,
seqlen_q
,
head_size
=
q
.
shape
_
,
nheads_k
,
seqlen_k
,
_
=
k
.
shape
q_strides
=
(
q
.
stride
(
0
),
q
.
stride
(
1
),
q
.
stride
(
2
),
q
.
stride
(
3
))
k_strides
=
(
k
.
stride
(
0
),
k
.
stride
(
1
),
k
.
stride
(
2
),
k
.
stride
(
3
))
v_strides
=
(
v
.
stride
(
0
),
v
.
stride
(
1
),
v
.
stride
(
2
),
v
.
stride
(
3
))
o_strides
=
(
o
.
stride
(
0
),
o
.
stride
(
1
),
o
.
stride
(
2
),
o
.
stride
(
3
))
# Get closest power of 2 over or equal to 32.
padded_d_model
=
1
<<
(
head_size
-
1
).
bit_length
()
padded_d_model
=
max
(
padded_d_model
,
16
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
metadata
.
max_seqlens_q
,
META
[
'BLOCK_M'
]),
nheads_q
,
batch
)
# encoded_softmax is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out
# to give a consistent starting point and then populate it with the output of softmax with the sign bit set according
# to the dropout mask. The resulting return allows this mask to be fed into the reference implementation for testing
# only. This return holds no useful output aside from debugging.
if
metadata
.
return_encoded_softmax
:
encoded_softmax
=
torch
.
zeros
((
q
.
shape
[
0
],
q
.
shape
[
1
],
q
.
shape
[
2
],
k
.
shape
[
2
]),
device
=
q
.
device
,
dtype
=
torch
.
float32
)
else
:
encoded_softmax
=
None
M
=
torch
.
empty
((
batch
,
nheads_q
,
metadata
.
max_seqlens_q
),
device
=
q
.
device
,
dtype
=
torch
.
float32
)
# Seed the RNG so we get reproducible results for testing.
philox_seed
=
0x1BF52
philox_offset
=
0x1D4B42
if
metadata
.
bias
is
not
None
:
bias_strides
=
(
metadata
.
bias
.
stride
(
0
),
metadata
.
bias
.
stride
(
1
),
metadata
.
bias
.
stride
(
2
),
metadata
.
bias
.
stride
(
3
))
else
:
bias_strides
=
(
0
,
0
,
0
,
0
)
if
metadata
.
alibi_slopes
is
not
None
:
alibi_strides
=
(
metadata
.
alibi_slopes
.
stride
(
0
),
metadata
.
alibi_slopes
.
stride
(
1
))
else
:
alibi_strides
=
(
0
,
0
)
attn_fwd
[
grid
](
q
,
k
,
v
,
metadata
.
bias
,
metadata
.
sm_scale
,
M
,
o
,
*
q_strides
,
*
k_strides
,
*
v_strides
,
*
o_strides
,
*
bias_strides
,
*
alibi_strides
,
metadata
.
cu_seqlens_q
,
metadata
.
cu_seqlens_k
,
dropout_p
=
metadata
.
dropout_p
,
philox_seed
=
philox_seed
,
philox_offset_base
=
philox_offset
,
encoded_softmax
=
encoded_softmax
,
alibi_slopes
=
metadata
.
alibi_slopes
,
HQ
=
nheads_q
,
HK
=
nheads_k
,
ACTUAL_BLOCK_DMODEL
=
head_size
,
MAX_SEQLENS_Q
=
metadata
.
max_seqlens_q
,
MAX_SEQLENS_K
=
metadata
.
max_seqlens_k
,
IS_CAUSAL
=
metadata
.
causal
,
VARLEN
=
metadata
.
varlen
,
BLOCK_DMODEL
=
padded_d_model
,
BIAS_TYPE
=
0
if
metadata
.
bias
is
None
else
1
,
USE_ALIBI
=
False
if
metadata
.
alibi_slopes
is
None
else
True
,
ENABLE_DROPOUT
=
metadata
.
dropout_p
>
0.0
,
RETURN_ENCODED_SOFTMAX
=
metadata
.
return_encoded_softmax
,
BATCH_SIZE
=
q
.
shape
[
0
]
)
if
os
.
environ
.
get
(
"FLASH_ATTENTION_PRINT_PARAM"
,
"0"
)
==
"1"
:
best_config
=
attn_fwd
.
get_best_config
()
print
(
f
"
{
best_config
.
kwargs
=
}
,
{
best_config
.
num_stages
=
}
,
{
best_config
.
num_warps
=
}
"
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
o
,
M
)
ctx
.
grid
=
grid
ctx
.
sm_scale
=
metadata
.
sm_scale
ctx
.
BLOCK_DMODEL
=
head_size
ctx
.
causal
=
metadata
.
causal
ctx
.
alibi_slopes
=
metadata
.
alibi_slopes
ctx
.
dropout_p
=
metadata
.
dropout_p
ctx
.
philox_seed
=
philox_seed
ctx
.
philox_offset
=
philox_offset
ctx
.
encoded_softmax
=
encoded_softmax
ctx
.
return_encoded_softmax
=
metadata
.
return_encoded_softmax
return
o
if
not
metadata
.
return_encoded_softmax
else
(
o
,
encoded_softmax
,
)
# S_dmask
@
staticmethod
def
backward
(
ctx
,
do
,
*
args
):
if
torch
.
version
.
hip
is
not
None
:
BLOCK
=
64
else
:
BLOCK
=
128
q
,
k
,
v
,
o
,
M
=
ctx
.
saved_tensors
import
os
if
os
.
environ
.
get
(
"TRITON_FLASHATTN_DEBUG"
,
"0"
)
==
"1"
:
print
(
f
"triton flash attention:
{
q
.
shape
=
}
,
{
k
.
shape
=
}
,
{
v
.
shape
}
,
{
o
.
shape
=
}
,
{
do
.
shape
=
}
"
)
print
(
f
"triton flash attention:
{
q
.
stride
()
=
}
,
{
k
.
stride
()
=
}
,
{
v
.
stride
()
=
}
,
{
o
.
stride
()
=
}
,
{
do
.
stride
()
}
"
)
# assert do.is_contiguous()
assert
q
.
stride
()
==
k
.
stride
()
==
v
.
stride
()
==
o
.
stride
()
==
do
.
stride
()
seqlen_q
=
q
.
shape
[
2
]
dq
=
torch
.
empty_like
(
q
)
dk
=
torch
.
empty_like
(
k
)
dv
=
torch
.
empty_like
(
v
)
BATCH
,
N_HEAD
,
N_CTX
=
q
.
shape
[:
3
]
PRE_BLOCK
=
128
NUM_WARPS
,
NUM_STAGES
=
4
,
1
BLOCK_M1
,
BLOCK_N1
,
BLOCK_M2
,
BLOCK_N2
=
32
,
64
,
64
,
32
BLK_SLICE_FACTOR
=
2
RCP_LN2
=
1.4426950408889634
# = 1.0 / ln(2)
arg_k
=
k
arg_k
=
arg_k
*
(
ctx
.
sm_scale
*
RCP_LN2
)
assert
N_CTX
%
PRE_BLOCK
==
0
delta
=
torch
.
empty_like
(
M
)
Lq
,
Lk
,
Lv
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
padded_head
=
(
Lk
!=
ctx
.
BLOCK_DMODEL
)
grid_preprocess
=
(
triton
.
cdiv
(
do
.
shape
[
2
],
BLOCK
),
do
.
shape
[
1
],
do
.
shape
[
0
])
_attn_bwd_preprocess
[
grid_preprocess
](
o
,
do
,
delta
,
o
.
stride
(
0
),
o
.
stride
(
1
),
o
.
stride
(
2
),
o
.
stride
(
3
),
do
.
stride
(
0
),
do
.
stride
(
1
),
do
.
stride
(
2
),
do
.
stride
(
3
),
seqlen_q
,
head_dim
=
Lk
,
BLOCK_M
=
BLOCK
,
D_HEAD
=
ctx
.
BLOCK_DMODEL
,
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
N_CTX
,
META
[
'BLOCK_N1'
]),
1
,
BATCH
*
N_HEAD
)
_attn_bwd
[
grid
](
q
,
arg_k
,
v
,
ctx
.
sm_scale
,
ctx
.
alibi_slopes
,
do
,
dq
,
dk
,
dv
,
M
,
delta
,
q
.
stride
(
0
),
q
.
stride
(
1
),
q
.
stride
(
2
),
q
.
stride
(
3
),
N_HEAD
,
N_CTX
,
BLOCK_DMODEL
=
ctx
.
BLOCK_DMODEL
,
BLOCK_M1
=
BLOCK_M1
,
BLOCK_N1
=
BLOCK_N1
,
BLOCK_M2
=
BLOCK_M2
,
BLOCK_N2
=
BLOCK_N2
,
BLK_SLICE_FACTOR
=
BLK_SLICE_FACTOR
,
USE_ALIBI
=
False
if
ctx
.
alibi_slopes
is
None
else
True
,
)
return
dq
,
dk
,
dv
,
None
,
None
attention
=
_attention
.
apply
# flash_attn wrapper
def
input_helper
(
Z
,
HQ
,
HK
,
N_CTX_Q
,
N_CTX_K
,
D_HEAD
,
dtype
):
torch
.
manual_seed
(
20
)
# Initialize q, k, v
q
=
torch
.
randn
((
Z
,
HQ
,
N_CTX_Q
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
)
k
=
torch
.
randn
((
Z
,
HK
,
N_CTX_K
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
)
v
=
torch
.
randn
((
Z
,
HK
,
N_CTX_K
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
)
sm_scale
=
D_HEAD
**-
0.5
input_metadata
=
MetaData
(
sm_scale
=
sm_scale
)
input_metadata
.
max_seqlens_q
=
N_CTX_Q
input_metadata
.
max_seqlens_k
=
N_CTX_K
return
q
,
k
,
v
,
input_metadata
def
padding_bshd
(
t
):
# BSHD
batch
,
seqlen
,
nheads
,
dim
=
t
.
shape
t
=
torch
.
nn
.
functional
.
pad
(
t
.
reshape
(
batch
,
seqlen
,
nheads
*
dim
),
(
0
,
32
),
'constant'
,
0
)[:,:,:
-
32
].
reshape
(
batch
,
seqlen
,
nheads
,
dim
)
# pad: nheads*dim+32
# t = torch.nn.functional.pad(t.reshape(batch, seqlen, nheads, dim), (0, 32), 'constant', 0)[:,:,:,:-32] # pad: dim+32
return
t
def
flash_attn_func
(
q
,
k
,
v
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
,
padding_input
=
False
):
if
padding_input
:
k
,
v
=
(
padding_bshd
(
t
)
for
t
in
(
k
,
v
))
q
,
k
,
v
=
(
t
.
transpose
(
1
,
2
)
for
t
in
(
q
,
k
,
v
))
softmax_scale
=
softmax_scale
if
softmax_scale
else
q
.
shape
[
-
1
]
**-
0.5
input_metadata
=
MetaData
(
sm_scale
=
softmax_scale
,
causal
=
causal
,
dropout_p
=
dropout_p
,
return_encoded_softmax
=
return_attn_probs
)
input_metadata
.
max_seqlens_q
=
q
.
shape
[
2
]
input_metadata
.
max_seqlens_k
=
k
.
shape
[
2
]
return
_attention
.
apply
(
q
,
k
,
v
,
None
,
input_metadata
).
transpose
(
1
,
2
)
def
flash_attn_kvpacked_func
(
q
,
kv
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
,
padding_input
=
False
):
k
,
v
=
kv
[:,
:,
0
],
kv
[:,
:,
1
]
# batch_size, seqlen, 2, nheads_k, headdim
if
padding_input
:
k
,
v
=
(
padding_bshd
(
t
)
for
t
in
(
k
,
v
))
# pad
q
,
k
,
v
=
(
t
.
transpose
(
1
,
2
)
for
t
in
(
q
,
k
,
v
))
# trans bshd to bhsd
softmax_scale
=
softmax_scale
if
softmax_scale
else
q
.
shape
[
-
1
]
**-
0.5
input_metadata
=
MetaData
(
sm_scale
=
softmax_scale
,
causal
=
causal
,
dropout_p
=
dropout_p
,
return_encoded_softmax
=
return_attn_probs
)
input_metadata
.
max_seqlens_q
=
q
.
shape
[
2
]
input_metadata
.
max_seqlens_k
=
k
.
shape
[
2
]
return
_attention
.
apply
(
q
,
k
,
v
,
None
,
input_metadata
).
transpose
(
1
,
2
)
# trans bhsd to bshd
def
flash_attn_qkvpacked_func
(
qkv
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
,
padding_input
=
False
):
q
,
k
,
v
=
qkv
[:,
:,
0
],
qkv
[:,
:,
1
],
qkv
[:,
:,
2
]
if
padding_input
:
k
,
v
=
(
padding_bshd
(
t
)
for
t
in
(
k
,
v
))
q
,
k
,
v
=
(
t
.
transpose
(
1
,
2
)
for
t
in
(
q
,
k
,
v
))
softmax_scale
=
softmax_scale
if
softmax_scale
else
q
.
shape
[
-
1
]
**-
0.5
input_metadata
=
MetaData
(
sm_scale
=
softmax_scale
,
causal
=
causal
,
dropout_p
=
dropout_p
,
return_encoded_softmax
=
return_attn_probs
)
input_metadata
.
max_seqlens_q
=
q
.
shape
[
2
]
input_metadata
.
max_seqlens_k
=
k
.
shape
[
2
]
return
_attention
.
apply
(
q
,
k
,
v
,
None
,
input_metadata
).
transpose
(
1
,
2
)
# varlen flash_attn
def
varlen_input_helper
(
Z
,
HQ
,
HK
,
N_CTX_Q
,
N_CTX_K
,
D_HEAD
,
dtype
,
causal
):
torch
.
manual_seed
(
20
)
# Random sequence lengths. Using N_CTX as kind of max of sum of individual seqs
max_seqlens_q
=
N_CTX_Q
//
Z
max_seqlens_k
=
N_CTX_K
//
Z
seqlens_q
=
torch
.
randint
(
1
,
max_seqlens_q
+
1
,
(
Z
,),
dtype
=
torch
.
int32
)
seqlens_k
=
torch
.
randint
(
1
,
max_seqlens_k
+
1
,
(
Z
,),
dtype
=
torch
.
int32
)
max_seqlens_q
=
torch
.
max
(
seqlens_q
).
item
()
max_seqlens_k
=
torch
.
max
(
seqlens_k
).
item
()
# Calculate cumulative sequence lengths
cu_seqlens_q
=
torch
.
cat
([
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
),
seqlens_q
.
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)])
cu_seqlens_k
=
torch
.
cat
([
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
),
seqlens_k
.
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)])
cu_seqlens_q
=
cu_seqlens_q
.
to
(
device
=
"cuda"
)
cu_seqlens_k
=
cu_seqlens_k
.
to
(
device
=
"cuda"
)
# -1 because the last entry of cu_seqlens_q specifies the end of the last seq
num_ctxs
=
len
(
cu_seqlens_q
)
-
1
# Initialize q, k, v with variable lengths
total_q
=
cu_seqlens_q
[
-
1
].
item
()
total_k
=
cu_seqlens_k
[
-
1
].
item
()
q
=
torch
.
randn
((
total_q
,
HQ
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0.
,
std
=
0.5
).
requires_grad_
()
k
=
torch
.
randn
((
total_k
,
HK
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0.
,
std
=
0.5
).
requires_grad_
()
v
=
torch
.
randn
((
total_k
,
HK
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0.
,
std
=
0.5
).
requires_grad_
()
sm_scale
=
D_HEAD
**-
0.5
input_metadata
=
MetaData
(
sm_scale
=
sm_scale
)
input_metadata
.
set_varlen_params
(
cu_seqlens_q
,
cu_seqlens_k
)
input_metadata
.
max_seqlens_q
=
max_seqlens_q
input_metadata
.
max_seqlens_k
=
max_seqlens_k
if
causal
:
input_metadata
.
need_causal
()
return
q
,
k
,
v
,
input_metadata
def
padding_thd
(
t
):
# THD
total_seqlen
,
nheads
,
dim
=
t
.
shape
t
=
torch
.
nn
.
functional
.
pad
(
t
.
reshape
(
total_seqlen
,
nheads
*
dim
),
(
0
,
32
),
'constant'
,
0
)[:,:
-
32
].
reshape
(
total_seqlen
,
nheads
,
dim
)
# pad: nheads*dim+32
# t = torch.nn.functional.pad(t.reshape(total_seqlen, nheads, dim), (0, 32), 'constant', 0)[:,:-32] # pad: dim+32
return
t
def
flash_attn_varlen_qkvpacked_func
(
qkv
,
cu_seqlens
,
max_seqlens
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
,
padding_input
=
False
):
q
,
k
,
v
=
qkv
[:,
0
],
qkv
[:,
1
],
qkv
[:,
2
]
# total_seqlen, 3, nheads, dim
if
padding_input
:
k
,
v
=
(
padding_thd
(
t
)
for
t
in
(
k
,
v
))
# pad
softmax_scale
=
softmax_scale
if
softmax_scale
else
q
.
shape
[
-
1
]
**-
0.5
input_metadata
=
MetaData
(
sm_scale
=
softmax_scale
,
causal
=
causal
,
dropout_p
=
dropout_p
,
return_encoded_softmax
=
return_attn_probs
)
input_metadata
.
set_varlen_params
(
cu_seqlens
,
cu_seqlens
)
input_metadata
.
max_seqlens_q
=
max_seqlens
input_metadata
.
max_seqlens_k
=
max_seqlens
return
_attention
.
apply
(
q
,
k
,
v
,
None
,
input_metadata
)
def
flash_attn_varlen_kvpacked_func
(
q
,
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlens_q
,
max_seqlens_k
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
,
padding_input
=
False
):
k
,
v
=
kv
[:,
0
],
kv
[:,
1
]
# total_seqlen, 2, nheads, dim
if
padding_input
:
k
,
v
=
(
padding_thd
(
t
)
for
t
in
(
k
,
v
))
softmax_scale
=
softmax_scale
if
softmax_scale
else
q
.
shape
[
-
1
]
**-
0.5
input_metadata
=
MetaData
(
sm_scale
=
softmax_scale
,
causal
=
causal
,
dropout_p
=
dropout_p
,
return_encoded_softmax
=
return_attn_probs
)
input_metadata
.
set_varlen_params
(
cu_seqlens_q
,
cu_seqlens_k
)
input_metadata
.
max_seqlens_q
=
max_seqlens_q
input_metadata
.
max_seqlens_k
=
max_seqlens_k
return
_attention
.
apply
(
q
,
k
,
v
,
None
,
input_metadata
)
def
flash_attn_varlen_func
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlens_q
,
max_seqlens_k
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
,
padding_input
=
False
):
if
padding_input
:
k
,
v
=
(
padding_thd
(
t
)
for
t
in
(
k
,
v
))
softmax_scale
=
softmax_scale
if
softmax_scale
else
q
.
shape
[
-
1
]
**-
0.5
input_metadata
=
MetaData
(
sm_scale
=
softmax_scale
,
causal
=
causal
,
dropout_p
=
dropout_p
,
return_encoded_softmax
=
return_attn_probs
)
input_metadata
.
set_varlen_params
(
cu_seqlens_q
,
cu_seqlens_k
)
input_metadata
.
max_seqlens_q
=
max_seqlens_q
input_metadata
.
max_seqlens_k
=
max_seqlens_k
return
_attention
.
apply
(
q
,
k
,
v
,
None
,
input_metadata
)
# legacy interface
def
flash_attn_unpadded_qkvpacked_func
(
qkv
,
cu_seqlens
,
max_seqlens
,
dropout_p
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
,
padding_input
=
False
):
return
flash_attn_varlen_qkvpacked_func
(
qkv
,
cu_seqlens
,
max_seqlens
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
,
padding_input
)
def
flash_attn_unpadded_kvpacked_func
(
q
,
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlens_q
,
max_seqlens_k
,
dropout_p
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
,
padding_input
=
False
):
return
flash_attn_varlen_kvpacked_func
(
q
,
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlens_q
,
max_seqlens_k
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
,
padding_input
)
def
flash_attn_unpadded_func
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlens_q
,
max_seqlens_k
,
dropout_p
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
,
padding_input
=
False
):
return
flash_attn_varlen_func
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlens_q
,
max_seqlens_k
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
,
padding_input
)
vllm/envs.py
View file @
1863c926
...
@@ -130,7 +130,7 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -130,7 +130,7 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# flag to control if vllm should use triton flash attention
# flag to control if vllm should use triton flash attention
"VLLM_USE_TRITON_FLASH_ATTN"
:
"VLLM_USE_TRITON_FLASH_ATTN"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_TRITON_FLASH_ATTN"
,
"
Fals
e"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_TRITON_FLASH_ATTN"
,
"
Tru
e"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# local rank of the process in the distributed setting, used to determine
# local rank of the process in the distributed setting, used to determine
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment