utils.py 11.5 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utility functions for attention-related v1 tests."""

from dataclasses import dataclass
6
from typing import Optional, Union
7
8
9
10

import pytest
import torch

11
from vllm.attention.backends.registry import _Backend
12
13
14
15
16
17
18
19
20
21
from vllm.config import (
    CacheConfig,
    CompilationConfig,
    DeviceConfig,
    LoadConfig,
    ModelConfig,
    ParallelConfig,
    SchedulerConfig,
    VllmConfig,
)
22
from vllm.config.model import ModelDType
23
from vllm.platforms import current_platform
24
25
26
27
28
29
30
31
from vllm.utils import resolve_obj_by_qualname
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec


@dataclass
class BatchSpec:
    """Specification for a batch configuration (workload shape only)."""
32

33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    seq_lens: list[int]
    query_lens: list[int]

    name: str = "unnamed"

    @property
    def batch_size(self):
        return len(self.seq_lens)

    def __post_init__(self):
        assert len(self.seq_lens) == len(self.query_lens)

    def compute_num_tokens(self):
        return sum(self.query_lens)


def create_common_attn_metadata(
50
51
52
53
54
55
    batch_spec: BatchSpec,
    block_size: int,
    device: torch.device,
    max_block_idx: int = 1000,
    arange_block_indices: bool = False,
) -> CommonAttentionMetadata:
56
57
    """Create CommonAttentionMetadata from a BatchSpec and ModelParams."""
    # Create query start locations
58
59
60
61
62
63
    query_start_loc = torch.zeros(
        batch_spec.batch_size + 1, dtype=torch.int32, device=device
    )
    query_start_loc[1:] = torch.tensor(
        batch_spec.query_lens, dtype=torch.int32, device=device
    ).cumsum(0)
64
65
66
67
    query_start_loc_cpu = query_start_loc.cpu()
    num_tokens = batch_spec.compute_num_tokens()

    # Create sequence lengths
68
    seq_lens = torch.tensor(batch_spec.seq_lens, dtype=torch.int32, device=device)
69
    seq_lens_cpu = seq_lens.cpu()
70
    max_seq_len = int(seq_lens_cpu.max())
71
72
73
74
75
76
77
78

    # Create computed tokens (context length for each sequence)
    context_lens = [
        batch_spec.seq_lens[i] - batch_spec.query_lens[i]
        for i in range(batch_spec.batch_size)
    ]
    num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32)

79
    # Create block table and slot mapping
80
    max_blocks = (max(batch_spec.seq_lens) + block_size - 1) // block_size
81
82
    if arange_block_indices:
        num_blocks = batch_spec.batch_size * max_blocks
83
84
85
86
87
88
        block_table_tensor = torch.arange(
            num_blocks, dtype=torch.int32, device=device
        ).view(batch_spec.batch_size, max_blocks)
        slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device).view(
            num_tokens
        )
89
    else:
90
91
92
93
94
95
96
97
98
99
        block_table_tensor = torch.randint(
            0,
            max_block_idx,
            (batch_spec.batch_size, max_blocks),
            dtype=torch.int32,
            device=device,
        )
        slot_mapping = torch.randint(
            0, max_block_idx, (num_tokens,), dtype=torch.int64, device=device
        )
100
101
102
103
104
105
106
107
108
109
110
111
112

    # Calculate max query length
    max_query_len = max(batch_spec.query_lens)

    return CommonAttentionMetadata(
        query_start_loc=query_start_loc,
        query_start_loc_cpu=query_start_loc_cpu,
        seq_lens=seq_lens,
        seq_lens_cpu=seq_lens_cpu,
        num_computed_tokens_cpu=num_computed_tokens_cpu,
        num_reqs=batch_spec.batch_size,
        num_actual_tokens=num_tokens,
        max_query_len=max_query_len,
113
        max_seq_len=max_seq_len,
114
115
        block_table_tensor=block_table_tensor,
        slot_mapping=slot_mapping,
116
        causal=True,
117
118
119
120
121
    )


def get_attention_backend(backend_name: _Backend):
    """Set up attention backend classes for testing.
122

123
124
125
    Args:
        backend_name: Name of the backend ("flash_attn", "flashinfer", etc.)
        vllm_config: VllmConfig instance
126

127
128
129
130
    Returns:
        Tuple of (backend_builder_class, backend_impl_class)
    """
    backend_map = {
131
132
133
134
135
136
        _Backend.FLASH_ATTN: (
            "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
            if current_platform.is_cuda()
            else "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
        ),
        _Backend.FLASHINFER: "vllm.v1.attention.backends.flashinfer.FlashInferBackend",
137
138
        _Backend.FLEX_ATTENTION: "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend",  # noqa: E501
        _Backend.TRITON_ATTN: "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend",  # noqa: E501
139
        _Backend.TREE_ATTN: "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend",
140
141
        _Backend.XFORMERS: "vllm.v1.attention.backends.xformers.XFormersAttentionBackend",  # noqa: E501
        _Backend.CUTLASS_MLA: "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend",  # noqa: E501
142
        _Backend.FLASHMLA: "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend",
143
144
145
        _Backend.FLASH_ATTN_MLA: "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend",  # noqa: E501
        _Backend.FLASHINFER_MLA: "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend",  # noqa: E501
        _Backend.TRITON_MLA: "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend",  # noqa: E501
146
147
148
149
150
151
152
153
154
155
156
157
158
159
    }

    if backend_name not in backend_map:
        raise ValueError(f"Unknown backend: {backend_name}")

    backend_class_name = backend_map[backend_name]

    try:
        backend_class = resolve_obj_by_qualname(backend_class_name)
        return backend_class.get_builder_cls(), backend_class.get_impl_cls()
    except ImportError as e:
        pytest.skip(f"{backend_name} not available: {e}")


160
def create_standard_kv_cache_spec(vllm_config: VllmConfig) -> FullAttentionSpec:
161
162
163
164
    """Create a FullAttentionSpec from ModelParams only."""
    return FullAttentionSpec(
        block_size=vllm_config.cache_config.block_size,
        num_kv_heads=vllm_config.model_config.get_num_kv_heads(
165
166
            vllm_config.parallel_config
        ),
167
168
169
170
171
172
        head_size=vllm_config.model_config.get_head_size(),
        dtype=vllm_config.model_config.dtype,
        sliding_window=vllm_config.model_config.get_sliding_window(),
    )


173
174
175
176
177
178
179
180
181
182
183
def create_vllm_config(
    model_name: str = "meta-llama/Meta-Llama-3-8B",
    tensor_parallel_size: int = 1,
    max_model_len: int = 1024,
    dtype: Union[ModelDType, torch.dtype] = "auto",
    num_gpu_blocks: int = 1000,
    block_size: int = 16,
    max_num_seqs: int = 256,
    max_num_batched_tokens: int = 8192,
    enable_chunked_prefill: bool = True,
    add_mock_model_methods: bool = True,
184
    hf_config_override: Optional[dict] = None,
185
) -> VllmConfig:
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    """Create a VllmConfig for testing with reasonable defaults."""

    model_config = ModelConfig(
        model=model_name,
        tokenizer=model_name,
        trust_remote_code=False,
        dtype=dtype,
        seed=0,
        max_model_len=max_model_len,
    )

    cache_config = CacheConfig(
        block_size=block_size,
        cache_dtype="auto",
        swap_space=0,
    )
    # Set cache blocks for testing
    #   (these may be set during initialization normally)
Matthew Bonanni's avatar
Matthew Bonanni committed
204
    cache_config.num_gpu_blocks = num_gpu_blocks
205
206
207
    cache_config.num_cpu_blocks = 0

    parallel_config = ParallelConfig(
208
209
        tensor_parallel_size=tensor_parallel_size,
    )
210
211
212
213

    scheduler_config = SchedulerConfig(
        max_num_seqs=max_num_seqs,
        max_num_batched_tokens=max_num_batched_tokens,
Matthew Bonanni's avatar
Matthew Bonanni committed
214
        enable_chunked_prefill=enable_chunked_prefill,
215
216
217
218
219
220
221
222
223
224
225
226
    )

    device_config = DeviceConfig()
    load_config = LoadConfig()
    compilation_config = CompilationConfig()

    if add_mock_model_methods:
        # Add mock methods to satisfy backends that need them
        # This is a workaround because tests don't build full, real models,
        # but some backends expect to query the model for layer-specific
        # parameters
        import types
227
228

        model_config.get_num_layers = types.MethodType(lambda self: 1, model_config)
229
        model_config.get_sliding_window_for_layer = types.MethodType(
230
231
            lambda self, i: None, model_config
        )
232
        model_config.get_logits_soft_cap_for_layer = types.MethodType(
233
234
            lambda self, i: 0.0, model_config
        )
235
        model_config.get_sm_scale_for_layer = types.MethodType(
236
237
            lambda self, i: 1.0 / model_config.get_head_size() ** 0.5, model_config
        )
238

239
240
241
    if hf_config_override:
        model_config.hf_config.update(hf_config_override)

242
243
244
245
246
247
248
249
250
251
252
    return VllmConfig(
        model_config=model_config,
        cache_config=cache_config,
        parallel_config=parallel_config,
        scheduler_config=scheduler_config,
        device_config=device_config,
        load_config=load_config,
        compilation_config=compilation_config,
    )


253
254
255
256
257
258
259
260
def create_dummy_kv_cache(
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    dtype: torch.dtype,
    device: torch.device,
    num_blocks: int = 100,
) -> torch.Tensor:
261
262
263
264
265
266
267
268
    """Create a dummy KV cache tensor for testing."""
    kv_cache = torch.randn(
        num_blocks,
        2,  # K and V
        block_size,
        num_kv_heads,
        head_size,
        dtype=dtype,
269
270
        device=device,
    )
271
    return kv_cache
272
273
274
275
276
277
278
279
280
281
282
283
284


@dataclass
class BackendConfig:
    name: str
    env_vars: dict
    comp_config: dict  # compilation config
    specific_gpu_arch: Optional[tuple] = None


# Define all backend configurations of full cudagraph to be tested
full_cg_backend_configs = {
    # FA3 on Hopper
285
286
287
288
289
290
291
292
293
294
295
296
    "FA3": BackendConfig(
        name="FA3",
        env_vars={
            "VLLM_ATTENTION_BACKEND": "FLASH_ATTN",
            "VLLM_FLASH_ATTN_VERSION": "3",
            "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
        },
        comp_config={
            "cudagraph_mode": "FULL",
        },
        specific_gpu_arch=(9, 0),
    ),
297
    # FlashMLA on Hopper
298
299
300
301
302
303
304
305
306
307
    "FlashMLA": BackendConfig(
        name="FlashMLA",
        env_vars={
            "VLLM_ATTENTION_BACKEND": "FLASHMLA",
        },
        comp_config={
            "cudagraph_mode": "FULL_AND_PIECEWISE",
        },
        specific_gpu_arch=(9, 0),
    ),
308
    # Cutlass MLA on Blackwell
309
    "CutlassMLA": BackendConfig(
310
311
312
        name="CutlassMLA",
        env_vars={
            "VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
313
            "FORCE_NUM_KV_SPLITS": "1",  # TODO: remove this when hang issue is fixed
314
315
316
317
        },
        comp_config={
            "cudagraph_mode": "FULL_AND_PIECEWISE",
        },
318
319
        specific_gpu_arch=(10, 0),
    ),
320
    # FlashAttention MLA on Hopper
321
322
323
324
325
326
327
328
329
330
331
    "FlashAttentionMLA": BackendConfig(
        name="FlashAttentionMLA",
        env_vars={
            "VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
            "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
        },
        comp_config={
            "cudagraph_mode": "FULL_DECODE_ONLY",
        },
        specific_gpu_arch=(9, 0),
    ),
332
    # FA2
333
334
335
336
337
338
339
340
341
342
343
    "FA2": BackendConfig(
        name="FA2",
        env_vars={
            "VLLM_ATTENTION_BACKEND": "FLASH_ATTN",
            "VLLM_FLASH_ATTN_VERSION": "2",
            "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
        },
        comp_config={
            "cudagraph_mode": "FULL_AND_PIECEWISE",
        },
    ),
344
    # Triton Attention
345
346
347
348
349
350
351
    "TritonAttn": BackendConfig(
        name="TritonAttn",
        env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"},
        comp_config={
            "cudagraph_mode": "FULL_AND_PIECEWISE",
        },
    ),
352
    # FlashInfer
353
354
355
356
357
358
359
    "FlashInfer": BackendConfig(
        name="FlashInfer",
        env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
        comp_config={
            "cudagraph_mode": "FULL_AND_PIECEWISE",
        },
    ),
360
}