attention.py 2.9 KB
Newer Older
1
2
3
4
5
6
7
8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Any, Literal

from pydantic import field_validator

from vllm.config.utils import config
9
from vllm.v1.attention.backends.registry import AttentionBackendEnum
10
11
12
13
14
15
16


@config
class AttentionConfig:
    """Configuration for attention mechanisms in vLLM."""

    backend: AttentionBackendEnum | None = None
17
    """Attention backend to use. Use "auto" or None for automatic selection."""
18

19
20
    flash_attn_version: Literal[2, 3, 4] | None = None
    """Force vllm to use a specific flash-attention version (2, 3, or 4).
21
22
23
24
25
26
27
28
29
    Only valid when using the flash-attention backend."""

    use_prefill_decode_attention: bool = False
    """Use separate prefill and decode kernels for attention instead of
    the unified triton kernel."""

    flash_attn_max_num_splits_for_cuda_graph: int = 32
    """Flash Attention max number splits for cuda graph decode."""

30
31
32
33
34
    tq_max_kv_splits_for_cuda_graph: int = 32
    """TurboQuant max NUM_KV_SPLITS for cuda graph decode.
    Fixes the split count so grid dimensions are constant across captures,
    and buffers can be pre-allocated to avoid inflating the memory estimate."""

35
36
37
    use_cudnn_prefill: bool = False
    """Whether to use cudnn prefill."""

38
    use_trtllm_ragged_deepseek_prefill: bool = False
39
40
41
42
43
44
    """Whether to use TRTLLM ragged deepseek prefill."""

    use_trtllm_attention: bool | None = None
    """If set to True/False, use or don't use the TRTLLM attention backend
    in flashinfer. If None, auto-detect the attention backend in flashinfer."""

45
    disable_flashinfer_prefill: bool = True
46
47
48
49
50
    """Whether to disable flashinfer prefill."""

    disable_flashinfer_q_quantization: bool = False
    """If set, when using fp8 kv, do not quantize Q to fp8."""

51
52
53
    use_prefill_query_quantization: bool = False
    """If set, quantize query for attention in prefill."""

54
55
56
57
58
59
60
61
62
63
    def compute_hash(self) -> str:
        """
        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        from vllm.config.utils import get_hash_factors, hash_factors

64
        ignored_factors: set[str] = set()
65
66
67
68
69
70
        factors = get_hash_factors(self, ignored_factors)
        return hash_factors(factors)

    @field_validator("backend", mode="before")
    @classmethod
    def validate_backend_before(cls, value: Any) -> Any:
71
72
73
74
75
        """Enable parsing of the `backend` enum type from string.

        The special value "auto" is treated as None, which triggers
        automatic backend selection.
        """
76
        if isinstance(value, str):
77
78
            if value.lower() == "auto":
                return None
79
80
            return AttentionBackendEnum[value.upper()]
        return value