from dataclasses import dataclass from typing import List, Optional, Tuple import torch # Import from compression_config, not compression.__init__, to avoid circular imports # (compression -> compactor -> context -> compression). from vllm.kvprune.compression.compression_config import CompressionMethod from vllm.kvprune.config.engine_config import KvpruneAttentionSchedule @dataclass class CompressionContext: compression_method: CompressionMethod = CompressionMethod.COMPACTOR compression_chunk_size: int = -1 batch_tokens_to_retain: torch.Tensor | None = None max_tokens_to_retain: int = 0 context_lens: List[int] | None = None PHI: torch.Tensor | None = None # Compactor(与 kvpress ``CompactorPress`` 对齐的可选超参) sketch_dimension: int = 48 sink_size_start: int = 8 sink_size_end: int = 4 compactor_blending: Optional[float] = None # 与 kvpress 一致:未设 ``compactor_blending`` 时用该值(来自请求的 compression_ratio) compression_ratio: Optional[float] = None protected_first_tokens: List[int] | None = None protected_last_tokens: List[int] | None = None # CriticalAdaKV wo_weight: Optional[torch.Tensor] = None critical_ada_epsilon: float = 1e-4 critical_ada_first_stage_ratio: float = 0.5 critical_ada_alpha_safeguard: float = 0.2 @dataclass class Context: is_prefill: bool = False do_compression: bool = False cu_seqlens_q: torch.Tensor | None = None cu_seqlens_k: torch.Tensor | None = None # Set in ModelRunner.run_prefill before forward — avoids D2H inside compactor kernels. cu_seqlens_q_host: Optional[Tuple[int, ...]] = None cu_seqlens_k_host: Optional[Tuple[int, ...]] = None max_seqlen_q: int = 0 max_seqlen_k: int = 0 batch_mapping: torch.Tensor | None = None max_bh_len: int = 0 compression_context: CompressionContext | None = None STORE_STREAM: torch.cuda.Stream | None = None key_split: int | None = None attention_schedule: KvpruneAttentionSchedule = ( KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE ) _CONTEXT = Context() def get_context(): return _CONTEXT def set_context( *, is_prefill, do_compression=False, cu_seqlens_q=None, cu_seqlens_k=None, cu_seqlens_q_host: Optional[Tuple[int, ...]] = None, cu_seqlens_k_host: Optional[Tuple[int, ...]] = None, max_seqlen_q=0, max_seqlen_k=0, batch_mapping=None, max_bh_len=0, compression_context: CompressionContext = None, STORE_STREAM=None, key_split=None, attention_schedule=KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE, ): global _CONTEXT _CONTEXT = Context( is_prefill, do_compression, cu_seqlens_q, cu_seqlens_k, cu_seqlens_q_host, cu_seqlens_k_host, max_seqlen_q, max_seqlen_k, batch_mapping, max_bh_len, compression_context, STORE_STREAM, key_split, attention_schedule, ) def reset_context(): global _CONTEXT _CONTEXT = Context()