Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8d3d07fc
Commit
8d3d07fc
authored
Jan 21, 2026
by
laibao
Browse files
feat: kvpress新增 KV 压缩状态与元数据打通
parent
5036e878
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
34 additions
and
0 deletions
+34
-0
vllm/v1/core/sched/output.py
vllm/v1/core/sched/output.py
+6
-0
vllm/v1/request.py
vllm/v1/request.py
+4
-0
vllm/v1/worker/block_table.py
vllm/v1/worker/block_table.py
+5
-0
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+19
-0
No files found.
vllm/v1/core/sched/output.py
View file @
8d3d07fc
...
@@ -31,6 +31,7 @@ class NewRequestData:
...
@@ -31,6 +31,7 @@ class NewRequestData:
pooling_params
:
Optional
[
PoolingParams
]
pooling_params
:
Optional
[
PoolingParams
]
block_ids
:
tuple
[
list
[
int
],
...]
block_ids
:
tuple
[
list
[
int
],
...]
num_computed_tokens
:
int
num_computed_tokens
:
int
num_kv_tokens
:
int
lora_request
:
Optional
[
LoRARequest
]
lora_request
:
Optional
[
LoRARequest
]
@
classmethod
@
classmethod
...
@@ -49,6 +50,7 @@ class NewRequestData:
...
@@ -49,6 +50,7 @@ class NewRequestData:
pooling_params
=
request
.
pooling_params
,
pooling_params
=
request
.
pooling_params
,
block_ids
=
block_ids
,
block_ids
=
block_ids
,
num_computed_tokens
=
request
.
num_computed_tokens
,
num_computed_tokens
=
request
.
num_computed_tokens
,
num_kv_tokens
=
request
.
num_kv_tokens
,
lora_request
=
request
.
lora_request
,
lora_request
=
request
.
lora_request
,
)
)
...
@@ -62,6 +64,7 @@ class NewRequestData:
...
@@ -62,6 +64,7 @@ class NewRequestData:
f
"sampling_params=
{
self
.
sampling_params
}
,"
f
"sampling_params=
{
self
.
sampling_params
}
,"
f
"block_ids=
{
self
.
block_ids
}
,"
f
"block_ids=
{
self
.
block_ids
}
,"
f
"num_computed_tokens=
{
self
.
num_computed_tokens
}
,"
f
"num_computed_tokens=
{
self
.
num_computed_tokens
}
,"
f
"num_kv_tokens=
{
self
.
num_kv_tokens
}
,"
f
"lora_request=
{
self
.
lora_request
}
"
f
"lora_request=
{
self
.
lora_request
}
"
")"
)
")"
)
...
@@ -76,6 +79,7 @@ class NewRequestData:
...
@@ -76,6 +79,7 @@ class NewRequestData:
f
"sampling_params=
{
self
.
sampling_params
}
,"
f
"sampling_params=
{
self
.
sampling_params
}
,"
f
"block_ids=
{
self
.
block_ids
}
,"
f
"block_ids=
{
self
.
block_ids
}
,"
f
"num_computed_tokens=
{
self
.
num_computed_tokens
}
,"
f
"num_computed_tokens=
{
self
.
num_computed_tokens
}
,"
f
"num_kv_tokens=
{
self
.
num_kv_tokens
}
,"
f
"lora_request=
{
self
.
lora_request
}
"
f
"lora_request=
{
self
.
lora_request
}
"
")"
)
")"
)
...
@@ -93,6 +97,7 @@ class CachedRequestData:
...
@@ -93,6 +97,7 @@ class CachedRequestData:
new_token_ids
:
list
[
list
[
int
]]
new_token_ids
:
list
[
list
[
int
]]
new_block_ids
:
list
[
tuple
[
list
[
int
],
...]]
new_block_ids
:
list
[
tuple
[
list
[
int
],
...]]
num_computed_tokens
:
list
[
int
]
num_computed_tokens
:
list
[
int
]
num_kv_tokens
:
list
[
int
]
@
property
@
property
def
num_reqs
(
self
)
->
int
:
def
num_reqs
(
self
)
->
int
:
...
@@ -106,6 +111,7 @@ class CachedRequestData:
...
@@ -106,6 +111,7 @@ class CachedRequestData:
new_token_ids
=
[],
new_token_ids
=
[],
new_block_ids
=
[],
new_block_ids
=
[],
num_computed_tokens
=
[],
num_computed_tokens
=
[],
num_kv_tokens
=
[],
)
)
...
...
vllm/v1/request.py
View file @
8d3d07fc
...
@@ -79,6 +79,10 @@ class Request:
...
@@ -79,6 +79,10 @@ class Request:
self
.
_all_token_ids
:
list
[
int
]
=
self
.
prompt_token_ids
.
copy
()
self
.
_all_token_ids
:
list
[
int
]
=
self
.
prompt_token_ids
.
copy
()
self
.
spec_token_ids
:
list
[
int
]
=
[]
self
.
spec_token_ids
:
list
[
int
]
=
[]
self
.
num_computed_tokens
=
0
self
.
num_computed_tokens
=
0
# Number of tokens currently stored in the KV cache for this request.
# This can be different from `num_computed_tokens` when KV compression
# is enabled (e.g., token-shared prefill compression).
self
.
num_kv_tokens
=
0
self
.
num_generated_token_ids
=
0
self
.
num_generated_token_ids
=
0
self
.
cache_salt
:
Optional
[
str
]
=
cache_salt
self
.
cache_salt
:
Optional
[
str
]
=
cache_salt
...
...
vllm/v1/worker/block_table.py
View file @
8d3d07fc
...
@@ -63,6 +63,11 @@ class BlockTable:
...
@@ -63,6 +63,11 @@ class BlockTable:
def
add_row
(
self
,
block_ids
:
list
[
int
],
row_idx
:
int
)
->
None
:
def
add_row
(
self
,
block_ids
:
list
[
int
],
row_idx
:
int
)
->
None
:
self
.
num_blocks_per_row
[
row_idx
]
=
0
self
.
num_blocks_per_row
[
row_idx
]
=
0
# Keep the invariant that "unused" entries map to the null block (id=0).
# This matters when we *shrink* a request's block list (e.g. KV
# compression tail-block truncation) and later re-use freed blocks for
# other requests.
self
.
block_table_np
[
row_idx
,
:].
fill
(
0
)
self
.
append_row
(
block_ids
,
row_idx
)
self
.
append_row
(
block_ids
,
row_idx
)
def
move_row
(
self
,
src
:
int
,
tgt
:
int
)
->
None
:
def
move_row
(
self
,
src
:
int
,
tgt
:
int
)
->
None
:
...
...
vllm/v1/worker/gpu_input_batch.py
View file @
8d3d07fc
...
@@ -38,6 +38,7 @@ class CachedRequestState:
...
@@ -38,6 +38,7 @@ class CachedRequestState:
block_ids
:
tuple
[
list
[
int
],
...]
block_ids
:
tuple
[
list
[
int
],
...]
num_computed_tokens
:
int
num_computed_tokens
:
int
num_kv_tokens
:
int
output_token_ids
:
list
[
int
]
output_token_ids
:
list
[
int
]
spec_token_ids
:
list
[
int
]
=
None
spec_token_ids
:
list
[
int
]
=
None
...
@@ -51,6 +52,12 @@ class CachedRequestState:
...
@@ -51,6 +52,12 @@ class CachedRequestState:
repr
=
False
,
repr
=
False
,
compare
=
False
)
compare
=
False
)
# Chunked prefill (scheme 3): cached prompt compaction plan.
# Computed on the last prompt chunk; applied before the first decode step.
kv_compression_prompt_idx_sorted
:
Optional
[
torch
.
Tensor
]
=
None
# [K] int32
kv_compression_prompt_keep_len
:
Optional
[
int
]
=
None
kv_compression_prompt_prompt_len
:
Optional
[
int
]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
self
.
num_prompt_tokens
=
len
(
self
.
prompt_token_ids
)
self
.
num_prompt_tokens
=
len
(
self
.
prompt_token_ids
)
...
@@ -114,6 +121,13 @@ class InputBatch:
...
@@ -114,6 +121,13 @@ class InputBatch:
)
)
self
.
num_computed_tokens_cpu
=
\
self
.
num_computed_tokens_cpu
=
\
self
.
num_computed_tokens_cpu_tensor
.
numpy
()
self
.
num_computed_tokens_cpu_tensor
.
numpy
()
self
.
num_kv_tokens_cpu_tensor
=
torch
.
zeros
(
(
max_num_reqs
,
),
device
=
"cpu"
,
dtype
=
torch
.
int32
,
pin_memory
=
pin_memory
,
)
self
.
num_kv_tokens_cpu
=
self
.
num_kv_tokens_cpu_tensor
.
numpy
()
# Block table.
# Block table.
self
.
block_table
=
MultiGroupBlockTable
(
self
.
block_table
=
MultiGroupBlockTable
(
...
@@ -348,6 +362,7 @@ class InputBatch:
...
@@ -348,6 +362,7 @@ class InputBatch:
self
.
num_tokens_no_spec
[
req_index
]
=
request
.
num_tokens
self
.
num_tokens_no_spec
[
req_index
]
=
request
.
num_tokens
self
.
num_computed_tokens_cpu
[
req_index
]
=
request
.
num_computed_tokens
self
.
num_computed_tokens_cpu
[
req_index
]
=
request
.
num_computed_tokens
self
.
num_kv_tokens_cpu
[
req_index
]
=
request
.
num_kv_tokens
self
.
block_table
.
add_row
(
request
.
block_ids
,
req_index
)
self
.
block_table
.
add_row
(
request
.
block_ids
,
req_index
)
if
sampling_params
:
=
request
.
sampling_params
:
if
sampling_params
:
=
request
.
sampling_params
:
...
@@ -504,6 +519,8 @@ class InputBatch:
...
@@ -504,6 +519,8 @@ class InputBatch:
self
.
num_prompt_tokens
[
i2
],
self
.
num_prompt_tokens
[
i1
]
self
.
num_prompt_tokens
[
i2
],
self
.
num_prompt_tokens
[
i1
]
self
.
num_computed_tokens_cpu
[
i1
],
self
.
num_computed_tokens_cpu
[
i2
]
=
\
self
.
num_computed_tokens_cpu
[
i1
],
self
.
num_computed_tokens_cpu
[
i2
]
=
\
self
.
num_computed_tokens_cpu
[
i2
],
self
.
num_computed_tokens_cpu
[
i1
]
self
.
num_computed_tokens_cpu
[
i2
],
self
.
num_computed_tokens_cpu
[
i1
]
self
.
num_kv_tokens_cpu
[
i1
],
self
.
num_kv_tokens_cpu
[
i2
]
=
\
self
.
num_kv_tokens_cpu
[
i2
],
self
.
num_kv_tokens_cpu
[
i1
]
self
.
temperature_cpu
[
i1
],
self
.
temperature_cpu
[
i2
]
=
\
self
.
temperature_cpu
[
i1
],
self
.
temperature_cpu
[
i2
]
=
\
self
.
temperature_cpu
[
i2
],
self
.
temperature_cpu
[
i1
]
self
.
temperature_cpu
[
i2
],
self
.
temperature_cpu
[
i1
]
self
.
top_p_cpu
[
i1
],
self
.
top_p_cpu
[
i2
]
=
\
self
.
top_p_cpu
[
i1
],
self
.
top_p_cpu
[
i2
]
=
\
...
@@ -602,6 +619,8 @@ class InputBatch:
...
@@ -602,6 +619,8 @@ class InputBatch:
last_req_index
]
last_req_index
]
self
.
num_computed_tokens_cpu
[
self
.
num_computed_tokens_cpu
[
empty_index
]
=
self
.
num_computed_tokens_cpu
[
last_req_index
]
empty_index
]
=
self
.
num_computed_tokens_cpu
[
last_req_index
]
self
.
num_kv_tokens_cpu
[
empty_index
]
=
self
.
num_kv_tokens_cpu
[
last_req_index
]
self
.
block_table
.
move_row
(
last_req_index
,
empty_index
)
self
.
block_table
.
move_row
(
last_req_index
,
empty_index
)
self
.
temperature_cpu
[
empty_index
]
=
self
.
temperature_cpu
[
self
.
temperature_cpu
[
empty_index
]
=
self
.
temperature_cpu
[
last_req_index
]
last_req_index
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment