compression_config.py 1.08 KB
Newer Older
chenzk's avatar
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import logging
from dataclasses import dataclass
from enum import Enum, auto

logger = logging.getLogger(__name__)


class CompressionMethod(Enum):
    CRITICALADAKV = auto()
    COMPACTOR = auto()
    SNAPKV = auto()
    NONE = auto()


# class CachingPolicy(Enum):
#     CACHE_PROMPT = auto()
#     DONT_CACHE = auto()


# class CompressionType(Enum):
#     QUERY_AWARE = auto()
#     QUERY_AGNOSTIC = auto()


@dataclass
class SequenceCompressionParams:
    compression_ratio: float = 1.0
    protected_first_tokens: int = 16
    protected_last_tokens: int = 64


@dataclass
class BatchCompressionParams:
    # compression_type: CompressionType = CompressionType.QUERY_AGNOSTIC
    compression_method: CompressionMethod = CompressionMethod.COMPACTOR

    do_chunked_compression: bool = True
    chunk_size: int = 512

    def __post_init__(self):
        if self.compression_method == CompressionMethod.SNAPKV:
            self.do_chunked_compression = False
            logger.warning(
                "CompressionMethod.SNAPKV is not compatible with chunked compression. Disabling it."
            )