Commit ad069e33 authored by laibao's avatar laibao
Browse files

feat: kvpress新增 runner 侧 KV 压缩状态/位置打通

parent 2fde0fa2
...@@ -313,6 +313,14 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -313,6 +313,14 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
device="cpu", device="cpu",
pin_memory=self.pin_memory) pin_memory=self.pin_memory)
self.positions_np = self.positions_cpu.numpy() self.positions_np = self.positions_cpu.numpy()
# KV positions are decoupled from logical positions when KV compression
# is enabled. We keep a separate buffer to avoid recomputing or
# overwriting `positions_np` (used for RoPE / input token lookup).
self.kv_positions_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
self.kv_positions_np = self.kv_positions_cpu.numpy()
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32, dtype=torch.int32,
device="cpu", device="cpu",
...@@ -448,6 +456,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -448,6 +456,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
generator=generator, generator=generator,
block_ids=new_req_data.block_ids, block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens, num_computed_tokens=new_req_data.num_computed_tokens,
num_kv_tokens=new_req_data.num_kv_tokens,
output_token_ids=[], output_token_ids=[],
lora_request=new_req_data.lora_request, lora_request=new_req_data.lora_request,
) )
...@@ -497,11 +506,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -497,11 +506,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
for i, req_id in enumerate(req_data.req_ids): for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests[req_id] req_state = self.requests[req_id]
num_computed_tokens = req_data.num_computed_tokens[i] num_computed_tokens = req_data.num_computed_tokens[i]
num_kv_tokens = req_data.num_kv_tokens[i]
new_block_ids = req_data.new_block_ids[i] new_block_ids = req_data.new_block_ids[i]
resumed_from_preemption = req_data.resumed_from_preemption[i] resumed_from_preemption = req_data.resumed_from_preemption[i]
# Update the cached states. # Update the cached states.
req_state.num_computed_tokens = num_computed_tokens req_state.num_computed_tokens = num_computed_tokens
req_state.num_kv_tokens = num_kv_tokens
spec_token_ids = ( spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id, ())) scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
...@@ -545,7 +556,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -545,7 +556,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# Update the persistent batch. # Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = ( self.input_batch.num_computed_tokens_cpu[req_index] = (
num_computed_tokens) num_computed_tokens)
self.input_batch.block_table.append_row(new_block_ids, req_index) self.input_batch.num_kv_tokens_cpu[req_index] = num_kv_tokens
if resumed_from_preemption:
self.input_batch.block_table.add_row(new_block_ids, req_index)
else:
self.input_batch.block_table.append_row(new_block_ids, req_index)
# For the last rank, we don't need to update the token_ids_cpu # For the last rank, we don't need to update the token_ids_cpu
# because the sampled tokens are already cached. # because the sampled tokens are already cached.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment