compression_params.py 2.11 KB
Newer Older
chenzk's avatar
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Per-request KV compression for :meth:`vllm.LLM.generate` (``compression=`` kwarg)."""

from __future__ import annotations

from dataclasses import dataclass


@dataclass
class CompressionParams:
    """Per-prompt compression intent for :meth:`vllm.LLM.generate`.

    If **any** prompt in the batch has ``compression_ratio < 1.0``, the **whole** batch
    is run on the compactor ``LLMEngine`` (same stack as standalone compactor-vllm:
    ``PagedKVCache`` + pruning kernels). If all prompts have ``compression_ratio >= 1.0``,
    the batch stays on standard vLLM.

    ``compression_method`` follows :mod:`vllm.kvprune.core.compression_bridge` aliases:
    ``none``, ``criticaladakv``, ``compactor``, ``snapkv`` (ignored when
    ``compression_ratio`` is effectively 1).

    ``protected_*`` map to compactor :class:`~vllm.kvprune.compression.compression_config.SequenceCompressionParams`
    (defaults match standalone compactor-vllm-style usage).
    """

    compression_ratio: float = 1.0
    compression_method: str = "compactor"
    protected_first_tokens: int = 16
    protected_last_tokens: int = 64

    def __post_init__(self) -> None:
        if not 0.0 < self.compression_ratio <= 1.0:
            raise ValueError(
                f"compression_ratio must be in (0, 1], got {self.compression_ratio}"
            )
        self.compression_method = (
            self.compression_method or "compactor"
        ).strip().lower()
        from vllm.kvprune.core.compression_bridge import VALID_ALIASES_FOR_SAMPLING

        if self.compression_method not in VALID_ALIASES_FOR_SAMPLING:
            raise ValueError(
                f"compression_method must be one of {sorted(VALID_ALIASES_FOR_SAMPLING)}, "
                f"got {self.compression_method!r}"
            )
        if self.compression_ratio >= 1.0 - 1e-9:
            self.compression_method = "none"
        elif self.compression_method == "none":
            raise ValueError(
                "When compression_ratio < 1.0, compression_method cannot be 'none'."
            )