selector.py 4.33 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import inspect
5
from functools import cache
6
from typing import cast, get_args
7
8
9
10

import torch

from vllm.attention.backends.abstract import AttentionBackend
11
12
13
14
from vllm.attention.backends.registry import (
    MAMBA_TYPE_TO_BACKEND_MAP,
    MambaAttentionBackendEnum,
)
15
from vllm.config.cache import CacheDType
16
from vllm.logger import init_logger
17
from vllm.utils.import_utils import resolve_obj_by_qualname
18
19
20
21

logger = init_logger(__name__)


22
23
24
def get_attn_backend(
    head_size: int,
    dtype: torch.dtype,
25
    kv_cache_dtype: str | None,
26
    block_size: int | None,
27
    use_mla: bool = False,
28
    has_sink: bool = False,
29
    use_sparse: bool = False,
30
    attn_type: str | None = None,
31
) -> type[AttentionBackend]:
32
    """Selects which attention backend to use and lazily imports it."""
33
34
35
36
37
38
39
40

    if kv_cache_dtype is not None:
        valid_cache_dtypes = get_args(CacheDType)
        assert kv_cache_dtype in valid_cache_dtypes, (
            f"Invalid kv_cache_dtype: {kv_cache_dtype}. "
            f"Valid values are: {valid_cache_dtypes}"
        )

41
42
43
44
45
    from vllm.config import get_current_vllm_config

    vllm_config = get_current_vllm_config()
    backend_enum = vllm_config.attention_config.backend

Joe Runde's avatar
Joe Runde committed
46
    return _cached_get_attn_backend(
47
        backend=backend_enum,
Joe Runde's avatar
Joe Runde committed
48
49
        head_size=head_size,
        dtype=dtype,
50
        kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
Joe Runde's avatar
Joe Runde committed
51
        block_size=block_size,
52
        use_mla=use_mla,
53
        has_sink=has_sink,
54
        use_sparse=use_sparse,
55
        attn_type=attn_type,
Joe Runde's avatar
Joe Runde committed
56
57
58
    )


59
@cache
Joe Runde's avatar
Joe Runde committed
60
def _cached_get_attn_backend(
61
    backend,
Joe Runde's avatar
Joe Runde committed
62
63
    head_size: int,
    dtype: torch.dtype,
64
65
    kv_cache_dtype: CacheDType | None,
    block_size: int | None,
66
    use_mla: bool = False,
67
    has_sink: bool = False,
68
    use_sparse: bool = False,
69
    attn_type: str | None = None,
70
) -> type[AttentionBackend]:
71
72
    from vllm.platforms import current_platform

73
74
75
76
77
78
79
80
    sig = inspect.signature(current_platform.get_attn_backend_cls)
    if "use_v1" in sig.parameters:
        logger.warning_once(
            "use_v1 parameter for get_attn_backend_cls is deprecated and will "
            "be removed in v0.13.0 or v1.0.0, whichever is soonest. Please "
            "remove it from your plugin code."
        )
        attention_cls = current_platform.get_attn_backend_cls(
81
            backend,
82
83
84
85
86
87
88
89
            head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            True,  # use_v1
            use_mla,
            has_sink,
            use_sparse,
90
            attn_type,
91
92
93
        )
    else:
        attention_cls = current_platform.get_attn_backend_cls(
94
            backend,
95
96
97
98
99
100
101
            head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            use_mla,
            has_sink,
            use_sparse,
102
            attn_type,
103
        )
104
105
    if not attention_cls:
        raise ValueError(
106
107
            f"Invalid attention backend for {current_platform.device_name}"
        )
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    backend = resolve_obj_by_qualname(attention_cls)

    # Adjust kv cache layout if the selected backend requires a specific one
    required_layout = backend.get_required_kv_cache_layout()
    if required_layout is not None:
        from vllm.v1.attention.backends.utils import set_kv_cache_layout

        set_kv_cache_layout(required_layout)
        logger.info(
            "Using %s KV cache layout for %s backend.",
            required_layout,
            backend.get_name(),
        )

    return backend
123
124


125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def get_mamba_attn_backend(
    mamba_type: str,
) -> type[AttentionBackend]:
    """Select which mamba attention backend to use and lazily import it."""
    return _cached_get_mamba_attn_backend(mamba_type)


@cache
def _cached_get_mamba_attn_backend(
    mamba_type: str,
) -> type[AttentionBackend]:
    assert mamba_type and isinstance(mamba_type, str)

    selected_backend = None
    try:
        backend_name = MAMBA_TYPE_TO_BACKEND_MAP[mamba_type]
        selected_backend = MambaAttentionBackendEnum[backend_name]
    except KeyError as e:
        raise ValueError(
            f"Invalid mamba attention backend type: '{backend_name}'. Valid "
            f"backends are: {list(MambaAttentionBackendEnum.__members__.keys())}"
        ) from e

    mamba_attn_backend = selected_backend.get_class()
    return mamba_attn_backend