context.py 3.05 KB
Newer Older
chenzk's avatar
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from dataclasses import dataclass
from typing import List, Optional, Tuple

import torch

# Import from compression_config, not compression.__init__, to avoid circular imports
# (compression -> compactor -> context -> compression).
from vllm.kvprune.compression.compression_config import CompressionMethod
from vllm.kvprune.config.engine_config import KvpruneAttentionSchedule


@dataclass
class CompressionContext:
    compression_method: CompressionMethod = CompressionMethod.COMPACTOR

    compression_chunk_size: int = -1
    batch_tokens_to_retain: torch.Tensor | None = None
    max_tokens_to_retain: int = 0
    context_lens: List[int] | None = None
    PHI: torch.Tensor | None = None

    # Compactor(与 kvpress ``CompactorPress`` 对齐的可选超参)
    sketch_dimension: int = 48
    sink_size_start: int = 8
    sink_size_end: int = 4
    compactor_blending: Optional[float] = None
    # 与 kvpress 一致:未设 ``compactor_blending`` 时用该值(来自请求的 compression_ratio)
    compression_ratio: Optional[float] = None

    protected_first_tokens: List[int] | None = None
    protected_last_tokens: List[int] | None = None

    # CriticalAdaKV
    wo_weight: Optional[torch.Tensor] = None
    critical_ada_epsilon: float = 1e-4
    critical_ada_first_stage_ratio: float = 0.5
    critical_ada_alpha_safeguard: float = 0.2


@dataclass
class Context:
    is_prefill: bool = False
    do_compression: bool = False

    cu_seqlens_q: torch.Tensor | None = None
    cu_seqlens_k: torch.Tensor | None = None
    # Set in ModelRunner.run_prefill before forward — avoids D2H inside compactor kernels.
    cu_seqlens_q_host: Optional[Tuple[int, ...]] = None
    cu_seqlens_k_host: Optional[Tuple[int, ...]] = None
    max_seqlen_q: int = 0
    max_seqlen_k: int = 0
    batch_mapping: torch.Tensor | None = None
    max_bh_len: int = 0

    compression_context: CompressionContext | None = None
    STORE_STREAM: torch.cuda.Stream | None = None

    key_split: int | None = None
    attention_schedule: KvpruneAttentionSchedule = (
        KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE
    )


_CONTEXT = Context()


def get_context():
    return _CONTEXT


def set_context(
    *,
    is_prefill,
    do_compression=False,
    cu_seqlens_q=None,
    cu_seqlens_k=None,
    cu_seqlens_q_host: Optional[Tuple[int, ...]] = None,
    cu_seqlens_k_host: Optional[Tuple[int, ...]] = None,
    max_seqlen_q=0,
    max_seqlen_k=0,
    batch_mapping=None,
    max_bh_len=0,
    compression_context: CompressionContext = None,
    STORE_STREAM=None,
    key_split=None,
    attention_schedule=KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE,
):
    global _CONTEXT
    _CONTEXT = Context(
        is_prefill,
        do_compression,
        cu_seqlens_q,
        cu_seqlens_k,
        cu_seqlens_q_host,
        cu_seqlens_k_host,
        max_seqlen_q,
        max_seqlen_k,
        batch_mapping,
        max_bh_len,
        compression_context,
        STORE_STREAM,
        key_split,
        attention_schedule,
    )


def reset_context():
    global _CONTEXT
    _CONTEXT = Context()