Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5036e878
Commit
5036e878
authored
Jan 21, 2026
by
laibao
Browse files
feat: kvpress新增 SnapKV 打分与 KV compaction Triton 内核
parent
d3acd4a5
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
1080 additions
and
0 deletions
+1080
-0
vllm/v1/attention/kv_compression/__init__.py
vllm/v1/attention/kv_compression/__init__.py
+3
-0
vllm/v1/attention/kv_compression/kv_cache_triton.py
vllm/v1/attention/kv_compression/kv_cache_triton.py
+492
-0
vllm/v1/attention/kv_compression/snapkv_triton.py
vllm/v1/attention/kv_compression/snapkv_triton.py
+585
-0
No files found.
vllm/v1/attention/kv_compression/__init__.py
0 → 100644
View file @
5036e878
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
vllm/v1/attention/kv_compression/kv_cache_triton.py
0 → 100644
View file @
5036e878
from
__future__
import
annotations
from
typing
import
Optional
,
Tuple
import
torch
from
vllm.triton_utils
import
HAS_TRITON
if
HAS_TRITON
:
import
triton
import
triton.language
as
tl
def
_require_triton
()
->
None
:
if
not
HAS_TRITON
:
raise
RuntimeError
(
"Triton is not available."
)
def
_check_cuda
(
*
tensors
:
torch
.
Tensor
)
->
None
:
for
t
in
tensors
:
if
not
isinstance
(
t
,
torch
.
Tensor
):
raise
TypeError
(
"Expected torch.Tensor inputs."
)
if
t
.
device
.
type
!=
"cuda"
:
raise
RuntimeError
(
"Triton KV cache ops require CUDA/ROCm tensors."
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
'BLOCK_T'
:
128
,
'BLOCK_D'
:
64
},
num_warps
=
4
,
num_stages
=
2
),
triton
.
Config
({
'BLOCK_T'
:
256
,
'BLOCK_D'
:
64
},
num_warps
=
4
,
num_stages
=
2
),
triton
.
Config
({
'BLOCK_T'
:
256
,
'BLOCK_D'
:
128
},
num_warps
=
8
,
num_stages
=
2
),
],
key
=
[
"D"
],
)
@
triton
.
jit
def
_gather_k_to_packed_kernel
(
K_ptr
,
out_ptr
,
blk_ids_ptr
,
req_blk_starts_ptr
,
cu_seqlens_ptr
,
seq_lens_ptr
,
B
,
H
,
max_blocks
,
block_size
,
D
,
sKb
,
sKh
,
sKt
,
sKd
,
so_t
,
so_h
,
so_d
,
BLOCK_T
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
):
pid_bh
=
tl
.
program_id
(
0
)
pid_t
=
tl
.
program_id
(
1
)
pid_d
=
tl
.
program_id
(
2
)
b
=
pid_bh
//
H
h
=
pid_bh
%
H
if
b
>=
B
:
return
seq_len
=
tl
.
load
(
seq_lens_ptr
+
b
)
if
seq_len
<=
0
:
return
t0
=
pid_t
*
BLOCK_T
t_range
=
t0
+
tl
.
arange
(
0
,
BLOCK_T
)
t_mask
=
t_range
<
seq_len
d0
=
pid_d
*
BLOCK_D
d_range
=
d0
+
tl
.
arange
(
0
,
BLOCK_D
)
d_mask
=
d_range
<
D
# Map logical token indices -> physical block ids.
blk
=
t_range
//
block_size
inb
=
t_range
-
blk
*
block_size
req_blk_start
=
tl
.
load
(
req_blk_starts_ptr
+
b
)
gblk
=
req_blk_start
+
blk
# Guard against out-of-range block indices (should not happen when block_table
# covers the sequence length).
gblk_safe
=
tl
.
where
(
t_mask
,
gblk
,
0
)
bid
=
tl
.
load
(
blk_ids_ptr
+
gblk_safe
,
mask
=
t_mask
,
other
=
0
)
# Source: key cache layout [num_blocks, H, block_size, D]
src_base
=
K_ptr
+
bid
[:,
None
]
*
sKb
+
h
*
sKh
+
inb
[:,
None
]
*
sKt
src_ptrs
=
src_base
+
d_range
[
None
,
:]
*
sKd
# Destination: packed output layout [T, H, D]
out_start
=
tl
.
load
(
cu_seqlens_ptr
+
b
)
dst_base
=
out_ptr
+
(
out_start
+
t_range
)[:,
None
]
*
so_t
+
h
*
so_h
dst_ptrs
=
dst_base
+
d_range
[
None
,
:]
*
so_d
tile
=
tl
.
load
(
src_ptrs
,
mask
=
(
t_mask
[:,
None
]
&
d_mask
[
None
,
:]),
other
=
0
)
tl
.
store
(
dst_ptrs
,
tile
,
mask
=
(
t_mask
[:,
None
]
&
d_mask
[
None
,
:]))
@
torch
.
inference_mode
()
def
gather_k_to_packed_triton
(
key_cache
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
*
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Gather a block-wise KV key cache into a packed [T, H, D] tensor.
Expected layouts:
- key_cache: [num_blocks, H, block_size, D]
- block_table: [B, max_blocks] int32 physical block ids
- seq_lens: [B] int32 logical lengths (tokens) to gather
- cu_seqlens: [B+1] int32 cumulative offsets into the packed output
"""
_require_triton
()
_check_cuda
(
key_cache
,
block_table
,
seq_lens
,
cu_seqlens
)
if
key_cache
.
ndim
!=
4
:
raise
ValueError
(
"key_cache must be a 4D tensor [num_blocks, H, Tb, D]."
)
if
block_table
.
ndim
!=
2
:
raise
ValueError
(
"block_table must be 2D [B, max_blocks]."
)
if
seq_lens
.
ndim
!=
1
:
raise
ValueError
(
"seq_lens must be 1D [B]."
)
if
cu_seqlens
.
ndim
!=
1
:
raise
ValueError
(
"cu_seqlens must be 1D [B+1]."
)
device
=
key_cache
.
device
B
=
int
(
seq_lens
.
numel
())
if
B
==
0
:
return
torch
.
empty
((
0
,
int
(
key_cache
.
shape
[
1
]),
int
(
key_cache
.
shape
[
3
])),
device
=
device
,
dtype
=
key_cache
.
dtype
)
H
=
int
(
key_cache
.
shape
[
1
])
block_size
=
int
(
key_cache
.
shape
[
2
])
D
=
int
(
key_cache
.
shape
[
3
])
max_blocks
=
int
(
block_table
.
shape
[
1
])
seq_lens_i32
=
seq_lens
.
to
(
device
=
device
,
dtype
=
torch
.
int32
)
cu_i32
=
cu_seqlens
.
to
(
device
=
device
,
dtype
=
torch
.
int32
)
total_tokens
=
int
(
cu_i32
[
-
1
].
item
())
if
cu_i32
.
numel
()
>
0
else
0
if
out
is
None
:
out
=
torch
.
empty
((
total_tokens
,
H
,
D
),
device
=
device
,
dtype
=
key_cache
.
dtype
)
else
:
if
out
.
shape
!=
(
total_tokens
,
H
,
D
):
raise
ValueError
(
f
"out has shape
{
tuple
(
out
.
shape
)
}
, expected
{
(
total_tokens
,
H
,
D
)
}
."
)
blk_ids
=
block_table
.
to
(
device
=
device
,
dtype
=
torch
.
int32
).
reshape
(
-
1
)
req_starts
=
(
torch
.
arange
(
B
,
device
=
device
,
dtype
=
torch
.
int32
)
*
max_blocks
)
sKb
,
sKh
,
sKt
,
sKd
=
[
int
(
s
)
for
s
in
key_cache
.
stride
()]
so_t
,
so_h
,
so_d
=
[
int
(
s
)
for
s
in
out
.
stride
()]
L_max
=
int
(
seq_lens_i32
.
max
().
item
())
if
B
>
0
else
0
if
total_tokens
==
0
or
L_max
==
0
or
D
==
0
or
H
==
0
:
return
out
# Use the smallest tile sizes across autotune configs to guarantee coverage
# even when the selected config uses smaller blocks.
grid
=
(
B
*
H
,
triton
.
cdiv
(
L_max
,
128
),
triton
.
cdiv
(
D
,
64
),
)
_gather_k_to_packed_kernel
[
grid
](
key_cache
,
out
,
blk_ids
,
req_starts
,
cu_i32
,
seq_lens_i32
,
B
,
H
,
max_blocks
,
block_size
,
D
,
sKb
,
sKh
,
sKt
,
sKd
,
so_t
,
so_h
,
so_d
,
)
return
out
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
'BLOCK_T'
:
128
,
'BLOCK_D'
:
64
},
num_warps
=
4
,
num_stages
=
2
),
triton
.
Config
({
'BLOCK_T'
:
256
,
'BLOCK_D'
:
64
},
num_warps
=
4
,
num_stages
=
2
),
triton
.
Config
({
'BLOCK_T'
:
512
,
'BLOCK_D'
:
64
},
num_warps
=
8
,
num_stages
=
2
),
triton
.
Config
({
'BLOCK_T'
:
256
,
'BLOCK_D'
:
128
},
num_warps
=
8
,
num_stages
=
2
),
],
key
=
[
'K_max'
,
'Dk'
],
)
@
triton
.
jit
def
_front_compact_inplace_fa_k_kernel
(
K_ptr
,
blk_ids_ptr
,
req_blk_starts_ptr
,
idx_ptr
,
keep_ptr
,
B
,
H
,
K_max
,
block_size
,
Dk
,
sKb
,
sKh
,
sKt
,
sKd
,
si_b
,
si_h
,
si_k
,
BLOCK_T
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
):
pid_bh
=
tl
.
program_id
(
0
)
pid_d
=
tl
.
program_id
(
1
)
b
=
pid_bh
//
H
h
=
pid_bh
%
H
if
b
>=
B
:
return
d0
=
pid_d
*
BLOCK_D
d_range
=
d0
+
tl
.
arange
(
0
,
BLOCK_D
)
d_mask
=
d_range
<
Dk
d_safe
=
tl
.
where
(
d_mask
,
d_range
,
0
)
keep_b
=
tl
.
load
(
keep_ptr
+
b
)
if
keep_b
<=
0
:
return
req_blk_start
=
tl
.
load
(
req_blk_starts_ptr
+
b
)
k0
=
0
while
k0
<
keep_b
:
k_range
=
k0
+
tl
.
arange
(
0
,
BLOCK_T
)
k_mask
=
(
k_range
<
K_max
)
&
(
k_range
<
keep_b
)
k_safe
=
tl
.
where
(
k_mask
,
k_range
,
0
)
idx_base
=
idx_ptr
+
b
*
si_b
+
h
*
si_h
+
k_safe
*
si_k
t_src
=
tl
.
load
(
idx_base
,
mask
=
k_mask
,
other
=
0
)
# No-op copies (src == dst) can be skipped safely because idx_sorted is
# ascending, so we always copy from later/equal positions to earlier.
t_dst
=
k_safe
copy_mask
=
k_mask
&
(
t_src
!=
t_dst
)
blk_src
=
t_src
//
block_size
inb_src
=
t_src
%
block_size
gblk_src
=
req_blk_start
+
blk_src
bid_src
=
tl
.
load
(
blk_ids_ptr
+
gblk_src
,
mask
=
copy_mask
,
other
=
0
)
blk_dst
=
t_dst
//
block_size
inb_dst
=
t_dst
%
block_size
gblk_dst
=
req_blk_start
+
blk_dst
bid_dst
=
tl
.
load
(
blk_ids_ptr
+
gblk_dst
,
mask
=
copy_mask
,
other
=
0
)
src_base
=
K_ptr
+
bid_src
[:,
None
]
*
sKb
+
h
*
sKh
+
inb_src
[:,
None
]
*
sKt
src_ptrs
=
src_base
+
d_safe
[
None
,
:]
*
sKd
dst_base
=
K_ptr
+
bid_dst
[:,
None
]
*
sKb
+
h
*
sKh
+
inb_dst
[:,
None
]
*
sKt
dst_ptrs
=
dst_base
+
d_safe
[
None
,
:]
*
sKd
tile
=
tl
.
load
(
src_ptrs
,
mask
=
(
copy_mask
[:,
None
]
&
d_mask
[
None
,
:]),
other
=
0
)
tl
.
store
(
dst_ptrs
,
tile
,
mask
=
(
copy_mask
[:,
None
]
&
d_mask
[
None
,
:]))
k0
+=
BLOCK_T
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
'BLOCK_T'
:
128
,
'BLOCK_D'
:
64
},
num_warps
=
4
,
num_stages
=
2
),
triton
.
Config
({
'BLOCK_T'
:
256
,
'BLOCK_D'
:
64
},
num_warps
=
4
,
num_stages
=
2
),
triton
.
Config
({
'BLOCK_T'
:
512
,
'BLOCK_D'
:
64
},
num_warps
=
8
,
num_stages
=
2
),
triton
.
Config
({
'BLOCK_T'
:
256
,
'BLOCK_D'
:
128
},
num_warps
=
8
,
num_stages
=
2
),
],
key
=
[
'K_max'
,
'Dv'
],
)
@
triton
.
jit
def
_front_compact_inplace_fa_v_kernel
(
V_ptr
,
blk_ids_ptr
,
req_blk_starts_ptr
,
idx_ptr
,
keep_ptr
,
B
,
H
,
K_max
,
block_size
,
Dv
,
sv_b
,
sv_h
,
sv_d
,
sv_t
,
si_b
,
si_h
,
si_k
,
BLOCK_T
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
):
pid_bh
=
tl
.
program_id
(
0
)
pid_d
=
tl
.
program_id
(
1
)
b
=
pid_bh
//
H
h
=
pid_bh
%
H
if
b
>=
B
:
return
d0
=
pid_d
*
BLOCK_D
d_range
=
d0
+
tl
.
arange
(
0
,
BLOCK_D
)
d_mask
=
d_range
<
Dv
d_safe
=
tl
.
where
(
d_mask
,
d_range
,
0
)
keep_b
=
tl
.
load
(
keep_ptr
+
b
)
if
keep_b
<=
0
:
return
req_blk_start
=
tl
.
load
(
req_blk_starts_ptr
+
b
)
k0
=
0
while
k0
<
keep_b
:
k_range
=
k0
+
tl
.
arange
(
0
,
BLOCK_T
)
k_mask
=
(
k_range
<
K_max
)
&
(
k_range
<
keep_b
)
k_safe
=
tl
.
where
(
k_mask
,
k_range
,
0
)
idx_base
=
idx_ptr
+
b
*
si_b
+
h
*
si_h
+
k_safe
*
si_k
t_src
=
tl
.
load
(
idx_base
,
mask
=
k_mask
,
other
=
0
)
t_dst
=
k_safe
copy_mask
=
k_mask
&
(
t_src
!=
t_dst
)
blk_src
=
t_src
//
block_size
inb_src
=
t_src
%
block_size
gblk_src
=
req_blk_start
+
blk_src
bid_src
=
tl
.
load
(
blk_ids_ptr
+
gblk_src
,
mask
=
copy_mask
,
other
=
0
)
blk_dst
=
t_dst
//
block_size
inb_dst
=
t_dst
%
block_size
gblk_dst
=
req_blk_start
+
blk_dst
bid_dst
=
tl
.
load
(
blk_ids_ptr
+
gblk_dst
,
mask
=
copy_mask
,
other
=
0
)
# value layout: [num_blocks, H, Dv, block_size]
v_src_base
=
V_ptr
+
bid_src
[:,
None
]
*
sv_b
+
h
*
sv_h
+
d_safe
[
None
,
:]
*
sv_d
v_src_ptrs
=
v_src_base
+
inb_src
[:,
None
]
*
sv_t
v_dst_base
=
V_ptr
+
bid_dst
[:,
None
]
*
sv_b
+
h
*
sv_h
+
d_safe
[
None
,
:]
*
sv_d
v_dst_ptrs
=
v_dst_base
+
inb_dst
[:,
None
]
*
sv_t
tile
=
tl
.
load
(
v_src_ptrs
,
mask
=
(
copy_mask
[:,
None
]
&
d_mask
[
None
,
:]),
other
=
0
)
tl
.
store
(
v_dst_ptrs
,
tile
,
mask
=
(
copy_mask
[:,
None
]
&
d_mask
[
None
,
:]))
k0
+=
BLOCK_T
@
torch
.
inference_mode
()
def
front_compact_inplace_fa_triton
(
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
idx_sorted
:
torch
.
Tensor
,
keep
:
torch
.
Tensor
,
)
->
None
:
"""In-place front compaction for FlashAttention KV cache.
Moves selected time indices to the front [0..keep[b]) per request for both
key_cache and value_cache in-place.
Expected layouts:
- key_cache: [num_blocks, H, block_size, Dk]
- value_cache: [num_blocks, H, Dv, block_size]
- block_table: [B, max_blocks] int32 physical block ids
- idx_sorted: [B, K] int32 or [B, H, K] int32 (ascending indices)
- keep: [B] int32 (<= K), number of kept tokens per request
"""
_require_triton
()
_check_cuda
(
key_cache
,
value_cache
,
block_table
,
idx_sorted
,
keep
)
if
key_cache
.
ndim
!=
4
or
value_cache
.
ndim
!=
4
:
raise
ValueError
(
"key_cache/value_cache must be 4D tensors."
)
if
block_table
.
ndim
!=
2
:
raise
ValueError
(
"block_table must be 2D [B, max_blocks]."
)
if
idx_sorted
.
ndim
not
in
(
2
,
3
):
raise
ValueError
(
"idx_sorted must be 2D [B,K] or 3D [B,H,K]."
)
if
keep
.
ndim
!=
1
:
raise
ValueError
(
"keep must be 1D [B]."
)
device
=
key_cache
.
device
B
=
int
(
block_table
.
shape
[
0
])
if
B
==
0
:
return
H
=
int
(
key_cache
.
shape
[
1
])
block_size
=
int
(
key_cache
.
shape
[
2
])
Dk
=
int
(
key_cache
.
shape
[
3
])
Dv
=
int
(
value_cache
.
shape
[
2
])
if
idx_sorted
.
ndim
==
2
:
idx_sorted
=
idx_sorted
[:,
None
,
:].
expand
(
-
1
,
H
,
-
1
)
K_max
=
int
(
idx_sorted
.
shape
[
2
])
if
K_max
==
0
:
return
blk_ids
=
block_table
.
to
(
device
=
device
,
dtype
=
torch
.
int32
).
reshape
(
-
1
)
max_blocks
=
int
(
block_table
.
shape
[
1
])
req_starts
=
(
torch
.
arange
(
B
,
device
=
device
,
dtype
=
torch
.
int32
)
*
max_blocks
)
idx_i32
=
idx_sorted
.
to
(
device
=
device
,
dtype
=
torch
.
int32
)
keep_i32
=
keep
.
to
(
device
=
device
,
dtype
=
torch
.
int32
)
sKb
,
sKh
,
sKt
,
sKd
=
[
int
(
s
)
for
s
in
key_cache
.
stride
()]
sv_b
,
sv_h
,
sv_d
,
sv_t
=
[
int
(
s
)
for
s
in
value_cache
.
stride
()]
si_b
,
si_h
,
si_k
=
[
int
(
s
)
for
s
in
idx_i32
.
stride
()]
if
Dk
>
0
:
grid_k
=
(
B
*
H
,
triton
.
cdiv
(
Dk
,
64
),
)
_front_compact_inplace_fa_k_kernel
[
grid_k
](
key_cache
,
blk_ids
,
req_starts
,
idx_i32
,
keep_i32
,
B
,
H
,
K_max
,
block_size
,
Dk
,
sKb
,
sKh
,
sKt
,
sKd
,
si_b
,
si_h
,
si_k
,
)
if
Dv
>
0
:
grid_v
=
(
B
*
H
,
triton
.
cdiv
(
Dv
,
64
),
)
_front_compact_inplace_fa_v_kernel
[
grid_v
](
value_cache
,
blk_ids
,
req_starts
,
idx_i32
,
keep_i32
,
B
,
H
,
K_max
,
block_size
,
Dv
,
sv_b
,
sv_h
,
sv_d
,
sv_t
,
si_b
,
si_h
,
si_k
,
)
def
make_fa_cache_view
(
*
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Return (K_view, V_view) in the canonical FA compaction layout.
- K_view: [num_blocks, H, block_size, D]
- V_view: [num_blocks, H, D, block_size]
"""
if
key_cache
.
ndim
!=
4
or
value_cache
.
ndim
!=
4
:
raise
ValueError
(
"key_cache/value_cache must be 4D tensors."
)
# ROCm path (FlashAttention v1): K=[B,H,T,D] and V=[B,H,D,T]
if
(
value_cache
.
shape
[
3
]
==
key_cache
.
shape
[
2
]
and
value_cache
.
shape
[
2
]
==
key_cache
.
shape
[
3
]):
k_view
=
key_cache
v_view
=
value_cache
else
:
# CUDA path: K=[B,T,H,D] and V=[B,T,H,D]
k_view
=
key_cache
.
permute
(
0
,
2
,
1
,
3
)
v_view
=
value_cache
.
permute
(
0
,
2
,
3
,
1
)
return
k_view
,
v_view
vllm/v1/attention/kv_compression/snapkv_triton.py
0 → 100644
View file @
5036e878
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
import
math
from
typing
import
Optional
,
Union
import
torch
from
vllm.triton_utils
import
HAS_TRITON
if
HAS_TRITON
:
import
triton
import
triton.language
as
tl
if
HAS_TRITON
:
@
triton
.
autotune
(
configs
=
[
triton
.
Config
(
{
"BLOCK_Q"
:
bq
,
"BLOCK_K"
:
bk
},
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
for
bq
in
[
32
,
64
]
for
bk
in
[
32
,
64
]
for
num_warps
in
[
4
,
8
]
for
num_stages
in
[
3
,
4
]
],
key
=
[
"QUERY_GROUP_SIZE"
,
"D"
,
"ROWS_MAX"
],
)
@
triton
.
jit
def
_lse_and_store_logits_kernel
(
Q
,
K
,
cu_q
,
cu_k
,
w_b
,
out_m
,
out_S
,
LOGITS
,
sm_scale
,
QUERY_GROUP_SIZE
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
STRIDE_Q_NQ
,
STRIDE_Q_HQ
,
STRIDE_K_NK
,
STRIDE_K_HK
,
STRIDE_M_B
,
STRIDE_M_H
,
STRIDE_M_R
,
STRIDE_S_B
,
STRIDE_S_H
,
STRIDE_S_R
,
STRIDE_LG_NK
,
STRIDE_LG_HK
,
STRIDE_LG_R
,
BLOCK_Q
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
ROWS_MAX
,
):
b
=
tl
.
program_id
(
0
)
hk
=
tl
.
program_id
(
1
)
rid
=
tl
.
program_id
(
2
)
q_end
=
tl
.
load
(
cu_q
+
b
+
1
)
k_beg
=
tl
.
load
(
cu_k
+
b
)
k_end
=
tl
.
load
(
cu_k
+
b
+
1
)
win
=
tl
.
load
(
w_b
+
b
)
q_win_beg
=
q_end
-
win
k_eff_end
=
k_end
-
win
if
(
win
<=
0
)
or
(
k_eff_end
<=
k_beg
):
return
rows_b
=
win
*
QUERY_GROUP_SIZE
row0
=
rid
*
BLOCK_Q
if
row0
>=
rows_b
:
return
qk_scale
=
sm_scale
*
1.4426950408889634
# exp -> exp2.
offs_qrow
=
row0
+
tl
.
arange
(
0
,
BLOCK_Q
)
row_mask
=
offs_qrow
<
rows_b
hq_local
=
offs_qrow
%
QUERY_GROUP_SIZE
q_off
=
offs_qrow
//
QUERY_GROUP_SIZE
q_idx
=
q_win_beg
+
q_off
hq_glob
=
hk
*
QUERY_GROUP_SIZE
+
hq_local
offs_d
=
tl
.
arange
(
0
,
D
)
q_ptrs
=
(
Q
+
q_idx
[:,
None
]
*
STRIDE_Q_NQ
+
hq_glob
[:,
None
]
*
STRIDE_Q_HQ
+
offs_d
[
None
,
:])
q_rows
=
tl
.
load
(
q_ptrs
,
mask
=
row_mask
[:,
None
],
other
=
0.0
)
m
=
tl
.
zeros
([
BLOCK_Q
],
dtype
=
tl
.
float32
)
+
(
-
float
(
"inf"
))
S
=
tl
.
zeros
([
BLOCK_Q
],
dtype
=
tl
.
float32
)
for
ks
in
tl
.
range
(
k_beg
,
k_eff_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
k_eff_end
k_ptrs
=
(
K
+
nk
[:,
None
]
*
STRIDE_K_NK
+
hk
*
STRIDE_K_HK
+
offs_d
[
None
,
:])
k_blk
=
tl
.
load
(
k_ptrs
,
mask
=
kmask
[:,
None
],
other
=
0.0
)
s
=
tl
.
dot
(
q_rows
,
k_blk
.
T
)
*
qk_scale
s
=
tl
.
where
(
kmask
[
None
,
:],
s
,
-
float
(
"inf"
))
log_ptrs
=
(
LOGITS
+
nk
[:,
None
]
*
STRIDE_LG_NK
+
hk
*
STRIDE_LG_HK
+
(
row0
+
tl
.
arange
(
0
,
BLOCK_Q
))[
None
,
:]
*
STRIDE_LG_R
)
tl
.
store
(
log_ptrs
,
s
.
T
,
mask
=
kmask
[:,
None
]
&
row_mask
[
None
,
:])
cur_max
=
tl
.
max
(
s
,
1
)
n_m
=
tl
.
maximum
(
m
,
cur_max
)
rescale
=
tl
.
math
.
exp2
(
m
-
n_m
)
S
=
S
*
rescale
+
tl
.
sum
(
tl
.
math
.
exp2
(
s
-
n_m
[:,
None
]),
1
)
m
=
n_m
m_base
=
out_m
+
b
*
STRIDE_M_B
+
hk
*
STRIDE_M_H
+
row0
*
STRIDE_M_R
S_base
=
out_S
+
b
*
STRIDE_S_B
+
hk
*
STRIDE_S_H
+
row0
*
STRIDE_S_R
tl
.
store
(
m_base
+
tl
.
arange
(
0
,
BLOCK_Q
)
*
STRIDE_M_R
,
m
,
mask
=
row_mask
)
tl
.
store
(
S_base
+
tl
.
arange
(
0
,
BLOCK_Q
)
*
STRIDE_S_R
,
S
,
mask
=
row_mask
)
@
triton
.
jit
def
_lse_and_store_logits_kernel_rocm_safe
(
Q
,
K
,
cu_q
,
cu_k
,
w_b
,
out_m
,
out_S
,
LOGITS
,
sm_scale
,
QUERY_GROUP_SIZE
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
STRIDE_Q_NQ
,
STRIDE_Q_HQ
,
STRIDE_K_NK
,
STRIDE_K_HK
,
STRIDE_M_B
,
STRIDE_M_H
,
STRIDE_M_R
,
STRIDE_S_B
,
STRIDE_S_H
,
STRIDE_S_R
,
STRIDE_LG_NK
,
STRIDE_LG_HK
,
STRIDE_LG_R
,
BLOCK_Q
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
):
"""ROCm-safe variant of `_lse_and_store_logits_kernel`.
On some ROCm + Triton (HIP) stacks we have observed memory corruption
from the tl.dot-based implementation. This variant avoids `tl.dot` and
instead computes dot-products via explicit outer-product accumulation.
"""
b
=
tl
.
program_id
(
0
)
hk
=
tl
.
program_id
(
1
)
rid
=
tl
.
program_id
(
2
)
q_end
=
tl
.
load
(
cu_q
+
b
+
1
)
k_beg
=
tl
.
load
(
cu_k
+
b
)
k_end
=
tl
.
load
(
cu_k
+
b
+
1
)
win
=
tl
.
load
(
w_b
+
b
)
q_win_beg
=
q_end
-
win
k_eff_end
=
k_end
-
win
if
(
win
<=
0
)
or
(
k_eff_end
<=
k_beg
):
return
rows_b
=
win
*
QUERY_GROUP_SIZE
row0
=
rid
*
BLOCK_Q
if
row0
>=
rows_b
:
return
qk_scale
=
sm_scale
*
1.4426950408889634
# exp -> exp2.
offs_qrow
=
row0
+
tl
.
arange
(
0
,
BLOCK_Q
)
row_mask
=
offs_qrow
<
rows_b
hq_local
=
offs_qrow
%
QUERY_GROUP_SIZE
q_off
=
offs_qrow
//
QUERY_GROUP_SIZE
q_idx
=
q_win_beg
+
q_off
hq_glob
=
hk
*
QUERY_GROUP_SIZE
+
hq_local
m
=
tl
.
zeros
([
BLOCK_Q
],
dtype
=
tl
.
float32
)
+
(
-
float
(
"inf"
))
S
=
tl
.
zeros
([
BLOCK_Q
],
dtype
=
tl
.
float32
)
for
ks
in
tl
.
range
(
k_beg
,
k_eff_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
k_eff_end
# Accumulate s = Q @ K^T in fp32 via outer products.
s
=
tl
.
zeros
([
BLOCK_Q
,
BLOCK_K
],
dtype
=
tl
.
float32
)
for
ds
in
tl
.
static_range
(
0
,
D
,
BLOCK_D
):
offs_d
=
ds
+
tl
.
arange
(
0
,
BLOCK_D
)
dmask
=
offs_d
<
D
q_ptrs
=
(
Q
+
q_idx
[:,
None
]
*
STRIDE_Q_NQ
+
hq_glob
[:,
None
]
*
STRIDE_Q_HQ
+
offs_d
[
None
,
:])
q_chunk
=
tl
.
load
(
q_ptrs
,
mask
=
row_mask
[:,
None
]
&
dmask
[
None
,
:],
other
=
0.0
,
).
to
(
tl
.
float32
)
# [BQ, BD]
k_ptrs
=
(
K
+
nk
[:,
None
]
*
STRIDE_K_NK
+
hk
*
STRIDE_K_HK
+
offs_d
[
None
,
:])
k_chunk
=
tl
.
load
(
k_ptrs
,
mask
=
kmask
[:,
None
]
&
dmask
[
None
,
:],
other
=
0.0
,
).
to
(
tl
.
float32
)
# [BK, BD]
s
+=
tl
.
sum
(
q_chunk
[:,
None
,
:]
*
k_chunk
[
None
,
:,
:],
axis
=
2
)
s
=
s
*
qk_scale
s
=
tl
.
where
(
kmask
[
None
,
:],
s
,
-
float
(
"inf"
))
log_ptrs
=
(
LOGITS
+
nk
[:,
None
]
*
STRIDE_LG_NK
+
hk
*
STRIDE_LG_HK
+
(
row0
+
tl
.
arange
(
0
,
BLOCK_Q
))[
None
,
:]
*
STRIDE_LG_R
)
tl
.
store
(
log_ptrs
,
s
.
T
,
mask
=
kmask
[:,
None
]
&
row_mask
[
None
,
:])
cur_max
=
tl
.
max
(
s
,
1
)
n_m
=
tl
.
maximum
(
m
,
cur_max
)
rescale
=
tl
.
math
.
exp2
(
m
-
n_m
)
S
=
S
*
rescale
+
tl
.
sum
(
tl
.
math
.
exp2
(
s
-
n_m
[:,
None
]),
1
)
m
=
n_m
m_base
=
out_m
+
b
*
STRIDE_M_B
+
hk
*
STRIDE_M_H
+
row0
*
STRIDE_M_R
S_base
=
out_S
+
b
*
STRIDE_S_B
+
hk
*
STRIDE_S_H
+
row0
*
STRIDE_S_R
tl
.
store
(
m_base
+
tl
.
arange
(
0
,
BLOCK_Q
)
*
STRIDE_M_R
,
m
,
mask
=
row_mask
)
tl
.
store
(
S_base
+
tl
.
arange
(
0
,
BLOCK_Q
)
*
STRIDE_S_R
,
S
,
mask
=
row_mask
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_Q"
:
bq
,
"BLOCK_K"
:
bk
})
for
bq
in
[
16
,
32
,
64
]
for
bk
in
[
32
,
64
,
128
]
],
key
=
[
"HK"
,
"HQ"
],
)
@
triton
.
jit
def
_scores_from_logits_kernel
(
cu_k
,
w_b
,
in_m
,
in_S
,
LOGITS
,
OUT
,
QUERY_GROUP_SIZE
:
tl
.
constexpr
,
STRIDE_M_B
,
STRIDE_M_H
,
STRIDE_M_R
,
STRIDE_S_B
,
STRIDE_S_H
,
STRIDE_S_R
,
STRIDE_LG_NK
,
STRIDE_LG_HK
,
STRIDE_LG_R
,
STRIDE_OUT_NK
,
STRIDE_OUT_HK
,
BLOCK_Q
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
DO_POOL
:
tl
.
constexpr
,
KPOOL
:
tl
.
constexpr
,
PROTECT_LAST
:
tl
.
constexpr
,
):
b
=
tl
.
program_id
(
0
)
hk
=
tl
.
program_id
(
1
)
k_beg
=
tl
.
load
(
cu_k
+
b
)
k_end
=
tl
.
load
(
cu_k
+
b
+
1
)
win
=
tl
.
load
(
w_b
+
b
)
k_eff_end
=
k_end
-
win
if
(
win
<=
0
)
or
(
k_eff_end
<=
k_beg
):
return
rows_b
=
win
*
QUERY_GROUP_SIZE
for
ks
in
tl
.
range
(
k_beg
,
k_eff_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
k_eff_end
scores
=
tl
.
zeros
([
BLOCK_K
],
dtype
=
tl
.
float32
)
for
row0
in
tl
.
range
(
0
,
rows_b
,
BLOCK_Q
):
r_idx
=
row0
+
tl
.
arange
(
0
,
BLOCK_Q
)
rmask
=
r_idx
<
rows_b
m_ptr
=
(
in_m
+
b
*
STRIDE_M_B
+
hk
*
STRIDE_M_H
+
row0
*
STRIDE_M_R
)
S_ptr
=
(
in_S
+
b
*
STRIDE_S_B
+
hk
*
STRIDE_S_H
+
row0
*
STRIDE_S_R
)
m
=
tl
.
load
(
m_ptr
+
tl
.
arange
(
0
,
BLOCK_Q
)
*
STRIDE_M_R
,
mask
=
rmask
,
other
=-
float
(
"inf"
))
S
=
tl
.
load
(
S_ptr
+
tl
.
arange
(
0
,
BLOCK_Q
)
*
STRIDE_S_R
,
mask
=
rmask
,
other
=
0.0
)
valid_row
=
S
>
0
m
=
tl
.
where
(
valid_row
,
m
,
0.0
)
S
=
tl
.
where
(
valid_row
,
S
,
1.0
)
log_ptrs
=
(
LOGITS
+
nk
[:,
None
]
*
STRIDE_LG_NK
+
hk
*
STRIDE_LG_HK
+
(
row0
+
tl
.
arange
(
0
,
BLOCK_Q
))[
None
,
:]
*
STRIDE_LG_R
)
s_T
=
tl
.
load
(
log_ptrs
,
mask
=
kmask
[:,
None
]
&
rmask
[
None
,
:],
other
=-
float
(
"inf"
))
probs_T
=
tl
.
math
.
exp2
(
s_T
-
m
[
None
,
:])
/
S
[
None
,
:]
probs_T
=
tl
.
where
(
valid_row
[
None
,
:],
probs_T
,
0.0
)
scores
+=
tl
.
sum
(
probs_T
,
1
)
if
DO_POOL
and
(
KPOOL
>
1
):
i
=
tl
.
arange
(
0
,
BLOCK_K
)[:,
None
]
j
=
tl
.
arange
(
0
,
BLOCK_K
)[
None
,
:]
band
=
(
j
<=
i
)
&
((
i
-
j
)
<
KPOOL
)
band
=
band
&
kmask
[
None
,
:]
sums
=
tl
.
sum
(
tl
.
where
(
band
,
scores
[
None
,
:],
0.0
),
1
)
denom
=
tl
.
sum
(
band
,
1
).
to
(
tl
.
float32
)
denom
=
tl
.
where
(
denom
>
0
,
denom
,
1.0
)
scores
=
sums
/
denom
out_ptrs
=
OUT
+
nk
*
STRIDE_OUT_NK
+
hk
*
STRIDE_OUT_HK
tl
.
store
(
out_ptrs
,
scores
,
mask
=
kmask
)
if
PROTECT_LAST
:
pad_beg
=
k_eff_end
pad_end
=
k_end
if
pad_end
>
pad_beg
:
for
ks
in
tl
.
range
(
pad_beg
,
pad_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
pad_end
out_ptrs
=
OUT
+
nk
*
STRIDE_OUT_NK
+
hk
*
STRIDE_OUT_HK
tl
.
store
(
out_ptrs
,
tl
.
full
([
BLOCK_K
],
float
(
"inf"
),
dtype
=
tl
.
float32
),
mask
=
kmask
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_K"
:
bk
})
for
bk
in
[
32
,
64
,
128
]],
key
=
[
"HK"
],
)
@
triton
.
jit
def
_zscore_per_batch_epilogue
(
OUT
,
cu_k
,
w_b
,
STRIDE_OUT_NK
,
STRIDE_OUT_HK
,
HK
:
tl
.
constexpr
,
EPS
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
):
b
=
tl
.
program_id
(
0
)
k_beg
=
tl
.
load
(
cu_k
+
b
)
k_end
=
tl
.
load
(
cu_k
+
b
+
1
)
win
=
tl
.
load
(
w_b
+
b
)
k_eff_end
=
k_end
-
win
if
k_eff_end
<=
k_beg
:
return
sumv
=
tl
.
zeros
([],
dtype
=
tl
.
float32
)
sumsq
=
tl
.
zeros
([],
dtype
=
tl
.
float32
)
count
=
((
k_eff_end
-
k_beg
)
*
HK
).
to
(
tl
.
float32
)
for
ks
in
tl
.
range
(
k_beg
,
k_eff_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
k_eff_end
for
h
in
tl
.
range
(
0
,
HK
):
ptrs
=
OUT
+
nk
*
STRIDE_OUT_NK
+
h
*
STRIDE_OUT_HK
vals
=
tl
.
load
(
ptrs
,
mask
=
kmask
,
other
=
0.0
).
to
(
tl
.
float32
)
sumv
+=
tl
.
sum
(
vals
,
0
)
sumsq
+=
tl
.
sum
(
vals
*
vals
,
0
)
mean
=
sumv
/
count
var
=
tl
.
maximum
(
sumsq
/
count
-
mean
*
mean
,
0.0
)
invstd
=
1.0
/
tl
.
sqrt
(
var
+
EPS
)
for
ks
in
tl
.
range
(
k_beg
,
k_eff_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
k_eff_end
for
h
in
tl
.
range
(
0
,
HK
):
ptrs
=
OUT
+
nk
*
STRIDE_OUT_NK
+
h
*
STRIDE_OUT_HK
vals
=
tl
.
load
(
ptrs
,
mask
=
kmask
,
other
=
0.0
).
to
(
tl
.
float32
)
vals
=
(
vals
-
mean
)
*
invstd
tl
.
store
(
ptrs
,
vals
,
mask
=
kmask
)
def
query_aware_key_scores
(
q
:
torch
.
Tensor
,
# [N_q, Hq, D]
k
:
torch
.
Tensor
,
# [N_k, Hk, D]
cu_seqlens_q
:
torch
.
Tensor
,
# [B+1] int32
cu_seqlens_k
:
torch
.
Tensor
,
# [B+1] int32
w
:
Union
[
int
,
torch
.
Tensor
],
# [B] int32 or scalar
sm_scale
:
Optional
[
float
]
=
None
,
*
,
pool
:
bool
=
True
,
kpool
:
int
=
5
,
protect_last
:
bool
=
True
,
normalize
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""SnapKV query-aware key scores (Triton), returns [N_k, Hk] float32."""
if
not
HAS_TRITON
:
raise
RuntimeError
(
"Triton is not available."
)
if
q
.
device
.
type
!=
"cuda"
or
k
.
device
.
type
!=
"cuda"
:
raise
RuntimeError
(
"Triton SnapKV requires CUDA/ROCm tensors."
)
if
q
.
ndim
!=
3
or
k
.
ndim
!=
3
:
raise
ValueError
(
"q and k must be 3D tensors."
)
if
q
.
stride
(
-
1
)
!=
1
or
k
.
stride
(
-
1
)
!=
1
:
raise
ValueError
(
"Last dim must be contiguous for Triton SnapKV."
)
device
=
q
.
device
N_q
,
Hq
,
D
=
q
.
shape
N_k
,
Hk
,
Dk
=
k
.
shape
if
D
!=
Dk
:
raise
ValueError
(
"q and k must have the same head size."
)
if
(
Hq
%
Hk
)
!=
0
:
raise
ValueError
(
"Hq must be a multiple of Hk."
)
if
sm_scale
is
None
:
sm_scale
=
1.0
/
math
.
sqrt
(
D
)
B
=
int
(
cu_seqlens_q
.
numel
()
-
1
)
if
B
!=
int
(
cu_seqlens_k
.
numel
()
-
1
):
raise
ValueError
(
"cu_seqlens_q and cu_seqlens_k must match."
)
G
=
Hq
//
Hk
if
isinstance
(
w
,
int
):
max_w
=
int
(
w
)
w
=
torch
.
full
((
B
,
),
fill_value
=
max_w
,
device
=
device
,
dtype
=
torch
.
int32
)
else
:
if
w
.
numel
()
!=
B
:
raise
ValueError
(
"w must have shape [B]."
)
w
=
w
.
to
(
device
=
device
,
dtype
=
torch
.
int32
)
max_w
=
int
(
w
.
max
().
item
())
rows_max
=
max_w
*
G
if
rows_max
<=
0
:
return
torch
.
zeros
((
N_k
,
Hk
),
dtype
=
torch
.
float32
,
device
=
device
)
if
kpool
<
1
:
raise
ValueError
(
"kpool must be >= 1."
)
out
=
torch
.
zeros
((
N_k
,
Hk
),
dtype
=
torch
.
float32
,
device
=
device
)
m_scratch
=
torch
.
empty
((
B
,
Hk
,
rows_max
),
dtype
=
torch
.
float32
,
device
=
device
)
S_scratch
=
torch
.
empty
((
B
,
Hk
,
rows_max
),
dtype
=
torch
.
float32
,
device
=
device
)
logits_buf
=
torch
.
empty
((
N_k
,
Hk
,
rows_max
),
dtype
=
torch
.
float32
,
device
=
device
)
STRIDE_Q_NQ
,
STRIDE_Q_HQ
,
_
=
q
.
stride
()
STRIDE_K_NK
,
STRIDE_K_HK
,
_
=
k
.
stride
()
STRIDE_M_B
,
STRIDE_M_H
,
STRIDE_M_R
=
m_scratch
.
stride
()
STRIDE_S_B
,
STRIDE_S_H
,
STRIDE_S_R
=
S_scratch
.
stride
()
STRIDE_LG_NK
,
STRIDE_LG_HK
,
STRIDE_LG_R
=
logits_buf
.
stride
()
STRIDE_OUT_NK
,
STRIDE_OUT_HK
=
out
.
stride
()
def
grid
(
meta
):
return
B
,
Hk
,
triton
.
cdiv
(
rows_max
,
meta
[
"BLOCK_Q"
])
cu_q
=
cu_seqlens_q
.
to
(
device
=
device
,
dtype
=
torch
.
int32
)
cu_k
=
cu_seqlens_k
.
to
(
device
=
device
,
dtype
=
torch
.
int32
)
# NOTE: On ROCm/HIP, we prefer a dot-free kernel variant to avoid known
# correctness issues (silent memory corruption) observed with the tl.dot
# implementation on some stacks.
is_rocm
=
getattr
(
torch
.
version
,
"hip"
,
None
)
is
not
None
if
is_rocm
:
_lse_and_store_logits_kernel_rocm_safe
[
grid
](
q
,
k
,
cu_q
,
cu_k
,
w
,
m_scratch
,
S_scratch
,
logits_buf
,
sm_scale
,
QUERY_GROUP_SIZE
=
G
,
D
=
D
,
STRIDE_Q_NQ
=
STRIDE_Q_NQ
,
STRIDE_Q_HQ
=
STRIDE_Q_HQ
,
STRIDE_K_NK
=
STRIDE_K_NK
,
STRIDE_K_HK
=
STRIDE_K_HK
,
STRIDE_M_B
=
STRIDE_M_B
,
STRIDE_M_H
=
STRIDE_M_H
,
STRIDE_M_R
=
STRIDE_M_R
,
STRIDE_S_B
=
STRIDE_S_B
,
STRIDE_S_H
=
STRIDE_S_H
,
STRIDE_S_R
=
STRIDE_S_R
,
STRIDE_LG_NK
=
STRIDE_LG_NK
,
STRIDE_LG_HK
=
STRIDE_LG_HK
,
STRIDE_LG_R
=
STRIDE_LG_R
,
BLOCK_Q
=
32
,
BLOCK_K
=
32
,
BLOCK_D
=
16
,
num_warps
=
4
,
num_stages
=
1
,
)
else
:
_lse_and_store_logits_kernel
[
grid
](
q
,
k
,
cu_q
,
cu_k
,
w
,
m_scratch
,
S_scratch
,
logits_buf
,
sm_scale
,
QUERY_GROUP_SIZE
=
G
,
D
=
D
,
STRIDE_Q_NQ
=
STRIDE_Q_NQ
,
STRIDE_Q_HQ
=
STRIDE_Q_HQ
,
STRIDE_K_NK
=
STRIDE_K_NK
,
STRIDE_K_HK
=
STRIDE_K_HK
,
STRIDE_M_B
=
STRIDE_M_B
,
STRIDE_M_H
=
STRIDE_M_H
,
STRIDE_M_R
=
STRIDE_M_R
,
STRIDE_S_B
=
STRIDE_S_B
,
STRIDE_S_H
=
STRIDE_S_H
,
STRIDE_S_R
=
STRIDE_S_R
,
STRIDE_LG_NK
=
STRIDE_LG_NK
,
STRIDE_LG_HK
=
STRIDE_LG_HK
,
STRIDE_LG_R
=
STRIDE_LG_R
,
ROWS_MAX
=
rows_max
,
)
_scores_from_logits_kernel
[(
B
,
Hk
)](
cu_k
,
w
,
m_scratch
,
S_scratch
,
logits_buf
,
out
,
QUERY_GROUP_SIZE
=
G
,
STRIDE_M_B
=
STRIDE_M_B
,
STRIDE_M_H
=
STRIDE_M_H
,
STRIDE_M_R
=
STRIDE_M_R
,
STRIDE_S_B
=
STRIDE_S_B
,
STRIDE_S_H
=
STRIDE_S_H
,
STRIDE_S_R
=
STRIDE_S_R
,
STRIDE_LG_NK
=
STRIDE_LG_NK
,
STRIDE_LG_HK
=
STRIDE_LG_HK
,
STRIDE_LG_R
=
STRIDE_LG_R
,
STRIDE_OUT_NK
=
STRIDE_OUT_NK
,
STRIDE_OUT_HK
=
STRIDE_OUT_HK
,
DO_POOL
=
pool
,
KPOOL
=
kpool
,
PROTECT_LAST
=
protect_last
,
)
if
normalize
:
_zscore_per_batch_epilogue
[(
B
,
)](
out
,
cu_k
,
w
,
STRIDE_OUT_NK
,
STRIDE_OUT_HK
,
HK
=
Hk
,
EPS
=
1e-12
,
)
return
out
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment