Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2b7160c6
Commit
2b7160c6
authored
Apr 23, 2026
by
chenzk
Browse files
vllm kvprune:v1.0.0
parent
fa718036
Changes
305
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3295 additions
and
0 deletions
+3295
-0
vllm/compactor-vllm/src/compactor_vllm/compression/snapkv_origin.py
...ctor-vllm/src/compactor_vllm/compression/snapkv_origin.py
+449
-0
vllm/compactor-vllm/src/compactor_vllm/config/__init__.py
vllm/compactor-vllm/src/compactor_vllm/config/__init__.py
+0
-0
vllm/compactor-vllm/src/compactor_vllm/config/constants.py
vllm/compactor-vllm/src/compactor_vllm/config/constants.py
+5
-0
vllm/compactor-vllm/src/compactor_vllm/config/engine_config.py
...compactor-vllm/src/compactor_vllm/config/engine_config.py
+100
-0
vllm/compactor-vllm/src/compactor_vllm/config/sampling_params.py
...mpactor-vllm/src/compactor_vllm/config/sampling_params.py
+11
-0
vllm/compactor-vllm/src/compactor_vllm/core/__init__.py
vllm/compactor-vllm/src/compactor_vllm/core/__init__.py
+0
-0
vllm/compactor-vllm/src/compactor_vllm/core/llm_engine.py
vllm/compactor-vllm/src/compactor_vllm/core/llm_engine.py
+404
-0
vllm/compactor-vllm/src/compactor_vllm/core/memory_manager.py
.../compactor-vllm/src/compactor_vllm/core/memory_manager.py
+182
-0
vllm/compactor-vllm/src/compactor_vllm/core/model_runner.py
vllm/compactor-vllm/src/compactor_vllm/core/model_runner.py
+584
-0
vllm/compactor-vllm/src/compactor_vllm/core/scheduler.py
vllm/compactor-vllm/src/compactor_vllm/core/scheduler.py
+215
-0
vllm/compactor-vllm/src/compactor_vllm/kv_cache/__init__.py
vllm/compactor-vllm/src/compactor_vllm/kv_cache/__init__.py
+0
-0
vllm/compactor-vllm/src/compactor_vllm/kv_cache/page_table.py
.../compactor-vllm/src/compactor_vllm/kv_cache/page_table.py
+313
-0
vllm/compactor-vllm/src/compactor_vllm/kv_cache/store_kv_cache.py
...pactor-vllm/src/compactor_vllm/kv_cache/store_kv_cache.py
+468
-0
vllm/compactor-vllm/src/compactor_vllm/kv_cache/write_page_table.py
...ctor-vllm/src/compactor_vllm/kv_cache/write_page_table.py
+110
-0
vllm/compactor-vllm/src/compactor_vllm/layers/__init__.py
vllm/compactor-vllm/src/compactor_vllm/layers/__init__.py
+0
-0
vllm/compactor-vllm/src/compactor_vllm/layers/activation.py
vllm/compactor-vllm/src/compactor_vllm/layers/activation.py
+13
-0
vllm/compactor-vllm/src/compactor_vllm/layers/attention.py
vllm/compactor-vllm/src/compactor_vllm/layers/attention.py
+170
-0
vllm/compactor-vllm/src/compactor_vllm/layers/embed_head.py
vllm/compactor-vllm/src/compactor_vllm/layers/embed_head.py
+69
-0
vllm/compactor-vllm/src/compactor_vllm/layers/layernorm.py
vllm/compactor-vllm/src/compactor_vllm/layers/layernorm.py
+49
-0
vllm/compactor-vllm/src/compactor_vllm/layers/linear.py
vllm/compactor-vllm/src/compactor_vllm/layers/linear.py
+153
-0
No files found.
Too many changes to show.
To preserve performance only
305 of 305+
files are displayed.
Plain diff
Email patch
vllm/compactor-vllm/src/compactor_vllm/compression/snapkv_origin.py
0 → 100644
View file @
2b7160c6
import
math
from
typing
import
Optional
import
torch
import
triton
from
triton
import
language
as
tl
from
compactor_vllm.compression.common
import
BaseCompressionMethod
from
compactor_vllm.utils.helpers
import
maybe_execute_in_stream
from
compactor_vllm.utils.triton_compat
import
autotune
as
triton_autotune
class
SnapKVCompression
(
BaseCompressionMethod
):
@
staticmethod
def
pre_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
context
)
->
Optional
[
torch
.
Tensor
]:
return
None
@
staticmethod
def
post_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
pre_rope_scores
:
torch
.
Tensor
,
context
,
)
->
Optional
[
torch
.
Tensor
]:
scores
=
maybe_execute_in_stream
(
query_aware_key_scores
,
q
,
k
,
context
.
cu_seqlens_q
,
context
.
cu_seqlens_k
,
w
=
32
,
STORE_STREAM
=
context
.
STORE_STREAM
,
)
return
scores
@
triton_autotune
(
configs
=
[
triton
.
Config
(
{
"BLOCK_Q"
:
bq
,
"BLOCK_K"
:
bk
},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
bq
in
[
32
,
64
]
for
bk
in
[
32
,
64
]
for
num_warps
in
[
4
,
8
]
for
num_stages
in
[
3
,
4
]
],
key
=
[
"QUERY_GROUP_SIZE"
,
"D"
,
"ROWS_MAX"
],
cache_results
=
True
,
)
@
triton
.
jit
def
_lse_and_store_logits_kernel
(
Q
,
K
,
cu_q
,
cu_k
,
w_b
,
# int32 pointers
out_m
,
out_S
,
# [B, Hk, ROWS_MAX] float32
LOGITS
,
# [Nk, Hk, ROWS_MAX] float32
sm_scale
,
# float
QUERY_GROUP_SIZE
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
STRIDE_Q_NQ
,
STRIDE_Q_HQ
,
STRIDE_K_NK
,
STRIDE_K_HK
,
STRIDE_M_B
,
STRIDE_M_H
,
STRIDE_M_R
,
STRIDE_S_B
,
STRIDE_S_H
,
STRIDE_S_R
,
STRIDE_LG_NK
,
STRIDE_LG_HK
,
STRIDE_LG_R
,
BLOCK_Q
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
ROWS_MAX
,
):
# program ids
b
=
tl
.
program_id
(
0
)
hk
=
tl
.
program_id
(
1
)
rid
=
tl
.
program_id
(
2
)
# row-tile id
# batch segment bounds
q_end
=
tl
.
load
(
cu_q
+
b
+
1
)
k_beg
=
tl
.
load
(
cu_k
+
b
)
k_end
=
tl
.
load
(
cu_k
+
b
+
1
)
win
=
tl
.
load
(
w_b
+
b
)
q_win_beg
=
q_end
-
win
k_eff_end
=
k_end
-
win
if
(
win
<=
0
)
or
(
k_eff_end
<=
k_beg
):
return
# rows for this (b,hk)
rows_b
=
win
*
QUERY_GROUP_SIZE
row0
=
rid
*
BLOCK_Q
if
row0
>=
rows_b
:
return
# exp(x) = exp2(x * 1/ln2)
qk_scale
=
sm_scale
*
1.4426950408889634
offs_qrow
=
row0
+
tl
.
arange
(
0
,
BLOCK_Q
)
row_mask
=
offs_qrow
<
rows_b
# map row -> (q_idx, hq_local)
hq_local
=
offs_qrow
%
QUERY_GROUP_SIZE
q_off
=
offs_qrow
//
QUERY_GROUP_SIZE
q_idx
=
q_win_beg
+
q_off
hq_glob
=
hk
*
QUERY_GROUP_SIZE
+
hq_local
offs_d
=
tl
.
arange
(
0
,
D
)
q_ptrs
=
(
Q
+
q_idx
[:,
None
]
*
STRIDE_Q_NQ
+
hq_glob
[:,
None
]
*
STRIDE_Q_HQ
+
offs_d
[
None
,
:]
)
q_rows
=
tl
.
load
(
q_ptrs
,
mask
=
row_mask
[:,
None
],
other
=
0.0
)
m
=
tl
.
zeros
([
BLOCK_Q
],
dtype
=
tl
.
float32
)
+
(
-
float
(
"inf"
))
S
=
tl
.
zeros
([
BLOCK_Q
],
dtype
=
tl
.
float32
)
for
ks
in
tl
.
range
(
k_beg
,
k_eff_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
k_eff_end
k_ptrs
=
K
+
nk
[:,
None
]
*
STRIDE_K_NK
+
hk
*
STRIDE_K_HK
+
offs_d
[
None
,
:]
k_blk
=
tl
.
load
(
k_ptrs
,
mask
=
kmask
[:,
None
],
other
=
0.0
)
# [BK, D]
s
=
tl
.
dot
(
q_rows
,
k_blk
.
T
)
*
qk_scale
# [BQ, BK]
s
=
tl
.
where
(
kmask
[
None
,
:],
s
,
-
float
(
"inf"
))
# store into LOGITS[nk, hk, row] -> [BK, BQ]
log_ptrs
=
(
LOGITS
+
nk
[:,
None
]
*
STRIDE_LG_NK
+
hk
*
STRIDE_LG_HK
+
(
row0
+
tl
.
arange
(
0
,
BLOCK_Q
))[
None
,
:]
*
STRIDE_LG_R
)
tl
.
store
(
log_ptrs
,
s
.
T
,
mask
=
kmask
[:,
None
]
&
row_mask
[
None
,
:])
# log2 streaming LSE update
cur_max
=
tl
.
max
(
s
,
1
)
# [BQ]
n_m
=
tl
.
maximum
(
m
,
cur_max
)
rescale
=
tl
.
math
.
exp2
(
m
-
n_m
)
S
=
S
*
rescale
+
tl
.
sum
(
tl
.
math
.
exp2
(
s
-
n_m
[:,
None
]),
1
)
m
=
n_m
# store m,S for these rows
m_base
=
out_m
+
b
*
STRIDE_M_B
+
hk
*
STRIDE_M_H
+
row0
*
STRIDE_M_R
S_base
=
out_S
+
b
*
STRIDE_S_B
+
hk
*
STRIDE_S_H
+
row0
*
STRIDE_S_R
tl
.
store
(
m_base
+
tl
.
arange
(
0
,
BLOCK_Q
)
*
STRIDE_M_R
,
m
,
mask
=
row_mask
)
tl
.
store
(
S_base
+
tl
.
arange
(
0
,
BLOCK_Q
)
*
STRIDE_S_R
,
S
,
mask
=
row_mask
)
@
triton_autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_Q"
:
bq
,
"BLOCK_K"
:
bk
})
for
bq
in
[
16
,
32
,
64
]
for
bk
in
[
32
,
64
,
128
]
],
key
=
[
"HK"
,
"HQ"
],
cache_results
=
True
,
)
@
triton
.
jit
def
_scores_from_logits_kernel
(
cu_k
,
w_b
,
in_m
,
in_S
,
# [B, Hk, ROWS_MAX] f32
LOGITS
,
# [Nk, Hk, ROWS_MAX] f32, base-2 logits
OUT
,
# [Nk, Hk] f32
#
QUERY_GROUP_SIZE
:
tl
.
constexpr
,
STRIDE_M_B
,
STRIDE_M_H
,
STRIDE_M_R
,
STRIDE_S_B
,
STRIDE_S_H
,
STRIDE_S_R
,
STRIDE_LG_NK
,
STRIDE_LG_HK
,
STRIDE_LG_R
,
STRIDE_OUT_NK
,
STRIDE_OUT_HK
,
BLOCK_Q
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
#
DO_POOL
:
tl
.
constexpr
,
# set True to enable in-place avg pool
KPOOL
:
tl
.
constexpr
,
# kernel size for avg pool (stride=1)
):
b
=
tl
.
program_id
(
0
)
hk
=
tl
.
program_id
(
1
)
k_beg
=
tl
.
load
(
cu_k
+
b
)
k_end
=
tl
.
load
(
cu_k
+
b
+
1
)
win
=
tl
.
load
(
w_b
+
b
)
k_eff_end
=
k_end
-
win
if
(
win
<=
0
)
or
(
k_eff_end
<=
k_beg
):
return
rows_b
=
win
*
QUERY_GROUP_SIZE
# === scores over computed region ===
for
ks
in
tl
.
range
(
k_beg
,
k_eff_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
k_eff_end
scores
=
tl
.
zeros
([
BLOCK_K
],
dtype
=
tl
.
float32
)
for
row0
in
tl
.
range
(
0
,
rows_b
,
BLOCK_Q
):
r_idx
=
row0
+
tl
.
arange
(
0
,
BLOCK_Q
)
rmask
=
r_idx
<
rows_b
# load m, S for rows
m_ptr
=
in_m
+
b
*
STRIDE_M_B
+
hk
*
STRIDE_M_H
+
row0
*
STRIDE_M_R
S_ptr
=
in_S
+
b
*
STRIDE_S_B
+
hk
*
STRIDE_S_H
+
row0
*
STRIDE_S_R
m
=
tl
.
load
(
m_ptr
+
tl
.
arange
(
0
,
BLOCK_Q
)
*
STRIDE_M_R
,
mask
=
rmask
,
other
=-
float
(
"inf"
),
)
S
=
tl
.
load
(
S_ptr
+
tl
.
arange
(
0
,
BLOCK_Q
)
*
STRIDE_S_R
,
mask
=
rmask
,
other
=
0.0
)
valid_row
=
S
>
0
m
=
tl
.
where
(
valid_row
,
m
,
0.0
)
S
=
tl
.
where
(
valid_row
,
S
,
1.0
)
# load stored logits^T: [BK, BQ]
log_ptrs
=
(
LOGITS
+
nk
[:,
None
]
*
STRIDE_LG_NK
+
hk
*
STRIDE_LG_HK
+
(
row0
+
tl
.
arange
(
0
,
BLOCK_Q
))[
None
,
:]
*
STRIDE_LG_R
)
s_T
=
tl
.
load
(
log_ptrs
,
mask
=
kmask
[:,
None
]
&
rmask
[
None
,
:],
other
=-
float
(
"inf"
)
)
# [BK, BQ]
# probs^T = exp2(s_T - m) / S, sum over rows
probs_T
=
tl
.
math
.
exp2
(
s_T
-
m
[
None
,
:])
/
S
[
None
,
:]
probs_T
=
tl
.
where
(
valid_row
[
None
,
:],
probs_T
,
0.0
)
scores
+=
tl
.
sum
(
probs_T
,
1
)
# [BK]
if
DO_POOL
and
(
KPOOL
>
1
):
i
=
tl
.
arange
(
0
,
BLOCK_K
)[:,
None
]
j
=
tl
.
arange
(
0
,
BLOCK_K
)[
None
,
:]
band
=
(
j
<=
i
)
&
((
i
-
j
)
<
KPOOL
)
band
=
band
&
kmask
[
None
,
:]
# sum within band
sums
=
tl
.
sum
(
tl
.
where
(
band
,
scores
[
None
,
:],
0.0
),
1
)
# [BK]
denom
=
tl
.
sum
(
band
,
1
).
to
(
tl
.
float32
)
# [BK]
denom
=
tl
.
where
(
denom
>
0
,
denom
,
1.0
)
scores
=
sums
/
denom
out_ptrs
=
OUT
+
nk
*
STRIDE_OUT_NK
+
hk
*
STRIDE_OUT_HK
tl
.
store
(
out_ptrs
,
scores
,
mask
=
kmask
)
pad_beg
=
k_eff_end
pad_end
=
k_end
if
pad_end
>
pad_beg
:
for
ks
in
tl
.
range
(
pad_beg
,
pad_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
pad_end
out_ptrs
=
OUT
+
nk
*
STRIDE_OUT_NK
+
hk
*
STRIDE_OUT_HK
tl
.
store
(
out_ptrs
,
tl
.
full
([
BLOCK_K
],
float
(
"inf"
),
dtype
=
tl
.
float32
),
mask
=
kmask
)
@
triton_autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_K"
:
bk
})
for
bk
in
[
32
,
64
,
128
]],
key
=
[
"HK"
],
cache_results
=
True
,
)
@
triton
.
jit
def
_zscore_per_batch_epilogue
(
OUT
,
# [Nk, Hk], float32
cu_k
,
w_b
,
# [B+1], [B] int32
STRIDE_OUT_NK
,
STRIDE_OUT_HK
,
HK
:
tl
.
constexpr
,
# Hk
EPS
:
tl
.
constexpr
,
# e.g., 1e-12
BLOCK_K
:
tl
.
constexpr
,
# e.g., 128
):
b
=
tl
.
program_id
(
0
)
k_beg
=
tl
.
load
(
cu_k
+
b
)
k_end
=
tl
.
load
(
cu_k
+
b
+
1
)
win
=
tl
.
load
(
w_b
+
b
)
k_eff_end
=
k_end
-
win
if
k_eff_end
<=
k_beg
:
return
sumv
=
tl
.
zeros
([],
dtype
=
tl
.
float32
)
sumsq
=
tl
.
zeros
([],
dtype
=
tl
.
float32
)
count
=
((
k_eff_end
-
k_beg
)
*
HK
).
to
(
tl
.
float32
)
for
ks
in
tl
.
range
(
k_beg
,
k_eff_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
k_eff_end
for
h
in
tl
.
range
(
0
,
HK
):
ptrs
=
OUT
+
nk
*
STRIDE_OUT_NK
+
h
*
STRIDE_OUT_HK
vals
=
tl
.
load
(
ptrs
,
mask
=
kmask
,
other
=
0.0
).
to
(
tl
.
float32
)
sumv
+=
tl
.
sum
(
vals
,
0
)
sumsq
+=
tl
.
sum
(
vals
*
vals
,
0
)
mean
=
sumv
/
count
var
=
tl
.
maximum
(
sumsq
/
count
-
mean
*
mean
,
0.0
)
invstd
=
1.0
/
tl
.
sqrt
(
var
+
EPS
)
for
ks
in
tl
.
range
(
k_beg
,
k_eff_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
k_eff_end
for
h
in
tl
.
range
(
0
,
HK
):
ptrs
=
OUT
+
nk
*
STRIDE_OUT_NK
+
h
*
STRIDE_OUT_HK
vals
=
tl
.
load
(
ptrs
,
mask
=
kmask
,
other
=
0.0
).
to
(
tl
.
float32
)
vals
=
(
vals
-
mean
)
*
invstd
tl
.
store
(
ptrs
,
vals
,
mask
=
kmask
)
def
query_aware_key_scores
(
q
:
torch
.
Tensor
,
# [N_q, Hq, D]
k
:
torch
.
Tensor
,
# [N_k, Hk, D]
cu_seqlens_q
:
torch
.
Tensor
,
# [B+1], int32
cu_seqlens_k
:
torch
.
Tensor
,
# [B+1], int32
w
:
torch
.
Tensor
|
int
,
# [B], int32
sm_scale
:
float
=
None
,
# defaults to 1/sqrt(D)
*
,
accum_scores
:
torch
.
Tensor
=
None
,
accum_blending
:
float
=
None
,
normalize
:
bool
=
False
,
)
->
Optional
[
torch
.
Tensor
]:
assert
q
.
stride
(
-
1
)
==
1
and
k
.
stride
(
-
1
)
==
1
,
"last dim must be contiguous"
device
=
q
.
device
N_q
,
Hq
,
D
=
q
.
shape
N_k
,
Hk
,
Dk
=
k
.
shape
assert
(
Hq
%
Hk
)
==
0
,
"Hq must be a multiple of Hk"
if
sm_scale
is
None
:
sm_scale
=
1.0
/
math
.
sqrt
(
D
)
B
=
cu_seqlens_q
.
numel
()
-
1
assert
B
==
cu_seqlens_k
.
numel
()
-
1
G
=
Hq
//
Hk
if
type
(
w
)
is
int
:
max_w
=
w
w
=
torch
.
full
((
B
,),
fill_value
=
w
,
device
=
device
,
dtype
=
torch
.
int32
)
else
:
max_w
=
int
(
w
.
max
().
item
())
assert
w
.
numel
()
==
B
ROWS_MAX
=
max_w
*
G
if
ROWS_MAX
==
0
:
return
torch
.
zeros
((
N_k
,
Hk
),
dtype
=
torch
.
float32
,
device
=
device
)
out
=
torch
.
empty
((
N_k
,
Hk
),
dtype
=
torch
.
float32
,
device
=
device
)
m_scratch
=
torch
.
empty
((
B
,
Hk
,
ROWS_MAX
),
dtype
=
torch
.
float32
,
device
=
device
)
S_scratch
=
torch
.
empty
((
B
,
Hk
,
ROWS_MAX
),
dtype
=
torch
.
float32
,
device
=
device
)
logits_buf
=
torch
.
empty
((
N_k
,
Hk
,
ROWS_MAX
),
dtype
=
torch
.
float32
,
device
=
device
)
# strides
STRIDE_Q_NQ
,
STRIDE_Q_HQ
,
_
=
q
.
stride
()
STRIDE_K_NK
,
STRIDE_K_HK
,
_
=
k
.
stride
()
STRIDE_M_B
,
STRIDE_M_H
,
STRIDE_M_R
=
m_scratch
.
stride
()
STRIDE_S_B
,
STRIDE_S_H
,
STRIDE_S_R
=
S_scratch
.
stride
()
STRIDE_LG_NK
,
STRIDE_LG_HK
,
STRIDE_LG_R
=
logits_buf
.
stride
()
STRIDE_OUT_NK
,
STRIDE_OUT_HK
=
out
.
stride
()
def
grid
(
META
):
return
B
,
Hk
,
triton
.
cdiv
(
ROWS_MAX
,
META
[
"BLOCK_Q"
])
_lse_and_store_logits_kernel
[
grid
](
q
,
k
,
cu_seqlens_q
,
cu_seqlens_k
,
w
,
m_scratch
,
S_scratch
,
logits_buf
,
sm_scale
,
QUERY_GROUP_SIZE
=
Hq
//
Hk
,
D
=
D
,
STRIDE_Q_NQ
=
STRIDE_Q_NQ
,
STRIDE_Q_HQ
=
STRIDE_Q_HQ
,
STRIDE_K_NK
=
STRIDE_K_NK
,
STRIDE_K_HK
=
STRIDE_K_HK
,
STRIDE_M_B
=
STRIDE_M_B
,
STRIDE_M_H
=
STRIDE_M_H
,
STRIDE_M_R
=
STRIDE_M_R
,
STRIDE_S_B
=
STRIDE_S_B
,
STRIDE_S_H
=
STRIDE_S_H
,
STRIDE_S_R
=
STRIDE_S_R
,
STRIDE_LG_NK
=
STRIDE_LG_NK
,
STRIDE_LG_HK
=
STRIDE_LG_HK
,
STRIDE_LG_R
=
STRIDE_LG_R
,
ROWS_MAX
=
ROWS_MAX
,
)
_scores_from_logits_kernel
[(
B
,
Hk
)](
cu_seqlens_k
,
w
,
m_scratch
,
S_scratch
,
logits_buf
,
out
,
QUERY_GROUP_SIZE
=
Hq
//
Hk
,
STRIDE_M_B
=
STRIDE_M_B
,
STRIDE_M_H
=
STRIDE_M_H
,
STRIDE_M_R
=
STRIDE_M_R
,
STRIDE_S_B
=
STRIDE_S_B
,
STRIDE_S_H
=
STRIDE_S_H
,
STRIDE_S_R
=
STRIDE_S_R
,
STRIDE_LG_NK
=
STRIDE_LG_NK
,
STRIDE_LG_HK
=
STRIDE_LG_HK
,
STRIDE_LG_R
=
STRIDE_LG_R
,
STRIDE_OUT_NK
=
STRIDE_OUT_NK
,
STRIDE_OUT_HK
=
STRIDE_OUT_HK
,
DO_POOL
=
True
,
KPOOL
=
5
,
)
if
normalize
:
_zscore_per_batch_epilogue
[(
B
,)](
out
,
cu_seqlens_k
,
w
,
STRIDE_OUT_NK
,
STRIDE_OUT_HK
,
HK
=
Hk
,
EPS
=
1e-12
,
)
if
accum_scores
is
not
None
:
if
accum_blending
is
not
None
:
accum_scores
.
mul_
(
accum_blending
)
accum_scores
.
add_
(
out
)
return
accum_scores
else
:
return
out
vllm/compactor-vllm/src/compactor_vllm/config/__init__.py
0 → 100644
View file @
2b7160c6
vllm/compactor-vllm/src/compactor_vllm/config/constants.py
0 → 100644
View file @
2b7160c6
RESERVED_BATCH
=
0
# NOTE: Triton `tl.constexpr` is intended for use in kernel signatures/annotations.
# Some Triton builds reject passing `tl.constexpr(...)` objects as constexpr values.
# Keep the runtime value as a plain int and let kernel signatures declare constexpr.
TRITON_RESERVED_BATCH
=
RESERVED_BATCH
vllm/compactor-vllm/src/compactor_vllm/config/engine_config.py
0 → 100644
View file @
2b7160c6
import
os
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
typing
import
List
,
Optional
from
transformers
import
AutoConfig
class
AttentionBackend
(
Enum
):
FLASH_ATTENTION
=
auto
()
COMPACTOR_TRITON
=
auto
()
@
dataclass
class
LLMConfig
:
"""Configuration for the :class:`LLM` engine.
Parameters
----------
model : str
Hugging Face model identifier (e.g. ``"meta-llama/Meta-Llama-3-8B"``) or
a local model name that can be resolved by
:func:`transformers.AutoConfig.from_pretrained`.
path : str, optional
Local directory containing the model weights. If ``None``, the engine
will attempt to resolve a local snapshot for ``model`` using
:func:`huggingface_hub.snapshot_download`.
max_num_seqs : int, default 256
Upper bound on the number of concurrent batches that the scheduler and
KV-cache manager are allowed to handle. This affects the size of the
page table and some internal buffers.
max_model_len : int, default 40960
Maximum context length (in tokens) that the engine will allocate KV cache
and CUDA graphs for. During initialization this value is clamped to
``hf_config.max_position_embeddings`` for the chosen model.
gpu_memory_utilization : float, default 0.9
Fraction of the total GPU memory that may be used for KV cache and model
activations. Values should be in ``(0, 1]``. If this budget is too small,
the KV-cache manager may raise an error at warmup time due
to insufficient memory.
tensor_parallel_size : int, default 1
Number of tensor-parallel workers to shard the model
across. Must be between 1 and 8, and must evenly divide the model's
number of key/value heads.
enforce_eager : bool, default False
If ``True``, disable CUDA graph capture and always run the model in
eager mode during decoding. This reduces throughput. When ``False``,
the engine will capture and reuse CUDA graphs for supported
batch sizes and sequence lengths.
hf_config : transformers.AutoConfig, optional
Pre-loaded Hugging Face configuration for the model. If ``None``,
it will then be populated automatically based on ``model``.
eos : int, default -1
Primary stop token id (warmup / single-id paths). If ``-1``, the
:class:`LLM` constructor fills this and :attr:`eos_token_ids` from the
tokenizer.
eos_token_ids : list of int, optional
All token ids that terminate generation (e.g. HF tokenizers may expose
``eos_token_id`` as a list for chat models). If ``None``, inferred in
:class:`LLM` from the tokenizer and model type.
kvcache_page_size : int, default 128
Number of tokens stored in a single KV-cache page. Smaller pages improve
allocation flexibility but increase page-table overhead; larger pages
reduce overhead but have coarser granularity.
leverage_sketch_size : int, default 48
Sketch dimension used by the Compactor leverage-score estimator.
attention_backend : AttentionBackend, default AttentionBackend.COMPACTOR_TRITON
Attention implementation to use. ``COMPACTOR_TRITON`` selects the custom
Triton kernels used by Compactor; ``FLASH_ATTENTION`` selects the
FlashAttention3 varlen backend. The COMPACTOR_TRITON tends to be faster
for longer sequence lengths, while FA3 is faster at shorter lengths.
"""
model
:
str
path
:
Optional
[
str
]
=
None
nccl_port
:
Optional
[
int
]
=
1218
max_num_seqs
:
int
=
256
max_model_len
:
int
=
40960
gpu_memory_utilization
:
float
=
0.9
tensor_parallel_size
:
int
=
1
enforce_eager
:
bool
=
False
hf_config
:
AutoConfig
|
None
=
None
eos
:
int
=
-
1
eos_token_ids
:
Optional
[
List
[
int
]]
=
None
kvcache_page_size
:
int
=
128
leverage_sketch_size
:
int
=
48
attention_backend
:
AttentionBackend
=
AttentionBackend
.
COMPACTOR_TRITON
show_progress_bar
:
bool
=
True
def
__post_init__
(
self
):
if
self
.
path
is
not
None
and
not
os
.
path
.
isdir
(
self
.
path
):
raise
NotADirectoryError
(
f
"Engine config dir
{
self
.
path
}
does not exist"
)
if
self
.
tensor_parallel_size
<=
0
or
self
.
tensor_parallel_size
>
8
:
assert
1
<=
self
.
tensor_parallel_size
<=
8
raise
ValueError
(
"tensor_parallel_size must be >= 1 and <= 8"
)
if
self
.
hf_config
is
None
:
self
.
hf_config
=
AutoConfig
.
from_pretrained
(
self
.
model
)
self
.
max_model_len
=
min
(
self
.
max_model_len
,
self
.
hf_config
.
max_position_embeddings
)
vllm/compactor-vllm/src/compactor_vllm/config/sampling_params.py
0 → 100644
View file @
2b7160c6
from
dataclasses
import
dataclass
@
dataclass
class
SamplingParams
:
temperature
:
float
=
1.0
max_new_tokens
:
int
=
256
def
__post_init__
(
self
):
if
self
.
temperature
<
0
:
raise
ValueError
(
"Temperature cannot be negative"
)
vllm/compactor-vllm/src/compactor_vllm/core/__init__.py
0 → 100644
View file @
2b7160c6
vllm/compactor-vllm/src/compactor_vllm/core/llm_engine.py
0 → 100644
View file @
2b7160c6
import
atexit
import
inspect
import
logging
from
typing
import
Any
,
List
,
Optional
,
Union
import
torch.multiprocessing
as
mp
from
compactor_vllm.compression.compression_config
import
(
BatchCompressionParams
,
SequenceCompressionParams
,
)
from
compactor_vllm.config.engine_config
import
LLMConfig
from
compactor_vllm.config.sampling_params
import
SamplingParams
from
compactor_vllm.core.model_runner
import
ModelRunner
from
compactor_vllm.models
import
MODEL_REGISTRY
from
compactor_vllm.utils.sequence
import
Sequence
from
transformers
import
AutoTokenizer
logger
=
logging
.
getLogger
(
__name__
)
PromptLike
=
Union
[
str
,
List
[
int
]]
def
_infer_stop_token_ids
(
tokenizer
,
hf_config
)
->
list
[
int
]:
"""
Build the set of token ids that should end generation.
Newer HF chat tokenizers often expose ``eos_token_id`` as a *list* of ids.
The engine must not compare generated ids to that list as a single ``int``;
see :attr:`LLMConfig.eos_token_ids` and decode-time ``torch.isin``.
Qwen chat uses ``</think>`` (im_end) as the assistant turn boundary; include it
when present in ``additional_special_tokens`` / ``added_tokens_encoder``. We
avoid loose substring matches like ``
\"
end
\"
`` that can tag unrelated tokens.
"""
raw
=
tokenizer
.
eos_token_id
ids
:
list
[
int
]
=
[]
if
isinstance
(
raw
,
(
list
,
tuple
)):
ids
.
extend
(
int
(
x
)
for
x
in
raw
)
elif
raw
is
not
None
:
ids
.
append
(
int
(
raw
))
unk_id
=
getattr
(
tokenizer
,
"unk_token_id"
,
None
)
def
_maybe_add_tid
(
tid
:
int
)
->
None
:
if
not
isinstance
(
tid
,
int
)
or
tid
<
0
:
return
if
unk_id
is
not
None
and
tid
==
unk_id
:
return
if
tid
not
in
ids
:
ids
.
append
(
tid
)
model_type
=
getattr
(
hf_config
,
"model_type"
,
None
)
if
model_type
in
(
"qwen2"
,
"qwen3"
,
"qwen2_moe"
,
"qwen3_moe"
):
enc
=
getattr
(
tokenizer
,
"added_tokens_encoder"
,
None
)
if
isinstance
(
enc
,
dict
):
for
key
,
tid
in
enc
.
items
():
if
isinstance
(
key
,
str
)
and
"im_end"
in
key
:
_maybe_add_tid
(
int
(
tid
))
for
extra
in
getattr
(
tokenizer
,
"additional_special_tokens"
,
[])
or
[]:
if
not
isinstance
(
extra
,
str
)
or
"im_end"
not
in
extra
:
continue
try
:
tid
=
tokenizer
.
convert_tokens_to_ids
(
extra
)
except
(
TypeError
,
ValueError
,
KeyError
):
continue
_maybe_add_tid
(
tid
)
if
not
ids
:
raise
ValueError
(
"Could not infer stop token ids from the tokenizer; set "
"LLMConfig(eos_token_ids=[...]) explicitly."
)
return
ids
def
_merge_apply_chat_template_kwargs
(
tokenizer
,
user_kwargs
:
Optional
[
dict
[
str
,
Any
]],
)
->
dict
[
str
,
Any
]:
"""
Merge user kwargs with defaults for HF chat templates that support them.
Qwen3 (and similar) instruct models expect `add_generation_prompt=True` so
the first generated token continues the assistant turn; without it, output
can repeat punctuation / template fragments. `enable_thinking=False` avoids
the Qwen3 reasoning channel when the tokenizer supports it.
"""
out
=
dict
(
user_kwargs
or
{})
try
:
sig
=
inspect
.
signature
(
tokenizer
.
apply_chat_template
)
except
(
TypeError
,
ValueError
):
return
out
if
"add_generation_prompt"
in
sig
.
parameters
and
"add_generation_prompt"
not
in
out
:
out
[
"add_generation_prompt"
]
=
True
if
"enable_thinking"
in
sig
.
parameters
and
"enable_thinking"
not
in
out
:
out
[
"enable_thinking"
]
=
False
return
out
def
_runner_entry
(
config
:
LLMConfig
,
rank
:
int
,
evt
):
runner
=
None
try
:
runner
=
ModelRunner
(
config
,
rank
,
evt
)
runner
.
loop
()
except
Exception
as
e
:
logging
.
exception
(
f
"Rank
{
rank
}
:
{
repr
(
e
)
}
"
)
finally
:
if
runner
is
not
None
:
runner
.
exit
()
class
LLMEngine
:
"""High-level engine coordinating model runners and scheduling"""
def
__init__
(
self
,
config
:
LLMConfig
):
self
.
config
=
config
if
self
.
config
.
hf_config
.
model_type
not
in
MODEL_REGISTRY
:
raise
ValueError
(
f
"Unknown model
{
self
.
config
.
model
}
"
)
if
config
.
path
is
None
:
from
huggingface_hub
import
snapshot_download
self
.
config
.
path
=
snapshot_download
(
repo_id
=
config
.
model
,
local_files_only
=
True
)
logger
.
info
(
f
"Using
{
self
.
config
.
model
}
snapshot @
{
self
.
config
.
path
}
"
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
self
.
config
.
model
,
use_fast
=
True
)
if
self
.
config
.
eos_token_ids
is
None
:
if
self
.
config
.
eos
!=
-
1
:
self
.
config
.
eos_token_ids
=
[
int
(
self
.
config
.
eos
)]
else
:
self
.
config
.
eos_token_ids
=
_infer_stop_token_ids
(
self
.
tokenizer
,
self
.
config
.
hf_config
)
else
:
self
.
config
.
eos_token_ids
=
[
int
(
x
)
for
x
in
self
.
config
.
eos_token_ids
]
self
.
config
.
eos_token_ids
=
sorted
(
set
(
self
.
config
.
eos_token_ids
))
if
self
.
config
.
eos
==
-
1
:
self
.
config
.
eos
=
int
(
self
.
config
.
eos_token_ids
[
0
])
else
:
self
.
config
.
eos
=
int
(
self
.
config
.
eos
)
if
self
.
config
.
eos
not
in
self
.
config
.
eos_token_ids
:
self
.
config
.
eos_token_ids
=
sorted
(
self
.
config
.
eos_token_ids
+
[
self
.
config
.
eos
]
)
self
.
ps
=
[]
world_size
=
int
(
self
.
config
.
tensor_parallel_size
)
self
.
events
=
[]
if
world_size
>
1
:
ctx
=
mp
.
get_context
(
"spawn"
)
for
r
in
range
(
1
,
world_size
):
event
=
ctx
.
Event
()
p
=
ctx
.
Process
(
target
=
_runner_entry
,
args
=
(
self
.
config
,
r
,
event
),
daemon
=
True
,
)
p
.
start
()
self
.
ps
.
append
(
p
)
self
.
events
.
append
(
event
)
self
.
master_model_runner
=
ModelRunner
(
self
.
config
,
rank
=
0
,
peer_events
=
self
.
events
)
atexit
.
register
(
self
.
exit
)
def
exit
(
self
):
if
getattr
(
self
,
"_exited"
,
False
):
return
self
.
_exited
=
True
runner
=
getattr
(
self
,
"master_model_runner"
,
None
)
if
runner
is
not
None
:
try
:
runner
.
exit
()
except
Exception
:
logger
.
exception
(
"Failed to exit master ModelRunner cleanly"
)
for
p
in
self
.
ps
:
if
p
.
is_alive
():
p
.
terminate
()
p
.
join
(
timeout
=
1.0
)
if
hasattr
(
self
,
"events"
):
self
.
events
.
clear
()
def
tokenize_prompt
(
self
,
prompt
:
PromptLike
,
**
tokenizer_kwargs
)
->
List
[
int
]:
"""
Turn a raw prompt into token IDs.
"""
if
isinstance
(
prompt
,
str
):
return
self
.
tokenizer
(
prompt
,
**
tokenizer_kwargs
)[
"input_ids"
]
else
:
return
list
(
prompt
)
def
detokenize_prompt
(
self
,
sequences
:
List
[
Sequence
],
**
detokenizer_kwargs
)
->
List
[
str
]:
"""
Turn completed Sequences into strings.
"""
defaults
:
dict
[
str
,
Any
]
=
{
"skip_special_tokens"
:
True
}
merged
=
{
**
defaults
,
**
detokenizer_kwargs
}
return
self
.
tokenizer
.
batch_decode
(
[
s
.
completion_token_ids
for
s
in
sequences
],
**
merged
)
def
_build_sequences
(
self
,
prompts
:
List
[
PromptLike
]
|
PromptLike
,
sampling_params
:
SamplingParams
|
List
[
SamplingParams
],
per_sequence_compression_params
:
Optional
[
SequenceCompressionParams
|
List
[
SequenceCompressionParams
]
]
=
None
,
tokenizer_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
)
->
List
[
Sequence
]:
"""
Build Sequence objects from prompts, sampling params, and optional
per-sequence compression parameters.
"""
tokenizer_kwargs
=
{}
if
tokenizer_kwargs
is
None
else
tokenizer_kwargs
if
not
isinstance
(
prompts
,
list
):
prompts
=
[
prompts
]
if
isinstance
(
sampling_params
,
SamplingParams
):
sampling_params_list
:
List
[
SamplingParams
]
=
[
sampling_params
]
*
len
(
prompts
)
else
:
sampling_params_list
=
sampling_params
assert
len
(
sampling_params_list
)
==
len
(
prompts
),
(
"sampling_params list must match prompts length"
)
if
per_sequence_compression_params
is
None
:
compression_params_list
:
List
[
SequenceCompressionParams
]
=
[
SequenceCompressionParams
(
1.0
)
for
_
in
prompts
]
elif
isinstance
(
per_sequence_compression_params
,
SequenceCompressionParams
):
compression_params_list
=
[
per_sequence_compression_params
]
*
len
(
prompts
)
else
:
# list-like
assert
len
(
per_sequence_compression_params
)
==
len
(
prompts
),
(
"per_sequence_compression_params list must match prompts length"
)
compression_params_list
=
list
(
per_sequence_compression_params
)
seqs
:
List
[
Sequence
]
=
[]
for
prompt
,
sparams
,
cparams
in
zip
(
prompts
,
sampling_params_list
,
compression_params_list
):
token_ids
=
self
.
tokenize_prompt
(
prompt
,
**
tokenizer_kwargs
)
if
cparams
.
protected_first_tokens
+
cparams
.
protected_last_tokens
>=
len
(
token_ids
):
cparams
.
compression_ratio
=
1.0
seqs
.
append
(
Sequence
(
prompt_token_ids
=
token_ids
,
sampling_params
=
sparams
,
compression_params
=
cparams
,
)
)
return
seqs
def
generate
(
self
,
prompts
:
List
[
PromptLike
]
|
PromptLike
,
sampling_params
:
SamplingParams
|
List
[
SamplingParams
],
batch_compression_params
:
BatchCompressionParams
,
*
,
per_sequence_compression_params
:
Union
[
List
[
SequenceCompressionParams
],
SequenceCompressionParams
]
=
None
,
tokenizer_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
detokenizer_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
return_sequences
:
bool
=
False
,
)
->
List
[
str
]
|
tuple
[
List
[
str
],
List
[
Sequence
]]:
"""
Accept prompts and return completed Sequences.
Args:
:param prompts:
Single prompt or list of prompts, each either a raw text prompt,
or pre-tokenized input IDs.
:param sampling_params:
A single SamplingParams for all prompts in this batch or a list of
SamplingParams with the same length as ``prompts``.
:param batch_compression_params:
Compression settings for this batch.
:param per_sequence_compression_params:
Per-sequence compression parameters, including the compression
ratio to be applied and the size of the protected regions of the
sequence (how many start tokens and end tokens to keep uncompressed).
If a SequenceCompressionParams instance, the same params will be
applied to all sequences in this batch; if a list is provided,
each SequenceCompressionParams will be attached to the corresponding
prompt in the batch.
:param tokenizer_kwargs:
Extra kwargs forwarded to ``tokenizer(...)`` when tokenizing
string prompts.
:param detokenizer_kwargs:
Passed through to `tokenizer.batch_decode`.
:param return_sequences:
Whether to return sequence objects or not
Returns:
:return List[Sequence]:
One Sequence per input prompt, with `completion_token_ids`
filled in after generation.
"""
tokenizer_kwargs
=
{}
if
tokenizer_kwargs
is
None
else
tokenizer_kwargs
detokenizer_kwargs
=
{}
if
detokenizer_kwargs
is
None
else
detokenizer_kwargs
seqs
=
self
.
_build_sequences
(
prompts
,
sampling_params
=
sampling_params
,
per_sequence_compression_params
=
per_sequence_compression_params
,
tokenizer_kwargs
=
tokenizer_kwargs
,
)
self
.
master_model_runner
.
generate
(
seqs
,
batch_compression_params
)
output_strings
=
self
.
detokenize_prompt
(
seqs
,
**
detokenizer_kwargs
)
if
return_sequences
:
return
output_strings
,
seqs
return
output_strings
def
generate_chat
(
self
,
messages_batch
:
List
[
List
[
dict
]],
sampling_params
:
SamplingParams
|
List
[
SamplingParams
],
batch_compression_params
:
BatchCompressionParams
,
per_sequence_compression_params
:
Union
[
SequenceCompressionParams
,
List
[
SequenceCompressionParams
]
],
*
,
tokenizer_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
detokenizer_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
return_sequences
:
bool
=
False
,
)
->
List
[
str
]
|
tuple
[
List
[
str
],
List
[
Sequence
]]:
"""
Convenience API for chat-style prompts using HF `apply_chat_template`.
Args:
:param messages_batch:
List of conversations, where each conversation is a list of
message dicts like:
{"role": "system" | "user" | "assistant", "content": str}
:param sampling_params:
A single SamplingParams for all prompts in this batch or a list of
SamplingParams with the same length as ``prompts``.
:param batch_compression_params:
Batch Level compression settings. Can set compression_method.
:param per_sequence_compression_params:
Per-sequence compression parameters, including the compression
ratio to be applied and the size of the protected regions of the
sequence (how many start tokens and end tokens to keep uncompressed).
If a SequenceCompressionParams instance, the same params will be
applied to all sequences in this batch; if a list is provided,
each SequenceCompressionParams will be attached to the corresponding
conversation in the batch.
:param tokenizer_kwargs:
Passed through to `tokenizer.apply_chat_template`.
:param detokenizer_kwargs:
Passed through to `tokenizer.batch_decode`.
:param return_sequences:
Whether to return sequence objects or not
Returns:
:return List[str] or tuple[List[str], List[Sequence]]:
One string per conversation.
"""
prompts_token_ids
:
List
[
List
[
int
]]
=
[]
tokenizer_kwargs
=
_merge_apply_chat_template_kwargs
(
self
.
tokenizer
,
tokenizer_kwargs
)
detokenizer_kwargs
=
{}
if
detokenizer_kwargs
is
None
else
detokenizer_kwargs
for
messages
in
messages_batch
:
input_ids
=
self
.
tokenizer
.
apply_chat_template
(
messages
,
tokenize
=
True
,
**
tokenizer_kwargs
,
)
if
hasattr
(
input_ids
,
"tolist"
):
input_ids
=
input_ids
.
tolist
()
prompts_token_ids
.
append
(
input_ids
)
return
self
.
generate
(
prompts_token_ids
,
sampling_params
=
sampling_params
,
batch_compression_params
=
batch_compression_params
,
per_sequence_compression_params
=
per_sequence_compression_params
,
tokenizer_kwargs
=
tokenizer_kwargs
,
detokenizer_kwargs
=
detokenizer_kwargs
,
return_sequences
=
return_sequences
,
)
def
generate_from_sequences
(
self
,
seqs
:
List
[
Sequence
],
batch_compression_params
:
BatchCompressionParams
,
)
->
List
[
Sequence
]:
"""
Args:
:param seqs:
List of Sequence instances
:param batch_compression_params:
Compression settings.
Returns:
:return List[Sequence]:
Same list, mutated in-place with completions.
"""
self
.
master_model_runner
.
generate
(
seqs
,
batch_compression_params
)
return
seqs
vllm/compactor-vllm/src/compactor_vllm/core/memory_manager.py
0 → 100644
View file @
2b7160c6
import
logging
from
typing
import
Iterable
,
List
,
Optional
import
torch
import
torch.distributed
as
dist
from
compactor_vllm.config.engine_config
import
LLMConfig
from
compactor_vllm.kv_cache.page_table
import
KVAllocationStatus
,
PagedKVCache
from
torch
import
nn
logger
=
logging
.
getLogger
(
__name__
)
class
KVCacheManager
:
def
__init__
(
self
,
rank
:
int
,
config
:
LLMConfig
):
super
().
__init__
()
hf_config
=
config
.
hf_config
self
.
rank
=
rank
self
.
gpu_frac
=
config
.
gpu_memory_utilization
self
.
page_size
=
config
.
kvcache_page_size
self
.
world_size
=
config
.
tensor_parallel_size
self
.
max_num_batches
=
config
.
max_num_seqs
self
.
max_model_len
=
config
.
max_model_len
self
.
num_layers
=
hf_config
.
num_hidden_layers
self
.
model_dtype
=
hf_config
.
torch_dtype
self
.
head_dim
=
getattr
(
hf_config
,
"head_dim"
,
None
)
self
.
max_pages_per_batch
=
(
self
.
max_model_len
+
self
.
page_size
-
1
)
//
self
.
page_size
self
.
num_kv_heads
=
hf_config
.
num_key_value_heads
//
dist
.
get_world_size
()
assert
hf_config
.
num_key_value_heads
%
dist
.
get_world_size
()
==
0
,
(
"world size needs to divide num_kv_heads"
)
self
.
num_pages
=
None
self
.
paged_cache
:
Optional
[
PagedKVCache
]
=
None
self
.
max_batched_tokens
=
None
self
.
seq_id_to_batch
=
{}
def
allocate_sequences
(
self
,
seq_ids
:
List
[
int
],
max_positions
:
List
[
int
]
)
->
(
bool
,
Optional
[
torch
.
Tensor
]):
batch_mapping
=
[]
for
seq_id
,
len_to_alloc
in
zip
(
seq_ids
,
max_positions
):
if
seq_id
not
in
self
.
seq_id_to_batch
:
batch_id
=
self
.
paged_cache
.
new_batch
()
if
batch_id
is
None
:
logger
.
warning
(
"Failed to allocate batch!"
)
return
False
,
None
self
.
seq_id_to_batch
[
seq_id
]
=
int
(
batch_id
)
batch_mapping
.
append
(
self
.
seq_id_to_batch
[
seq_id
])
if
(
alloc_status
:
=
self
.
paged_cache
.
reserve_tokens
(
self
.
seq_id_to_batch
[
seq_id
],
len_to_alloc
)
)
!=
KVAllocationStatus
.
SUCCESS
:
logger
.
warning
(
f
"Failed to allocate pages (
{
alloc_status
}
)!"
)
return
False
,
None
batch_mapping
=
torch
.
as_tensor
(
batch_mapping
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
return
True
,
batch_mapping
def
free_sequences
(
self
,
seq_ids
:
Iterable
[
int
]):
for
seq_id
in
seq_ids
:
global_batch_id
=
self
.
seq_id_to_batch
.
pop
(
seq_id
,
None
)
self
.
paged_cache
.
free_batch
(
global_batch_id
)
def
init_cache
(
self
,
model
:
nn
.
Module
):
self
.
num_pages
=
self
.
get_num_pages
(
self
.
gpu_frac
,
self
.
max_pages_per_batch
)
self
.
paged_cache
=
PagedKVCache
(
num_layers
=
self
.
num_layers
,
H_kv
=
self
.
num_kv_heads
,
head_dim
=
self
.
head_dim
,
page_size
=
self
.
page_size
,
num_pages
=
int
(
self
.
num_pages
),
max_num_batches
=
self
.
max_num_batches
,
device
=
f
"cuda:
{
self
.
rank
}
"
,
dtype
=
self
.
model_dtype
,
max_logical_pages_per_head
=
int
(
self
.
max_pages_per_batch
),
)
self
.
_assign_cache_to_layers
(
model
)
def
_assign_cache_to_layers
(
self
,
model
)
->
None
:
for
layer_index
,
layer
in
enumerate
(
model
.
model
.
layers
):
attn
=
layer
.
self_attn
.
attn
k
,
v
,
pt
,
bh
=
self
.
paged_cache
.
layer_slices
(
layer_index
)
attn
.
k_cache
=
k
attn
.
v_cache
=
v
attn
.
page_table
=
pt
attn
.
bh_seq_lens
=
bh
attn
.
page_size
=
self
.
page_size
def
get_num_pages
(
self
,
frac
:
float
,
n_logical_pages_max
:
int
):
free
,
total
=
torch
.
cuda
.
mem_get_info
()
used
=
total
-
free
stats
=
torch
.
cuda
.
memory_stats
()
peak
=
int
(
stats
[
"allocated_bytes.all.peak"
])
current
=
int
(
stats
[
"allocated_bytes.all.current"
])
bytes_for_kv_budget
=
int
(
total
*
frac
*
0.9
)
-
used
-
peak
+
current
if
bytes_for_kv_budget
<=
0
:
raise
RuntimeError
(
f
"Insufficient memory for KV cache."
f
"Try increasing gpu_memory_utilization (currently
{
frac
:.
2
f
}
)."
)
# page_table[L, B, H_kv, N_LOGICAL_PAGES_MAX] + bh_seq_lens[L, B, H_kv]
int32_sz
=
torch
.
empty
((),
dtype
=
torch
.
int32
).
element_size
()
# 4
page_table_bytes_per_layer
=
(
self
.
max_num_batches
*
self
.
num_kv_heads
*
n_logical_pages_max
*
int32_sz
# page_table
+
self
.
max_num_batches
*
self
.
num_kv_heads
*
int32_sz
)
total_page_table_bytes
=
self
.
num_layers
*
page_table_bytes_per_layer
kv_bytes_net
=
bytes_for_kv_budget
-
total_page_table_bytes
if
kv_bytes_net
<=
0
:
raise
RuntimeError
(
"page-table footprint exceeds KV cache budget. "
f
"reduce max_num_seqs (
{
self
.
max_num_batches
}
) "
f
"or increase kv_cache_mem_fraction (currently
{
frac
:.
2
f
}
)."
)
dtype_sz
=
torch
.
empty
((),
dtype
=
self
.
model_dtype
).
element_size
()
bytes_per_page_across_layers
=
self
.
num_layers
*
(
2
*
self
.
page_size
*
self
.
head_dim
*
dtype_sz
)
return
max
(
1
,
kv_bytes_net
//
bytes_per_page_across_layers
)
def
estimate_max_batched_tokens
(
self
,
warmup_tokens
:
int
,
bytes_used_before_warmup
:
int
,
bytes_peak_after_warmup
:
int
,
)
->
int
:
"""
Estimate the max total number of tokens that can be processed concurrently
without OOM.
"""
assert
warmup_tokens
>
0
,
"warmup_tokens must be > 0"
# activation bytes per token
warmup_delta
=
max
(
0
,
int
(
bytes_peak_after_warmup
)
-
int
(
bytes_used_before_warmup
)
)
bytes_per_token
=
max
(
1
,
(
warmup_delta
+
warmup_tokens
-
1
)
//
warmup_tokens
)
free
,
total
=
torch
.
cuda
.
mem_get_info
()
target
=
int
(
total
*
self
.
gpu_frac
)
used_now
=
int
(
total
-
free
)
# reserve headroom equal to the gap between peak and current allocations seen so far
stats
=
torch
.
cuda
.
memory_stats
()
peak_cur
=
int
(
stats
.
get
(
"allocated_bytes.all.peak"
,
0
))
cur_now
=
int
(
stats
.
get
(
"allocated_bytes.all.current"
,
0
))
cushion
=
max
(
0
,
peak_cur
-
cur_now
)
activation_budget
=
int
(
max
(
0
,
target
-
used_now
-
cushion
)
*
0.95
)
max_tokens_per_batch
=
activation_budget
//
bytes_per_token
max_tokens_in_cache
=
(
self
.
num_pages
*
self
.
page_size
)
//
self
.
num_kv_heads
# round to lower multiple of page size
max_tokens_per_batch
=
(
max_tokens_per_batch
//
self
.
page_size
)
*
self
.
page_size
max_tokens_in_cache
=
(
max_tokens_in_cache
//
self
.
page_size
)
*
self
.
page_size
self
.
max_batched_tokens
=
min
(
max_tokens_in_cache
,
max_tokens_per_batch
)
return
self
.
max_batched_tokens
@
property
def
num_free_batches
(
self
)
->
int
:
return
len
(
self
.
paged_cache
.
free_batches
)
@
property
def
num_free_pages
(
self
)
->
int
:
return
min
(
len
(
fp
)
for
fp
in
self
.
paged_cache
.
free_pages
)
def
reclaim_pages
(
self
,
seq_ids_to_reclaim
:
Iterable
[
int
],
future_reserved_buffer
:
List
[
int
]
|
torch
.
Tensor
,
)
->
int
:
approximate_bytes_freed
=
0
for
i
,
seq_id
in
enumerate
(
seq_ids_to_reclaim
):
batch_idx
=
self
.
seq_id_to_batch
[
seq_id
]
approximate_bytes_freed
+=
self
.
paged_cache
.
reclaim_pages
(
batch_idx
,
future_reserved_buffer
[
i
]
)
return
approximate_bytes_freed
vllm/compactor-vllm/src/compactor_vllm/core/model_runner.py
0 → 100644
View file @
2b7160c6
import
atexit
import
logging
import
inspect
from
typing
import
List
,
Optional
import
torch
import
torch.distributed
as
dist
from
compactor_vllm.attention.sparse_decode_kernel
import
num_splits_heuristic
from
compactor_vllm.compression.compression_config
import
BatchCompressionParams
from
compactor_vllm.config.constants
import
RESERVED_BATCH
from
compactor_vllm.config.engine_config
import
AttentionBackend
,
LLMConfig
from
compactor_vllm.core.memory_manager
import
KVCacheManager
from
compactor_vllm.core.scheduler
import
Scheduler
from
compactor_vllm.layers.sampler
import
Sampler
from
compactor_vllm.models
import
MODEL_REGISTRY
from
compactor_vllm.utils.arguments
import
(
DecodeBatchArguments
,
DecodeBatchOutput
,
PackedTensorArguments
,
PrefillBatchArguments
,
)
from
compactor_vllm.utils.context
import
CompressionContext
,
reset_context
,
set_context
from
compactor_vllm.utils.sequence
import
Sequence
from
torch.multiprocessing
import
Event
from
tqdm
import
tqdm
logger
=
logging
.
getLogger
(
__name__
)
class
ModelRunner
:
"""Per-rank execution loop. Manages model, sampler, KV cache, and warmup"""
def
__init__
(
self
,
config
:
LLMConfig
,
rank
:
int
,
batch_ready
:
Optional
[
Event
]
=
None
,
peer_events
:
List
[
Event
]
=
None
,
):
self
.
rank
=
rank
self
.
config
=
config
_dev
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
assert
config
.
eos_token_ids
is
not
None
and
len
(
config
.
eos_token_ids
)
>
0
,
(
"LLMConfig.eos_token_ids must be set (filled in LLMEngine from tokenizer)."
)
self
.
_stop_token_ids
=
torch
.
tensor
(
config
.
eos_token_ids
,
dtype
=
torch
.
int64
,
device
=
_dev
)
hf_config
=
config
.
hf_config
self
.
enforce_eager
=
config
.
enforce_eager
self
.
world_size
=
config
.
tensor_parallel_size
self
.
leverage_sketch_size
=
config
.
leverage_sketch_size
self
.
show_progress_bar
=
config
.
show_progress_bar
self
.
max_num_batches
=
config
.
max_num_seqs
self
.
max_model_len
=
config
.
max_model_len
self
.
num_layers
=
hf_config
.
num_hidden_layers
self
.
model_dtype
=
hf_config
.
torch_dtype
self
.
head_dim
=
getattr
(
hf_config
,
"head_dim"
,
None
)
init_kwargs
=
{}
if
"device_id"
in
inspect
.
signature
(
dist
.
init_process_group
).
parameters
:
init_kwargs
[
"device_id"
]
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
dist
.
init_process_group
(
"nccl"
,
f
"tcp://localhost:
{
config
.
nccl_port
}
"
,
world_size
=
self
.
world_size
,
rank
=
rank
,
**
init_kwargs
,
)
torch
.
cuda
.
set_device
(
rank
)
default_dtype
=
torch
.
get_default_dtype
()
torch
.
set_default_dtype
(
hf_config
.
torch_dtype
)
torch
.
set_default_device
(
"cuda"
)
model_type
=
hf_config
.
model_type
self
.
model
=
MODEL_REGISTRY
[
model_type
](
hf_config
)
self
.
model
.
load_model
(
config
.
path
,
use_tqdm
=
self
.
is_master
and
self
.
show_progress_bar
)
self
.
sampler
=
Sampler
()
pre_warmup_mem
=
torch
.
cuda
.
memory_stats
().
get
(
"allocated_bytes.all.current"
,
0
)
self
.
warmup
(
num_warmup_tokens
=
self
.
max_model_len
,
attention_backend
=
AttentionBackend
.
FLASH_ATTENTION
,
)
post_warmup_peak
=
torch
.
cuda
.
memory_stats
().
get
(
"allocated_bytes.all.peak"
,
0
)
self
.
kv_manager
=
KVCacheManager
(
rank
,
config
)
self
.
kv_manager
.
init_cache
(
self
.
model
)
self
.
store_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
torch
.
cuda
.
Stream
()
torch
.
set_default_device
(
"cpu"
)
torch
.
set_default_dtype
(
default_dtype
)
self
.
batch_ready
=
batch_ready
self
.
peer_events
=
peer_events
if
peer_events
is
not
None
else
[]
self
.
captured_graphs
=
{}
self
.
min_captured_len
=
{}
self
.
max_batched_tokens
=
self
.
kv_manager
.
estimate_max_batched_tokens
(
self
.
max_model_len
,
pre_warmup_mem
,
post_warmup_peak
)
if
self
.
is_master
:
logger
.
info
(
f
"Estimated max batched tokens of
{
self
.
max_batched_tokens
}
"
)
if
self
.
config
.
attention_backend
==
AttentionBackend
.
COMPACTOR_TRITON
:
self
.
warmup
(
num_warmup_tokens
=
self
.
max_model_len
,
attention_backend
=
AttentionBackend
.
COMPACTOR_TRITON
,
)
if
not
self
.
enforce_eager
:
bs
=
[
1
<<
i
for
i
in
range
(
self
.
max_num_batches
.
bit_length
())]
for
bs
in
(
tqdm
(
bs
,
desc
=
"Capturing CUDA Graphs"
)
if
self
.
is_master
and
self
.
show_progress_bar
else
bs
):
for
seq_len
in
[
1024
,
4096
,
8192
,
16384
]:
self
.
capture_cudagraph
(
bs
,
seq_len
)
self
.
packed_args
=
PackedTensorArguments
(
rank
=
self
.
rank
,
max_batched_tokens
=
self
.
max_batched_tokens
,
config
=
self
.
config
,
)
atexit
.
register
(
self
.
exit
)
@
torch
.
inference_mode
()
def
warmup
(
self
,
num_warmup_tokens
:
int
,
attention_backend
:
AttentionBackend
):
if
self
.
rank
==
0
:
if
attention_backend
==
AttentionBackend
.
COMPACTOR_TRITON
:
backend_name
=
"Compactor Triton"
else
:
backend_name
=
"Flash"
logger
.
info
(
f
"Warming up with
{
backend_name
}
Attention Backend"
)
device
=
torch
.
device
(
f
"cuda:
{
self
.
rank
}
"
)
input_ids
=
torch
.
tensor
(
[
self
.
config
.
eos
]
*
num_warmup_tokens
,
device
=
device
,
dtype
=
torch
.
int64
)
positions
=
torch
.
arange
(
num_warmup_tokens
,
device
=
device
,
dtype
=
torch
.
int64
)
cu_seqlens_q
=
torch
.
tensor
(
[
0
,
num_warmup_tokens
],
device
=
device
,
dtype
=
torch
.
int32
)
cu_seqlens_k
=
torch
.
tensor
(
[
0
,
num_warmup_tokens
],
device
=
device
,
dtype
=
torch
.
int32
)
if
attention_backend
==
AttentionBackend
.
COMPACTOR_TRITON
:
success
,
batch_mapping
=
self
.
kv_manager
.
allocate_sequences
(
[
-
1
],
[
num_warmup_tokens
]
)
assert
success
else
:
batch_mapping
=
None
set_context
(
is_prefill
=
True
,
do_compression
=
False
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
max_seqlen_q
=
num_warmup_tokens
,
max_seqlen_k
=
num_warmup_tokens
,
batch_mapping
=
batch_mapping
,
attention_backend
=
attention_backend
,
)
for
_
in
range
(
2
):
torch
.
cuda
.
reset_peak_memory_stats
()
self
.
model
.
compute_logits
(
self
.
model
(
input_ids
,
positions
))
dist
.
barrier
()
if
attention_backend
==
AttentionBackend
.
COMPACTOR_TRITON
:
self
.
kv_manager
.
paged_cache
.
bh_seq_lens
.
index_fill_
(
1
,
batch_mapping
.
to
(
torch
.
long
),
0
)
reset_context
()
if
attention_backend
==
AttentionBackend
.
COMPACTOR_TRITON
:
self
.
kv_manager
.
free_sequences
([
-
1
])
def
exit
(
self
):
if
getattr
(
self
,
"_exited"
,
False
):
return
self
.
_exited
=
True
try
:
if
hasattr
(
self
,
"captured_graphs"
):
self
.
captured_graphs
.
clear
()
finally
:
if
dist
.
is_initialized
():
dist
.
destroy_process_group
()
def
loop
(
self
):
while
True
:
if
self
.
batch_ready
.
wait
(
1.0
):
self
.
_process_batches_peer
()
@
torch
.
inference_mode
()
def
run_prefill
(
self
,
prefill_args
:
PrefillBatchArguments
,
batch_mapping
:
torch
.
Tensor
):
assert
prefill_args
.
B
>
0
and
prefill_args
.
N
>
0
max_bh_len
=
(
self
.
kv_manager
.
paged_cache
.
bh_seq_lens
.
index_select
(
1
,
index
=
batch_mapping
)
.
max
()
.
item
()
)
compression_context
=
CompressionContext
(
compression_method
=
prefill_args
.
compression_method
,
compression_chunk_size
=
prefill_args
.
compression_chunk_size
,
batch_tokens_to_retain
=
prefill_args
.
batch_tokens_to_retain
,
max_tokens_to_retain
=
prefill_args
.
max_tokens_to_retain
,
context_lens
=
prefill_args
.
context_lens
.
tolist
(),
PHI
=
prefill_args
.
PHI
,
sketch_dimension
=
self
.
leverage_sketch_size
,
protected_first_tokens
=
prefill_args
.
protected_first
,
protected_last_tokens
=
prefill_args
.
protected_last
,
compression_ratio
=
prefill_args
.
compression_ratio
,
)
set_context
(
is_prefill
=
True
,
do_compression
=
prefill_args
.
do_compression
,
cu_seqlens_q
=
prefill_args
.
cu_seqlens_q
,
cu_seqlens_k
=
prefill_args
.
cu_seqlens_k
,
max_seqlen_q
=
prefill_args
.
max_seqlen_q
,
max_seqlen_k
=
prefill_args
.
max_seqlen_k
,
batch_mapping
=
batch_mapping
,
max_bh_len
=
max_bh_len
,
compression_context
=
compression_context
,
STORE_STREAM
=
self
.
store_stream
,
attention_backend
=
self
.
config
.
attention_backend
,
)
logits
=
self
.
model
.
compute_logits
(
self
.
model
(
prefill_args
.
input_ids
,
prefill_args
.
positions
)
)
reset_context
()
return
logits
def
maybe_broadcast
(
self
,
tensor
:
torch
.
Tensor
):
if
self
.
world_size
>
1
:
return
dist
.
broadcast
(
tensor
,
src
=
0
)
return
None
def
maybe_release_peers
(
self
,
do_release
=
False
):
if
self
.
world_size
>
1
:
if
self
.
is_master
:
if
do_release
:
for
event
in
self
.
peer_events
:
event
.
clear
()
dist
.
barrier
()
else
:
dist
.
barrier
()
@
torch
.
inference_mode
()
def
generate
(
self
,
all_sequences
:
List
[
Sequence
],
batch_compression_params
:
Optional
[
BatchCompressionParams
]
=
None
,
):
assert
self
.
is_master
,
"generate can only be called on the master process"
for
begin_execution_event
in
self
.
peer_events
:
begin_execution_event
.
set
()
if
batch_compression_params
is
None
:
batch_compression_params
=
BatchCompressionParams
()
self
.
_process_batches_master
(
all_sequences
,
batch_compression_params
)
@
property
def
is_master
(
self
):
return
self
.
rank
==
0
@
torch
.
inference_mode
()
def
_process_batches_master
(
self
,
all_sequences
:
List
[
Sequence
],
batch_compression_params
:
BatchCompressionParams
,
):
assert
self
.
is_master
compression_details
=
f
"Applying Compression Method:
{
batch_compression_params
.
compression_method
}
"
if
any
(
seq
.
compression_params
.
compression_ratio
<
1.0
for
seq
in
all_sequences
):
logger
.
info
(
compression_details
)
scheduler
=
Scheduler
(
all_sequences
=
all_sequences
,
kv_manager
=
self
.
kv_manager
,
use_tqdm
=
self
.
show_progress_bar
,
)
decode_batch
=
DecodeBatchArguments
()
decode_flags
=
torch
.
empty
(
2
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
while
not
scheduler
.
is_finished
():
sequences
=
scheduler
.
get_prefill_batch
()
seq_ids_cpu
=
[
seq
.
seq_id
for
seq
in
sequences
]
scheduler
.
add_running_sequence_ids
(
seq_ids_cpu
,
update_status
=
True
)
temps
=
torch
.
tensor
(
[
s
.
sampling_params
.
temperature
for
s
in
sequences
],
dtype
=
torch
.
float32
,
pin_memory
=
True
,
).
cuda
(
non_blocking
=
True
)
prefill_arguments
=
self
.
packed_args
.
build_prefill_args
(
sequences
,
batch_compression_params
=
batch_compression_params
)
max_ctx_lens
=
(
prefill_arguments
.
max_new_tokens
+
prefill_arguments
.
context_lens
)
success
,
batch_mapping
=
self
.
kv_manager
.
allocate_sequences
(
seq_ids_cpu
,
max_ctx_lens
.
tolist
()
)
assert
success
,
"failed to allocate pages for sequences"
logits
=
self
.
run_prefill
(
prefill_arguments
,
batch_mapping
)
# Must match prefill `positions` dtype (int64). `context_lens` is int32
# from the packed buffer; using int32 here breaks RoPE indexing
# (`cos_sin_cache[positions]`) on CUDA for decode vs prefill.
positions
=
prefill_arguments
.
context_lens
.
to
(
dtype
=
torch
.
int64
)
token_ids
=
self
.
sampler
(
logits
,
temps
)
# Prefill KV writes + bh_seq_lens updates run on STORE_STREAM; reclaim
# reads bh_seq_lens on the default stream and must not race.
if
self
.
store_stream
is
not
None
:
torch
.
cuda
.
default_stream
().
wait_stream
(
self
.
store_stream
)
# TODO: synchronize page counts accross dist
if
self
.
world_size
==
1
:
self
.
kv_manager
.
reclaim_pages
(
seq_ids_cpu
,
prefill_arguments
.
max_new_tokens
)
# with logging_redirect_tqdm():
# logger.info(
# f"Reclaimed {reclaimed_bytes / 1e6:.2f} MB from the KV cache"
# )
if
scheduler
.
any_pending_sequences
():
num_pending_batches
=
(
0
if
decode_batch
.
token_ids
is
None
else
decode_batch
.
token_ids
.
shape
[
0
]
)
occupancy
=
int
((
num_pending_batches
+
len
(
seq_ids_cpu
))
*
0.66
)
else
:
occupancy
=
-
1
run_decode
=
not
scheduler
.
can_prefill_another_batch
()
decode_batch
=
decode_batch
.
update
(
batch_mapping
,
token_ids
,
positions
,
max_ctx_lens
,
prefill_arguments
.
seq_ids
,
temps
,
occupancy
,
)
if
self
.
world_size
>
1
:
decode_flags
[
0
]
=
int
(
run_decode
)
decode_flags
[
1
]
=
occupancy
self
.
maybe_broadcast
(
decode_flags
)
if
not
run_decode
:
continue
if
self
.
store_stream
is
not
None
:
torch
.
cuda
.
default_stream
().
wait_stream
(
self
.
store_stream
)
decode_output
,
decode_batch
=
self
.
run_decode_loop
(
decode_batch
)
finished_sequence_ids
=
scheduler
.
get_finished_sequence_ids_from_unfinished
(
decode_batch
.
seq_ids
.
tolist
()
)
scheduler
.
record_finished_sequence_ids
(
finished_sequence_ids
,
update_status
=
True
)
self
.
kv_manager
.
free_sequences
(
finished_sequence_ids
)
self
.
maybe_release_peers
(
scheduler
.
is_finished
())
scheduler
.
update_sequences
(
decode_output
.
output_tokens
.
tolist
(),
decode_output
.
output_seq_ids
.
tolist
(),
)
scheduler
.
close
()
@
torch
.
inference_mode
()
def
_process_batches_peer
(
self
):
assert
not
self
.
is_master
scheduler
=
Scheduler
([],
kv_manager
=
self
.
kv_manager
)
decode_batch
=
DecodeBatchArguments
()
decode_flags
=
torch
.
empty
(
2
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
while
self
.
batch_ready
.
is_set
():
prefill_arguments
=
self
.
packed_args
.
build_prefill_args
()
B
=
prefill_arguments
.
B
max_ctx_lens
=
(
prefill_arguments
.
max_new_tokens
+
prefill_arguments
.
context_lens
)
seq_ids_cpu
=
prefill_arguments
.
seq_ids
.
tolist
()
scheduler
.
add_running_sequence_ids
(
seq_ids_cpu
)
success
,
batch_mapping
=
self
.
kv_manager
.
allocate_sequences
(
seq_ids_cpu
,
max_ctx_lens
.
tolist
()
)
assert
success
,
"failed to allocate pages for sequences"
self
.
run_prefill
(
prefill_arguments
,
batch_mapping
)
positions
=
prefill_arguments
.
context_lens
.
to
(
dtype
=
torch
.
int64
)
self
.
maybe_broadcast
(
decode_flags
)
run_decode
=
bool
(
decode_flags
[
0
].
item
())
occupancy
=
int
(
decode_flags
[
1
].
item
())
token_ids
=
torch
.
empty
(
B
,
dtype
=
torch
.
int64
,
device
=
"cuda"
)
decode_batch
=
decode_batch
.
update
(
batch_mapping
,
token_ids
,
positions
,
max_ctx_lens
,
prefill_arguments
.
seq_ids
,
None
,
# temps not used in peer process
occupancy
,
)
if
not
run_decode
:
continue
if
self
.
store_stream
is
not
None
:
torch
.
cuda
.
default_stream
().
wait_stream
(
self
.
store_stream
)
_
,
decode_batch
=
self
.
run_decode_loop
(
decode_batch
)
finished_sequence_ids
=
scheduler
.
get_finished_sequence_ids_from_unfinished
(
decode_batch
.
seq_ids
.
tolist
()
)
scheduler
.
record_finished_sequence_ids
(
finished_sequence_ids
)
self
.
kv_manager
.
free_sequences
(
finished_sequence_ids
)
self
.
maybe_release_peers
()
scheduler
.
close
()
@
torch
.
inference_mode
()
def
run_decode_loop
(
self
,
decode_batch
:
DecodeBatchArguments
,
)
->
tuple
[
DecodeBatchOutput
,
DecodeBatchArguments
]:
if
self
.
is_master
:
num_stashed_batches
=
decode_batch
.
num_stashed_batches
tok_buffer
=
[
decode_batch
.
token_ids
[
num_stashed_batches
:].
to
(
"cpu"
,
non_blocking
=
True
)
]
seq_buffer
=
[
decode_batch
.
seq_ids
[
num_stashed_batches
:].
to
(
"cpu"
,
non_blocking
=
True
)
]
while
True
:
self
.
maybe_broadcast
(
decode_batch
.
token_ids
)
not_stopped
=
~
torch
.
isin
(
decode_batch
.
token_ids
,
self
.
_stop_token_ids
)
running_batches
=
(
decode_batch
.
positions
<
decode_batch
.
max_ctx_lens
)
&
(
not_stopped
)
decode_batch
.
token_ids
=
torch
.
masked_select
(
decode_batch
.
token_ids
,
running_batches
)
decode_batch
.
positions
=
torch
.
masked_select
(
decode_batch
.
positions
,
running_batches
)
decode_batch
.
batch_mapping
=
torch
.
masked_select
(
decode_batch
.
batch_mapping
,
running_batches
)
decode_batch
.
max_ctx_lens
=
torch
.
masked_select
(
decode_batch
.
max_ctx_lens
,
running_batches
)
decode_batch
.
seq_ids
=
torch
.
masked_select
(
decode_batch
.
seq_ids
,
running_batches
)
if
self
.
is_master
:
decode_batch
.
temps
=
torch
.
masked_select
(
decode_batch
.
temps
,
running_batches
)
num_remaining
=
decode_batch
.
token_ids
.
numel
()
if
(
num_remaining
==
0
or
num_remaining
<=
decode_batch
.
desired_batch_occupancy
):
decode_batch
.
num_stashed_batches
=
num_remaining
break
if
self
.
enforce_eager
:
set_context
(
is_prefill
=
False
,
do_compression
=
False
,
batch_mapping
=
decode_batch
.
batch_mapping
,
)
logits
=
self
.
model
.
compute_logits
(
self
.
model
(
decode_batch
.
token_ids
,
decode_batch
.
positions
)
)
else
:
logits
=
self
.
run_graph_decode
(
decode_batch
.
token_ids
,
decode_batch
.
positions
,
decode_batch
.
batch_mapping
,
)
if
self
.
is_master
:
decode_batch
.
token_ids
=
self
.
sampler
(
logits
,
decode_batch
.
temps
)
tok_buffer
.
append
(
decode_batch
.
token_ids
.
to
(
"cpu"
,
non_blocking
=
True
))
seq_buffer
.
append
(
decode_batch
.
seq_ids
.
to
(
"cpu"
,
non_blocking
=
True
))
decode_batch
.
positions
+=
1
if
self
.
is_master
:
# non_blocking D2H copies must finish before cat/tolist read CPU data.
torch
.
cuda
.
synchronize
()
output
=
DecodeBatchOutput
(
output_tokens
=
torch
.
cat
(
tok_buffer
),
output_seq_ids
=
torch
.
cat
(
seq_buffer
),
)
else
:
output
=
DecodeBatchOutput
(
None
,
None
)
return
output
,
decode_batch
@
torch
.
inference_mode
()
def
run_graph_decode
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
batch_mapping
:
torch
.
Tensor
,
):
set_context
(
is_prefill
=
False
,
do_compression
=
False
,
batch_mapping
=
batch_mapping
,
)
bs
=
input_ids
.
shape
[
0
]
graph_dict
=
self
.
get_cuda_graph
(
bs
,
int
(
positions
.
max
()))
graph_dict
[
"input_ids"
][:
bs
]
=
input_ids
graph_dict
[
"positions"
][:
bs
]
=
positions
graph_dict
[
"batch_mapping"
].
fill_
(
RESERVED_BATCH
)
graph_dict
[
"batch_mapping"
][:
bs
]
=
batch_mapping
graph_dict
[
"graph"
].
replay
()
return
(
graph_dict
[
"logits"
][:
bs
]
if
graph_dict
[
"logits"
]
is
not
None
else
graph_dict
[
"logits"
]
)
@
torch
.
inference_mode
()
def
capture_cudagraph
(
self
,
batch_size
:
int
,
max_seqlen_k
:
int
):
dist
.
barrier
()
device
=
torch
.
device
(
"cuda"
)
logger
.
debug
(
f
"Capturing CUDA graph for batch size
{
batch_size
}
(
{
max_seqlen_k
}
tokens)"
)
_g_input_ids
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
int32
,
device
=
device
)
_g_positions
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
device
)
_g_logits
=
None
key_split
=
num_splits_heuristic
(
batch_size
*
self
.
kv_manager
.
num_kv_heads
,
max_seq_len
=
max_seqlen_k
,
num_sms
=
torch
.
cuda
.
get_device_properties
(
device
).
multi_processor_count
,
max_splits
=
12
,
)
success
,
_g_batch_mapping
=
self
.
kv_manager
.
allocate_sequences
(
list
(
range
(
batch_size
)),
[
256
]
*
batch_size
)
assert
success
set_context
(
is_prefill
=
False
,
do_compression
=
False
,
batch_mapping
=
_g_batch_mapping
,
key_split
=
key_split
,
)
# warmup
self
.
model
.
compute_logits
(
self
.
model
(
_g_input_ids
,
_g_positions
))
dist
.
barrier
()
decode_graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
decode_graph
):
_g_logits
=
self
.
model
.
compute_logits
(
self
.
model
(
_g_input_ids
,
_g_positions
)
)
graph_vars
=
{
"graph"
:
decode_graph
,
"input_ids"
:
_g_input_ids
,
"positions"
:
_g_positions
,
"batch_mapping"
:
_g_batch_mapping
,
"logits"
:
_g_logits
,
"key_split"
:
key_split
,
}
if
batch_size
not
in
self
.
captured_graphs
:
self
.
captured_graphs
[
batch_size
]
=
{}
self
.
min_captured_len
[
batch_size
]
=
float
(
"inf"
)
self
.
captured_graphs
[
batch_size
][
max_seqlen_k
]
=
graph_vars
self
.
min_captured_len
[
batch_size
]
=
min
(
max_seqlen_k
,
self
.
min_captured_len
[
batch_size
]
)
self
.
kv_manager
.
free_sequences
(
list
(
range
(
batch_size
)))
def
get_cuda_graph
(
self
,
batch_size
:
int
,
max_seqlen_k
:
int
):
batch_size
=
next
(
x
for
x
in
self
.
captured_graphs
.
keys
()
if
x
>=
batch_size
)
batch_size_graphs
=
self
.
captured_graphs
[
batch_size
]
# we want largest seq_len that is smaller than max_seqlen_k
best
=
self
.
min_captured_len
[
batch_size
]
for
seq_len
in
batch_size_graphs
.
keys
():
if
seq_len
<=
max_seqlen_k
:
best
=
max
(
best
,
seq_len
)
return
batch_size_graphs
[
best
]
vllm/compactor-vllm/src/compactor_vllm/core/scheduler.py
0 → 100644
View file @
2b7160c6
import
time
from
typing
import
Iterable
,
List
from
compactor_vllm.core.memory_manager
import
KVCacheManager
from
compactor_vllm.utils.sequence
import
Sequence
,
SequenceStatus
from
tqdm
import
tqdm
def
cdiv
(
a
,
b
):
"""ceiling division"""
return
(
a
+
b
-
1
)
//
b
class
Scheduler
:
"""
Simple sequence scheduler for prefill + decode with a paged KV cache.
The scheduler tracks three disjoint sets of sequence IDs:
* ``pending_sequence_ids`` – sequences that have not yet been started.
* ``active_sequence_ids`` – sequences currently running.
* ``finished_sequence_ids`` – sequences that have generated all tokens.
At prefill time, :meth:`get_prefill_batch` selects a subset of pending
sequences that can fit into the available KV cache and per-step token
budget, given the constraints from the associated :class:`KVCacheManager`.
The class also handles basic bookkeeping of sequence statuses.
Args:
:param all_sequences:
Iterable of :class:`Sequence` objects to be scheduled. Each
sequence must have a unique ``seq_id``.
:param kv_manager:
A :class:`KVCacheManager` instance that this scheduler will use
to determine whether additional batches can be scheduled.
:param use_tqdm:
If True, two progress bars are created:
* "Started Batches" – increments when a sequence moves from
pending to running.
* "Finished Batches" – increments when a sequence finishes.
"""
def
__init__
(
self
,
all_sequences
:
Iterable
[
Sequence
],
kv_manager
:
KVCacheManager
,
*
,
use_tqdm
=
False
,
):
self
.
allseq_mapping
:
dict
[
int
,
Sequence
]
=
{
s
.
seq_id
:
s
for
s
in
all_sequences
}
self
.
pending_sequence_ids
:
set
[
int
]
=
set
([
s
.
seq_id
for
s
in
all_sequences
])
self
.
active_sequence_ids
:
set
[
int
]
=
set
()
self
.
finished_sequence_ids
:
set
[
int
]
=
set
()
self
.
manager
=
kv_manager
self
.
use_tqdm
=
use_tqdm
self
.
start_time
=
time
.
perf_counter
()
self
.
total_tokens_generated
=
0
self
.
total_tokens_input
=
0
self
.
pbar
=
None
if
use_tqdm
:
self
.
pbar
=
tqdm
(
total
=
len
(
self
.
pending_sequence_ids
),
desc
=
"Completed Batches"
,
)
def
get_prefill_batch
(
self
)
->
List
[
Sequence
]:
"""
Select a batch of pending sequences to prefill under KV/memory constraints.
The selection is greedy over ``pending_sequence_ids`` in iteration order.
A sequence is added to the batch if:
* The sum of its prompt length and the total prompt tokens selected so
far does not exceed ``manager.max_batched_tokens``, and
* There is at least one free KV "batch slot" left
(``manager.num_free_batches``), and
* The total number of KV pages required by the sequence's prompt +
max_new_tokens does not exceed the remaining free pages.
Returns:
:return List[Sequence]:
The list of :class:`Sequence` objects chosen for prefill in
this step. The caller is responsible for marking them as
active via :meth:`add_running_sequence_ids`.
"""
total_tok
,
sequences
=
0
,
[]
num_free_batches
,
num_free_pages
=
(
self
.
manager
.
num_free_batches
,
self
.
manager
.
num_free_pages
,
)
for
seq_id
in
self
.
pending_sequence_ids
:
seq
=
self
.
allseq_mapping
[
seq_id
]
prompt_length
=
seq
.
prompt_len
pages_needed
=
(
cdiv
(
prompt_length
+
seq
.
sampling_params
.
max_new_tokens
,
self
.
manager
.
page_size
,
)
*
self
.
manager
.
num_kv_heads
)
if
(
prompt_length
+
total_tok
<=
self
.
manager
.
max_batched_tokens
and
num_free_batches
>
0
and
pages_needed
<
num_free_pages
):
sequences
.
append
(
seq
)
total_tok
+=
prompt_length
num_free_pages
-=
pages_needed
num_free_batches
-=
1
return
sequences
def
is_finished
(
self
)
->
bool
:
"""
Check whether all sequences have completed.
"""
return
(
len
(
self
.
pending_sequence_ids
)
==
0
and
len
(
self
.
active_sequence_ids
)
==
0
)
def
any_pending_sequences
(
self
)
->
bool
:
"""
Check whether any sequences are still pending (not yet started).
"""
return
len
(
self
.
pending_sequence_ids
)
!=
0
def
add_running_sequence_ids
(
self
,
active_sequence_ids
:
Iterable
[
int
],
*
,
update_status
:
bool
=
False
):
"""
Mark a set of sequences as active / running. This moves sequence IDs
from ``pending_sequence_ids`` into ``active_sequence_ids``. Optionally,
it also updates the per-sequence status and progress bar.
Args:
:param active_sequence_ids:
Iterable of sequence IDs that have been scheduled for prefill
or decode and should now be considered running.
:param update_status:
If True, set each corresponding :class:`Sequence`'s
``status = SequenceStatus.RUNNING`` and increment the
"Started Batches" progress bar if ``use_tqdm`` is enabled.
"""
self
.
active_sequence_ids
.
update
(
active_sequence_ids
)
self
.
pending_sequence_ids
.
difference_update
(
self
.
active_sequence_ids
)
if
update_status
:
for
seq_id
in
active_sequence_ids
:
self
.
allseq_mapping
[
seq_id
].
status
=
SequenceStatus
.
RUNNING
self
.
total_tokens_input
+=
self
.
allseq_mapping
[
seq_id
].
prompt_len
def
get_finished_sequence_ids_from_unfinished
(
self
,
unfinished_sequence_ids
:
Iterable
[
int
]
)
->
set
[
int
]:
"""
Infer which active sequences have finished given the
unfinished set (for decode steps where the caller knows
which sequences are still generating but not necessarily
which have just completed).
Args:
:param unfinished_sequence_ids:
Iterable of sequence IDs that are still running
Returns:
:return set[int]:
The inferred set of sequence IDs that transitioned from active
to finished.
"""
return
self
.
active_sequence_ids
.
difference
(
unfinished_sequence_ids
)
def
record_finished_sequence_ids
(
self
,
finished_sequence_ids
:
Iterable
[
int
],
*
,
update_status
:
bool
=
False
):
"""
Record that a set of sequences has finished generation.
This moves IDs from ``active_sequence_ids`` into
``finished_sequence_ids``.
Args:
:param finished_sequence_ids:
Iterable of sequence IDs that have completed generation and
no longer require KV cache.
:param update_status:
If True, set each corresponding :class:`Sequence`'s
``status = SequenceStatus.FINISHED``
"""
self
.
active_sequence_ids
.
difference_update
(
finished_sequence_ids
)
self
.
finished_sequence_ids
.
update
(
finished_sequence_ids
)
if
update_status
:
for
seq_id
in
finished_sequence_ids
:
self
.
allseq_mapping
[
seq_id
].
status
=
SequenceStatus
.
FINISHED
if
self
.
pbar
is
not
None
:
self
.
pbar
.
update
(
1
)
def
update_sequences
(
self
,
tokens
:
Iterable
[
int
],
seq_ids
:
Iterable
[
int
]):
"""
Append newly generated tokens to their corresponding sequences.
Args:
:param tokens:
Iterable of generated token IDs, one per sequence.
:param seq_ids:
Iterable of sequence IDs aligned with ``tokens``.
"""
cur_time
=
time
.
perf_counter
()
for
tok
,
seq_id
in
zip
(
tokens
,
seq_ids
):
self
.
allseq_mapping
[
seq_id
].
add_new_token
(
tok
)
self
.
total_tokens_generated
+=
1
if
self
.
pbar
is
not
None
:
self
.
pbar
.
set_description
(
f
"Throughput:
{
(
self
.
total_tokens_generated
+
self
.
total_tokens_input
)
/
(
cur_time
-
self
.
start_time
):.
2
f
}
tok/s"
)
def
close
(
self
):
if
self
.
pbar
is
not
None
:
self
.
pbar
.
close
()
def
can_prefill_another_batch
(
self
)
->
bool
:
return
len
(
self
.
get_prefill_batch
())
>
0
vllm/compactor-vllm/src/compactor_vllm/kv_cache/__init__.py
0 → 100644
View file @
2b7160c6
vllm/compactor-vllm/src/compactor_vllm/kv_cache/page_table.py
0 → 100644
View file @
2b7160c6
import
heapq
import
logging
from
enum
import
Enum
,
auto
from
typing
import
List
,
Optional
,
Union
import
torch
from
compactor_vllm.config.constants
import
RESERVED_BATCH
from
compactor_vllm.kv_cache.write_page_table
import
scatter_to_page_table
logger
=
logging
.
getLogger
(
__name__
)
def
cdiv
(
a
,
b
):
return
(
a
+
b
-
1
)
//
b
def
next_multiple
(
a
,
b
):
return
cdiv
(
a
,
b
)
*
b
class
KVAllocationStatus
(
Enum
):
EXCEEDS_MAX_SEQUENCE_LENGTH
=
auto
()
EXCEEDS_CURRENTLY_AVAILABLE_PAGES
=
auto
()
EXCEEDS_MAX_NUM_BATCHES
=
auto
()
SUCCESS
=
auto
()
class
PagedKVCache
(
torch
.
nn
.
Module
):
"""
Global paged KV cache.
This module manages:
* A global K/V backing buffer for all layers:
``kv_cache[2, num_layers, n_pages * page_size, head_dim]``,
where the first dimension indexes K vs V.
* A per-layer page table:
``page_table[num_layers, max_num_seqs, H_kv, max_pages_per_head]``,
mapping logical (batch, kv-head, logical_page) to a physical page ID
in the global K/V buffer.
* Per-layer, per-(batch, kv-head) logical sequence lengths
``bh_seq_lens[num_layers, max_num_seqs, H_kv]`` (in tokens), and
the number of allocated pages ``bh_num_pages`` for each (layer, batch,
head).
* A page allocator implemented as a min-heap of free physical pages
per layer, plus free batch indices.
Pages are of fixed size ``page_size`` tokens.
Args:
:param num_layers:
Number of transformer layers that will use this cache.
:param max_logical_pages_per_head:
Maximum number of logical pages that can be assigned to a single
(batch, kv-head) pair.
:param num_pages:
Total number of physical pages available in the global cache per
layer. The global K/V buffers are of length
``num_pages * page_size`` along the token dimension.
:param page_size:
Number of tokens stored per page.
:param H_kv:
Number of KV heads per layer.
:param head_dim:
Head dimension for K/V.
:param max_num_batches:
Maximum number of concurrent batches / sequences supported. One
batch index is reserved for internal use (``RESERVED_BATCH``).
:param dtype:
Data type of K/V entries (e.g. ``torch.float16`` or ``torch.bfloat16``).
:param device:
Device on which to allocate the cache (string, torch.device, or
int; defaults to ``"cuda"``).
"""
def
__init__
(
self
,
num_layers
:
int
,
max_logical_pages_per_head
:
int
,
num_pages
:
int
,
page_size
:
int
,
# tokens per page
H_kv
:
int
,
head_dim
:
int
,
max_num_batches
:
int
,
dtype
:
torch
.
dtype
,
device
:
Union
[
str
,
torch
.
device
,
int
]
=
"cuda"
,
):
super
().
__init__
()
self
.
n_pages
=
num_pages
self
.
num_layers
=
num_layers
self
.
page_size
:
int
=
int
(
page_size
)
self
.
H_kv
=
int
(
H_kv
)
self
.
max_pages_per_head
=
max_logical_pages_per_head
max_num_batches
+=
1
self
.
max_num_batches
=
max_num_batches
self
.
head_dim
=
head_dim
cache_shape
=
(
2
,
num_layers
,
num_pages
*
page_size
,
head_dim
)
self
.
kv_cache
=
torch
.
empty
(
cache_shape
,
dtype
=
dtype
,
device
=
device
)
self
.
page_table
=
torch
.
empty
(
(
num_layers
,
max_num_batches
,
H_kv
,
self
.
max_pages_per_head
),
device
=
device
,
dtype
=
torch
.
int32
,
)
# Per-(batch, head) logical seq length (tokens)
self
.
bh_seq_lens
=
torch
.
zeros
(
(
num_layers
,
max_num_batches
,
H_kv
),
device
=
device
,
dtype
=
torch
.
int32
)
# self._bh_seq_lens_cpu_buffer = torch.zeros((num_layers, H_kv), device="cpu", dtype=torch.int32)
self
.
bh_num_pages
=
torch
.
zeros
(
(
num_layers
,
max_num_batches
,
H_kv
),
device
=
device
,
dtype
=
torch
.
int32
)
# Page allocator (min-heap of free physical pages)
self
.
free_pages
:
List
[
List
[
int
]]
=
[
list
(
range
(
num_pages
))
for
_
in
range
(
num_layers
)
]
for
free_pages
in
self
.
free_pages
:
heapq
.
heapify
(
free_pages
)
# batch zero is reserved
self
.
free_batches
:
List
[
int
]
=
list
(
reversed
(
range
(
max_num_batches
)))
self
.
free_batches
.
remove
(
RESERVED_BATCH
)
# Record of physical page ids owned by a batch (for freeing)
self
.
pages_indices_per_batch
:
List
[
List
[
set
[
int
]]]
=
[
[
set
()
for
_
in
range
(
num_layers
)]
for
_
in
range
(
max_num_batches
)
]
def
new_batch
(
self
)
->
Optional
[
int
]:
"""
Reserve a new batch slot.
A batch slot corresponds to a row in ``bh_seq_lens`` /
``bh_num_pages`` and a slice in ``page_table`` for all layers and KV
heads. This method checks whether a free batch index is available, and
whether each layer has at least ``H_kv`` free pages remaining.
If both checks pass, it returns a batch index and removes it from
``free_batches``. Otherwise, it returns ``None``.
Returns:
:return Optional[int]:
Newly reserved batch index, or ``None`` if no capacity is
available.
"""
if
self
.
free_batches
and
all
([
self
.
H_kv
<=
len
(
fp
)
for
fp
in
self
.
free_pages
]):
return
self
.
free_batches
.
pop
()
return
None
def
reserve_tokens
(
self
,
batch_index
:
int
,
add_tokens
:
int
)
->
KVAllocationStatus
:
"""
Ensure enough pages are allocated to handle ``add_tokens`` new tokens.
Args:
:param batch_index:
Batch index to reserve space for.
:param add_tokens:
Number of additional tokens to reserve capacity for.
All heads in this batch and all layers reserve
the same number of extra tokens.
Returns:
:return bool:
``True`` if the reservation succeeds; ``False`` otherwise .
"""
cur_bh_lens
=
self
.
bh_seq_lens
[:,
batch_index
]
# [L, H]
curr_pages
=
self
.
bh_num_pages
[:,
batch_index
]
# [L, H]
curr_cap_tokens
=
curr_pages
*
self
.
page_size
# [L, H]
need_tokens
=
cur_bh_lens
+
add_tokens
# [L, H]
if
(
need_tokens
<=
curr_cap_tokens
).
all
():
return
KVAllocationStatus
.
SUCCESS
missing_tokens
=
need_tokens
-
curr_cap_tokens
add_pages
=
cdiv
(
missing_tokens
,
self
.
page_size
)
new_total_pages
=
curr_pages
+
add_pages
if
(
new_total_pages
>
self
.
max_pages_per_head
).
any
():
return
KVAllocationStatus
.
EXCEEDS_MAX_SEQUENCE_LENGTH
# CPU work
pages_per_layer_cpu
=
add_pages
.
sum
(
dim
=-
1
).
tolist
()
new_phys_pages
=
[]
for
layer_index
in
range
(
self
.
num_layers
):
if
pages_per_layer_cpu
[
layer_index
]
>
len
(
self
.
free_pages
[
layer_index
]):
return
KVAllocationStatus
.
EXCEEDS_CURRENTLY_AVAILABLE_PAGES
for
layer_index
in
range
(
self
.
num_layers
):
this_layer_pages
=
[
heapq
.
heappop
(
self
.
free_pages
[
layer_index
])
for
_
in
range
(
pages_per_layer_cpu
[
layer_index
])
]
self
.
pages_indices_per_batch
[
batch_index
][
layer_index
]
|=
set
(
this_layer_pages
)
new_phys_pages
.
extend
(
this_layer_pages
)
new_phys_pages
=
torch
.
tensor
(
new_phys_pages
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
scatter_to_page_table
(
add_pages
=
add_pages
,
new_phys_pages
=
new_phys_pages
,
curr_pages
=
curr_pages
,
page_table
=
self
.
page_table
[:,
batch_index
],
max_pages_per_head
=
self
.
max_pages_per_head
,
)
self
.
bh_num_pages
[:,
batch_index
,
:]
=
new_total_pages
.
to
(
self
.
bh_num_pages
.
dtype
)
return
KVAllocationStatus
.
SUCCESS
def
reclaim_pages
(
self
,
batch_index
:
int
,
future_reserve_tokens
:
int
=
0
,
):
"""
Reclaim unused pages for a single batch index. This shrinks the KV
allocation for the batch down to the minimum number of pages needed
to hold the current (plus optional future) sequence length.
Args:
:param batch_index:
Batch index whose pages should be compacted.
:param future_reserve_tokens:
Optional number of extra tokens to keep capacity for, beyond
the current sequence length. This can reduce churn when
sequences are expected to grow slightly in the near future.
Returns:
:return int:
Approximate number of bytes freed across both K and V.
"""
device
=
self
.
bh_seq_lens
.
device
L
,
B
,
H
=
self
.
bh_seq_lens
.
shape
assert
0
<=
batch_index
<
B
seq
=
self
.
bh_seq_lens
[:,
batch_index
,
:]
+
future_reserve_tokens
# [L, H]
alloc
=
self
.
bh_num_pages
[:,
batch_index
,
:]
# [L, H]
pt
=
self
.
page_table
[:,
batch_index
,
:,
:].
reshape
(
-
1
)
# [L, H, P]
# Compute used pages: ceil_div(seq, page_size), clamped into [0, alloc]
used_pages
=
cdiv
(
seq
,
self
.
page_size
)
used_pages
=
torch
.
minimum
(
used_pages
,
alloc
)
# page indices [0..P-1], broadcasted over [L, H, P]
p
=
torch
.
arange
(
self
.
max_pages_per_head
,
device
=
device
,
dtype
=
torch
.
int32
).
view
(
1
,
1
,
self
.
max_pages_per_head
)
# allocated: p < alloc
alloc_mask
=
p
<
alloc
.
unsqueeze
(
-
1
)
# [L, H, P]
# to free: allocated and p in [used_pages, alloc)
free_mask
=
alloc_mask
&
(
p
>=
used_pages
.
unsqueeze
(
-
1
))
free_mask_flat
=
free_mask
.
view
(
-
1
)
# [L*H*P]
if
not
free_mask_flat
.
any
():
return
0
idx
=
free_mask_flat
.
nonzero
(
as_tuple
=
False
).
squeeze
(
-
1
)
# indices of freed slots
# Freed physical page ids
freed_pages
=
pt
[
idx
]
# Compute layer index for each freed slot:
# layout is [L, H, P] → flat index = ((l * H) + h) * P + p
freed_layers
=
(
idx
//
(
H
*
self
.
max_pages_per_head
)).
to
(
torch
.
int32
)
freed_pages
=
freed_pages
.
tolist
()
layer_mapping
=
freed_layers
.
tolist
()
self
.
bh_num_pages
[:,
batch_index
,
:]
=
used_pages
for
page
,
layer
in
zip
(
freed_pages
,
layer_mapping
):
self
.
pages_indices_per_batch
[
batch_index
][
layer
].
remove
(
page
)
heapq
.
heappush
(
self
.
free_pages
[
layer
],
page
)
approximate_bytes_freed
=
(
len
(
freed_pages
)
*
(
self
.
page_size
*
self
.
head_dim
*
self
.
kv_cache
.
element_size
())
*
2
)
# multiply for two for K + V
return
approximate_bytes_freed
def
_free_batch_layer
(
self
,
layer_index
:
int
,
batch_index
:
int
)
->
None
:
"""
Free all pages belonging to batch_index and reset its metadata.
"""
# Return pages to the global heap
for
phys
in
self
.
pages_indices_per_batch
[
batch_index
][
layer_index
]:
heapq
.
heappush
(
self
.
free_pages
[
layer_index
],
int
(
phys
))
self
.
pages_indices_per_batch
[
batch_index
][
layer_index
]
=
set
()
def
free_batch
(
self
,
batch_index
:
int
)
->
None
:
"""
Free all resources associated with a batch index.
Args:
:param batch_index:
Batch index to release. Must have been previously allocated
via :meth:`new_batch`.
"""
for
layer
in
range
(
self
.
num_layers
):
self
.
_free_batch_layer
(
layer
,
batch_index
)
self
.
bh_seq_lens
[:,
batch_index
].
zero_
()
self
.
bh_num_pages
[:,
batch_index
].
zero_
()
self
.
free_batches
.
append
(
batch_index
)
def
layer_slices
(
self
,
layer
:
int
):
"""
Return layer-local views needed by the attention module.
For a given ``layer`` index, this method returns the slices of the
global K/V cache, page table, and per-(batch, head) sequence lengths
corresponding to that layer.
Args:
:param layer:
Layer index ``l`` in ``[0, num_layers)``.
Returns:
:return Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
``(k, v, pt, bh)`` as described above.
"""
assert
0
<=
layer
<
self
.
num_layers
k
=
self
.
kv_cache
[
0
,
layer
]
v
=
self
.
kv_cache
[
1
,
layer
]
pt
=
self
.
page_table
[
layer
]
bh
=
self
.
bh_seq_lens
[
layer
]
return
k
,
v
,
pt
,
bh
vllm/compactor-vllm/src/compactor_vllm/kv_cache/store_kv_cache.py
0 → 100644
View file @
2b7160c6
import
torch
import
triton
import
triton.language
as
tl
from
compactor_vllm.config.constants
import
(
TRITON_RESERVED_BATCH
as
_TRITON_RESERVED_BATCH
,
)
@
triton
.
jit
def
_prefill_store_topk_kv_kernel
(
key
,
value
,
# [N_total, H, D] (D stride assumed 1)
batch_mapping
,
# [B] int32 (local b -> true batch)
num_tokens_to_retain
,
# [B] int32
indices_topk
,
# [B, MAX_SEL] int32 (across all heads)
# Lengths & page table:
bh_lens
,
# [B, H] int32 (contiguous)
page_table
,
# [B_total * H * N_LOGICAL_PAGES_MAX] int32 (flattened), read-only
k_cache
,
v_cache
,
# [N_PAGES * PAGE_SIZE, D]
sk_n
,
sk_h
,
# strides for key,value. D stride assumed 1
sv_n
,
sv_h
,
# Runtime ints
MAX_SEL
,
# num tokens that are ranked in indices for each batch (might be bigger than num_tokens_to_retain)
HKV
:
tl
.
constexpr
,
N_LOGICAL_PAGES_MAX
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
K_TILE
:
tl
.
constexpr
,
# how many selected tokens each program processes
TRITON_RESERVED_BATCH
:
tl
.
constexpr
,
):
b_local
=
tl
.
program_id
(
0
)
tile_id
=
tl
.
program_id
(
1
)
offs
=
tl
.
arange
(
0
,
D
)
# how many tokens we actually keep for this batch
k_total
=
tl
.
load
(
num_tokens_to_retain
+
b_local
)
if
k_total
==
0
:
return
# map to true batch row in the page table
b_true
=
tl
.
load
(
batch_mapping
+
b_local
)
if
b_true
==
TRITON_RESERVED_BATCH
:
return
base
=
tile_id
*
K_TILE
# process up to K_TILE tokens
for
j
in
tl
.
range
(
0
,
K_TILE
):
sel_idx
=
base
+
j
if
sel_idx
<
k_total
and
sel_idx
<
MAX_SEL
:
# flattened selection: sel = token * H + head
sel
=
tl
.
load
(
indices_topk
+
b_local
*
MAX_SEL
+
sel_idx
)
tok
=
sel
//
HKV
head
=
sel
-
(
tok
*
HKV
)
# atomically reserve one position in (b_local, hed)
# i.e the KV cache is scrambled when storing
len_ptr
=
bh_lens
+
b_local
*
HKV
+
head
pos
=
tl
.
atomic_add
(
len_ptr
,
1
)
# old length (int32)
lp
=
pos
//
PAGE_SIZE
off
=
pos
-
lp
*
PAGE_SIZE
# translate logical page to physical page
pt_base
=
(
b_true
*
HKV
+
head
)
*
N_LOGICAL_PAGES_MAX
phys
=
tl
.
load
(
page_table
+
pt_base
+
lp
).
to
(
tl
.
int64
)
# destination row and element offset
dst_row
=
phys
*
PAGE_SIZE
+
off
dst_off
=
dst_row
*
D
+
offs
# load one vector from [N_total, H, D]
k_src
=
key
+
tok
*
sk_n
+
head
*
sk_h
+
offs
v_src
=
value
+
tok
*
sv_n
+
head
*
sv_h
+
offs
tl
.
store
(
k_cache
+
dst_off
,
tl
.
load
(
k_src
,
cache_modifier
=
".cv"
,
eviction_policy
=
"evict_first"
),
eviction_policy
=
"evict_first"
,
)
tl
.
store
(
v_cache
+
dst_off
,
tl
.
load
(
v_src
,
cache_modifier
=
".cv"
,
eviction_policy
=
"evict_first"
),
eviction_policy
=
"evict_first"
,
)
def
prefill_store_topk_kv
(
*
,
new_keys
:
torch
.
Tensor
,
# [N_total, H, D]
new_vals
:
torch
.
Tensor
,
# [N_total, H, D]
indices_topk
:
torch
.
Tensor
,
# [B, MAX_SEL] int32 (global flattened token*H + head)
num_tokens_to_retain
:
torch
.
Tensor
,
# [B] int32
page_table
:
torch
.
Tensor
,
# [B_total, H, N_LOGICAL_PAGES_MAX] int32
batch_mapping
:
torch
.
Tensor
,
# [B] int32 (local -> true batch rows)
bh_lens
:
torch
.
Tensor
,
# [B, H] int32 (contiguous), UPDATED atomically
k_cache
:
torch
.
Tensor
,
# [N_PAGES * PAGE_SIZE, D]
v_cache
:
torch
.
Tensor
,
# [N_PAGES * PAGE_SIZE, D]
PAGE_SIZE
:
int
,
PAD_TO_PAGE_SIZE
:
bool
=
True
,
cu_seqlens_k
:
torch
.
Tensor
|
None
=
None
,
K_TILE
:
int
=
16
,
TRITON_RESERVED_BATCH
:
int
=
None
,
):
assert
new_keys
.
shape
==
new_vals
.
shape
N_total
,
H
,
D
=
new_keys
.
shape
B
=
indices_topk
.
shape
[
0
]
assert
page_table
.
shape
[
1
]
==
H
assert
bh_lens
.
shape
==
(
B
,
H
)
assert
new_keys
.
device
==
k_cache
.
device
==
v_cache
.
device
assert
page_table
.
is_contiguous
(),
"page table must be contiguous."
assert
bh_lens
.
is_contiguous
(),
"bh_lens must be contiguous."
assert
batch_mapping
.
is_contiguous
(),
"batch mapping must be contiguous."
assert
k_cache
.
is_contiguous
()
and
v_cache
.
is_contiguous
()
assert
new_keys
.
stride
(
-
1
)
==
1
and
new_vals
.
stride
(
-
1
)
==
1
,
(
"new_keys/new_vals last dim must be contiguous."
)
assert
(
D
&
(
D
-
1
))
==
0
,
"D must be a power of 2"
page_table
=
page_table
.
to
(
torch
.
int32
)
bh_lens
=
bh_lens
.
to
(
torch
.
int32
)
batch_mapping
=
batch_mapping
.
to
(
torch
.
int32
)
indices_topk
=
indices_topk
.
to
(
torch
.
int32
)
num_tokens_to_retain
=
num_tokens_to_retain
.
to
(
torch
.
int32
)
# strides (elements) for [N_total, H, D]
sk_n
,
sk_h
,
_
=
new_keys
.
stride
()
sv_n
,
sv_h
,
_
=
new_vals
.
stride
()
# tile second grid dim
MAX_SEL
=
indices_topk
.
shape
[
-
1
]
N_TILES
=
(
MAX_SEL
+
K_TILE
-
1
)
//
K_TILE
grid
=
(
B
,
max
(
1
,
N_TILES
))
if
TRITON_RESERVED_BATCH
is
None
:
TRITON_RESERVED_BATCH
=
_TRITON_RESERVED_BATCH
_prefill_store_topk_kv_kernel
[
grid
](
key
=
new_keys
,
value
=
new_vals
,
batch_mapping
=
batch_mapping
,
num_tokens_to_retain
=
num_tokens_to_retain
,
indices_topk
=
indices_topk
,
bh_lens
=
bh_lens
,
page_table
=
page_table
,
k_cache
=
k_cache
,
v_cache
=
v_cache
,
sk_n
=
sk_n
,
sk_h
=
sk_h
,
sv_n
=
sv_n
,
sv_h
=
sv_h
,
MAX_SEL
=
int
(
MAX_SEL
),
HKV
=
H
,
N_LOGICAL_PAGES_MAX
=
page_table
.
shape
[
2
],
D
=
D
,
PAGE_SIZE
=
PAGE_SIZE
,
K_TILE
=
K_TILE
,
TRITON_RESERVED_BATCH
=
TRITON_RESERVED_BATCH
,
)
if
PAD_TO_PAGE_SIZE
:
assert
cu_seqlens_k
is
not
None
assert
indices_topk
.
is_contiguous
()
assert
page_table
.
is_contiguous
()
_prefill_store_topk_pad_kernel
[(
B
,
H
)](
key
=
new_keys
,
value
=
new_vals
,
batch_mapping
=
batch_mapping
,
num_tokens_to_retain
=
num_tokens_to_retain
,
indices
=
indices_topk
,
local_lens
=
bh_lens
,
page_table_flat
=
page_table
,
k_cache
=
k_cache
,
v_cache
=
v_cache
,
cu_seqlens_k
=
cu_seqlens_k
,
sk_n
=
sk_n
,
sk_h
=
sk_h
,
sv_n
=
sv_n
,
sv_h
=
sv_h
,
MAX_SEL
=
int
(
MAX_SEL
),
H
=
H
,
# type: ignore
N_LOGICAL_PAGES_MAX
=
page_table
.
shape
[
2
],
# type: ignore
D
=
D
,
# type: ignore
PAGE_SIZE
=
PAGE_SIZE
,
# type: ignore
TRITON_RESERVED_BATCH
=
TRITON_RESERVED_BATCH
,
)
@
triton
.
jit
def
_prefill_store_topk_pad_kernel
(
key
,
# [N_total, H, D]
value
,
# [N_total, H, D]
batch_mapping
,
# [B] int32 (local b -> true batch)
num_tokens_to_retain
,
# [B] int32
indices
,
# [B, MAX_SEL] int32 (across all heads)
local_lens
,
# [B, H] int32 (contiguous)
page_table_flat
,
# [B_total*H*N_LOGICAL_PAGES_MAX] int32
k_cache
,
v_cache
,
# [N_PAGES*PAGE_SIZE, D]
cu_seqlens_k
,
sk_n
,
sk_h
,
sv_n
,
sv_h
,
MAX_SEL
,
# Constexprs
H
:
tl
.
constexpr
,
# number of KV heads
N_LOGICAL_PAGES_MAX
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
TRITON_RESERVED_BATCH
:
tl
.
constexpr
,
):
b_local
=
tl
.
program_id
(
0
)
h
=
tl
.
program_id
(
1
)
offs_d
=
tl
.
arange
(
0
,
D
)
L
=
tl
.
load
(
local_lens
+
b_local
*
H
+
h
)
modulo_page_size
=
L
-
(
L
//
PAGE_SIZE
)
*
PAGE_SIZE
if
modulo_page_size
==
0
:
return
need
=
PAGE_SIZE
-
modulo_page_size
b_true
=
tl
.
load
(
batch_mapping
+
b_local
)
if
b_true
==
TRITON_RESERVED_BATCH
:
return
pt_base
=
(
b_true
*
H
+
h
)
*
N_LOGICAL_PAGES_MAX
written_tokens
=
0
idx
=
tl
.
load
(
num_tokens_to_retain
+
b_local
)
this_batch_ctx_len
=
tl
.
load
(
cu_seqlens_k
+
b_local
+
1
)
-
tl
.
load
(
cu_seqlens_k
+
b_local
)
max_additional
=
this_batch_ctx_len
-
L
while
(
written_tokens
<
need
and
idx
<
MAX_SEL
)
and
(
written_tokens
<
max_additional
):
# candidate head
cand_idx
=
tl
.
load
(
indices
+
b_local
*
MAX_SEL
+
idx
)
cand_h
=
cand_idx
%
H
if
cand_h
==
h
:
tok
=
cand_idx
//
H
pos
=
L
+
written_tokens
lp
=
pos
//
PAGE_SIZE
off
=
pos
-
lp
*
PAGE_SIZE
phys
=
tl
.
load
(
page_table_flat
+
pt_base
+
lp
).
to
(
tl
.
int32
)
dst_row
=
phys
*
PAGE_SIZE
+
off
dst_off
=
dst_row
.
to
(
tl
.
int64
)
*
D
+
offs_d
k_src
=
key
+
tok
*
sk_n
+
h
*
sk_h
+
offs_d
v_src
=
value
+
tok
*
sv_n
+
h
*
sv_h
+
offs_d
tl
.
store
(
k_cache
+
dst_off
,
tl
.
load
(
k_src
),
)
tl
.
store
(
v_cache
+
dst_off
,
tl
.
load
(
v_src
),
)
written_tokens
+=
1
idx
+=
1
tl
.
store
(
local_lens
+
b_local
*
H
+
h
,
L
+
written_tokens
)
@
triton
.
jit
def
_prefill_store_all_kv_kernel
(
key
,
value
,
# [N, H, D] (D contiguous)
cu_seqlens_k
,
# [B + 1] int32
batch_mapping
,
# [B] int32 (local b -> true batch index)
bh_lens
,
# [B * HKV] int32 (UPDATED)
pt_flat
,
# [B_total * HKV * N_LOGICAL_PAGES_MAX] int32 (flattened)
k_cache
,
v_cache
,
# [N_PAGES * PAGE_SIZE, D]
# source strides (elements)
sk_n
,
sk_h
,
sv_n
,
sv_h
,
# constexpr
HKV
:
tl
.
constexpr
,
N_LOGICAL_PAGES_MAX
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
K_TILE
:
tl
.
constexpr
,
# number of (token, head) pairs processed per program
):
pid_b
=
tl
.
program_id
(
0
)
pid_blk
=
tl
.
program_id
(
1
)
start
=
tl
.
load
(
cu_seqlens_k
+
pid_b
)
end
=
tl
.
load
(
cu_seqlens_k
+
pid_b
+
1
)
num_toks_this_batch
=
end
-
start
if
num_toks_this_batch
<=
0
:
return
total_elems
=
num_toks_this_batch
*
HKV
# base linear index in (token, head) grid for this program
base
=
pid_blk
*
K_TILE
offs_d
=
tl
.
arange
(
0
,
D
)
# Iterate K_TILE elements in this tile
for
i
in
tl
.
range
(
0
,
K_TILE
):
idx
=
base
+
i
if
idx
<
total_elems
:
# map linear idx -> (t, h)
t
=
idx
//
HKV
h
=
idx
-
t
*
HKV
len_idx
=
pid_b
*
HKV
+
h
L0
=
tl
.
load
(
bh_lens
+
len_idx
)
token_idx_in_cache
=
L0
+
t
lp
=
token_idx_in_cache
//
PAGE_SIZE
# logical page
off_in_pg
=
token_idx_in_cache
-
lp
*
PAGE_SIZE
# pos in page
# physical page
b_true
=
tl
.
load
(
batch_mapping
+
pid_b
).
to
(
tl
.
int32
)
pt_base
=
(
b_true
*
HKV
+
h
)
*
N_LOGICAL_PAGES_MAX
phys
=
tl
.
load
(
pt_flat
+
pt_base
+
lp
).
to
(
tl
.
int64
)
row
=
phys
*
PAGE_SIZE
+
off_in_pg
dst_off
=
row
*
D
+
offs_d
n_global
=
(
start
+
t
).
to
(
tl
.
int64
)
# Use strides for non-contiguous [N, H, D] (D stride == 1)
k_src
=
key
+
n_global
*
sk_n
+
h
*
sk_h
+
offs_d
v_src
=
value
+
n_global
*
sv_n
+
h
*
sv_h
+
offs_d
tl
.
store
(
k_cache
+
dst_off
,
tl
.
load
(
k_src
))
tl
.
store
(
v_cache
+
dst_off
,
tl
.
load
(
v_src
))
def
prefill_store_all_kv
(
*
,
new_keys
:
torch
.
Tensor
,
new_values
:
torch
.
Tensor
,
# [N, H_kv, D]
cu_seqlens_k
:
torch
.
Tensor
,
# [B + 1] int32
max_seqlen_k
:
int
,
k_cache
:
torch
.
Tensor
,
v_cache
:
torch
.
Tensor
,
page_table
:
torch
.
Tensor
,
# [B_total, H_kv, N_LOGICAL_PAGES_MAX] int32
bh_lens
:
torch
.
Tensor
,
# [B, H_kv] int32 (UPDATED)
batch_mapping
:
torch
.
Tensor
,
# [B] int32 (local->true)
PAGE_SIZE
:
int
,
K_TILE
:
int
=
32
,
# how many (token, head) pairs per program
):
assert
new_keys
.
stride
(
-
1
)
==
1
and
new_values
.
stride
(
-
1
)
==
1
,
(
"last dim must be contiguous"
)
assert
page_table
.
is_contiguous
(),
"page table must be contiguous"
assert
bh_lens
.
is_contiguous
(),
"bh_lens must be contiguous"
assert
batch_mapping
.
is_contiguous
(),
"batch mapping must be contiguous"
assert
k_cache
.
is_contiguous
()
and
v_cache
.
is_contiguous
()
N
,
HKV
,
D
=
new_keys
.
shape
B
=
batch_mapping
.
shape
[
0
]
assert
(
D
&
(
D
-
1
))
==
0
,
"D must be a power of 2"
sk_n
,
sk_h
,
_
=
new_keys
.
stride
()
sv_n
,
sv_h
,
_
=
new_values
.
stride
()
n_tiles
=
(
max_seqlen_k
*
HKV
+
K_TILE
-
1
)
//
K_TILE
grid
=
(
B
,
n_tiles
)
_prefill_store_all_kv_kernel
[
grid
](
new_keys
,
new_values
,
cu_seqlens_k
,
batch_mapping
,
bh_lens
,
page_table
,
k_cache
,
v_cache
,
sk_n
=
sk_n
,
sk_h
=
sk_h
,
sv_n
=
sv_n
,
sv_h
=
sv_h
,
HKV
=
HKV
,
N_LOGICAL_PAGES_MAX
=
page_table
.
shape
[
-
1
],
D
=
D
,
PAGE_SIZE
=
PAGE_SIZE
,
K_TILE
=
K_TILE
,
)
bh_lens
+=
cu_seqlens_k
.
diff
()[:,
None
]
@
triton
.
jit
def
_decode_store_kv_kernel
(
key
,
value
,
batch_mapping
,
# [B] int32
bh_lens
,
# [B*HKV] int32
page_table
,
# [B_total*HKV*N_LOGICAL_PAGES_MAX]
k_cache
,
v_cache
,
# [N_PAGES*PAGE_SIZE, D]
sk_b
,
sk_h
,
sv_b
,
sv_h
,
HKV
:
tl
.
constexpr
,
N_LOGICAL_PAGES_MAX
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
TRITON_RESERVED_BATCH
:
tl
.
constexpr
,
):
pid_b
=
tl
.
program_id
(
0
)
h
=
tl
.
program_id
(
1
)
mapped_b
=
tl
.
load
(
batch_mapping
+
pid_b
)
if
mapped_b
==
TRITON_RESERVED_BATCH
:
return
offs_d
=
tl
.
arange
(
0
,
D
)
length
=
tl
.
load
(
bh_lens
+
pid_b
*
HKV
+
h
)
logical_page
=
length
//
PAGE_SIZE
internal_offset
=
length
-
logical_page
*
PAGE_SIZE
pt_base
=
(
mapped_b
*
HKV
+
h
)
*
N_LOGICAL_PAGES_MAX
physical_page
=
tl
.
load
(
page_table
+
pt_base
+
logical_page
).
to
(
tl
.
int64
)
dst_row
=
physical_page
*
PAGE_SIZE
+
internal_offset
# Source addressing using strides (D stride == 1)
k_src
=
key
+
pid_b
*
sk_b
+
h
*
sk_h
+
offs_d
v_src
=
value
+
pid_b
*
sv_b
+
h
*
sv_h
+
offs_d
dst_off
=
dst_row
*
D
+
offs_d
tl
.
store
(
k_cache
+
dst_off
,
tl
.
load
(
k_src
))
tl
.
store
(
v_cache
+
dst_off
,
tl
.
load
(
v_src
))
tl
.
store
(
bh_lens
+
pid_b
*
HKV
+
h
,
length
+
1
)
def
decode_store_kv
(
*
,
key
:
torch
.
Tensor
,
# [B, HKV, D]
value
:
torch
.
Tensor
,
# [B, HKV, D]
batch_mapping
:
torch
.
Tensor
,
# [B] int32
bh_lens
:
torch
.
Tensor
,
# [B, HKV] or flattened [B*HKV] int32
page_table
:
torch
.
Tensor
,
# [B_total, HKV, N_LOGICAL_PAGES_MAX] int32
k_cache
:
torch
.
Tensor
,
v_cache
:
torch
.
Tensor
,
# [N_PAGES*PAGE_SIZE, D]
PAGE_SIZE
:
int
,
TRITON_RESERVED_BATCH
:
int
=
None
,
):
assert
key
.
shape
==
value
.
shape
and
key
.
ndim
==
3
,
"key/value must be [B, HKV, D]"
B
,
HKV
,
D
=
key
.
shape
assert
key
.
stride
(
-
1
)
==
1
and
value
.
stride
(
-
1
)
==
1
,
(
"key/value last dim must be contiguous."
)
assert
page_table
.
is_contiguous
(),
"page table must be contiguous."
assert
bh_lens
.
is_contiguous
(),
"bh_lens must be contiguous."
assert
batch_mapping
.
is_contiguous
(),
"batch mapping must be contiguous."
assert
k_cache
.
is_contiguous
()
and
v_cache
.
is_contiguous
()
assert
(
D
&
(
D
-
1
))
==
0
,
"D must be a power of 2"
sk_b
,
sk_h
,
_
=
key
.
stride
()
sv_b
,
sv_h
,
_
=
value
.
stride
()
grid
=
(
int
(
batch_mapping
.
shape
[
0
]),
HKV
,
)
_decode_store_kv_kernel
[
grid
](
key
=
key
,
value
=
value
,
batch_mapping
=
batch_mapping
,
bh_lens
=
bh_lens
,
page_table
=
page_table
,
k_cache
=
k_cache
,
v_cache
=
v_cache
,
sk_b
=
sk_b
,
sk_h
=
sk_h
,
sv_b
=
sv_b
,
sv_h
=
sv_h
,
HKV
=
HKV
,
N_LOGICAL_PAGES_MAX
=
page_table
.
shape
[
2
],
D
=
D
,
PAGE_SIZE
=
PAGE_SIZE
,
TRITON_RESERVED_BATCH
=
TRITON_RESERVED_BATCH
if
TRITON_RESERVED_BATCH
is
not
None
else
_TRITON_RESERVED_BATCH
,
)
vllm/compactor-vllm/src/compactor_vllm/kv_cache/write_page_table.py
0 → 100644
View file @
2b7160c6
import
torch
import
triton
import
triton.language
as
tl
def
scatter_to_page_table
(
add_pages
:
torch
.
Tensor
,
# [L, H] int32
new_phys_pages
:
torch
.
Tensor
,
# [N]
curr_pages
:
torch
.
Tensor
,
# [L, H] int32
page_table
:
torch
.
Tensor
,
# [L, H, max_pages_per_head] int32, NOT assumed contiguous globally
max_pages_per_head
:
int
,
):
"""
Append newly allocated physical pages into a layered page table via Triton.
For each (layer ``l``, head ``h``):
Args:
:param add_pages:
Tensor of shape ``[L, H]`` (int32) indicating how many pages to
append for each (layer, head).
:param new_phys_pages:
1D tensor of shape ``[N]`` (int32) containing physical page IDs
for all (layer, head) pairs, concatenated in row-major (L, H)
order. ``N`` must equal ``add_pages.sum()``.
:param curr_pages:
Tensor of shape ``[L, H]`` (int32) with the current logical page
counts per (layer, head) before this update.
:param page_table:
Tensor of shape ``[L, H, max_pages_per_head]`` (int32) holding
the logical to physical page mapping. The last dimension is
logically indexed as logical_page ∈ [0, max_pages_per_head).
:param max_pages_per_head:
Maximum number of logical pages permitted per (layer, head). The
kernel skips writes beyond this bound.
Returns:
None. The function updates ``page_table`` in-place.
"""
L
,
H
=
add_pages
.
shape
if
L
==
0
or
H
==
0
:
return
add_flat
=
add_pages
.
to
(
torch
.
int32
).
contiguous
().
view
(
-
1
)
curr_flat
=
curr_pages
.
to
(
torch
.
int32
).
contiguous
().
view
(
-
1
)
cum_page_heads
=
torch
.
empty
(
L
*
H
+
1
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
cum_page_heads
[
0
]
=
0
torch
.
cumsum
(
add_flat
,
0
,
out
=
cum_page_heads
[
1
:])
stride_pl
,
stride_ph
,
stride_pp
=
page_table
.
stride
()
grid
=
(
L
,
H
)
_scatter_pages_kernel_lh
[
grid
](
add_flat
,
cum_page_heads
,
new_phys_pages
,
curr_flat
,
page_table
,
stride_pl
,
stride_ph
,
stride_pp
,
L
=
L
,
H
=
H
,
max_pages_per_head
=
max_pages_per_head
,
)
@
triton
.
jit
def
_scatter_pages_kernel_lh
(
add_pages
,
# int32 [L*H]
cum_page_heads
,
# int32 [L*H], base offset in flat_new_phys per (l,h)
flat_new_phys
,
# int32 [total_pages]
curr_pages
,
# int32 [L*H], existing logical pages per (l,h)
page_table_ptr
,
# int32* base pointer to page_table
stride_pl
,
# int, stride for layer dim
stride_ph
,
# int, stride for head dim
stride_pp
,
# int, stride for page dim
L
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
max_pages_per_head
:
tl
.
constexpr
,
):
layer_idx
=
tl
.
program_id
(
0
)
h
=
tl
.
program_id
(
1
)
if
layer_idx
>=
L
or
h
>=
H
:
return
lh
=
layer_idx
*
H
+
h
ap
=
tl
.
load
(
add_pages
+
lh
)
if
ap
<=
0
:
return
base
=
tl
.
load
(
cum_page_heads
+
lh
)
cp
=
tl
.
load
(
curr_pages
+
lh
)
# Append ap pages: logical pages [cp .. cp+ap)
for
i
in
tl
.
range
(
0
,
ap
):
phys
=
tl
.
load
(
flat_new_phys
+
base
+
i
)
lp
=
cp
+
i
if
lp
<
max_pages_per_head
:
offset
=
layer_idx
*
stride_pl
+
h
*
stride_ph
+
lp
*
stride_pp
tl
.
store
(
page_table_ptr
+
offset
,
phys
)
# TODO: write reclaim kernel
@
triton
.
jit
def
reclaim_page_kernel
():
pass
def
reclaim_pages
(
batch_index
:
int
,
bh_seq_lens
:
torch
.
Tensor
,
bh_num_pages
:
torch
.
Tensor
,
page_table
:
torch
.
Tensor
,
):
pass
vllm/compactor-vllm/src/compactor_vllm/layers/__init__.py
0 → 100644
View file @
2b7160c6
vllm/compactor-vllm/src/compactor_vllm/layers/activation.py
0 → 100644
View file @
2b7160c6
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
class
SiluAndMul
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
# @torch.compile
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
,
y
=
x
.
chunk
(
2
,
-
1
)
return
F
.
silu
(
x
)
*
y
vllm/compactor-vllm/src/compactor_vllm/layers/attention.py
0 → 100644
View file @
2b7160c6
from
typing
import
Optional
import
torch
from
compactor_vllm.attention.sparse_decode_kernel
import
head_sparse_decode_attention
from
compactor_vllm.attention.sparse_varlen_kernel
import
(
causal_sparse_varlen_with_cache
,
)
from
compactor_vllm.compression.common
import
extract_and_store_top_kv
from
compactor_vllm.config.engine_config
import
AttentionBackend
from
compactor_vllm.kv_cache.store_kv_cache
import
decode_store_kv
,
prefill_store_all_kv
from
compactor_vllm.utils.context
import
Context
,
get_context
from
compactor_vllm.utils.helpers
import
maybe_execute_in_stream
from
flash_attn.flash_attn_interface
import
flash_attn_varlen_func
from
torch
import
nn
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
num_heads
,
head_dim
,
scale
,
num_kv_heads
,
):
super
().
__init__
()
self
.
num_heads
:
int
=
num_heads
self
.
head_dim
=
head_dim
self
.
scale
:
float
=
scale
self
.
num_kv_heads
=
int
(
num_kv_heads
)
self
.
k_cache
:
Optional
[
torch
.
Tensor
]
=
None
self
.
v_cache
:
Optional
[
torch
.
Tensor
]
=
None
self
.
page_table
:
Optional
[
torch
.
Tensor
]
=
None
self
.
bh_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
self
.
page_size
:
Optional
[
int
]
=
None
def
forward
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
scores
:
Optional
[
torch
.
Tensor
]
=
None
,
):
context
:
Context
=
get_context
()
batch_mapping
=
context
.
batch_mapping
seq_lens
=
(
None
if
self
.
bh_seq_lens
is
None
else
self
.
bh_seq_lens
.
index_select
(
0
,
batch_mapping
).
contiguous
()
)
if
context
.
is_prefill
:
seq_lens_copy
=
seq_lens
.
clone
()
if
seq_lens
is
not
None
else
None
if
(
self
.
k_cache
is
not
None
and
context
.
do_compression
and
scores
is
not
None
):
compression_context
=
context
.
compression_context
assert
scores
is
not
None
assert
compression_context
is
not
None
maybe_execute_in_stream
(
extract_and_store_top_kv
,
scores
=
scores
,
cu_seqlens_k
=
context
.
cu_seqlens_k
,
max_k_len
=
context
.
max_seqlen_k
,
top_k
=
compression_context
.
max_tokens_to_retain
,
H
=
int
(
self
.
num_kv_heads
),
new_keys
=
k
,
new_vals
=
v
,
num_tokens_to_retain
=
compression_context
.
batch_tokens_to_retain
,
page_table
=
self
.
page_table
,
batch_mapping
=
batch_mapping
,
bh_lens
=
seq_lens
,
k_cache
=
self
.
k_cache
,
v_cache
=
self
.
v_cache
,
PAGE_SIZE
=
self
.
page_size
,
PAD_TO_PAGE_SIZE
=
True
,
STORE_STREAM
=
context
.
STORE_STREAM
,
)
elif
self
.
k_cache
is
not
None
:
maybe_execute_in_stream
(
prefill_store_all_kv
,
new_keys
=
k
,
new_values
=
v
,
cu_seqlens_k
=
context
.
cu_seqlens_k
,
max_seqlen_k
=
context
.
max_seqlen_k
,
k_cache
=
self
.
k_cache
,
v_cache
=
self
.
v_cache
,
page_table
=
self
.
page_table
,
bh_lens
=
seq_lens
,
batch_mapping
=
batch_mapping
,
PAGE_SIZE
=
self
.
page_size
,
STORE_STREAM
=
context
.
STORE_STREAM
,
)
# No compression: FA varlen on q,k,v (matches HF). Compressed: Triton reads paged KV.
use_flash_prefill
=
context
.
attention_backend
==
AttentionBackend
.
FLASH_ATTENTION
or
(
context
.
attention_backend
==
AttentionBackend
.
COMPACTOR_TRITON
and
not
context
.
do_compression
)
if
use_flash_prefill
:
o
=
flash_attn_varlen_func
(
q
,
k
,
v
,
max_seqlen_q
=
context
.
max_seqlen_q
,
cu_seqlens_q
=
context
.
cu_seqlens_q
,
max_seqlen_k
=
context
.
max_seqlen_k
,
cu_seqlens_k
=
context
.
cu_seqlens_k
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
)
elif
context
.
attention_backend
==
AttentionBackend
.
COMPACTOR_TRITON
:
# Top-k KV writes on STORE_STREAM; Triton prefill must see finished writes.
if
context
.
do_compression
and
context
.
STORE_STREAM
is
not
None
:
torch
.
cuda
.
current_stream
().
wait_stream
(
context
.
STORE_STREAM
)
o
=
causal_sparse_varlen_with_cache
(
q
,
k
,
v
,
self
.
k_cache
,
self
.
v_cache
,
seq_lens_bh
=
seq_lens_copy
,
global_page_table
=
self
.
page_table
,
batch_mapping
=
batch_mapping
,
cu_seqlens_q
=
context
.
cu_seqlens_q
,
max_seqlen_q
=
context
.
max_seqlen_q
,
max_seqlen_k_cache
=
context
.
max_bh_len
,
HKV
=
int
(
self
.
num_kv_heads
),
PAGE_SIZE
=
self
.
page_size
,
sm_scale
=
self
.
scale
,
)
else
:
raise
NotImplementedError
else
:
assert
self
.
k_cache
is
not
None
,
"KV Cache must be initialized for decoding"
decode_store_kv
(
key
=
k
,
value
=
v
,
batch_mapping
=
batch_mapping
,
bh_lens
=
seq_lens
,
page_table
=
self
.
page_table
,
k_cache
=
self
.
k_cache
,
v_cache
=
self
.
v_cache
,
PAGE_SIZE
=
self
.
page_size
,
)
o
=
head_sparse_decode_attention
(
q
,
self
.
k_cache
,
self
.
v_cache
,
seq_lens
,
self
.
page_table
,
batch_mapping
,
int
(
self
.
num_kv_heads
),
self
.
page_size
,
self
.
scale
,
key_split
=
context
.
key_split
,
)
if
self
.
bh_seq_lens
is
not
None
:
longbm
=
batch_mapping
.
to
(
torch
.
long
)
maybe_execute_in_stream
(
self
.
bh_seq_lens
.
index_copy_
,
0
,
longbm
,
seq_lens
,
STORE_STREAM
=
context
.
STORE_STREAM
if
context
.
is_prefill
else
None
,
)
return
o
vllm/compactor-vllm/src/compactor_vllm/layers/embed_head.py
0 → 100644
View file @
2b7160c6
import
torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
from
compactor_vllm.utils.context
import
get_context
from
torch
import
nn
class
VocabParallelEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
):
super
().
__init__
()
self
.
tp_rank
=
dist
.
get_rank
()
self
.
tp_size
=
dist
.
get_world_size
()
assert
num_embeddings
%
self
.
tp_size
==
0
self
.
num_embeddings
=
num_embeddings
self
.
num_embeddings_per_partition
=
self
.
num_embeddings
//
self
.
tp_size
self
.
vocab_start_idx
=
self
.
num_embeddings_per_partition
*
self
.
tp_rank
self
.
vocab_end_idx
=
self
.
vocab_start_idx
+
self
.
num_embeddings_per_partition
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_embeddings_per_partition
,
embedding_dim
)
)
self
.
weight
.
weight_loader
=
self
.
weight_loader
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
):
param_data
=
param
.
data
shard_size
=
param_data
.
size
(
0
)
start_idx
=
self
.
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
0
,
start_idx
,
shard_size
)
param_data
.
copy_
(
loaded_weight
)
def
forward
(
self
,
x
:
torch
.
Tensor
):
if
self
.
tp_size
>
1
:
mask
=
(
x
>=
self
.
vocab_start_idx
)
&
(
x
<
self
.
vocab_end_idx
)
x
=
mask
*
(
x
-
self
.
vocab_start_idx
)
y
=
F
.
embedding
(
x
,
self
.
weight
)
if
self
.
tp_size
>
1
:
y
=
mask
.
unsqueeze
(
1
)
*
y
dist
.
all_reduce
(
y
)
return
y
class
ParallelLMHead
(
VocabParallelEmbedding
):
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
bias
:
bool
=
False
,
):
assert
not
bias
super
().
__init__
(
num_embeddings
,
embedding_dim
)
def
forward
(
self
,
x
:
torch
.
Tensor
):
context
=
get_context
()
if
context
.
is_prefill
:
last_indices
=
context
.
cu_seqlens_q
[
1
:]
-
1
x
=
x
[
last_indices
].
contiguous
()
logits
=
F
.
linear
(
x
,
self
.
weight
)
if
self
.
tp_size
>
1
:
all_logits
=
(
[
torch
.
empty_like
(
logits
)
for
_
in
range
(
self
.
tp_size
)]
if
self
.
tp_rank
==
0
else
None
)
dist
.
gather
(
logits
,
all_logits
,
0
)
logits
=
torch
.
cat
(
all_logits
,
-
1
)
if
self
.
tp_rank
==
0
else
None
return
logits
vllm/compactor-vllm/src/compactor_vllm/layers/layernorm.py
0 → 100644
View file @
2b7160c6
import
torch
from
torch
import
nn
class
RMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
eps
:
float
=
1e-6
,
)
->
None
:
super
().
__init__
()
self
.
eps
=
eps
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
# @torch.compile
def
rms_forward
(
self
,
x
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
orig_dtype
=
x
.
dtype
x
=
x
.
float
()
var
=
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
.
mul_
(
torch
.
rsqrt
(
var
+
self
.
eps
))
x
=
x
.
to
(
orig_dtype
).
mul_
(
self
.
weight
)
return
x
# @torch.compile
def
add_rms_forward
(
self
,
x
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
orig_dtype
=
x
.
dtype
x
=
x
.
float
().
add_
(
residual
.
float
())
residual
=
x
.
to
(
orig_dtype
)
var
=
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
.
mul_
(
torch
.
rsqrt
(
var
+
self
.
eps
))
x
=
x
.
to
(
orig_dtype
).
mul_
(
self
.
weight
)
return
x
,
residual
def
forward
(
self
,
x
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
residual
is
None
:
return
self
.
rms_forward
(
x
)
else
:
return
self
.
add_rms_forward
(
x
,
residual
)
vllm/compactor-vllm/src/compactor_vllm/layers/linear.py
0 → 100644
View file @
2b7160c6
import
torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
from
torch
import
nn
def
divide
(
numerator
,
denominator
):
assert
numerator
%
denominator
==
0
return
numerator
//
denominator
class
LinearBase
(
nn
.
Module
):
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
bias
:
bool
=
False
,
tp_dim
:
int
|
None
=
None
,
):
super
().
__init__
()
self
.
tp_dim
=
tp_dim
self
.
tp_rank
=
dist
.
get_rank
()
self
.
tp_size
=
dist
.
get_world_size
()
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
(
output_size
,
input_size
))
self
.
weight
.
weight_loader
=
self
.
weight_loader
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
empty
(
output_size
))
self
.
bias
.
weight_loader
=
self
.
weight_loader
else
:
self
.
register_parameter
(
"bias"
,
None
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
raise
NotImplementedError
class
ReplicatedLinear
(
LinearBase
):
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
bias
:
bool
=
False
,
):
super
().
__init__
(
input_size
,
output_size
,
bias
)
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
):
param
.
data
.
copy_
(
loaded_weight
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
F
.
linear
(
x
,
self
.
weight
,
self
.
bias
)
class
ColumnParallelLinear
(
LinearBase
):
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
bias
:
bool
=
False
,
):
tp_size
=
dist
.
get_world_size
()
super
().
__init__
(
input_size
,
divide
(
output_size
,
tp_size
),
bias
,
0
)
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
):
param_data
=
param
.
data
shard_size
=
param_data
.
size
(
self
.
tp_dim
)
start_idx
=
self
.
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
self
.
tp_dim
,
start_idx
,
shard_size
)
param_data
.
copy_
(
loaded_weight
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
F
.
linear
(
x
,
self
.
weight
,
self
.
bias
)
class
MergedColumnParallelLinear
(
ColumnParallelLinear
):
def
__init__
(
self
,
input_size
:
int
,
output_sizes
:
list
[
int
],
bias
:
bool
=
False
,
):
self
.
output_sizes
=
output_sizes
super
().
__init__
(
input_size
,
sum
(
output_sizes
),
bias
)
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
int
):
param_data
=
param
.
data
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
//
self
.
tp_size
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
//
self
.
tp_size
param_data
=
param_data
.
narrow
(
self
.
tp_dim
,
shard_offset
,
shard_size
)
loaded_weight
=
loaded_weight
.
chunk
(
self
.
tp_size
,
self
.
tp_dim
)[
self
.
tp_rank
]
param_data
.
copy_
(
loaded_weight
)
class
QKVParallelLinear
(
ColumnParallelLinear
):
def
__init__
(
self
,
hidden_size
:
int
,
head_size
:
int
,
total_num_heads
:
int
,
total_num_kv_heads
:
int
|
None
=
None
,
bias
:
bool
=
False
,
):
tp_size
=
dist
.
get_world_size
()
total_num_kv_heads
=
total_num_kv_heads
or
total_num_heads
self
.
head_size
=
head_size
self
.
num_heads
=
divide
(
total_num_heads
,
tp_size
)
self
.
num_kv_heads
=
divide
(
total_num_kv_heads
,
tp_size
)
output_size
=
(
total_num_heads
+
2
*
total_num_kv_heads
)
*
self
.
head_size
super
().
__init__
(
hidden_size
,
output_size
,
bias
)
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
str
):
param_data
=
param
.
data
assert
loaded_shard_id
in
[
"q"
,
"k"
,
"v"
]
if
loaded_shard_id
==
"q"
:
shard_size
=
self
.
num_heads
*
self
.
head_size
shard_offset
=
0
elif
loaded_shard_id
==
"k"
:
shard_size
=
self
.
num_kv_heads
*
self
.
head_size
shard_offset
=
self
.
num_heads
*
self
.
head_size
else
:
shard_size
=
self
.
num_kv_heads
*
self
.
head_size
shard_offset
=
(
self
.
num_heads
*
self
.
head_size
+
self
.
num_kv_heads
*
self
.
head_size
)
param_data
=
param_data
.
narrow
(
self
.
tp_dim
,
shard_offset
,
shard_size
)
loaded_weight
=
loaded_weight
.
chunk
(
self
.
tp_size
,
self
.
tp_dim
)[
self
.
tp_rank
]
param_data
.
copy_
(
loaded_weight
)
class
RowParallelLinear
(
LinearBase
):
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
bias
:
bool
=
False
,
):
tp_size
=
dist
.
get_world_size
()
super
().
__init__
(
divide
(
input_size
,
tp_size
),
output_size
,
bias
,
1
)
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
):
param_data
=
param
.
data
shard_size
=
param_data
.
size
(
self
.
tp_dim
)
start_idx
=
self
.
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
self
.
tp_dim
,
start_idx
,
shard_size
)
param_data
.
copy_
(
loaded_weight
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
y
=
F
.
linear
(
x
,
self
.
weight
,
self
.
bias
if
self
.
tp_rank
==
0
else
None
)
if
self
.
tp_size
>
1
:
dist
.
all_reduce
(
y
)
return
y
Prev
1
2
3
4
5
6
7
8
…
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment