Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ccd7c050
Unverified
Commit
ccd7c050
authored
Jun 17, 2025
by
jvlunteren
Committed by
GitHub
Jun 17, 2025
Browse files
[Kernel] Add Split-KV Support to Unified Triton Attention Kernel (#19152)
Signed-off-by:
Jan van Lunteren
<
jvl@zurich.ibm.com
>
parent
c48c6c40
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
456 additions
and
52 deletions
+456
-52
vllm/attention/ops/triton_unified_attention.py
vllm/attention/ops/triton_unified_attention.py
+456
-52
No files found.
vllm/attention/ops/triton_unified_attention.py
View file @
ccd7c050
...
...
@@ -7,6 +7,7 @@
# - Chih-Chieh Yang <chih.chieh.yang@ibm.com>
# - Thomas Parnell <tpa@zurich.ibm.com>
import
torch
import
triton
import
triton.language
as
tl
...
...
@@ -28,6 +29,24 @@ def apply_softcap(S, x):
return
x
*
(
p1
-
p2
)
/
(
p1
+
p2
)
@
triton
.
jit
def
find_seq_idx
(
query_start_len_ptr
,
target_idx
,
num_seqs
,
BLOCK_Q
:
tl
.
constexpr
,
use_q_block_mode
:
tl
.
constexpr
):
left
:
tl
.
int32
=
0
right
=
num_seqs
while
left
<
right
:
mid
=
(
left
+
right
)
//
2
val
=
tl
.
load
(
query_start_len_ptr
+
mid
)
mid_val
=
val
//
BLOCK_Q
+
mid
if
use_q_block_mode
else
val
if
mid_val
<=
target_idx
:
left
=
mid
+
1
else
:
right
=
mid
return
left
-
1
@
triton
.
jit
def
kernel_unified_attention_2d
(
output_ptr
,
# [num_tokens, num_query_heads, head_size]
...
...
@@ -67,21 +86,12 @@ def kernel_unified_attention_2d(
num_seqs
:
tl
.
int32
,
BLOCK_M
:
tl
.
constexpr
,
# int
):
q_block_global_idx
=
tl
.
program_id
(
0
)
kv_head_idx
=
tl
.
program_id
(
1
)
left
:
tl
.
int32
=
0
right
=
num_seqs
while
left
<
right
:
mid
=
(
left
+
right
)
//
2
mid_val
=
tl
.
load
(
query_start_len_ptr
+
mid
)
//
BLOCK_Q
+
mid
if
mid_val
<=
q_block_global_idx
:
left
=
mid
+
1
else
:
right
=
mid
seq_idx
=
find_seq_idx
(
query_start_len_ptr
,
q_block_global_idx
,
num_seqs
,
BLOCK_Q
,
True
)
seq_idx
=
left
-
1
q_block_start_idx
=
tl
.
load
(
query_start_len_ptr
+
seq_idx
)
//
BLOCK_Q
+
seq_idx
...
...
@@ -242,6 +252,311 @@ def kernel_unified_attention_2d(
)
@
triton
.
jit
def
kernel_unified_attention_3d
(
segm_output_ptr
,
# [num_tokens, num_query_heads, num_segments, head_size]
segm_max_ptr
,
# [num_tokens, num_query_heads, num_segments]
segm_expsum_ptr
,
# [num_tokens, num_query_heads, num_segments]
query_ptr
,
# [num_tokens, num_query_heads, head_size]
key_cache_ptr
,
# [num_blks, num_kv_heads, head_size // x, blk_size, x]
value_cache_ptr
,
# [num_blks, num_kv_heads, head_size, blk_size]
block_tables_ptr
,
# [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr
,
# [num_seqs]
alibi_slopes_ptr
,
# [num_query_heads]
scale
,
# float32
k_scale
,
# float32
v_scale
,
# float32
softcap
,
# float32
num_query_heads
:
tl
.
constexpr
,
# int
num_queries_per_kv
:
tl
.
constexpr
,
# int
block_table_stride
:
tl
.
int64
,
# int
query_stride_0
:
tl
.
int64
,
# int
query_stride_1
:
tl
.
int64
,
# int, should be equal to head_size
BLOCK_SIZE
:
tl
.
constexpr
,
# int
HEAD_SIZE
:
tl
.
constexpr
,
# int
HEAD_SIZE_PADDED
:
tl
.
constexpr
,
# int, must be power of 2
USE_ALIBI_SLOPES
:
tl
.
constexpr
,
# bool
USE_SOFTCAP
:
tl
.
constexpr
,
# bool
SLIDING_WINDOW
:
tl
.
constexpr
,
# int
stride_k_cache_0
:
tl
.
int64
,
# int
stride_k_cache_1
:
tl
.
int64
,
# int
stride_k_cache_2
:
tl
.
int64
,
# int
stride_k_cache_3
:
tl
.
constexpr
,
# int
stride_v_cache_0
:
tl
.
int64
,
# int
stride_v_cache_1
:
tl
.
int64
,
# int
stride_v_cache_2
:
tl
.
int64
,
# int
stride_v_cache_3
:
tl
.
constexpr
,
# int
query_start_len_ptr
,
# [num_seqs+1]
BLOCK_Q
:
tl
.
constexpr
,
# int
num_seqs
:
tl
.
int32
,
BLOCK_M
:
tl
.
constexpr
,
# int
NUM_SEGMENTS_PER_SEQ
:
tl
.
constexpr
,
# int
):
q_block_global_idx
=
tl
.
program_id
(
0
)
kv_head_idx
=
tl
.
program_id
(
1
)
segm_idx
=
tl
.
program_id
(
2
)
seq_idx
=
find_seq_idx
(
query_start_len_ptr
,
q_block_global_idx
,
num_seqs
,
BLOCK_Q
,
True
)
q_block_start_idx
=
tl
.
load
(
query_start_len_ptr
+
seq_idx
)
//
BLOCK_Q
+
seq_idx
q_block_local_idx
=
q_block_global_idx
-
q_block_start_idx
cur_batch_in_all_start_index
=
tl
.
load
(
query_start_len_ptr
+
seq_idx
)
cur_batch_in_all_stop_index
=
tl
.
load
(
query_start_len_ptr
+
seq_idx
+
1
)
cur_batch_query_len
=
cur_batch_in_all_stop_index
\
-
cur_batch_in_all_start_index
if
q_block_local_idx
*
BLOCK_Q
>=
cur_batch_query_len
:
return
# sequence len for this particular sequence
seq_len
=
tl
.
load
(
seq_lens_ptr
+
seq_idx
)
# number of segments for this particular sequence
num_segments
=
NUM_SEGMENTS_PER_SEQ
blocks_per_segment
=
cdiv_fn
(
seq_len
,
num_segments
*
BLOCK_SIZE
)
if
segm_idx
*
blocks_per_segment
*
BLOCK_SIZE
>=
seq_len
:
return
offs_m
=
tl
.
arange
(
0
,
BLOCK_M
)
offs_d
=
tl
.
arange
(
0
,
HEAD_SIZE_PADDED
)
query_pos
=
q_block_local_idx
*
BLOCK_Q
+
offs_m
//
num_queries_per_kv
query_offset_0
=
cur_batch_in_all_start_index
+
query_pos
query_offset_1
=
kv_head_idx
*
num_queries_per_kv
+
\
offs_m
%
num_queries_per_kv
query_offset
=
(
query_offset_0
[:,
None
]
*
query_stride_0
+
query_offset_1
[:,
None
]
*
query_stride_1
+
offs_d
[
None
,
:])
dim_mask
=
tl
.
where
(
offs_d
<
HEAD_SIZE
,
1
,
0
).
to
(
tl
.
int1
)
query_mask_0
=
tl
.
where
(
query_pos
<
cur_batch_query_len
,
1
,
0
).
to
(
tl
.
int1
)
query_mask_1
=
tl
.
where
(
query_offset_1
<
num_query_heads
,
1
,
0
).
to
(
tl
.
int1
)
# Q : (BLOCK_M, HEAD_SIZE_PADDED)
Q
=
tl
.
load
(
query_ptr
+
query_offset
,
mask
=
dim_mask
[
None
,
:]
&
query_mask_0
[:,
None
]
&
query_mask_1
[:,
None
],
other
=
0.0
,
)
block_table_offset
=
seq_idx
*
block_table_stride
M
=
tl
.
full
([
BLOCK_M
],
float
(
"-inf"
),
dtype
=
tl
.
float32
)
L
=
tl
.
full
([
BLOCK_M
],
1.0
,
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
HEAD_SIZE_PADDED
],
dtype
=
tl
.
float32
)
# context length for this particular sequences
context_len
=
seq_len
-
cur_batch_query_len
# alibi slope for this head
if
USE_ALIBI_SLOPES
:
alibi_slope
=
tl
.
load
(
alibi_slopes_ptr
+
query_offset_1
,
mask
=
query_mask_1
,
other
=
0.0
)
num_blocks
=
cdiv_fn
(
seq_len
,
BLOCK_SIZE
)
# iterate through tiles within current segment
for
j
in
range
(
segm_idx
*
blocks_per_segment
,
min
((
segm_idx
+
1
)
*
blocks_per_segment
,
num_blocks
),
):
physical_block_idx
=
tl
.
load
(
block_tables_ptr
+
block_table_offset
+
j
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
v_offset
=
(
physical_block_idx
*
stride_v_cache_0
+
kv_head_idx
*
stride_v_cache_2
+
offs_d
[
None
,
:]
*
stride_v_cache_3
+
offs_n
[:,
None
]
*
stride_v_cache_1
)
k_offset
=
(
physical_block_idx
*
stride_k_cache_0
+
kv_head_idx
*
stride_k_cache_2
+
offs_d
[:,
None
]
*
stride_k_cache_3
+
offs_n
[
None
,
:]
*
stride_k_cache_1
)
# K : (HEAD_SIZE, BLOCK_SIZE)
K_load
=
tl
.
load
(
key_cache_ptr
+
k_offset
,
mask
=
dim_mask
[:,
None
],
other
=
0.0
)
if
K_load
.
dtype
.
is_fp8
():
if
Q
.
dtype
.
is_fp8
():
K
=
K_load
else
:
K
=
(
K_load
.
to
(
tl
.
float32
)
*
tl
.
load
(
k_scale
)).
to
(
Q
.
dtype
)
else
:
K
=
K_load
# V : (BLOCK_SIZE, HEAD_SIZE)
V_load
=
tl
.
load
(
value_cache_ptr
+
v_offset
,
mask
=
dim_mask
[
None
,
:],
other
=
0.0
)
if
V_load
.
dtype
.
is_fp8
():
if
Q
.
dtype
.
is_fp8
():
V
=
V_load
else
:
V
=
(
V_load
.
to
(
tl
.
float32
)
*
tl
.
load
(
v_scale
)).
to
(
Q
.
dtype
)
else
:
V
=
V_load
seq_offset
=
j
*
BLOCK_SIZE
+
offs_n
seq_mask
=
seq_offset
[
None
,
:]
<
context_len
+
query_pos
[:,
None
]
+
1
# S : (BLOCK_M, BLOCK_SIZE)
S
=
tl
.
zeros
(
shape
=
(
BLOCK_M
,
BLOCK_SIZE
),
dtype
=
tl
.
float32
)
S
+=
scale
*
tl
.
dot
(
Q
,
K
)
if
USE_SOFTCAP
:
S
=
apply_softcap
(
S
,
softcap
)
S
=
tl
.
where
(
query_mask_1
[:,
None
]
&
query_mask_0
[:,
None
]
&
seq_mask
,
S
,
float
(
"-inf"
))
if
SLIDING_WINDOW
>
0
:
S
=
tl
.
where
((
context_len
+
query_pos
[:,
None
]
-
seq_offset
)
<
SLIDING_WINDOW
,
S
,
float
(
"-inf"
))
if
USE_ALIBI_SLOPES
:
S
+=
alibi_slope
[:,
None
]
*
(
seq_offset
-
context_len
)
# compute running maximum
# m_j : (BLOCK_M,)
m_j
=
tl
.
maximum
(
M
,
tl
.
max
(
S
,
axis
=
1
))
# For sliding window there's a chance the max is -inf due to masking of
# the entire row. In this case we need to set m_j 0 to avoid NaN
m_j
=
tl
.
where
(
m_j
>
float
(
"-inf"
),
m_j
,
0.0
)
# P : (BLOCK_M, BLOCK_SIZE,)
P
=
tl
.
exp
(
S
-
m_j
[:,
None
])
# l_j : (BLOCK_M,)
l_j
=
tl
.
sum
(
P
,
axis
=
1
)
# alpha : (BLOCK_M, )
alpha
=
tl
.
exp
(
M
-
m_j
)
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
acc
=
acc
*
alpha
[:,
None
]
# update constants
L
=
L
*
alpha
+
l_j
M
=
m_j
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
acc
+=
tl
.
dot
(
P
.
to
(
V
.
dtype
),
V
)
segm_output_offset
=
(
query_offset_0
[:,
None
].
to
(
tl
.
int64
)
*
(
num_query_heads
*
NUM_SEGMENTS_PER_SEQ
*
HEAD_SIZE_PADDED
)
+
query_offset_1
[:,
None
]
*
(
NUM_SEGMENTS_PER_SEQ
*
HEAD_SIZE_PADDED
)
+
segm_idx
*
HEAD_SIZE_PADDED
+
tl
.
arange
(
0
,
HEAD_SIZE_PADDED
)[
None
,
:])
tl
.
store
(
segm_output_ptr
+
segm_output_offset
,
acc
,
mask
=
dim_mask
[
None
,
:]
&
query_mask_0
[:,
None
]
&
query_mask_1
[:,
None
],
)
segm_offset
=
(
query_offset_0
.
to
(
tl
.
int64
)
*
(
num_query_heads
*
NUM_SEGMENTS_PER_SEQ
)
+
query_offset_1
*
NUM_SEGMENTS_PER_SEQ
+
segm_idx
)
tl
.
store
(
segm_max_ptr
+
segm_offset
,
M
,
mask
=
query_mask_0
&
query_mask_1
)
tl
.
store
(
segm_expsum_ptr
+
segm_offset
,
L
,
mask
=
query_mask_0
&
query_mask_1
)
@
triton
.
jit
def
reduce_segments
(
output_ptr
,
# [num_tokens, num_query_heads, head_size]
segm_output_ptr
,
#[num_tokens, num_query_heads, max_num_segments, head_size]
segm_max_ptr
,
# [num_tokens, num_query_heads, max_num_segments]
segm_expsum_ptr
,
# [num_tokens, num_query_heads, max_num_segments]
seq_lens_ptr
,
# [num_seqs]
num_seqs
,
# int
num_query_heads
:
tl
.
constexpr
,
# int
output_stride_0
:
tl
.
int64
,
# int
output_stride_1
:
tl
.
int64
,
# int, should be equal to head_size
block_table_stride
:
tl
.
int64
,
# int
BLOCK_SIZE
:
tl
.
constexpr
,
# int
HEAD_SIZE
:
tl
.
constexpr
,
# int, must be power of 2
HEAD_SIZE_PADDED
:
tl
.
constexpr
,
# int, must be power of 2
query_start_len_ptr
,
# [num_seqs+1]
BLOCK_Q
:
tl
.
constexpr
,
# int
NUM_SEGMENTS_PER_SEQ
:
tl
.
constexpr
,
# int
):
query_token_idx
=
tl
.
program_id
(
0
)
query_head_idx
=
tl
.
program_id
(
1
)
seq_idx
=
find_seq_idx
(
query_start_len_ptr
,
query_token_idx
,
num_seqs
,
BLOCK_Q
,
False
)
# sequence len for this particular sequence
seq_len
=
tl
.
load
(
seq_lens_ptr
+
seq_idx
)
# number of segments for this particular sequence
num_segments
=
NUM_SEGMENTS_PER_SEQ
blocks_per_segment
=
cdiv_fn
(
seq_len
,
num_segments
*
BLOCK_SIZE
)
# create masks for subsequent loads
act_num_segments
=
cdiv_fn
(
seq_len
,
blocks_per_segment
*
BLOCK_SIZE
)
segm_mask
=
tl
.
arange
(
0
,
NUM_SEGMENTS_PER_SEQ
)
<
tl
.
full
(
[
NUM_SEGMENTS_PER_SEQ
],
act_num_segments
,
dtype
=
tl
.
int32
)
dim_mask
=
tl
.
where
(
tl
.
arange
(
0
,
HEAD_SIZE_PADDED
)
<
HEAD_SIZE
,
1
,
0
).
to
(
tl
.
int1
)
# load segment maxima
segm_offset
=
(
query_token_idx
.
to
(
tl
.
int64
)
*
(
num_query_heads
*
NUM_SEGMENTS_PER_SEQ
)
+
query_head_idx
*
NUM_SEGMENTS_PER_SEQ
+
tl
.
arange
(
0
,
NUM_SEGMENTS_PER_SEQ
))
segm_max
=
tl
.
load
(
segm_max_ptr
+
segm_offset
,
mask
=
segm_mask
,
other
=
float
(
"-inf"
))
overall_max
=
tl
.
max
(
segm_max
)
# load and rescale segment exp sums
segm_expsum
=
tl
.
load
(
segm_expsum_ptr
+
segm_offset
,
mask
=
segm_mask
,
other
=
0.0
)
segm_expsum
=
segm_expsum
*
tl
.
exp
(
segm_max
-
overall_max
)
overall_expsum
=
tl
.
sum
(
segm_expsum
)
# load, rescale, and add segment attention outputs
segm_output_offset
=
(
query_token_idx
.
to
(
tl
.
int64
)
*
(
num_query_heads
*
NUM_SEGMENTS_PER_SEQ
*
HEAD_SIZE_PADDED
)
+
query_head_idx
*
(
NUM_SEGMENTS_PER_SEQ
*
HEAD_SIZE_PADDED
)
+
tl
.
arange
(
0
,
NUM_SEGMENTS_PER_SEQ
)[:,
None
]
*
HEAD_SIZE_PADDED
+
tl
.
arange
(
0
,
HEAD_SIZE_PADDED
)[
None
,
:])
segm_output
=
tl
.
load
(
segm_output_ptr
+
segm_output_offset
,
mask
=
segm_mask
[:,
None
]
&
dim_mask
[
None
,
:],
other
=
0.0
,
)
segm_output
*=
tl
.
exp
(
segm_max
-
overall_max
)[:,
None
]
acc_sum
=
tl
.
sum
(
segm_output
,
axis
=
0
)
# safely divide by overall_expsum, returning 0.0 if overall_expsum is 0
acc
=
tl
.
where
(
overall_expsum
==
0.0
,
0.0
,
acc_sum
/
overall_expsum
)
# write result
output_offset
=
(
query_token_idx
*
output_stride_0
+
query_head_idx
*
output_stride_1
+
tl
.
arange
(
0
,
HEAD_SIZE_PADDED
))
tl
.
store
(
output_ptr
+
output_offset
,
acc
,
mask
=
dim_mask
)
def
unified_attention
(
q
,
k
,
...
...
@@ -291,44 +606,133 @@ def unified_attention(
# = floor(q.shape[0] / BLOCK_Q) + num_seqs
total_num_q_blocks
=
q
.
shape
[
0
]
//
BLOCK_Q
+
num_seqs
kernel_unified_attention_2d
[(
total_num_q_blocks
,
num_kv_heads
,
)](
output_ptr
=
out
,
query_ptr
=
q
,
key_cache_ptr
=
k
,
value_cache_ptr
=
v
,
block_tables_ptr
=
block_table
,
seq_lens_ptr
=
seqused_k
,
alibi_slopes_ptr
=
alibi_slopes
,
scale
=
softmax_scale
,
k_scale
=
k_descale
,
v_scale
=
v_descale
,
softcap
=
softcap
,
num_query_heads
=
num_query_heads
,
num_queries_per_kv
=
num_queries_per_kv
,
block_table_stride
=
block_table
.
stride
(
0
),
query_stride_0
=
q
.
stride
(
0
),
query_stride_1
=
q
.
stride
(
1
),
output_stride_0
=
out
.
stride
(
0
),
output_stride_1
=
out
.
stride
(
1
),
BLOCK_SIZE
=
block_size
,
HEAD_SIZE
=
head_size
,
HEAD_SIZE_PADDED
=
triton
.
next_power_of_2
(
head_size
),
USE_ALIBI_SLOPES
=
use_alibi_slopes
,
USE_SOFTCAP
=
(
softcap
>
0
),
SLIDING_WINDOW
=
(
1
+
window_size
[
0
]),
stride_k_cache_0
=
k
.
stride
(
0
),
stride_k_cache_1
=
k
.
stride
(
1
),
stride_k_cache_2
=
k
.
stride
(
2
),
stride_k_cache_3
=
k
.
stride
(
3
),
stride_v_cache_0
=
v
.
stride
(
0
),
stride_v_cache_1
=
v
.
stride
(
1
),
stride_v_cache_2
=
v
.
stride
(
2
),
stride_v_cache_3
=
v
.
stride
(
3
),
query_start_len_ptr
=
cu_seqlens_q
,
BLOCK_Q
=
BLOCK_Q
,
num_seqs
=
num_seqs
,
BLOCK_M
=
BLOCK_M
,
)
# if batch contains a prefill
if
max_seqlen_q
>
1
or
total_num_q_blocks
*
num_kv_heads
>
128
:
kernel_unified_attention_2d
[(
total_num_q_blocks
,
num_kv_heads
,
)](
output_ptr
=
out
,
query_ptr
=
q
,
key_cache_ptr
=
k
,
value_cache_ptr
=
v
,
block_tables_ptr
=
block_table
,
seq_lens_ptr
=
seqused_k
,
alibi_slopes_ptr
=
alibi_slopes
,
scale
=
softmax_scale
,
k_scale
=
k_descale
,
v_scale
=
v_descale
,
softcap
=
softcap
,
num_query_heads
=
num_query_heads
,
num_queries_per_kv
=
num_queries_per_kv
,
block_table_stride
=
block_table
.
stride
(
0
),
query_stride_0
=
q
.
stride
(
0
),
query_stride_1
=
q
.
stride
(
1
),
output_stride_0
=
out
.
stride
(
0
),
output_stride_1
=
out
.
stride
(
1
),
BLOCK_SIZE
=
block_size
,
HEAD_SIZE
=
head_size
,
HEAD_SIZE_PADDED
=
triton
.
next_power_of_2
(
head_size
),
USE_ALIBI_SLOPES
=
use_alibi_slopes
,
USE_SOFTCAP
=
(
softcap
>
0
),
SLIDING_WINDOW
=
(
1
+
window_size
[
0
]),
stride_k_cache_0
=
k
.
stride
(
0
),
stride_k_cache_1
=
k
.
stride
(
1
),
stride_k_cache_2
=
k
.
stride
(
2
),
stride_k_cache_3
=
k
.
stride
(
3
),
stride_v_cache_0
=
v
.
stride
(
0
),
stride_v_cache_1
=
v
.
stride
(
1
),
stride_v_cache_2
=
v
.
stride
(
2
),
stride_v_cache_3
=
v
.
stride
(
3
),
query_start_len_ptr
=
cu_seqlens_q
,
BLOCK_Q
=
BLOCK_Q
,
num_seqs
=
num_seqs
,
BLOCK_M
=
BLOCK_M
,
)
else
:
# for initial version, NUM_SEGMENTS = 16 is chosen as a default
# value that showed good performance in tests
NUM_SEGMENTS
=
16
segm_output
=
torch
.
empty
(
q
.
shape
[
0
],
num_query_heads
,
NUM_SEGMENTS
,
triton
.
next_power_of_2
(
head_size
),
dtype
=
torch
.
float32
,
device
=
q
.
device
,
)
segm_max
=
torch
.
empty
(
q
.
shape
[
0
],
num_query_heads
,
NUM_SEGMENTS
,
dtype
=
torch
.
float32
,
device
=
q
.
device
,
)
segm_expsum
=
torch
.
empty
(
q
.
shape
[
0
],
num_query_heads
,
NUM_SEGMENTS
,
dtype
=
torch
.
float32
,
device
=
q
.
device
,
)
kernel_unified_attention_3d
[(
total_num_q_blocks
,
num_kv_heads
,
NUM_SEGMENTS
)](
segm_output_ptr
=
segm_output
,
segm_max_ptr
=
segm_max
,
segm_expsum_ptr
=
segm_expsum
,
query_ptr
=
q
,
key_cache_ptr
=
k
,
value_cache_ptr
=
v
,
block_tables_ptr
=
block_table
,
seq_lens_ptr
=
seqused_k
,
alibi_slopes_ptr
=
alibi_slopes
,
scale
=
softmax_scale
,
k_scale
=
k_descale
,
v_scale
=
v_descale
,
softcap
=
softcap
,
num_query_heads
=
num_query_heads
,
num_queries_per_kv
=
num_queries_per_kv
,
block_table_stride
=
block_table
.
stride
(
0
),
query_stride_0
=
q
.
stride
(
0
),
query_stride_1
=
q
.
stride
(
1
),
BLOCK_SIZE
=
block_size
,
HEAD_SIZE
=
head_size
,
HEAD_SIZE_PADDED
=
triton
.
next_power_of_2
(
head_size
),
USE_ALIBI_SLOPES
=
use_alibi_slopes
,
USE_SOFTCAP
=
(
softcap
>
0
),
SLIDING_WINDOW
=
(
1
+
window_size
[
0
]),
stride_k_cache_0
=
k
.
stride
(
0
),
stride_k_cache_1
=
k
.
stride
(
1
),
stride_k_cache_2
=
k
.
stride
(
2
),
stride_k_cache_3
=
k
.
stride
(
3
),
stride_v_cache_0
=
v
.
stride
(
0
),
stride_v_cache_1
=
v
.
stride
(
1
),
stride_v_cache_2
=
v
.
stride
(
2
),
stride_v_cache_3
=
v
.
stride
(
3
),
query_start_len_ptr
=
cu_seqlens_q
,
BLOCK_Q
=
BLOCK_Q
,
num_seqs
=
num_seqs
,
BLOCK_M
=
BLOCK_M
,
NUM_SEGMENTS_PER_SEQ
=
NUM_SEGMENTS
,
)
reduce_segments
[(
q
.
shape
[
0
],
num_query_heads
)](
output_ptr
=
out
,
segm_output_ptr
=
segm_output
,
segm_max_ptr
=
segm_max
,
segm_expsum_ptr
=
segm_expsum
,
seq_lens_ptr
=
seqused_k
,
num_seqs
=
num_seqs
,
num_query_heads
=
num_query_heads
,
output_stride_0
=
out
.
stride
(
0
),
output_stride_1
=
out
.
stride
(
1
),
block_table_stride
=
block_table
.
stride
(
0
),
BLOCK_SIZE
=
block_size
,
HEAD_SIZE
=
head_size
,
HEAD_SIZE_PADDED
=
triton
.
next_power_of_2
(
head_size
),
query_start_len_ptr
=
cu_seqlens_q
,
BLOCK_Q
=
BLOCK_Q
,
NUM_SEGMENTS_PER_SEQ
=
NUM_SEGMENTS
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment