utils.py 10.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utility functions for attention-related v1 tests."""

from dataclasses import dataclass

import pytest
import torch

10
from vllm.attention.backends.abstract import AttentionImpl
11
from vllm.attention.backends.registry import AttentionBackendEnum
12
13
14
15
16
17
18
19
20
21
from vllm.config import (
    CacheConfig,
    CompilationConfig,
    DeviceConfig,
    LoadConfig,
    ModelConfig,
    ParallelConfig,
    SchedulerConfig,
    VllmConfig,
)
22
from vllm.config.model import ModelDType
23
24
25
26
from vllm.v1.attention.backends.utils import (
    AttentionMetadataBuilder,
    CommonAttentionMetadata,
)
27
28
29
30
31
32
from vllm.v1.kv_cache_interface import FullAttentionSpec


@dataclass
class BatchSpec:
    """Specification for a batch configuration (workload shape only)."""
33

34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    seq_lens: list[int]
    query_lens: list[int]

    name: str = "unnamed"

    @property
    def batch_size(self):
        return len(self.seq_lens)

    def __post_init__(self):
        assert len(self.seq_lens) == len(self.query_lens)

    def compute_num_tokens(self):
        return sum(self.query_lens)


def create_common_attn_metadata(
51
52
53
54
55
56
    batch_spec: BatchSpec,
    block_size: int,
    device: torch.device,
    max_block_idx: int = 1000,
    arange_block_indices: bool = False,
) -> CommonAttentionMetadata:
57
58
    """Create CommonAttentionMetadata from a BatchSpec and ModelParams."""
    # Create query start locations
59
60
61
62
63
64
    query_start_loc = torch.zeros(
        batch_spec.batch_size + 1, dtype=torch.int32, device=device
    )
    query_start_loc[1:] = torch.tensor(
        batch_spec.query_lens, dtype=torch.int32, device=device
    ).cumsum(0)
65
66
67
68
    query_start_loc_cpu = query_start_loc.cpu()
    num_tokens = batch_spec.compute_num_tokens()

    # Create sequence lengths
69
    seq_lens = torch.tensor(batch_spec.seq_lens, dtype=torch.int32, device=device)
70
    seq_lens_cpu = seq_lens.cpu()
71
    max_seq_len = int(seq_lens_cpu.max())
72
73
74
75
76
77
78
79

    # Create computed tokens (context length for each sequence)
    context_lens = [
        batch_spec.seq_lens[i] - batch_spec.query_lens[i]
        for i in range(batch_spec.batch_size)
    ]
    num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32)

80
    # Create block table and slot mapping
81
    max_blocks = (max(batch_spec.seq_lens) + block_size - 1) // block_size
82
83
    if arange_block_indices:
        num_blocks = batch_spec.batch_size * max_blocks
84
85
86
87
88
89
        block_table_tensor = torch.arange(
            num_blocks, dtype=torch.int32, device=device
        ).view(batch_spec.batch_size, max_blocks)
        slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device).view(
            num_tokens
        )
90
    else:
91
92
93
94
95
96
97
98
99
100
        block_table_tensor = torch.randint(
            0,
            max_block_idx,
            (batch_spec.batch_size, max_blocks),
            dtype=torch.int32,
            device=device,
        )
        slot_mapping = torch.randint(
            0, max_block_idx, (num_tokens,), dtype=torch.int64, device=device
        )
101
102
103
104
105
106
107
108
109
110
111
112
113

    # Calculate max query length
    max_query_len = max(batch_spec.query_lens)

    return CommonAttentionMetadata(
        query_start_loc=query_start_loc,
        query_start_loc_cpu=query_start_loc_cpu,
        seq_lens=seq_lens,
        seq_lens_cpu=seq_lens_cpu,
        num_computed_tokens_cpu=num_computed_tokens_cpu,
        num_reqs=batch_spec.batch_size,
        num_actual_tokens=num_tokens,
        max_query_len=max_query_len,
114
        max_seq_len=max_seq_len,
115
116
        block_table_tensor=block_table_tensor,
        slot_mapping=slot_mapping,
117
        causal=True,
118
119
120
    )


121
def try_get_attention_backend(
122
    backend: AttentionBackendEnum,
123
124
) -> tuple[type[AttentionMetadataBuilder], type[AttentionImpl]]:
    """Try to get the attention backend class, skipping test if not found."""
125
    try:
126
        backend_class = backend.get_class()
127
128
        return backend_class.get_builder_cls(), backend_class.get_impl_cls()
    except ImportError as e:
129
        pytest.skip(f"{backend.name} not available: {e}")
130
        raise AssertionError("unreachable") from None
131
132


133
def create_standard_kv_cache_spec(vllm_config: VllmConfig) -> FullAttentionSpec:
134
135
136
137
    """Create a FullAttentionSpec from ModelParams only."""
    return FullAttentionSpec(
        block_size=vllm_config.cache_config.block_size,
        num_kv_heads=vllm_config.model_config.get_num_kv_heads(
138
139
            vllm_config.parallel_config
        ),
140
141
142
143
144
145
        head_size=vllm_config.model_config.get_head_size(),
        dtype=vllm_config.model_config.dtype,
        sliding_window=vllm_config.model_config.get_sliding_window(),
    )


146
147
148
149
def create_vllm_config(
    model_name: str = "meta-llama/Meta-Llama-3-8B",
    tensor_parallel_size: int = 1,
    max_model_len: int = 1024,
150
    dtype: ModelDType | torch.dtype = "auto",
151
152
153
154
155
156
    num_gpu_blocks: int = 1000,
    block_size: int = 16,
    max_num_seqs: int = 256,
    max_num_batched_tokens: int = 8192,
    enable_chunked_prefill: bool = True,
    add_mock_model_methods: bool = True,
157
    hf_config_override: dict | None = None,
158
) -> VllmConfig:
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
    """Create a VllmConfig for testing with reasonable defaults."""

    model_config = ModelConfig(
        model=model_name,
        tokenizer=model_name,
        trust_remote_code=False,
        dtype=dtype,
        seed=0,
        max_model_len=max_model_len,
    )

    cache_config = CacheConfig(
        block_size=block_size,
        cache_dtype="auto",
        swap_space=0,
    )
    # Set cache blocks for testing
    #   (these may be set during initialization normally)
Matthew Bonanni's avatar
Matthew Bonanni committed
177
    cache_config.num_gpu_blocks = num_gpu_blocks
178
179
180
    cache_config.num_cpu_blocks = 0

    parallel_config = ParallelConfig(
181
182
        tensor_parallel_size=tensor_parallel_size,
    )
183
184
185
186

    scheduler_config = SchedulerConfig(
        max_num_seqs=max_num_seqs,
        max_num_batched_tokens=max_num_batched_tokens,
Matthew Bonanni's avatar
Matthew Bonanni committed
187
        enable_chunked_prefill=enable_chunked_prefill,
188
189
190
191
192
193
194
195
196
197
198
199
    )

    device_config = DeviceConfig()
    load_config = LoadConfig()
    compilation_config = CompilationConfig()

    if add_mock_model_methods:
        # Add mock methods to satisfy backends that need them
        # This is a workaround because tests don't build full, real models,
        # but some backends expect to query the model for layer-specific
        # parameters
        import types
200
201

        model_config.get_num_layers = types.MethodType(lambda self: 1, model_config)
202
        model_config.get_sliding_window_for_layer = types.MethodType(
203
204
            lambda self, i: None, model_config
        )
205
        model_config.get_logits_soft_cap_for_layer = types.MethodType(
206
207
            lambda self, i: 0.0, model_config
        )
208
        model_config.get_sm_scale_for_layer = types.MethodType(
209
210
            lambda self, i: 1.0 / model_config.get_head_size() ** 0.5, model_config
        )
211

212
213
214
    if hf_config_override:
        model_config.hf_config.update(hf_config_override)

215
216
217
218
219
220
221
222
223
224
225
    return VllmConfig(
        model_config=model_config,
        cache_config=cache_config,
        parallel_config=parallel_config,
        scheduler_config=scheduler_config,
        device_config=device_config,
        load_config=load_config,
        compilation_config=compilation_config,
    )


226
227
228
229
230
231
232
233
def create_dummy_kv_cache(
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    dtype: torch.dtype,
    device: torch.device,
    num_blocks: int = 100,
) -> torch.Tensor:
234
235
236
237
238
239
240
241
    """Create a dummy KV cache tensor for testing."""
    kv_cache = torch.randn(
        num_blocks,
        2,  # K and V
        block_size,
        num_kv_heads,
        head_size,
        dtype=dtype,
242
243
        device=device,
    )
244
    return kv_cache
245
246
247
248
249
250
251


@dataclass
class BackendConfig:
    name: str
    env_vars: dict
    comp_config: dict  # compilation config
252
    specific_gpu_arch: tuple | None = None
253
254
255
256
257


# Define all backend configurations of full cudagraph to be tested
full_cg_backend_configs = {
    # FA3 on Hopper
258
259
260
261
262
263
264
265
266
267
268
269
    "FA3": BackendConfig(
        name="FA3",
        env_vars={
            "VLLM_ATTENTION_BACKEND": "FLASH_ATTN",
            "VLLM_FLASH_ATTN_VERSION": "3",
            "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
        },
        comp_config={
            "cudagraph_mode": "FULL",
        },
        specific_gpu_arch=(9, 0),
    ),
270
    # FlashMLA on Hopper
271
272
273
274
275
276
277
278
279
280
    "FlashMLA": BackendConfig(
        name="FlashMLA",
        env_vars={
            "VLLM_ATTENTION_BACKEND": "FLASHMLA",
        },
        comp_config={
            "cudagraph_mode": "FULL_AND_PIECEWISE",
        },
        specific_gpu_arch=(9, 0),
    ),
281
    # Cutlass MLA on Blackwell
282
    "CutlassMLA": BackendConfig(
283
284
285
        name="CutlassMLA",
        env_vars={
            "VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
286
287
288
289
290
291
292
293
294
295
296
        },
        comp_config={
            "cudagraph_mode": "FULL_AND_PIECEWISE",
        },
        specific_gpu_arch=(10, 0),
    ),
    # FlashInfer MLA on Blackwell
    "FlashInferMLA": BackendConfig(
        name="FlashInferMLA",
        env_vars={
            "VLLM_ATTENTION_BACKEND": "FLASHINFER_MLA",
297
298
299
300
        },
        comp_config={
            "cudagraph_mode": "FULL_AND_PIECEWISE",
        },
301
302
        specific_gpu_arch=(10, 0),
    ),
303
    # FlashAttention MLA on Hopper
304
305
306
307
308
309
310
311
312
313
314
    "FlashAttentionMLA": BackendConfig(
        name="FlashAttentionMLA",
        env_vars={
            "VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
            "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
        },
        comp_config={
            "cudagraph_mode": "FULL_DECODE_ONLY",
        },
        specific_gpu_arch=(9, 0),
    ),
315
    # FA2
316
317
318
319
320
321
322
323
324
325
326
    "FA2": BackendConfig(
        name="FA2",
        env_vars={
            "VLLM_ATTENTION_BACKEND": "FLASH_ATTN",
            "VLLM_FLASH_ATTN_VERSION": "2",
            "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
        },
        comp_config={
            "cudagraph_mode": "FULL_AND_PIECEWISE",
        },
    ),
327
    # Triton Attention
328
329
330
331
332
333
334
    "TritonAttn": BackendConfig(
        name="TritonAttn",
        env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"},
        comp_config={
            "cudagraph_mode": "FULL_AND_PIECEWISE",
        },
    ),
335
    # FlashInfer
336
337
338
339
340
341
342
    "FlashInfer": BackendConfig(
        name="FlashInfer",
        env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
        comp_config={
            "cudagraph_mode": "FULL_AND_PIECEWISE",
        },
    ),
343
}