Commit d41ca128 authored by laibao's avatar laibao
Browse files

feat(kvpress): 扩展 InputBatch 请求状态

parent c44fcded
......@@ -38,6 +38,7 @@ class CachedRequestState:
block_ids: tuple[list[int], ...]
num_computed_tokens: int
num_kv_tokens: int
output_token_ids: list[int]
mrope_positions: torch.Tensor | None = None
......@@ -51,6 +52,12 @@ class CachedRequestState:
# Used when both async_scheduling and spec_decode are enabled.
prev_num_draft_len: int = 0
# Chunked prefill (scheme 3): cached prompt compaction plan.
# Computed on the last prompt chunk; applied before the first decode step.
kv_compression_prompt_idx_sorted: torch.Tensor | None = None # [K] int32
kv_compression_prompt_keep_len: int | None = None
kv_compression_prompt_prompt_len: int | None = None
# for pooling models
pooling_params: PoolingParams | None = None
pooling_states: PoolingStates | None = None
......@@ -143,6 +150,13 @@ class InputBatch:
pin_memory=pin_memory,
)
self.num_computed_tokens_cpu = self.num_computed_tokens_cpu_tensor.numpy()
self.num_kv_tokens_cpu_tensor = torch.zeros(
(max_num_reqs,),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.num_kv_tokens_cpu = self.num_kv_tokens_cpu_tensor.numpy()
# Block table.
self.block_table = MultiGroupBlockTable(
......@@ -346,6 +360,7 @@ class InputBatch:
self.num_tokens_no_spec[req_index] = request.num_tokens
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
self.num_kv_tokens_cpu[req_index] = request.num_kv_tokens
self.block_table.add_row(request.block_ids, req_index)
if sampling_params := request.sampling_params:
......@@ -556,6 +571,10 @@ class InputBatch:
self.num_computed_tokens_cpu[i2],
self.num_computed_tokens_cpu[i1],
)
self.num_kv_tokens_cpu[i1], self.num_kv_tokens_cpu[i2] = (
self.num_kv_tokens_cpu[i2],
self.num_kv_tokens_cpu[i1],
)
# NOTE: the following is unsafe
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
......@@ -706,6 +725,7 @@ class InputBatch:
self.num_computed_tokens_cpu[empty_index] = self.num_computed_tokens_cpu[
last_req_index
]
self.num_kv_tokens_cpu[empty_index] = self.num_kv_tokens_cpu[last_req_index]
self.block_table.move_row(last_req_index, empty_index)
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment