selector.py 3.75 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from functools import cache
5
from typing import cast, get_args
6
7
8
9

import torch

from vllm.attention.backends.abstract import AttentionBackend
10
11
12
13
from vllm.attention.backends.registry import (
    MAMBA_TYPE_TO_BACKEND_MAP,
    MambaAttentionBackendEnum,
)
14
from vllm.config.cache import CacheDType
15
from vllm.logger import init_logger
16
from vllm.utils.import_utils import resolve_obj_by_qualname
17
18
19
20

logger = init_logger(__name__)


21
22
23
def get_attn_backend(
    head_size: int,
    dtype: torch.dtype,
24
    kv_cache_dtype: str | None,
25
    block_size: int | None,
26
    use_mla: bool = False,
27
    has_sink: bool = False,
28
    use_sparse: bool = False,
29
    use_mm_prefix: bool = False,
30
    attn_type: str | None = None,
31
) -> type[AttentionBackend]:
32
    """Selects which attention backend to use and lazily imports it."""
33
34
35
36
37
38
39
40

    if kv_cache_dtype is not None:
        valid_cache_dtypes = get_args(CacheDType)
        assert kv_cache_dtype in valid_cache_dtypes, (
            f"Invalid kv_cache_dtype: {kv_cache_dtype}. "
            f"Valid values are: {valid_cache_dtypes}"
        )

41
42
43
44
45
    from vllm.config import get_current_vllm_config

    vllm_config = get_current_vllm_config()
    backend_enum = vllm_config.attention_config.backend

Joe Runde's avatar
Joe Runde committed
46
    return _cached_get_attn_backend(
47
        backend=backend_enum,
Joe Runde's avatar
Joe Runde committed
48
49
        head_size=head_size,
        dtype=dtype,
50
        kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
Joe Runde's avatar
Joe Runde committed
51
        block_size=block_size,
52
        use_mla=use_mla,
53
        has_sink=has_sink,
54
        use_sparse=use_sparse,
55
        use_mm_prefix=use_mm_prefix,
56
        attn_type=attn_type,
Joe Runde's avatar
Joe Runde committed
57
58
59
    )


60
@cache
Joe Runde's avatar
Joe Runde committed
61
def _cached_get_attn_backend(
62
    backend,
Joe Runde's avatar
Joe Runde committed
63
64
    head_size: int,
    dtype: torch.dtype,
65
66
    kv_cache_dtype: CacheDType | None,
    block_size: int | None,
67
    use_mla: bool = False,
68
    has_sink: bool = False,
69
    use_sparse: bool = False,
70
    use_mm_prefix: bool = False,
71
    attn_type: str | None = None,
72
) -> type[AttentionBackend]:
73
74
    from vllm.platforms import current_platform

75
76
77
78
79
80
81
82
83
84
85
86
    attention_cls = current_platform.get_attn_backend_cls(
        backend,
        head_size,
        dtype,
        kv_cache_dtype,
        block_size,
        use_mla,
        has_sink,
        use_sparse,
        use_mm_prefix,
        attn_type,
    )
87
88
    if not attention_cls:
        raise ValueError(
89
90
            f"Invalid attention backend for {current_platform.device_name}"
        )
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    backend = resolve_obj_by_qualname(attention_cls)

    # Adjust kv cache layout if the selected backend requires a specific one
    required_layout = backend.get_required_kv_cache_layout()
    if required_layout is not None:
        from vllm.v1.attention.backends.utils import set_kv_cache_layout

        set_kv_cache_layout(required_layout)
        logger.info(
            "Using %s KV cache layout for %s backend.",
            required_layout,
            backend.get_name(),
        )

    return backend
106
107


108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def get_mamba_attn_backend(
    mamba_type: str,
) -> type[AttentionBackend]:
    """Select which mamba attention backend to use and lazily import it."""
    return _cached_get_mamba_attn_backend(mamba_type)


@cache
def _cached_get_mamba_attn_backend(
    mamba_type: str,
) -> type[AttentionBackend]:
    assert mamba_type and isinstance(mamba_type, str)

    selected_backend = None
    try:
        backend_name = MAMBA_TYPE_TO_BACKEND_MAP[mamba_type]
        selected_backend = MambaAttentionBackendEnum[backend_name]
    except KeyError as e:
        raise ValueError(
            f"Invalid mamba attention backend type: '{backend_name}'. Valid "
            f"backends are: {list(MambaAttentionBackendEnum.__members__.keys())}"
        ) from e

    mamba_attn_backend = selected_backend.get_class()
    return mamba_attn_backend