"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "0f3f3c86ec44467fa80b60cb9f971f9ede028f76"
Commit 2df94aa9 authored by laibao's avatar laibao
Browse files

feat: kvpress runner 侧按 num_kv_tokens 计算 KV 写入位置

parent ad069e33
...@@ -146,6 +146,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -146,6 +146,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
self.attention_chunk_size = model_config.attention_chunk_size self.attention_chunk_size = model_config.attention_chunk_size
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
if envs.VLLM_ENABLE_KV_COMPRESSION:
# KV compression changes the effective KV sequence layout and
# invalidates cascade attention assumptions (common-prefix blocks).
self.cascade_attn_enabled = False
# Whether the current step needs KV compaction work (score/topk/dst).
# This is set per-step in `_prepare_inputs`.
self.kv_compression_needs_compaction: bool = False
# Multi-modal data support # Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY self.mm_registry = MULTIMODAL_REGISTRY
...@@ -673,6 +680,17 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -673,6 +680,17 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
np.add(self.input_batch.num_computed_tokens_cpu[req_indices], np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
arange, arange,
out=positions_np) out=positions_np)
# KV positions (where the KV for each scheduled token is temporarily
# written). When KV compression is enabled, KV positions are decoupled
# from logical positions.
use_kv_compression = envs.VLLM_ENABLE_KV_COMPRESSION
if use_kv_compression:
kv_positions_np = self.kv_positions_np[:total_num_scheduled_tokens]
np.add(self.input_batch.num_kv_tokens_cpu[req_indices],
arange,
out=kv_positions_np)
else:
kv_positions_np = None
# Calculate M-RoPE positions. # Calculate M-RoPE positions.
# Only relevant for models using M-RoPE (e.g, Qwen2-VL) # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
...@@ -700,6 +718,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -700,6 +718,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
block_size = kv_cache_group_spec.kv_cache_spec.block_size block_size = kv_cache_group_spec.kv_cache_spec.block_size
block_table: BlockTable = self.input_batch.block_table[ block_table: BlockTable = self.input_batch.block_table[
kv_cache_group_id] kv_cache_group_id]
slot_positions_np = (kv_positions_np
if use_kv_compression else positions_np)
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2. # where K is the max_num_blocks_per_req and the block size is 2.
...@@ -708,11 +728,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -708,11 +728,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# block_size. # block_size.
block_table_indices = ( block_table_indices = (
req_indices * block_table.max_num_blocks_per_req + req_indices * block_table.max_num_blocks_per_req +
positions_np // block_size) slot_positions_np // block_size)
block_table_cpu = block_table.get_cpu_tensor() block_table_cpu = block_table.get_cpu_tensor()
block_numbers = block_table_cpu.flatten( block_numbers = block_table_cpu.flatten(
)[block_table_indices].numpy() )[block_table_indices].numpy()
block_offsets = positions_np % block_size block_offsets = slot_positions_np % block_size
np.add( np.add(
block_numbers * block_size, block_numbers * block_size,
block_offsets, block_offsets,
...@@ -722,9 +742,14 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -722,9 +742,14 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
self.query_start_loc_np[0] = 0 self.query_start_loc_np[0] = 0
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
self.seq_lens_np[:num_reqs] = ( if use_kv_compression:
self.input_batch.num_computed_tokens_cpu[:num_reqs] + self.seq_lens_np[:num_reqs] = (
num_scheduled_tokens) self.input_batch.num_kv_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
else:
self.seq_lens_np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
# Copy the tensors to the GPU. # Copy the tensors to the GPU.
self.input_ids[:total_num_scheduled_tokens].copy_( self.input_ids[:total_num_scheduled_tokens].copy_(
...@@ -2547,6 +2572,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -2547,6 +2572,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
assert len(self.attn_backends) == 0 and len( assert len(self.attn_backends) == 0 and len(
self.attn_metadata_builders self.attn_metadata_builders
) == 0, "Attention backends are already initialized" ) == 0, "Attention backends are already initialized"
if envs.VLLM_ENABLE_KV_COMPRESSION and self.full_cuda_graph:
raise ValueError(
"KV compression is currently incompatible with full CUDA "
"graph mode.")
for i, kv_cache_group_spec in enumerate( for i, kv_cache_group_spec in enumerate(
kv_cache_config.kv_cache_groups): kv_cache_config.kv_cache_groups):
kv_cache_spec = kv_cache_group_spec.kv_cache_spec kv_cache_spec = kv_cache_group_spec.kv_cache_spec
...@@ -2570,7 +2599,16 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -2570,7 +2599,16 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
raise NotImplementedError( raise NotImplementedError(
"Non-Attention backend is not supported by V1 " "Non-Attention backend is not supported by V1 "
"GPUModelRunner.") "GPUModelRunner.")
if (envs.VLLM_ENABLE_KV_COMPRESSION
and attn_backend_i.get_name() != "FLASH_ATTN_VLLM_V1"):
raise ValueError(
"KV compression currently requires "
"VLLM_ATTENTION_BACKEND=FLASH_ATTN_VLLM_V1.")
elif isinstance(kv_cache_spec, MambaSpec): elif isinstance(kv_cache_spec, MambaSpec):
if envs.VLLM_ENABLE_KV_COMPRESSION:
raise ValueError(
"KV compression is currently only supported for "
"Transformer attention layers.")
attn_backend_i = Mamba2AttentionBackend attn_backend_i = Mamba2AttentionBackend
else: else:
raise ValueError( raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment