selector.py 4.48 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import inspect
5
from functools import cache
6
from typing import cast, get_args
7
8
9
10

import torch

from vllm.attention.backends.abstract import AttentionBackend
11
12
13
14
from vllm.attention.backends.registry import (
    MAMBA_TYPE_TO_BACKEND_MAP,
    MambaAttentionBackendEnum,
)
15
from vllm.config.cache import CacheDType
16
from vllm.logger import init_logger
17
from vllm.utils.import_utils import resolve_obj_by_qualname
18
19
20
21

logger = init_logger(__name__)


22
23
24
def get_attn_backend(
    head_size: int,
    dtype: torch.dtype,
25
    kv_cache_dtype: str | None,
26
    block_size: int | None,
27
    use_mla: bool = False,
28
    has_sink: bool = False,
29
    use_sparse: bool = False,
30
    use_mm_prefix: bool = False,
31
    attn_type: str | None = None,
32
) -> type[AttentionBackend]:
33
    """Selects which attention backend to use and lazily imports it."""
34
35
36
37
38
39
40
41

    if kv_cache_dtype is not None:
        valid_cache_dtypes = get_args(CacheDType)
        assert kv_cache_dtype in valid_cache_dtypes, (
            f"Invalid kv_cache_dtype: {kv_cache_dtype}. "
            f"Valid values are: {valid_cache_dtypes}"
        )

42
43
44
45
46
    from vllm.config import get_current_vllm_config

    vllm_config = get_current_vllm_config()
    backend_enum = vllm_config.attention_config.backend

Joe Runde's avatar
Joe Runde committed
47
    return _cached_get_attn_backend(
48
        backend=backend_enum,
Joe Runde's avatar
Joe Runde committed
49
50
        head_size=head_size,
        dtype=dtype,
51
        kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
Joe Runde's avatar
Joe Runde committed
52
        block_size=block_size,
53
        use_mla=use_mla,
54
        has_sink=has_sink,
55
        use_sparse=use_sparse,
56
        use_mm_prefix=use_mm_prefix,
57
        attn_type=attn_type,
Joe Runde's avatar
Joe Runde committed
58
59
60
    )


61
@cache
Joe Runde's avatar
Joe Runde committed
62
def _cached_get_attn_backend(
63
    backend,
Joe Runde's avatar
Joe Runde committed
64
65
    head_size: int,
    dtype: torch.dtype,
66
67
    kv_cache_dtype: CacheDType | None,
    block_size: int | None,
68
    use_mla: bool = False,
69
    has_sink: bool = False,
70
    use_sparse: bool = False,
71
    use_mm_prefix: bool = False,
72
    attn_type: str | None = None,
73
) -> type[AttentionBackend]:
74
75
    from vllm.platforms import current_platform

76
77
78
79
80
81
82
83
    sig = inspect.signature(current_platform.get_attn_backend_cls)
    if "use_v1" in sig.parameters:
        logger.warning_once(
            "use_v1 parameter for get_attn_backend_cls is deprecated and will "
            "be removed in v0.13.0 or v1.0.0, whichever is soonest. Please "
            "remove it from your plugin code."
        )
        attention_cls = current_platform.get_attn_backend_cls(
84
            backend,
85
86
87
88
89
90
91
92
            head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            True,  # use_v1
            use_mla,
            has_sink,
            use_sparse,
93
            use_mm_prefix,
94
            attn_type,
95
96
97
        )
    else:
        attention_cls = current_platform.get_attn_backend_cls(
98
            backend,
99
100
101
102
103
104
105
            head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            use_mla,
            has_sink,
            use_sparse,
106
            use_mm_prefix,
107
            attn_type,
108
        )
109
110
    if not attention_cls:
        raise ValueError(
111
112
            f"Invalid attention backend for {current_platform.device_name}"
        )
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
    backend = resolve_obj_by_qualname(attention_cls)

    # Adjust kv cache layout if the selected backend requires a specific one
    required_layout = backend.get_required_kv_cache_layout()
    if required_layout is not None:
        from vllm.v1.attention.backends.utils import set_kv_cache_layout

        set_kv_cache_layout(required_layout)
        logger.info(
            "Using %s KV cache layout for %s backend.",
            required_layout,
            backend.get_name(),
        )

    return backend
128
129


130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def get_mamba_attn_backend(
    mamba_type: str,
) -> type[AttentionBackend]:
    """Select which mamba attention backend to use and lazily import it."""
    return _cached_get_mamba_attn_backend(mamba_type)


@cache
def _cached_get_mamba_attn_backend(
    mamba_type: str,
) -> type[AttentionBackend]:
    assert mamba_type and isinstance(mamba_type, str)

    selected_backend = None
    try:
        backend_name = MAMBA_TYPE_TO_BACKEND_MAP[mamba_type]
        selected_backend = MambaAttentionBackendEnum[backend_name]
    except KeyError as e:
        raise ValueError(
            f"Invalid mamba attention backend type: '{backend_name}'. Valid "
            f"backends are: {list(MambaAttentionBackendEnum.__members__.keys())}"
        ) from e

    mamba_attn_backend = selected_backend.get_class()
    return mamba_attn_backend