__init__.py 1.33 KB
Newer Older
chenzk's avatar
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from vllm.kvprune.compression.common import (
    BaseCompressionMethod,
    NoCompression,
)
from vllm.kvprune.compression.criticalkv import CriticalAdaKVCompression
from vllm.kvprune.compression.compactor import CompactorCompression
from vllm.kvprune.compression.compression_config import (
    BatchCompressionParams,
    CompressionMethod,
    SequenceCompressionParams,
)
from vllm.kvprune.compression.snapkv import SnapKVCompression

COMPRESSION_REGISTRY: dict[CompressionMethod, type[BaseCompressionMethod]] = {
    CompressionMethod.CRITICALADAKV: CriticalAdaKVCompression,
    CompressionMethod.COMPACTOR: CompactorCompression,
    CompressionMethod.SNAPKV: SnapKVCompression,
    CompressionMethod.NONE: NoCompression,
}


def apply_prerope_compression(q, k, v, context):
    method = context.compression_context.compression_method
    return COMPRESSION_REGISTRY[method].pre_rope_scoring(q, k, v, context=context)


def apply_postrope_compression(q, k, v, prerope_scores, context):
    method = context.compression_context.compression_method
    return COMPRESSION_REGISTRY[method].post_rope_scoring(
        q, k, v, prerope_scores, context=context
    )


__all__ = [
    "apply_prerope_compression",
    "apply_postrope_compression",
    "CompressionMethod",
    "BatchCompressionParams",
    "SequenceCompressionParams",
    "COMPRESSION_REGISTRY"
]