attn_utils.py 8.07 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
4
from typing import Any, cast
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
7
8
9

import torch

from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
10
11
from vllm.v1.attention.backend import (
    AttentionBackend,
Woosuk Kwon's avatar
Woosuk Kwon committed
12
13
14
    AttentionMetadataBuilder,
    CommonAttentionMetadata,
)
15
from vllm.v1.attention.backends.utils import get_dcp_local_seq_lens
Woosuk Kwon's avatar
Woosuk Kwon committed
16
from vllm.v1.kv_cache_interface import (
17
    AttentionSpec,
Woosuk Kwon's avatar
Woosuk Kwon committed
18
19
20
21
22
23
24
25
    KVCacheConfig,
    KVCacheSpec,
)
from vllm.v1.worker.utils import bind_kv_cache


def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
    kv_cache_spec: dict[str, KVCacheSpec] = {}
26
27
    layer_type = cast(type[Any], AttentionLayerBase)
    attn_layers = get_layers_from_vllm_config(vllm_config, layer_type)
Woosuk Kwon's avatar
Woosuk Kwon committed
28
29
30
31
32
33
34
35
    for layer_name, attn_module in attn_layers.items():
        # Skip modules that don't need KV cache (eg encoder-only attention)
        if spec := attn_module.get_kv_cache_spec(vllm_config):
            kv_cache_spec[layer_name] = spec
    return kv_cache_spec


def init_attn_backend(
36
    kv_cache_config: KVCacheConfig, vllm_config: VllmConfig, device: torch.device
Woosuk Kwon's avatar
Woosuk Kwon committed
37
):
38
    attn_backends: dict[str, type[AttentionBackend]] = {}
Woosuk Kwon's avatar
Woosuk Kwon committed
39
40
41
42
43
44
    attn_metadata_builders: list[AttentionMetadataBuilder] = []
    flashinfer_workspace: torch.Tensor | None = None
    for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
        layer_names = kv_cache_group_spec.layer_names
        any_layer_name = next(iter(layer_names))

45
46
        layer_type = cast(type[Any], AttentionLayerBase)
        attn_layers = get_layers_from_vllm_config(vllm_config, layer_type, layer_names)
Woosuk Kwon's avatar
Woosuk Kwon committed
47
48
49
50
51
        attn_backend = attn_layers[any_layer_name].get_attn_backend()
        for layer_name in layer_names:
            attn_backends[layer_name] = attn_backend

        attn_metadata_builder = attn_backend.get_builder_cls()(
52
            kv_cache_group_spec.kv_cache_spec, layer_names, vllm_config, device
Woosuk Kwon's avatar
Woosuk Kwon committed
53
54
55
        )
        attn_metadata_builders.append(attn_metadata_builder)  # type: ignore

56
        if attn_backend.get_name() == "FLASHINFER":
Woosuk Kwon's avatar
Woosuk Kwon committed
57
58
59
60
61
62
63
            if flashinfer_workspace is None:
                flashinfer_workspace = attn_metadata_builder._get_workspace_buffer()
            else:
                attn_metadata_builder.set_workspace_buffer(flashinfer_workspace)
    return attn_backends, attn_metadata_builders


64
def _allocate_kv_cache(kv_cache_config: KVCacheConfig, device: torch.device):
Woosuk Kwon's avatar
Woosuk Kwon committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
    for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
        tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device)
        for layer_name in kv_cache_tensor.shared_by:
            kv_cache_raw_tensors[layer_name] = tensor

    layer_names = set()
    for group in kv_cache_config.kv_cache_groups:
        for layer_name in group.layer_names:
            layer_names.add(layer_name)
    assert layer_names == set(kv_cache_raw_tensors.keys()), (
        "Some layers are not correctly initialized"
    )
    return kv_cache_raw_tensors


def _reshape_kv_cache(
    kv_cache_config: KVCacheConfig,
    kv_cache_raw_tensors: dict[str, torch.Tensor],
    attn_backends: dict[str, AttentionBackend],
) -> dict[str, torch.Tensor]:
    kv_caches: dict[str, torch.Tensor] = {}
    for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
        kv_cache_spec = kv_cache_group_spec.kv_cache_spec
89
        assert isinstance(kv_cache_spec, AttentionSpec)
Woosuk Kwon's avatar
Woosuk Kwon committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        for layer_name in kv_cache_group_spec.layer_names:
            raw_tensor = kv_cache_raw_tensors[layer_name]
            assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
            num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes

            attn_backend = attn_backends[layer_name]
            kv_cache_shape = attn_backend.get_kv_cache_shape(
                num_blocks,
                kv_cache_spec.block_size,
                kv_cache_spec.num_kv_heads,
                kv_cache_spec.head_size,
            )

            # FIXME(woosuk): Add kv_cache_stride_order to all attention backends.
            try:
                kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
                assert len(kv_cache_stride_order) == len(kv_cache_shape)
            except (AttributeError, NotImplementedError):
                kv_cache_stride_order = tuple(range(len(kv_cache_shape)))

            kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
            inv_order = [
                kv_cache_stride_order.index(i)
                for i in range(len(kv_cache_stride_order))
            ]

            dtype = kv_cache_spec.dtype
            raw_tensor = raw_tensor.view(dtype)
            raw_tensor = raw_tensor.view(kv_cache_shape)
            kv_caches[layer_name] = raw_tensor.permute(*inv_order)
    return kv_caches


def init_kv_cache(
    runner_kv_caches: list[torch.Tensor],
    forward_context: dict[str, Any],
    kv_cache_config: KVCacheConfig,
    attn_backends: dict[str, AttentionBackend],
    device: torch.device,
129
) -> dict[str, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
130
131
132
    kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device)
    kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors, attn_backends)
    bind_kv_cache(kv_caches, forward_context, runner_kv_caches)
133
    return kv_caches
Woosuk Kwon's avatar
Woosuk Kwon committed
134
135


136
def build_slot_mappings_by_layer(
137
    slot_mappings: torch.Tensor, kv_cache_config: KVCacheConfig
138
139
) -> dict[str, torch.Tensor]:
    slot_mappings_by_layer: dict[str, torch.Tensor] = {}
140
141
    kv_cache_groups = kv_cache_config.kv_cache_groups
    for slot_mapping, kv_cache_group in zip(slot_mappings, kv_cache_groups):
142
143
144
145
146
        for layer_name in kv_cache_group.layer_names:
            slot_mappings_by_layer[layer_name] = slot_mapping
    return slot_mappings_by_layer


147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def prepare_dcp_local_seq_lens(
    dcp_local_seq_lens: torch.Tensor,
    seq_lens: torch.Tensor,
    num_reqs: int,
    dcp_size: int,
    dcp_rank: int,
    cp_kv_cache_interleave_size: int,
) -> None:
    """Populate the persistent DCP local seq_lens buffer (CUDA graph safe)."""
    if dcp_size <= 1:
        return

    local_seq_lens = get_dcp_local_seq_lens(
        seq_lens[:num_reqs],
        dcp_size=dcp_size,
        dcp_rank=dcp_rank,
        cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
    )
    dcp_local_seq_lens[:num_reqs].copy_(local_seq_lens, non_blocking=True)
    dcp_local_seq_lens[num_reqs:].zero_()


Woosuk Kwon's avatar
Woosuk Kwon committed
169
170
171
172
def build_attn_metadata(
    attn_metadata_builders: list[AttentionMetadataBuilder],
    num_reqs: int,
    num_tokens: int,
173
174
    query_start_loc_gpu: torch.Tensor,
    query_start_loc_cpu: torch.Tensor,
175
    max_query_len: int,
176
    seq_lens: torch.Tensor,
177
    max_seq_len: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
178
179
180
    block_tables: Sequence[torch.Tensor],
    slot_mappings: torch.Tensor,
    kv_cache_config: KVCacheConfig,
181
    dcp_local_seq_lens: torch.Tensor | None = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
182
) -> dict[str, Any]:
183
    seq_lens = seq_lens[:num_reqs]
Woosuk Kwon's avatar
Woosuk Kwon committed
184

185
186
187
    if dcp_local_seq_lens is not None:
        dcp_local_seq_lens = dcp_local_seq_lens[:num_reqs]

Woosuk Kwon's avatar
Woosuk Kwon committed
188
189
190
191
192
193
194
195
196
    attn_metadata: dict[str, Any] = {}
    kv_cache_groups = kv_cache_config.kv_cache_groups
    for i, kv_cache_spec in enumerate(kv_cache_groups):
        block_table = block_tables[i]
        slot_mapping = slot_mappings[i]

        common_attn_metadata = CommonAttentionMetadata(
            query_start_loc=query_start_loc_gpu,
            query_start_loc_cpu=query_start_loc_cpu,
197
            seq_lens=seq_lens,
Woosuk Kwon's avatar
Woosuk Kwon committed
198
199
200
201
202
203
204
            max_seq_len=max_seq_len,
            num_reqs=num_reqs,
            num_actual_tokens=num_tokens,
            max_query_len=max_query_len,
            block_table_tensor=block_table,
            slot_mapping=slot_mapping,
            causal=True,
205
            dcp_local_seq_lens=dcp_local_seq_lens,
Woosuk Kwon's avatar
Woosuk Kwon committed
206
207
208
209
        )

        attn_metadata_builder = attn_metadata_builders[i]
        metadata = attn_metadata_builder.build(
210
            common_prefix_len=0, common_attn_metadata=common_attn_metadata
Woosuk Kwon's avatar
Woosuk Kwon committed
211
212
213
214
        )
        for layer_name in kv_cache_spec.layer_names:
            attn_metadata[layer_name] = metadata
    return attn_metadata