Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f81ce56b
Commit
f81ce56b
authored
Apr 23, 2026
by
chenzk
Browse files
vllm kvprune:v1.0.1
parent
2b7160c6
Changes
237
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
0 additions
and
2037 deletions
+0
-2037
vllm/kvprune_legacy_save/triton_kernels/tensor_details/layout_details/hopper_value.py
...ton_kernels/tensor_details/layout_details/hopper_value.py
+0
-362
vllm/kvprune_legacy_save/triton_kernels/tensor_details/layout_details/strided.py
...e/triton_kernels/tensor_details/layout_details/strided.py
+0
-17
vllm/kvprune_legacy_save/triton_kernels/testing.py
vllm/kvprune_legacy_save/triton_kernels/testing.py
+0
-215
vllm/kvprune_legacy_save/triton_kernels/topk.py
vllm/kvprune_legacy_save/triton_kernels/topk.py
+0
-157
vllm/kvprune_legacy_save/triton_kernels/topk_details/__init__.py
...prune_legacy_save/triton_kernels/topk_details/__init__.py
+0
-0
vllm/kvprune_legacy_save/triton_kernels/topk_details/_topk_backward.py
...legacy_save/triton_kernels/topk_details/_topk_backward.py
+0
-51
vllm/kvprune_legacy_save/triton_kernels/topk_details/_topk_forward.py
..._legacy_save/triton_kernels/topk_details/_topk_forward.py
+0
-183
vllm/kvprune_legacy_save/utils/__init__.py
vllm/kvprune_legacy_save/utils/__init__.py
+0
-29
vllm/kvprune_legacy_save/utils/arguments.py
vllm/kvprune_legacy_save/utils/arguments.py
+0
-445
vllm/kvprune_legacy_save/utils/context.py
vllm/kvprune_legacy_save/utils/context.py
+0
-109
vllm/kvprune_legacy_save/utils/helpers.py
vllm/kvprune_legacy_save/utils/helpers.py
+0
-35
vllm/kvprune_legacy_save/utils/kv_dist.py
vllm/kvprune_legacy_save/utils/kv_dist.py
+0
-35
vllm/kvprune_legacy_save/utils/layout_bridge.py
vllm/kvprune_legacy_save/utils/layout_bridge.py
+0
-167
vllm/kvprune_legacy_save/utils/sequence.py
vllm/kvprune_legacy_save/utils/sequence.py
+0
-83
vllm/kvprune_legacy_save/utils/tp_collectives.py
vllm/kvprune_legacy_save/utils/tp_collectives.py
+0
-48
vllm/kvprune_legacy_save/utils/tp_utils.py
vllm/kvprune_legacy_save/utils/tp_utils.py
+0
-40
vllm/kvprune_legacy_save/utils/triton_compat.py
vllm/kvprune_legacy_save/utils/triton_compat.py
+0
-61
No files found.
vllm/kvprune_legacy_save/triton_kernels/tensor_details/layout_details/hopper_value.py
deleted
100644 → 0
View file @
2b7160c6
import
torch
import
triton
import
triton.language
as
tl
from
.base
import
Layout
def
right_shift_unsigned
(
x
,
shift
):
return
(
x
>>
shift
)
&
((
1
<<
(
32
-
shift
))
-
1
)
# -----------------------------------------------------------------------
# Interleave the bits of four consecutive fp4 values (i.e. 16-bits) as:
# 1000000111000000 (first fp4)
# 1000000111000000 (second fp4)
# 1000000111000000 (third fp4)
# 0110110000000000 (fourth fp4)
# This is done so that dequantization can be done in 14 SASS instructions
# -----------------------------------------------------------------------
def
_compress_fp4
(
x
):
x
=
x
.
to
(
torch
.
int32
)
return
((
x
&
0x8
)
<<
12
)
|
((
x
&
0x7
)
<<
6
)
def
_compress_fourth
(
x
):
x
=
x
.
to
(
torch
.
int32
)
return
((
x
&
0x8
)
<<
11
)
|
((
x
&
0x6
)
<<
9
)
|
((
x
&
0x1
)
<<
13
)
def
_pack_bits
(
x
:
torch
.
Tensor
,
mx_axis
:
int
):
x
=
x
.
contiguous
()
assert
x
.
shape
[
-
1
]
%
4
==
0
,
(
"Input tensor must have a last dimension divisible by 4"
)
x
=
x
.
reshape
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
4
,
4
))
first
=
_compress_fp4
(
x
[...,
0
])
|
(
_compress_fp4
(
x
[...,
0
]
>>
4
)
<<
16
)
second
=
_compress_fp4
(
x
[...,
1
])
|
(
_compress_fp4
(
x
[...,
1
]
>>
4
)
<<
16
)
third
=
_compress_fp4
(
x
[...,
2
])
|
(
_compress_fp4
(
x
[...,
2
]
>>
4
)
<<
16
)
fourth
=
_compress_fourth
(
x
[...,
3
])
|
(
_compress_fourth
(
x
[...,
3
]
>>
4
)
<<
16
)
x
=
(
first
|
right_shift_unsigned
(
second
,
3
)
|
right_shift_unsigned
(
third
,
6
)
|
fourth
)
assert
x
.
is_contiguous
()
x
=
x
.
view
(
torch
.
uint8
)
return
x
# -----------------------------------------------------------------------
# inverse operation of _pack_bits
# -----------------------------------------------------------------------
def
_bf16_to_fp4e2m1
(
x
):
# 0bAxxxxxxBCDxxxxxx (int16) -> 0b0000ABCD (uint8)
assert
x
.
dtype
==
torch
.
int16
s
=
(
right_shift_unsigned
(
x
,
15
)
&
0x1
)
<<
3
em
=
right_shift_unsigned
(
x
,
6
)
&
0x7
return
(
s
|
em
).
to
(
torch
.
uint8
)
def
_bf16x2_to_fp4e2m1x2
(
x
):
# 0bAxxxxxxBCDxxxxxx_0bExxxxxxFGHxxxxxx (int32) -> 0bABCD_EFGH (uint8)
assert
x
.
dtype
==
torch
.
int32
lo
=
(
x
&
0xFFFF
).
to
(
torch
.
int16
)
hi
=
(
right_shift_unsigned
(
x
,
16
)
&
0xFFFF
).
to
(
torch
.
int16
)
ret_lo
=
_bf16_to_fp4e2m1
(
lo
)
ret_hi
=
_bf16_to_fp4e2m1
(
hi
)
return
ret_lo
|
(
ret_hi
<<
4
)
def
_unpack_bits
(
x
,
mx_axis
:
int
):
x
=
x
.
view
(
torch
.
int32
)
m
=
0b10000001110000001000000111000000
a
=
(
x
<<
1
)
&
0b10000000000000001000000000000000
b
=
right_shift_unsigned
(
x
,
3
)
&
0b00000001100000000000000110000000
c
=
right_shift_unsigned
(
x
,
7
)
&
0b00000000010000000000000001000000
unpacked
=
[
x
&
m
,
(
x
<<
3
)
&
m
,
(
x
<<
6
)
&
m
,
(
a
|
b
)
|
c
]
x
=
torch
.
stack
(
unpacked
,
dim
=-
1
)
x
=
x
.
flatten
(
-
2
,
-
1
)
x
=
_bf16x2_to_fp4e2m1x2
(
x
)
return
x
# -----------------------------------------------------------------------
class
HopperMXValueLayout
(
Layout
):
name
:
str
=
"HOPPER_VALUE"
def
__init__
(
self
,
shape
,
mx_axis
,
mma_version
=
3
):
super
().
__init__
(
shape
)
assert
mx_axis
in
range
(
len
(
shape
))
self
.
mx_axis
=
mx_axis
self
.
mma_version
=
mma_version
(
*
self
.
leading_shape
,
self
.
K
,
self
.
N
,
)
=
shape
def
_maybe_mT
(
self
,
data
):
if
self
.
mx_axis
==
len
(
self
.
leading_shape
):
return
data
.
mT
return
data
def
swizzle_data
(
self
,
data
):
"""
Given a uint8 tensor of shape (*, M, K), returns a tensor of shape
(*, M // 4, K * 4) such that:
1) Groups contiguously all the elements owned by the same thread of 4
mma tiles along the K axis. The following animation shows a similar
grouping for 2 tiles along M and 2 tiles along K rather than 4 along K
as done here:
https://neuralmagic.com/wp-content/uploads/2024/10/animation_4.gif
2) Moves the elements belonging to thread 4-7 to be contiguous with those
from thread 0-3. This is done to get a full cache line when loading them
from HBM.
mx_axis selects the lhs or rhs of the matmul.
WARNING: Assumes that the matmul will be done in bf16 or fp16!
Implementing it for fp8 is as easy as making the tile size (8, 8)
"""
batch
=
data
.
ndim
-
2
assert
batch
>=
0
assert
self
.
mma_version
in
(
2
,
3
)
data
=
self
.
_maybe_mT
(
data
)
init_shape
=
data
.
shape
# We are loading 8 bf16 elements per thread to use ld.global.v4
# Every u8 represents 2 mxfp4 elements
u8_kwidth
=
8
//
2
if
self
.
mma_version
==
2
else
1
# Pack the 4 // u8_kwidth subtiles of an mma into a u4x8
contig
=
(
1
,
u8_kwidth
)
scott_trick
=
(
2
,
1
)
threads
=
(
4
,
4
)
warp_tile
=
(
2
,
2
)
k_tile
=
(
1
,
4
//
u8_kwidth
)
sizes
=
list
(
data
.
shape
[:
-
2
])
pads
=
[]
# [rest, K, tile, threads] per dimension
for
i
,
(
a
,
b
,
c
,
s
,
d
)
in
enumerate
(
zip
(
k_tile
,
warp_tile
,
threads
,
scott_trick
,
contig
)
):
pack
=
a
*
b
*
c
*
s
*
d
size
=
data
.
shape
[
batch
+
i
]
pad
=
(
pack
-
size
%
pack
)
%
pack
pads
+=
[(
0
,
pad
)]
sizes
.
append
((
size
+
pad
)
//
pack
)
sizes
+=
[
a
,
b
,
c
,
s
,
d
]
pads
=
tuple
(
x
for
t
in
pads
[::
-
1
]
for
x
in
t
)
data
=
torch
.
nn
.
functional
.
pad
(
data
,
pads
)
init_shape
=
data
.
shape
# 0: rest[0]
# 1: k_tile[0]
# 2: warp_tile[0]
# 3: threads[0]
# 4: scott_trick[0]
# 5: contig[0]
# 6: rest[1]
# 7: k_tile[1]
# 8: warp_tile[1]
# 9: threads[1]
# 10: scott_trick[1]
# 11: contig[1]
data
=
data
.
view
(
*
sizes
)
# Want [rest[0], threads[0], rest[1], scott_trick[0], scott_trick[0], threads[1], contig[1], contig[0], k_tile[1], k_tile[0], warp_tile[1], warp_tile[0]]
perm
=
[
0
,
3
,
6
,
10
,
4
,
9
,
7
,
1
,
8
,
2
,
5
,
11
]
perm
=
list
(
range
(
batch
))
+
[
batch
+
p
for
p
in
perm
]
data
=
data
.
permute
(
*
perm
).
contiguous
()
# These are views
data
=
data
.
flatten
(
-
10
,
-
1
)
data
=
data
.
flatten
(
-
3
,
-
2
)
assert
data
.
is_contiguous
()
assert
data
.
shape
[
-
2
]
==
init_shape
[
-
2
]
//
4
assert
data
.
shape
[
-
1
]
==
init_shape
[
-
1
]
*
4
# twiddle the bits
data
=
_pack_bits
(
data
,
self
.
mx_axis
)
data
=
self
.
_maybe_mT
(
data
)
return
data
def
unswizzle_data
(
self
,
data
):
data
=
self
.
_maybe_mT
(
data
)
data
=
_unpack_bits
(
data
,
self
.
mx_axis
)
*
batch
,
M
,
K
=
data
.
shape
# We have two times the elements if we already upcasted to bfloat16
mult
=
2
if
data
.
dtype
==
torch
.
bfloat16
else
1
assert
M
%
4
==
0
,
"M must be divisible by 4"
assert
K
%
(
4
*
8
*
2
*
2
*
mult
)
==
0
,
(
f
"K must be divisible by
{
4
*
8
*
2
*
2
*
mult
}
"
)
# We are loading 8 bf16 elements per thread to use ld.global.v4
# Every u8 represents 2 mxfp4 elements
u8_kwidth
=
8
//
2
if
self
.
mma_version
==
2
else
1
data
=
data
.
reshape
(
*
batch
,
M
//
4
,
4
,
K
//
(
4
*
8
*
2
*
2
*
mult
),
2
,
4
,
8
//
u8_kwidth
,
2
,
u8_kwidth
*
mult
,
)
b
=
len
(
batch
)
perm
=
[
0
,
6
,
1
,
3
,
2
,
5
,
4
,
7
]
perm
=
list
(
range
(
b
))
+
[
b
+
p
for
p
in
perm
]
data
=
data
.
permute
(
*
perm
)
data
=
data
.
reshape
(
*
batch
,
M
*
4
,
K
//
4
)
data
=
self
.
_maybe_mT
(
data
)
return
data
[...,
:
self
.
K
,
:
self
.
N
]
def
swizzle_block_shape
(
self
,
block_shape
):
return
block_shape
@
triton
.
jit
def
_unshuffle_triton
(
x
,
mma_version
:
tl
.
constexpr
):
"""
Triton inverse of swizzle_mxfp4_value_hopper
"""
tl
.
static_assert
(
mma_version
==
2
or
mma_version
==
3
,
"mma_version must be 2 or 3"
)
# if mx_axis == 0:
# x = x.trans()
# We have two times the elements if we already upcasted to bfloat16
mult
:
tl
.
constexpr
=
2
if
x
.
dtype
==
tl
.
bfloat16
else
1
M
:
tl
.
constexpr
=
x
.
shape
[
0
]
K
:
tl
.
constexpr
=
x
.
shape
[
1
]
tl
.
static_assert
(
M
%
4
==
0
,
"M must be divisible by 4"
)
tl
.
static_assert
(
K
%
(
4
*
8
*
2
*
2
*
mult
)
==
0
,
f
"K must be divisible by
{
4
*
8
*
2
*
2
*
mult
}
"
,
)
# We are loading 8 bf16 elements per thread to use ld.global.v4
# Every u8 represents 2 mxfp4 elements
u8_kwidth
:
tl
.
constexpr
=
8
//
2
if
mma_version
==
2
else
1
x
=
x
.
reshape
(
M
//
4
,
4
,
K
//
(
4
*
8
*
2
*
2
*
mult
),
2
,
4
,
8
//
u8_kwidth
,
2
,
u8_kwidth
*
mult
,
)
x
=
x
.
trans
(
0
,
6
,
1
,
3
,
2
,
5
,
4
,
7
)
x
=
x
.
reshape
(
M
*
4
,
K
//
4
)
# if mx_axis == 0:
# x = x.trans()
return
x
@
triton
.
jit
def
_unpack_fp4_to_bf16_triton
(
x
):
# For now we implement just H100 support (mul.bf16x2)
# A100 support is possible via fma
r0
,
r1
=
tl
.
inline_asm_elementwise
(
r
"""
{
.reg .b32 b, c, d<7>, scale;
.reg .b32 bias;
mov.b32 bias, 0x7e807e80; // 2 ** 126 == 2 ** (bias_bf16 - bias_fp2)
// We add the missing bias to the scale directly
and.b32 $0, $4, 0b10000001110000001000000111000000;
mul.bf16x2 $0, $0, bias;
shl.b32 b, $4, 3;
and.b32 $1, b, 0b10000001110000001000000111000000;
mul.bf16x2 $1, $1, bias;
shl.b32 c, $4, 6;
and.b32 $2, c, 0b10000001110000001000000111000000;
mul.bf16x2 $2, $2, bias;
// Unpack last two elements
shl.b32 d0, $4, 1;
and.b32 d1, d0, 0b10000000000000001000000000000000;
shr.b32 d2, $4, 3;
and.b32 d3, d2, 0b00000001100000000000000110000000;
or.b32 d4, d1, d3;
shr.b32 d5, $4, 7;
and.b32 d6, d5, 0b00000000010000000000000001000000;
or.b32 $3, d4, d6;
mul.bf16x2 $3, $3, bias;
}
"""
,
constraints
=
"=r,=r,=r,=r,r"
,
args
=
[
x
],
dtype
=
(
tl
.
bfloat16
,
tl
.
bfloat16
),
is_pure
=
True
,
pack
=
4
,
)
# Concat each pack of 4
x
=
tl
.
join
(
r0
,
r1
)
x
=
x
.
reshape
(
x
.
shape
[
0
],
x
.
shape
[
1
]
//
4
,
4
,
x
.
shape
[
2
])
x
=
x
.
trans
(
0
,
1
,
3
,
2
)
x
=
x
.
reshape
(
x
.
shape
[
0
],
x
.
shape
[
1
]
*
x
.
shape
[
2
]
*
x
.
shape
[
3
])
return
x
@
triton
.
jit
def
mxfp4_to_bf16_triton
(
x
,
scale
,
mx_axis
:
tl
.
constexpr
):
"""
Implements the bit-untwiddling of a 32-bit integer (8 mxfp4 elements):
(x << 0) & 0b1000000111000000
(x << 3) & 0b1000000111000000
(x << 6) & 0b1000000111000000
((x << 1) & 0b1000000000000000) | ((x >> 3) & 0b0000000110000000) | ((x >> 7) & 0b0000000001000000)
"""
# upcast values to bfloat16
tl
.
static_assert
(
len
(
x
.
shape
)
==
2
)
tl
.
static_assert
(
mx_axis
==
0
or
mx_axis
==
1
,
"mx_axis must be 0 or 1"
)
tl
.
static_assert
(
x
.
shape
[
1
]
%
4
==
0
)
tl
.
static_assert
(
x
.
dtype
==
tl
.
uint8
)
if
mx_axis
==
0
:
x
=
x
.
trans
()
x
=
_unpack_fp4_to_bf16_triton
(
x
)
x
=
_unshuffle_triton
(
x
,
mma_version
=
3
)
if
mx_axis
==
0
:
x
=
x
.
trans
()
# upcast scale to bfloat16
# Add bias missing from the bf16 upcasting sequence
# triton / LLVM generates terrible code for this sequence
# scale = scale.to(tl.uint16)
# scale = scale << 7
# scale = scale.to(tl.bfloat16, bitcast=True)
scale
=
tl
.
inline_asm_elementwise
(
r
"""
{
prmt.b32 $0, $2, 0, 0x5140;
shl.b32 $0, $0, 7;
prmt.b32 $1, $2, 0, 0x7362;
shl.b32 $1, $1, 7;
}
"""
,
constraints
=
"=r,=r,r"
,
args
=
[
scale
],
dtype
=
tl
.
bfloat16
,
is_pure
=
True
,
pack
=
4
,
)
# Broadcast scale
scale
=
scale
.
expand_dims
(
mx_axis
+
1
)
scale
=
scale
.
broadcast_to
(
scale
.
shape
[:
mx_axis
+
1
]
+
[
32
]
+
scale
.
shape
[
mx_axis
+
2
:]
)
scale
=
scale
.
reshape
(
x
.
shape
)
# Combine scale and x
x
=
x
*
scale
return
x
vllm/kvprune_legacy_save/triton_kernels/tensor_details/layout_details/strided.py
deleted
100644 → 0
View file @
2b7160c6
from
.base
import
Layout
class
StridedLayout
(
Layout
):
name
:
str
=
None
def
__init__
(
self
,
shape
)
->
None
:
super
().
__init__
(
shape
)
def
swizzle_data
(
self
,
data
):
return
data
def
unswizzle_data
(
self
,
data
):
return
data
def
swizzle_block_shape
(
self
,
block_shape
):
return
block_shape
vllm/kvprune_legacy_save/triton_kernels/testing.py
deleted
100644 → 0
View file @
2b7160c6
import
enum
import
functools
import
os
import
subprocess
import
sys
import
torch
from
vllm.kvprune.triton_kernels.numerics
import
(
MAX_FINITE_FLOAT8E4B8
,
MAX_FINITE_FLOAT8E4NV
,
MAX_FINITE_FLOAT8E5
,
)
def
assert_equal
(
ref
,
tri
):
if
isinstance
(
ref
,
torch
.
Tensor
):
assert
torch
.
all
(
ref
==
tri
)
else
:
assert
ref
==
tri
def
assert_close
(
ref
,
tri
,
maxtol
=
None
,
rmstol
=
None
,
description
=
"--"
,
verbose
=
True
):
if
tri
.
dtype
.
itemsize
==
1
:
ref_as_type
=
ref
.
to
(
tri
.
dtype
)
if
ref
.
dtype
==
tri
.
dtype
:
assert
torch
.
all
(
ref_as_type
==
tri
)
return
ref
=
ref_as_type
if
ref
.
numel
()
==
0
:
return
if
maxtol
is
None
:
maxtol
=
2e-2
if
rmstol
is
None
:
rmstol
=
4e-3
"""
Compare reference values against obtained values.
"""
# cast to float32:
ref
=
ref
.
to
(
torch
.
float32
).
detach
()
tri
=
tri
.
to
(
torch
.
float32
).
detach
()
assert
ref
.
shape
==
tri
.
shape
,
(
f
"Tensors must have same size
{
ref
.
shape
=
}
{
tri
.
shape
=
}
"
)
# deal with infinite elements:
inf_mask_ref
=
torch
.
isinf
(
ref
)
inf_mask_tri
=
torch
.
isinf
(
tri
)
assert
torch
.
equal
(
inf_mask_ref
,
inf_mask_tri
),
(
"Tensor must have same infinite elements"
)
refn
=
torch
.
where
(
inf_mask_ref
,
0
,
ref
)
trin
=
torch
.
where
(
inf_mask_tri
,
0
,
tri
)
# normalise so that RMS calculation doesn't overflow:
eps
=
1.0e-30
multiplier
=
1.0
/
(
torch
.
max
(
torch
.
abs
(
refn
))
+
eps
)
refn
*=
multiplier
trin
*=
multiplier
ref_rms
=
torch
.
sqrt
(
torch
.
square
(
refn
).
mean
())
+
eps
rel_err
=
torch
.
abs
(
refn
-
trin
)
/
torch
.
maximum
(
ref_rms
,
torch
.
abs
(
refn
))
max_err
=
torch
.
max
(
rel_err
).
item
()
rms_err
=
torch
.
sqrt
(
torch
.
square
(
rel_err
).
mean
()).
item
()
if
verbose
:
print
(
"%s maximum relative error = %s (threshold = %s)"
%
(
description
,
max_err
,
maxtol
)
)
print
(
"%s RMS relative error = %s (threshold = %s)"
%
(
description
,
rms_err
,
rmstol
)
)
if
max_err
>
maxtol
:
bad_idxs
=
torch
.
nonzero
(
rel_err
>
maxtol
)
num_nonzero
=
bad_idxs
.
size
(
0
)
bad_idxs
=
bad_idxs
[:
1000
]
print
(
"%d / %d mismatched elements (shape = %s) at coords %s"
%
(
num_nonzero
,
rel_err
.
numel
(),
tuple
(
rel_err
.
shape
),
bad_idxs
.
tolist
())
)
bad_idxs
=
bad_idxs
.
unbind
(
-
1
)
print
(
"ref values: "
,
ref
[
tuple
(
bad_idxs
)].
cpu
())
print
(
"tri values: "
,
tri
[
tuple
(
bad_idxs
)].
cpu
())
assert
max_err
<=
maxtol
assert
rms_err
<=
rmstol
class
ComputeSanitizerTool
(
enum
.
Enum
):
MEMCHECK
=
"memcheck"
RACECHECK
=
"racecheck"
SYNCCHECK
=
"synccheck"
INITCHECK
=
"initcheck"
def
compute_sanitizer
(
**
target_kwargs
):
"""
Decorator to run a test with compute sanitizer enabled and pytorch caching allocator disabled,
to expose potential memory access errors.
This decorator requires the `request` fixture to be present.
If `run_sanitizer` argument is present and set to False, the sanitizer is not run.
Running tests under compute sanitizer requires launching subprocess and is slow,
so use sparingly
"""
def
decorator
(
test_fn
):
@
functools
.
wraps
(
test_fn
)
def
wrapper
(
*
args
,
**
kwargs
):
if
os
.
environ
.
get
(
"SKIP_COMPUTE_SANITIZER"
)
==
"1"
:
test_fn
(
*
args
,
**
kwargs
)
return
import
psutil
if
target_kwargs
.
pop
(
"clear_torch_cache"
,
False
):
# If we don't pop clear_torch_cache, it won't pass
# target_kwargs.items() <= kwargs.items() condition below.
torch
.
cuda
.
empty_cache
()
tools_to_check
=
target_kwargs
.
pop
(
"tools_to_check"
,
[
ComputeSanitizerTool
.
MEMCHECK
]
)
assert
isinstance
(
tools_to_check
,
list
),
f
"
{
tools_to_check
=
}
"
assert
all
(
tool
in
ComputeSanitizerTool
for
tool
in
tools_to_check
),
(
f
"
{
(
tool
for
tool
in
tools_to_check
if
tool
not
in
ComputeSanitizerTool
)
=
}
"
)
ppid_name
=
psutil
.
Process
(
os
.
getppid
()).
exe
()
run_compute_sanitizer
=
target_kwargs
.
items
()
<=
kwargs
.
items
()
if
"run_sanitizer"
in
kwargs
:
run_compute_sanitizer
&=
kwargs
[
"run_sanitizer"
]
if
run_compute_sanitizer
and
"compute-sanitizer"
not
in
ppid_name
:
for
tool
in
tools_to_check
:
path
=
os
.
path
.
realpath
(
test_fn
.
__globals__
[
"__file__"
])
# get path of current file
env
=
{
"PATH"
:
os
.
environ
[
"PATH"
],
"PYTORCH_NO_CUDA_MEMORY_CACHING"
:
"1"
,
"TORCH_SHOW_CPP_STACKTRACES"
:
"1"
,
"CUDA_LAUNCH_BLOCKING"
:
"1"
,
}
if
"CUDA_VISIBLE_DEVICES"
in
os
.
environ
:
env
[
"CUDA_VISIBLE_DEVICES"
]
=
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
assert
"request_fixture"
in
kwargs
,
(
"memcheck'ed test must have a (possibly unused) `request` fixture"
)
test_id
=
kwargs
[
"request_fixture"
].
node
.
callspec
.
id
cmd
=
f
"
{
path
}
::
{
test_fn
.
__name__
}
[
{
test_id
}
]"
cmd
=
[
"compute-sanitizer"
,
"--target-processes=application-only"
,
"--destroy-on-device-error=context"
,
f
"--tool=
{
tool
.
value
}
"
,
sys
.
executable
,
"-m"
,
"pytest"
,
"-vsx"
,
cmd
,
]
for
opt
in
[
"--update_checksum"
,
"--ignore_checksum_error"
]:
if
opt
in
sys
.
argv
:
cmd
.
append
(
opt
)
out
=
subprocess
.
run
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
STDOUT
,
env
=
env
,
)
sanitizer_ok
=
"ERROR SUMMARY: 0 errors"
in
str
(
out
.
stdout
)
or
"RACECHECK SUMMARY: 0 hazards displayed"
in
str
(
out
.
stdout
)
test_output
=
out
.
stdout
if
type
(
test_output
)
is
bytes
:
test_output
=
test_output
.
decode
()
fail
=
False
if
not
sanitizer_ok
:
print
(
"compute-sanitizer returned an error"
)
fail
=
True
elif
out
.
returncode
!=
0
:
print
(
"The test failed due to some other reason: consider running without compute-sanitizer to verify."
)
print
(
f
"
{
out
.
returncode
=
}
"
)
fail
=
True
if
fail
:
print
(
"*****************************************************"
)
print
(
"******************** TEST OUTPUT ********************"
)
print
(
"*****************************************************"
)
print
(
test_output
)
print
(
"*****************************************************"
)
print
(
"****************** TEST OUTPUT END ******************"
)
print
(
"*****************************************************"
)
assert
None
else
:
test_fn
(
*
args
,
**
kwargs
)
return
wrapper
return
decorator
def
compute_actual_scale
(
x
,
dtype
):
max_finite
=
{
torch
.
float8_e5m2
:
MAX_FINITE_FLOAT8E5
,
torch
.
float8_e4m3fn
:
MAX_FINITE_FLOAT8E4NV
,
torch
.
float8_e4m3fnuz
:
MAX_FINITE_FLOAT8E4B8
,
}[
dtype
]
return
x
.
abs
().
max
()
/
max_finite
vllm/kvprune_legacy_save/triton_kernels/topk.py
deleted
100644 → 0
View file @
2b7160c6
import
torch
import
triton
from
vllm.kvprune.triton_kernels.topk_details._topk_forward
import
_topk_forward
from
vllm.kvprune.triton_kernels.topk_details
import
_topk_backward
from
vllm.kvprune.triton_kernels.tensor
import
Tensor
,
Bitmatrix
from
typing
import
Optional
,
Union
def
topk_forward
(
x
,
k
,
apply_softmax
=
True
,
dim
=
1
,
return_bitmatrix
=
True
,
y_indx
=
None
,
n_rows
=
None
):
if
not
isinstance
(
x
,
Tensor
):
x_shape
=
[
x
.
shape
[
0
]
if
n_rows
is
None
else
n_rows
,
x
.
shape
[
1
]]
x_shape_max
=
[
x
.
shape
[
0
],
x
.
shape
[
1
]]
x
=
Tensor
(
x
,
shape
=
x_shape
,
shape_max
=
x_shape_max
)
cdiv
=
lambda
a
,
b
:
(
a
+
b
-
1
)
//
b
BLOCK_M
=
32
BLOCK_N
=
32
BLOCK_S
=
128
assert
len
(
x
.
shape
)
==
2
assert
x
.
shape_max
[
-
1
]
<
32768
assert
dim
==
1
assert
return_bitmatrix
n_rows
,
n_cols
=
x
.
shape
n_rows_max
,
_
=
x
.
shape_max
dev
=
x
.
device
# scratchpad tensors
# NOTE: these are not returned
y_vals
=
torch
.
empty
((
n_rows_max
,
k
),
dtype
=
x
.
dtype
,
device
=
dev
)
if
y_indx
is
not
None
:
use_provided_indx
=
True
else
:
y_indx
=
torch
.
empty
((
n_rows_max
,
k
),
dtype
=
torch
.
int16
,
device
=
dev
)
use_provided_indx
=
False
# create bitmatrix in transposed memory layout:
n_cols_pad
=
cdiv
(
n_cols
,
BLOCK_N
)
*
BLOCK_N
n_cols_words
=
n_cols_pad
//
32
bitmatrix
=
torch
.
empty
(
(
n_cols_words
,
cdiv
(
n_rows_max
,
32
)
*
32
),
dtype
=
torch
.
uint32
,
device
=
dev
)
bitmatrix
=
torch
.
transpose
(
bitmatrix
,
0
,
1
)[:
n_rows_max
]
s_blocks
=
cdiv
(
n_cols
,
BLOCK_S
)
s_cols
=
s_blocks
*
BLOCK_S
scratchpad
=
torch
.
empty
((
s_cols
,),
dtype
=
torch
.
int32
,
device
=
dev
)
pids
=
max
(
cdiv
(
n_rows_max
,
BLOCK_M
),
s_blocks
)
_topk_forward
[(
pids
,)](
x
,
x
.
stride
(
0
),
# inputs
y_vals
,
y_indx
,
y_vals
.
stride
(
0
),
use_provided_indx
,
# output [topk]
bitmatrix
,
bitmatrix
.
stride
(
0
),
bitmatrix
.
stride
(
1
),
# output [bitmatrix]
n_rows
,
n_cols
,
# shapes
scratchpad
,
BLOCK_S
,
s_blocks
,
# thing to memset to zero
BLOCK_M
=
BLOCK_M
,
BLOCK_N
=
BLOCK_N
,
# tunable parameter
APPLY_SOFTMAX
=
apply_softmax
,
N_EXPTS_PAD
=
n_cols_pad
,
N_EXPTS_ACT
=
k
,
# constants
)
bitmatrix_shape
=
[
n_rows
,
n_cols_words
*
32
]
bitmatrix_shape_max
=
[
n_rows_max
,
None
]
bitmatrix
=
Bitmatrix
(
bitmatrix
,
shape
=
bitmatrix_shape
,
shape_max
=
bitmatrix_shape_max
,
scratchpad
=
scratchpad
,
)
return
y_vals
,
y_indx
,
bitmatrix
def
topk_backward
(
x
,
y_indx
,
dy_vals
,
k
,
n_rows
,
apply_softmax
):
assert
dy_vals
.
shape
[
-
1
]
==
k
n_expts_pad
=
triton
.
next_power_of_2
(
x
.
shape
[
-
1
])
dx
=
torch
.
empty_like
(
x
)
_topk_backward
[(
dy_vals
.
shape
[
0
],)](
y_indx
,
y_indx
.
stride
(
0
),
dy_vals
,
dy_vals
.
stride
(
0
),
x
,
x
.
stride
(
0
),
# inputs
dx
,
# outputs
dx
.
stride
(
0
),
x
.
shape
[
0
],
n_rows
,
x
.
shape
[
-
1
],
APPLY_SOFTMAX
=
apply_softmax
,
N_EXPTS_ACT
=
k
,
N_EXPTS_PAD
=
n_expts_pad
,
)
return
dx
class
TopK
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
k
,
apply_softmax
,
dim
,
return_bitmatrix
,
y_indx
,
n_rows
):
y_vals
,
y_indx
,
bitmatrix
=
topk_forward
(
x
,
k
,
apply_softmax
,
dim
,
return_bitmatrix
,
y_indx
,
n_rows
)
ctx
.
save_for_backward
(
x
,
y_indx
)
ctx
.
apply_softmax
=
apply_softmax
ctx
.
k
=
k
ctx
.
n_rows
=
n_rows
return
y_vals
,
y_indx
,
bitmatrix
@
staticmethod
def
backward
(
ctx
,
dy_vals
,
_0
,
_1
):
x
,
y_indx
=
ctx
.
saved_tensors
dx
=
topk_backward
(
x
,
y_indx
,
dy_vals
,
ctx
.
k
,
ctx
.
n_rows
,
ctx
.
apply_softmax
)
return
dx
,
None
,
None
,
None
,
None
,
None
,
None
def
topk
(
x
:
Union
[
Tensor
,
torch
.
Tensor
],
k
:
int
,
apply_softmax
:
bool
=
True
,
dim
:
int
=
1
,
return_bitmatrix
:
bool
=
True
,
y_indx
:
Optional
[
torch
.
Tensor
]
=
None
,
n_rows
:
Optional
[
int
]
=
None
,
):
"""
Computes the top-k values and indices along a specified dimension of a tensor.
Note that the input can be either a `Tensor` or a `torch.Tensor`, but the output will always be a `torch.Tensor`.
Parameters
----------
x : Union[triton_kernels.Tensor, torch.Tensor]
Input tensor of shape (n_tokens, n_expts).
k : int
Number of top elements to retrieve.
apply_softmax : bool, default True
Whether to apply softmax to the input tensor before computing top-k.
dim : int, default 1
Dimension along which to compute top-k.
return_bitmatrix : bool, default True
A bitmatrix of shape (n_tokens, cdiv(n_expts, 32)).
Each bit on [t, b] indicates whether the b-th expert was selected for the t-th token.
y_indx : torch.Tensor, optional
Pre-allocated tensor for storing indices of top-k elements with shape (n_tokens, k).
If provided, we skip the computation of top-k indices and use this tensor instead.
n_rows : int, optional
Number of rows to apply top-k on. If None, we consider all rows in `x`.
Returns
-------
(expt_scal, expt_indx, bitmatrix) : Tuple[torch.Tensor, torch.Tensor, Bitmatrix]
"""
ret
=
TopK
.
apply
(
x
,
k
,
apply_softmax
,
dim
,
return_bitmatrix
,
y_indx
,
n_rows
)
return
ret
vllm/kvprune_legacy_save/triton_kernels/topk_details/__init__.py
deleted
100644 → 0
View file @
2b7160c6
vllm/kvprune_legacy_save/triton_kernels/topk_details/_topk_backward.py
deleted
100644 → 0
View file @
2b7160c6
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
_topk_backward
(
Yi
,
stride_ym
,
# topk indices
DY
,
stride_dym
,
# output gradient values
X
,
stride_xm
,
# input values
DX
,
stride_dxm
,
# input gradient values
n_rows
,
NRows
,
n_expts_tot
,
APPLY_SOFTMAX
:
tl
.
constexpr
,
N_EXPTS_ACT
:
tl
.
constexpr
,
N_EXPTS_PAD
:
tl
.
constexpr
,
):
pid_m
=
tl
.
program_id
(
0
)
if
NRows
is
not
None
:
n_rows
=
tl
.
load
(
NRows
)
if
pid_m
>=
n_rows
:
return
Yi
+=
pid_m
*
stride_ym
DY
+=
pid_m
*
stride_dym
X
+=
pid_m
*
stride_xm
DX
+=
pid_m
*
stride_dxm
# --
offs_xn
=
tl
.
arange
(
0
,
N_EXPTS_PAD
)
offs_yn
=
tl
.
arange
(
0
,
N_EXPTS_ACT
)
mask_xn
=
offs_xn
<
n_expts_tot
# recompute softmax
y_indx
=
tl
.
load
(
Yi
+
offs_yn
)
x
=
tl
.
load
(
X
+
y_indx
)
x
=
x
.
to
(
tl
.
float32
)
y
=
tl
.
softmax
(
x
)
# compute input-gradient
dy
=
tl
.
load
(
DY
+
offs_yn
)
dy
=
dy
.
to
(
tl
.
float32
)
s
=
tl
.
sum
(
y
*
dy
,
0
)
# write-back input gradient
tl
.
store
(
DX
+
offs_xn
,
0
,
mask
=
mask_xn
)
tl
.
debug_barrier
()
if
APPLY_SOFTMAX
:
dx
=
y
*
(
dy
-
s
)
else
:
dx
=
dy
tl
.
store
(
DX
+
y_indx
,
dx
)
vllm/kvprune_legacy_save/triton_kernels/topk_details/_topk_forward.py
deleted
100644 → 0
View file @
2b7160c6
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
get_topmask_and_fullmask
(
x
):
tl
.
static_assert
(
x
.
dtype
.
is_int_unsigned
(),
"floating-point value must be passed as bits"
)
tm
:
tl
.
constexpr
=
1
<<
(
-
1
+
x
.
dtype
.
primitive_bitwidth
)
fm
:
tl
.
constexpr
=
(
1
<<
x
.
dtype
.
primitive_bitwidth
)
-
1
tm_arr
=
tl
.
full
(
x
.
shape
,
tm
,
dtype
=
x
.
dtype
)
fm_arr
=
tl
.
full
(
x
.
shape
,
fm
,
dtype
=
x
.
dtype
)
return
tm_arr
,
fm_arr
@
triton
.
jit
def
fpval_to_key
(
x
):
tm
,
fm
=
get_topmask_and_fullmask
(
x
)
return
x
^
tl
.
where
((
x
&
tm
)
!=
0
,
fm
,
tm
)
@
triton
.
jit
def
key_to_fpval
(
x
):
tm
,
fm
=
get_topmask_and_fullmask
(
x
)
return
x
^
tl
.
where
((
x
&
tm
)
==
0
,
fm
,
tm
)
# stable top-k tie-breaks to value with smaller index
@
triton
.
jit
def
indx_to_key
(
indx
,
N_EXPTS_PAD
:
tl
.
constexpr
):
return
N_EXPTS_PAD
-
indx
@
triton
.
jit
def
key_to_indx
(
indx
,
N_EXPTS_PAD
:
tl
.
constexpr
):
return
N_EXPTS_PAD
-
indx
@
triton
.
jit
def
streaming_topk
(
X
,
stride_xm
,
n_expts_tot
,
offs_m
,
mask_m
,
N_EXPTS_PAD
:
tl
.
constexpr
,
N_EXPTS_ACT
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
x_nbits
:
tl
.
constexpr
=
X
.
dtype
.
element_ty
.
primitive_bitwidth
x_utype
:
tl
.
constexpr
=
tl
.
dtype
(
f
"uint
{
x_nbits
}
"
)
if
x_nbits
<
16
:
# this ensures that we leave at least 16 bits for expert index
# even if the input dtype is smaller than 16 bits:
y_nbits
:
tl
.
constexpr
=
32
else
:
y_nbits
:
tl
.
constexpr
=
x_nbits
*
2
x_ultype
:
tl
.
constexpr
=
tl
.
dtype
(
f
"uint
{
y_nbits
}
"
)
x_dtype
:
tl
.
constexpr
=
X
.
dtype
.
element_ty
# subtract 1 from loop iterations because we peel the first (masked) iteration:
loop_iterations
:
tl
.
constexpr
=
N_EXPTS_PAD
//
BLOCK_N
-
1
offs_x_n
=
loop_iterations
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
mask_n
=
offs_x_n
[
None
,
:]
<
n_expts_tot
# first iteration:
X_ptrs
=
X
+
offs_m
[:,
None
]
*
stride_xm
+
offs_x_n
[
None
,
:]
x
=
tl
.
load
(
X_ptrs
,
mask
=
(
mask_m
&
mask_n
),
other
=
float
(
"-inf"
))
x
=
fpval_to_key
(
x
.
to
(
x_utype
,
bitcast
=
True
))
x
=
(
x
.
to
(
x_ultype
)
<<
16
)
|
indx_to_key
(
offs_x_n
,
N_EXPTS_PAD
)[
None
,
:]
acc
=
tl
.
topk
(
x
,
N_EXPTS_ACT
,
dim
=
1
)
# subsequent iterations:
for
_i
in
(
tl
.
static_range
if
loop_iterations
<=
4
else
range
)(
loop_iterations
):
acc
=
tl
.
bitonic_merge
(
acc
)
# ensure sorted ascending for the merge
X_ptrs
-=
BLOCK_N
offs_x_n
-=
BLOCK_N
x
=
tl
.
load
(
X_ptrs
,
mask
=
mask_m
,
other
=
float
(
"-inf"
))
x
=
fpval_to_key
(
x
.
to
(
x_utype
,
bitcast
=
True
))
x
=
(
x
.
to
(
x_ultype
)
<<
16
)
|
indx_to_key
(
offs_x_n
,
N_EXPTS_PAD
)[
None
,
:]
acc
=
tl
.
maximum
(
acc
,
tl
.
topk
(
x
,
N_EXPTS_ACT
,
dim
=
1
))
# rotate expert index into upper 16 bits:
# 0000vvvvvvvviiii --> iiii0000vvvvvvvv
acc
=
(
acc
<<
(
y_nbits
-
16
))
|
(
acc
>>
16
)
# sort in ascending order of expert (descending order of key)
acc
=
tl
.
sort
(
acc
,
dim
=
1
,
descending
=
True
)
# iiii0000vvvvvvvv --> 0000iiii:
y_indices_raw
=
(
acc
>>
(
y_nbits
-
16
)).
to
(
tl
.
uint32
)
y_indices
=
key_to_indx
(
y_indices_raw
,
N_EXPTS_PAD
)
# iiii0000vvvvvvvv --> vvvvvvvv:
y_values_raw
=
acc
.
to
(
x_utype
)
y_values
=
key_to_fpval
(
y_values_raw
).
to
(
x_dtype
,
bitcast
=
True
)
return
y_values
,
y_indices
@
triton
.
jit
def
_topk_forward
(
X
,
stride_xm
,
# inputs
Yv
,
Yi
,
stride_ym
,
# topk values/indices
USE_PROVIDED_INDX
:
tl
.
constexpr
,
Bits
,
stride_rm
:
tl
.
constexpr
,
stride_rn
:
tl
.
constexpr
,
# bitmatrix
n_rows
,
n_expts_tot
,
# shape
S
,
BLOCK_S
:
tl
.
constexpr
,
s_blocks
,
# thing to memset
APPLY_SOFTMAX
:
tl
.
constexpr
,
# constant
BLOCK_M
:
tl
.
constexpr
,
N_EXPTS_PAD
:
tl
.
constexpr
,
N_EXPTS_ACT
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
if
isinstance
(
n_rows
,
tl
.
tensor
)
and
n_rows
.
dtype
.
is_ptr
():
n_rows
=
tl
.
load
(
n_rows
)
if
pid
<
s_blocks
:
tl
.
store
(
S
+
BLOCK_S
*
pid
+
tl
.
arange
(
0
,
BLOCK_S
),
tl
.
zeros
([
BLOCK_S
],
tl
.
int32
)
)
if
pid
*
BLOCK_M
>=
n_rows
:
# early exit:
return
tl
.
static_assert
(
BLOCK_N
%
32
==
0
)
tl
.
static_assert
(
N_EXPTS_PAD
%
BLOCK_N
==
0
)
x_dtype
:
tl
.
constexpr
=
X
.
dtype
.
element_ty
# load logits
offs_m
=
pid
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_y_n
=
tl
.
arange
(
0
,
N_EXPTS_ACT
)
mask_m
=
offs_m
[:,
None
]
<
n_rows
if
USE_PROVIDED_INDX
:
Yi_ptrs
=
Yi
+
offs_m
[:,
None
]
*
stride_ym
+
offs_y_n
[
None
,
:]
y_indices
=
tl
.
load
(
Yi_ptrs
,
mask
=
mask_m
)
Xv_ptrs
=
X
+
offs_m
[:,
None
]
*
stride_xm
+
y_indices
y_values
=
tl
.
load
(
Xv_ptrs
,
mask
=
mask_m
)
else
:
y_values
,
y_indices
=
streaming_topk
(
X
,
stride_xm
,
n_expts_tot
,
offs_m
,
mask_m
,
#
N_EXPTS_PAD
,
N_EXPTS_ACT
,
BLOCK_N
,
)
# normalize selected values
if
APPLY_SOFTMAX
:
y_values
=
tl
.
softmax
(
y_values
.
to
(
tl
.
float32
),
dim
=
1
,
keep_dims
=
True
).
to
(
x_dtype
)
# write back
Yv_ptrs
=
Yv
+
offs_m
[:,
None
]
*
stride_ym
+
offs_y_n
[
None
,
:]
tl
.
store
(
Yv_ptrs
,
y_values
,
mask
=
mask_m
)
if
not
USE_PROVIDED_INDX
:
Yi_ptrs
=
Yi
+
offs_m
[:,
None
]
*
stride_ym
+
offs_y_n
[
None
,
:]
tl
.
store
(
Yi_ptrs
,
y_indices
,
mask
=
mask_m
)
# pack into bitmatrix
y_div
=
y_indices
//
32
y_rem
=
y_indices
%
32
loop_iterations
=
N_EXPTS_PAD
//
BLOCK_N
for
i
in
range
(
loop_iterations
):
offs_r_n
=
tl
.
arange
(
0
,
BLOCK_N
//
32
)
+
i
*
(
BLOCK_N
//
32
)
y2
=
tl
.
where
(
y_div
[:,
:,
None
]
==
offs_r_n
[
None
,
None
,
:],
(
1
<<
y_rem
)[:,
:,
None
],
0
)
r
=
tl
.
reduce_or
(
y2
,
axis
=
1
)
BitsPtrs
=
Bits
+
offs_m
[:,
None
]
*
stride_rm
+
offs_r_n
[
None
,
:]
*
stride_rn
tl
.
store
(
BitsPtrs
,
r
,
mask
=
mask_m
)
vllm/kvprune_legacy_save/utils/__init__.py
deleted
100644 → 0
View file @
2b7160c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Shared helpers: Triton compat, layout bridge, context, sequences."""
from
vllm.kvprune.utils.layout_bridge
import
(
block_table_to_global_page_table
,
build_batch_mapping
,
build_page_table_head_major
,
flatten_kv_cache_head_major
,
flatten_kv_cache_plane
,
write_head_major_flat_to_interleaved
,
)
from
vllm.kvprune.utils.triton_compat
import
(
autotune
as
triton_autotune
,
cuda_capability_geq
,
maybe_set_allocator
,
)
__all__
=
[
"block_table_to_global_page_table"
,
"build_batch_mapping"
,
"build_page_table_head_major"
,
"cuda_capability_geq"
,
"flatten_kv_cache_head_major"
,
"flatten_kv_cache_plane"
,
"write_head_major_flat_to_interleaved"
,
"maybe_set_allocator"
,
"triton_autotune"
,
]
vllm/kvprune_legacy_save/utils/arguments.py
deleted
100644 → 0
View file @
2b7160c6
import
itertools
import
math
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
import
torch
from
vllm.kvprune.compression
import
CompressionMethod
from
vllm.kvprune.compression.compression_config
import
BatchCompressionParams
from
vllm.kvprune.config.engine_config
import
LLMConfig
from
vllm.kvprune.utils.sequence
import
Sequence
from
vllm.kvprune.utils.kv_dist
import
broadcast_from_tp_rank0
from
vllm.kvprune.utils.tp_utils
import
kv_heads_shard_divisor
@
dataclass
class
PrefillBatchArguments
:
B
:
int
N
:
int
do_compression
:
bool
compression_method
:
CompressionMethod
compression_chunk_size
:
int
seq_ids
:
torch
.
Tensor
input_ids
:
torch
.
Tensor
positions
:
torch
.
Tensor
cu_seqlens_q
:
torch
.
Tensor
cu_seqlens_k
:
torch
.
Tensor
max_seqlen_q
:
int
max_seqlen_k
:
int
batch_tokens_to_retain
:
Optional
[
torch
.
Tensor
]
max_tokens_to_retain
:
Optional
[
int
]
protected_first
:
Optional
[
List
[
int
]]
protected_last
:
Optional
[
List
[
int
]]
PHI
:
Optional
[
torch
.
Tensor
]
# args needed for memory reservation
context_lens
:
torch
.
Tensor
max_new_tokens
:
torch
.
Tensor
# 与 kvpress ``CompactorPress`` blending 默认(未显式指定时用 compression_ratio)对齐
compression_ratio
:
float
=
1.0
class
PackedTensorArguments
:
def
__init__
(
self
,
rank
:
int
,
max_batched_tokens
:
int
,
config
:
LLMConfig
,
seed
:
int
=
42
,
*
,
device
:
torch
.
device
|
None
=
None
,
use_tp_group_for_collectives
:
bool
=
False
,
)
->
None
:
hf_config
=
config
.
hf_config
self
.
rank
=
rank
self
.
device
=
device
if
device
is
not
None
else
torch
.
device
(
f
"cuda:
{
rank
}
"
)
self
.
_use_tp_group
=
use_tp_group_for_collectives
self
.
max_num_batches
=
config
.
max_num_seqs
self
.
max_batched_tokens
=
max_batched_tokens
_ws
=
kv_heads_shard_divisor
()
self
.
num_kv_heads
=
hf_config
.
num_key_value_heads
//
_ws
self
.
world_size
=
config
.
tensor_parallel_size
self
.
page_size
=
int
(
config
.
kvcache_page_size
)
self
.
head_dim
=
getattr
(
hf_config
,
"head_dim"
,
None
)
self
.
sketch_dim
=
config
.
leverage_sketch_size
self
.
model_dtype
=
hf_config
.
torch_dtype
# i64 pack = [seq_ids (BMAX)] || [input_ids (NMAX)] || [positions (NMAX)] || max_new_tok (BMAX)
self
.
i64_len_max
=
(
self
.
max_num_batches
+
2
*
self
.
max_batched_tokens
+
self
.
max_num_batches
)
self
.
packed_context_i64
=
torch
.
empty
(
self
.
i64_len_max
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
# i32 pack = [header (6): ... + compression_ratio*1e6] || [cu_q (BMAX+1)] || ...
# || [protected_first_tokens (BMAX)] || [protected_last_tokens (BMAX)]
self
.
i32_len_max
=
(
6
+
(
self
.
max_num_batches
+
1
)
+
(
self
.
max_num_batches
+
1
)
+
self
.
max_num_batches
+
self
.
max_num_batches
+
self
.
max_num_batches
+
self
.
max_num_batches
)
self
.
packed_context_i32
=
torch
.
empty
(
self
.
i32_len_max
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
seed
)
self
.
PHI
=
torch
.
randn
(
(
self
.
head_dim
,
self
.
sketch_dim
),
device
=
self
.
packed_context_i32
.
device
,
generator
=
self
.
generator
,
).
to
(
self
.
model_dtype
)
*
(
1
/
math
.
sqrt
(
self
.
sketch_dim
))
def
_master_build_prefill
(
self
,
seqs
:
List
[
Sequence
],
batch_compression_params
:
BatchCompressionParams
)
->
PrefillBatchArguments
:
B
=
len
(
seqs
)
if
B
==
0
:
raise
ValueError
(
"prefill batch is empty (scheduler should not call build_prefill with "
"no sequences)"
)
Ls
=
[
x
.
prompt_len
for
x
in
seqs
]
N
=
sum
(
Ls
)
assert
N
<=
self
.
max_batched_tokens
do_compression
=
any
(
x
.
compression_params
.
compression_ratio
<
1.0
for
x
in
seqs
)
do_compression
=
(
do_compression
and
batch_compression_params
.
compression_method
!=
CompressionMethod
.
NONE
)
pack_slices_64
=
self
.
packed_i64_slices
(
B
,
N
)
pack_slices_32
=
self
.
packed_i32_slices
(
B
)
# max_retain = max(retain)
protected_first_list
=
[
x
.
compression_params
.
protected_first_tokens
for
x
in
seqs
]
protected_last_list
=
[
x
.
compression_params
.
protected_last_tokens
for
x
in
seqs
]
retain
=
[
max
(
int
(
round
(
x
.
compression_params
.
compression_ratio
*
(
L
-
s
-
e
)
*
self
.
num_kv_heads
)
),
1
,
)
for
s
,
e
,
L
,
x
in
zip
(
protected_first_list
,
protected_last_list
,
Ls
,
seqs
)
]
retain
=
torch
.
tensor
(
retain
,
dtype
=
torch
.
int32
,
device
=
"cpu"
,
pin_memory
=
True
)
protected_first
=
torch
.
tensor
(
protected_first_list
,
dtype
=
torch
.
int32
,
device
=
"cpu"
,
pin_memory
=
True
)
protected_last
=
torch
.
tensor
(
protected_last_list
,
dtype
=
torch
.
int32
,
device
=
"cpu"
,
pin_memory
=
True
)
self
.
packed_context_i32
[
pack_slices_32
[
"protected_first"
]].
copy_
(
protected_first
,
non_blocking
=
True
)
self
.
packed_context_i32
[
pack_slices_32
[
"protected_last"
]].
copy_
(
protected_last
,
non_blocking
=
True
)
compression_chunk_size
=
(
batch_compression_params
.
chunk_size
if
batch_compression_params
.
do_chunked_compression
else
-
1
)
min_compression_ratio
=
min
(
x
.
compression_params
.
compression_ratio
for
x
in
seqs
)
cr_scaled
=
int
(
round
(
float
(
min_compression_ratio
)
*
1_000_000.0
))
cr_scaled
=
max
(
min
(
cr_scaled
,
2_000_000_000
),
-
2_000_000_000
)
header_host
=
torch
.
tensor
(
[
B
,
N
,
1
if
do_compression
else
0
,
batch_compression_params
.
compression_method
.
value
,
compression_chunk_size
,
cr_scaled
,
],
dtype
=
torch
.
int32
,
device
=
"cpu"
,
pin_memory
=
True
,
)
self
.
packed_context_i32
[
pack_slices_32
[
"retain"
]].
copy_
(
retain
,
non_blocking
=
True
)
self
.
packed_context_i32
[
pack_slices_32
[
"header"
]].
copy_
(
header_host
,
non_blocking
=
True
)
max_seq_qk
=
max
(
Ls
)
cu
=
torch
.
tensor
(
list
(
itertools
.
accumulate
(
Ls
,
initial
=
0
)),
dtype
=
torch
.
int32
,
device
=
"cpu"
,
pin_memory
=
True
,
)
self
.
packed_context_i32
[
pack_slices_32
[
"cu_q"
]].
copy_
(
cu
,
non_blocking
=
True
)
self
.
packed_context_i32
[
pack_slices_32
[
"cu_k"
]].
copy_
(
cu
,
non_blocking
=
True
)
self
.
packed_context_i32
[
pack_slices_32
[
"context_lens"
]].
copy_
(
cu
.
diff
(),
non_blocking
=
True
)
seq_ids
=
torch
.
tensor
(
[
x
.
seq_id
for
x
in
seqs
],
dtype
=
torch
.
int64
,
device
=
"cpu"
,
pin_memory
=
True
)
input_ids
=
torch
.
tensor
(
[
tid
for
x
in
seqs
for
tid
in
x
.
prompt_token_ids
],
dtype
=
torch
.
int64
,
device
=
"cpu"
,
pin_memory
=
True
,
)
self
.
packed_context_i64
[
pack_slices_64
[
"seq_ids"
]].
copy_
(
seq_ids
,
non_blocking
=
True
)
self
.
packed_context_i64
[
pack_slices_64
[
"input_ids"
]].
copy_
(
input_ids
,
non_blocking
=
True
)
positions
=
torch
.
cat
(
[
torch
.
arange
(
L
,
dtype
=
torch
.
int64
,
device
=
"cpu"
,
pin_memory
=
True
)
for
L
in
Ls
]
)
self
.
packed_context_i64
[
pack_slices_64
[
"positions"
]].
copy_
(
positions
,
non_blocking
=
True
)
max_new_tokens
=
torch
.
tensor
(
[
seq
.
sampling_params
.
max_new_tokens
for
seq
in
seqs
],
dtype
=
torch
.
int64
,
device
=
"cpu"
,
pin_memory
=
True
,
)
self
.
packed_context_i64
[
pack_slices_64
[
"max_new_tokens"
]].
copy_
(
max_new_tokens
,
non_blocking
=
True
)
# `prefill_store_topk_kv(..., PAD_TO_PAGE_SIZE=True)` may scan beyond the
# top-k prefix to fill per-head lengths up to a page boundary. Using a
# full ranking (top_k = max_seq_len * HKV) makes `torch.topk` degenerate
# into a full sort, which is very expensive for long contexts.
#
# Instead, request only a prefix that is large enough for:
# 1) the maximum "keep" budget in the batch, plus
# 2) a conservative extra window for page-padding candidates.
max_seq_len
=
int
(
self
.
packed_context_i32
[
pack_slices_32
[
"context_lens"
]].
max
())
full_budget
=
max_seq_len
*
self
.
num_kv_heads
keep_budget
=
int
(
retain
.
max
().
item
())
pad_search_budget
=
(
self
.
page_size
-
1
)
*
(
self
.
num_kv_heads
**
2
)
max_retain
=
min
(
full_budget
,
keep_budget
+
pad_search_budget
)
# Non-blocking H2D copies above must finish before NCCL broadcast, or peers can
# receive stale/garbage packed buffers → wrong prefill → garbage tokens on TP>1.
if
self
.
packed_context_i64
.
is_cuda
:
torch
.
cuda
.
synchronize
()
# PHI: rank 0's sketch matrix is broadcast so all TP ranks share one PHI for
# leverage / compactor scores (same order as packed_context: i64, i32, PHI).
broadcast_from_tp_rank0
(
self
.
packed_context_i64
,
use_tp_group
=
self
.
_use_tp_group
)
broadcast_from_tp_rank0
(
self
.
packed_context_i32
,
use_tp_group
=
self
.
_use_tp_group
)
if
self
.
world_size
>
1
:
broadcast_from_tp_rank0
(
self
.
PHI
,
use_tp_group
=
self
.
_use_tp_group
)
prefill_args
=
PrefillBatchArguments
(
B
=
B
,
N
=
N
,
do_compression
=
do_compression
,
compression_method
=
batch_compression_params
.
compression_method
,
compression_chunk_size
=
compression_chunk_size
,
seq_ids
=
self
.
packed_context_i64
[
pack_slices_64
[
"seq_ids"
]],
input_ids
=
self
.
packed_context_i64
[
pack_slices_64
[
"input_ids"
]],
positions
=
self
.
packed_context_i64
[
pack_slices_64
[
"positions"
]],
cu_seqlens_q
=
self
.
packed_context_i32
[
pack_slices_32
[
"cu_q"
]],
cu_seqlens_k
=
self
.
packed_context_i32
[
pack_slices_32
[
"cu_k"
]],
max_seqlen_q
=
max_seq_qk
,
max_seqlen_k
=
max_seq_qk
,
batch_tokens_to_retain
=
self
.
packed_context_i32
[
pack_slices_32
[
"retain"
]],
max_tokens_to_retain
=
max_retain
,
PHI
=
self
.
PHI
,
context_lens
=
self
.
packed_context_i32
[
pack_slices_32
[
"context_lens"
]],
max_new_tokens
=
self
.
packed_context_i64
[
pack_slices_64
[
"max_new_tokens"
]],
protected_first
=
protected_first_list
,
protected_last
=
protected_last_list
,
compression_ratio
=
min_compression_ratio
,
)
return
prefill_args
def
_peer_receive_prefill
(
self
)
->
PrefillBatchArguments
:
broadcast_from_tp_rank0
(
self
.
packed_context_i64
,
use_tp_group
=
self
.
_use_tp_group
)
broadcast_from_tp_rank0
(
self
.
packed_context_i32
,
use_tp_group
=
self
.
_use_tp_group
)
if
self
.
world_size
>
1
:
broadcast_from_tp_rank0
(
self
.
PHI
,
use_tp_group
=
self
.
_use_tp_group
)
# Header is 6 fields (B, N, do_compression, method, chunk_size, cr_scaled); must match
# packed_i32_slices(B)["header"] for any B.
header
=
self
.
packed_context_i32
[:
6
].
tolist
()
B
,
N
=
int
(
header
[
0
]),
int
(
header
[
1
])
do_compression
=
bool
(
int
(
header
[
2
]))
compression_method
=
CompressionMethod
(
int
(
header
[
3
]))
compression_chunk_size
=
int
(
header
[
4
])
compression_ratio
=
int
(
header
[
5
])
/
1_000_000.0
pack_slices_64
=
self
.
packed_i64_slices
(
B
,
N
)
pack_slices_32
=
self
.
packed_i32_slices
(
B
)
max_seq_len
=
int
(
self
.
packed_context_i32
[
pack_slices_32
[
"context_lens"
]].
max
())
# Must match _master_build_prefill: max_seqlen_{q,k} = max(Ls), not cu_q.max()
# (which equals total batch tokens N and breaks varlen attention on peers).
full_budget
=
max_seq_len
*
self
.
num_kv_heads
keep_budget
=
int
(
self
.
packed_context_i32
[
pack_slices_32
[
"retain"
]].
max
().
item
())
pad_search_budget
=
(
self
.
page_size
-
1
)
*
(
self
.
num_kv_heads
**
2
)
max_retain
=
min
(
full_budget
,
keep_budget
+
pad_search_budget
)
prefill_args
=
PrefillBatchArguments
(
B
=
B
,
N
=
N
,
do_compression
=
do_compression
,
compression_method
=
compression_method
,
compression_chunk_size
=
compression_chunk_size
,
seq_ids
=
self
.
packed_context_i64
[
pack_slices_64
[
"seq_ids"
]],
input_ids
=
self
.
packed_context_i64
[
pack_slices_64
[
"input_ids"
]],
positions
=
self
.
packed_context_i64
[
pack_slices_64
[
"positions"
]],
cu_seqlens_q
=
self
.
packed_context_i32
[
pack_slices_32
[
"cu_q"
]],
cu_seqlens_k
=
self
.
packed_context_i32
[
pack_slices_32
[
"cu_k"
]],
max_seqlen_q
=
max_seq_len
,
max_seqlen_k
=
max_seq_len
,
batch_tokens_to_retain
=
self
.
packed_context_i32
[
pack_slices_32
[
"retain"
]],
max_tokens_to_retain
=
max_retain
,
PHI
=
self
.
PHI
,
context_lens
=
self
.
packed_context_i32
[
pack_slices_32
[
"context_lens"
]],
max_new_tokens
=
self
.
packed_context_i64
[
pack_slices_64
[
"max_new_tokens"
]],
protected_first
=
self
.
packed_context_i32
[
pack_slices_32
[
"protected_first"
]
].
tolist
(),
protected_last
=
self
.
packed_context_i32
[
pack_slices_32
[
"protected_last"
]
].
tolist
(),
compression_ratio
=
compression_ratio
,
)
return
prefill_args
@
torch
.
inference_mode
()
def
build_prefill_args
(
self
,
seqs
:
Optional
[
List
[
Sequence
]]
=
None
,
batch_compression_params
:
Optional
[
BatchCompressionParams
]
=
None
,
)
->
PrefillBatchArguments
:
if
self
.
rank
==
0
:
return
self
.
_master_build_prefill
(
seqs
,
batch_compression_params
)
return
self
.
_peer_receive_prefill
()
def
broadcast
(
self
):
if
self
.
world_size
>
1
:
return
broadcast_from_tp_rank0
(
self
.
packed_context_i64
,
use_tp_group
=
self
.
_use_tp_group
)
return
None
@
staticmethod
def
packed_i64_slices
(
B
:
int
,
N
:
int
):
return
{
"seq_ids"
:
slice
(
0
,
B
),
"input_ids"
:
slice
(
B
,
B
+
N
),
"positions"
:
slice
(
B
+
N
,
B
+
2
*
N
),
"max_new_tokens"
:
slice
(
B
+
2
*
N
,
2
*
B
+
2
*
N
),
}
@
staticmethod
def
packed_i32_slices
(
B
:
int
):
h0
,
h1
=
0
,
6
q0
=
h1
q1
=
q0
+
(
B
+
1
)
k0
=
q1
k1
=
k0
+
(
B
+
1
)
r0
=
k1
r1
=
r0
+
B
c0
=
r1
c1
=
r1
+
B
pf0
=
c1
pf1
=
c1
+
B
pl0
=
pf1
pl1
=
pf1
+
B
return
{
"header"
:
slice
(
h0
,
h1
),
"cu_q"
:
slice
(
q0
,
q1
),
"cu_k"
:
slice
(
k0
,
k1
),
"retain"
:
slice
(
r0
,
r1
),
"context_lens"
:
slice
(
c0
,
c1
),
"protected_first"
:
slice
(
pf0
,
pf1
),
"protected_last"
:
slice
(
pl0
,
pl1
),
}
@
dataclass
class
DecodeBatchOutput
:
output_tokens
:
Optional
[
torch
.
Tensor
]
output_seq_ids
:
Optional
[
torch
.
Tensor
]
@
dataclass
class
DecodeBatchArguments
:
batch_mapping
:
Optional
[
torch
.
Tensor
]
=
None
token_ids
:
Optional
[
torch
.
Tensor
]
=
None
positions
:
Optional
[
torch
.
Tensor
]
=
None
max_ctx_lens
:
Optional
[
torch
.
Tensor
]
=
None
seq_ids
:
Optional
[
torch
.
Tensor
]
=
None
temps
:
Optional
[
torch
.
Tensor
]
=
None
desired_batch_occupancy
:
int
=
-
1
num_stashed_batches
:
int
=
0
def
update
(
self
,
batch_mapping
,
token_ids
,
positions
,
max_ctx_lens
,
seq_ids
,
temps
=
None
,
desired_batch_occupancy
:
int
=
None
,
):
if
self
.
batch_mapping
is
not
None
:
self
.
batch_mapping
=
torch
.
cat
([
self
.
batch_mapping
,
batch_mapping
],
dim
=
0
)
else
:
self
.
batch_mapping
=
batch_mapping
.
clone
()
if
self
.
token_ids
is
not
None
:
self
.
token_ids
=
torch
.
cat
([
self
.
token_ids
,
token_ids
],
dim
=
0
)
else
:
self
.
token_ids
=
token_ids
.
clone
()
if
self
.
positions
is
not
None
:
self
.
positions
=
torch
.
cat
([
self
.
positions
,
positions
],
dim
=
0
)
else
:
self
.
positions
=
positions
.
clone
()
if
self
.
max_ctx_lens
is
not
None
:
self
.
max_ctx_lens
=
torch
.
cat
([
self
.
max_ctx_lens
,
max_ctx_lens
],
dim
=
0
)
else
:
self
.
max_ctx_lens
=
max_ctx_lens
.
clone
()
if
self
.
seq_ids
is
not
None
:
self
.
seq_ids
=
torch
.
cat
([
self
.
seq_ids
,
seq_ids
],
dim
=
0
)
else
:
self
.
seq_ids
=
seq_ids
.
clone
()
if
self
.
temps
is
not
None
and
temps
is
not
None
:
self
.
temps
=
torch
.
cat
([
self
.
temps
,
temps
],
dim
=
0
)
elif
temps
is
not
None
:
self
.
temps
=
temps
.
clone
()
if
desired_batch_occupancy
is
not
None
:
self
.
desired_batch_occupancy
=
desired_batch_occupancy
return
self
vllm/kvprune_legacy_save/utils/context.py
deleted
100644 → 0
View file @
2b7160c6
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Tuple
import
torch
# Import from compression_config, not compression.__init__, to avoid circular imports
# (compression -> compactor -> context -> compression).
from
vllm.kvprune.compression.compression_config
import
CompressionMethod
from
vllm.kvprune.config.engine_config
import
KvpruneAttentionSchedule
@
dataclass
class
CompressionContext
:
compression_method
:
CompressionMethod
=
CompressionMethod
.
COMPACTOR
compression_chunk_size
:
int
=
-
1
batch_tokens_to_retain
:
torch
.
Tensor
|
None
=
None
max_tokens_to_retain
:
int
=
0
context_lens
:
List
[
int
]
|
None
=
None
PHI
:
torch
.
Tensor
|
None
=
None
# Compactor(与 kvpress ``CompactorPress`` 对齐的可选超参)
sketch_dimension
:
int
=
48
sink_size_start
:
int
=
8
sink_size_end
:
int
=
4
compactor_blending
:
Optional
[
float
]
=
None
# 与 kvpress 一致:未设 ``compactor_blending`` 时用该值(来自请求的 compression_ratio)
compression_ratio
:
Optional
[
float
]
=
None
protected_first_tokens
:
List
[
int
]
|
None
=
None
protected_last_tokens
:
List
[
int
]
|
None
=
None
# CriticalAdaKV
wo_weight
:
Optional
[
torch
.
Tensor
]
=
None
critical_ada_epsilon
:
float
=
1e-4
critical_ada_first_stage_ratio
:
float
=
0.5
critical_ada_alpha_safeguard
:
float
=
0.2
@
dataclass
class
Context
:
is_prefill
:
bool
=
False
do_compression
:
bool
=
False
cu_seqlens_q
:
torch
.
Tensor
|
None
=
None
cu_seqlens_k
:
torch
.
Tensor
|
None
=
None
# Set in ModelRunner.run_prefill before forward — avoids D2H inside compactor kernels.
cu_seqlens_q_host
:
Optional
[
Tuple
[
int
,
...]]
=
None
cu_seqlens_k_host
:
Optional
[
Tuple
[
int
,
...]]
=
None
max_seqlen_q
:
int
=
0
max_seqlen_k
:
int
=
0
batch_mapping
:
torch
.
Tensor
|
None
=
None
max_bh_len
:
int
=
0
compression_context
:
CompressionContext
|
None
=
None
STORE_STREAM
:
torch
.
cuda
.
Stream
|
None
=
None
key_split
:
int
|
None
=
None
attention_schedule
:
KvpruneAttentionSchedule
=
(
KvpruneAttentionSchedule
.
FA_PREFILL_TRITON_DECODE
)
_CONTEXT
=
Context
()
def
get_context
():
return
_CONTEXT
def
set_context
(
*
,
is_prefill
,
do_compression
=
False
,
cu_seqlens_q
=
None
,
cu_seqlens_k
=
None
,
cu_seqlens_q_host
:
Optional
[
Tuple
[
int
,
...]]
=
None
,
cu_seqlens_k_host
:
Optional
[
Tuple
[
int
,
...]]
=
None
,
max_seqlen_q
=
0
,
max_seqlen_k
=
0
,
batch_mapping
=
None
,
max_bh_len
=
0
,
compression_context
:
CompressionContext
=
None
,
STORE_STREAM
=
None
,
key_split
=
None
,
attention_schedule
=
KvpruneAttentionSchedule
.
FA_PREFILL_TRITON_DECODE
,
):
global
_CONTEXT
_CONTEXT
=
Context
(
is_prefill
,
do_compression
,
cu_seqlens_q
,
cu_seqlens_k
,
cu_seqlens_q_host
,
cu_seqlens_k_host
,
max_seqlen_q
,
max_seqlen_k
,
batch_mapping
,
max_bh_len
,
compression_context
,
STORE_STREAM
,
key_split
,
attention_schedule
,
)
def
reset_context
():
global
_CONTEXT
_CONTEXT
=
Context
()
vllm/kvprune_legacy_save/utils/helpers.py
deleted
100644 → 0
View file @
2b7160c6
from
collections.abc
import
Callable
import
torch
def
maybe_execute_in_stream
(
fn
:
Callable
,
*
args
,
STORE_STREAM
:
torch
.
cuda
.
Stream
=
None
,
**
kwargs
):
if
STORE_STREAM
is
not
None
:
tensors
=
[
arg
for
arg
in
args
if
isinstance
(
arg
,
torch
.
Tensor
)]
tensors
+=
[
val
for
val
in
kwargs
.
values
()
if
isinstance
(
val
,
torch
.
Tensor
)]
obj
=
getattr
(
fn
,
"__self__"
,
None
)
if
isinstance
(
obj
,
torch
.
Tensor
):
tensors
.
append
(
obj
)
STORE_STREAM
.
wait_stream
(
torch
.
cuda
.
default_stream
())
# Some PyTorch builds don't make `torch.cuda.Stream` a context manager.
# The portable API is `torch.cuda.stream(stream)`.
stream_ctx
=
(
STORE_STREAM
if
hasattr
(
STORE_STREAM
,
"__enter__"
)
else
torch
.
cuda
.
stream
(
STORE_STREAM
)
)
with
stream_ctx
:
output
=
fn
(
*
args
,
**
kwargs
)
for
t
in
tensors
:
t
.
record_stream
(
STORE_STREAM
)
if
isinstance
(
output
,
tuple
):
for
o
in
output
:
if
isinstance
(
o
,
torch
.
Tensor
):
o
.
record_stream
(
torch
.
cuda
.
default_stream
())
elif
isinstance
(
output
,
torch
.
Tensor
):
output
.
record_stream
(
torch
.
cuda
.
default_stream
())
return
output
else
:
return
fn
(
*
args
,
**
kwargs
)
vllm/kvprune_legacy_save/utils/kv_dist.py
deleted
100644 → 0
View file @
2b7160c6
"""Distributed helpers for kvprune when embedded in vLLM (use TP process group)."""
from
__future__
import
annotations
import
torch
import
torch.distributed
as
dist
def
broadcast_from_tp_rank0
(
tensor
:
torch
.
Tensor
,
*
,
use_tp_group
:
bool
)
->
None
:
"""Broadcast ``tensor`` from group-local rank 0.
When ``use_tp_group`` is False (standalone compactor subprocesses), uses the
default process group (world == tensor parallel size).
When True (embedded in a vLLM worker), uses vLLM's tensor-parallel group so
collectives do not accidentally involve DP/PP ranks if the default group is global.
"""
if
not
use_tp_group
:
dist
.
broadcast
(
tensor
,
src
=
0
)
return
from
vllm.distributed.parallel_state
import
get_tp_group
get_tp_group
().
broadcast
(
tensor
,
src
=
0
)
def
barrier_sync
(
*
,
use_tp_group
:
bool
)
->
None
:
"""Barrier across either the default group or the TP group (see :func:`broadcast_from_tp_rank0`)."""
if
not
use_tp_group
:
dist
.
barrier
()
return
from
vllm.distributed.parallel_state
import
get_tp_group
get_tp_group
().
barrier
()
vllm/kvprune_legacy_save/utils/layout_bridge.py
deleted
100644 → 0
View file @
2b7160c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Bridge vLLM paged KV layout to compactor Triton kernels.
vLLM FlashAttention KV cache is shaped
[num_blocks, block_size, num_kv_heads, head_dim].
Compactor kernels expect a flat buffer [CACHE_SIZE, head_dim] and a page table
global_page_table[batch, kv_head, logical_page] -> physical_page_id
where each physical page holds ``block_size`` consecutive rows belonging to that
KV head only.
When num_kv_heads == 1 (MQA), a vLLM block maps 1:1 to compactor rows:
row_index = physical_block_id * block_size + offset_in_block.
When ``num_kv_heads > 1``, we permute to head-major
``[num_kv_heads, num_blocks, block_size, head_dim]`` and flatten to
``[num_kv_heads * num_blocks * block_size, head_dim]`` so each KV head occupies
a disjoint row range in the flat buffer. The page table is built so each
logical compression page maps to ``global_row // PAGE_SIZE`` in that layout
(see ``build_page_table_head_major``).
"""
from
__future__
import
annotations
import
torch
def
_cdiv
(
n
:
int
,
d
:
int
)
->
int
:
return
(
n
+
d
-
1
)
//
d
def
flatten_kv_cache_head_major
(
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""View ``[nb, bs, H, D]`` caches as ``[H*nb*bs, D]`` in head-major order."""
if
key_cache
.
shape
!=
value_cache
.
shape
:
raise
ValueError
(
"key_cache and value_cache must match"
)
nb
,
bs
,
hkv
,
d
=
key_cache
.
shape
k_hm
=
key_cache
.
permute
(
2
,
0
,
1
,
3
).
contiguous
()
v_hm
=
value_cache
.
permute
(
2
,
0
,
1
,
3
).
contiguous
()
k_flat
=
k_hm
.
reshape
(
hkv
*
nb
*
bs
,
d
)
v_flat
=
v_hm
.
reshape
(
hkv
*
nb
*
bs
,
d
)
return
k_flat
,
v_flat
def
write_head_major_flat_to_interleaved
(
k_flat
:
torch
.
Tensor
,
v_flat
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
)
->
None
:
"""Copy ``[H*nb*bs, D]`` head-major flats back to ``[nb, bs, H, D]``."""
nb
,
bs
,
hkv
,
d
=
key_cache
.
shape
k_hm
=
k_flat
.
view
(
hkv
,
nb
,
bs
,
d
)
v_hm
=
v_flat
.
view
(
hkv
,
nb
,
bs
,
d
)
key_cache
.
copy_
(
k_hm
.
permute
(
1
,
2
,
0
,
3
))
value_cache
.
copy_
(
v_hm
.
permute
(
1
,
2
,
0
,
3
))
def
build_page_table_head_major
(
block_table
:
torch
.
Tensor
,
num_kv_heads
:
int
,
num_blocks
:
int
,
block_size
:
int
,
page_size
:
int
,
max_batches
:
int
,
)
->
torch
.
Tensor
:
"""Build ``[max_batches, H, max_chain]`` page table for head-major flat KV.
Chains physical page ids in ``block_table`` order for each (batch, head).
Each entry is ``global_row // page_size`` where ``global_row`` indexes rows
in the head-major flat buffer (see ``flatten_kv_cache_head_major``).
"""
bsz
,
max_blocks
=
block_table
.
shape
if
bsz
>
max_batches
:
raise
ValueError
(
"batch size exceeds max_batches for page table"
)
num_pages_per_block
=
_cdiv
(
block_size
,
page_size
)
max_chain
=
max_blocks
*
num_pages_per_block
out
=
torch
.
zeros
(
(
max_batches
,
num_kv_heads
,
max_chain
),
dtype
=
torch
.
int32
,
device
=
block_table
.
device
,
)
bt
=
block_table
.
to
(
torch
.
int64
)
for
b
in
range
(
bsz
):
for
h
in
range
(
num_kv_heads
):
lp_idx
=
0
for
blk_i
in
range
(
max_blocks
):
bid
=
int
(
bt
[
b
,
blk_i
].
item
())
if
bid
<
0
:
continue
if
bid
>=
num_blocks
:
raise
ValueError
(
f
"block_table[
{
b
}
,
{
blk_i
}
]=
{
bid
}
out of range "
f
"num_blocks=
{
num_blocks
}
"
)
base_row
=
h
*
(
num_blocks
*
block_size
)
+
bid
*
block_size
for
p
in
range
(
num_pages_per_block
):
start_row
=
base_row
+
p
*
page_size
if
start_row
>=
base_row
+
block_size
:
break
phys
=
start_row
//
page_size
out
[
b
,
h
,
lp_idx
]
=
int
(
phys
)
lp_idx
+=
1
return
out
def
flatten_kv_cache_plane
(
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""View (num_blocks, block_size, HKV, D) caches as [num_blocks*block_size*HKV, D].
This matches compactor row indexing only when HKV == 1 (see module doc).
"""
if
num_kv_heads
!=
1
:
raise
ValueError
(
"flatten_kv_cache_plane requires num_kv_heads==1 for compactor layout"
)
if
key_cache
.
shape
!=
value_cache
.
shape
:
raise
ValueError
(
"key_cache and value_cache must match"
)
# [num_blocks, block_size, 1, D] -> [num_blocks * block_size, D]
nb
,
bs
,
hkv
,
d
=
key_cache
.
shape
if
hkv
!=
1
:
raise
ValueError
(
"expected num_kv_heads==1"
)
k_flat
=
key_cache
.
reshape
(
nb
*
bs
,
d
)
v_flat
=
value_cache
.
reshape
(
nb
*
bs
,
d
)
if
not
k_flat
.
is_contiguous
():
k_flat
=
k_flat
.
contiguous
()
if
not
v_flat
.
is_contiguous
():
v_flat
=
v_flat
.
contiguous
()
return
k_flat
,
v_flat
def
block_table_to_global_page_table
(
block_table
:
torch
.
Tensor
,
num_kv_heads
:
int
,
max_batches
:
int
,
)
->
torch
.
Tensor
:
"""Build [max_batches, HKV, num_logical_pages] int32 page table.
For MQA, every KV head reuses the same physical block ids as vLLM's table.
"""
# block_table: [num_reqs_padded, max_num_blocks]
bsz
,
max_lp
=
block_table
.
shape
if
bsz
>
max_batches
:
raise
ValueError
(
"batch size exceeds max_batches for page table"
)
out
=
torch
.
zeros
(
(
max_batches
,
num_kv_heads
,
max_lp
),
dtype
=
torch
.
int32
,
device
=
block_table
.
device
,
)
bt
=
block_table
.
to
(
torch
.
int32
)[:
bsz
]
if
num_kv_heads
==
1
:
out
[:
bsz
,
0
,
:
max_lp
]
=
bt
else
:
for
h
in
range
(
num_kv_heads
):
out
[:
bsz
,
h
,
:
max_lp
]
=
bt
return
out
def
build_batch_mapping
(
num_reqs
:
int
,
device
:
torch
.
device
)
->
torch
.
Tensor
:
"""Local batch index -> global batch row (identity)."""
return
torch
.
arange
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
vllm/kvprune_legacy_save/utils/sequence.py
deleted
100644 → 0
View file @
2b7160c6
from
dataclasses
import
dataclass
,
field
from
enum
import
Enum
,
auto
from
itertools
import
count
from
typing
import
List
from
vllm.kvprune.compression.compression_config
import
SequenceCompressionParams
from
vllm.kvprune.config.sampling_params
import
SamplingParams
class
SequenceStatus
(
Enum
):
WAITING
=
auto
()
RUNNING
=
auto
()
FINISHED
=
auto
()
@
dataclass
class
Sequence
:
"""
Represents a single user request / sequence being generated.
"""
_counter
=
count
()
prompt_token_ids
:
List
[
int
]
completion_token_ids
:
List
[
int
]
=
field
(
default_factory
=
list
)
sampling_params
:
SamplingParams
=
field
(
default_factory
=
SamplingParams
)
compression_params
:
SequenceCompressionParams
=
field
(
default_factory
=
SequenceCompressionParams
)
status
:
SequenceStatus
=
SequenceStatus
.
WAITING
seq_id
:
int
=
field
(
default_factory
=
lambda
:
next
(
Sequence
.
_counter
),
init
=
False
)
num_tokens_processed
:
int
=
0
@
property
def
num_prompt_tokens
(
self
)
->
int
:
return
len
(
self
.
prompt_token_ids
)
@
property
def
num_generated_tokens
(
self
)
->
int
:
return
len
(
self
.
completion_token_ids
)
def
add_new_token
(
self
,
token_id
:
int
)
->
None
:
if
len
(
self
.
completion_token_ids
)
==
0
:
self
.
num_tokens_processed
+=
self
.
num_prompt_tokens
self
.
completion_token_ids
.
append
(
token_id
)
self
.
num_tokens_processed
+=
1
def
tokens_to_retain_per_layer
(
self
,
num_kv_heads
:
int
)
->
int
:
n
=
int
(
self
.
compression_params
.
compression_ratio
*
self
.
num_prompt_tokens
*
num_kv_heads
)
return
max
(
1
,
n
)
def
__getstate__
(
self
):
return
dict
(
prompt_token_ids
=
list
(
self
.
prompt_token_ids
),
completion_token_ids
=
list
(
self
.
completion_token_ids
),
sampling_params
=
self
.
sampling_params
,
compression_params
=
self
.
compression_params
,
status
=
self
.
status
,
seq_id
=
self
.
seq_id
,
num_tokens_processed
=
self
.
num_tokens_processed
,
)
def
__setstate__
(
self
,
state
):
self
.
prompt_token_ids
=
list
(
state
[
"prompt_token_ids"
])
self
.
completion_token_ids
=
list
(
state
[
"completion_token_ids"
])
self
.
sampling_params
=
state
[
"sampling_params"
]
self
.
compression_params
=
state
[
"compression_params"
]
self
.
status
=
state
[
"status"
]
self
.
seq_id
=
state
[
"seq_id"
]
self
.
num_tokens_processed
=
state
[
"num_tokens_processed"
]
@
property
def
prompt_len
(
self
)
->
int
:
return
len
(
self
.
prompt_token_ids
)
@
property
def
completion_len
(
self
)
->
int
:
return
len
(
self
.
completion_token_ids
)
vllm/kvprune_legacy_save/utils/tp_collectives.py
deleted
100644 → 0
View file @
2b7160c6
"""Tensor-parallel collectives for kvprune (match vLLM TP process group when embedded)."""
from
__future__
import
annotations
import
torch.distributed
as
dist
def
tensor_parallel_all_reduce
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""All-reduce across tensor-parallel ranks (in-place on ``tensor`` when possible).
When vLLM :mod:`vllm.distributed.parallel_state` is initialized (e.g. kvprune
runs inside a vLLM GPU worker), uses the same TP NCCL group as the main model
(:func:`~vllm.distributed.communication_op.tensor_model_parallel_all_reduce`).
vLLM's TP :meth:`~vllm.distributed.parallel_state.GroupCoordinator.all_reduce`
is **out-of-place** and returns a new tensor. Call sites such as
:class:`~vllm.kvprune.layers.linear.RowParallelLinear` historically invoked
``tensor_parallel_all_reduce(y)`` without using the return value, which left
``y`` as the **unreduced** per-rank partial output under TP>1 — wrong activations,
wrong logits, and garbage tokens. We copy the reduced result back into ``tensor``
so existing call sites remain correct.
Standalone kvprune subprocesses only have the default process group (world ==
``tensor_parallel_size``); in that case we fall back to :func:`torch.distributed.all_reduce`
on the default group.
"""
if
not
dist
.
is_initialized
()
or
dist
.
get_world_size
()
<=
1
:
return
tensor
try
:
from
vllm.distributed.parallel_state
import
model_parallel_is_initialized
if
model_parallel_is_initialized
():
from
vllm.distributed.communication_op
import
(
tensor_model_parallel_all_reduce
as
vllm_tp_all_reduce
,
)
reduced
=
vllm_tp_all_reduce
(
tensor
)
if
reduced
is
not
tensor
:
# vLLM TP all_reduce is out-of-place: `reduced` holds the cross-rank sum.
# Call sites ignore the return value and expect `tensor` to be updated — we
# MUST materialize the reduced values here or TP>1 keeps per-rank partials
# (RowParallel / VocabParallel outputs stay wrong without this copy).
tensor
.
copy_
(
reduced
)
return
tensor
except
Exception
:
pass
dist
.
all_reduce
(
tensor
)
return
tensor
vllm/kvprune_legacy_save/utils/tp_utils.py
deleted
100644 → 0
View file @
2b7160c6
"""Tensor-parallel helpers for kvprune when embedded in a vLLM worker."""
from
__future__
import
annotations
import
torch.distributed
as
dist
def
tensor_parallel_rank_for_sharding
()
->
int
:
"""Rank within the tensor-parallel group (matches vLLM weight shards when embedded).
Falls back to :func:`torch.distributed.get_rank` when vLLM parallel state is
unavailable (standalone kvprune with only the default process group).
"""
try
:
from
vllm.distributed.parallel_state
import
get_tensor_model_parallel_rank
return
int
(
get_tensor_model_parallel_rank
())
except
Exception
:
if
dist
.
is_initialized
():
return
int
(
dist
.
get_rank
())
return
0
def
tensor_parallel_world_size_for_sharding
()
->
int
:
"""World size of the tensor-parallel group."""
try
:
from
vllm.distributed.parallel_state
import
(
get_tensor_model_parallel_world_size
,
)
return
int
(
get_tensor_model_parallel_world_size
())
except
Exception
:
if
dist
.
is_initialized
():
return
int
(
dist
.
get_world_size
())
return
1
def
kv_heads_shard_divisor
()
->
int
:
"""Return world size used to shard KV heads (TP group when vLLM is loaded)."""
return
tensor_parallel_world_size_for_sharding
()
vllm/kvprune_legacy_save/utils/triton_compat.py
deleted
100644 → 0
View file @
2b7160c6
from
__future__
import
annotations
import
inspect
from
typing
import
Any
,
Callable
,
Mapping
import
torch
def
_filter_kwargs_for_callable
(
fn
:
Callable
[...,
Any
],
kwargs
:
Mapping
[
str
,
Any
]
)
->
dict
[
str
,
Any
]:
try
:
params
=
inspect
.
signature
(
fn
).
parameters
except
(
TypeError
,
ValueError
):
return
dict
(
kwargs
)
return
{
k
:
v
for
k
,
v
in
kwargs
.
items
()
if
k
in
params
}
def
autotune
(
*
,
configs
,
key
,
**
kwargs
):
"""
Compatibility wrapper around `triton.autotune`.
Some Triton builds (e.g., custom vendor builds) may not support newer
keyword arguments like `cache_results`. This wrapper filters unsupported
kwargs based on the runtime `triton.autotune` signature.
"""
import
triton
filtered
=
_filter_kwargs_for_callable
(
triton
.
autotune
,
kwargs
)
return
triton
.
autotune
(
configs
=
configs
,
key
=
key
,
**
filtered
)
def
maybe_set_allocator
(
alloc_fn
:
Callable
[[
int
,
int
,
int
|
None
],
Any
])
->
bool
:
"""
Call `triton.set_allocator(alloc_fn)` if present; otherwise no-op.
Returns True if the allocator was set.
"""
import
triton
setter
=
getattr
(
triton
,
"set_allocator"
,
None
)
if
setter
is
None
:
return
False
setter
(
alloc_fn
)
return
True
def
cuda_capability_geq
(
major
:
int
,
minor
:
int
=
0
,
device
:
int
|
None
=
None
)
->
bool
:
"""
Host-side CUDA capability check that works even when `tl.target_info` is absent.
"""
if
not
torch
.
cuda
.
is_available
():
return
False
if
device
is
None
:
try
:
device
=
torch
.
cuda
.
current_device
()
except
Exception
:
device
=
0
cap
=
torch
.
cuda
.
get_device_capability
(
device
)
return
cap
>=
(
major
,
minor
)
Prev
1
…
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment