Commit d41ca128 authored by laibao's avatar laibao
Browse files

feat(kvpress): 扩展 InputBatch 请求状态

parent c44fcded
...@@ -38,6 +38,7 @@ class CachedRequestState: ...@@ -38,6 +38,7 @@ class CachedRequestState:
block_ids: tuple[list[int], ...] block_ids: tuple[list[int], ...]
num_computed_tokens: int num_computed_tokens: int
num_kv_tokens: int
output_token_ids: list[int] output_token_ids: list[int]
mrope_positions: torch.Tensor | None = None mrope_positions: torch.Tensor | None = None
...@@ -51,6 +52,12 @@ class CachedRequestState: ...@@ -51,6 +52,12 @@ class CachedRequestState:
# Used when both async_scheduling and spec_decode are enabled. # Used when both async_scheduling and spec_decode are enabled.
prev_num_draft_len: int = 0 prev_num_draft_len: int = 0
# Chunked prefill (scheme 3): cached prompt compaction plan.
# Computed on the last prompt chunk; applied before the first decode step.
kv_compression_prompt_idx_sorted: torch.Tensor | None = None # [K] int32
kv_compression_prompt_keep_len: int | None = None
kv_compression_prompt_prompt_len: int | None = None
# for pooling models # for pooling models
pooling_params: PoolingParams | None = None pooling_params: PoolingParams | None = None
pooling_states: PoolingStates | None = None pooling_states: PoolingStates | None = None
...@@ -143,6 +150,13 @@ class InputBatch: ...@@ -143,6 +150,13 @@ class InputBatch:
pin_memory=pin_memory, pin_memory=pin_memory,
) )
self.num_computed_tokens_cpu = self.num_computed_tokens_cpu_tensor.numpy() self.num_computed_tokens_cpu = self.num_computed_tokens_cpu_tensor.numpy()
self.num_kv_tokens_cpu_tensor = torch.zeros(
(max_num_reqs,),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.num_kv_tokens_cpu = self.num_kv_tokens_cpu_tensor.numpy()
# Block table. # Block table.
self.block_table = MultiGroupBlockTable( self.block_table = MultiGroupBlockTable(
...@@ -346,6 +360,7 @@ class InputBatch: ...@@ -346,6 +360,7 @@ class InputBatch:
self.num_tokens_no_spec[req_index] = request.num_tokens self.num_tokens_no_spec[req_index] = request.num_tokens
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
self.num_kv_tokens_cpu[req_index] = request.num_kv_tokens
self.block_table.add_row(request.block_ids, req_index) self.block_table.add_row(request.block_ids, req_index)
if sampling_params := request.sampling_params: if sampling_params := request.sampling_params:
...@@ -556,6 +571,10 @@ class InputBatch: ...@@ -556,6 +571,10 @@ class InputBatch:
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i2],
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i1],
) )
self.num_kv_tokens_cpu[i1], self.num_kv_tokens_cpu[i2] = (
self.num_kv_tokens_cpu[i2],
self.num_kv_tokens_cpu[i1],
)
# NOTE: the following is unsafe # NOTE: the following is unsafe
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\ # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
...@@ -706,6 +725,7 @@ class InputBatch: ...@@ -706,6 +725,7 @@ class InputBatch:
self.num_computed_tokens_cpu[empty_index] = self.num_computed_tokens_cpu[ self.num_computed_tokens_cpu[empty_index] = self.num_computed_tokens_cpu[
last_req_index last_req_index
] ]
self.num_kv_tokens_cpu[empty_index] = self.num_kv_tokens_cpu[last_req_index]
self.block_table.move_row(last_req_index, empty_index) self.block_table.move_row(last_req_index, empty_index)
self.request_lora_mapping[empty_index] = self.request_lora_mapping[ self.request_lora_mapping[empty_index] = self.request_lora_mapping[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment