Commit 8d3d07fc authored by laibao's avatar laibao
Browse files

feat: kvpress新增 KV 压缩状态与元数据打通

parent 5036e878
......@@ -31,6 +31,7 @@ class NewRequestData:
pooling_params: Optional[PoolingParams]
block_ids: tuple[list[int], ...]
num_computed_tokens: int
num_kv_tokens: int
lora_request: Optional[LoRARequest]
@classmethod
......@@ -49,6 +50,7 @@ class NewRequestData:
pooling_params=request.pooling_params,
block_ids=block_ids,
num_computed_tokens=request.num_computed_tokens,
num_kv_tokens=request.num_kv_tokens,
lora_request=request.lora_request,
)
......@@ -62,6 +64,7 @@ class NewRequestData:
f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens},"
f"num_kv_tokens={self.num_kv_tokens},"
f"lora_request={self.lora_request}"
")")
......@@ -76,6 +79,7 @@ class NewRequestData:
f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens},"
f"num_kv_tokens={self.num_kv_tokens},"
f"lora_request={self.lora_request}"
")")
......@@ -93,6 +97,7 @@ class CachedRequestData:
new_token_ids: list[list[int]]
new_block_ids: list[tuple[list[int], ...]]
num_computed_tokens: list[int]
num_kv_tokens: list[int]
@property
def num_reqs(self) -> int:
......@@ -106,6 +111,7 @@ class CachedRequestData:
new_token_ids=[],
new_block_ids=[],
num_computed_tokens=[],
num_kv_tokens=[],
)
......
......@@ -79,6 +79,10 @@ class Request:
self._all_token_ids: list[int] = self.prompt_token_ids.copy()
self.spec_token_ids: list[int] = []
self.num_computed_tokens = 0
# Number of tokens currently stored in the KV cache for this request.
# This can be different from `num_computed_tokens` when KV compression
# is enabled (e.g., token-shared prefill compression).
self.num_kv_tokens = 0
self.num_generated_token_ids = 0
self.cache_salt: Optional[str] = cache_salt
......
......@@ -63,6 +63,11 @@ class BlockTable:
def add_row(self, block_ids: list[int], row_idx: int) -> None:
self.num_blocks_per_row[row_idx] = 0
# Keep the invariant that "unused" entries map to the null block (id=0).
# This matters when we *shrink* a request's block list (e.g. KV
# compression tail-block truncation) and later re-use freed blocks for
# other requests.
self.block_table_np[row_idx, :].fill(0)
self.append_row(block_ids, row_idx)
def move_row(self, src: int, tgt: int) -> None:
......
......@@ -38,6 +38,7 @@ class CachedRequestState:
block_ids: tuple[list[int], ...]
num_computed_tokens: int
num_kv_tokens: int
output_token_ids: list[int]
spec_token_ids: list[int] = None
......@@ -51,6 +52,12 @@ class CachedRequestState:
repr=False,
compare=False)
# Chunked prefill (scheme 3): cached prompt compaction plan.
# Computed on the last prompt chunk; applied before the first decode step.
kv_compression_prompt_idx_sorted: Optional[torch.Tensor] = None # [K] int32
kv_compression_prompt_keep_len: Optional[int] = None
kv_compression_prompt_prompt_len: Optional[int] = None
def __post_init__(self):
self.num_prompt_tokens = len(self.prompt_token_ids)
......@@ -114,6 +121,13 @@ class InputBatch:
)
self.num_computed_tokens_cpu = \
self.num_computed_tokens_cpu_tensor.numpy()
self.num_kv_tokens_cpu_tensor = torch.zeros(
(max_num_reqs, ),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.num_kv_tokens_cpu = self.num_kv_tokens_cpu_tensor.numpy()
# Block table.
self.block_table = MultiGroupBlockTable(
......@@ -348,6 +362,7 @@ class InputBatch:
self.num_tokens_no_spec[req_index] = request.num_tokens
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
self.num_kv_tokens_cpu[req_index] = request.num_kv_tokens
self.block_table.add_row(request.block_ids, req_index)
if sampling_params := request.sampling_params:
......@@ -504,6 +519,8 @@ class InputBatch:
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
self.num_kv_tokens_cpu[i1], self.num_kv_tokens_cpu[i2] =\
self.num_kv_tokens_cpu[i2], self.num_kv_tokens_cpu[i1]
self.temperature_cpu[i1], self.temperature_cpu[i2] =\
self.temperature_cpu[i2], self.temperature_cpu[i1]
self.top_p_cpu[i1], self.top_p_cpu[i2] =\
......@@ -602,6 +619,8 @@ class InputBatch:
last_req_index]
self.num_computed_tokens_cpu[
empty_index] = self.num_computed_tokens_cpu[last_req_index]
self.num_kv_tokens_cpu[
empty_index] = self.num_kv_tokens_cpu[last_req_index]
self.block_table.move_row(last_req_index, empty_index)
self.temperature_cpu[empty_index] = self.temperature_cpu[
last_req_index]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment