Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
flash-attention
Commits
4b535e64
Commit
4b535e64
authored
May 21, 2026
by
zhangshao
Browse files
update
parent
34e67b1e
Changes
22
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2007 additions
and
494 deletions
+2007
-494
benchmarks/benchmark_pa.py
benchmarks/benchmark_pa.py
+406
-0
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api.cpp
+22
-16
csrc/flash_attn/src/dropout.h
csrc/flash_attn/src/dropout.h
+6
-4
csrc/flash_attn/src/flash_bwd_kernel.h
csrc/flash_attn/src/flash_bwd_kernel.h
+158
-190
csrc/flash_attn/src/flash_bwd_launch_template.h
csrc/flash_attn/src/flash_bwd_launch_template.h
+2
-2
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+814
-162
csrc/flash_attn/src/flash_fwd_launch_template.h
csrc/flash_attn/src/flash_fwd_launch_template.h
+10
-3
csrc/flash_attn/src/flash_fwd_unified_hdim128_bf16_causal_sm80.cu
...sh_attn/src/flash_fwd_unified_hdim128_bf16_causal_sm80.cu
+8
-0
csrc/flash_attn/src/flash_fwd_unified_hdim128_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_unified_hdim128_bf16_sm80.cu
+8
-0
csrc/flash_attn/src/flash_fwd_unified_hdim128_fp16_causal_sm80.cu
...sh_attn/src/flash_fwd_unified_hdim128_fp16_causal_sm80.cu
+8
-0
csrc/flash_attn/src/flash_fwd_unified_hdim128_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_unified_hdim128_fp16_sm80.cu
+8
-0
csrc/flash_attn/src/flash_sparse_util.cu
csrc/flash_attn/src/flash_sparse_util.cu
+152
-0
csrc/flash_attn/src/kernel_traits.h
csrc/flash_attn/src/kernel_traits.h
+153
-0
csrc/flash_attn/src/paged_attention.cu
csrc/flash_attn/src/paged_attention.cu
+53
-38
csrc/flash_attn/src/paged_attention_938.cu
csrc/flash_attn/src/paged_attention_938.cu
+46
-34
csrc/flash_attn/src/utils.h
csrc/flash_attn/src/utils.h
+59
-0
csrc/flash_attn_hg/flash_api.cpp
csrc/flash_attn_hg/flash_api.cpp
+14
-25
flash_attn/__init__.py
flash_attn/__init__.py
+1
-0
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+74
-20
setup.py
setup.py
+5
-0
No files found.
benchmarks/benchmark_pa.py
0 → 100644
View file @
4b535e64
import
math
import
time
import
pytest
import
torch
import
random
import
torch.nn.functional
as
F
import
csv
from
einops
import
rearrange
,
repeat
# from flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache
from
flash_attn
import
vllm_flash_attn_with_kvcache
as
_flash_attn_with_kvcache
max_seqlen
=
8192
*
5
# max_seqlen=4352
eager
=
True
# eager=False
def
attention_ref
(
q
,
k
,
v
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
attn_bias
=
None
,
dropout_p
=
0.0
,
dropout_mask
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
softcap
=
0.0
,
upcast
=
True
,
reorder_ops
=
False
,
key_leftpad
=
None
,
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
k: (batch_size, seqlen_k, nheads_k, head_dim)
v: (batch_size, seqlen_k, nheads_k, head_dim)
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
causal: whether to apply causal masking
window_size: (int, int), left and right window size
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.)
without changing the math. This is to estimate the numerical error from operation
reordering.
Output:
output: (batch_size, seqlen_q, nheads, head_dim)
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
"""
if
causal
:
window_size
=
(
window_size
[
0
],
0
)
dtype_og
=
q
.
dtype
if
upcast
:
q
,
k
,
v
=
q
.
float
(),
k
.
float
(),
v
.
float
()
seqlen_q
,
seqlen_k
=
q
.
shape
[
1
],
k
.
shape
[
1
]
k
=
repeat
(
k
,
"b s h d -> b s (h g) d"
,
g
=
q
.
shape
[
2
]
//
k
.
shape
[
2
])
v
=
repeat
(
v
,
"b s h d -> b s (h g) d"
,
g
=
q
.
shape
[
2
]
//
v
.
shape
[
2
])
d
=
q
.
shape
[
-
1
]
if
not
reorder_ops
:
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
/
math
.
sqrt
(
d
),
k
)
else
:
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
,
k
/
math
.
sqrt
(
d
))
if
softcap
>
0
:
scores
=
scores
/
softcap
scores
=
scores
.
tanh
()
scores
=
scores
*
softcap
if
key_padding_mask
is
not
None
:
scores
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
float
(
"-inf"
))
if
attn_bias
is
not
None
:
scores
=
scores
+
attn_bias
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
).
to
(
v
.
dtype
)
# Some rows might be completely masked out so we fill them with zero instead of NaN
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
attention
=
attention
.
masked_fill
(
torch
.
all
(
local_mask
,
dim
=-
1
,
keepdim
=
True
),
0.0
)
# We want to mask here so that the attention matrix doesn't have any NaNs
# Otherwise we'll get NaN in dV
if
query_padding_mask
is
not
None
:
attention
=
attention
.
masked_fill
(
rearrange
(
~
query_padding_mask
,
"b s -> b 1 s 1"
),
0.0
)
dropout_scaling
=
1.0
/
(
1
-
dropout_p
)
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
if
dropout_mask
is
not
None
:
attention_drop
=
attention
.
masked_fill
(
~
dropout_mask
,
0.0
)
else
:
attention_drop
=
attention
output
=
torch
.
einsum
(
"bhts,bshd->bthd"
,
attention_drop
,
v
*
dropout_scaling
)
if
query_padding_mask
is
not
None
:
output
.
masked_fill_
(
rearrange
(
~
query_padding_mask
,
"b s -> b s 1 1"
),
0.0
)
return
output
.
to
(
dtype
=
dtype_og
),
attention
.
to
(
dtype
=
dtype_og
)
def
test_flash_attn_kvcache
(
seqlen_q
,
seqlen_k
,
d
,
has_batch_idx
,
has_leftpad
,
paged_kv_block_size
,
rotary_fraction
,
rotary_interleaved
,
seqlen_new_eq_seqlen_q
,
causal
,
local
,
alibi
,
new_kv
,
dtype
,
batch_size
,
qhead
,
kv_head
,
prof
=
False
,
):
# if seqlen_q > seqlen_k and new_kv:
# pytest.skip()
# if not new_kv and rotary_fraction > 0.0:
# pytest.skip()
# if has_batch_idx and paged_kv_block_size is not None:
# pytest.skip()
# if has_leftpad and paged_kv_block_size is not None:
# pytest.skip()
device
=
"cuda"
# set seed
torch
.
random
.
manual_seed
(
0
)
# batch_size = 64
# nheads = 32
batch_size_cache
=
batch_size
if
not
has_batch_idx
else
batch_size
*
2
# rotary_dim must be a multiple of 16, and must be <= d
rotary_dim
=
math
.
floor
(
int
(
rotary_fraction
*
d
)
/
16
)
*
16
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
qhead
,
d
,
device
=
device
,
dtype
=
dtype
)
seqlen_new
=
seqlen_q
if
seqlen_new_eq_seqlen_q
else
torch
.
randint
(
1
,
seqlen_q
+
1
,
(
1
,)).
item
()
nheads_k
=
kv_head
# alloc k v
if
new_kv
:
k
=
torch
.
randn
(
batch_size
,
seqlen_new
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
v
=
torch
.
randn
(
batch_size
,
seqlen_new
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
else
:
k
,
v
=
None
,
None
# 生成kvcache
if
paged_kv_block_size
is
None
:
k_cache
=
torch
.
randn
(
batch_size_cache
,
seqlen_k
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
v_cache
=
torch
.
randn
(
batch_size_cache
,
seqlen_k
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
block_table
=
None
else
:
(
k_cache
,
v_cache
,
block_table
,
k_cache_paged
,
v_cache_paged
,
num_blocks
,
)
=
_generate_block_kvcache
(
seqlen_k
,
paged_kv_block_size
,
batch_size
,
nheads_k
,
d
,
device
,
dtype
)
seq_lens
=
[
seqlen_k
for
_
in
range
(
batch_size
)]
cache_seqlens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int
,
device
=
device
)
if
has_leftpad
:
cache_leftpad
=
torch
.
cat
([
torch
.
randint
(
0
,
cache_seqlens
[
i
].
item
(),
(
1
,),
dtype
=
torch
.
int32
,
device
=
device
)
if
cache_seqlens
[
i
].
item
()
>
0
else
torch
.
zeros
(
1
,
dtype
=
torch
.
int32
,
device
=
device
)
for
i
in
range
(
batch_size
)])
else
:
cache_leftpad
=
None
arange
=
rearrange
(
torch
.
arange
(
seqlen_k
,
device
=
device
),
"s -> 1 s"
)
cache_seqlens_expanded
=
rearrange
(
cache_seqlens
,
"b -> b 1"
)
key_padding_mask
=
arange
<
cache_seqlens_expanded
+
(
seqlen_new
if
new_kv
else
0
)
if
has_leftpad
:
key_padding_mask
=
torch
.
logical_and
(
key_padding_mask
,
arange
>=
cache_leftpad
.
unsqueeze
(
-
1
).
expand
(
-
1
,
seqlen_k
)
)
if
has_batch_idx
:
cache_batch_idx
=
torch
.
randperm
(
batch_size_cache
,
dtype
=
torch
.
int32
,
device
=
device
)[
:
batch_size
]
else
:
cache_batch_idx
=
None
alibi_slopes
,
attn_bias
=
None
,
None
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
cos
,
sin
=
None
,
None
q_ro
,
k_ro
=
q
,
k
# k_cache[:, 64:] = -1
k_cache_ref
=
(
k_cache
if
not
has_batch_idx
else
k_cache
[
cache_batch_idx
.
to
(
dtype
=
torch
.
long
)]
).
clone
()
v_cache_ref
=
(
v_cache
if
not
has_batch_idx
else
v_cache
[
cache_batch_idx
.
to
(
dtype
=
torch
.
long
)]
).
clone
()
if
new_kv
:
update_mask
=
torch
.
logical_and
(
cache_seqlens_expanded
<=
arange
,
arange
<
cache_seqlens_expanded
+
seqlen_new
)
k_cache_ref
[
update_mask
]
=
rearrange
(
k_ro
,
"b s ... -> (b s) ..."
)
v_cache_ref
[
update_mask
]
=
rearrange
(
v
,
"b s ... -> (b s) ..."
)
# k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
# v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
k_cache_rep
=
repeat
(
k_cache_ref
,
"b s h d -> b s (h g) d"
,
g
=
nheads_k
//
nheads_k
)
v_cache_rep
=
repeat
(
v_cache_ref
,
"b s h d -> b s (h g) d"
,
g
=
nheads_k
//
nheads_k
)
q_scale
=
torch
.
tensor
([
0.5
],
dtype
=
torch
.
float32
,
device
=
device
)
k_scale
=
torch
.
tensor
([
0.5
],
dtype
=
torch
.
float32
,
device
=
device
)
v_scale
=
torch
.
tensor
([
0.25
],
dtype
=
torch
.
float32
,
device
=
device
)
# new_type = torch.float8_e5m2
# new_type = torch.float8_e4m3fn
new_type
=
dtype
k_cache_paged
=
k_cache_paged
.
permute
(
0
,
2
,
1
,
3
).
contiguous
().
to
(
new_type
)
v_cache_paged
=
v_cache_paged
.
permute
(
0
,
2
,
3
,
1
).
contiguous
().
to
(
new_type
)
max_seqlen_k
=
seqlen_k
# max_seqlen_k=32768
# warm
for
i
in
range
(
10
):
out
=
_flash_attn_with_kvcache
(
q
,
k_cache
if
paged_kv_block_size
is
None
else
k_cache_paged
,
v_cache
if
paged_kv_block_size
is
None
else
v_cache_paged
,
cache_seqlens
=
cache_seqlens
,
block_table
=
block_table
,
causal
=
causal
,
max_seqlen_k
=
max_seqlen_k
,
q_scale
=
q_scale
,
k_scale
=
k_scale
,
v_scale
=
v_scale
,
)
# prof time
torch
.
cuda
.
synchronize
()
repeat_num
=
100
start_time
=
time
.
time
()
for
i
in
range
(
repeat_num
):
out
=
_flash_attn_with_kvcache
(
q
,
k_cache
if
paged_kv_block_size
is
None
else
k_cache_paged
,
v_cache
if
paged_kv_block_size
is
None
else
v_cache_paged
,
cache_seqlens
=
cache_seqlens
,
block_table
=
block_table
,
causal
=
causal
,
max_seqlen_k
=
max_seqlen_k
,
q_scale
=
q_scale
,
k_scale
=
k_scale
,
v_scale
=
v_scale
,
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
fc1_espl
=
end_time
-
start_time
DCU_time
=
fc1_espl
*
1000
*
1000
/
repeat_num
IO_bytes
=
batch_size
*
seqlen_k
*
kv_head
*
d
*
2
*
k_cache_paged
.
element_size
()
#kv cache size to read
IO_bytes
+=
batch_size
*
qhead
*
d
*
q
.
element_size
()
#q size to read
IO_bytes
+=
(
seqlen_k
//
512
+
1
)
*
batch_size
*
qhead
*
d
*
2
*
2
# temp to write and read
IO_bytes
+=
batch_size
*
qhead
*
d
*
2
#output to write
IO_speed
=
IO_bytes
/
DCU_time
/
1024
/
1024
/
1024
*
1000
*
1000
print
(
'FA_kvcache bs='
,
batch_size
,
' seqlen='
,
seqlen_k
,
' qhead='
,
qhead
,
' kv_head='
,
kv_head
,
' time is'
,
'{:.2f}'
.
format
(
DCU_time
),
'us Bandwidth='
,
'{:.2f}'
.
format
(
IO_speed
),
'GB/s'
)
res_list
=
[
paged_kv_block_size
,
batch_size
,
seqlen_k
,
d
,
qhead
,
kv_head
,
DCU_time
,
IO_speed
]
# print('FA_kvcache bs=', batch_size,' seqlen=',seqlen_k,' qhead=',qhead, ' kv_head=',kv_head, ' time is', '{:.2f}'.format(DCU_time), 'us')
# res_list = [paged_kv_block_size, batch_size, seqlen_k, d, qhead, kv_head, DCU_time]
return
res_list
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
if
new_kv
:
if
paged_kv_block_size
is
None
:
k_cache_select
=
(
k_cache
if
not
has_batch_idx
else
k_cache
[
cache_batch_idx
.
to
(
dtype
=
torch
.
long
)]
)
v_cache_select
=
(
v_cache
if
not
has_batch_idx
else
v_cache
[
cache_batch_idx
.
to
(
dtype
=
torch
.
long
)]
)
else
:
k_cache_select
=
rearrange
(
k_cache_paged
[
block_table
.
to
(
dtype
=
torch
.
long
).
flatten
()],
"(b nblocks) block_size ... -> b (nblocks block_size) ..."
,
b
=
batch_size
,
)[:,
:
seqlen_k
]
v_cache_select
=
rearrange
(
v_cache_paged
[
block_table
.
to
(
dtype
=
torch
.
long
).
flatten
()],
"(b nblocks) block_size ... -> b (nblocks block_size) ..."
,
b
=
batch_size
,
)[:,
:
seqlen_k
]
assert
torch
.
allclose
(
k_cache_select
,
k_cache_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
assert
torch
.
equal
(
v_cache_select
,
v_cache_ref
)
mult
=
3
if
not
alibi
else
5
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
mult
*
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
+
1e-5
def
_generate_block_kvcache
(
seqlen_k
,
paged_kv_block_size
,
batch_size
,
nheads_k
,
d
,
device
,
dtype
):
num_blocks
=
50000
k_cache_paged
=
torch
.
randn
(
num_blocks
,
paged_kv_block_size
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
v_cache_paged
=
torch
.
randn
(
num_blocks
,
paged_kv_block_size
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
if
eager
:
max_num_blocks_per_seq
=
(
seqlen_k
+
paged_kv_block_size
-
1
)
//
paged_kv_block_size
else
:
max_num_blocks_per_seq
=
(
max_seqlen
+
paged_kv_block_size
-
1
)
//
paged_kv_block_size
block_tables
=
[]
for
_
in
range
(
batch_size
):
block_table
=
[
random
.
randint
(
0
,
num_blocks
-
1
)
for
_
in
range
(
max_num_blocks_per_seq
)
]
block_tables
.
append
(
block_table
)
block_tables
=
torch
.
tensor
(
block_tables
,
dtype
=
torch
.
int
,
device
=
device
)
# # randperm torch.randperm
# block_table = rearrange(
# torch.randperm(batch_size*max_seqlen//paged_kv_block_size, dtype=torch.int32, device=device),
# "(b nblocks) -> b nblocks",
# b=batch_size,
# )
k_cache
=
rearrange
(
# pytorch 1.12 doesn't have indexing with int32
k_cache_paged
[
block_tables
.
to
(
dtype
=
torch
.
long
).
flatten
()],
"(b nblocks) block_size ... -> b (nblocks block_size) ..."
,
b
=
batch_size
,
)[:,
:
seqlen_k
]
v_cache
=
rearrange
(
v_cache_paged
[
block_tables
.
to
(
dtype
=
torch
.
long
).
flatten
()],
"(b nblocks) block_size ... -> b (nblocks block_size) ..."
,
b
=
batch_size
,
)[:,
:
seqlen_k
]
return
k_cache
,
v_cache
,
block_tables
,
k_cache_paged
,
v_cache_paged
,
num_blocks
# mha
if
__name__
==
"__main__"
:
# HIP_VISIBLE_DEVICES=6 python test_kvcache.py
#config = [(1,16,16),(1,32,32),(1,32,4),(64,32,4),(1,52,4),(64,52,4),(1,16,2),(64,16,2),(1,26,2),(64,26,2),(1,8,1),(64,8,1),(1,13,1),(64,13,1)]
# config = [(120,6,1),(120,8,1),(120,28,4),(120,16,2),(120,20,4)]
# seq_lens=[600,1200,2400,4800]
random
.
seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
# batchsize = [4,8,16,24,32,48,56,64,72,88,120]
# batchsize = [1,2,4,8,16,24,32,40,48,56,64,72,80,88,96,104]
batchsize
=
[
1
,
8
,
32
,
128
]
# batchsize = [128,256,512]
# batchsize = [16,24,32,40,48,56,64,72,80,88,96] #70B,235B
# batchsize = [24,32,40,48,56] #30B
# batchsize = [40,48,56,64,72,80,88,96] #8B
# head = [(32,2)]
# head = [(12,1)]
head
=
[(
16
,
2
),(
32
,
8
)]
# head = [(15,1),(16,1)]
# head = [(8,1),(9,1),(10,1),(11,1),(12,1),(13,1),(14,1),(15,1),(16,1),(17,1),(18,1),(19,1),(20,1),(21,1),(22,1),(23,1),(24,1),(25,1),(26,1),(27,1),(28,1),(29,1),(30,1),(31,1),(32,1)]
# head = [(4,1),(8,1),(12,1),(16,1),(24,1)]
# seq_lens=[100,400,700,1000,1300,1600,1900,2200,2500,2800,3100,3400,3700,4000,4300]
# seq_lens=[2000,2100,2200,2300,2400,2500,2600,2700]
seq_lens
=
[
2048
,
8192
,
32768
]
# seq_lens=[8192,128000]
# seq_lens=[1000,1100,1350,1500,1650,1800,2000,2300,2600,3000,3300,3500,3700,4000,4096,4100,4200,4300,4500,4700,5000]
# seq_lens=[3000,3300,3500,3800,4000,4300,4500,4800,5000]
# seq_lens=[500,700,1000,1300,2000,3000,4000,16000,18000,20000]
# seq_lens=[200,500,800,1100,1300,2000,3000,4000,5000,15000,16000,18000,20000]
# seq_lens=[200,500,800,1100,1300,2000,3000,4000,5000,16000,16500,17000,17500,18000,18500,19000,19500,20000]
# seq_lens=[16000,17000,18000,19000,20000,21000]
# heads = [8, 10, 16, 18, 20, 28, 30, 32, 38, 40, 48, 50, 58, 60, 64, 68, 70]
# batchs = [64]
# seq_lens=[1500]
dtype
=
torch
.
float16
# dtype=torch.bfloat16
print
(
dtype
)
res_time
=
[]
for
qh
,
kh
in
head
:
for
bs
in
batchsize
:
for
seq
in
seq_lens
:
# if (not (seq>=10000 and bs>16)) and seq<max_seqlen:
if
True
:
prof_time
=
test_flash_attn_kvcache
(
seqlen_q
=
1
,
seqlen_k
=
seq
,
#128 512
d
=
128
,
# 64 128 160 256
has_batch_idx
=
False
,
has_leftpad
=
False
,
paged_kv_block_size
=
64
,
#16 256
rotary_fraction
=
0.0
,
rotary_interleaved
=
False
,
seqlen_new_eq_seqlen_q
=
True
,
causal
=
True
,
# 因果注意力机制
local
=
False
,
# 局部注意力
alibi
=
False
,
new_kv
=
False
,
dtype
=
dtype
,
batch_size
=
bs
,
qhead
=
qh
,
kv_head
=
kh
,
prof
=
False
# 运行单次
)
res_time
.
append
(
prof_time
)
with
open
(
'kvcache_time.csv'
,
'w'
,
newline
=
''
)
as
csvfile
:
writer
=
csv
.
writer
(
csvfile
)
for
row
in
res_time
:
writer
.
writerow
(
row
)
csrc/flash_attn/flash_api.cpp
View file @
4b535e64
...
@@ -225,6 +225,7 @@ hg_varlen_bwd_bshd(const at::Tensor &dout,
...
@@ -225,6 +225,7 @@ hg_varlen_bwd_bshd(const at::Tensor &dout,
static
const
bool
print_param
=
get_env_
(
"FLASH_ATTENTION_PRINT_PARAM"
);
static
const
bool
print_param
=
get_env_
(
"FLASH_ATTENTION_PRINT_PARAM"
);
static
const
bool
print_hg_path
=
get_env_
(
"FLASH_ATTENTION_PRINT_HG"
);
static
const
bool
print_hg_path
=
get_env_
(
"FLASH_ATTENTION_PRINT_HG"
);
static
const
bool
disable_varlen_tiny_dim64
=
get_env_
(
"FLASH_ATTENTION_DISABLE_VARLEN_TINY_DIM64"
);
static
const
bool
disable_varlen_tiny_dim64
=
get_env_
(
"FLASH_ATTENTION_DISABLE_VARLEN_TINY_DIM64"
);
static
const
bool
enable_hg_varlen
=
get_env_
(
"FLASH_ATTENTION_ENABLE_HG_VARLEN"
);
#ifdef HAS_HG_DISPATCH
#ifdef HAS_HG_DISPATCH
...
@@ -741,20 +742,23 @@ void run_mha_fwd_prefix_kv_fp8(Flash_fwd_params ¶ms, cudaStream_t stream, bo
...
@@ -741,20 +742,23 @@ void run_mha_fwd_prefix_kv_fp8(Flash_fwd_params ¶ms, cudaStream_t stream, bo
void
run_mha_fwd_unified
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
,
bool
force_split_kernel
=
false
)
{
void
run_mha_fwd_unified
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
,
bool
force_split_kernel
=
false
)
{
FP16_SWITCH
(
!
params
.
is_bf16
,
[
&
]
{
FP16_SWITCH
(
!
params
.
is_bf16
,
[
&
]
{
// using elem_type = cutlass::half_t;
HEADDIM_SWITCH
(
params
.
d
,
[
&
]
// using elem_type = cutlass::float_e5m2_t;
{
// HEADDIM_SWITCH_FP8(params.d, [&] {
// using elem_type = cutlass::half_t;
constexpr
static
int
kHeadDim
=
256
;
// using elem_type = cutlass::float_e5m2_t;
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
// HEADDIM_SWITCH_FP8(params.d, [&] {
if
(
params
.
d
!=
256
)
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
TORCH_CHECK
(
false
,
"unified attn only support dim=256"
);
if
(
params
.
d
!=
256
&&
params
.
d
!=
128
)
{
}
TORCH_CHECK
(
false
,
"unified attn only support dim=128/256"
);
run_mha_fwd_unified_dispatch
<
elem_type
,
kHeadDim
,
Is_causal
>
(
params
,
stream
);
}
run_mha_fwd_unified_dispatch
<
elem_type
,
kHeadDim
,
Is_causal
>
(
params
,
stream
);
});
// });
});
// });
});
});
});
}
}
void
run_mha_fwd_mla
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
,
bool
force_split_kernel
=
false
)
{
void
run_mha_fwd_mla
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
,
bool
force_split_kernel
=
false
)
{
params
.
num_splits
=
1
;
params
.
num_splits
=
1
;
if
(
params
.
is_fp8
==
true
)
if
(
params
.
is_fp8
==
true
)
...
@@ -961,7 +965,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
...
@@ -961,7 +965,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
alibi_slopes_
,
s_aux_
,
alibi_slopes_
,
s_aux_
,
skip_softmax_threshold_scale_factor
,
skip_softmax_threshold_scale_factor
,
is_causal
,
seqlen_q
,
seqlen_k
,
is_causal
,
seqlen_q
,
seqlen_k
,
window_size_left
,
window_size_right
))
{
window_size_left
,
window_size_right
)
&&
(
!
is_bhsd
)
)
{
if
(
print_param
||
print_hg_path
)
{
if
(
print_param
||
print_hg_path
)
{
printf
(
"[flash_attn] HG PATH layout=%s q=(%d,%d,%d,%d) k=(%d,%d,%d,%d) v=(%d,%d,%d,%d)
\n
"
,
printf
(
"[flash_attn] HG PATH layout=%s q=(%d,%d,%d,%d) k=(%d,%d,%d,%d) v=(%d,%d,%d,%d)
\n
"
,
is_bhsd
?
"bhsd"
:
"bshd"
,
is_bhsd
?
"bhsd"
:
"bshd"
,
...
@@ -2019,7 +2023,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
...
@@ -2019,7 +2023,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
if
(
can_use_hg_dense_bwd
(
if
(
can_use_hg_dense_bwd
(
q
.
scalar_type
(),
alibi_slopes_
,
q
.
scalar_type
(),
alibi_slopes_
,
head_size
,
head_size_value
,
is_causal
,
seqlen_q
,
seqlen_k
,
head_size
,
head_size_value
,
is_causal
,
seqlen_q
,
seqlen_k
,
window_size_left
,
window_size_right
,
p_dropout
))
{
window_size_left
,
window_size_right
,
p_dropout
)
&&
(
!
is_bhsd
)
)
{
if
(
print_param
||
print_hg_path
)
{
if
(
print_param
||
print_hg_path
)
{
printf
(
"[flash_attn] HG BWD PATH layout=%s q=(%d,%d,%d,%d) k=(%d,%d,%d,%d) v=(%d,%d,%d,%d) dout=(%d,%d,%d,%d)
\n
"
,
printf
(
"[flash_attn] HG BWD PATH layout=%s q=(%d,%d,%d,%d) k=(%d,%d,%d,%d) v=(%d,%d,%d,%d) dout=(%d,%d,%d,%d)
\n
"
,
is_bhsd
?
"bhsd"
:
"bshd"
,
is_bhsd
?
"bhsd"
:
"bshd"
,
...
@@ -2308,7 +2312,8 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
...
@@ -2308,7 +2312,8 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
CHECK_SHAPE
(
cu_seqlens_k
,
batch_size
+
1
);
CHECK_SHAPE
(
cu_seqlens_k
,
batch_size
+
1
);
#ifdef HAS_HG_DISPATCH
#ifdef HAS_HG_DISPATCH
if
(
can_use_hg_varlen_bwd
(
if
(
enable_hg_varlen
&&
can_use_hg_varlen_bwd
(
q
.
scalar_type
(),
alibi_slopes_
,
q
.
scalar_type
(),
alibi_slopes_
,
head_size
,
head_size_value
,
total_q
,
total_k
,
max_seqlen_k
,
head_size
,
head_size_value
,
total_q
,
total_k
,
max_seqlen_k
,
window_size_left
,
window_size_right
,
p_dropout
))
{
window_size_left
,
window_size_right
,
p_dropout
))
{
...
@@ -4459,7 +4464,7 @@ TORCH_LIBRARY_IMPL(flash_attn2_c_op, CUDA, m) {
...
@@ -4459,7 +4464,7 @@ TORCH_LIBRARY_IMPL(flash_attn2_c_op, CUDA, m) {
return
std
::
make_tuple
(
results
[
0
],
results
[
1
]);
return
std
::
make_tuple
(
results
[
0
],
results
[
1
]);
});
});
}
}
at
::
Tensor
mean_pool_fast
(
const
at
::
Tensor
&
input
,
int
blk
,
const
c10
::
optional
<
at
::
Tensor
>
&
mean
);
// ============================================================================
// ============================================================================
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
...
@@ -4484,6 +4489,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -4484,6 +4489,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"varlen_bwd_attnmask"
,
&
mha_varlen_bwd_attnmask
,
"Backward pass (variable length), with explicit attention mask"
);
m
.
def
(
"varlen_bwd_attnmask"
,
&
mha_varlen_bwd_attnmask
,
"Backward pass (variable length), with explicit attention mask"
);
m
.
def
(
"paged_attention"
,
&
paged_attention
,
"Forward pass, with KV-cache"
);
m
.
def
(
"paged_attention"
,
&
paged_attention
,
"Forward pass, with KV-cache"
);
m
.
def
(
"fwd_sparse"
,
&
mha_fwd_sparse
,
"Forward sparse pass"
);
m
.
def
(
"fwd_sparse"
,
&
mha_fwd_sparse
,
"Forward sparse pass"
);
m
.
def
(
"fwd_sparse_mean_pool_fast"
,
&
mean_pool_fast
,
"before mha_fwd_sparse"
);
m
.
def
(
"varlen_fwd_sparse"
,
&
mha_varlen_fwd_sparse
,
"Forward pass sparse (variable length)"
);
m
.
def
(
"varlen_fwd_sparse"
,
&
mha_varlen_fwd_sparse
,
"Forward pass sparse (variable length)"
);
m
.
def
(
"varlen_fwd_unified"
,
&
unified2D_attention_fwd
,
"Forward pass unified attn (variable length && block table)"
);
m
.
def
(
"varlen_fwd_unified"
,
&
unified2D_attention_fwd
,
"Forward pass unified attn (variable length && block table)"
);
}
}
csrc/flash_attn/src/dropout.h
View file @
4b535e64
...
@@ -311,7 +311,7 @@ struct Dropout {
...
@@ -311,7 +311,7 @@ struct Dropout {
for
(
int
i
=
0
;
i
<
size
<
1
>
(
tensor
);
++
i
)
for
(
int
i
=
0
;
i
<
size
<
1
>
(
tensor
);
++
i
)
{
{
const
int
row_idx_base
=
block_row_start
+
i
*
block_row_stride
+
(
threadIdx
.
x
/
64
)
*
16
;
const
int
row_idx_base
=
block_row_start
+
i
*
block_row_stride
+
(
threadIdx
.
x
/
64
)
*
16
+
lane_id
%
16
;
const
int
row_idx
=
row_idx_base
;
const
int
row_idx
=
row_idx_base
;
uint2
rowcol
=
make_uint2
(
row_idx
,
col_idx_offset
);
uint2
rowcol
=
make_uint2
(
row_idx
,
col_idx_offset
);
uint4
random_uint4
=
flash
::
philox
(
seed
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
),
offset
);
uint4
random_uint4
=
flash
::
philox
(
seed
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
),
offset
);
...
@@ -344,7 +344,7 @@ struct Dropout {
...
@@ -344,7 +344,7 @@ struct Dropout {
}
}
};
};
const
int
lane_id
=
threadIdx
.
x
%
64
;
const
int
lane_id
=
threadIdx
.
x
%
64
;
const
int
col_idx_offset
=
block_col_start
+
(
threadIdx
.
x
/
64
)
*
16
;
const
int
col_idx_offset
=
block_col_start
+
(
threadIdx
.
x
/
64
)
*
16
+
lane_id
%
16
;
extern
__shared__
char
smem_
[];
extern
__shared__
char
smem_
[];
uint8_t
*
p_rand_8
=
reinterpret_cast
<
uint8_t
*>
(
smem_
+
16384
);
uint8_t
*
p_rand_8
=
reinterpret_cast
<
uint8_t
*>
(
smem_
+
16384
);
...
@@ -369,8 +369,10 @@ struct Dropout {
...
@@ -369,8 +369,10 @@ struct Dropout {
uint8_t
(
&
rnd_8
)[
16
]
=
reinterpret_cast
<
uint8_t
(
&
)[
16
]
>
(
random_uint4
);
uint8_t
(
&
rnd_8
)[
16
]
=
reinterpret_cast
<
uint8_t
(
&
)[
16
]
>
(
random_uint4
);
*
reinterpret_cast
<
uint4
*>
(
&
p_rand_8
[
row_
*
RAND_STRIDE
+
col_
])
=
random_uint4
;
*
reinterpret_cast
<
uint4
*>
(
&
p_rand_8
[
row_
*
RAND_STRIDE
+
col_
])
=
random_uint4
;
__syncthreads
();
// __syncthreads();
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
size
<
2
>
(
tensor
);
++
j
)
{
for
(
int
j
=
0
;
j
<
size
<
2
>
(
tensor
);
++
j
)
{
#pragma unroll
#pragma unroll
...
...
csrc/flash_attn/src/flash_bwd_kernel.h
View file @
4b535e64
This diff is collapsed.
Click to expand it.
csrc/flash_attn/src/flash_bwd_launch_template.h
View file @
4b535e64
...
@@ -384,7 +384,7 @@ void run_flash_bwd_separate_prefetch(Flash_bwd_params ¶ms, cudaStream_t stre
...
@@ -384,7 +384,7 @@ void run_flash_bwd_separate_prefetch(Flash_bwd_params ¶ms, cudaStream_t stre
const
bool
is_even_K
=
params
.
d
==
Kernel_traits
::
kHeadDim
;
const
bool
is_even_K
=
params
.
d
==
Kernel_traits
::
kHeadDim
;
constexpr
int
smem_size_dropout
=
Kernel_trans_traits
::
kBlockM
*
Kernel_trans_traits
::
kBlockN
;
constexpr
int
smem_size_dropout
=
Kernel_trans_traits
::
kBlockM
*
Kernel_trans_traits
::
kBlockN
;
constexpr
int
smem_size_dk_dv
=
Kernel_trans_traits
::
kSmemPrefetchSize
;
constexpr
int
smem_size_dk_dv
=
Kernel_trans_traits
::
kSmemPrefetchSize
;
constexpr
int
smem_size_dk_dv_total
=
(
Kernel_trans_traits
::
kHeadDim
==
128
)
?
(
smem_size_dk_dv
+
smem_size_dropout
)
:
(
smem_size_dk_dv
);
constexpr
int
smem_size_dk_dv_total
=
(
Kernel_trans_traits
::
kHeadDim
==
128
||
Kernel_trans_traits
::
kHeadDim
==
64
)
?
(
smem_size_dk_dv
+
smem_size_dropout
)
:
(
smem_size_dk_dv
);
constexpr
int
smem_size_dq
=
Kernel_traits
::
kSmemPrefetchSize
;
constexpr
int
smem_size_dq
=
Kernel_traits
::
kSmemPrefetchSize
;
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
// constexpr static bool IsEvenMNConst = false;
// constexpr static bool IsEvenMNConst = false;
...
@@ -561,7 +561,7 @@ void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) {
...
@@ -561,7 +561,7 @@ void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) {
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
if
(
get_device_name
()
==
"gfx936"
||
get_device_name
()
==
"gfx938"
)
if
(
get_device_name
()
==
"gfx936"
||
get_device_name
()
==
"gfx938"
)
{
{
using
kernel_trans_traits
=
Flash_bwd_kernel_trans_16x64_prefetch_traits
<
Headdim
,
/*kBlockM_*/
64
,
/*kBlockN_*/
128
,
/*kNWarps_*/
4
,
T
,
3
>
;
using
kernel_trans_traits
=
Flash_bwd_kernel_trans_16x64_prefetch_traits
<
Headdim
,
/*kBlockM_*/
64
,
/*kBlockN_*/
Is_dropout
?
64
:
128
,
/*kNWarps_*/
4
,
T
,
3
>
;
using
kernel_traits
=
Flash_bwd_kernel_dq_16x64_prefetch_traits
<
Headdim
,
/*kBlockM_*/
128
,
/*kBlockN_*/
64
,
/*kNWarps_*/
4
,
using
kernel_traits
=
Flash_bwd_kernel_dq_16x64_prefetch_traits
<
Headdim
,
/*kBlockM_*/
128
,
/*kBlockN_*/
64
,
/*kNWarps_*/
4
,
/*AtomLayoutMSdP_*/
4
,
/*AtomLayoutNdKV*/
1
,
/*AtomLayoutMdQ*/
4
,
/*Is_V_in_regs_*/
false
,
/*AtomLayoutMSdP_*/
4
,
/*AtomLayoutNdKV*/
1
,
/*AtomLayoutMdQ*/
4
,
/*Is_V_in_regs_*/
false
,
/*No_double_buffer_*/
true
,
/*Is_Q_in_regs_*/
false
,
/*Share_Q_K_smem_*/
true
,
T
,
3
>
;
/*No_double_buffer_*/
true
,
/*Is_Q_in_regs_*/
false
,
/*Share_Q_K_smem_*/
true
,
T
,
3
>
;
...
...
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
4b535e64
This diff is collapsed.
Click to expand it.
csrc/flash_attn/src/flash_fwd_launch_template.h
View file @
4b535e64
...
@@ -770,8 +770,15 @@ void run_mha_fwd_unified_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream)
...
@@ -770,8 +770,15 @@ void run_mha_fwd_unified_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream)
using
combine_kernel_traits
=
Flash_fwd_kernel_16x64_traits_splitkv
<
256
,
kBlockM
,
kBlockN
,
4
,
false
,
false
,
T
,
256
>
;
using
combine_kernel_traits
=
Flash_fwd_kernel_16x64_traits_splitkv
<
256
,
kBlockM
,
kBlockN
,
4
,
false
,
false
,
T
,
256
>
;
run_flash_splitkv_fwd_16x64_unified_prefetch
<
prefetch_kernel_traits
,
combine_kernel_traits
,
Is_causal
>
(
params
,
stream
);
run_flash_splitkv_fwd_16x64_unified_prefetch
<
prefetch_kernel_traits
,
combine_kernel_traits
,
Is_causal
>
(
params
,
stream
);
}
}
}
else
{
}
else
if
constexpr
(
Headdim
==
128
)
{
assert
(
false
&&
"unified attn only supported headdim=256"
);
if
(
get_device_name
()
==
"gfx936"
||
get_device_name
()
==
"gfx938"
)
{
assert
(
params
.
knew_ptr
==
nullptr
&&
params
.
block_table
!=
nullptr
);
using
prefetch_kernel_traits
=
Flash_fwd_kernel_16x64_splitkv_prefetch_unified_traits
<
128
,
64
,
64
,
4
,
T
,
3
,
128
>
;
using
combine_kernel_traits
=
Flash_fwd_kernel_16x64_traits_splitkv
<
128
,
kBlockM
,
kBlockN
,
4
,
false
,
false
,
T
,
128
>
;
run_flash_splitkv_fwd_16x64_unified_prefetch
<
prefetch_kernel_traits
,
combine_kernel_traits
,
Is_causal
>
(
params
,
stream
);
}
}
else
{
assert
(
false
&&
"unified attn only supported headdim=128/256"
);
}
}
}
}
...
@@ -797,7 +804,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -797,7 +804,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) {
if
(
params
.
seqlen_q
<=
64
||
params
.
h
*
params
.
b
*
mblocks
<
4
*
sm_count
)
{
if
(
params
.
seqlen_q
<=
64
||
params
.
h
*
params
.
b
*
mblocks
<
4
*
sm_count
)
{
run_flash_fwd_16x64_prefetch
<
Flash_fwd_kernel_16x64_prefetch_traits_dim64
<
Headdim
,
64
,
64
,
4
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
run_flash_fwd_16x64_prefetch
<
Flash_fwd_kernel_16x64_prefetch_traits_dim64
<
Headdim
,
64
,
64
,
4
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
else
{
}
else
{
run_flash_fwd_16x64_prefetch
<
Flash_fwd_kernel_16x64_prefetch_traits_dim64
<
Headdim
,
256
,
64
,
4
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
run_flash_fwd_16x64_prefetch
<
Flash_fwd_kernel_16x64_prefetch_traits_dim64
<
Headdim
,
Is_dropout
?
128
:
256
,
64
,
4
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
}
}
else
{
}
else
{
run_flash_fwd_16x64
<
Flash_fwd_kernel_16x64_traits
<
Headdim
,
256
,
64
,
4
,
/*Is_Q_use_smem_=*/
false
,
/*Share_K_V_smem_=*/
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
run_flash_fwd_16x64
<
Flash_fwd_kernel_16x64_traits
<
Headdim
,
256
,
64
,
4
,
/*Is_Q_use_smem_=*/
false
,
/*Share_K_V_smem_=*/
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
...
...
csrc/flash_attn/src/flash_fwd_unified_hdim128_bf16_causal_sm80.cu
0 → 100644
View file @
4b535e64
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_unified_dispatch
<
cutlass
::
bfloat16_t
,
128
,
true
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_unified_hdim128_bf16_sm80.cu
0 → 100644
View file @
4b535e64
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_unified_dispatch
<
cutlass
::
bfloat16_t
,
128
,
false
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_unified_hdim128_fp16_causal_sm80.cu
0 → 100644
View file @
4b535e64
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_unified_dispatch
<
cutlass
::
half_t
,
128
,
true
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_unified_hdim128_fp16_sm80.cu
0 → 100644
View file @
4b535e64
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_unified_dispatch
<
cutlass
::
half_t
,
128
,
false
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_sparse_util.cu
0 → 100644
View file @
4b535e64
#include <torch/python.h>
#include <torch/nn/functional.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
template
<
typename
scalar_t
>
static
__device__
inline
void
from_float
(
scalar_t
&
out
,
float
f
){
if
constexpr
(
std
::
is_same
<
scalar_t
,
_Float16
>::
value
||
std
::
is_same
<
scalar_t
,
float
>::
value
){
out
=
f
;
}
else
{
uint32_t
u
=
*
(
uint32_t
*
)(
&
f
);
u
+=
0x7fff
+
((
u
>>
16
)
&
1
);
// u += 0x8000;
out
=
u
>>
16
;
}
}
template
<
typename
scalar_t
>
static
__device__
inline
float
to_float
(
scalar_t
in
){
if
constexpr
(
std
::
is_same
<
scalar_t
,
_Float16
>::
value
||
std
::
is_same
<
scalar_t
,
float
>::
value
){
return
in
;
}
else
{
union
{
uint32_t
int32
;
float
fp32
;
}
u
=
{
uint32_t
(
in
)
<<
16
};
return
u
.
fp32
;
}
}
#define Input_Type_SWITCH(SRC_DTYPE, ...) \
[&] { \
if (SRC_DTYPE == at::ScalarType::Half) { \
using scalar_t=_Float16; \
return __VA_ARGS__(); \
}else { \
using scalar_t=uint16_t; \
return __VA_ARGS__(); \
} \
}()
#define BLK_SWITCH(blk,...) \
[&] { \
if (blk==64){ \
constexpr static int BLK = 64; \
return __VA_ARGS__(); \
}else { \
constexpr static int BLK = 128; \
return __VA_ARGS__(); \
} \
}()
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
template
<
typename
scalar_t
,
int
blocksize
,
int
DIM
,
int
BLK
,
bool
has_mean
>
__global__
void
mean_pool_fast_kernel
(
scalar_t
*
out
,
const
scalar_t
*
input
,
int
L_BLOCKS
,
int
b
,
int
s
,
int
h
,
const
scalar_t
*
mean
){
int
tid
=
threadIdx
.
x
;
if
(
blockIdx
.
x
<
L_BLOCKS
-
1
||
s
==
L_BLOCKS
*
BLK
){
const
scalar_t
*
input_cur
=
input
+
blockIdx
.
z
*
s
*
h
*
DIM
+
blockIdx
.
y
*
DIM
+
(
blockIdx
.
x
*
BLK
+
tid
/
16
)
*
h
*
DIM
+
tid
%
16
*
8
;
scalar_t
*
out_cur
=
out
+
blockIdx
.
z
*
h
*
L_BLOCKS
*
DIM
+
blockIdx
.
y
*
L_BLOCKS
*
DIM
+
blockIdx
.
x
*
DIM
;
const
scalar_t
*
mean_cur
=
has_mean
?
mean
+
blockIdx
.
z
*
h
*
DIM
+
blockIdx
.
y
*
DIM
+
tid
%
16
*
8
:
nullptr
;
constexpr
int
n
=
DIM
*
BLK
;
using
half_vec
=
__attribute__
(
(
__vector_size__
(
8
*
sizeof
(
scalar_t
))
))
scalar_t
;
using
float_vec
=
__attribute__
(
(
__vector_size__
(
8
*
sizeof
(
float
))
))
float
;
__shared__
float
lds_ptr
[
blocksize
*
8
];
{
float_vec
sum
=
{
0.0
f
,
0.0
f
,
0.0
f
,
0.0
f
,
0.0
f
,
0.0
f
,
0.0
f
,
0.0
f
};
half_vec
mean_temp
;
if
constexpr
(
has_mean
){
mean_temp
=
*
reinterpret_cast
<
const
half_vec
*>
(
mean_cur
);
// if(tid==0)printf("mean_temp =%.5f,%.5f,%.5f,%.5f, %.5f,%.5f,%.5f,%.5f,\n", to_float(mean_temp[0]), to_float(mean_temp[1]), to_float(mean_temp[2]), to_float(mean_temp[3])
// , to_float(mean_temp[4]), to_float(mean_temp[5]), to_float(mean_temp[6]), to_float(mean_temp[7]));
}
for
(
int
i
=
0
;
i
<
n
;
i
+=
blocksize
*
8
){
half_vec
temp
=
*
reinterpret_cast
<
const
half_vec
*>
(
input_cur
+
i
*
h
);
for
(
int
ii
=
0
;
ii
<
8
;
ii
++
){
if
constexpr
(
has_mean
){
sum
[
ii
]
+=
to_float
(
temp
[
ii
])
-
to_float
(
mean_temp
[
ii
]);
}
else
{
sum
[
ii
]
+=
to_float
(
temp
[
ii
]);
}
}
}
*
reinterpret_cast
<
float_vec
*>
(
lds_ptr
+
tid
*
8
)
=
sum
;
__syncthreads
();
}
float
sum
=
0.0
f
;
for
(
int
i
=
0
;
i
<
8
;
i
++
){
sum
+=
lds_ptr
[
tid
+
DIM
*
i
];
}
sum
/=
BLK
;
from_float
(
out_cur
[
tid
],
sum
);
}
else
{
int
s_lenth
=
s
%
BLK
;
const
scalar_t
*
input_cur
=
input
+
blockIdx
.
z
*
s
*
h
*
DIM
+
blockIdx
.
y
*
DIM
+
(
blockIdx
.
x
*
BLK
)
*
h
*
DIM
+
tid
;
scalar_t
*
out_cur
=
out
+
blockIdx
.
z
*
h
*
L_BLOCKS
*
DIM
+
blockIdx
.
y
*
L_BLOCKS
*
DIM
+
blockIdx
.
x
*
DIM
;
const
scalar_t
*
mean_cur
=
has_mean
?
mean
+
blockIdx
.
z
*
h
*
DIM
+
blockIdx
.
y
*
DIM
+
tid
:
nullptr
;
float
sum
=
0.0
f
;
float
mean_temp
=
0.0
f
;
if
constexpr
(
has_mean
){
mean_temp
=
to_float
(
*
(
mean_cur
));
}
for
(
int
i
=
0
;
i
<
s_lenth
;
i
++
){
scalar_t
temp
=
*
(
input_cur
+
i
*
h
*
DIM
);
if
constexpr
(
has_mean
){
sum
+=
(
to_float
(
temp
)
-
mean_temp
);
}
else
{
sum
+=
to_float
(
temp
);
}
}
sum
/=
s_lenth
;
from_float
(
out_cur
[
tid
],
sum
);
}
}
at
::
Tensor
mean_pool_fast
(
const
at
::
Tensor
&
input
,
int
blk
,
const
c10
::
optional
<
at
::
Tensor
>
&
mean
){
//assume dim=128
int
b
=
input
.
size
(
0
);
int
s
=
input
.
size
(
1
);
int
h
=
input
.
size
(
2
);
int
d
=
input
.
size
(
3
);
int
L_BLOCKS
=
(
s
+
blk
-
1
)
/
blk
;
auto
out
=
torch
::
empty
({
b
,
h
,
L_BLOCKS
,
d
},
input
.
options
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
dim3
grid
(
L_BLOCKS
,
h
,
b
);
Input_Type_SWITCH
(
input
.
scalar_type
(),[
&
]{
BLK_SWITCH
(
blk
,[
&
]{
const
scalar_t
*
mean_ptr
=
mean
?
reinterpret_cast
<
const
scalar_t
*>
(
mean
.
value
().
data_ptr
())
:
nullptr
;
BOOL_SWITCH
(
mean_ptr
!=
nullptr
,
has_mean
,[
&
]{
const
scalar_t
*
input_ptr
=
reinterpret_cast
<
const
scalar_t
*>
(
input
.
data_ptr
());
scalar_t
*
out_ptr
=
reinterpret_cast
<
scalar_t
*>
(
out
.
data_ptr
());
mean_pool_fast_kernel
<
scalar_t
,
128
,
128
,
BLK
,
has_mean
><<<
grid
,
128
,
0
,
stream
>>>
(
out_ptr
,
input_ptr
,
L_BLOCKS
,
b
,
s
,
h
,
mean_ptr
);
});
});
});
return
out
;
}
\ No newline at end of file
csrc/flash_attn/src/kernel_traits.h
View file @
4b535e64
...
@@ -1711,6 +1711,159 @@ struct Flash_fwd_kernel_16x64_splitkv_prefetch_mla_traits : public Base {
...
@@ -1711,6 +1711,159 @@ struct Flash_fwd_kernel_16x64_splitkv_prefetch_mla_traits : public Base {
Layout
<
Shape
<
Int
<
kGmemRowsPerThread
>
,
_8
>
,
Stride
<
_8
,
_1
>>
{}));
// Val layout, 8 vals per load
Layout
<
Shape
<
Int
<
kGmemRowsPerThread
>
,
_8
>
,
Stride
<
_8
,
_1
>>
{}));
// Val layout, 8 vals per load
};
};
template
<
int
kHeadDim_
,
int
kBlockM_
,
int
kBlockN_
,
int
kNWarps_
,
typename
elem_type
=
cutlass
::
half_t
,
int
kStages_
=
1
,
int
kHeadDimV_
=
kHeadDim_
,
typename
Base
=
Flash_kernel_traits
<
kHeadDim_
,
kBlockM_
,
kBlockN_
,
kNWarps_
,
elem_type
>
>
struct
Flash_fwd_kernel_16x64_splitkv_prefetch_unified_traits
:
public
Base
{
using
Element
=
typename
Base
::
Element
;
using
ElementAccum
=
typename
Base
::
ElementAccum
;
using
index_t
=
typename
Base
::
index_t
;
static
constexpr
bool
Has_cp_async
=
Base
::
Has_cp_async
;
using
SmemCopyAtom
=
typename
Base
::
SmemCopyAtom
;
using
SmemCopyAtomTransposed
=
typename
Base
::
SmemCopyAtomTransposed
;
static
constexpr
bool
Share_Q_K_smem
=
true
;
// The number of threads.
static
constexpr
int
kNWarps
=
kNWarps_
;
static
constexpr
int
kNThreads
=
kNWarps
*
64
;
static
constexpr
int
kBlockM
=
kBlockM_
;
static
constexpr
int
kBlockN
=
kBlockN_
;
static
constexpr
int
kHeadDim
=
kHeadDim_
;
static
constexpr
int
kHeadDimV
=
kHeadDimV_
;
static_assert
(
kBlockN
%
64
==
0
);
static_assert
(
kHeadDim
%
32
==
0
);
static_assert
(
kHeadDimV
%
32
==
0
);
static
constexpr
int
kStages
=
kStages_
;
static
constexpr
int
kBlockKSmem
=
kHeadDim
%
64
==
0
?
64
:
32
;
static
constexpr
int
kBlockKGmem
=
kHeadDim
%
128
==
0
?
128
:
(
kHeadDim
%
64
==
0
?
64
:
32
);
static
constexpr
int
kSwizzle
=
kBlockKSmem
==
32
?
2
:
3
;
using
MMA_Atom_Arch_16x64
=
std
::
conditional_t
<
std
::
is_same_v
<
elem_type
,
cutlass
::
half_t
>
,
MMA_Atom
<
GFX928_16x64x32_F32F16F16F32_NT
>
,
MMA_Atom
<
GFX928_16x64x32_F32BF16BF16F32_NT
>
>
;
using
MMA_Atom_Arch_16x64_BLayout
=
std
::
conditional_t
<
std
::
is_same_v
<
elem_type
,
cutlass
::
half_t
>
,
MMA_Atom
<
GFX928_16x64x32_F32F16F16F32_NT_BLayout
>
,
MMA_Atom
<
GFX928_16x64x32_F32BF16BF16F32_NT_BLayout
>
>
;
using
MMA_Atom_Arch_16x32
=
std
::
conditional_t
<
std
::
is_same_v
<
elem_type
,
cutlass
::
half_t
>
,
MMA_Atom
<
GFX928_16x32x16_F32F16F16F32_NT
>
,
MMA_Atom
<
GFX928_16x32x16_F32BF16BF16F32_NT
>
>
;
using
TiledMma
=
TiledMMA
<
typename
Base
::
MMA_Atom_Arch
,
Layout
<
Shape
<
Int
<
kNWarps
>
,
_1
,
_1
>>
,
// 4x1x1 or 8x1x1 thread group
typename
Base
::
ValLayoutMNK
>
;
using
TiledMma16x64
=
TiledMMA
<
MMA_Atom_Arch_16x64
,
Layout
<
Shape
<
Int
<
kNWarps
>
,
_1
,
_1
>>
,
// 4x1x1 or 8x1x1 thread group
typename
Base
::
ValLayoutMNK
>
;
using
TiledMma16x64BLayout
=
TiledMMA
<
MMA_Atom_Arch_16x64_BLayout
,
Layout
<
Shape
<
Int
<
kNWarps
>
,
_1
,
_1
>>
,
typename
Base
::
ValLayoutMNK
>
;
using
TiledMma16x32
=
TiledMMA
<
MMA_Atom_Arch_16x32
,
Layout
<
Shape
<
Int
<
kNWarps
>
,
_1
,
_1
>>
,
// 4x1x1 or 8x1x1 thread group
typename
Base
::
ValLayoutMNK
>
;
using
SmemLayoutAtomQ
=
decltype
(
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
Layout
<
Shape
<
_8
,
Int
<
kBlockKSmem
>>
,
Stride
<
Int
<
kBlockKSmem
>
,
_1
>>
{}));
using
SmemLayoutQ
=
decltype
(
tile_to_shape
(
SmemLayoutAtomQ
{},
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{}));
static
constexpr
uint32_t
LayoutBlock
=
64
;
static
constexpr
uint32_t
LayoutDim
=
128
;
using
SmemLayoutAtomK
=
Layout
<
Shape
<
Int
<
kBlockN
>
,
Int
<
128
>>
,
Stride
<
Int
<
128
>
,
_1
>>
;
using
SmemLayoutKV
=
decltype
(
tile_to_shape
(
SmemLayoutAtomK
{},
Shape
<
Int
<
kBlockN
>
,
Int
<
128
>>
{}));
using
SmemLayoutK
=
Layout
<
Shape
<
Int
<
kBlockN
*
(
128
/
64
)
>
,
Int
<
64
>>
,
Stride
<
Int
<
64
>
,
_1
>>
;
using
SmemLayoutAtomO
=
decltype
(
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
Layout
<
Shape
<
Int
<
8
>
,
Int
<
kBlockKSmem
>>
,
Stride
<
Int
<
kBlockKSmem
>
,
_1
>>
{}));
using
SmemLayoutO
=
decltype
(
tile_to_shape
(
SmemLayoutAtomO
{},
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{}));
using
SmemCopyAtomO
=
Copy_Atom
<
DefaultCopy
,
Element
>
;
using
SmemCopyAtomOaccum
=
Copy_Atom
<
DefaultCopy
,
ElementAccum
>
;
using
SmemLayoutAtomV
=
Layout
<
Shape
<
Int
<
16
>
,
Int
<
32
>>
,
Stride
<
Int
<
32
>
,
_1
>>
;
using
SmemLayoutV
=
decltype
(
tile_to_shape
(
SmemLayoutAtomV
{},
Shape
<
Int
<
LayoutBlock
>
,
Int
<
LayoutDim
>>
{}));
using
SmemLayoutVtransposed
=
decltype
(
composition
(
SmemLayoutV
{},
make_layout
(
Shape
<
Int
<
LayoutDim
>
,
Int
<
LayoutBlock
>>
{},
GenRowMajor
{})));
using
SmemLayoutVsplit
=
decltype
(
tile_to_shape
(
SmemLayoutAtomV
{},
Shape
<
Int
<
16
>
,
Int
<
4
*
LayoutDim
>>
{}));
using
SmemLayoutVtransSplit
=
decltype
(
composition
(
SmemLayoutVsplit
{},
make_layout
(
Shape
<
Int
<
4
*
LayoutDim
>
,
Int
<
16
>>
{},
GenRowMajor
{})));
static
constexpr
int
kSmemKVSize
=
size
(
SmemLayoutKV
{})
*
2
*
sizeof
(
Element
);
static
constexpr
int
kSmemKSize
=
size
(
SmemLayoutKV
{})
*
sizeof
(
Element
);
static
constexpr
int
kSmemOSize
=
size
(
SmemLayoutO
{})
*
sizeof
(
Element
);
// static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element);
// static constexpr int kSmemSize = std::max(kSmemKSize, kSmemOSize);
static
constexpr
int
kSmemSize
=
kSmemKSize
;
static
constexpr
int
kGmemElemsPerLoad
=
sizeof
(
cute
::
uint128_t
)
/
sizeof
(
Element
);
static_assert
(
kHeadDim
%
kGmemElemsPerLoad
==
0
,
"kHeadDim must be a multiple of kGmemElemsPerLoad"
);
static
constexpr
int
kGmemThreadsPerRow
=
kNThreads
==
512
?
16
:
kBlockKSmem
/
kGmemElemsPerLoad
;
static_assert
(
kNThreads
%
kGmemThreadsPerRow
==
0
,
"kNThreads must be a multiple of kGmemThreadsPerRow"
);
#if 1
using
GmemLayoutAtom
=
Layout
<
Shape
<
Int
<
kNThreads
/
kGmemThreadsPerRow
>
,
Int
<
kGmemThreadsPerRow
>>
,
Stride
<
Int
<
kGmemThreadsPerRow
>
,
_1
>>
;
#else
using
GmemLayoutAtom
=
Layout
<
Shape
<
_64
,
_4
>
,
Stride
<
_4
,
_1
>>
;
#endif
using
Gmem_copy_struct
=
std
::
conditional_t
<
Has_cp_async
,
SM80_CP_ASYNC_CACHEGLOBAL
<
cute
::
uint128_t
>
,
DefaultCopy
>
;
using
GmemTiledCopyQKV
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
Gmem_copy_struct
,
Element
>
{},
GmemLayoutAtom
{},
Layout
<
Shape
<
_1
,
_8
>>
{}));
// Val layout, 8 vals per read
using
GmemTiledCopyO
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
Element
>
{},
GmemLayoutAtom
{},
Layout
<
Shape
<
_1
,
_8
>>
{}));
// Val layout, 8 vals per store
using
GmemLayoutAtomOaccum
=
std
::
conditional_t
<
kBlockKSmem
==
32
,
Layout
<
Shape
<
_32
,
_8
>
,
// Thread layout, 8 threads per row
Stride
<
_8
,
_1
>>
,
Layout
<
Shape
<
_16
,
_16
>
,
// Thread layout, 16 threads per row
Stride
<
_16
,
_1
>>
>
;
using
GmemTiledCopyOaccum
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
ElementAccum
>
{},
GmemLayoutAtomOaccum
{},
Layout
<
Shape
<
_1
,
_4
>>
{}));
// Val layout, 4 vals per store
static
constexpr
int
kGmemRowsPerThread
=
kBlockN
/
(
kNThreads
/
kGmemThreadsPerRow
);
using
GmemTiledCopyQKVPaged
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
UniversalCopy
<
uint128_t
>
,
Element
>
{},
GmemLayoutAtom
{},
Layout
<
Shape
<
Int
<
kGmemRowsPerThread
>
,
_8
>
,
Stride
<
_8
,
_1
>>
{}));
using
GmemLayoutAtomRotcossin
=
GmemLayoutAtom
;
using
GmemTiledCopyRotcossin
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
UniversalCopy
<
uint64_t
>
,
Element
>
{},
GmemLayoutAtomRotcossin
{},
Layout
<
Shape
<
_1
,
_4
>>
{}));
// Val layout, 4 vals per load
using
GmemTiledCopyRotcossinCont
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
Element
>
{},
GmemLayoutAtomRotcossin
{},
Layout
<
Shape
<
_1
,
_8
>>
{}));
// Val layout, 8 vals per load
using
GmemTiledCopyRotcossinPaged
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
UniversalCopy
<
uint64_t
>
,
Element
>
{},
GmemLayoutAtomRotcossin
{},
Layout
<
Shape
<
Int
<
kGmemRowsPerThread
>
,
_4
>
,
Stride
<
_4
,
_1
>>
{}));
// Val layout, 4 vals per load
using
GmemTiledCopyRotcossinContPaged
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
UniversalCopy
<
uint128_t
>
,
Element
>
{},
GmemLayoutAtomRotcossin
{},
Layout
<
Shape
<
Int
<
kGmemRowsPerThread
>
,
_8
>
,
Stride
<
_8
,
_1
>>
{}));
// Val layout, 8 vals per load
};
template
<
int
kHeadDim_
,
int
kBlockM_
,
int
kBlockN_
,
int
kNWarps_
,
typename
elem_type
=
cutlass
::
half_t
,
template
<
int
kHeadDim_
,
int
kBlockM_
,
int
kBlockN_
,
int
kNWarps_
,
typename
elem_type
=
cutlass
::
half_t
,
int
kStages_
=
1
,
int
kHeadDimV_
=
kHeadDim_
,
typename
Base
=
Flash_kernel_traits
<
kHeadDim_
,
kBlockM_
,
kBlockN_
,
kNWarps_
,
elem_type
>
>
int
kStages_
=
1
,
int
kHeadDimV_
=
kHeadDim_
,
typename
Base
=
Flash_kernel_traits
<
kHeadDim_
,
kBlockM_
,
kBlockN_
,
kNWarps_
,
elem_type
>
>
...
...
csrc/flash_attn/src/paged_attention.cu
View file @
4b535e64
...
@@ -315,8 +315,8 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
...
@@ -315,8 +315,8 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
v_scale
=*
v_scale_ptr
;
v_scale
=*
v_scale_ptr
;
}
}
const
int
num_queries_per_kv
=
num_heads
/
num_kv_heads
;
const
int
num_queries_per_kv
=
num_heads
/
num_kv_heads
;
const
int
head_idx
=
blockIdx
.
x
*
num_queries_per_kv
;
const
int
kv_head_idx
=
blockIdx
.
x
;
const
int
kv_head_idx
=
blockIdx
.
x
;
const
int
head_idx
=
num_queries_per_kv
/
mtp
*
kv_head_idx
;
constexpr
int
reuse_group
=
(
REUSE_KV_TIMES
-
1
)
/
4
+
1
;
constexpr
int
reuse_group
=
(
REUSE_KV_TIMES
-
1
)
/
4
+
1
;
constexpr
int
Mloop
=
(
REUSE_KV_TIMES
-
1
)
/
16
+
1
;
constexpr
int
Mloop
=
(
REUSE_KV_TIMES
-
1
)
/
16
+
1
;
extern
__shared__
char
shared_mem
[];
extern
__shared__
char
shared_mem
[];
...
@@ -353,13 +353,20 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
...
@@ -353,13 +353,20 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
q_zero
.
data
[
0
]
=
{
0
,
0
,
0
,
0
};
q_zero
.
data
[
0
]
=
{
0
,
0
,
0
,
0
};
q_zero
.
data
[
1
]
=
{
0
,
0
,
0
,
0
};
q_zero
.
data
[
1
]
=
{
0
,
0
,
0
,
0
};
scalar_t
*
s_q
=
reinterpret_cast
<
scalar_t
*>
(
shared_mem
);
scalar_t
*
s_q
=
reinterpret_cast
<
scalar_t
*>
(
shared_mem
);
for
(
int
i
=
thread_idx
*
8
;
i
<
num_queries_per_kv
*
HEAD_SIZE
;
i
+=
NUM_THREADS
*
8
){
{
*
reinterpret_cast
<
half4x2
*>
(
s_q
+
i
)
=*
reinterpret_cast
<
const
half4x2
*>
(
q_ptr
+
i
);
int
head_offset
=
HEAD_SIZE
*
num_queries_per_kv
/
mtp
;
for
(
int
i
=
thread_idx
*
8
;
i
<
num_queries_per_kv
*
HEAD_SIZE
;
i
+=
NUM_THREADS
*
8
){
int
qoffset
=
i
/
head_offset
;
qoffset
*=
num_kv_heads
*
head_offset
;
qoffset
+=
i
%
head_offset
;
*
reinterpret_cast
<
half4x2
*>
(
s_q
+
i
)
=*
reinterpret_cast
<
const
half4x2
*>
(
q_ptr
+
qoffset
);
}
}
}
__syncthreads
();
__syncthreads
();
for
(
int
m
=
0
;
m
<
Mloop
;
m
++
){
for
(
int
m
=
0
;
m
<
Mloop
;
m
++
){
int
head_idx_
=
rowid
+
16
*
m
;
for
(
int
i
=
0
;
i
<
HEAD_SIZE
/
32
;
i
++
){
for
(
int
i
=
0
;
i
<
HEAD_SIZE
/
32
;
i
++
){
int
head_idx_
=
rowid
+
16
*
m
;
if
(
head_idx_
<
num_queries_per_kv
)
q_vec
[
m
][
i
]
=*
reinterpret_cast
<
const
half4x2
*>
(
s_q
+
head_idx_
*
HEAD_SIZE
+
(
i
*
4
+
rows
)
*
8
);
if
(
head_idx_
<
num_queries_per_kv
)
q_vec
[
m
][
i
]
=*
reinterpret_cast
<
const
half4x2
*>
(
s_q
+
head_idx_
*
HEAD_SIZE
+
(
i
*
4
+
rows
)
*
8
);
else
q_vec
[
m
][
i
]
=
q_zero
;
else
q_vec
[
m
][
i
]
=
q_zero
;
}
}
...
@@ -422,7 +429,7 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
...
@@ -422,7 +429,7 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
else
{
else
{
scalar_t
temp
;
scalar_t
temp
;
if
(
mtp
>
1
){
if
(
mtp
>
1
){
int
casual
=
mtp
-
reuse_kv_idx
*
mtp
/
num_
heads
;
int
casual
=
mtp
-
reuse_kv_idx
*
mtp
/
num_
queries_per_kv
;
if
(
token_idx
+
casual
>
seq_len
)
qk_vec
[
m
][
ii
]
=-
INFINITY
;
if
(
token_idx
+
casual
>
seq_len
)
qk_vec
[
m
][
ii
]
=-
INFINITY
;
}
}
from_float
(
temp
,
qk_vec
[
m
][
ii
]);
from_float
(
temp
,
qk_vec
[
m
][
ii
]);
...
@@ -643,33 +650,38 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
...
@@ -643,33 +650,38 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
}
}
}
}
}
}
scalar_t
*
out_ptr_base
;
{
int
out_offset
;
scalar_t
*
out_ptr_base
;
if
(
num_partitions
>
1
){
int
out_offset
;
out_offset
=
max_num_partitions
*
HEAD_SIZE
;
if
(
num_partitions
>
1
){
out_ptr_base
=
out_tmp
+
out_tmp_offset
+
seq_idx
*
num_heads
*
out_offset
+
head_idx
*
out_offset
+
partition_idx
*
HEAD_SIZE
;
out_offset
=
max_num_partitions
*
HEAD_SIZE
;
}
out_ptr_base
=
out_tmp
+
out_tmp_offset
+
seq_idx
*
num_heads
*
out_offset
+
head_idx
*
out_offset
+
partition_idx
*
HEAD_SIZE
;
else
{
}
out_offset
=
HEAD_SIZE
;
else
{
out_ptr_base
=
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
out_offset
=
HEAD_SIZE
;
}
out_ptr_base
=
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
for
(
int
g
=
0
;
g
<
reuse_group
;
g
++
){
}
int
reusekvid
=
g
*
4
+
rows
;
int
head_offset
=
num_queries_per_kv
/
mtp
;
if
(
reusekvid
<
num_queries_per_kv
){
for
(
int
g
=
0
;
g
<
reuse_group
;
g
++
){
scalar_t
*
out_ptr
=
out_ptr_base
+
reusekvid
*
out_offset
;
int
reusekvid
=
g
*
4
+
rows
;
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
if
(
reusekvid
<
num_queries_per_kv
){
const
int
row_idx
=
rowid
+
16
*
warp_idx
+
i
*
WARP_SIZE
;
int
out_head
=
reusekvid
/
head_offset
*
num_kv_heads
*
head_offset
+
reusekvid
%
head_offset
;
from_float
(
*
(
out_ptr
+
row_idx
),
accs
[
reusekvid
/
16
][
i
][
g
%
4
]
*
v_scale
);
scalar_t
*
out_ptr
=
out_ptr_base
+
out_head
*
out_offset
;
// if(reusekvid==0)printf("patition=%d,tid=%d,i=%d,g=%d,acc=%f\n",partition_idx,thread_idx,i,g,accs[i][g]);
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
rowid
+
16
*
warp_idx
+
i
*
WARP_SIZE
;
from_float
(
*
(
out_ptr
+
row_idx
),
accs
[
reusekvid
/
16
][
i
][
g
%
4
]
*
v_scale
);
// if(reusekvid==0)printf("patition=%d,tid=%d,i=%d,g=%d,acc=%f\n",partition_idx,thread_idx,i,g,accs[i][g]);
}
}
}
}
}
}
if
(
num_partitions
>
1
&&
thread_idx
<
num_queries_per_kv
){
if
(
num_partitions
>
1
&&
thread_idx
<
num_queries_per_kv
){
int
out_head
=
thread_idx
/
head_offset
*
num_kv_heads
*
head_offset
+
thread_idx
%
head_offset
;
int
offset
=
seq_idx
*
num_heads
*
max_num_partitions
+
(
head_idx
+
thread_idx
)
*
max_num_partitions
+
partition_idx
;
int
offset
=
seq_idx
*
num_heads
*
max_num_partitions
+
(
head_idx
+
out_head
)
*
max_num_partitions
+
partition_idx
;
float
*
exp_sums
=
reinterpret_cast
<
float
*>
(
out_tmp
);
float
*
exp_sums
=
reinterpret_cast
<
float
*>
(
out_tmp
);
float
*
max_logits
=
reinterpret_cast
<
float
*>
(
out_tmp
+
max_tmp_offset
);
float
*
max_logits
=
reinterpret_cast
<
float
*>
(
out_tmp
+
max_tmp_offset
);
*
(
exp_sums
+
offset
)
=
expsum_out
[
thread_idx
];
*
(
exp_sums
+
offset
)
=
expsum_out
[
thread_idx
];
*
(
max_logits
+
offset
)
=
max_out
[
thread_idx
];
*
(
max_logits
+
offset
)
=
max_out
[
thread_idx
];
}
}
}
}
}
...
@@ -797,19 +809,22 @@ void paged_attention(
...
@@ -797,19 +809,22 @@ void paged_attention(
int
num_kv_heads
=
key_cache
.
size
(
1
);
int
num_kv_heads
=
key_cache
.
size
(
1
);
int
PARTITION_SIZE
=
512
;
int
PARTITION_SIZE
=
512
;
int
reusekv
=
get_reusekv
(
num_heads
,
num_kv_heads
);
int
reusekv
=
get_reusekv
(
num_heads
,
num_kv_heads
);
if
(
reusekv
>
15
)
PARTITION_SIZE
=
256
;
//if seq<10,the seq is invalid
if
(
max_seq_len
<=
10
||
(
max_seq_len
>=
8192
&&
max_seq_len
==
max_num_blocks_per_seq
*
block_size
)){
if
(
max_seq_len
<=
10
||
(
max_seq_len
>=
8192
&&
max_seq_len
==
max_num_blocks_per_seq
*
block_size
)){
int
meanseq
=
num_blocks
*
block_size
/
num_seqs
+
8192
;
int
meanseq
=
num_blocks
*
block_size
/
num_seqs
+
4096
;
int
maxseq
=
100000000
/
num_seqs
/
headsize
/
num_heads
*
64
;
int
maxseq
=
100000000
/
num_seqs
/
headsize
/
num_heads
*
64
;
if
(
reusekv
<
=
8
)
maxseq
*=
2
;
if
(
reusekv
<
16
)
maxseq
*=
2
;
max_seq_len
=
MIN
(
max_num_blocks_per_seq
*
block_size
,
MIN
(
meanseq
,
maxseq
));
max_seq_len
=
MIN
(
max_num_blocks_per_seq
*
block_size
,
MIN
(
meanseq
,
maxseq
));
}
}
else
{
int
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
if
(
max_num_partitions
*
num_seqs
*
num_kv_heads
<=
160
)
PARTITION_SIZE
=
256
;
if
(
num_seqs
*
num_kv_heads
<=
32
&&
max_seq_len
<=
32768
)
PARTITION_SIZE
=
256
;
}
int
real_reuse_times
=
num_heads
/
num_kv_heads
;
int
real_reuse_times
=
num_heads
/
num_kv_heads
;
int
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
if
(
max_num_partitions
*
num_seqs
*
num_kv_heads
<=
160
||
reusekv
>
15
)
PARTITION_SIZE
=
256
;
if
(
num_seqs
*
num_kv_heads
<=
32
&&
max_seq_len
<=
32768
)
PARTITION_SIZE
=
256
;
// if(max_num_partitions*num_seqs*num_kv_heads>200&&real_reuse_times<6&&max_seq_len>30000)PARTITION_SIZE=1024;
if
(
PA_PARTITION_SIZE
!=
0
)
PARTITION_SIZE
=
PA_PARTITION_SIZE
;
if
(
PA_PARTITION_SIZE
!=
0
)
PARTITION_SIZE
=
PA_PARTITION_SIZE
;
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
int
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
static
float
*
tmp_out_ptr
=
nullptr
;
static
float
*
tmp_out_ptr
=
nullptr
;
constexpr
int
temp_out_size
=
110000000
;
constexpr
int
temp_out_size
=
110000000
;
if
(
tmp_out_ptr
==
nullptr
){
if
(
tmp_out_ptr
==
nullptr
){
...
@@ -881,7 +896,7 @@ void paged_attention(
...
@@ -881,7 +896,7 @@ void paged_attention(
int
shared_mem_size
=
PARTITION_SIZE
*
2
*
real_reuse_times
+
other_use
;
int
shared_mem_size
=
PARTITION_SIZE
*
2
*
real_reuse_times
+
other_use
;
grid
.
z
=
max_num_partitions
;
grid
.
z
=
max_num_partitions
;
dim3
block
(
NUM_THREADS
);
dim3
block
(
NUM_THREADS
);
if
(
PA_PRINT_PARAM
)
printf
(
"is_fp8=%d,shared_mem_size=%d,HEAD_SIZE=%d,BLOCK_SIZE=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d,PARTITION_SIZE=%d,max_num_partitions=%d
\n
"
,
if
(
PA_PRINT_PARAM
&&
static_cast
<
int32_t
>
(
query
.
get_device
())
==
0
)
printf
(
"is_fp8=%d,shared_mem_size=%d,HEAD_SIZE=%d,BLOCK_SIZE=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d,PARTITION_SIZE=%d,max_num_partitions=%d
\n
"
,
(
int
)(
sizeof
(
cache_t
)
==
1
),
shared_mem_size
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
grid
.
x
,
grid
.
y
,
grid
.
z
,
num_heads
,
num_kv_heads
,
max_seq_len
,
num_seqs
,
PARTITION_SIZE
,
max_num_partitions
);
(
int
)(
sizeof
(
cache_t
)
==
1
),
shared_mem_size
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
grid
.
x
,
grid
.
y
,
grid
.
z
,
num_heads
,
num_kv_heads
,
max_seq_len
,
num_seqs
,
PARTITION_SIZE
,
max_num_partitions
);
paged_attention_kernel
<
scalar_t
,
cache_t
,
is_e4m3
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
REUSE_KV_TIMES
><<<
grid
,
block
,
shared_mem_size
,
stream
>>>
(
paged_attention_kernel
<
scalar_t
,
cache_t
,
is_e4m3
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
REUSE_KV_TIMES
><<<
grid
,
block
,
shared_mem_size
,
stream
>>>
(
(
scalar_t
*
)
out_ptr
,(
scalar_t
*
)
tmp_out_ptr
,
(
scalar_t
*
)
query_ptr
,(
cache_t
*
)
key_cache_ptr
,
(
cache_t
*
)
value_cache_ptr
,
(
scalar_t
*
)
out_ptr
,(
scalar_t
*
)
tmp_out_ptr
,
(
scalar_t
*
)
query_ptr
,(
cache_t
*
)
key_cache_ptr
,
(
cache_t
*
)
value_cache_ptr
,
...
...
csrc/flash_attn/src/paged_attention_938.cu
View file @
4b535e64
...
@@ -363,8 +363,9 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
...
@@ -363,8 +363,9 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
}
}
k_scale
*=
q_scale
;
k_scale
*=
q_scale
;
const
int
num_queries_per_kv
=
num_heads
/
num_kv_heads
;
const
int
num_queries_per_kv
=
num_heads
/
num_kv_heads
;
const
int
head_idx
=
blockIdx
.
x
*
num_queries_per_kv
;
const
int
kv_head_idx
=
blockIdx
.
x
;
const
int
kv_head_idx
=
blockIdx
.
x
;
const
int
head_idx
=
num_queries_per_kv
/
mtp
*
kv_head_idx
;
constexpr
int
reuse_group
=
(
REUSE_KV_TIMES
-
1
)
/
4
+
1
;
constexpr
int
reuse_group
=
(
REUSE_KV_TIMES
-
1
)
/
4
+
1
;
constexpr
int
Mloop
=
(
REUSE_KV_TIMES
-
1
)
/
16
+
1
;
constexpr
int
Mloop
=
(
REUSE_KV_TIMES
-
1
)
/
16
+
1
;
extern
__shared__
char
shared_mem
[];
extern
__shared__
char
shared_mem
[];
...
@@ -397,12 +398,19 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
...
@@ -397,12 +398,19 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
}
}
intx4
q_vec
[
Mloop
][
HEAD_SIZE
/
64
];
intx4
q_vec
[
Mloop
][
HEAD_SIZE
/
64
];
q_type
*
s_q
=
reinterpret_cast
<
q_type
*>
(
shared_mem
);
q_type
*
s_q
=
reinterpret_cast
<
q_type
*>
(
shared_mem
);
for
(
int
i
=
thread_idx
*
8
;
i
<
num_queries_per_kv
*
HEAD_SIZE
;
i
+=
NUM_THREADS
*
8
){
{
if
constexpr
(
q_is_fp8
){
int
head_offset
=
HEAD_SIZE
*
num_queries_per_kv
/
mtp
;
*
reinterpret_cast
<
intx2
*>
(
s_q
+
i
)
=*
reinterpret_cast
<
const
intx2
*>
(
q_ptr
+
i
);
for
(
int
i
=
thread_idx
*
8
;
i
<
num_queries_per_kv
*
HEAD_SIZE
;
i
+=
NUM_THREADS
*
8
){
}
int
qoffset
=
i
/
head_offset
;
else
{
qoffset
*=
num_kv_heads
*
head_offset
;
*
reinterpret_cast
<
half4x2
*>
(
s_q
+
i
)
=*
reinterpret_cast
<
const
half4x2
*>
(
q_ptr
+
i
);
qoffset
+=
i
%
head_offset
;
if
constexpr
(
q_is_fp8
){
*
reinterpret_cast
<
intx2
*>
(
s_q
+
i
)
=*
reinterpret_cast
<
const
intx2
*>
(
q_ptr
+
qoffset
);
}
else
{
*
reinterpret_cast
<
half4x2
*>
(
s_q
+
i
)
=*
reinterpret_cast
<
const
half4x2
*>
(
q_ptr
+
qoffset
);
}
}
}
}
}
__syncthreads
();
__syncthreads
();
...
@@ -475,7 +483,7 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
...
@@ -475,7 +483,7 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
else
{
else
{
scalar_t
temp
;
scalar_t
temp
;
if
(
mtp
>
1
){
if
(
mtp
>
1
){
int
casual
=
mtp
-
reuse_kv_idx
*
mtp
/
num_
heads
;
int
casual
=
mtp
-
reuse_kv_idx
*
mtp
/
num_
queries_per_kv
;
if
(
token_idx
+
casual
>
seq_len
)
qk_vec
[
m
][
ii
]
=-
INFINITY
;
if
(
token_idx
+
casual
>
seq_len
)
qk_vec
[
m
][
ii
]
=-
INFINITY
;
}
}
from_float
(
temp
,
qk_vec
[
m
][
ii
]);
from_float
(
temp
,
qk_vec
[
m
][
ii
]);
...
@@ -680,34 +688,38 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
...
@@ -680,34 +688,38 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
}
}
}
}
}
}
{
scalar_t
*
out_ptr_base
;
scalar_t
*
out_ptr_base
;
int
out_offset
;
int
out_offset
;
if
(
num_partitions
>
1
){
if
(
num_partitions
>
1
){
out_offset
=
max_num_partitions
*
HEAD_SIZE
;
out_offset
=
max_num_partitions
*
HEAD_SIZE
;
out_ptr_base
=
out_tmp
+
out_tmp_offset
+
seq_idx
*
num_heads
*
out_offset
+
head_idx
*
out_offset
+
partition_idx
*
HEAD_SIZE
;
out_ptr_base
=
out_tmp
+
out_tmp_offset
+
seq_idx
*
num_heads
*
out_offset
+
head_idx
*
out_offset
+
partition_idx
*
HEAD_SIZE
;
}
}
else
{
else
{
out_offset
=
HEAD_SIZE
;
out_offset
=
HEAD_SIZE
;
out_ptr_base
=
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
out_ptr_base
=
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
}
}
for
(
int
g
=
0
;
g
<
reuse_group
;
g
++
){
int
head_offset
=
num_queries_per_kv
/
mtp
;
int
reusekvid
=
g
*
4
+
rows
;
for
(
int
g
=
0
;
g
<
reuse_group
;
g
++
){
if
(
reusekvid
<
num_queries_per_kv
){
int
reusekvid
=
g
*
4
+
rows
;
scalar_t
*
out_ptr
=
out_ptr_base
+
reusekvid
*
out_offset
;
if
(
reusekvid
<
num_queries_per_kv
){
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
int
out_head
=
reusekvid
/
head_offset
*
num_kv_heads
*
head_offset
+
reusekvid
%
head_offset
;
const
int
row_idx
=
rowid
+
16
*
warp_idx
+
i
*
WARP_SIZE
;
scalar_t
*
out_ptr
=
out_ptr_base
+
out_head
*
out_offset
;
from_float
(
*
(
out_ptr
+
row_idx
),
accs
[
reusekvid
/
16
][
i
][
g
%
4
]
*
v_scale
);
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
// if(reusekvid==0)printf("patition=%d,tid=%d,i=%d,g=%d,acc=%f\n",partition_idx,thread_idx,i,g,accs[i][g]);
const
int
row_idx
=
rowid
+
16
*
warp_idx
+
i
*
WARP_SIZE
;
from_float
(
*
(
out_ptr
+
row_idx
),
accs
[
reusekvid
/
16
][
i
][
g
%
4
]
*
v_scale
);
// if(reusekvid==0)printf("patition=%d,tid=%d,i=%d,g=%d,acc=%f\n",partition_idx,thread_idx,i,g,accs[i][g]);
}
}
}
}
}
}
if
(
num_partitions
>
1
&&
thread_idx
<
num_queries_per_kv
){
if
(
num_partitions
>
1
&&
thread_idx
<
num_queries_per_kv
){
int
out_head
=
thread_idx
/
head_offset
*
num_kv_heads
*
head_offset
+
thread_idx
%
head_offset
;
int
offset
=
seq_idx
*
num_heads
*
max_num_partitions
+
(
head_idx
+
thread_idx
)
*
max_num_partitions
+
partition_idx
;
int
offset
=
seq_idx
*
num_heads
*
max_num_partitions
+
(
head_idx
+
out_head
)
*
max_num_partitions
+
partition_idx
;
float
*
exp_sums
=
reinterpret_cast
<
float
*>
(
out_tmp
);
float
*
exp_sums
=
reinterpret_cast
<
float
*>
(
out_tmp
);
float
*
max_logits
=
reinterpret_cast
<
float
*>
(
out_tmp
+
max_tmp_offset
);
float
*
max_logits
=
reinterpret_cast
<
float
*>
(
out_tmp
+
max_tmp_offset
);
*
(
exp_sums
+
offset
)
=
expsum_out
[
thread_idx
];
*
(
exp_sums
+
offset
)
=
expsum_out
[
thread_idx
];
*
(
max_logits
+
offset
)
=
max_out
[
thread_idx
];
*
(
max_logits
+
offset
)
=
max_out
[
thread_idx
];
}
}
}
#endif
#endif
}
}
...
...
csrc/flash_attn/src/utils.h
View file @
4b535e64
...
@@ -675,7 +675,66 @@ __forceinline__ __device__ void gemm_rr(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tC
...
@@ -675,7 +675,66 @@ __forceinline__ __device__ void gemm_rr(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tC
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
row
,
int
col
,
typename
Tensor0
,
typename
Tensor1
>
__forceinline__
__device__
static
void
__ds_read_m32x16_row_col_alt
(
Tensor0
&
src
,
Tensor1
&
dst
)
{
auto
lds
=
reinterpret_cast
<
__fp16
*>
(
src
.
data
().
get
());
auto
layout
=
src
.
layout
();
constexpr
short
offset
=
layout
(
0
,
row
,
col
)
*
2
;
auto
d
=
__builtin_amdgcn_ds_read_m32x16f16_alt
((
__attribute__
((
address_space
(
3
)))
__fp16
*
)(
lds
),
offset
);
uint16_t
*
d_ptr
=
reinterpret_cast
<
uint16_t
*>
(
&
d
);
uint16_t
*
dst_ptr
=
reinterpret_cast
<
uint16_t
*>
(
&
(
dst
(
0
,
row
,
col
)));
dst_ptr
[
0
]
=
d_ptr
[
0
];
dst_ptr
[
1
]
=
d_ptr
[
1
];
dst_ptr
[
2
]
=
d_ptr
[
2
];
dst_ptr
[
3
]
=
d_ptr
[
3
];
dst_ptr
[
4
]
=
d_ptr
[
4
];
dst_ptr
[
5
]
=
d_ptr
[
5
];
dst_ptr
[
6
]
=
d_ptr
[
6
];
dst_ptr
[
7
]
=
d_ptr
[
7
];
}
template
<
int
k_idx
,
typename
Tensor0
,
typename
Tensor1
,
typename
Tensor2
,
typename
Tensor3
,
typename
TiledMma
,
typename
TiledCopy
,
typename
ThrCopy
>
__forceinline__
__device__
void
gemm_k_rs_ds_read_m32x16_alt
(
Tensor0
&
acc
,
Tensor1
&
tCrA
,
Tensor2
&
tCrB
,
Tensor3
const
&
tCsB
,
TiledMma
tiled_mma
,
TiledCopy
smem_tiled_copy_B
,
ThrCopy
smem_thr_copy_B
)
{
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCrA
)
==
size
<
1
>
(
acc
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCrB
)
==
size
<
2
>
(
acc
));
// MMA_N
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
tCrA
)
==
size
<
2
>
(
tCrB
));
// MMA_K
Tensor
tCrB_copy_view
=
smem_thr_copy_B
.
retile_D
(
tCrB
);
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCsB
)
==
size
<
1
>
(
tCrB_copy_view
));
// N
auto
shape
=
tCsB
.
shape
();
constexpr
int
rows
=
get
<
1
>
(
shape
);
static_assert
(
rows
==
6
||
rows
==
4
||
rows
==
3
||
rows
==
2
);
if
constexpr
(
rows
==
6
)
{
__ds_read_m32x16_row_col_alt
<
0
,
k_idx
>
(
tCsB
,
tCrB_copy_view
);
__ds_read_m32x16_row_col_alt
<
1
,
k_idx
>
(
tCsB
,
tCrB_copy_view
);
__ds_read_m32x16_row_col_alt
<
2
,
k_idx
>
(
tCsB
,
tCrB_copy_view
);
__ds_read_m32x16_row_col_alt
<
3
,
k_idx
>
(
tCsB
,
tCrB_copy_view
);
__ds_read_m32x16_row_col_alt
<
4
,
k_idx
>
(
tCsB
,
tCrB_copy_view
);
__ds_read_m32x16_row_col_alt
<
5
,
k_idx
>
(
tCsB
,
tCrB_copy_view
);
}
else
if
constexpr
(
rows
==
4
)
{
__ds_read_m32x16_row_col_alt
<
0
,
k_idx
>
(
tCsB
,
tCrB_copy_view
);
__ds_read_m32x16_row_col_alt
<
1
,
k_idx
>
(
tCsB
,
tCrB_copy_view
);
__ds_read_m32x16_row_col_alt
<
2
,
k_idx
>
(
tCsB
,
tCrB_copy_view
);
__ds_read_m32x16_row_col_alt
<
3
,
k_idx
>
(
tCsB
,
tCrB_copy_view
);
}
else
if
constexpr
(
rows
==
3
)
{
__ds_read_m32x16_row_col_alt
<
0
,
k_idx
>
(
tCsB
,
tCrB_copy_view
);
__ds_read_m32x16_row_col_alt
<
1
,
k_idx
>
(
tCsB
,
tCrB_copy_view
);
__ds_read_m32x16_row_col_alt
<
2
,
k_idx
>
(
tCsB
,
tCrB_copy_view
);
}
else
if
constexpr
(
rows
==
2
)
{
__ds_read_m32x16_row_col_alt
<
0
,
k_idx
>
(
tCsB
,
tCrB_copy_view
);
__ds_read_m32x16_row_col_alt
<
1
,
k_idx
>
(
tCsB
,
tCrB_copy_view
);
}
// cute::copy(smem_tiled_copy_B, tCsB(_, _, k_idx), tCrB_copy_view(_, _, k_idx));
cute
::
gemm
(
tiled_mma
,
tCrA
(
_
,
_
,
k_idx
),
tCrB
(
_
,
_
,
k_idx
),
acc
);
}
template
<
int
row
,
int
col
,
typename
Tensor0
,
typename
Tensor1
>
template
<
int
row
,
int
col
,
typename
Tensor0
,
typename
Tensor1
>
__forceinline__
__device__
static
void
__ds_read_m32x16_row_col
(
Tensor0
&
src
,
Tensor1
&
dst
)
__forceinline__
__device__
static
void
__ds_read_m32x16_row_col
(
Tensor0
&
src
,
Tensor1
&
dst
)
...
...
csrc/flash_attn_hg/flash_api.cpp
View file @
4b535e64
...
@@ -353,10 +353,11 @@ void set_params_dropout(Flash_fwd_params ¶ms, float p_dropout,
...
@@ -353,10 +353,11 @@ void set_params_dropout(Flash_fwd_params ¶ms, float p_dropout,
c10
::
optional
<
at
::
Generator
>
gen_
,
c10
::
optional
<
at
::
Generator
>
gen_
,
at
::
TensorOptions
opts
,
at
::
TensorOptions
opts
,
at
::
Tensor
&
dropout_debug_count
)
{
at
::
Tensor
&
dropout_debug_count
)
{
rng_state
=
at
::
empty
({
2
},
opts
.
dtype
(
at
::
ScalarType
::
Long
));
// Match the generic FlashAttention API contract: rng_state is returned as a
// tensor even when dropout is disabled.
params
.
rng_state
=
reinterpret_cast
<
uint64_t
*>
(
rng_state
.
data_ptr
());
if
(
p_dropout
>
0
)
{
if
(
p_dropout
>
0
)
{
rng_state
=
at
::
empty
({
2
},
opts
.
dtype
(
at
::
ScalarType
::
Long
));
// Forward kernel will populate memory with the seed and offset.
params
.
rng_state
=
reinterpret_cast
<
uint64_t
*>
(
rng_state
.
data_ptr
());
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
gen_
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
gen_
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
// See Note [Acquire lock when using random generators]
// See Note [Acquire lock when using random generators]
...
@@ -371,8 +372,6 @@ void set_params_dropout(Flash_fwd_params ¶ms, float p_dropout,
...
@@ -371,8 +372,6 @@ void set_params_dropout(Flash_fwd_params ¶ms, float p_dropout,
params
.
dropout_debug_count
=
params
.
dropout_debug_count
=
reinterpret_cast
<
uint32_t
*>
(
dropout_debug_count
.
data_ptr
());
reinterpret_cast
<
uint32_t
*>
(
dropout_debug_count
.
data_ptr
());
#endif
#endif
}
else
{
params
.
rng_state
=
nullptr
;
}
}
}
}
...
@@ -1637,16 +1636,11 @@ std::vector<at::Tensor> varlen_fwd_bhsd(
...
@@ -1637,16 +1636,11 @@ std::vector<at::Tensor> varlen_fwd_bhsd(
params
.
total_k
=
total_k
;
params
.
total_k
=
total_k
;
at
::
Tensor
rng_state
;
at
::
Tensor
rng_state
;
if
(
p_dropout
>
0
)
{
auto
options
=
auto
options
=
at
::
TensorOptions
()
at
::
TensorOptions
().
dtype
(
at
::
ScalarType
::
Float
).
device
(
at
::
DeviceType
::
CUDA
);
.
dtype
(
at
::
ScalarType
::
Float
)
rng_state
=
at
::
empty
({
2
},
options
.
dtype
(
at
::
ScalarType
::
Long
));
.
device
(
at
::
DeviceType
::
CUDA
);
// Keep the return tuple compatible with the generic FlashAttention path.
rng_state
=
at
::
empty
({
2
},
options
.
dtype
(
at
::
ScalarType
::
Long
));
params
.
rng_state
=
reinterpret_cast
<
uint64_t
*>
(
rng_state
.
data_ptr
());
// Forward kernel will populate memory with the seed and offset.
params
.
rng_state
=
reinterpret_cast
<
uint64_t
*>
(
rng_state
.
data_ptr
());
}
else
{
params
.
rng_state
=
nullptr
;
}
set_params_alibi
(
params
,
alibi_slopes_
,
batch_size
,
num_heads
);
set_params_alibi
(
params
,
alibi_slopes_
,
batch_size
,
num_heads
);
...
@@ -1884,16 +1878,11 @@ std::vector<at::Tensor> hg_prefix_prefill_varlen_fwd(
...
@@ -1884,16 +1878,11 @@ std::vector<at::Tensor> hg_prefix_prefill_varlen_fwd(
}
}
at
::
Tensor
rng_state
;
at
::
Tensor
rng_state
;
if
(
p_dropout
>
0
)
{
auto
options
=
auto
options
=
at
::
TensorOptions
()
at
::
TensorOptions
().
dtype
(
at
::
ScalarType
::
Float
).
device
(
at
::
DeviceType
::
CUDA
);
.
dtype
(
at
::
ScalarType
::
Float
)
rng_state
=
at
::
empty
({
2
},
options
.
dtype
(
at
::
ScalarType
::
Long
));
.
device
(
at
::
DeviceType
::
CUDA
);
// Keep the return tuple compatible with the generic FlashAttention path.
rng_state
=
at
::
empty
({
2
},
options
.
dtype
(
at
::
ScalarType
::
Long
));
params
.
rng_state
=
reinterpret_cast
<
uint64_t
*>
(
rng_state
.
data_ptr
());
// Forward kernel will populate memory with the seed and offset.
params
.
rng_state
=
reinterpret_cast
<
uint64_t
*>
(
rng_state
.
data_ptr
());
}
else
{
params
.
rng_state
=
nullptr
;
}
set_params_alibi
(
params
,
alibi_slopes_
,
batch_size
,
num_heads
);
set_params_alibi
(
params
,
alibi_slopes_
,
batch_size
,
num_heads
);
...
...
flash_attn/__init__.py
View file @
4b535e64
...
@@ -26,6 +26,7 @@ if torch.cuda.is_available():
...
@@ -26,6 +26,7 @@ if torch.cuda.is_available():
flash_attn_varlen_with_mask_func
,
flash_attn_varlen_with_mask_func
,
# unified attn functions
# unified attn functions
varlen_fwd_unified
,
varlen_fwd_unified
,
fwd_sparse_mean_pool_fast
,
)
)
# triton fa interface
# triton fa interface
from
flash_attn.flash_attn_triton_interface
import
flash_attn_func
as
triton_flash_attn_func
from
flash_attn.flash_attn_triton_interface
import
flash_attn_func
as
triton_flash_attn_func
...
...
flash_attn/flash_attn_interface.py
View file @
4b535e64
...
@@ -161,7 +161,7 @@ def _flash_attn_varlen_forward(
...
@@ -161,7 +161,7 @@ def _flash_attn_varlen_forward(
# breakpoint()
# breakpoint()
return
out
,
softmax_lse
,
S_dmask
,
rng_state
return
out
,
softmax_lse
,
S_dmask
,
rng_state
@
torch
.
library
.
register_fake
(
"flash_attn2_c_op::varlen_fwd"
)
@
_
torch
_
register_fake
_wrapper
(
"flash_attn2_c_op::varlen_fwd"
)
def
varlen_fwd_fake
(
def
varlen_fwd_fake
(
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
...
@@ -2008,7 +2008,7 @@ def vllm_flash_attn_varlen_func(
...
@@ -2008,7 +2008,7 @@ def vllm_flash_attn_varlen_func(
# if mtp, k head must be 1.
# if mtp, k head must be 1.
# todo : support k head >1
# todo : support k head >1
is_mtp
=
(
max_seqlen_q
*
bs
==
total_q
and
max_seqlen_q
>
1
and
max_seqlen_q
<
5
)
is_mtp
=
(
max_seqlen_q
*
bs
==
total_q
and
max_seqlen_q
>
1
and
max_seqlen_q
<
5
)
if
(
max_seqlen_q
==
1
or
(
is_mtp
and
k
.
shape
[
1
]
==
1
)
)
and
real_window_size
[
0
]
==-
1
:
if
(
max_seqlen_q
==
1
or
is_mtp
)
and
real_window_size
[
0
]
==-
1
:
if
out
==
None
:
if
out
==
None
:
if
q
.
dtype
==
torch
.
float8_e4m3fn
or
q
.
dtype
==
torch
.
float8_e5m2
:
if
q
.
dtype
==
torch
.
float8_e4m3fn
or
q
.
dtype
==
torch
.
float8_e5m2
:
out
=
torch
.
empty
(
q
.
size
(),
device
=
q
.
device
,
dtype
=
torch
.
bfloat16
)
out
=
torch
.
empty
(
q
.
size
(),
device
=
q
.
device
,
dtype
=
torch
.
bfloat16
)
...
@@ -3816,6 +3816,22 @@ def spas_fa2_attn_meansim_topk_varlen_cuda(
...
@@ -3816,6 +3816,22 @@ def spas_fa2_attn_meansim_topk_varlen_cuda(
)
)
def
fwd_sparse_mean_pool_fast
(
x
,
BLK
,
mean
=
None
):
return
flash_attn_cuda
.
fwd_sparse_mean_pool_fast
(
x
,
BLK
,
mean
)
def
get_block_map_fast
(
q
,
k
,
topk_ratio
,
BLKQ
=
128
,
BLKK
=
64
):
meank
=
torch
.
mean
(
k
,
dim
=-
3
,
keepdim
=
True
)
pooled_kblocks
=
fwd_sparse_mean_pool_fast
(
k
,
BLKK
,
meank
)
pooled_qblocks
=
fwd_sparse_mean_pool_fast
(
q
,
BLKQ
)
pooled_score
=
pooled_qblocks
@
pooled_kblocks
.
transpose
(
-
1
,
-
2
)
K
=
pooled_score
.
shape
[
-
1
]
topk
=
min
(
K
,
int
(
topk_ratio
*
K
))
lut
=
torch
.
topk
(
pooled_score
,
topk
,
dim
=-
1
,
sorted
=
False
).
indices
sparse_map
=
torch
.
zeros_like
(
pooled_score
,
dtype
=
torch
.
int8
)
sparse_map
.
scatter_
(
-
1
,
lut
,
1
)
return
sparse_map
,
lut
,
topk
class
SparseLinearAttention
(
nn
.
Module
):
class
SparseLinearAttention
(
nn
.
Module
):
def
__init__
(
self
,
head_dim
,
topk
,
feature_map
=
'softmax'
,
use_bf16
=
True
,
use_fp8
=
False
,
tie_feature_map_qk
=
True
):
def
__init__
(
self
,
head_dim
,
topk
,
feature_map
=
'softmax'
,
use_bf16
=
True
,
use_fp8
=
False
,
tie_feature_map_qk
=
True
):
R
'''
R
'''
...
@@ -3872,19 +3888,15 @@ class SparseLinearAttention(nn.Module):
...
@@ -3872,19 +3888,15 @@ class SparseLinearAttention(nn.Module):
'''
'''
B
,
seqlen_q
,
H
,
headdim
=
q
.
shape
B
,
seqlen_q
,
H
,
headdim
=
q
.
shape
q_bhld
=
q
.
transpose
(
1
,
2
).
contiguous
()
# (B, H, L, D)
k_bhld
=
k
.
transpose
(
1
,
2
).
contiguous
()
# v_bhld = v.transpose(1, 2).contiguous()
# import pdb
# pdb.set_trace()
if
headdim
==
64
:
if
headdim
==
64
:
block_m
=
64
if
seqlen_q
<=
2048
else
128
block_m
=
64
if
seqlen_q
<=
2048
else
128
elif
headdim
==
128
:
elif
headdim
==
128
:
block_m
=
64
if
seqlen_q
<=
2048
else
128
block_m
=
64
if
seqlen_q
<=
2048
else
128
block_k
=
64
block_k
=
64
sparse_map
,
lut
,
real_topk
=
get_block_map
(
q_bhld
,
k_bhld
,
topk_ratio
=
self
.
topk
,
BLKQ
=
block_m
,
BLKK
=
block_k
)
if
headdim
==
64
:
sparse_map
,
lut
,
real_topk
=
get_block_map
(
q
.
transpose
(
1
,
2
).
contiguous
(),
k
.
transpose
(
1
,
2
).
contiguous
(),
topk_ratio
=
self
.
topk
,
BLKQ
=
block_m
,
BLKK
=
block_k
)
else
:
sparse_map
,
lut
,
real_topk
=
get_block_map_fast
(
q
,
k
,
topk_ratio
=
self
.
topk
,
BLKQ
=
block_m
,
BLKK
=
block_k
)
q
=
q
.
to
(
self
.
dtype
)
q
=
q
.
to
(
self
.
dtype
)
k
=
k
.
to
(
self
.
dtype
)
k
=
k
.
to
(
self
.
dtype
)
...
@@ -3900,10 +3912,10 @@ class SparseLinearAttention(nn.Module):
...
@@ -3900,10 +3912,10 @@ class SparseLinearAttention(nn.Module):
seqlen_k
=
k
.
size
(
1
)
seqlen_k
=
k
.
size
(
1
)
num_blocks_q
=
(
seqlen_q
+
block_m
-
1
)
//
block_m
num_blocks_q
=
(
seqlen_q
+
block_m
-
1
)
//
block_m
num_blocks_k
=
(
seqlen_k
+
block_k
-
1
)
//
block_k
num_blocks_k
=
(
seqlen_k
+
block_k
-
1
)
//
block_k
column_count
=
torch
.
zeros
(
column_count
=
torch
.
empty
(
(
B
,
H
,
num_blocks_q
),
dtype
=
torch
.
int32
,
device
=
q
.
device
(
B
,
H
,
num_blocks_q
),
dtype
=
torch
.
int32
,
device
=
q
.
device
)
)
column_index
=
torch
.
zeros
(
column_index
=
torch
.
empty
(
(
B
,
H
,
num_blocks_q
,
1
),
dtype
=
torch
.
int32
,
device
=
q
.
device
(
B
,
H
,
num_blocks_q
,
1
),
dtype
=
torch
.
int32
,
device
=
q
.
device
)
)
...
@@ -3972,14 +3984,56 @@ def sparse_attn_with_sla(
...
@@ -3972,14 +3984,56 @@ def sparse_attn_with_sla(
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
is
not
None
and
x
.
stride
(
-
1
)
!=
1
else
x
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
is
not
None
and
x
.
stride
(
-
1
)
!=
1
else
x
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
attn
=
SparseLinearAttention
(
dtype
=
torch
.
bfloat16
if
use_bf16
else
torch
.
float16
head_dim
=
q
.
size
(
-
1
),
dtype
=
torch
.
float8_e4m3fn
if
use_fp8
else
dtype
topk
=
topk
,
# = 1 - sparsity
B
,
seqlen_q
,
H
,
headdim
=
q
.
shape
feature_map
=
feature_map
,
# options: elu, relu, softmax
assert
not
(
use_bf16
and
use_fp8
),
"Only one of bf16 and fp8 can be used."
use_bf16
=
use_bf16
,
assert
headdim
in
(
64
,
128
),
"Dtype fp16/bf16 only support dim (64, 128)."
use_fp8
=
use_fp8
,
assert
not
(
use_fp8
and
headdim
==
64
),
"Dtype fp8 only support dim 128."
).
cuda
()
if
headdim
==
64
:
return
attn
(
q
,
k
,
v
,
return_sparsity
=
return_sparsity
)
block_m
=
64
if
seqlen_q
<=
2048
else
128
elif
headdim
==
128
:
block_m
=
64
if
seqlen_q
<=
2048
else
128
block_k
=
64
if
headdim
==
64
:
sparse_map
,
lut
,
real_topk
=
get_block_map
(
q
.
transpose
(
1
,
2
).
contiguous
(),
k
.
transpose
(
1
,
2
).
contiguous
(),
topk_ratio
=
topk
,
BLKQ
=
block_m
,
BLKK
=
block_k
)
else
:
sparse_map
,
lut
,
real_topk
=
get_block_map_fast
(
q
,
k
,
topk_ratio
=
topk
,
BLKQ
=
block_m
,
BLKK
=
block_k
)
q
=
q
.
to
(
dtype
)
k
=
k
.
to
(
dtype
)
v
=
v
.
to
(
dtype
)
########## SPARGE BEGIN ##########
headdim
=
q
.
size
(
-
1
)
block_offset
,
block_count
=
block_map_to_block_offset_triton
(
sparse_map
)
block_offset
=
block_offset
*
block_k
softmax_scale
=
1.0
/
(
headdim
**
0.5
)
assert
headdim
in
[
64
,
128
],
"headdim should be in [64, 128]. For other headdim, you can use padding and specify the softmax scale."
seqlen_k
=
k
.
size
(
1
)
num_blocks_q
=
(
seqlen_q
+
block_m
-
1
)
//
block_m
num_blocks_k
=
(
seqlen_k
+
block_k
-
1
)
//
block_k
column_count
=
torch
.
empty
(
(
B
,
H
,
num_blocks_q
),
dtype
=
torch
.
int32
,
device
=
q
.
device
)
column_index
=
torch
.
empty
(
(
B
,
H
,
num_blocks_q
,
1
),
dtype
=
torch
.
int32
,
device
=
q
.
device
)
o_s
=
sparse_attn_func
(
q
,
k
,
v
,
# Use original BLHD layout
block_count
=
block_count
,
block_offset
=
block_offset
,
column_count
=
column_count
,
column_index
=
column_index
,
softmax_scale
=
softmax_scale
,
is_sla
=
True
,
)
if
return_sparsity
:
return
o_s
,
real_topk
/
sparse_map
.
shape
[
-
1
]
else
:
return
o_s
def
_require_hg_varlen_symbol
(
name
:
str
):
def
_require_hg_varlen_symbol
(
name
:
str
):
...
...
setup.py
View file @
4b535e64
...
@@ -870,10 +870,15 @@ if not SKIP_CUDA_BUILD:
...
@@ -870,10 +870,15 @@ if not SKIP_CUDA_BUILD:
"csrc/flash_attn/src/flash_fwd_split_hdim256_fp8_outfp16_e5m2_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim256_fp8_outfp16_e5m2_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim256_q_bf16_kv_e5m2_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim256_q_bf16_kv_e5m2_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim256_q_fp16_kv_e5m2_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim256_q_fp16_kv_e5m2_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_unified_hdim128_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_unified_hdim128_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_unified_hdim128_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_unified_hdim128_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_unified_hdim256_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_unified_hdim256_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_unified_hdim256_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_unified_hdim256_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_unified_hdim256_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_unified_hdim256_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_unified_hdim256_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_unified_hdim256_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_sparse_util.cu"
],
],
extra_compile_args
=
{
extra_compile_args
=
{
"cxx"
:
[
"-O3"
,
"-w"
,
"-std=c++17"
,
"cxx"
:
[
"-O3"
,
"-w"
,
"-std=c++17"
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment