utils.py 12.5 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utility functions for attention-related v1 tests."""

from dataclasses import dataclass
6
from typing import Optional, Union
7
8
9
10

import pytest
import torch

11
from vllm.attention.backends.registry import _Backend
12
13
14
from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig,
                         LoadConfig, ModelConfig, ModelDType, ParallelConfig,
                         SchedulerConfig, VllmConfig)
15
from vllm.platforms import current_platform
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from vllm.utils import resolve_obj_by_qualname
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec


@dataclass
class BatchSpec:
    """Specification for a batch configuration (workload shape only)."""
    seq_lens: list[int]
    query_lens: list[int]

    name: str = "unnamed"

    @property
    def batch_size(self):
        return len(self.seq_lens)

    def __post_init__(self):
        assert len(self.seq_lens) == len(self.query_lens)

    def compute_num_tokens(self):
        return sum(self.query_lens)


def create_common_attn_metadata(
        batch_spec: BatchSpec,
        block_size: int,
        device: torch.device,
44
45
        max_block_idx: int = 1000,
        arange_block_indices: bool = False) -> CommonAttentionMetadata:
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    """Create CommonAttentionMetadata from a BatchSpec and ModelParams."""
    # Create query start locations
    query_start_loc = torch.zeros(batch_spec.batch_size + 1,
                                  dtype=torch.int32,
                                  device=device)
    query_start_loc[1:] = torch.tensor(batch_spec.query_lens,
                                       dtype=torch.int32,
                                       device=device).cumsum(0)
    query_start_loc_cpu = query_start_loc.cpu()
    num_tokens = batch_spec.compute_num_tokens()

    # Create sequence lengths
    seq_lens = torch.tensor(batch_spec.seq_lens,
                            dtype=torch.int32,
                            device=device)
    seq_lens_cpu = seq_lens.cpu()
62
    max_seq_len = int(seq_lens_cpu.max())
63
64
65
66
67
68
69
70

    # Create computed tokens (context length for each sequence)
    context_lens = [
        batch_spec.seq_lens[i] - batch_spec.query_lens[i]
        for i in range(batch_spec.batch_size)
    ]
    num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32)

71
    # Create block table and slot mapping
72
    max_blocks = (max(batch_spec.seq_lens) + block_size - 1) // block_size
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    if arange_block_indices:
        num_blocks = batch_spec.batch_size * max_blocks
        block_table_tensor = torch.arange(num_blocks,
                                          dtype=torch.int32,
                                          device=device).view(
                                              batch_spec.batch_size,
                                              max_blocks)
        slot_mapping = torch.arange(num_tokens,
                                    dtype=torch.int64,
                                    device=device).view(num_tokens)
    else:
        block_table_tensor = torch.randint(0,
                                           max_block_idx,
                                           (batch_spec.batch_size, max_blocks),
                                           dtype=torch.int32,
                                           device=device)
        slot_mapping = torch.randint(0,
                                     max_block_idx, (num_tokens, ),
                                     dtype=torch.int64,
                                     device=device)
93
94
95
96
97
98
99
100
101
102
103
104
105

    # Calculate max query length
    max_query_len = max(batch_spec.query_lens)

    return CommonAttentionMetadata(
        query_start_loc=query_start_loc,
        query_start_loc_cpu=query_start_loc_cpu,
        seq_lens=seq_lens,
        seq_lens_cpu=seq_lens_cpu,
        num_computed_tokens_cpu=num_computed_tokens_cpu,
        num_reqs=batch_spec.batch_size,
        num_actual_tokens=num_tokens,
        max_query_len=max_query_len,
106
        max_seq_len=max_seq_len,
107
108
        block_table_tensor=block_table_tensor,
        slot_mapping=slot_mapping,
109
        causal=True,
110
111
112
113
114
    )


def get_attention_backend(backend_name: _Backend):
    """Set up attention backend classes for testing.
115

116
117
118
    Args:
        backend_name: Name of the backend ("flash_attn", "flashinfer", etc.)
        vllm_config: VllmConfig instance
119

120
121
122
123
    Returns:
        Tuple of (backend_builder_class, backend_impl_class)
    """
    backend_map = {
124
        _Backend.FLASH_ATTN:
125
126
127
128
        ("vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
         if current_platform.is_cuda() else
         "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
         ),
129
        _Backend.FLASHINFER:
130
131
132
        "vllm.v1.attention.backends.flashinfer.FlashInferBackend",
        _Backend.FLEX_ATTENTION:
        "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend",
133
        _Backend.TRITON_ATTN:
134
        "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend",
135
136
        _Backend.TREE_ATTN:
        "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend",
137
        _Backend.XFORMERS:
138
        "vllm.v1.attention.backends.xformers.XFormersAttentionBackend",
Matthew Bonanni's avatar
Matthew Bonanni committed
139
140
        _Backend.CUTLASS_MLA:
        "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend",
141
        _Backend.FLASHMLA:
Matthew Bonanni's avatar
Matthew Bonanni committed
142
        "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend",
143
144
        _Backend.FLASH_ATTN_MLA:
        "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend",
145
146
        _Backend.FLASHINFER_MLA:
        "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend",
147
        _Backend.TRITON_MLA:
Matthew Bonanni's avatar
Matthew Bonanni committed
148
        "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend",
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
    }

    if backend_name not in backend_map:
        raise ValueError(f"Unknown backend: {backend_name}")

    backend_class_name = backend_map[backend_name]

    try:
        backend_class = resolve_obj_by_qualname(backend_class_name)
        return backend_class.get_builder_cls(), backend_class.get_impl_cls()
    except ImportError as e:
        pytest.skip(f"{backend_name} not available: {e}")


def create_standard_kv_cache_spec(
        vllm_config: VllmConfig) -> FullAttentionSpec:
    """Create a FullAttentionSpec from ModelParams only."""
    return FullAttentionSpec(
        block_size=vllm_config.cache_config.block_size,
        num_kv_heads=vllm_config.model_config.get_num_kv_heads(
            vllm_config.parallel_config),
        head_size=vllm_config.model_config.get_head_size(),
        dtype=vllm_config.model_config.dtype,
        sliding_window=vllm_config.model_config.get_sliding_window(),
    )


def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B",
                       tensor_parallel_size: int = 1,
                       max_model_len: int = 1024,
                       dtype: Union[ModelDType, torch.dtype] = "auto",
Matthew Bonanni's avatar
Matthew Bonanni committed
180
                       num_gpu_blocks: int = 1000,
181
182
183
                       block_size: int = 16,
                       max_num_seqs: int = 256,
                       max_num_batched_tokens: int = 8192,
Matthew Bonanni's avatar
Matthew Bonanni committed
184
                       enable_chunked_prefill: bool = True,
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
                       add_mock_model_methods: bool = True) -> VllmConfig:
    """Create a VllmConfig for testing with reasonable defaults."""

    model_config = ModelConfig(
        model=model_name,
        tokenizer=model_name,
        trust_remote_code=False,
        dtype=dtype,
        seed=0,
        max_model_len=max_model_len,
    )

    cache_config = CacheConfig(
        block_size=block_size,
        cache_dtype="auto",
        swap_space=0,
    )
    # Set cache blocks for testing
    #   (these may be set during initialization normally)
Matthew Bonanni's avatar
Matthew Bonanni committed
204
    cache_config.num_gpu_blocks = num_gpu_blocks
205
206
207
208
209
210
211
212
    cache_config.num_cpu_blocks = 0

    parallel_config = ParallelConfig(
        tensor_parallel_size=tensor_parallel_size, )

    scheduler_config = SchedulerConfig(
        max_num_seqs=max_num_seqs,
        max_num_batched_tokens=max_num_batched_tokens,
Matthew Bonanni's avatar
Matthew Bonanni committed
213
        enable_chunked_prefill=enable_chunked_prefill,
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
    )

    device_config = DeviceConfig()
    load_config = LoadConfig()
    compilation_config = CompilationConfig()

    if add_mock_model_methods:
        # Add mock methods to satisfy backends that need them
        # This is a workaround because tests don't build full, real models,
        # but some backends expect to query the model for layer-specific
        # parameters
        import types
        model_config.get_num_layers = types.MethodType(lambda self: 1,
                                                       model_config)
        model_config.get_sliding_window_for_layer = types.MethodType(
            lambda self, i: None, model_config)
        model_config.get_logits_soft_cap_for_layer = types.MethodType(
            lambda self, i: 0.0, model_config)
        model_config.get_sm_scale_for_layer = types.MethodType(
            lambda self, i: 1.0 / model_config.get_head_size()**0.5,
            model_config)

    return VllmConfig(
        model_config=model_config,
        cache_config=cache_config,
        parallel_config=parallel_config,
        scheduler_config=scheduler_config,
        device_config=device_config,
        load_config=load_config,
        compilation_config=compilation_config,
    )


def create_dummy_kv_cache(block_size: int,
                          num_kv_heads: int,
                          head_size: int,
                          dtype: torch.dtype,
                          device: torch.device,
                          num_blocks: int = 100) -> torch.Tensor:
    """Create a dummy KV cache tensor for testing."""
    kv_cache = torch.randn(
        num_blocks,
        2,  # K and V
        block_size,
        num_kv_heads,
        head_size,
        dtype=dtype,
        device=device)
    return kv_cache
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347


@dataclass
class BackendConfig:
    name: str
    env_vars: dict
    comp_config: dict  # compilation config
    specific_gpu_arch: Optional[tuple] = None


# Define all backend configurations of full cudagraph to be tested
full_cg_backend_configs = {
    # FA3 on Hopper
    "FA3":
    BackendConfig(name="FA3",
                  env_vars={
                      "VLLM_ATTENTION_BACKEND": "FLASH_ATTN",
                      "VLLM_FLASH_ATTN_VERSION": "3",
                      "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
                  },
                  comp_config={
                      "cudagraph_mode": "FULL",
                  },
                  specific_gpu_arch=(9, 0)),
    # FlashMLA on Hopper
    "FlashMLA":
    BackendConfig(name="FlashMLA",
                  env_vars={
                      "VLLM_ATTENTION_BACKEND": "FLASHMLA",
                  },
                  comp_config={
                      "cudagraph_mode": "FULL_AND_PIECEWISE",
                  },
                  specific_gpu_arch=(9, 0)),
    # Cutlass MLA on Blackwell
    "CutlassMLA":
    BackendConfig(
        name="CutlassMLA",
        env_vars={
            "VLLM_USE_V1": "1",
            "VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
            "FORCE_NUM_KV_SPLITS":
            "1",  # TODO: remove this when hang issue is fixed
        },
        comp_config={
            "cudagraph_mode": "FULL_AND_PIECEWISE",
        },
        specific_gpu_arch=(10, 0)),
    # FlashAttention MLA on Hopper
    "FlashAttentionMLA":
    BackendConfig(name="FlashAttentionMLA",
                  env_vars={
                      "VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
                      "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
                  },
                  comp_config={
                      "cudagraph_mode": "FULL_DECODE_ONLY",
                  },
                  specific_gpu_arch=(9, 0)),
    # FA2
    "FA2":
    BackendConfig(name="FA2",
                  env_vars={
                      "VLLM_ATTENTION_BACKEND": "FLASH_ATTN",
                      "VLLM_FLASH_ATTN_VERSION": "2",
                      "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
                  },
                  comp_config={
                      "cudagraph_mode": "FULL_AND_PIECEWISE",
                  }),
    # Triton Attention
    "TritonAttn":
    BackendConfig(name="TritonAttn",
                  env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"},
                  comp_config={
                      "cudagraph_mode": "FULL_AND_PIECEWISE",
                  }),
    # FlashInfer
    "FlashInfer":
    BackendConfig(name="FlashInfer",
                  env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
                  comp_config={
                      "cudagraph_mode": "FULL_AND_PIECEWISE",
                  }),
}