Commit 8d3d07fc authored by laibao's avatar laibao
Browse files

feat: kvpress新增 KV 压缩状态与元数据打通

parent 5036e878
...@@ -31,6 +31,7 @@ class NewRequestData: ...@@ -31,6 +31,7 @@ class NewRequestData:
pooling_params: Optional[PoolingParams] pooling_params: Optional[PoolingParams]
block_ids: tuple[list[int], ...] block_ids: tuple[list[int], ...]
num_computed_tokens: int num_computed_tokens: int
num_kv_tokens: int
lora_request: Optional[LoRARequest] lora_request: Optional[LoRARequest]
@classmethod @classmethod
...@@ -49,6 +50,7 @@ class NewRequestData: ...@@ -49,6 +50,7 @@ class NewRequestData:
pooling_params=request.pooling_params, pooling_params=request.pooling_params,
block_ids=block_ids, block_ids=block_ids,
num_computed_tokens=request.num_computed_tokens, num_computed_tokens=request.num_computed_tokens,
num_kv_tokens=request.num_kv_tokens,
lora_request=request.lora_request, lora_request=request.lora_request,
) )
...@@ -62,6 +64,7 @@ class NewRequestData: ...@@ -62,6 +64,7 @@ class NewRequestData:
f"sampling_params={self.sampling_params}," f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids}," f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens}," f"num_computed_tokens={self.num_computed_tokens},"
f"num_kv_tokens={self.num_kv_tokens},"
f"lora_request={self.lora_request}" f"lora_request={self.lora_request}"
")") ")")
...@@ -76,6 +79,7 @@ class NewRequestData: ...@@ -76,6 +79,7 @@ class NewRequestData:
f"sampling_params={self.sampling_params}," f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids}," f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens}," f"num_computed_tokens={self.num_computed_tokens},"
f"num_kv_tokens={self.num_kv_tokens},"
f"lora_request={self.lora_request}" f"lora_request={self.lora_request}"
")") ")")
...@@ -93,6 +97,7 @@ class CachedRequestData: ...@@ -93,6 +97,7 @@ class CachedRequestData:
new_token_ids: list[list[int]] new_token_ids: list[list[int]]
new_block_ids: list[tuple[list[int], ...]] new_block_ids: list[tuple[list[int], ...]]
num_computed_tokens: list[int] num_computed_tokens: list[int]
num_kv_tokens: list[int]
@property @property
def num_reqs(self) -> int: def num_reqs(self) -> int:
...@@ -106,6 +111,7 @@ class CachedRequestData: ...@@ -106,6 +111,7 @@ class CachedRequestData:
new_token_ids=[], new_token_ids=[],
new_block_ids=[], new_block_ids=[],
num_computed_tokens=[], num_computed_tokens=[],
num_kv_tokens=[],
) )
......
...@@ -79,6 +79,10 @@ class Request: ...@@ -79,6 +79,10 @@ class Request:
self._all_token_ids: list[int] = self.prompt_token_ids.copy() self._all_token_ids: list[int] = self.prompt_token_ids.copy()
self.spec_token_ids: list[int] = [] self.spec_token_ids: list[int] = []
self.num_computed_tokens = 0 self.num_computed_tokens = 0
# Number of tokens currently stored in the KV cache for this request.
# This can be different from `num_computed_tokens` when KV compression
# is enabled (e.g., token-shared prefill compression).
self.num_kv_tokens = 0
self.num_generated_token_ids = 0 self.num_generated_token_ids = 0
self.cache_salt: Optional[str] = cache_salt self.cache_salt: Optional[str] = cache_salt
......
...@@ -63,6 +63,11 @@ class BlockTable: ...@@ -63,6 +63,11 @@ class BlockTable:
def add_row(self, block_ids: list[int], row_idx: int) -> None: def add_row(self, block_ids: list[int], row_idx: int) -> None:
self.num_blocks_per_row[row_idx] = 0 self.num_blocks_per_row[row_idx] = 0
# Keep the invariant that "unused" entries map to the null block (id=0).
# This matters when we *shrink* a request's block list (e.g. KV
# compression tail-block truncation) and later re-use freed blocks for
# other requests.
self.block_table_np[row_idx, :].fill(0)
self.append_row(block_ids, row_idx) self.append_row(block_ids, row_idx)
def move_row(self, src: int, tgt: int) -> None: def move_row(self, src: int, tgt: int) -> None:
......
...@@ -38,6 +38,7 @@ class CachedRequestState: ...@@ -38,6 +38,7 @@ class CachedRequestState:
block_ids: tuple[list[int], ...] block_ids: tuple[list[int], ...]
num_computed_tokens: int num_computed_tokens: int
num_kv_tokens: int
output_token_ids: list[int] output_token_ids: list[int]
spec_token_ids: list[int] = None spec_token_ids: list[int] = None
...@@ -51,6 +52,12 @@ class CachedRequestState: ...@@ -51,6 +52,12 @@ class CachedRequestState:
repr=False, repr=False,
compare=False) compare=False)
# Chunked prefill (scheme 3): cached prompt compaction plan.
# Computed on the last prompt chunk; applied before the first decode step.
kv_compression_prompt_idx_sorted: Optional[torch.Tensor] = None # [K] int32
kv_compression_prompt_keep_len: Optional[int] = None
kv_compression_prompt_prompt_len: Optional[int] = None
def __post_init__(self): def __post_init__(self):
self.num_prompt_tokens = len(self.prompt_token_ids) self.num_prompt_tokens = len(self.prompt_token_ids)
...@@ -114,6 +121,13 @@ class InputBatch: ...@@ -114,6 +121,13 @@ class InputBatch:
) )
self.num_computed_tokens_cpu = \ self.num_computed_tokens_cpu = \
self.num_computed_tokens_cpu_tensor.numpy() self.num_computed_tokens_cpu_tensor.numpy()
self.num_kv_tokens_cpu_tensor = torch.zeros(
(max_num_reqs, ),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.num_kv_tokens_cpu = self.num_kv_tokens_cpu_tensor.numpy()
# Block table. # Block table.
self.block_table = MultiGroupBlockTable( self.block_table = MultiGroupBlockTable(
...@@ -348,6 +362,7 @@ class InputBatch: ...@@ -348,6 +362,7 @@ class InputBatch:
self.num_tokens_no_spec[req_index] = request.num_tokens self.num_tokens_no_spec[req_index] = request.num_tokens
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
self.num_kv_tokens_cpu[req_index] = request.num_kv_tokens
self.block_table.add_row(request.block_ids, req_index) self.block_table.add_row(request.block_ids, req_index)
if sampling_params := request.sampling_params: if sampling_params := request.sampling_params:
...@@ -504,6 +519,8 @@ class InputBatch: ...@@ -504,6 +519,8 @@ class InputBatch:
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1] self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\ self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1] self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
self.num_kv_tokens_cpu[i1], self.num_kv_tokens_cpu[i2] =\
self.num_kv_tokens_cpu[i2], self.num_kv_tokens_cpu[i1]
self.temperature_cpu[i1], self.temperature_cpu[i2] =\ self.temperature_cpu[i1], self.temperature_cpu[i2] =\
self.temperature_cpu[i2], self.temperature_cpu[i1] self.temperature_cpu[i2], self.temperature_cpu[i1]
self.top_p_cpu[i1], self.top_p_cpu[i2] =\ self.top_p_cpu[i1], self.top_p_cpu[i2] =\
...@@ -602,6 +619,8 @@ class InputBatch: ...@@ -602,6 +619,8 @@ class InputBatch:
last_req_index] last_req_index]
self.num_computed_tokens_cpu[ self.num_computed_tokens_cpu[
empty_index] = self.num_computed_tokens_cpu[last_req_index] empty_index] = self.num_computed_tokens_cpu[last_req_index]
self.num_kv_tokens_cpu[
empty_index] = self.num_kv_tokens_cpu[last_req_index]
self.block_table.move_row(last_req_index, empty_index) self.block_table.move_row(last_req_index, empty_index)
self.temperature_cpu[empty_index] = self.temperature_cpu[ self.temperature_cpu[empty_index] = self.temperature_cpu[
last_req_index] last_req_index]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment