runner.py 18 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

"""
Standard attention benchmark runner - shared utilities for non-MLA benchmarks.

This module provides helpers for running standard attention backends
(FlashAttention, Triton, FlashInfer) with real vLLM integration.
"""

11
import logging
12
import types
13
from contextlib import contextmanager
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28

import numpy as np
import torch
from batch_spec import parse_batch_spec, reorder_for_flashinfer
from common import BenchmarkConfig, BenchmarkResult, MockLayer, get_attention_scale

from vllm.config import (
    CacheConfig,
    CompilationConfig,
    DeviceConfig,
    LoadConfig,
    ModelConfig,
    ParallelConfig,
    SchedulerConfig,
    VllmConfig,
29
30
31
32
33
34
    set_current_vllm_config,
)
from vllm.v1.attention.backends.utils import (
    CommonAttentionMetadata,
    get_kv_cache_layout,
    set_kv_cache_layout,
35
36
37
38
39
40
41
42
)
from vllm.v1.kv_cache_interface import FullAttentionSpec

# ============================================================================
# Backend Configuration
# ============================================================================


43
44
45
def _get_backend_config(backend: str) -> dict:
    """
    Get backend configuration from AttentionBackendEnum.
46

47
48
49
    Args:
        backend: Backend name matching AttentionBackendEnum exactly
                 (e.g., "FLASH_ATTN", "TRITON_ATTN", "FLASHINFER")
50

51
52
53
54
55
56
57
58
59
60
    Returns:
        Dict with backend_class
    """
    from vllm.v1.attention.backends.registry import AttentionBackendEnum

    try:
        backend_enum = AttentionBackendEnum[backend]
        backend_class = backend_enum.get_class()
    except (KeyError, ValueError) as e:
        valid_backends = [b.name for b in AttentionBackendEnum if b.name != "CUSTOM"]
61
        raise ValueError(
62
63
64
65
            f"Unknown backend: {backend}. Valid backends: {valid_backends}"
        ) from e

    return {"backend_class": backend_class}
66
67


68
69
70
71
72
73
74
75
76
77
78
79
@contextmanager
def log_warnings_and_errors_only():
    """Temporarily set vLLM logger to WARNING level."""
    logger = logging.getLogger("vllm")
    old_level = logger.level
    logger.setLevel(logging.WARNING)
    try:
        yield
    finally:
        logger.setLevel(old_level)


80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# ============================================================================
# Metadata Building Helpers
# ============================================================================


def _build_common_attn_metadata(
    q_lens: list[int],
    kv_lens: list[int],
    block_size: int,
    device: torch.device,
) -> CommonAttentionMetadata:
    """Build CommonAttentionMetadata from query/kv lengths."""
    batch_size = len(q_lens)
    total_tokens = sum(q_lens)

    query_start_loc = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
    query_start_loc[1:] = torch.tensor(q_lens, dtype=torch.int32, device=device).cumsum(
        0
    )
    query_start_loc_cpu = query_start_loc.cpu()

    seq_lens = torch.tensor(kv_lens, dtype=torch.int32, device=device)
102
    max_seq_len = int(seq_lens.max().item())
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135

    max_blocks = (max(kv_lens) + block_size - 1) // block_size
    num_blocks = batch_size * max_blocks
    block_table_tensor = torch.arange(
        num_blocks, dtype=torch.int32, device=device
    ).view(batch_size, max_blocks)
    slot_mapping = torch.arange(total_tokens, dtype=torch.int64, device=device)

    max_query_len = max(q_lens)

    return CommonAttentionMetadata(
        query_start_loc=query_start_loc,
        query_start_loc_cpu=query_start_loc_cpu,
        seq_lens=seq_lens,
        num_reqs=batch_size,
        num_actual_tokens=total_tokens,
        max_query_len=max_query_len,
        max_seq_len=max_seq_len,
        block_table_tensor=block_table_tensor,
        slot_mapping=slot_mapping,
        causal=True,
    )


def _create_vllm_config(
    config: BenchmarkConfig,
    max_num_blocks: int,
) -> VllmConfig:
    """Create a VllmConfig for benchmarking with mock model methods."""
    model_config = ModelConfig(
        model="meta-llama/Meta-Llama-3-8B",
        tokenizer="meta-llama/Meta-Llama-3-8B",
        trust_remote_code=False,
136
        dtype="auto",  # Use model's native dtype
137
138
139
140
141
142
        seed=0,
        max_model_len=1024,
    )

    cache_config = CacheConfig(
        block_size=config.block_size,
143
        cache_dtype=config.kv_cache_dtype,
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    )
    cache_config.num_gpu_blocks = max_num_blocks
    cache_config.num_cpu_blocks = 0

    parallel_config = ParallelConfig(tensor_parallel_size=1)
    scheduler_config = SchedulerConfig(
        max_num_seqs=256,
        max_num_batched_tokens=8192,
        max_model_len=8192,
        is_encoder_decoder=False,
        enable_chunked_prefill=True,
    )
    device_config = DeviceConfig()
    load_config = LoadConfig()
    compilation_config = CompilationConfig()

    # Add mock methods for benchmark config values
    model_config.get_num_layers = types.MethodType(
        lambda self: config.num_layers, model_config
    )
    model_config.get_sliding_window_for_layer = types.MethodType(
        lambda self, i: None, model_config
    )
    model_config.get_logits_soft_cap_for_layer = types.MethodType(
        lambda self, i: 0.0, model_config
    )
    model_config.get_sm_scale_for_layer = types.MethodType(
        lambda self, i: 1.0 / config.head_dim**0.5, model_config
    )
    model_config.get_num_attention_heads = types.MethodType(
        lambda self, parallel_config=None: config.num_q_heads, model_config
    )
    model_config.get_num_kv_heads = types.MethodType(
        lambda self, parallel_config=None: config.num_kv_heads, model_config
    )
    model_config.get_head_size = types.MethodType(
        lambda self: config.head_dim, model_config
    )
    model_config.get_sliding_window = types.MethodType(lambda self: None, model_config)

    return VllmConfig(
        model_config=model_config,
        cache_config=cache_config,
        parallel_config=parallel_config,
        scheduler_config=scheduler_config,
        device_config=device_config,
        load_config=load_config,
        compilation_config=compilation_config,
    )


# ============================================================================
# Backend Initialization
# ============================================================================


def _create_backend_impl(
    backend_cfg: dict,
    config: BenchmarkConfig,
    device: torch.device,
204
    dtype: torch.dtype,
205
206
):
    """Create backend implementation instance."""
207
    backend_class = backend_cfg["backend_class"]
208
209
210
211
212
213
214
215
216
217

    scale = get_attention_scale(config.head_dim)

    impl = backend_class.get_impl_cls()(
        num_heads=config.num_q_heads,
        head_size=config.head_dim,
        scale=scale,
        num_kv_heads=config.num_kv_heads,
        alibi_slopes=None,
        sliding_window=None,
218
        kv_cache_dtype=config.kv_cache_dtype,
219
220
221
222
223
224
225
226
227
228
229
    )

    kv_cache_spec = FullAttentionSpec(
        block_size=config.block_size,
        num_kv_heads=config.num_kv_heads,
        head_size=config.head_dim,
        dtype=dtype,
    )

    layer = MockLayer(device, kv_cache_spec=kv_cache_spec)

230
    return backend_class, impl, layer
231
232
233
234
235
236
237


def _create_metadata_builder(
    backend_class,
    kv_cache_spec: FullAttentionSpec,
    vllm_config: VllmConfig,
    device: torch.device,
238
    backend_name: str = "",
239
240
):
    """Create metadata builder instance."""
241
242
243
244
245
    layer_names = ["layer_0"]
    builder_cls = backend_class.get_builder_cls()

    # Flashinfer needs get_per_layer_parameters mocked since we don't have
    # real model layers registered
246
    if backend_name == "FLASHINFER":
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
        import unittest.mock

        from vllm.v1.attention.backends.utils import PerLayerParameters

        def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls):
            head_size = vllm_config.model_config.get_head_size()
            return {
                layer_name: PerLayerParameters(
                    window_left=-1,  # No sliding window
                    logits_soft_cap=0.0,  # No soft cap
                    sm_scale=1.0 / (head_size**0.5),  # Standard scale
                )
                for layer_name in layer_names
            }

        with unittest.mock.patch(
            "vllm.v1.attention.backends.flashinfer.get_per_layer_parameters",
            mock_get_per_layer_parameters,
        ):
            return builder_cls(
                kv_cache_spec=kv_cache_spec,
                layer_names=layer_names,
                vllm_config=vllm_config,
                device=device,
            )

    return builder_cls(
274
        kv_cache_spec=kv_cache_spec,
275
        layer_names=layer_names,
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
        vllm_config=vllm_config,
        device=device,
    )


# ============================================================================
# Tensor Creation Helpers
# ============================================================================


def _create_input_tensors(
    config: BenchmarkConfig,
    total_q: int,
    device: torch.device,
    dtype: torch.dtype,
291
    quantize_query: bool = False,
292
) -> tuple:
293
294
295
296
297
298
299
300
301
302
    """Create Q, K, V input tensors for all layers.

    When quantize_query is True, queries are cast to fp8 to match backends
    that require query/key/value dtype consistency.
    """
    q_dtype = dtype
    if quantize_query:
        from vllm.platforms import current_platform

        q_dtype = current_platform.fp8_dtype()
303
304
305
    q_list = [
        torch.randn(
            total_q, config.num_q_heads, config.head_dim, device=device, dtype=dtype
306
        ).to(q_dtype)
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
        for _ in range(config.num_layers)
    ]
    k_list = [
        torch.randn(
            total_q, config.num_kv_heads, config.head_dim, device=device, dtype=dtype
        )
        for _ in range(config.num_layers)
    ]
    v_list = [
        torch.randn(
            total_q, config.num_kv_heads, config.head_dim, device=device, dtype=dtype
        )
        for _ in range(config.num_layers)
    ]
    return q_list, k_list, v_list


def _create_kv_cache(
    config: BenchmarkConfig,
    max_num_blocks: int,
327
    backend_class,
328
329
330
    device: torch.device,
    dtype: torch.dtype,
) -> list:
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
    """Create KV cache tensors for all layers using the backend's methods.

    Uses the backend's get_kv_cache_shape() and get_kv_cache_stride_order()
    to create the cache with the correct shape and memory layout.
    """
    # Get the logical shape from the backend
    cache_shape = backend_class.get_kv_cache_shape(
        num_blocks=max_num_blocks,
        block_size=config.block_size,
        num_kv_heads=config.num_kv_heads,
        head_size=config.head_dim,
    )

    # Get the stride order for custom memory layout
    try:
        stride_order = backend_class.get_kv_cache_stride_order()
        assert len(stride_order) == len(cache_shape)
    except (AttributeError, NotImplementedError):
        stride_order = tuple(range(len(cache_shape)))

    # Permute shape to physical layout order
    physical_shape = tuple(cache_shape[i] for i in stride_order)

    # Compute inverse permutation to get back to logical view
    inv_order = [stride_order.index(i) for i in range(len(stride_order))]

357
358
359
360
361
362
363
    # Use fp8 dtype for cache when requested.
    cache_dtype = dtype
    if config.kv_cache_dtype == "fp8":
        from vllm.platforms import current_platform

        cache_dtype = current_platform.fp8_dtype()

364
365
366
    cache_list = []
    for _ in range(config.num_layers):
        # Allocate in physical layout order (contiguous in memory)
367
        cache = torch.zeros(*physical_shape, device=device, dtype=cache_dtype)
368
369
370
371
        # Permute to logical view
        cache = cache.permute(*inv_order)
        cache_list.append(cache)

372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
    return cache_list


# ============================================================================
# Benchmark Execution
# ============================================================================


def _run_single_benchmark(
    config: BenchmarkConfig,
    impl,
    layer,
    q_list: list,
    k_list: list,
    v_list: list,
    cache_list: list,
    attn_metadata,
    device: torch.device,
    dtype: torch.dtype,
) -> tuple:
    """Run single benchmark iteration with warmup and timing loop."""
    total_q = q_list[0].shape[0]
    out = torch.empty(
        total_q, config.num_q_heads, config.head_dim, device=device, dtype=dtype
    )

    # Warmup
    for _ in range(config.warmup_iters):
        for i in range(config.num_layers):
            impl.forward(
                layer,
                q_list[i],
                k_list[i],
                v_list[i],
                cache_list[i],
                attn_metadata,
                output=out,
            )
410
    torch.accelerator.synchronize()
411

412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
    # Optionally capture a CUDA graph after warmup.
    # Graph replay eliminates CPU launch overhead so timings reflect pure
    # kernel time.
    if config.use_cuda_graphs:
        graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(graph):
            for i in range(config.num_layers):
                impl.forward(
                    layer,
                    q_list[i],
                    k_list[i],
                    v_list[i],
                    cache_list[i],
                    attn_metadata,
                    output=out,
                )
        benchmark_fn = graph.replay
    else:

        def benchmark_fn():
            for i in range(config.num_layers):
                impl.forward(
                    layer,
                    q_list[i],
                    k_list[i],
                    v_list[i],
                    cache_list[i],
                    attn_metadata,
                    output=out,
                )

443
444
445
446
447
448
449
    # Benchmark
    times = []
    for _ in range(config.repeats):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)

        start.record()
450
        benchmark_fn()
451
452
        end.record()

453
        torch.accelerator.synchronize()
454
455
456
457
458
459
        elapsed_ms = start.elapsed_time(end)
        times.append(elapsed_ms / 1000.0 / config.num_layers)  # seconds per layer

    mem_stats = {}
    if config.profile_memory:
        mem_stats = {
460
461
            "allocated_mb": torch.accelerator.memory_allocated(device) / 1024**2,
            "reserved_mb": torch.accelerator.memory_reserved(device) / 1024**2,
462
463
464
465
466
467
468
469
470
471
472
473
474
475
        }

    return times, mem_stats


# ============================================================================
# Public API
# ============================================================================


def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult:
    """
    Run standard attention benchmark with real kernels.

476
    Supports: FLASH_ATTN, TRITON_ATTN, FLASHINFER
477
478
479
480
481
482
483
484

    Args:
        config: Benchmark configuration

    Returns:
        BenchmarkResult with timing and memory statistics
    """
    device = torch.device(config.device)
485
    torch.accelerator.set_device_index(device)
486
487
488
489
490

    backend_cfg = _get_backend_config(config.backend)

    requests = parse_batch_spec(config.batch_spec)

491
    if config.backend == "FLASHINFER":
492
493
494
495
496
497
        requests = reorder_for_flashinfer(requests)

    q_lens = [r.q_len for r in requests]
    kv_lens = [r.kv_len for r in requests]
    total_q = sum(q_lens)
    max_kv = max(kv_lens)
498
    batch_size = len(q_lens)
499

500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
    # Calculate total blocks needed: batch_size * max_blocks_per_request
    max_blocks_per_request = (max_kv + config.block_size - 1) // config.block_size
    max_num_blocks = batch_size * max_blocks_per_request

    # Suppress vLLM logs during setup to reduce spam
    with log_warnings_and_errors_only():
        # Create vllm_config first - uses model's native dtype via "auto"
        vllm_config = _create_vllm_config(config, max_num_blocks)
        dtype = vllm_config.model_config.dtype

        # Wrap everything in set_current_vllm_config context
        # This is required for backends like flashinfer that need global config
        with set_current_vllm_config(vllm_config):
            backend_class, impl, layer = _create_backend_impl(
                backend_cfg, config, device, dtype
            )
516

517
518
519
520
521
522
            # Set KV cache layout if the backend requires a specific one
            # (e.g., FlashInfer requires HND on SM100/Blackwell for TRTLLM attention)
            required_layout = backend_class.get_required_kv_cache_layout()
            if required_layout is not None:
                set_kv_cache_layout(required_layout)
                get_kv_cache_layout.cache_clear()
523

524
525
526
            common_metadata = _build_common_attn_metadata(
                q_lens, kv_lens, config.block_size, device
            )
527

528
529
530
531
532
533
            kv_cache_spec = FullAttentionSpec(
                block_size=config.block_size,
                num_kv_heads=config.num_kv_heads,
                head_size=config.head_dim,
                dtype=dtype,
            )
534

535
536
537
            builder = _create_metadata_builder(
                backend_class, kv_cache_spec, vllm_config, device, config.backend
            )
538

539
540
541
542
            attn_metadata = builder.build(
                common_prefix_len=0,
                common_attn_metadata=common_metadata,
            )
543

544
545
546
547
            # Only quantize queries when the impl supports it
            quantize_query = config.kv_cache_dtype.startswith("fp8") and getattr(
                impl, "supports_quant_query_input", False
            )
548
            q_list, k_list, v_list = _create_input_tensors(
549
                config, total_q, device, dtype, quantize_query=quantize_query
550
            )
551

552
553
554
            cache_list = _create_kv_cache(
                config, max_num_blocks, backend_class, device, dtype
            )
555

556
557
558
559
560
561
562
563
564
565
566
567
            times, mem_stats = _run_single_benchmark(
                config,
                impl,
                layer,
                q_list,
                k_list,
                v_list,
                cache_list,
                attn_metadata,
                device,
                dtype,
            )
568
569
570
571
572
573
574
575
576
577
578
579
580
581

    mean_time = np.mean(times)
    throughput = total_q / mean_time if mean_time > 0 else 0

    return BenchmarkResult(
        config=config,
        mean_time=mean_time,
        std_time=np.std(times),
        min_time=np.min(times),
        max_time=np.max(times),
        throughput_tokens_per_sec=throughput,
        memory_allocated_mb=mem_stats.get("allocated_mb"),
        memory_reserved_mb=mem_stats.get("reserved_mb"),
    )