selector.py 4.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from functools import cache
5
from typing import NamedTuple, cast, get_args
6
7
8

import torch

9
from vllm.config.cache import CacheDType
10
from vllm.logger import init_logger
11
from vllm.utils.import_utils import resolve_obj_by_qualname
12
13
14
15
16
from vllm.v1.attention.backend import AttentionBackend, AttentionType
from vllm.v1.attention.backends.registry import (
    MAMBA_TYPE_TO_BACKEND_MAP,
    MambaAttentionBackendEnum,
)
17
18
19
20

logger = init_logger(__name__)


21
22
23
24
25
26
27
28
29
class AttentionSelectorConfig(NamedTuple):
    head_size: int
    dtype: torch.dtype
    kv_cache_dtype: CacheDType | None
    block_size: int | None
    use_mla: bool = False
    has_sink: bool = False
    use_sparse: bool = False
    use_mm_prefix: bool = False
30
    use_alibi_sqrt: bool = False
31
32
33
34
35
36
37
38
39
40
41
42
    attn_type: str = AttentionType.DECODER

    def __repr__(self):
        return (
            f"AttentionSelectorConfig(head_size={self.head_size}, "
            f"dtype={self.dtype}, "
            f"kv_cache_dtype={self.kv_cache_dtype}, "
            f"block_size={self.block_size}, "
            f"use_mla={self.use_mla}, "
            f"has_sink={self.has_sink}, "
            f"use_sparse={self.use_sparse}, "
            f"use_mm_prefix={self.use_mm_prefix}, "
43
            f"use_alibi_sqrt={self.use_alibi_sqrt}, "
44
45
46
47
            f"attn_type={self.attn_type})"
        )


48
49
50
def get_attn_backend(
    head_size: int,
    dtype: torch.dtype,
51
    kv_cache_dtype: str | None,
52
    block_size: int | None,
53
    use_mla: bool = False,
54
    has_sink: bool = False,
55
    use_sparse: bool = False,
56
    use_mm_prefix: bool = False,
57
    use_alibi_sqrt: bool = False,
58
    attn_type: str | None = None,
59
) -> type[AttentionBackend]:
60
    """Selects which attention backend to use and lazily imports it."""
61
62
63
64
65
66
67
68

    if kv_cache_dtype is not None:
        valid_cache_dtypes = get_args(CacheDType)
        assert kv_cache_dtype in valid_cache_dtypes, (
            f"Invalid kv_cache_dtype: {kv_cache_dtype}. "
            f"Valid values are: {valid_cache_dtypes}"
        )

69
70
71
72
73
    from vllm.config import get_current_vllm_config

    vllm_config = get_current_vllm_config()
    backend_enum = vllm_config.attention_config.backend

74
    attn_selector_config = AttentionSelectorConfig(
Joe Runde's avatar
Joe Runde committed
75
76
        head_size=head_size,
        dtype=dtype,
77
        kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
Joe Runde's avatar
Joe Runde committed
78
        block_size=block_size,
79
        use_mla=use_mla,
80
        has_sink=has_sink,
81
        use_sparse=use_sparse,
82
        use_mm_prefix=use_mm_prefix,
83
        use_alibi_sqrt=use_alibi_sqrt,
84
85
86
87
88
89
        attn_type=attn_type or AttentionType.DECODER,
    )

    return _cached_get_attn_backend(
        backend=backend_enum,
        attn_selector_config=attn_selector_config,
Joe Runde's avatar
Joe Runde committed
90
91
92
    )


93
@cache
Joe Runde's avatar
Joe Runde committed
94
def _cached_get_attn_backend(
95
    backend,
96
    attn_selector_config: AttentionSelectorConfig,
97
) -> type[AttentionBackend]:
98
99
    from vllm.platforms import current_platform

100
101
    attention_cls = current_platform.get_attn_backend_cls(
        backend,
102
        attn_selector_config=attn_selector_config,
103
    )
104
105
    if not attention_cls:
        raise ValueError(
106
107
            f"Invalid attention backend for {current_platform.device_name}"
        )
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    backend = resolve_obj_by_qualname(attention_cls)

    # Adjust kv cache layout if the selected backend requires a specific one
    required_layout = backend.get_required_kv_cache_layout()
    if required_layout is not None:
        from vllm.v1.attention.backends.utils import set_kv_cache_layout

        set_kv_cache_layout(required_layout)
        logger.info(
            "Using %s KV cache layout for %s backend.",
            required_layout,
            backend.get_name(),
        )

    return backend
123
124


125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def get_mamba_attn_backend(
    mamba_type: str,
) -> type[AttentionBackend]:
    """Select which mamba attention backend to use and lazily import it."""
    return _cached_get_mamba_attn_backend(mamba_type)


@cache
def _cached_get_mamba_attn_backend(
    mamba_type: str,
) -> type[AttentionBackend]:
    assert mamba_type and isinstance(mamba_type, str)

    selected_backend = None
    try:
        backend_name = MAMBA_TYPE_TO_BACKEND_MAP[mamba_type]
        selected_backend = MambaAttentionBackendEnum[backend_name]
    except KeyError as e:
        raise ValueError(
            f"Invalid mamba attention backend type: '{backend_name}'. Valid "
            f"backends are: {list(MambaAttentionBackendEnum.__members__.keys())}"
        ) from e

    mamba_attn_backend = selected_backend.get_class()
    return mamba_attn_backend