Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d29c39ca
Commit
d29c39ca
authored
Apr 30, 2026
by
chenzk
Browse files
vllm kvprune wo:v1.1.0
parent
f81ce56b
Changes
246
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2493 additions
and
0 deletions
+2493
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/blackwell_value.py
..._kernels/tensor_details/layout_details/blackwell_value.py
+37
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/cdna4_scale.py
...iton_kernels/tensor_details/layout_details/cdna4_scale.py
+50
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/hopper_scale.py
...ton_kernels/tensor_details/layout_details/hopper_scale.py
+91
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/hopper_value.py
...ton_kernels/tensor_details/layout_details/hopper_value.py
+362
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/strided.py
...m/triton_kernels/tensor_details/layout_details/strided.py
+17
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/testing.py
...mpactor-vllm/src/compactor_vllm/triton_kernels/testing.py
+215
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/topk.py
.../compactor-vllm/src/compactor_vllm/triton_kernels/topk.py
+157
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/topk_details/__init__.py
...rc/compactor_vllm/triton_kernels/topk_details/__init__.py
+0
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/topk_details/_topk_backward.py
...pactor_vllm/triton_kernels/topk_details/_topk_backward.py
+51
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/topk_details/_topk_forward.py
...mpactor_vllm/triton_kernels/topk_details/_topk_forward.py
+183
-0
vllm/compactor-vllm/src/compactor_vllm/utils/__init__.py
vllm/compactor-vllm/src/compactor_vllm/utils/__init__.py
+0
-0
vllm/compactor-vllm/src/compactor_vllm/utils/arguments.py
vllm/compactor-vllm/src/compactor_vllm/utils/arguments.py
+408
-0
vllm/compactor-vllm/src/compactor_vllm/utils/context.py
vllm/compactor-vllm/src/compactor_vllm/utils/context.py
+97
-0
vllm/compactor-vllm/src/compactor_vllm/utils/helpers.py
vllm/compactor-vllm/src/compactor_vllm/utils/helpers.py
+35
-0
vllm/compactor-vllm/src/compactor_vllm/utils/sequence.py
vllm/compactor-vllm/src/compactor_vllm/utils/sequence.py
+83
-0
vllm/compactor-vllm/src/compactor_vllm/utils/triton_compat.py
.../compactor-vllm/src/compactor_vllm/utils/triton_compat.py
+61
-0
vllm/compactor-vllm/tests/test_store_kv.py
vllm/compactor-vllm/tests/test_store_kv.py
+239
-0
vllm/compactor-vllm/tests/test_triton_attention.py
vllm/compactor-vllm/tests/test_triton_attention.py
+407
-0
vllm/compactor-vllm/vllm_memory_comparison.png
vllm/compactor-vllm/vllm_memory_comparison.png
+0
-0
vllm/compactor-vllm/vllm_throughput_comparison.png
vllm/compactor-vllm/vllm_throughput_comparison.png
+0
-0
No files found.
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/blackwell_value.py
0 → 100644
View file @
d29c39ca
import
torch
from
.base
import
Layout
class
BlackwellMXValueLayout
(
Layout
):
name
:
str
=
"BLACKWELL_VALUE"
def
__init__
(
self
,
shape
)
->
None
:
super
().
__init__
(
shape
)
self
.
shape
=
shape
def
swizzle_data
(
self
,
data
):
# permutation needed to make `data` row major
to_row_major
=
sorted
(
range
(
data
.
ndim
),
key
=
lambda
d
:
(
data
.
stride
(
d
),
d
))[::
-
1
]
# permutation needed to retrieve original order
inv
=
[
0
]
*
data
.
ndim
for
i
,
d
in
enumerate
(
to_row_major
):
inv
[
d
]
=
i
# leading dimension must be padded to be aligned to 128
align_dim
=
lambda
x
:
(
x
+
128
-
1
)
//
128
*
128
major_dim
=
data
.
stride
().
index
(
1
)
pad
=
align_dim
(
data
.
shape
[
major_dim
])
-
data
.
shape
[
major_dim
]
data
=
torch
.
nn
.
functional
.
pad
(
data
.
permute
(
to_row_major
),
(
0
,
pad
)).
permute
(
inv
)
return
data
def
unswizzle_data
(
self
,
data
:
torch
.
Tensor
):
# Trim padding along all dims back to the original shape recorded at init.
assert
data
.
ndim
==
len
(
self
.
shape
),
(
"Rank mismatch between data and recorded shape"
)
sizes
=
[
min
(
data
.
size
(
i
),
self
.
shape
[
i
])
for
i
in
range
(
data
.
ndim
)]
return
data
[
tuple
(
slice
(
0
,
s
)
for
s
in
sizes
)]
def
swizzle_block_shape
(
self
,
block_shape
):
return
block_shape
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/cdna4_scale.py
0 → 100644
View file @
d29c39ca
import
triton
import
triton.language
as
tl
from
.base
import
Layout
NON_K_PRESHUFFLE_BLOCK_SIZE
=
32
class
CDNA4MXScaleLayout
(
Layout
):
name
:
str
=
"CDNA4_SCALE"
def
__init__
(
self
,
shape
)
->
None
:
super
().
__init__
(
shape
)
def
swizzle_data
(
self
,
data
):
block_shape
=
data
.
shape
SCALE_K
=
block_shape
[
-
2
]
N
=
block_shape
[
-
1
]
data
=
data
.
transpose
(
-
1
,
-
2
)
data
=
data
.
view
(
-
1
,
N
//
NON_K_PRESHUFFLE_BLOCK_SIZE
,
2
,
16
,
SCALE_K
//
8
,
2
,
4
,
1
)
data
=
data
.
permute
(
0
,
1
,
4
,
6
,
3
,
5
,
2
,
7
).
contiguous
()
if
len
(
block_shape
)
==
3
:
E
=
block_shape
[
0
]
data
=
data
.
reshape
(
E
,
N
//
32
,
SCALE_K
*
32
)
else
:
assert
len
(
block_shape
)
==
2
data
=
data
.
reshape
(
N
//
32
,
SCALE_K
*
32
)
return
data
.
transpose
(
-
1
,
-
2
)
def
unswizzle_data
(
self
,
data
):
raise
NotImplementedError
()
def
swizzle_block_shape
(
self
,
block_shape
):
SCALE_K
=
block_shape
[
-
2
]
N
=
block_shape
[
-
1
]
return
block_shape
[:
-
2
]
+
[
N
//
32
,
SCALE_K
*
32
]
@
triton
.
jit
def
unswizzle_mx_scale_cdna4
(
x
,
BLOCK_N
:
tl
.
constexpr
,
MX_SCALE_BLOCK_K
:
tl
.
constexpr
,
N_PRESHUFFLE_FACTOR
:
tl
.
constexpr
=
NON_K_PRESHUFFLE_BLOCK_SIZE
,
):
x
=
x
.
reshape
(
BLOCK_N
//
N_PRESHUFFLE_FACTOR
,
MX_SCALE_BLOCK_K
//
8
,
4
,
16
,
2
,
2
,
1
)
x
=
x
.
permute
(
0
,
5
,
3
,
1
,
4
,
2
,
6
)
x
=
x
.
reshape
(
BLOCK_N
,
MX_SCALE_BLOCK_K
)
return
x
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/hopper_scale.py
0 → 100644
View file @
d29c39ca
import
torch
import
triton
import
triton.language
as
tl
from
.base
import
Layout
class
HopperMXScaleLayout
(
Layout
):
name
:
str
=
"HOPPER_SCALE"
def
__init__
(
self
,
shape
,
mx_axis
,
num_warps
=
8
)
->
None
:
assert
num_warps
&
(
num_warps
-
1
)
==
0
,
"warps_n must be a power of 2"
super
().
__init__
(
shape
)
self
.
mx_axis
=
mx_axis
self
.
num_warps
=
num_warps
*
self
.
leading_shape
,
_
,
_
=
shape
def
_maybe_mT
(
self
,
data
):
if
self
.
mx_axis
==
len
(
self
.
leading_shape
):
return
data
.
contiguous
().
mT
return
data
def
swizzle_data
(
self
,
data
):
data
=
self
.
_maybe_mT
(
data
).
contiguous
()
*
batch
,
M
,
K
=
data
.
shape
SWIZZLE_ALIGN_M
=
2
*
self
.
num_warps
*
2
*
8
SWIZZLE_ALIGN_K
=
2
pad_m
=
(
SWIZZLE_ALIGN_M
-
(
M
%
SWIZZLE_ALIGN_M
))
%
SWIZZLE_ALIGN_M
pad_k
=
(
SWIZZLE_ALIGN_K
-
(
K
%
SWIZZLE_ALIGN_K
))
%
SWIZZLE_ALIGN_K
data
=
torch
.
nn
.
functional
.
pad
(
data
,
(
0
,
pad_k
,
0
,
pad_m
))
*
batch
,
M
,
K
=
data
.
shape
assert
data
.
is_contiguous
()
assert
M
%
(
2
*
self
.
num_warps
*
2
*
8
)
==
0
and
K
%
2
==
0
,
(
f
"Input tensor must have a subtile of shape (...,
{
2
*
self
.
num_warps
*
2
*
8
}
, 2)"
)
b
=
len
(
batch
)
data
=
data
.
reshape
(
*
batch
,
M
//
(
2
*
self
.
num_warps
*
2
*
8
),
2
,
self
.
num_warps
,
2
,
8
,
K
//
2
,
2
,
)
perm
=
[
0
,
2
,
5
,
1
,
4
,
6
,
3
]
perm
=
list
(
range
(
b
))
+
[
b
+
p
for
p
in
perm
]
data
=
data
.
permute
(
*
perm
)
data
=
data
.
flatten
(
-
5
,
-
1
)
data
=
data
.
flatten
(
-
3
,
-
2
)
assert
data
.
shape
[
-
2
]
==
M
//
32
assert
data
.
shape
[
-
1
]
==
K
*
32
data
=
self
.
_maybe_mT
(
data
)
return
data
def
unswizzle_data
(
self
,
data
):
data
=
self
.
_maybe_mT
(
data
)
*
batch
,
M
,
K
=
data
.
shape
b
=
len
(
batch
)
data
=
data
.
reshape
(
*
batch
,
M
//
self
.
num_warps
,
self
.
num_warps
,
K
//
64
,
2
,
8
,
2
,
2
)
perm
=
[
0
,
3
,
1
,
6
,
4
,
2
,
5
]
perm
=
list
(
range
(
b
))
+
[
b
+
p
for
p
in
perm
]
data
=
data
.
permute
(
*
perm
)
data
=
data
.
reshape
(
*
batch
,
M
*
32
,
K
//
32
)
data
=
self
.
_maybe_mT
(
data
)
return
data
def
swizzle_block_shape
(
self
,
block_shape
):
return
block_shape
@
triton
.
jit
def
unswizzle_mxfp4_scale_hopper
(
x
,
mx_axis
:
tl
.
constexpr
,
num_warps
:
tl
.
constexpr
):
"""
Triton inverse of swizzle_mxfp4_scale_hopper
"""
tl
.
static_assert
(
len
(
x
.
shape
)
==
2
,
"NYI"
)
# implementation assumes mxfp data is packed along the last dimension
x
=
x
.
trans
()
if
mx_axis
==
0
else
x
M
:
tl
.
constexpr
=
x
.
shape
[
0
]
K
:
tl
.
constexpr
=
x
.
shape
[
1
]
tl
.
static_assert
(
M
%
num_warps
==
0
,
f
"M must be divisible by
{
num_warps
}
. Got
{
M
}
"
)
tl
.
static_assert
(
K
%
64
==
0
,
f
"K must be divisible by 64. Got
{
K
}
"
)
x
=
x
.
reshape
(
M
//
num_warps
,
num_warps
,
K
//
64
,
2
,
8
,
2
,
2
)
x
=
x
.
trans
(
0
,
3
,
1
,
6
,
4
,
2
,
5
)
x
=
x
.
reshape
(
M
*
32
,
K
//
32
)
# implementation assumed mxfp data is packed along the last dimension
x
=
x
.
trans
()
if
mx_axis
==
0
else
x
return
x
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/hopper_value.py
0 → 100644
View file @
d29c39ca
import
torch
import
triton
import
triton.language
as
tl
from
.base
import
Layout
def
right_shift_unsigned
(
x
,
shift
):
return
(
x
>>
shift
)
&
((
1
<<
(
32
-
shift
))
-
1
)
# -----------------------------------------------------------------------
# Interleave the bits of four consecutive fp4 values (i.e. 16-bits) as:
# 1000000111000000 (first fp4)
# 1000000111000000 (second fp4)
# 1000000111000000 (third fp4)
# 0110110000000000 (fourth fp4)
# This is done so that dequantization can be done in 14 SASS instructions
# -----------------------------------------------------------------------
def
_compress_fp4
(
x
):
x
=
x
.
to
(
torch
.
int32
)
return
((
x
&
0x8
)
<<
12
)
|
((
x
&
0x7
)
<<
6
)
def
_compress_fourth
(
x
):
x
=
x
.
to
(
torch
.
int32
)
return
((
x
&
0x8
)
<<
11
)
|
((
x
&
0x6
)
<<
9
)
|
((
x
&
0x1
)
<<
13
)
def
_pack_bits
(
x
:
torch
.
Tensor
,
mx_axis
:
int
):
x
=
x
.
contiguous
()
assert
x
.
shape
[
-
1
]
%
4
==
0
,
(
"Input tensor must have a last dimension divisible by 4"
)
x
=
x
.
reshape
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
4
,
4
))
first
=
_compress_fp4
(
x
[...,
0
])
|
(
_compress_fp4
(
x
[...,
0
]
>>
4
)
<<
16
)
second
=
_compress_fp4
(
x
[...,
1
])
|
(
_compress_fp4
(
x
[...,
1
]
>>
4
)
<<
16
)
third
=
_compress_fp4
(
x
[...,
2
])
|
(
_compress_fp4
(
x
[...,
2
]
>>
4
)
<<
16
)
fourth
=
_compress_fourth
(
x
[...,
3
])
|
(
_compress_fourth
(
x
[...,
3
]
>>
4
)
<<
16
)
x
=
(
first
|
right_shift_unsigned
(
second
,
3
)
|
right_shift_unsigned
(
third
,
6
)
|
fourth
)
assert
x
.
is_contiguous
()
x
=
x
.
view
(
torch
.
uint8
)
return
x
# -----------------------------------------------------------------------
# inverse operation of _pack_bits
# -----------------------------------------------------------------------
def
_bf16_to_fp4e2m1
(
x
):
# 0bAxxxxxxBCDxxxxxx (int16) -> 0b0000ABCD (uint8)
assert
x
.
dtype
==
torch
.
int16
s
=
(
right_shift_unsigned
(
x
,
15
)
&
0x1
)
<<
3
em
=
right_shift_unsigned
(
x
,
6
)
&
0x7
return
(
s
|
em
).
to
(
torch
.
uint8
)
def
_bf16x2_to_fp4e2m1x2
(
x
):
# 0bAxxxxxxBCDxxxxxx_0bExxxxxxFGHxxxxxx (int32) -> 0bABCD_EFGH (uint8)
assert
x
.
dtype
==
torch
.
int32
lo
=
(
x
&
0xFFFF
).
to
(
torch
.
int16
)
hi
=
(
right_shift_unsigned
(
x
,
16
)
&
0xFFFF
).
to
(
torch
.
int16
)
ret_lo
=
_bf16_to_fp4e2m1
(
lo
)
ret_hi
=
_bf16_to_fp4e2m1
(
hi
)
return
ret_lo
|
(
ret_hi
<<
4
)
def
_unpack_bits
(
x
,
mx_axis
:
int
):
x
=
x
.
view
(
torch
.
int32
)
m
=
0b10000001110000001000000111000000
a
=
(
x
<<
1
)
&
0b10000000000000001000000000000000
b
=
right_shift_unsigned
(
x
,
3
)
&
0b00000001100000000000000110000000
c
=
right_shift_unsigned
(
x
,
7
)
&
0b00000000010000000000000001000000
unpacked
=
[
x
&
m
,
(
x
<<
3
)
&
m
,
(
x
<<
6
)
&
m
,
(
a
|
b
)
|
c
]
x
=
torch
.
stack
(
unpacked
,
dim
=-
1
)
x
=
x
.
flatten
(
-
2
,
-
1
)
x
=
_bf16x2_to_fp4e2m1x2
(
x
)
return
x
# -----------------------------------------------------------------------
class
HopperMXValueLayout
(
Layout
):
name
:
str
=
"HOPPER_VALUE"
def
__init__
(
self
,
shape
,
mx_axis
,
mma_version
=
3
):
super
().
__init__
(
shape
)
assert
mx_axis
in
range
(
len
(
shape
))
self
.
mx_axis
=
mx_axis
self
.
mma_version
=
mma_version
(
*
self
.
leading_shape
,
self
.
K
,
self
.
N
,
)
=
shape
def
_maybe_mT
(
self
,
data
):
if
self
.
mx_axis
==
len
(
self
.
leading_shape
):
return
data
.
mT
return
data
def
swizzle_data
(
self
,
data
):
"""
Given a uint8 tensor of shape (*, M, K), returns a tensor of shape
(*, M // 4, K * 4) such that:
1) Groups contiguously all the elements owned by the same thread of 4
mma tiles along the K axis. The following animation shows a similar
grouping for 2 tiles along M and 2 tiles along K rather than 4 along K
as done here:
https://neuralmagic.com/wp-content/uploads/2024/10/animation_4.gif
2) Moves the elements belonging to thread 4-7 to be contiguous with those
from thread 0-3. This is done to get a full cache line when loading them
from HBM.
mx_axis selects the lhs or rhs of the matmul.
WARNING: Assumes that the matmul will be done in bf16 or fp16!
Implementing it for fp8 is as easy as making the tile size (8, 8)
"""
batch
=
data
.
ndim
-
2
assert
batch
>=
0
assert
self
.
mma_version
in
(
2
,
3
)
data
=
self
.
_maybe_mT
(
data
)
init_shape
=
data
.
shape
# We are loading 8 bf16 elements per thread to use ld.global.v4
# Every u8 represents 2 mxfp4 elements
u8_kwidth
=
8
//
2
if
self
.
mma_version
==
2
else
1
# Pack the 4 // u8_kwidth subtiles of an mma into a u4x8
contig
=
(
1
,
u8_kwidth
)
scott_trick
=
(
2
,
1
)
threads
=
(
4
,
4
)
warp_tile
=
(
2
,
2
)
k_tile
=
(
1
,
4
//
u8_kwidth
)
sizes
=
list
(
data
.
shape
[:
-
2
])
pads
=
[]
# [rest, K, tile, threads] per dimension
for
i
,
(
a
,
b
,
c
,
s
,
d
)
in
enumerate
(
zip
(
k_tile
,
warp_tile
,
threads
,
scott_trick
,
contig
)
):
pack
=
a
*
b
*
c
*
s
*
d
size
=
data
.
shape
[
batch
+
i
]
pad
=
(
pack
-
size
%
pack
)
%
pack
pads
+=
[(
0
,
pad
)]
sizes
.
append
((
size
+
pad
)
//
pack
)
sizes
+=
[
a
,
b
,
c
,
s
,
d
]
pads
=
tuple
(
x
for
t
in
pads
[::
-
1
]
for
x
in
t
)
data
=
torch
.
nn
.
functional
.
pad
(
data
,
pads
)
init_shape
=
data
.
shape
# 0: rest[0]
# 1: k_tile[0]
# 2: warp_tile[0]
# 3: threads[0]
# 4: scott_trick[0]
# 5: contig[0]
# 6: rest[1]
# 7: k_tile[1]
# 8: warp_tile[1]
# 9: threads[1]
# 10: scott_trick[1]
# 11: contig[1]
data
=
data
.
view
(
*
sizes
)
# Want [rest[0], threads[0], rest[1], scott_trick[0], scott_trick[0], threads[1], contig[1], contig[0], k_tile[1], k_tile[0], warp_tile[1], warp_tile[0]]
perm
=
[
0
,
3
,
6
,
10
,
4
,
9
,
7
,
1
,
8
,
2
,
5
,
11
]
perm
=
list
(
range
(
batch
))
+
[
batch
+
p
for
p
in
perm
]
data
=
data
.
permute
(
*
perm
).
contiguous
()
# These are views
data
=
data
.
flatten
(
-
10
,
-
1
)
data
=
data
.
flatten
(
-
3
,
-
2
)
assert
data
.
is_contiguous
()
assert
data
.
shape
[
-
2
]
==
init_shape
[
-
2
]
//
4
assert
data
.
shape
[
-
1
]
==
init_shape
[
-
1
]
*
4
# twiddle the bits
data
=
_pack_bits
(
data
,
self
.
mx_axis
)
data
=
self
.
_maybe_mT
(
data
)
return
data
def
unswizzle_data
(
self
,
data
):
data
=
self
.
_maybe_mT
(
data
)
data
=
_unpack_bits
(
data
,
self
.
mx_axis
)
*
batch
,
M
,
K
=
data
.
shape
# We have two times the elements if we already upcasted to bfloat16
mult
=
2
if
data
.
dtype
==
torch
.
bfloat16
else
1
assert
M
%
4
==
0
,
"M must be divisible by 4"
assert
K
%
(
4
*
8
*
2
*
2
*
mult
)
==
0
,
(
f
"K must be divisible by
{
4
*
8
*
2
*
2
*
mult
}
"
)
# We are loading 8 bf16 elements per thread to use ld.global.v4
# Every u8 represents 2 mxfp4 elements
u8_kwidth
=
8
//
2
if
self
.
mma_version
==
2
else
1
data
=
data
.
reshape
(
*
batch
,
M
//
4
,
4
,
K
//
(
4
*
8
*
2
*
2
*
mult
),
2
,
4
,
8
//
u8_kwidth
,
2
,
u8_kwidth
*
mult
,
)
b
=
len
(
batch
)
perm
=
[
0
,
6
,
1
,
3
,
2
,
5
,
4
,
7
]
perm
=
list
(
range
(
b
))
+
[
b
+
p
for
p
in
perm
]
data
=
data
.
permute
(
*
perm
)
data
=
data
.
reshape
(
*
batch
,
M
*
4
,
K
//
4
)
data
=
self
.
_maybe_mT
(
data
)
return
data
[...,
:
self
.
K
,
:
self
.
N
]
def
swizzle_block_shape
(
self
,
block_shape
):
return
block_shape
@
triton
.
jit
def
_unshuffle_triton
(
x
,
mma_version
:
tl
.
constexpr
):
"""
Triton inverse of swizzle_mxfp4_value_hopper
"""
tl
.
static_assert
(
mma_version
==
2
or
mma_version
==
3
,
"mma_version must be 2 or 3"
)
# if mx_axis == 0:
# x = x.trans()
# We have two times the elements if we already upcasted to bfloat16
mult
:
tl
.
constexpr
=
2
if
x
.
dtype
==
tl
.
bfloat16
else
1
M
:
tl
.
constexpr
=
x
.
shape
[
0
]
K
:
tl
.
constexpr
=
x
.
shape
[
1
]
tl
.
static_assert
(
M
%
4
==
0
,
"M must be divisible by 4"
)
tl
.
static_assert
(
K
%
(
4
*
8
*
2
*
2
*
mult
)
==
0
,
f
"K must be divisible by
{
4
*
8
*
2
*
2
*
mult
}
"
,
)
# We are loading 8 bf16 elements per thread to use ld.global.v4
# Every u8 represents 2 mxfp4 elements
u8_kwidth
:
tl
.
constexpr
=
8
//
2
if
mma_version
==
2
else
1
x
=
x
.
reshape
(
M
//
4
,
4
,
K
//
(
4
*
8
*
2
*
2
*
mult
),
2
,
4
,
8
//
u8_kwidth
,
2
,
u8_kwidth
*
mult
,
)
x
=
x
.
trans
(
0
,
6
,
1
,
3
,
2
,
5
,
4
,
7
)
x
=
x
.
reshape
(
M
*
4
,
K
//
4
)
# if mx_axis == 0:
# x = x.trans()
return
x
@
triton
.
jit
def
_unpack_fp4_to_bf16_triton
(
x
):
# For now we implement just H100 support (mul.bf16x2)
# A100 support is possible via fma
r0
,
r1
=
tl
.
inline_asm_elementwise
(
r
"""
{
.reg .b32 b, c, d<7>, scale;
.reg .b32 bias;
mov.b32 bias, 0x7e807e80; // 2 ** 126 == 2 ** (bias_bf16 - bias_fp2)
// We add the missing bias to the scale directly
and.b32 $0, $4, 0b10000001110000001000000111000000;
mul.bf16x2 $0, $0, bias;
shl.b32 b, $4, 3;
and.b32 $1, b, 0b10000001110000001000000111000000;
mul.bf16x2 $1, $1, bias;
shl.b32 c, $4, 6;
and.b32 $2, c, 0b10000001110000001000000111000000;
mul.bf16x2 $2, $2, bias;
// Unpack last two elements
shl.b32 d0, $4, 1;
and.b32 d1, d0, 0b10000000000000001000000000000000;
shr.b32 d2, $4, 3;
and.b32 d3, d2, 0b00000001100000000000000110000000;
or.b32 d4, d1, d3;
shr.b32 d5, $4, 7;
and.b32 d6, d5, 0b00000000010000000000000001000000;
or.b32 $3, d4, d6;
mul.bf16x2 $3, $3, bias;
}
"""
,
constraints
=
"=r,=r,=r,=r,r"
,
args
=
[
x
],
dtype
=
(
tl
.
bfloat16
,
tl
.
bfloat16
),
is_pure
=
True
,
pack
=
4
,
)
# Concat each pack of 4
x
=
tl
.
join
(
r0
,
r1
)
x
=
x
.
reshape
(
x
.
shape
[
0
],
x
.
shape
[
1
]
//
4
,
4
,
x
.
shape
[
2
])
x
=
x
.
trans
(
0
,
1
,
3
,
2
)
x
=
x
.
reshape
(
x
.
shape
[
0
],
x
.
shape
[
1
]
*
x
.
shape
[
2
]
*
x
.
shape
[
3
])
return
x
@
triton
.
jit
def
mxfp4_to_bf16_triton
(
x
,
scale
,
mx_axis
:
tl
.
constexpr
):
"""
Implements the bit-untwiddling of a 32-bit integer (8 mxfp4 elements):
(x << 0) & 0b1000000111000000
(x << 3) & 0b1000000111000000
(x << 6) & 0b1000000111000000
((x << 1) & 0b1000000000000000) | ((x >> 3) & 0b0000000110000000) | ((x >> 7) & 0b0000000001000000)
"""
# upcast values to bfloat16
tl
.
static_assert
(
len
(
x
.
shape
)
==
2
)
tl
.
static_assert
(
mx_axis
==
0
or
mx_axis
==
1
,
"mx_axis must be 0 or 1"
)
tl
.
static_assert
(
x
.
shape
[
1
]
%
4
==
0
)
tl
.
static_assert
(
x
.
dtype
==
tl
.
uint8
)
if
mx_axis
==
0
:
x
=
x
.
trans
()
x
=
_unpack_fp4_to_bf16_triton
(
x
)
x
=
_unshuffle_triton
(
x
,
mma_version
=
3
)
if
mx_axis
==
0
:
x
=
x
.
trans
()
# upcast scale to bfloat16
# Add bias missing from the bf16 upcasting sequence
# triton / LLVM generates terrible code for this sequence
# scale = scale.to(tl.uint16)
# scale = scale << 7
# scale = scale.to(tl.bfloat16, bitcast=True)
scale
=
tl
.
inline_asm_elementwise
(
r
"""
{
prmt.b32 $0, $2, 0, 0x5140;
shl.b32 $0, $0, 7;
prmt.b32 $1, $2, 0, 0x7362;
shl.b32 $1, $1, 7;
}
"""
,
constraints
=
"=r,=r,r"
,
args
=
[
scale
],
dtype
=
tl
.
bfloat16
,
is_pure
=
True
,
pack
=
4
,
)
# Broadcast scale
scale
=
scale
.
expand_dims
(
mx_axis
+
1
)
scale
=
scale
.
broadcast_to
(
scale
.
shape
[:
mx_axis
+
1
]
+
[
32
]
+
scale
.
shape
[
mx_axis
+
2
:]
)
scale
=
scale
.
reshape
(
x
.
shape
)
# Combine scale and x
x
=
x
*
scale
return
x
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/strided.py
0 → 100644
View file @
d29c39ca
from
.base
import
Layout
class
StridedLayout
(
Layout
):
name
:
str
=
None
def
__init__
(
self
,
shape
)
->
None
:
super
().
__init__
(
shape
)
def
swizzle_data
(
self
,
data
):
return
data
def
unswizzle_data
(
self
,
data
):
return
data
def
swizzle_block_shape
(
self
,
block_shape
):
return
block_shape
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/testing.py
0 → 100644
View file @
d29c39ca
import
enum
import
functools
import
os
import
subprocess
import
sys
import
torch
from
compactor_vllm.triton_kernels.numerics
import
(
MAX_FINITE_FLOAT8E4B8
,
MAX_FINITE_FLOAT8E4NV
,
MAX_FINITE_FLOAT8E5
,
)
def
assert_equal
(
ref
,
tri
):
if
isinstance
(
ref
,
torch
.
Tensor
):
assert
torch
.
all
(
ref
==
tri
)
else
:
assert
ref
==
tri
def
assert_close
(
ref
,
tri
,
maxtol
=
None
,
rmstol
=
None
,
description
=
"--"
,
verbose
=
True
):
if
tri
.
dtype
.
itemsize
==
1
:
ref_as_type
=
ref
.
to
(
tri
.
dtype
)
if
ref
.
dtype
==
tri
.
dtype
:
assert
torch
.
all
(
ref_as_type
==
tri
)
return
ref
=
ref_as_type
if
ref
.
numel
()
==
0
:
return
if
maxtol
is
None
:
maxtol
=
2e-2
if
rmstol
is
None
:
rmstol
=
4e-3
"""
Compare reference values against obtained values.
"""
# cast to float32:
ref
=
ref
.
to
(
torch
.
float32
).
detach
()
tri
=
tri
.
to
(
torch
.
float32
).
detach
()
assert
ref
.
shape
==
tri
.
shape
,
(
f
"Tensors must have same size
{
ref
.
shape
=
}
{
tri
.
shape
=
}
"
)
# deal with infinite elements:
inf_mask_ref
=
torch
.
isinf
(
ref
)
inf_mask_tri
=
torch
.
isinf
(
tri
)
assert
torch
.
equal
(
inf_mask_ref
,
inf_mask_tri
),
(
"Tensor must have same infinite elements"
)
refn
=
torch
.
where
(
inf_mask_ref
,
0
,
ref
)
trin
=
torch
.
where
(
inf_mask_tri
,
0
,
tri
)
# normalise so that RMS calculation doesn't overflow:
eps
=
1.0e-30
multiplier
=
1.0
/
(
torch
.
max
(
torch
.
abs
(
refn
))
+
eps
)
refn
*=
multiplier
trin
*=
multiplier
ref_rms
=
torch
.
sqrt
(
torch
.
square
(
refn
).
mean
())
+
eps
rel_err
=
torch
.
abs
(
refn
-
trin
)
/
torch
.
maximum
(
ref_rms
,
torch
.
abs
(
refn
))
max_err
=
torch
.
max
(
rel_err
).
item
()
rms_err
=
torch
.
sqrt
(
torch
.
square
(
rel_err
).
mean
()).
item
()
if
verbose
:
print
(
"%s maximum relative error = %s (threshold = %s)"
%
(
description
,
max_err
,
maxtol
)
)
print
(
"%s RMS relative error = %s (threshold = %s)"
%
(
description
,
rms_err
,
rmstol
)
)
if
max_err
>
maxtol
:
bad_idxs
=
torch
.
nonzero
(
rel_err
>
maxtol
)
num_nonzero
=
bad_idxs
.
size
(
0
)
bad_idxs
=
bad_idxs
[:
1000
]
print
(
"%d / %d mismatched elements (shape = %s) at coords %s"
%
(
num_nonzero
,
rel_err
.
numel
(),
tuple
(
rel_err
.
shape
),
bad_idxs
.
tolist
())
)
bad_idxs
=
bad_idxs
.
unbind
(
-
1
)
print
(
"ref values: "
,
ref
[
tuple
(
bad_idxs
)].
cpu
())
print
(
"tri values: "
,
tri
[
tuple
(
bad_idxs
)].
cpu
())
assert
max_err
<=
maxtol
assert
rms_err
<=
rmstol
class
ComputeSanitizerTool
(
enum
.
Enum
):
MEMCHECK
=
"memcheck"
RACECHECK
=
"racecheck"
SYNCCHECK
=
"synccheck"
INITCHECK
=
"initcheck"
def
compute_sanitizer
(
**
target_kwargs
):
"""
Decorator to run a test with compute sanitizer enabled and pytorch caching allocator disabled,
to expose potential memory access errors.
This decorator requires the `request` fixture to be present.
If `run_sanitizer` argument is present and set to False, the sanitizer is not run.
Running tests under compute sanitizer requires launching subprocess and is slow,
so use sparingly
"""
def
decorator
(
test_fn
):
@
functools
.
wraps
(
test_fn
)
def
wrapper
(
*
args
,
**
kwargs
):
if
os
.
environ
.
get
(
"SKIP_COMPUTE_SANITIZER"
)
==
"1"
:
test_fn
(
*
args
,
**
kwargs
)
return
import
psutil
if
target_kwargs
.
pop
(
"clear_torch_cache"
,
False
):
# If we don't pop clear_torch_cache, it won't pass
# target_kwargs.items() <= kwargs.items() condition below.
torch
.
cuda
.
empty_cache
()
tools_to_check
=
target_kwargs
.
pop
(
"tools_to_check"
,
[
ComputeSanitizerTool
.
MEMCHECK
]
)
assert
isinstance
(
tools_to_check
,
list
),
f
"
{
tools_to_check
=
}
"
assert
all
(
tool
in
ComputeSanitizerTool
for
tool
in
tools_to_check
),
(
f
"
{
(
tool
for
tool
in
tools_to_check
if
tool
not
in
ComputeSanitizerTool
)
=
}
"
)
ppid_name
=
psutil
.
Process
(
os
.
getppid
()).
exe
()
run_compute_sanitizer
=
target_kwargs
.
items
()
<=
kwargs
.
items
()
if
"run_sanitizer"
in
kwargs
:
run_compute_sanitizer
&=
kwargs
[
"run_sanitizer"
]
if
run_compute_sanitizer
and
"compute-sanitizer"
not
in
ppid_name
:
for
tool
in
tools_to_check
:
path
=
os
.
path
.
realpath
(
test_fn
.
__globals__
[
"__file__"
])
# get path of current file
env
=
{
"PATH"
:
os
.
environ
[
"PATH"
],
"PYTORCH_NO_CUDA_MEMORY_CACHING"
:
"1"
,
"TORCH_SHOW_CPP_STACKTRACES"
:
"1"
,
"CUDA_LAUNCH_BLOCKING"
:
"1"
,
}
if
"CUDA_VISIBLE_DEVICES"
in
os
.
environ
:
env
[
"CUDA_VISIBLE_DEVICES"
]
=
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
assert
"request_fixture"
in
kwargs
,
(
"memcheck'ed test must have a (possibly unused) `request` fixture"
)
test_id
=
kwargs
[
"request_fixture"
].
node
.
callspec
.
id
cmd
=
f
"
{
path
}
::
{
test_fn
.
__name__
}
[
{
test_id
}
]"
cmd
=
[
"compute-sanitizer"
,
"--target-processes=application-only"
,
"--destroy-on-device-error=context"
,
f
"--tool=
{
tool
.
value
}
"
,
sys
.
executable
,
"-m"
,
"pytest"
,
"-vsx"
,
cmd
,
]
for
opt
in
[
"--update_checksum"
,
"--ignore_checksum_error"
]:
if
opt
in
sys
.
argv
:
cmd
.
append
(
opt
)
out
=
subprocess
.
run
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
STDOUT
,
env
=
env
,
)
sanitizer_ok
=
"ERROR SUMMARY: 0 errors"
in
str
(
out
.
stdout
)
or
"RACECHECK SUMMARY: 0 hazards displayed"
in
str
(
out
.
stdout
)
test_output
=
out
.
stdout
if
type
(
test_output
)
is
bytes
:
test_output
=
test_output
.
decode
()
fail
=
False
if
not
sanitizer_ok
:
print
(
"compute-sanitizer returned an error"
)
fail
=
True
elif
out
.
returncode
!=
0
:
print
(
"The test failed due to some other reason: consider running without compute-sanitizer to verify."
)
print
(
f
"
{
out
.
returncode
=
}
"
)
fail
=
True
if
fail
:
print
(
"*****************************************************"
)
print
(
"******************** TEST OUTPUT ********************"
)
print
(
"*****************************************************"
)
print
(
test_output
)
print
(
"*****************************************************"
)
print
(
"****************** TEST OUTPUT END ******************"
)
print
(
"*****************************************************"
)
assert
None
else
:
test_fn
(
*
args
,
**
kwargs
)
return
wrapper
return
decorator
def
compute_actual_scale
(
x
,
dtype
):
max_finite
=
{
torch
.
float8_e5m2
:
MAX_FINITE_FLOAT8E5
,
torch
.
float8_e4m3fn
:
MAX_FINITE_FLOAT8E4NV
,
torch
.
float8_e4m3fnuz
:
MAX_FINITE_FLOAT8E4B8
,
}[
dtype
]
return
x
.
abs
().
max
()
/
max_finite
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/topk.py
0 → 100644
View file @
d29c39ca
import
torch
import
triton
from
compactor_vllm.triton_kernels.topk_details._topk_forward
import
_topk_forward
from
compactor_vllm.triton_kernels.topk_details
import
_topk_backward
from
compactor_vllm.triton_kernels.tensor
import
Tensor
,
Bitmatrix
from
typing
import
Optional
,
Union
def
topk_forward
(
x
,
k
,
apply_softmax
=
True
,
dim
=
1
,
return_bitmatrix
=
True
,
y_indx
=
None
,
n_rows
=
None
):
if
not
isinstance
(
x
,
Tensor
):
x_shape
=
[
x
.
shape
[
0
]
if
n_rows
is
None
else
n_rows
,
x
.
shape
[
1
]]
x_shape_max
=
[
x
.
shape
[
0
],
x
.
shape
[
1
]]
x
=
Tensor
(
x
,
shape
=
x_shape
,
shape_max
=
x_shape_max
)
cdiv
=
lambda
a
,
b
:
(
a
+
b
-
1
)
//
b
BLOCK_M
=
32
BLOCK_N
=
32
BLOCK_S
=
128
assert
len
(
x
.
shape
)
==
2
assert
x
.
shape_max
[
-
1
]
<
32768
assert
dim
==
1
assert
return_bitmatrix
n_rows
,
n_cols
=
x
.
shape
n_rows_max
,
_
=
x
.
shape_max
dev
=
x
.
device
# scratchpad tensors
# NOTE: these are not returned
y_vals
=
torch
.
empty
((
n_rows_max
,
k
),
dtype
=
x
.
dtype
,
device
=
dev
)
if
y_indx
is
not
None
:
use_provided_indx
=
True
else
:
y_indx
=
torch
.
empty
((
n_rows_max
,
k
),
dtype
=
torch
.
int16
,
device
=
dev
)
use_provided_indx
=
False
# create bitmatrix in transposed memory layout:
n_cols_pad
=
cdiv
(
n_cols
,
BLOCK_N
)
*
BLOCK_N
n_cols_words
=
n_cols_pad
//
32
bitmatrix
=
torch
.
empty
(
(
n_cols_words
,
cdiv
(
n_rows_max
,
32
)
*
32
),
dtype
=
torch
.
uint32
,
device
=
dev
)
bitmatrix
=
torch
.
transpose
(
bitmatrix
,
0
,
1
)[:
n_rows_max
]
s_blocks
=
cdiv
(
n_cols
,
BLOCK_S
)
s_cols
=
s_blocks
*
BLOCK_S
scratchpad
=
torch
.
empty
((
s_cols
,),
dtype
=
torch
.
int32
,
device
=
dev
)
pids
=
max
(
cdiv
(
n_rows_max
,
BLOCK_M
),
s_blocks
)
_topk_forward
[(
pids
,)](
x
,
x
.
stride
(
0
),
# inputs
y_vals
,
y_indx
,
y_vals
.
stride
(
0
),
use_provided_indx
,
# output [topk]
bitmatrix
,
bitmatrix
.
stride
(
0
),
bitmatrix
.
stride
(
1
),
# output [bitmatrix]
n_rows
,
n_cols
,
# shapes
scratchpad
,
BLOCK_S
,
s_blocks
,
# thing to memset to zero
BLOCK_M
=
BLOCK_M
,
BLOCK_N
=
BLOCK_N
,
# tunable parameter
APPLY_SOFTMAX
=
apply_softmax
,
N_EXPTS_PAD
=
n_cols_pad
,
N_EXPTS_ACT
=
k
,
# constants
)
bitmatrix_shape
=
[
n_rows
,
n_cols_words
*
32
]
bitmatrix_shape_max
=
[
n_rows_max
,
None
]
bitmatrix
=
Bitmatrix
(
bitmatrix
,
shape
=
bitmatrix_shape
,
shape_max
=
bitmatrix_shape_max
,
scratchpad
=
scratchpad
,
)
return
y_vals
,
y_indx
,
bitmatrix
def
topk_backward
(
x
,
y_indx
,
dy_vals
,
k
,
n_rows
,
apply_softmax
):
assert
dy_vals
.
shape
[
-
1
]
==
k
n_expts_pad
=
triton
.
next_power_of_2
(
x
.
shape
[
-
1
])
dx
=
torch
.
empty_like
(
x
)
_topk_backward
[(
dy_vals
.
shape
[
0
],)](
y_indx
,
y_indx
.
stride
(
0
),
dy_vals
,
dy_vals
.
stride
(
0
),
x
,
x
.
stride
(
0
),
# inputs
dx
,
# outputs
dx
.
stride
(
0
),
x
.
shape
[
0
],
n_rows
,
x
.
shape
[
-
1
],
APPLY_SOFTMAX
=
apply_softmax
,
N_EXPTS_ACT
=
k
,
N_EXPTS_PAD
=
n_expts_pad
,
)
return
dx
class
TopK
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
k
,
apply_softmax
,
dim
,
return_bitmatrix
,
y_indx
,
n_rows
):
y_vals
,
y_indx
,
bitmatrix
=
topk_forward
(
x
,
k
,
apply_softmax
,
dim
,
return_bitmatrix
,
y_indx
,
n_rows
)
ctx
.
save_for_backward
(
x
,
y_indx
)
ctx
.
apply_softmax
=
apply_softmax
ctx
.
k
=
k
ctx
.
n_rows
=
n_rows
return
y_vals
,
y_indx
,
bitmatrix
@
staticmethod
def
backward
(
ctx
,
dy_vals
,
_0
,
_1
):
x
,
y_indx
=
ctx
.
saved_tensors
dx
=
topk_backward
(
x
,
y_indx
,
dy_vals
,
ctx
.
k
,
ctx
.
n_rows
,
ctx
.
apply_softmax
)
return
dx
,
None
,
None
,
None
,
None
,
None
,
None
def
topk
(
x
:
Union
[
Tensor
,
torch
.
Tensor
],
k
:
int
,
apply_softmax
:
bool
=
True
,
dim
:
int
=
1
,
return_bitmatrix
:
bool
=
True
,
y_indx
:
Optional
[
torch
.
Tensor
]
=
None
,
n_rows
:
Optional
[
int
]
=
None
,
):
"""
Computes the top-k values and indices along a specified dimension of a tensor.
Note that the input can be either a `Tensor` or a `torch.Tensor`, but the output will always be a `torch.Tensor`.
Parameters
----------
x : Union[triton_kernels.Tensor, torch.Tensor]
Input tensor of shape (n_tokens, n_expts).
k : int
Number of top elements to retrieve.
apply_softmax : bool, default True
Whether to apply softmax to the input tensor before computing top-k.
dim : int, default 1
Dimension along which to compute top-k.
return_bitmatrix : bool, default True
A bitmatrix of shape (n_tokens, cdiv(n_expts, 32)).
Each bit on [t, b] indicates whether the b-th expert was selected for the t-th token.
y_indx : torch.Tensor, optional
Pre-allocated tensor for storing indices of top-k elements with shape (n_tokens, k).
If provided, we skip the computation of top-k indices and use this tensor instead.
n_rows : int, optional
Number of rows to apply top-k on. If None, we consider all rows in `x`.
Returns
-------
(expt_scal, expt_indx, bitmatrix) : Tuple[torch.Tensor, torch.Tensor, Bitmatrix]
"""
ret
=
TopK
.
apply
(
x
,
k
,
apply_softmax
,
dim
,
return_bitmatrix
,
y_indx
,
n_rows
)
return
ret
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/topk_details/__init__.py
0 → 100644
View file @
d29c39ca
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/topk_details/_topk_backward.py
0 → 100644
View file @
d29c39ca
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
_topk_backward
(
Yi
,
stride_ym
,
# topk indices
DY
,
stride_dym
,
# output gradient values
X
,
stride_xm
,
# input values
DX
,
stride_dxm
,
# input gradient values
n_rows
,
NRows
,
n_expts_tot
,
APPLY_SOFTMAX
:
tl
.
constexpr
,
N_EXPTS_ACT
:
tl
.
constexpr
,
N_EXPTS_PAD
:
tl
.
constexpr
,
):
pid_m
=
tl
.
program_id
(
0
)
if
NRows
is
not
None
:
n_rows
=
tl
.
load
(
NRows
)
if
pid_m
>=
n_rows
:
return
Yi
+=
pid_m
*
stride_ym
DY
+=
pid_m
*
stride_dym
X
+=
pid_m
*
stride_xm
DX
+=
pid_m
*
stride_dxm
# --
offs_xn
=
tl
.
arange
(
0
,
N_EXPTS_PAD
)
offs_yn
=
tl
.
arange
(
0
,
N_EXPTS_ACT
)
mask_xn
=
offs_xn
<
n_expts_tot
# recompute softmax
y_indx
=
tl
.
load
(
Yi
+
offs_yn
)
x
=
tl
.
load
(
X
+
y_indx
)
x
=
x
.
to
(
tl
.
float32
)
y
=
tl
.
softmax
(
x
)
# compute input-gradient
dy
=
tl
.
load
(
DY
+
offs_yn
)
dy
=
dy
.
to
(
tl
.
float32
)
s
=
tl
.
sum
(
y
*
dy
,
0
)
# write-back input gradient
tl
.
store
(
DX
+
offs_xn
,
0
,
mask
=
mask_xn
)
tl
.
debug_barrier
()
if
APPLY_SOFTMAX
:
dx
=
y
*
(
dy
-
s
)
else
:
dx
=
dy
tl
.
store
(
DX
+
y_indx
,
dx
)
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/topk_details/_topk_forward.py
0 → 100644
View file @
d29c39ca
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
get_topmask_and_fullmask
(
x
):
tl
.
static_assert
(
x
.
dtype
.
is_int_unsigned
(),
"floating-point value must be passed as bits"
)
tm
:
tl
.
constexpr
=
1
<<
(
-
1
+
x
.
dtype
.
primitive_bitwidth
)
fm
:
tl
.
constexpr
=
(
1
<<
x
.
dtype
.
primitive_bitwidth
)
-
1
tm_arr
=
tl
.
full
(
x
.
shape
,
tm
,
dtype
=
x
.
dtype
)
fm_arr
=
tl
.
full
(
x
.
shape
,
fm
,
dtype
=
x
.
dtype
)
return
tm_arr
,
fm_arr
@
triton
.
jit
def
fpval_to_key
(
x
):
tm
,
fm
=
get_topmask_and_fullmask
(
x
)
return
x
^
tl
.
where
((
x
&
tm
)
!=
0
,
fm
,
tm
)
@
triton
.
jit
def
key_to_fpval
(
x
):
tm
,
fm
=
get_topmask_and_fullmask
(
x
)
return
x
^
tl
.
where
((
x
&
tm
)
==
0
,
fm
,
tm
)
# stable top-k tie-breaks to value with smaller index
@
triton
.
jit
def
indx_to_key
(
indx
,
N_EXPTS_PAD
:
tl
.
constexpr
):
return
N_EXPTS_PAD
-
indx
@
triton
.
jit
def
key_to_indx
(
indx
,
N_EXPTS_PAD
:
tl
.
constexpr
):
return
N_EXPTS_PAD
-
indx
@
triton
.
jit
def
streaming_topk
(
X
,
stride_xm
,
n_expts_tot
,
offs_m
,
mask_m
,
N_EXPTS_PAD
:
tl
.
constexpr
,
N_EXPTS_ACT
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
x_nbits
:
tl
.
constexpr
=
X
.
dtype
.
element_ty
.
primitive_bitwidth
x_utype
:
tl
.
constexpr
=
tl
.
dtype
(
f
"uint
{
x_nbits
}
"
)
if
x_nbits
<
16
:
# this ensures that we leave at least 16 bits for expert index
# even if the input dtype is smaller than 16 bits:
y_nbits
:
tl
.
constexpr
=
32
else
:
y_nbits
:
tl
.
constexpr
=
x_nbits
*
2
x_ultype
:
tl
.
constexpr
=
tl
.
dtype
(
f
"uint
{
y_nbits
}
"
)
x_dtype
:
tl
.
constexpr
=
X
.
dtype
.
element_ty
# subtract 1 from loop iterations because we peel the first (masked) iteration:
loop_iterations
:
tl
.
constexpr
=
N_EXPTS_PAD
//
BLOCK_N
-
1
offs_x_n
=
loop_iterations
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
mask_n
=
offs_x_n
[
None
,
:]
<
n_expts_tot
# first iteration:
X_ptrs
=
X
+
offs_m
[:,
None
]
*
stride_xm
+
offs_x_n
[
None
,
:]
x
=
tl
.
load
(
X_ptrs
,
mask
=
(
mask_m
&
mask_n
),
other
=
float
(
"-inf"
))
x
=
fpval_to_key
(
x
.
to
(
x_utype
,
bitcast
=
True
))
x
=
(
x
.
to
(
x_ultype
)
<<
16
)
|
indx_to_key
(
offs_x_n
,
N_EXPTS_PAD
)[
None
,
:]
acc
=
tl
.
topk
(
x
,
N_EXPTS_ACT
,
dim
=
1
)
# subsequent iterations:
for
_i
in
(
tl
.
static_range
if
loop_iterations
<=
4
else
range
)(
loop_iterations
):
acc
=
tl
.
bitonic_merge
(
acc
)
# ensure sorted ascending for the merge
X_ptrs
-=
BLOCK_N
offs_x_n
-=
BLOCK_N
x
=
tl
.
load
(
X_ptrs
,
mask
=
mask_m
,
other
=
float
(
"-inf"
))
x
=
fpval_to_key
(
x
.
to
(
x_utype
,
bitcast
=
True
))
x
=
(
x
.
to
(
x_ultype
)
<<
16
)
|
indx_to_key
(
offs_x_n
,
N_EXPTS_PAD
)[
None
,
:]
acc
=
tl
.
maximum
(
acc
,
tl
.
topk
(
x
,
N_EXPTS_ACT
,
dim
=
1
))
# rotate expert index into upper 16 bits:
# 0000vvvvvvvviiii --> iiii0000vvvvvvvv
acc
=
(
acc
<<
(
y_nbits
-
16
))
|
(
acc
>>
16
)
# sort in ascending order of expert (descending order of key)
acc
=
tl
.
sort
(
acc
,
dim
=
1
,
descending
=
True
)
# iiii0000vvvvvvvv --> 0000iiii:
y_indices_raw
=
(
acc
>>
(
y_nbits
-
16
)).
to
(
tl
.
uint32
)
y_indices
=
key_to_indx
(
y_indices_raw
,
N_EXPTS_PAD
)
# iiii0000vvvvvvvv --> vvvvvvvv:
y_values_raw
=
acc
.
to
(
x_utype
)
y_values
=
key_to_fpval
(
y_values_raw
).
to
(
x_dtype
,
bitcast
=
True
)
return
y_values
,
y_indices
@
triton
.
jit
def
_topk_forward
(
X
,
stride_xm
,
# inputs
Yv
,
Yi
,
stride_ym
,
# topk values/indices
USE_PROVIDED_INDX
:
tl
.
constexpr
,
Bits
,
stride_rm
:
tl
.
constexpr
,
stride_rn
:
tl
.
constexpr
,
# bitmatrix
n_rows
,
n_expts_tot
,
# shape
S
,
BLOCK_S
:
tl
.
constexpr
,
s_blocks
,
# thing to memset
APPLY_SOFTMAX
:
tl
.
constexpr
,
# constant
BLOCK_M
:
tl
.
constexpr
,
N_EXPTS_PAD
:
tl
.
constexpr
,
N_EXPTS_ACT
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
if
isinstance
(
n_rows
,
tl
.
tensor
)
and
n_rows
.
dtype
.
is_ptr
():
n_rows
=
tl
.
load
(
n_rows
)
if
pid
<
s_blocks
:
tl
.
store
(
S
+
BLOCK_S
*
pid
+
tl
.
arange
(
0
,
BLOCK_S
),
tl
.
zeros
([
BLOCK_S
],
tl
.
int32
)
)
if
pid
*
BLOCK_M
>=
n_rows
:
# early exit:
return
tl
.
static_assert
(
BLOCK_N
%
32
==
0
)
tl
.
static_assert
(
N_EXPTS_PAD
%
BLOCK_N
==
0
)
x_dtype
:
tl
.
constexpr
=
X
.
dtype
.
element_ty
# load logits
offs_m
=
pid
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_y_n
=
tl
.
arange
(
0
,
N_EXPTS_ACT
)
mask_m
=
offs_m
[:,
None
]
<
n_rows
if
USE_PROVIDED_INDX
:
Yi_ptrs
=
Yi
+
offs_m
[:,
None
]
*
stride_ym
+
offs_y_n
[
None
,
:]
y_indices
=
tl
.
load
(
Yi_ptrs
,
mask
=
mask_m
)
Xv_ptrs
=
X
+
offs_m
[:,
None
]
*
stride_xm
+
y_indices
y_values
=
tl
.
load
(
Xv_ptrs
,
mask
=
mask_m
)
else
:
y_values
,
y_indices
=
streaming_topk
(
X
,
stride_xm
,
n_expts_tot
,
offs_m
,
mask_m
,
#
N_EXPTS_PAD
,
N_EXPTS_ACT
,
BLOCK_N
,
)
# normalize selected values
if
APPLY_SOFTMAX
:
y_values
=
tl
.
softmax
(
y_values
.
to
(
tl
.
float32
),
dim
=
1
,
keep_dims
=
True
).
to
(
x_dtype
)
# write back
Yv_ptrs
=
Yv
+
offs_m
[:,
None
]
*
stride_ym
+
offs_y_n
[
None
,
:]
tl
.
store
(
Yv_ptrs
,
y_values
,
mask
=
mask_m
)
if
not
USE_PROVIDED_INDX
:
Yi_ptrs
=
Yi
+
offs_m
[:,
None
]
*
stride_ym
+
offs_y_n
[
None
,
:]
tl
.
store
(
Yi_ptrs
,
y_indices
,
mask
=
mask_m
)
# pack into bitmatrix
y_div
=
y_indices
//
32
y_rem
=
y_indices
%
32
loop_iterations
=
N_EXPTS_PAD
//
BLOCK_N
for
i
in
range
(
loop_iterations
):
offs_r_n
=
tl
.
arange
(
0
,
BLOCK_N
//
32
)
+
i
*
(
BLOCK_N
//
32
)
y2
=
tl
.
where
(
y_div
[:,
:,
None
]
==
offs_r_n
[
None
,
None
,
:],
(
1
<<
y_rem
)[:,
:,
None
],
0
)
r
=
tl
.
reduce_or
(
y2
,
axis
=
1
)
BitsPtrs
=
Bits
+
offs_m
[:,
None
]
*
stride_rm
+
offs_r_n
[
None
,
:]
*
stride_rn
tl
.
store
(
BitsPtrs
,
r
,
mask
=
mask_m
)
vllm/compactor-vllm/src/compactor_vllm/utils/__init__.py
0 → 100644
View file @
d29c39ca
vllm/compactor-vllm/src/compactor_vllm/utils/arguments.py
0 → 100644
View file @
d29c39ca
import
itertools
import
math
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
import
torch
import
torch.distributed
as
dist
from
compactor_vllm.compression
import
CompressionMethod
from
compactor_vllm.compression.compression_config
import
BatchCompressionParams
from
compactor_vllm.config.engine_config
import
LLMConfig
from
compactor_vllm.utils.sequence
import
Sequence
@
dataclass
class
PrefillBatchArguments
:
B
:
int
N
:
int
do_compression
:
bool
compression_method
:
CompressionMethod
compression_chunk_size
:
int
seq_ids
:
torch
.
Tensor
input_ids
:
torch
.
Tensor
positions
:
torch
.
Tensor
cu_seqlens_q
:
torch
.
Tensor
cu_seqlens_k
:
torch
.
Tensor
max_seqlen_q
:
int
max_seqlen_k
:
int
batch_tokens_to_retain
:
Optional
[
torch
.
Tensor
]
max_tokens_to_retain
:
Optional
[
int
]
protected_first
:
Optional
[
List
[
int
]]
protected_last
:
Optional
[
List
[
int
]]
PHI
:
Optional
[
torch
.
Tensor
]
# args needed for memory reservation
context_lens
:
torch
.
Tensor
max_new_tokens
:
torch
.
Tensor
# 与 kvpress ``CompactorPress`` blending 默认(未显式指定时用 compression_ratio)对齐
compression_ratio
:
float
=
1.0
class
PackedTensorArguments
:
def
__init__
(
self
,
rank
:
int
,
max_batched_tokens
:
int
,
config
:
LLMConfig
,
seed
:
int
=
42
)
->
None
:
hf_config
=
config
.
hf_config
self
.
rank
=
rank
self
.
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
self
.
max_num_batches
=
config
.
max_num_seqs
self
.
max_batched_tokens
=
max_batched_tokens
self
.
num_kv_heads
=
hf_config
.
num_key_value_heads
//
dist
.
get_world_size
()
self
.
world_size
=
config
.
tensor_parallel_size
self
.
page_size
=
int
(
config
.
kvcache_page_size
)
self
.
head_dim
=
getattr
(
hf_config
,
"head_dim"
,
None
)
self
.
sketch_dim
=
config
.
leverage_sketch_size
self
.
model_dtype
=
hf_config
.
torch_dtype
# i64 pack = [seq_ids (BMAX)] || [input_ids (NMAX)] || [positions (NMAX)] || max_new_tok (BMAX)
self
.
i64_len_max
=
(
self
.
max_num_batches
+
2
*
self
.
max_batched_tokens
+
self
.
max_num_batches
)
self
.
packed_context_i64
=
torch
.
empty
(
self
.
i64_len_max
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
# i32 pack = [header (6): ... + compression_ratio*1e6] || [cu_q (BMAX+1)] || ...
# || [protected_first_tokens (BMAX)] || [protected_last_tokens (BMAX)]
self
.
i32_len_max
=
(
6
+
(
self
.
max_num_batches
+
1
)
+
(
self
.
max_num_batches
+
1
)
+
self
.
max_num_batches
+
self
.
max_num_batches
+
self
.
max_num_batches
+
self
.
max_num_batches
)
self
.
packed_context_i32
=
torch
.
empty
(
self
.
i32_len_max
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
seed
)
self
.
PHI
=
torch
.
randn
(
(
self
.
head_dim
,
self
.
sketch_dim
),
device
=
self
.
packed_context_i32
.
device
,
generator
=
self
.
generator
,
).
to
(
self
.
model_dtype
)
*
(
1
/
math
.
sqrt
(
self
.
sketch_dim
))
def
_master_build_prefill
(
self
,
seqs
:
List
[
Sequence
],
batch_compression_params
:
BatchCompressionParams
)
->
PrefillBatchArguments
:
B
=
len
(
seqs
)
Ls
=
[
x
.
prompt_len
for
x
in
seqs
]
N
=
sum
(
Ls
)
assert
N
<=
self
.
max_batched_tokens
do_compression
=
any
(
x
.
compression_params
.
compression_ratio
<
1.0
for
x
in
seqs
)
do_compression
=
(
do_compression
and
batch_compression_params
.
compression_method
!=
CompressionMethod
.
NONE
)
pack_slices_64
=
self
.
packed_i64_slices
(
B
,
N
)
pack_slices_32
=
self
.
packed_i32_slices
(
B
)
# max_retain = max(retain)
protected_first_list
=
[
x
.
compression_params
.
protected_first_tokens
for
x
in
seqs
]
protected_last_list
=
[
x
.
compression_params
.
protected_last_tokens
for
x
in
seqs
]
retain
=
[
max
(
int
(
round
(
x
.
compression_params
.
compression_ratio
*
(
L
-
s
-
e
)
*
self
.
num_kv_heads
)
),
1
,
)
for
s
,
e
,
L
,
x
in
zip
(
protected_first_list
,
protected_last_list
,
Ls
,
seqs
)
]
retain
=
torch
.
tensor
(
retain
,
dtype
=
torch
.
int32
,
device
=
"cpu"
,
pin_memory
=
True
)
protected_first
=
torch
.
tensor
(
protected_first_list
,
dtype
=
torch
.
int32
,
device
=
"cpu"
,
pin_memory
=
True
)
protected_last
=
torch
.
tensor
(
protected_last_list
,
dtype
=
torch
.
int32
,
device
=
"cpu"
,
pin_memory
=
True
)
self
.
packed_context_i32
[
pack_slices_32
[
"protected_first"
]].
copy_
(
protected_first
,
non_blocking
=
True
)
self
.
packed_context_i32
[
pack_slices_32
[
"protected_last"
]].
copy_
(
protected_last
,
non_blocking
=
True
)
compression_chunk_size
=
(
batch_compression_params
.
chunk_size
if
batch_compression_params
.
do_chunked_compression
else
-
1
)
min_compression_ratio
=
min
(
x
.
compression_params
.
compression_ratio
for
x
in
seqs
)
cr_scaled
=
int
(
round
(
float
(
min_compression_ratio
)
*
1_000_000.0
))
cr_scaled
=
max
(
min
(
cr_scaled
,
2_000_000_000
),
-
2_000_000_000
)
header_host
=
torch
.
tensor
(
[
B
,
N
,
1
if
do_compression
else
0
,
batch_compression_params
.
compression_method
.
value
,
compression_chunk_size
,
cr_scaled
,
],
dtype
=
torch
.
int32
,
device
=
"cpu"
,
pin_memory
=
True
,
)
self
.
packed_context_i32
[
pack_slices_32
[
"retain"
]].
copy_
(
retain
,
non_blocking
=
True
)
self
.
packed_context_i32
[
pack_slices_32
[
"header"
]].
copy_
(
header_host
,
non_blocking
=
True
)
max_seq_qk
=
max
(
Ls
)
cu
=
torch
.
tensor
(
list
(
itertools
.
accumulate
(
Ls
,
initial
=
0
)),
dtype
=
torch
.
int32
,
device
=
"cpu"
,
pin_memory
=
True
,
)
self
.
packed_context_i32
[
pack_slices_32
[
"cu_q"
]].
copy_
(
cu
,
non_blocking
=
True
)
self
.
packed_context_i32
[
pack_slices_32
[
"cu_k"
]].
copy_
(
cu
,
non_blocking
=
True
)
self
.
packed_context_i32
[
pack_slices_32
[
"context_lens"
]].
copy_
(
cu
.
diff
(),
non_blocking
=
True
)
seq_ids
=
torch
.
tensor
(
[
x
.
seq_id
for
x
in
seqs
],
dtype
=
torch
.
int64
,
device
=
"cpu"
,
pin_memory
=
True
)
input_ids
=
torch
.
tensor
(
[
tid
for
x
in
seqs
for
tid
in
x
.
prompt_token_ids
],
dtype
=
torch
.
int64
,
device
=
"cpu"
,
pin_memory
=
True
,
)
self
.
packed_context_i64
[
pack_slices_64
[
"seq_ids"
]].
copy_
(
seq_ids
,
non_blocking
=
True
)
self
.
packed_context_i64
[
pack_slices_64
[
"input_ids"
]].
copy_
(
input_ids
,
non_blocking
=
True
)
positions
=
torch
.
cat
(
[
torch
.
arange
(
L
,
dtype
=
torch
.
int64
,
device
=
"cpu"
,
pin_memory
=
True
)
for
L
in
Ls
]
)
self
.
packed_context_i64
[
pack_slices_64
[
"positions"
]].
copy_
(
positions
,
non_blocking
=
True
)
max_new_tokens
=
torch
.
tensor
(
[
seq
.
sampling_params
.
max_new_tokens
for
seq
in
seqs
],
dtype
=
torch
.
int64
,
device
=
"cpu"
,
pin_memory
=
True
,
)
self
.
packed_context_i64
[
pack_slices_64
[
"max_new_tokens"
]].
copy_
(
max_new_tokens
,
non_blocking
=
True
)
# `prefill_store_topk_kv(..., PAD_TO_PAGE_SIZE=True)` may scan beyond the
# top-k prefix to fill per-head lengths up to a page boundary. Using a
# full ranking (top_k = max_seq_len * HKV) makes `torch.topk` degenerate
# into a full sort, which is very expensive for long contexts.
#
# Instead, request only a prefix that is large enough for:
# 1) the maximum "keep" budget in the batch, plus
# 2) a conservative extra window for page-padding candidates.
max_seq_len
=
int
(
self
.
packed_context_i32
[
pack_slices_32
[
"context_lens"
]].
max
())
full_budget
=
max_seq_len
*
self
.
num_kv_heads
keep_budget
=
int
(
retain
.
max
().
item
())
pad_search_budget
=
(
self
.
page_size
-
1
)
*
(
self
.
num_kv_heads
**
2
)
max_retain
=
min
(
full_budget
,
keep_budget
+
pad_search_budget
)
dist
.
broadcast
(
self
.
packed_context_i64
,
src
=
0
)
dist
.
broadcast
(
self
.
packed_context_i32
,
src
=
0
)
prefill_args
=
PrefillBatchArguments
(
B
=
B
,
N
=
N
,
do_compression
=
do_compression
,
compression_method
=
batch_compression_params
.
compression_method
,
compression_chunk_size
=
compression_chunk_size
,
seq_ids
=
self
.
packed_context_i64
[
pack_slices_64
[
"seq_ids"
]],
input_ids
=
self
.
packed_context_i64
[
pack_slices_64
[
"input_ids"
]],
positions
=
self
.
packed_context_i64
[
pack_slices_64
[
"positions"
]],
cu_seqlens_q
=
self
.
packed_context_i32
[
pack_slices_32
[
"cu_q"
]],
cu_seqlens_k
=
self
.
packed_context_i32
[
pack_slices_32
[
"cu_k"
]],
max_seqlen_q
=
max_seq_qk
,
max_seqlen_k
=
max_seq_qk
,
batch_tokens_to_retain
=
self
.
packed_context_i32
[
pack_slices_32
[
"retain"
]],
max_tokens_to_retain
=
max_retain
,
PHI
=
self
.
PHI
,
context_lens
=
self
.
packed_context_i32
[
pack_slices_32
[
"context_lens"
]],
max_new_tokens
=
self
.
packed_context_i64
[
pack_slices_64
[
"max_new_tokens"
]],
protected_first
=
protected_first_list
,
protected_last
=
protected_last_list
,
compression_ratio
=
min_compression_ratio
,
)
return
prefill_args
def
_peer_receive_prefill
(
self
)
->
PrefillBatchArguments
:
dist
.
broadcast
(
self
.
packed_context_i64
,
src
=
0
)
dist
.
broadcast
(
self
.
packed_context_i32
,
src
=
0
)
header
=
self
.
packed_context_i32
[:
6
].
tolist
()
B
,
N
=
int
(
header
[
0
]),
int
(
header
[
1
])
do_compression
=
bool
(
int
(
header
[
2
]))
compression_method
=
CompressionMethod
(
int
(
header
[
3
]))
compression_chunk_size
=
int
(
header
[
4
])
compression_ratio
=
int
(
header
[
5
])
/
1_000_000.0
pack_slices_64
=
self
.
packed_i64_slices
(
B
,
N
)
pack_slices_32
=
self
.
packed_i32_slices
(
B
)
max_seq_len
=
int
(
self
.
packed_context_i32
[
pack_slices_32
[
"context_lens"
]].
max
())
full_budget
=
max_seq_len
*
self
.
num_kv_heads
keep_budget
=
int
(
self
.
packed_context_i32
[
pack_slices_32
[
"retain"
]].
max
().
item
())
pad_search_budget
=
(
self
.
page_size
-
1
)
*
(
self
.
num_kv_heads
**
2
)
max_retain
=
min
(
full_budget
,
keep_budget
+
pad_search_budget
)
prefill_args
=
PrefillBatchArguments
(
B
=
B
,
N
=
N
,
do_compression
=
do_compression
,
compression_method
=
compression_method
,
compression_chunk_size
=
compression_chunk_size
,
seq_ids
=
self
.
packed_context_i64
[
pack_slices_64
[
"seq_ids"
]],
input_ids
=
self
.
packed_context_i64
[
pack_slices_64
[
"input_ids"
]],
positions
=
self
.
packed_context_i64
[
pack_slices_64
[
"positions"
]],
cu_seqlens_q
=
self
.
packed_context_i32
[
pack_slices_32
[
"cu_q"
]],
cu_seqlens_k
=
self
.
packed_context_i32
[
pack_slices_32
[
"cu_k"
]],
max_seqlen_q
=
int
(
self
.
packed_context_i32
[
pack_slices_32
[
"cu_q"
]].
max
()),
max_seqlen_k
=
int
(
self
.
packed_context_i32
[
pack_slices_32
[
"cu_k"
]].
max
()),
batch_tokens_to_retain
=
self
.
packed_context_i32
[
pack_slices_32
[
"retain"
]],
max_tokens_to_retain
=
max_retain
,
PHI
=
self
.
PHI
,
context_lens
=
self
.
packed_context_i32
[
pack_slices_32
[
"context_lens"
]],
max_new_tokens
=
self
.
packed_context_i64
[
pack_slices_64
[
"max_new_tokens"
]],
protected_first
=
self
.
packed_context_i32
[
pack_slices_32
[
"protected_first"
]
].
tolist
(),
protected_last
=
self
.
packed_context_i32
[
pack_slices_32
[
"protected_last"
]
].
tolist
(),
compression_ratio
=
compression_ratio
,
)
return
prefill_args
@
torch
.
inference_mode
()
def
build_prefill_args
(
self
,
seqs
:
Optional
[
List
[
Sequence
]]
=
None
,
batch_compression_params
:
Optional
[
BatchCompressionParams
]
=
None
,
)
->
PrefillBatchArguments
:
if
self
.
rank
==
0
:
return
self
.
_master_build_prefill
(
seqs
,
batch_compression_params
)
return
self
.
_peer_receive_prefill
()
def
broadcast
(
self
):
if
self
.
world_size
>
1
:
return
dist
.
broadcast
(
self
.
packed_context_i64
,
src
=
0
)
return
None
@
staticmethod
def
packed_i64_slices
(
B
:
int
,
N
:
int
):
return
{
"seq_ids"
:
slice
(
0
,
B
),
"input_ids"
:
slice
(
B
,
B
+
N
),
"positions"
:
slice
(
B
+
N
,
B
+
2
*
N
),
"max_new_tokens"
:
slice
(
B
+
2
*
N
,
2
*
B
+
2
*
N
),
}
@
staticmethod
def
packed_i32_slices
(
B
:
int
):
h0
,
h1
=
0
,
6
q0
=
h1
q1
=
q0
+
(
B
+
1
)
k0
=
q1
k1
=
k0
+
(
B
+
1
)
r0
=
k1
r1
=
r0
+
B
c0
=
r1
c1
=
r1
+
B
pf0
=
c1
pf1
=
c1
+
B
pl0
=
pf1
pl1
=
pf1
+
B
return
{
"header"
:
slice
(
h0
,
h1
),
"cu_q"
:
slice
(
q0
,
q1
),
"cu_k"
:
slice
(
k0
,
k1
),
"retain"
:
slice
(
r0
,
r1
),
"context_lens"
:
slice
(
c0
,
c1
),
"protected_first"
:
slice
(
pf0
,
pf1
),
"protected_last"
:
slice
(
pl0
,
pl1
),
}
@
dataclass
class
DecodeBatchOutput
:
output_tokens
:
Optional
[
torch
.
Tensor
]
output_seq_ids
:
Optional
[
torch
.
Tensor
]
@
dataclass
class
DecodeBatchArguments
:
batch_mapping
:
Optional
[
torch
.
Tensor
]
=
None
token_ids
:
Optional
[
torch
.
Tensor
]
=
None
positions
:
Optional
[
torch
.
Tensor
]
=
None
max_ctx_lens
:
Optional
[
torch
.
Tensor
]
=
None
seq_ids
:
Optional
[
torch
.
Tensor
]
=
None
temps
:
Optional
[
torch
.
Tensor
]
=
None
desired_batch_occupancy
:
int
=
-
1
num_stashed_batches
:
int
=
0
def
update
(
self
,
batch_mapping
,
token_ids
,
positions
,
max_ctx_lens
,
seq_ids
,
temps
=
None
,
desired_batch_occupancy
:
int
=
None
,
):
if
self
.
batch_mapping
is
not
None
:
self
.
batch_mapping
=
torch
.
cat
([
self
.
batch_mapping
,
batch_mapping
],
dim
=
0
)
else
:
self
.
batch_mapping
=
batch_mapping
.
clone
()
if
self
.
token_ids
is
not
None
:
self
.
token_ids
=
torch
.
cat
([
self
.
token_ids
,
token_ids
],
dim
=
0
)
else
:
self
.
token_ids
=
token_ids
.
clone
()
if
self
.
positions
is
not
None
:
self
.
positions
=
torch
.
cat
([
self
.
positions
,
positions
],
dim
=
0
)
else
:
self
.
positions
=
positions
.
clone
()
if
self
.
max_ctx_lens
is
not
None
:
self
.
max_ctx_lens
=
torch
.
cat
([
self
.
max_ctx_lens
,
max_ctx_lens
],
dim
=
0
)
else
:
self
.
max_ctx_lens
=
max_ctx_lens
.
clone
()
if
self
.
seq_ids
is
not
None
:
self
.
seq_ids
=
torch
.
cat
([
self
.
seq_ids
,
seq_ids
],
dim
=
0
)
else
:
self
.
seq_ids
=
seq_ids
.
clone
()
if
self
.
temps
is
not
None
and
temps
is
not
None
:
self
.
temps
=
torch
.
cat
([
self
.
temps
,
temps
],
dim
=
0
)
elif
temps
is
not
None
:
self
.
temps
=
temps
.
clone
()
if
desired_batch_occupancy
is
not
None
:
self
.
desired_batch_occupancy
=
desired_batch_occupancy
return
self
vllm/compactor-vllm/src/compactor_vllm/utils/context.py
0 → 100644
View file @
d29c39ca
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
import
torch
from
compactor_vllm.compression
import
CompressionMethod
from
compactor_vllm.config.engine_config
import
AttentionBackend
@
dataclass
class
CompressionContext
:
compression_method
:
CompressionMethod
=
CompressionMethod
.
COMPACTOR
compression_chunk_size
:
int
=
-
1
batch_tokens_to_retain
:
torch
.
Tensor
|
None
=
None
max_tokens_to_retain
:
int
=
0
context_lens
:
List
[
int
]
|
None
=
None
PHI
:
torch
.
Tensor
|
None
=
None
# Compactor(与 kvpress ``CompactorPress`` 对齐的可选超参)
sketch_dimension
:
int
=
48
sink_size_start
:
int
=
8
sink_size_end
:
int
=
4
compactor_blending
:
Optional
[
float
]
=
None
# 与 kvpress 一致:未设 ``compactor_blending`` 时用该值(来自请求的 compression_ratio)
compression_ratio
:
Optional
[
float
]
=
None
protected_first_tokens
:
List
[
int
]
|
None
=
None
protected_last_tokens
:
List
[
int
]
|
None
=
None
# CriticalAdaKV
wo_weight
:
Optional
[
torch
.
Tensor
]
=
None
critical_ada_epsilon
:
float
=
1e-4
critical_ada_first_stage_ratio
:
float
=
0.5
critical_ada_alpha_safeguard
:
float
=
0.2
@
dataclass
class
Context
:
is_prefill
:
bool
=
False
do_compression
:
bool
=
False
cu_seqlens_q
:
torch
.
Tensor
|
None
=
None
cu_seqlens_k
:
torch
.
Tensor
|
None
=
None
max_seqlen_q
:
int
=
0
max_seqlen_k
:
int
=
0
batch_mapping
:
torch
.
Tensor
|
None
=
None
max_bh_len
:
int
=
0
compression_context
:
CompressionContext
|
None
=
None
STORE_STREAM
:
torch
.
cuda
.
Stream
|
None
=
None
key_split
:
int
|
None
=
None
attention_backend
:
AttentionBackend
=
AttentionBackend
.
COMPACTOR_TRITON
_CONTEXT
=
Context
()
def
get_context
():
return
_CONTEXT
def
set_context
(
*
,
is_prefill
,
do_compression
=
False
,
cu_seqlens_q
=
None
,
cu_seqlens_k
=
None
,
max_seqlen_q
=
0
,
max_seqlen_k
=
0
,
batch_mapping
=
None
,
max_bh_len
=
0
,
compression_context
:
CompressionContext
=
None
,
STORE_STREAM
=
None
,
key_split
=
None
,
attention_backend
=
AttentionBackend
.
COMPACTOR_TRITON
,
):
global
_CONTEXT
_CONTEXT
=
Context
(
is_prefill
,
do_compression
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
batch_mapping
,
max_bh_len
,
compression_context
,
STORE_STREAM
,
key_split
,
attention_backend
,
)
def
reset_context
():
global
_CONTEXT
_CONTEXT
=
Context
()
vllm/compactor-vllm/src/compactor_vllm/utils/helpers.py
0 → 100644
View file @
d29c39ca
from
collections.abc
import
Callable
import
torch
def
maybe_execute_in_stream
(
fn
:
Callable
,
*
args
,
STORE_STREAM
:
torch
.
cuda
.
Stream
=
None
,
**
kwargs
):
if
STORE_STREAM
is
not
None
:
tensors
=
[
arg
for
arg
in
args
if
isinstance
(
arg
,
torch
.
Tensor
)]
tensors
+=
[
val
for
val
in
kwargs
.
values
()
if
isinstance
(
val
,
torch
.
Tensor
)]
obj
=
getattr
(
fn
,
"__self__"
,
None
)
if
isinstance
(
obj
,
torch
.
Tensor
):
tensors
.
append
(
obj
)
STORE_STREAM
.
wait_stream
(
torch
.
cuda
.
default_stream
())
# Some PyTorch builds don't make `torch.cuda.Stream` a context manager.
# The portable API is `torch.cuda.stream(stream)`.
stream_ctx
=
(
STORE_STREAM
if
hasattr
(
STORE_STREAM
,
"__enter__"
)
else
torch
.
cuda
.
stream
(
STORE_STREAM
)
)
with
stream_ctx
:
output
=
fn
(
*
args
,
**
kwargs
)
for
t
in
tensors
:
t
.
record_stream
(
STORE_STREAM
)
if
isinstance
(
output
,
tuple
):
for
o
in
output
:
if
isinstance
(
o
,
torch
.
Tensor
):
o
.
record_stream
(
torch
.
cuda
.
default_stream
())
elif
isinstance
(
output
,
torch
.
Tensor
):
output
.
record_stream
(
torch
.
cuda
.
default_stream
())
return
output
else
:
return
fn
(
*
args
,
**
kwargs
)
vllm/compactor-vllm/src/compactor_vllm/utils/sequence.py
0 → 100644
View file @
d29c39ca
from
dataclasses
import
dataclass
,
field
from
enum
import
Enum
,
auto
from
itertools
import
count
from
typing
import
List
from
compactor_vllm.compression.compression_config
import
SequenceCompressionParams
from
compactor_vllm.config.sampling_params
import
SamplingParams
class
SequenceStatus
(
Enum
):
WAITING
=
auto
()
RUNNING
=
auto
()
FINISHED
=
auto
()
@
dataclass
class
Sequence
:
"""
Represents a single user request / sequence being generated.
"""
_counter
=
count
()
prompt_token_ids
:
List
[
int
]
completion_token_ids
:
List
[
int
]
=
field
(
default_factory
=
list
)
sampling_params
:
SamplingParams
=
field
(
default_factory
=
SamplingParams
)
compression_params
:
SequenceCompressionParams
=
field
(
default_factory
=
SequenceCompressionParams
)
status
:
SequenceStatus
=
SequenceStatus
.
WAITING
seq_id
:
int
=
field
(
default_factory
=
lambda
:
next
(
Sequence
.
_counter
),
init
=
False
)
num_tokens_processed
:
int
=
0
@
property
def
num_prompt_tokens
(
self
)
->
int
:
return
len
(
self
.
prompt_token_ids
)
@
property
def
num_generated_tokens
(
self
)
->
int
:
return
len
(
self
.
completion_token_ids
)
def
add_new_token
(
self
,
token_id
:
int
)
->
None
:
if
len
(
self
.
completion_token_ids
)
==
0
:
self
.
num_tokens_processed
+=
self
.
num_prompt_tokens
self
.
completion_token_ids
.
append
(
token_id
)
self
.
num_tokens_processed
+=
1
def
tokens_to_retain_per_layer
(
self
,
num_kv_heads
:
int
)
->
int
:
n
=
int
(
self
.
compression_params
.
compression_ratio
*
self
.
num_prompt_tokens
*
num_kv_heads
)
return
max
(
1
,
n
)
def
__getstate__
(
self
):
return
dict
(
prompt_token_ids
=
list
(
self
.
prompt_token_ids
),
completion_token_ids
=
list
(
self
.
completion_token_ids
),
sampling_params
=
self
.
sampling_params
,
compression_params
=
self
.
compression_params
,
status
=
self
.
status
,
seq_id
=
self
.
seq_id
,
num_tokens_processed
=
self
.
num_tokens_processed
,
)
def
__setstate__
(
self
,
state
):
self
.
prompt_token_ids
=
list
(
state
[
"prompt_token_ids"
])
self
.
completion_token_ids
=
list
(
state
[
"completion_token_ids"
])
self
.
sampling_params
=
state
[
"sampling_params"
]
self
.
compression_params
=
state
[
"compression_params"
]
self
.
status
=
state
[
"status"
]
self
.
seq_id
=
state
[
"seq_id"
]
self
.
num_tokens_processed
=
state
[
"num_tokens_processed"
]
@
property
def
prompt_len
(
self
)
->
int
:
return
len
(
self
.
prompt_token_ids
)
@
property
def
completion_len
(
self
)
->
int
:
return
len
(
self
.
completion_token_ids
)
vllm/compactor-vllm/src/compactor_vllm/utils/triton_compat.py
0 → 100644
View file @
d29c39ca
from
__future__
import
annotations
import
inspect
from
typing
import
Any
,
Callable
,
Mapping
import
torch
def
_filter_kwargs_for_callable
(
fn
:
Callable
[...,
Any
],
kwargs
:
Mapping
[
str
,
Any
]
)
->
dict
[
str
,
Any
]:
try
:
params
=
inspect
.
signature
(
fn
).
parameters
except
(
TypeError
,
ValueError
):
return
dict
(
kwargs
)
return
{
k
:
v
for
k
,
v
in
kwargs
.
items
()
if
k
in
params
}
def
autotune
(
*
,
configs
,
key
,
**
kwargs
):
"""
Compatibility wrapper around `triton.autotune`.
Some Triton builds (e.g., custom vendor builds) may not support newer
keyword arguments like `cache_results`. This wrapper filters unsupported
kwargs based on the runtime `triton.autotune` signature.
"""
import
triton
filtered
=
_filter_kwargs_for_callable
(
triton
.
autotune
,
kwargs
)
return
triton
.
autotune
(
configs
=
configs
,
key
=
key
,
**
filtered
)
def
maybe_set_allocator
(
alloc_fn
:
Callable
[[
int
,
int
,
int
|
None
],
Any
])
->
bool
:
"""
Call `triton.set_allocator(alloc_fn)` if present; otherwise no-op.
Returns True if the allocator was set.
"""
import
triton
setter
=
getattr
(
triton
,
"set_allocator"
,
None
)
if
setter
is
None
:
return
False
setter
(
alloc_fn
)
return
True
def
cuda_capability_geq
(
major
:
int
,
minor
:
int
=
0
,
device
:
int
|
None
=
None
)
->
bool
:
"""
Host-side CUDA capability check that works even when `tl.target_info` is absent.
"""
if
not
torch
.
cuda
.
is_available
():
return
False
if
device
is
None
:
try
:
device
=
torch
.
cuda
.
current_device
()
except
Exception
:
device
=
0
cap
=
torch
.
cuda
.
get_device_capability
(
device
)
return
cap
>=
(
major
,
minor
)
vllm/compactor-vllm/tests/test_store_kv.py
0 → 100644
View file @
d29c39ca
import
collections
import
logging
from
dataclasses
import
dataclass
from
typing
import
List
import
pytest
import
torch
import
triton
from
compactor_vllm.compression.common
import
scores_to_retain_indices
from
src.compactor_vllm.kv_cache.store_kv_cache
import
prefill_store_topk_kv
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
Workload
:
name
:
str
batch_size
:
int
nk_heads
:
int
head_dim
:
int
frac
:
float
# per-sequence cached context length fractionf
page_size
:
int
cache_lens
:
List
[
int
]
# per-sequence cached context length
WORKLOADS
:
List
[
Workload
]
=
[
Workload
(
name
=
f
"batch_size=
{
BATCH
}
kv_cache_len=
{
cache_lens
}
"
f
"FRAC=
{
frac
}
HKV=
{
NK_HEADS
}
HEAD_DIM=
{
HEAD_DIM
}
"
,
batch_size
=
BATCH
,
nk_heads
=
NK_HEADS
,
head_dim
=
HEAD_DIM
,
cache_lens
=
[
cache_lens
]
*
BATCH
,
frac
=
frac
,
page_size
=
ps
,
)
for
BATCH
in
[
1
,
2
,
3
,
8
]
for
frac
in
[
0.10
,
0.20
,
0.30
,
0.40
]
for
NK_HEADS
in
[
2
,
4
,
8
]
for
HEAD_DIM
in
[
32
,
64
,
128
]
for
cache_lens
in
[
10
,
20
,
30
,
70
,
1000
]
for
ps
in
[
128
,
256
]
]
@
pytest
.
mark
.
parametrize
(
"workload"
,
WORKLOADS
,
ids
=
lambda
wl
:
wl
.
name
)
def
test_prefill_store_topk_kv
(
workload
:
Workload
):
B
=
workload
.
batch_size
H
=
workload
.
nk_heads
D
=
workload
.
head_dim
TOP_K
=
int
(
workload
.
cache_lens
[
0
]
*
workload
.
nk_heads
*
workload
.
frac
)
PAGE_SIZE
=
workload
.
page_size
dtype
=
torch
.
float16
device
=
triton
.
runtime
.
driver
.
active
.
get_active_torch_device
()
lens
=
torch
.
tensor
(
workload
.
cache_lens
,
dtype
=
torch
.
int32
,
device
=
device
)
cu
=
torch
.
zeros
(
B
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
cu
[
1
:]
=
torch
.
cumsum
(
lens
,
dim
=
0
)
N_total
=
int
(
cu
[
-
1
].
item
())
keys
=
torch
.
randn
((
N_total
,
H
,
D
),
dtype
=
dtype
,
device
=
device
)
vals
=
torch
.
randn_like
(
keys
)
scores_flat
=
torch
.
randn
((
N_total
,
H
),
dtype
=
torch
.
float32
,
device
=
device
)
top_k_eff
=
max
(
0
,
min
(
TOP_K
,
int
(
lens
.
max
().
item
())
*
H
))
max_k_len
=
cu
.
diff
().
max
().
item
()
indices
=
scores_to_retain_indices
(
scores_flat
,
cu
,
max_k_len
,
top_k_eff
,
H
)
# [B, TOP_K]
LP
=
max
(
1
,
(
top_k_eff
+
PAGE_SIZE
-
1
)
//
PAGE_SIZE
)
N_LOGICAL_PAGES_MAX
=
LP
N_PAGES
=
B
*
H
*
LP
+
32
S_LARGE
=
N_PAGES
*
PAGE_SIZE
k_cache
=
torch
.
empty
((
S_LARGE
,
D
),
dtype
=
dtype
,
device
=
device
)
v_cache
=
torch
.
empty_like
(
k_cache
)
page_table
=
torch
.
empty
(
(
B
,
H
,
N_LOGICAL_PAGES_MAX
),
dtype
=
torch
.
int32
,
device
=
device
)
phys
=
0
for
b
in
range
(
B
):
for
h
in
range
(
H
):
for
lp
in
range
(
LP
):
page_table
[
b
,
h
,
lp
]
=
phys
phys
+=
1
assert
phys
<=
N_PAGES
,
"Not enough physical pages"
local_lens
=
torch
.
zeros
((
B
,
H
),
dtype
=
torch
.
int32
,
device
=
device
)
batch_mapping
=
torch
.
arange
(
B
,
dtype
=
torch
.
int32
,
device
=
device
)
num_to_retain
=
torch
.
full
((
B
,),
top_k_eff
,
dtype
=
torch
.
int32
,
device
=
device
)
prefill_store_topk_kv
(
new_keys
=
keys
,
new_vals
=
vals
,
indices_topk
=
indices
,
num_tokens_to_retain
=
num_to_retain
,
page_table
=
page_table
,
batch_mapping
=
batch_mapping
,
bh_lens
=
local_lens
,
PAGE_SIZE
=
PAGE_SIZE
,
k_cache
=
k_cache
,
v_cache
=
v_cache
,
PAD_TO_PAGE_SIZE
=
False
,
TRITON_RESERVED_BATCH
=-
1
,
)
torch
.
cuda
.
synchronize
()
local_lens_cpu
=
local_lens
.
cpu
()
page_table_cpu
=
page_table
.
cpu
()
k_cache_cpu
=
k_cache
.
cpu
()
v_cache_cpu
=
v_cache
.
cpu
()
keys_cpu
=
keys
.
cpu
()
vals_cpu
=
vals
.
cpu
()
indices_cpu
=
indices
.
cpu
()
for
b
in
range
(
B
):
hed
=
(
indices_cpu
[
b
]
%
H
).
numpy
()
counts
=
collections
.
Counter
(
hed
.
tolist
())
for
h
in
range
(
H
):
expected
=
counts
.
get
(
h
,
0
)
# type: ignore
got
=
int
(
local_lens_cpu
[
b
,
h
].
item
())
assert
got
==
expected
,
(
f
"Length mismatch at (b=
{
b
}
, h=
{
h
}
): got
{
got
}
, expected
{
expected
}
"
)
def
rows_for_head
(
b
,
h
,
L
):
"""Return the list of cache row indices storing the first L logical positions for (b,h)."""
rows
=
[]
for
pos
in
range
(
L
):
lp
=
pos
//
PAGE_SIZE
off
=
pos
%
PAGE_SIZE
phys
=
int
(
page_table_cpu
[
b
,
h
,
lp
].
item
())
rows
.
append
(
phys
*
PAGE_SIZE
+
off
)
return
rows
for
b
in
range
(
B
):
# which tokens per head were selected for this batch?
tok
=
(
indices_cpu
[
b
]
//
H
).
numpy
()
hed
=
(
indices_cpu
[
b
]
%
H
).
numpy
()
per_head
=
collections
.
defaultdict
(
list
)
for
t
,
h
in
zip
(
tok
,
hed
):
per_head
[
int
(
h
)].
append
(
int
(
t
))
for
h
in
range
(
H
):
L
=
int
(
local_lens_cpu
[
b
,
h
].
item
())
if
L
==
0
:
continue
# expected vectors (unordered) from source
toks_h
=
per_head
.
get
(
h
,
[])
assert
len
(
toks_h
)
==
L
expK
=
keys_cpu
[
toks_h
,
h
,
:].
contiguous
().
view
(
L
,
-
1
)
expV
=
vals_cpu
[
toks_h
,
h
,
:].
contiguous
().
view
(
L
,
-
1
)
# actual vectors read back from cache rows
rows
=
rows_for_head
(
b
,
h
,
L
)
actK
=
k_cache_cpu
[
rows
,
:].
contiguous
().
view
(
L
,
-
1
)
actV
=
v_cache_cpu
[
rows
,
:].
contiguous
().
view
(
L
,
-
1
)
expK_tuples
=
[
tuple
(
row
)
for
row
in
expK
.
numpy
().
tolist
()]
actK_tuples
=
[
tuple
(
row
)
for
row
in
actK
.
numpy
().
tolist
()]
expV_tuples
=
[
tuple
(
row
)
for
row
in
expV
.
numpy
().
tolist
()]
actV_tuples
=
[
tuple
(
row
)
for
row
in
actV
.
numpy
().
tolist
()]
assert
collections
.
Counter
(
expK_tuples
)
==
collections
.
Counter
(
actK_tuples
),
f
"K content mismatch at (b=
{
b
}
, h=
{
h
}
)"
assert
collections
.
Counter
(
expV_tuples
)
==
collections
.
Counter
(
actV_tuples
),
f
"V content mismatch at (b=
{
b
}
, h=
{
h
}
)"
def
test_prefill_store_topk_kv_pad_to_page_size
():
torch
.
manual_seed
(
0
)
B
,
H
,
D
=
2
,
2
,
64
PAGE_SIZE
=
128
RETAIN
=
64
dtype
=
torch
.
float16
device
=
triton
.
runtime
.
driver
.
active
.
get_active_torch_device
()
lens
=
torch
.
full
((
B
,),
256
,
dtype
=
torch
.
int32
,
device
=
device
)
cu
=
torch
.
zeros
(
B
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
cu
[
1
:]
=
torch
.
cumsum
(
lens
,
dim
=
0
)
N_total
=
int
(
cu
[
-
1
].
item
())
keys
=
torch
.
randn
((
N_total
,
H
,
D
),
dtype
=
dtype
,
device
=
device
)
vals
=
torch
.
randn_like
(
keys
)
scores_flat
=
torch
.
randn
((
N_total
,
H
),
dtype
=
torch
.
float32
,
device
=
device
)
max_k_len
=
int
(
lens
.
max
().
item
())
max_sel
=
max_k_len
*
H
indices
=
scores_to_retain_indices
(
scores_flat
,
cu
,
max_k_len
,
max_sel
,
H
)
N_LOGICAL_PAGES_MAX
=
2
N_PAGES
=
B
*
H
*
N_LOGICAL_PAGES_MAX
+
32
S_LARGE
=
N_PAGES
*
PAGE_SIZE
k_cache
=
torch
.
empty
((
S_LARGE
,
D
),
dtype
=
dtype
,
device
=
device
)
v_cache
=
torch
.
empty_like
(
k_cache
)
page_table
=
torch
.
empty
(
(
B
,
H
,
N_LOGICAL_PAGES_MAX
),
dtype
=
torch
.
int32
,
device
=
device
)
phys
=
0
for
b
in
range
(
B
):
for
h
in
range
(
H
):
for
lp
in
range
(
N_LOGICAL_PAGES_MAX
):
page_table
[
b
,
h
,
lp
]
=
phys
phys
+=
1
assert
phys
<=
N_PAGES
,
"Not enough physical pages"
local_lens
=
torch
.
zeros
((
B
,
H
),
dtype
=
torch
.
int32
,
device
=
device
)
batch_mapping
=
torch
.
arange
(
B
,
dtype
=
torch
.
int32
,
device
=
device
)
num_to_retain
=
torch
.
full
((
B
,),
RETAIN
,
dtype
=
torch
.
int32
,
device
=
device
)
prefill_store_topk_kv
(
new_keys
=
keys
,
new_vals
=
vals
,
indices_topk
=
indices
,
num_tokens_to_retain
=
num_to_retain
,
page_table
=
page_table
,
batch_mapping
=
batch_mapping
,
bh_lens
=
local_lens
,
PAGE_SIZE
=
PAGE_SIZE
,
k_cache
=
k_cache
,
v_cache
=
v_cache
,
PAD_TO_PAGE_SIZE
=
True
,
cu_seqlens_k
=
cu
,
TRITON_RESERVED_BATCH
=-
1
,
)
torch
.
cuda
.
synchronize
()
local_lens_cpu
=
local_lens
.
cpu
()
lens_cpu
=
lens
.
cpu
()
assert
(
local_lens_cpu
%
PAGE_SIZE
==
0
).
all
()
assert
(
local_lens_cpu
<=
lens_cpu
[:,
None
]).
all
()
vllm/compactor-vllm/tests/test_triton_attention.py
0 → 100644
View file @
d29c39ca
import
logging
import
math
from
dataclasses
import
dataclass
from
typing
import
List
import
pytest
import
torch
import
triton
from
flash_attn.flash_attn_interface
import
(
flash_attn_varlen_func
,
flash_attn_with_kvcache
,
)
from
compactor_vllm.attention.sparse_decode_kernel
import
head_sparse_decode_attention
from
compactor_vllm.attention.sparse_varlen_kernel
import
(
causal_sparse_varlen_with_cache
,
)
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
Workload
:
name
:
str
batch_size
:
int
nq_heads
:
int
nk_heads
:
int
head_dim
:
int
cache_lens
:
List
[
int
]
# per-sequence cached context length
append_lens
:
List
[
int
]
# per-sequence new tokens this step (Q_app, K_app, V_app)
WORKLOADS
:
List
[
Workload
]
=
[
Workload
(
name
=
f
"batch_size=
{
BATCH
}
kv_cache_len=
{
cache_lens
}
append_len=
{
append_lens
}
"
f
"HQ=
{
NQ_HEADS
}
HKV=
{
NK_HEADS
}
HEAD_DIM=
{
HEAD_DIM
}
"
,
batch_size
=
BATCH
,
nq_heads
=
NQ_HEADS
,
nk_heads
=
NK_HEADS
,
head_dim
=
HEAD_DIM
,
cache_lens
=
[
cache_lens
]
*
BATCH
,
append_lens
=
[
append_lens
]
*
BATCH
,
)
for
BATCH
in
[
1
,
2
,
3
,
8
]
for
NQ_HEADS
in
[
32
]
for
NK_HEADS
in
[
8
]
for
HEAD_DIM
in
[
128
]
for
cache_lens
in
[
0
,
1
,
70
,
128
,
8193
]
for
append_lens
in
[
1
,
2
,
13
,
8000
]
]
WORKLOADS_DECODE
:
List
[
Workload
]
=
[
Workload
(
name
=
f
"batch_size=
{
BATCH
}
kv_cache_len=
{
cache_lens
}
"
f
"HQ=
{
NQ_HEADS
}
HKV=
{
NK_HEADS
}
HEAD_DIM=
{
HEAD_DIM
}
"
,
batch_size
=
BATCH
,
nq_heads
=
NQ_HEADS
,
nk_heads
=
NK_HEADS
,
head_dim
=
HEAD_DIM
,
cache_lens
=
[
cache_lens
]
*
BATCH
,
append_lens
=
[
1
]
*
BATCH
,
)
for
BATCH
in
[
1
,
2
,
3
,
8
]
for
NQ_HEADS
in
[
32
]
for
NK_HEADS
in
[
8
]
for
HEAD_DIM
in
[
128
]
for
cache_lens
in
[
1
,
2
,
70
,
128
,
8000
]
]
def
build_paged_cache_from_lengths
(
B
,
H_kv
,
D
,
PAGE_SIZE
,
N_LOGICAL_PAGES_MAX
,
L_cache_per_b
,
# int32 [B], per-batch cache length
device
,
dtype
,
):
"""
Construct:
- seq_lens_bh[b, h] = L_cache_per_b[b]
- page_table[b, h, lp] giving physical page ids
- K_cache, V_cache filled for valid cached tokens
Physical layout:
physical_page_id = (b * H_kv + h) * N_LOGICAL_PAGES_MAX + lp
CACHE_SIZE = num_phys_pages * PAGE_SIZE
"""
assert
L_cache_per_b
.
shape
[
0
]
==
B
max_len
=
PAGE_SIZE
*
N_LOGICAL_PAGES_MAX
assert
(
L_cache_per_b
<=
max_len
).
all
()
seq_lens_bh
=
torch
.
empty
((
B
,
H_kv
),
dtype
=
torch
.
int32
,
device
=
device
)
for
b
in
range
(
B
):
seq_lens_bh
[
b
,
:].
fill_
(
L_cache_per_b
[
b
])
num_phys_pages
=
B
*
H_kv
*
N_LOGICAL_PAGES_MAX
CACHE_SIZE
=
num_phys_pages
*
PAGE_SIZE
K_cache
=
torch
.
zeros
((
CACHE_SIZE
,
D
),
device
=
device
,
dtype
=
dtype
)
V_cache
=
torch
.
zeros
((
CACHE_SIZE
,
D
),
device
=
device
,
dtype
=
dtype
)
page_table
=
torch
.
empty
(
(
B
,
H_kv
,
N_LOGICAL_PAGES_MAX
),
device
=
device
,
dtype
=
torch
.
int32
)
# assign unique physical pages per (b, h, lp)
phys_page
=
0
for
b
in
range
(
B
):
for
h
in
range
(
H_kv
):
for
lp
in
range
(
N_LOGICAL_PAGES_MAX
):
page_table
[
b
,
h
,
lp
]
=
phys_page
phys_page
+=
1
# fill cached tokens
g
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
1234
)
for
b
in
range
(
B
):
Lc
=
int
(
L_cache_per_b
[
b
].
item
())
for
h
in
range
(
H_kv
):
for
i
in
range
(
Lc
):
lp
=
i
//
PAGE_SIZE
off
=
i
%
PAGE_SIZE
phys
=
int
(
page_table
[
b
,
h
,
lp
].
item
())
idx
=
phys
*
PAGE_SIZE
+
off
K_cache
[
idx
]
=
torch
.
randn
(
D
,
device
=
device
,
dtype
=
dtype
,
generator
=
g
)
V_cache
[
idx
]
=
torch
.
randn
(
D
,
device
=
device
,
dtype
=
dtype
,
generator
=
g
)
return
K_cache
,
V_cache
,
page_table
,
seq_lens_bh
,
CACHE_SIZE
def
materialize_kv_for_flash_mixed
(
K_cache
,
V_cache
,
page_table
,
L_cache_per_b
,
# [B]
k_append_raw
,
# [N, H_kv, D]
v_append_raw
,
# [N, H_kv, D]
cu_seqlens_qk
,
# [B+1]
H_kv
,
PAGE_SIZE
,
):
"""
Build (K_total, V_total, cu_seqlens_k) for flash_attn_varlen_func such that:
For each batch b:
seqlen_q[b] = L_app[b] = cu[b+1] - cu[b]
seqlen_k[b] = L_cache_per_b[b] + L_app[b]
Keys:
- first L_cache_per_b[b] positions from paged cache
- next L_app[b] positions from k_append_raw for that batch
"""
device
=
K_cache
.
device
dtype
=
K_cache
.
dtype
B
=
cu_seqlens_qk
.
numel
()
-
1
N
,
H_kv_raw
,
D
=
k_append_raw
.
shape
assert
H_kv_raw
==
H_kv
# appended lengths
L_app
=
(
cu_seqlens_qk
[
1
:]
-
cu_seqlens_qk
[:
-
1
]).
to
(
torch
.
int32
)
# [B]
seqlen_k
=
L_cache_per_b
+
L_app
# [B]
cu_seqlens_k
=
torch
.
empty
(
B
+
1
,
device
=
device
,
dtype
=
torch
.
int32
)
cu_seqlens_k
[
0
]
=
0
total_k
=
int
(
seqlen_k
.
sum
().
item
())
K_total
=
torch
.
empty
((
total_k
,
H_kv
,
D
),
device
=
device
,
dtype
=
dtype
)
V_total
=
torch
.
empty
((
total_k
,
H_kv
,
D
),
device
=
device
,
dtype
=
dtype
)
for
b
in
range
(
B
):
offset_k
=
int
(
cu_seqlens_k
[
b
].
item
())
Lc
=
int
(
L_cache_per_b
[
b
].
item
())
La
=
int
(
L_app
[
b
].
item
())
q_start
=
int
(
cu_seqlens_qk
[
b
].
item
())
# cache segment
for
g
in
range
(
H_kv
):
for
i
in
range
(
Lc
):
lp
=
i
//
PAGE_SIZE
off
=
i
%
PAGE_SIZE
phys
=
int
(
page_table
[
b
,
g
,
lp
].
item
())
idx
=
phys
*
PAGE_SIZE
+
off
K_total
[
offset_k
+
i
,
g
]
=
K_cache
[
idx
]
V_total
[
offset_k
+
i
,
g
]
=
V_cache
[
idx
]
# appended segment
if
k_append_raw
.
numel
()
>
0
:
for
g
in
range
(
H_kv
):
for
j
in
range
(
La
):
src
=
q_start
+
j
dst
=
offset_k
+
Lc
+
j
K_total
[
dst
,
g
]
=
k_append_raw
[
src
,
g
]
V_total
[
dst
,
g
]
=
v_append_raw
[
src
,
g
]
cu_seqlens_k
[
b
+
1
]
=
cu_seqlens_k
[
b
]
+
(
Lc
+
La
)
return
K_total
,
V_total
,
cu_seqlens_k
@
pytest
.
mark
.
parametrize
(
"workload"
,
WORKLOADS
,
ids
=
lambda
wl
:
wl
.
name
)
def
test_causal_sparse_varlen_with_cache
(
workload
:
Workload
):
dtype
=
torch
.
float16
device
=
triton
.
runtime
.
driver
.
active
.
get_active_torch_device
()
DEFAULT_PAGE_SIZE
=
256
N_LOGICAL_PAGES_MAX
=
256
L_cache_per_b
=
torch
.
as_tensor
(
workload
.
cache_lens
,
device
=
device
,
dtype
=
torch
.
int32
)
K_cache
,
V_cache
,
page_table
,
seq_lens_bh
,
CACHE_SIZE
=
(
build_paged_cache_from_lengths
(
B
=
workload
.
batch_size
,
H_kv
=
workload
.
nk_heads
,
D
=
workload
.
head_dim
,
PAGE_SIZE
=
DEFAULT_PAGE_SIZE
,
N_LOGICAL_PAGES_MAX
=
N_LOGICAL_PAGES_MAX
,
L_cache_per_b
=
L_cache_per_b
,
device
=
device
,
dtype
=
dtype
,
)
)
assert
len
(
workload
.
append_lens
)
==
workload
.
batch_size
cu
=
[
0
]
for
L
in
workload
.
append_lens
:
cu
.
append
(
cu
[
-
1
]
+
L
)
cu_seqlens_qk
=
torch
.
tensor
(
cu
,
dtype
=
torch
.
int32
,
device
=
device
)
N
=
int
(
cu_seqlens_qk
[
-
1
].
item
())
q_raw
=
torch
.
randn
(
N
,
workload
.
nq_heads
,
workload
.
head_dim
,
device
=
device
,
dtype
=
dtype
)
k_append_raw
=
torch
.
randn
(
N
,
workload
.
nk_heads
,
workload
.
head_dim
,
device
=
device
,
dtype
=
dtype
)
v_append_raw
=
torch
.
randn_like
(
k_append_raw
)
batch_mapping
=
torch
.
arange
(
workload
.
batch_size
,
device
=
device
,
dtype
=
torch
.
int32
)
sm_scale
=
1.0
/
math
.
sqrt
(
workload
.
head_dim
)
K_total
,
V_total
,
cu_seqlens_k
=
materialize_kv_for_flash_mixed
(
K_cache
=
K_cache
,
V_cache
=
V_cache
,
page_table
=
page_table
,
L_cache_per_b
=
L_cache_per_b
,
k_append_raw
=
k_append_raw
,
v_append_raw
=
v_append_raw
,
cu_seqlens_qk
=
cu_seqlens_qk
,
H_kv
=
workload
.
nk_heads
,
PAGE_SIZE
=
DEFAULT_PAGE_SIZE
,
)
max_seqlen_q
=
int
((
cu_seqlens_qk
[
1
:]
-
cu_seqlens_qk
[:
-
1
]).
max
().
item
())
max_seqlen_k
=
int
((
cu_seqlens_k
[
1
:]
-
cu_seqlens_k
[:
-
1
]).
max
().
item
())
max_seqlen_k_triton
=
seq_lens_bh
.
max
().
item
()
out_triton
=
causal_sparse_varlen_with_cache
(
q
=
q_raw
,
k_cache
=
K_cache
,
v_cache
=
V_cache
,
k
=
k_append_raw
,
v
=
v_append_raw
,
seq_lens_bh
=
seq_lens_bh
,
global_page_table
=
page_table
,
batch_mapping
=
batch_mapping
,
cu_seqlens_q
=
cu_seqlens_qk
,
HKV
=
workload
.
nk_heads
,
PAGE_SIZE
=
DEFAULT_PAGE_SIZE
,
sm_scale
=
sm_scale
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_k_cache
=
max_seqlen_k_triton
,
)
out_flash
=
flash_attn_varlen_func
(
q
=
q_raw
,
k
=
K_total
,
v
=
V_total
,
cu_seqlens_q
=
cu_seqlens_qk
,
cu_seqlens_k
=
cu_seqlens_k
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_k
=
max_seqlen_k
,
dropout_p
=
0.0
,
softmax_scale
=
sm_scale
,
causal
=
True
,
)
assert
torch
.
allclose
(
out_triton
,
out_flash
,
rtol
=
1e-6
,
atol
=
3e-3
)
max_diff
=
(
out_triton
-
out_flash
).
abs
().
max
().
item
()
logger
.
info
(
f
"[causal_sparse_varlen_with_cache:
{
workload
.
name
}
]: max abs diff=
{
max_diff
:
.
5
f
}
"
)
def
materialize_kv_cache_for_flash_decode
(
K_cache
,
V_cache
,
page_table
,
L_cache_per_b
,
# [B] int32
H_kv
:
int
,
PAGE_SIZE
:
int
,
):
"""
Build (K_flash, V_flash) suitable for flash_attn_with_kvcache, with shape:
(B, seqlen_cache_max, H_kv, D)
For each batch b:
- cache_seqlen[b] = L_cache_per_b[b]
- K_flash[b, :cache_seqlen[b], g] and V_flash[...] are filled from the paged KV cache.
- Tokens beyond cache_seqlen[b] (if any) are left as zeros and will be masked out
by flash_attn_with_kvcache via cache_seqlens.
"""
device
=
K_cache
.
device
dtype
=
K_cache
.
dtype
B
=
L_cache_per_b
.
shape
[
0
]
D
=
K_cache
.
shape
[
1
]
seqlen_cache_max
=
int
(
L_cache_per_b
.
max
().
item
())
K_flash
=
torch
.
zeros
((
B
,
seqlen_cache_max
,
H_kv
,
D
),
device
=
device
,
dtype
=
dtype
)
V_flash
=
torch
.
zeros_like
(
K_flash
)
for
b
in
range
(
B
):
Lc
=
int
(
L_cache_per_b
[
b
].
item
())
if
Lc
==
0
:
continue
for
g
in
range
(
H_kv
):
for
i
in
range
(
Lc
):
lp
=
i
//
PAGE_SIZE
off
=
i
%
PAGE_SIZE
phys
=
int
(
page_table
[
b
,
g
,
lp
].
item
())
idx
=
phys
*
PAGE_SIZE
+
off
K_flash
[
b
,
i
,
g
]
=
K_cache
[
idx
]
V_flash
[
b
,
i
,
g
]
=
V_cache
[
idx
]
return
K_flash
,
V_flash
@
pytest
.
mark
.
parametrize
(
"workload"
,
WORKLOADS_DECODE
,
ids
=
lambda
wl
:
wl
.
name
)
def
test_sparse_decode_attention
(
workload
:
Workload
):
dtype
=
torch
.
float16
device
=
triton
.
runtime
.
driver
.
active
.
get_active_torch_device
()
DEFAULT_PAGE_SIZE
=
256
N_LOGICAL_PAGES_MAX
=
256
# per-sequence cache lengths (all equal for WORKLOADS_DECODE)
L_cache_per_b
=
torch
.
as_tensor
(
workload
.
cache_lens
,
device
=
device
,
dtype
=
torch
.
int32
)
# build paged KV cache used by the Triton kernel
K_cache
,
V_cache
,
page_table
,
seq_lens_bh
,
CACHE_SIZE
=
(
build_paged_cache_from_lengths
(
B
=
workload
.
batch_size
,
H_kv
=
workload
.
nk_heads
,
D
=
workload
.
head_dim
,
PAGE_SIZE
=
DEFAULT_PAGE_SIZE
,
N_LOGICAL_PAGES_MAX
=
N_LOGICAL_PAGES_MAX
,
L_cache_per_b
=
L_cache_per_b
,
device
=
device
,
dtype
=
dtype
,
)
)
B
=
workload
.
batch_size
HQ
=
workload
.
nq_heads
HKV
=
workload
.
nk_heads
D
=
workload
.
head_dim
# Triton kernel expects q: [B, HQ, D]
q_triton
=
torch
.
randn
(
B
,
HQ
,
D
,
device
=
device
,
dtype
=
dtype
)
batch_mapping
=
torch
.
arange
(
B
,
device
=
device
,
dtype
=
torch
.
int32
)
sm_scale
=
1.0
/
math
.
sqrt
(
D
)
out_triton
=
head_sparse_decode_attention
(
q
=
q_triton
,
k
=
K_cache
,
v
=
V_cache
,
seq_lens_bh
=
seq_lens_bh
,
global_page_table
=
page_table
,
batch_mapping
=
batch_mapping
,
HKV
=
HKV
,
PAGE_SIZE
=
DEFAULT_PAGE_SIZE
,
sm_scale
=
sm_scale
,
)
# [B, HQ, D]
# materialize contiguous KV cache with shape [B, seqlen_cache_max, HKV, D]
K_flash
,
V_flash
=
materialize_kv_cache_for_flash_decode
(
K_cache
=
K_cache
,
V_cache
=
V_cache
,
page_table
=
page_table
,
L_cache_per_b
=
L_cache_per_b
,
H_kv
=
HKV
,
PAGE_SIZE
=
DEFAULT_PAGE_SIZE
,
)
# flash_attn_with_kvcache expects q: [B, seqlen_q, HQ, D]
q_flash
=
q_triton
.
unsqueeze
(
1
)
# seqlen_q = 1
out_flash
=
flash_attn_with_kvcache
(
q
=
q_flash
,
k_cache
=
K_flash
,
v_cache
=
V_flash
,
cache_seqlens
=
L_cache_per_b
,
softmax_scale
=
sm_scale
,
causal
=
True
,
).
squeeze
(
1
)
# [B, 1, HQ, D]
assert
torch
.
allclose
(
out_triton
,
out_flash
,
rtol
=
1e-6
,
atol
=
3e-3
)
max_diff
=
(
out_triton
-
out_flash
).
abs
().
max
().
item
()
logger
.
info
(
f
"[head_sparse_decode_attention:
{
workload
.
name
}
]: max abs diff=
{
max_diff
:
.
5
f
}
"
)
vllm/compactor-vllm/vllm_memory_comparison.png
0 → 100644
View file @
d29c39ca
79.4 KB
vllm/compactor-vllm/vllm_throughput_comparison.png
0 → 100644
View file @
d29c39ca
98.6 KB
Prev
1
2
3
4
5
6
7
8
9
10
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment