selector.py 4.34 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from functools import cache
5
from typing import NamedTuple, cast, get_args
6
7
8

import torch

9
from vllm.config.cache import CacheDType
10
from vllm.logger import init_logger
11
from vllm.utils.import_utils import resolve_obj_by_qualname
12
13
14
15
16
from vllm.v1.attention.backend import AttentionBackend, AttentionType
from vllm.v1.attention.backends.registry import (
    MAMBA_TYPE_TO_BACKEND_MAP,
    MambaAttentionBackendEnum,
)
17
18
19
20

logger = init_logger(__name__)


21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class AttentionSelectorConfig(NamedTuple):
    head_size: int
    dtype: torch.dtype
    kv_cache_dtype: CacheDType | None
    block_size: int | None
    use_mla: bool = False
    has_sink: bool = False
    use_sparse: bool = False
    use_mm_prefix: bool = False
    attn_type: str = AttentionType.DECODER

    def __repr__(self):
        return (
            f"AttentionSelectorConfig(head_size={self.head_size}, "
            f"dtype={self.dtype}, "
            f"kv_cache_dtype={self.kv_cache_dtype}, "
            f"block_size={self.block_size}, "
            f"use_mla={self.use_mla}, "
            f"has_sink={self.has_sink}, "
            f"use_sparse={self.use_sparse}, "
            f"use_mm_prefix={self.use_mm_prefix}, "
            f"attn_type={self.attn_type})"
        )


46
47
48
def get_attn_backend(
    head_size: int,
    dtype: torch.dtype,
49
    kv_cache_dtype: str | None,
50
    block_size: int | None,
51
    use_mla: bool = False,
52
    has_sink: bool = False,
53
    use_sparse: bool = False,
54
    use_mm_prefix: bool = False,
55
    attn_type: str | None = None,
56
) -> type[AttentionBackend]:
57
    """Selects which attention backend to use and lazily imports it."""
58
59
60
61
62
63
64
65

    if kv_cache_dtype is not None:
        valid_cache_dtypes = get_args(CacheDType)
        assert kv_cache_dtype in valid_cache_dtypes, (
            f"Invalid kv_cache_dtype: {kv_cache_dtype}. "
            f"Valid values are: {valid_cache_dtypes}"
        )

66
67
68
69
70
    from vllm.config import get_current_vllm_config

    vllm_config = get_current_vllm_config()
    backend_enum = vllm_config.attention_config.backend

71
    attn_selector_config = AttentionSelectorConfig(
Joe Runde's avatar
Joe Runde committed
72
73
        head_size=head_size,
        dtype=dtype,
74
        kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
Joe Runde's avatar
Joe Runde committed
75
        block_size=block_size,
76
        use_mla=use_mla,
77
        has_sink=has_sink,
78
        use_sparse=use_sparse,
79
        use_mm_prefix=use_mm_prefix,
80
81
82
83
84
85
        attn_type=attn_type or AttentionType.DECODER,
    )

    return _cached_get_attn_backend(
        backend=backend_enum,
        attn_selector_config=attn_selector_config,
Joe Runde's avatar
Joe Runde committed
86
87
88
    )


89
@cache
Joe Runde's avatar
Joe Runde committed
90
def _cached_get_attn_backend(
91
    backend,
92
    attn_selector_config: AttentionSelectorConfig,
93
) -> type[AttentionBackend]:
94
95
    from vllm.platforms import current_platform

96
97
    attention_cls = current_platform.get_attn_backend_cls(
        backend,
98
        attn_selector_config=attn_selector_config,
99
    )
100
101
    if not attention_cls:
        raise ValueError(
102
103
            f"Invalid attention backend for {current_platform.device_name}"
        )
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    backend = resolve_obj_by_qualname(attention_cls)

    # Adjust kv cache layout if the selected backend requires a specific one
    required_layout = backend.get_required_kv_cache_layout()
    if required_layout is not None:
        from vllm.v1.attention.backends.utils import set_kv_cache_layout

        set_kv_cache_layout(required_layout)
        logger.info(
            "Using %s KV cache layout for %s backend.",
            required_layout,
            backend.get_name(),
        )

    return backend
119
120


121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def get_mamba_attn_backend(
    mamba_type: str,
) -> type[AttentionBackend]:
    """Select which mamba attention backend to use and lazily import it."""
    return _cached_get_mamba_attn_backend(mamba_type)


@cache
def _cached_get_mamba_attn_backend(
    mamba_type: str,
) -> type[AttentionBackend]:
    assert mamba_type and isinstance(mamba_type, str)

    selected_backend = None
    try:
        backend_name = MAMBA_TYPE_TO_BACKEND_MAP[mamba_type]
        selected_backend = MambaAttentionBackendEnum[backend_name]
    except KeyError as e:
        raise ValueError(
            f"Invalid mamba attention backend type: '{backend_name}'. Valid "
            f"backends are: {list(MambaAttentionBackendEnum.__members__.keys())}"
        ) from e

    mamba_attn_backend = selected_backend.get_class()
    return mamba_attn_backend