"vllm/vscode:/vscode.git/clone" did not exist on "be48360c1fb9284804f9e1cae23b58e23e762877"
test_fusion_attn.py 20.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import copy
4
5
6
7
8
from typing import Optional

import pytest
import torch._dynamo

9
from tests.compile.backend import LazyInitPass, TestBackend
10
from tests.models.utils import check_outputs_equal
11
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
12
from vllm import LLM, SamplingParams
13
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
14
from vllm.attention import Attention, AttentionMetadata
15
from vllm.attention.backends.registry import _Backend
16
from vllm.attention.selector import global_force_attn_backend_context_manager
17
from vllm.compilation.fusion import QUANT_OPS
18
19
20
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.noop_elimination import NoOpEliminationPass
21
from vllm.compilation.post_cleanup import PostCleanupPass
22
23
24
25
26
27
28
29
30
31
from vllm.config import (
    CacheConfig,
    CompilationConfig,
    CompilationLevel,
    ModelConfig,
    PassConfig,
    SchedulerConfig,
    VllmConfig,
    set_current_vllm_config,
)
32
33
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.quantization.utils.quant_utils import (
34
35
36
37
38
    QuantKey,
    kFp8StaticTensorSym,
    kNvfp4Quant,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
39
from vllm.platforms import current_platform
40
from vllm.utils import is_torch_equal_or_newer
41
42
43
from vllm.v1.kv_cache_interface import AttentionSpec

FP8_DTYPE = current_platform.fp8_dtype()
44
FP4_DTYPE = torch.uint8
45
46
47
48
49
50
51

# globals needed for string-import custom Dynamo backend field
backend: Optional[TestBackend] = None
backend_unfused: Optional[TestBackend] = None


@pytest.mark.parametrize(
52
53
    "model, quant_key", [("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)]
)
54
@pytest.mark.parametrize("use_triton_fa", [True, False])
55
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
56
57
58
59
60
61
@pytest.mark.skipif(
    not current_platform.is_rocm(), reason="V0 attn quant fusion only on ROCm"
)
def test_attention_fusion_v0(
    example_prompts, monkeypatch, model: str, quant_key: QuantKey, use_triton_fa: bool
):
62
63
64
65
66
67
68
    # Clean Dynamo cache to avoid reusing other test cases
    # (for some reason the reset at the end is not enough)
    torch._dynamo.reset()

    # Use global backends
    global backend, backend_unfused

69
    monkeypatch.setenv("VLLM_USE_V1", "1")
70
71
72
73
74
75
76
77
78
79
80
    monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", str(int(use_triton_fa)))

    # Prompt 4 seems too open-ended, differs between fused and unfused
    # (both outputs look reasonable though)
    prompts = example_prompts[:4] + example_prompts[5:]

    compile_config = CompilationConfig(
        # DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
        # DYNAMO_ONCE does not properly propagate shapes.
        level=CompilationLevel.DYNAMO_AS_IS,
        backend="tests.compile.test_fusion_attn.backend_unfused",
81
        custom_ops=["+quant_fp8"],
82
    )
83
84
85
86
87
88
89
    vllm_config = VllmConfig(
        compilation_config=compile_config,
        model_config=ModelConfig(
            model=model,
            dtype=torch.bfloat16,
        ),
    )
90
91
    backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))

92
93
94
95
96
97
98
    llm = LLM(
        model,
        enforce_eager=True,
        compilation_config=compile_config,
        gpu_memory_utilization=0.5,
        max_model_len=2048,
    )
99

100
    sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_p=0.95)
101
102
103
104
105
106
107
108
109
110

    unfused_output = llm.generate(prompts, sampling_params)
    backend_unfused = None  # Reset backend to make sure llm gets released
    del llm

    compile_config = CompilationConfig(
        # DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
        # DYNAMO_ONCE does not properly propagate shapes.
        level=CompilationLevel.DYNAMO_AS_IS,
        backend="tests.compile.test_fusion_attn.backend",
111
        custom_ops=["+quant_fp8"],
112
    )
113
114
115
116
117
118
119
    vllm_config = VllmConfig(
        compilation_config=compile_config,
        model_config=ModelConfig(
            model=model,
            dtype=torch.bfloat16,
        ),
    )
120
121
122

    # AttnFusionPass needs attention layers to be registered in config upon init
    # so we initialize it during compilation.
123
    attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
124
    backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass)
125
126
127
128
129
130
131
    llm2 = LLM(
        model,
        enforce_eager=True,
        compilation_config=compile_config,
        gpu_memory_utilization=0.5,
        max_model_len=2048,
    )
132
133
134

    # check support
    attn_fusion_supported = [
135
        layer.impl.fused_output_quant_supported(quant_key)
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
        for key, layer in compile_config.static_forward_context.items()
    ]

    print(f"{attn_fusion_supported=}")
    if any(attn_fusion_supported):
        # Check quant ops
        backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False)

    # attention ops present in both, just output_scale param changes
    attn_nodes_pre = list(find_op_nodes(ATTN_OP, backend.graph_pre_pass))
    attn_nodes_post = list(find_op_nodes(ATTN_OP, backend.graph_post_pass))
    assert len(attn_nodes_pre) == len(attn_nodes_post)

    for i in range(len(attn_nodes_pre)):
        assert attn_nodes_pre[i].kwargs["output_scale"] is None
        fused = attn_nodes_post[i].kwargs["output_scale"] is not None
152
153
154
        assert fused == attn_fusion_supported[i], (
            f"Node {i} {'' if fused else 'not '} expected to have fused output quant"
        )
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174

    # check outputs
    fused_output = llm2.generate(prompts, sampling_params)

    # transform outputs to format expected by check_outputs_equal
    sample_outs = lambda s: (list(s.token_ids), s.text)
    outs_lst = lambda ros: [sample_outs(ro.outputs[0]) for ro in ros]

    check_outputs_equal(
        outputs_0_lst=outs_lst(unfused_output),
        outputs_1_lst=outs_lst(fused_output),
        name_0="unfused",
        name_1="fused",
    )

    # Clean Dynamo cache to avoid polluting other case(s)
    torch._dynamo.reset()

    # Reset backend to make sure llm2 gets released
    backend = None
175
176


177
178
class AttentionQuantPatternModel(torch.nn.Module):
    """Base model for AttentionQuantPattern fusion."""
179

180
181
182
183
184
185
186
187
188
189
    def __init__(
        self,
        num_qo_heads: int,
        num_kv_heads: int,
        head_size: int,
        kv_cache_dtype: torch.dtype,
        device: torch.device,
        vllm_config: VllmConfig,
        **kwargs,
    ):
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
        super().__init__()
        self.num_qo_heads = num_qo_heads
        self.num_kv_heads = num_kv_heads
        self.head_size = head_size
        self.kv_cache_dtype = kv_cache_dtype
        self.device = device
        self.vllm_config = vllm_config

        self.attn = Attention(
            num_heads=self.num_qo_heads,
            head_size=self.head_size,
            scale=1.0 / (self.head_size**0.5),
            num_kv_heads=self.num_kv_heads,
            cache_config=vllm_config.cache_config,
            prefix="model.layers.0.self_attn.attn",
        )
206
207
        self.attn._k_scale = self.attn._k_scale.to(device)
        self.attn._v_scale = self.attn._v_scale.to(device)
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223

        self.block_size = 16

        # Initialize attn MetadataBuilder
        self.builder = self.attn.attn_backend.get_builder_cls()(
            kv_cache_spec=AttentionSpec(
                block_size=self.block_size,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
                dtype=self.kv_cache_dtype,
            ),
            layer_names=[self.attn.layer_name],
            vllm_config=self.vllm_config,
            device=self.device,
        )

224
    def build_attn_metadata(self, batch_size: int, use_hnd: bool) -> AttentionMetadata:
225
226
227
        """Initialize attention metadata."""

        # Create common attn metadata
228
        batch_spec = BatchSpec(seq_lens=[1] * batch_size, query_lens=[1] * batch_size)
229
        common_attn_metadata = create_common_attn_metadata(
230
231
            batch_spec, self.block_size, self.device, arange_block_indices=True
        )
232

233
        max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
234
235
236
        num_blocks = batch_size * max_blocks

        # Create dummy KV cache for FlashInfer TRTLLM
237
238
        #   - NHD: [num_blocks, block_size, num_kv_heads, head_size]
        #   - HND: [num_blocks, num_kv_heads, block_size, head_size]
239
240
241
242
243
244
245
246
247
        kv_cache = torch.zeros(
            num_blocks,
            2,
            self.num_kv_heads,
            self.block_size,
            self.head_size,
            dtype=self.kv_cache_dtype,
            device=self.device,
        )
248
249
250
251
252
253
254
255
256
257
258
        if current_platform.is_rocm():
            # k/v as 1st dimention
            if use_hnd:
                kv_cache = kv_cache.permute(1, 0, 2, 3, 4)
            else:
                kv_cache = kv_cache.permute(1, 0, 3, 2, 4)
        else:
            # k/v as 2nd dimention
            # Create kv_cache in HND layout and permute to NHD layout
            # (later will be permuted back to HND layout in forward pass)
            kv_cache = kv_cache.permute(0, 1, 3, 2, 4)
259
260
261
262
        self.attn.kv_cache = [kv_cache]

        # Build attn metadata
        self.attn_metadata = self.builder.build(
263
264
            common_prefix_len=0, common_attn_metadata=common_attn_metadata
        )
265
266
267

        return self.attn_metadata

268
269
270
271
272
273
274
275
276
277
278

class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
    """Test model for AttentionFp8StaticQuantPattern fusion."""

    quant_key = kFp8StaticTensorSym

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.fp8_linear = Fp8LinearOp(
            act_quant_static=self.quant_key.scale.static,
279
280
            act_quant_group_shape=self.quant_key.scale.group_shape,
        )
281
282
283

        hidden_size = self.num_qo_heads * self.head_size
        self.w = kwargs.get(
284
285
286
287
288
289
290
291
292
            "w",
            {
                "weight": torch.randn(hidden_size, hidden_size)
                .to(dtype=FP8_DTYPE, device=self.device)
                .t(),
                "wscale": torch.tensor([1.0], dtype=torch.float32, device=self.device),
                "scale": torch.tensor([1.0], dtype=torch.float32, device=self.device),
            },
        )
293
294

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
295
296
        """Forward pass that creates the pattern to be fused."""
        attn_output = self.attn(q, k, v)
297
298
299
300
301
302
        return self.fp8_linear.apply(
            input=attn_output,
            weight=self.w["weight"],
            weight_scale=self.w["wscale"],
            input_scale=self.w["scale"],
        )
303
304
305
306
307
308
309
310
311
312
313
314


class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
    """Test model for AttentionNvfp4QuantPattern fusion."""

    quant_key = kNvfp4Quant

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        hidden_size = self.num_qo_heads * self.head_size
        self.w = kwargs.get(
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
            "w",
            {
                "weight": torch.randint(
                    256,
                    (hidden_size, hidden_size // 2),
                    dtype=FP4_DTYPE,
                    device=self.device,
                ),
                "wscale_swizzled": torch.randn(hidden_size, hidden_size // 16).to(
                    dtype=FP8_DTYPE, device=self.device
                ),
                "wscale": torch.tensor([500], dtype=torch.float32, device=self.device),
                "scale": torch.tensor([0.002], dtype=torch.float32, device=self.device),
            },
        )
330
331
332
333
334

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
        """Forward pass that creates the pattern to be fused."""
        attn_output = self.attn(q, k, v)
        quant_output, output_block_scale = scaled_fp4_quant(
335
336
337
338
339
340
341
342
343
344
            attn_output, 1 / self.w["scale"]
        )
        return cutlass_scaled_fp4_mm(
            a=quant_output,
            b=self.w["weight"],
            block_scale_a=output_block_scale,
            block_scale_b=self.w["wscale_swizzled"],
            alpha=self.w["scale"] * self.w["wscale"],
            out_dtype=attn_output.dtype,
        )
345
346


347
if current_platform.is_cuda():
348
349
350
351
352
353
354
355
356
357
    MODELS = [
        (
            "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
            TestAttentionFp8StaticQuantPatternModel,
        ),
        (
            "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
            TestAttentionNvfp4QuantPatternModel,
        ),
    ]
358
359
    HEADS = [(64, 8), (40, 8)]
elif current_platform.is_rocm():
360
361
362
    MODELS = [
        ("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel)
    ]
363
364
365
366
367
368
369
    HEADS = [(32, 8), (40, 8)]
else:
    MODELS = []
    HEADS = []


@pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS)
370
@pytest.mark.parametrize("head_size", [128])
371
372
373
@pytest.mark.parametrize(
    "batch_size", [7, 256, 533] if current_platform.is_cuda() else [8]
)
374
375
376
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("model_name, model_class", MODELS)
@pytest.mark.parametrize(
377
378
379
380
381
382
    "backend",
    [_Backend.FLASHINFER] if current_platform.is_cuda() else [_Backend.TRITON_ATTN],
)
@pytest.mark.parametrize(
    "split_attention", [False, True] if current_platform.is_rocm() else [False]
)
383
384
385
# TODO(boyuan): test inductor graph partition on rocm
@pytest.mark.parametrize(
    "use_inductor_graph_partition",
386
387
388
389
390
    [False] if current_platform.is_rocm() else [False, True],
)
@pytest.mark.skipif(
    not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
)
391
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
@pytest.mark.skipif(
    current_platform.is_cuda() and not current_platform.is_device_capability((10, 0)),
    reason="On CUDA only test on SM100(Blackwell)",
)
@pytest.mark.skipif(
    not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
)
def test_attention_quant_pattern(
    num_qo_heads: int,
    num_kv_heads: int,
    head_size: int,
    batch_size: int,
    dtype: torch.dtype,
    model_name: str,
    model_class: type[AttentionQuantPatternModel],
    backend: _Backend,
    split_attention: bool,
    use_inductor_graph_partition: bool,
    monkeypatch,
    dist_init,
    caplog_vllm,
):
414
415
    """Test AttentionStaticQuantPattern fusion pass"""

416
417
    if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
        pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
418

419
    monkeypatch.setenv("VLLM_USE_V1", "1")
420
421
    if split_attention:
        monkeypatch.setenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "1")
422
423
424
425
426
427
428
429

    device = torch.device("cuda:0")
    torch.manual_seed(42)

    vllm_config = VllmConfig(
        model_config=ModelConfig(
            model=model_name,
            max_model_len=2048,
430
            dtype=dtype,
431
432
433
434
435
        ),
        scheduler_config=SchedulerConfig(max_num_seqs=1024),
        compilation_config=CompilationConfig(
            level=CompilationLevel.PIECEWISE,
            custom_ops=["+quant_fp8"],
436
            use_inductor_graph_partition=use_inductor_graph_partition,
437
        ),
438
439
        cache_config=CacheConfig(cache_dtype="fp8"),
    )
440
441

    # Create test inputs
442
443
444
    q = torch.randn(batch_size, num_qo_heads * head_size, dtype=dtype, device=device)
    k = torch.randn(batch_size, num_kv_heads * head_size, dtype=dtype, device=device)
    v = torch.randn(batch_size, num_kv_heads * head_size, dtype=dtype, device=device)
445
446
447
448
449
450
451
452

    # Mark first dimension as dynamic for realistic testing
    torch._dynamo.mark_dynamic(q, 0)
    torch._dynamo.mark_dynamic(k, 0)
    torch._dynamo.mark_dynamic(v, 0)

    # Run model directly without compilation and fusion
    vllm_config_unfused = copy.deepcopy(vllm_config)
453
454
455
456
457
458
459
460
461
462
463
464
465
    with (
        set_current_vllm_config(vllm_config_unfused),
        set_forward_context(attn_metadata=None, vllm_config=vllm_config_unfused),
        global_force_attn_backend_context_manager(backend),
    ):
        model_unfused = model_class(
            num_qo_heads=num_qo_heads,
            num_kv_heads=num_kv_heads,
            head_size=head_size,
            kv_cache_dtype=FP8_DTYPE,
            device=device,
            vllm_config=vllm_config_unfused,
        )
466
467
468
469
        model_unfused = model_unfused.to(device)

        forward_ctx = get_forward_context()
        forward_ctx.attn_metadata = model_unfused.build_attn_metadata(
470
471
            batch_size, use_hnd=split_attention
        )
472
473

        # Run model directly without compilation and fusion
474
        result_unfused = model_unfused(q, k, v)
475
476
477

    # Run model with attn fusion enabled
    vllm_config.compilation_config.pass_config = PassConfig(
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
        enable_attn_fusion=True, enable_noop=True
    )
    with (
        set_current_vllm_config(vllm_config),
        set_forward_context(attn_metadata=None, vllm_config=vllm_config),
        global_force_attn_backend_context_manager(backend),
    ):
        model_fused = model_class(
            num_qo_heads=num_qo_heads,
            num_kv_heads=num_kv_heads,
            head_size=head_size,
            kv_cache_dtype=FP8_DTYPE,
            device=device,
            vllm_config=vllm_config,
            w=model_unfused.w,
        )
494
495
496
        model_fused = model_fused.to(device)

        forward_ctx = get_forward_context()
497
        forward_ctx.attn_metadata = model_fused.build_attn_metadata(
498
499
            batch_size, use_hnd=split_attention
        )
500
501
502

        # Create test backend with fusion passes enabled
        noop_pass = NoOpEliminationPass(vllm_config)
503
504
505
506
        attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
        cleanup_pass = PostCleanupPass(vllm_config)

        test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass)
507
508

        # Compile model with fusion enabled
509
510
511
        model_compiled = torch.compile(
            model_fused, backend=test_backend, fullgraph=True
        )
512
        assert model_compiled.attn._o_scale_float is None
513

514
        result_fused_1 = model_compiled(q, k, v)
515

516
517
518
519
520
521
522
        if backend == _Backend.FLASHINFER:
            # With the Flashinfer backend after the 1st round of the forward
            # pass, output quant scale should be loaded into the attn layer's
            # _o_scale_float, the 2nd round should reuse the loaded
            # _o_scale_float
            assert model_compiled.attn._o_scale_float is not None
            result_fused_2 = model_compiled(q, k, v)
523

524
525
            assert model_compiled.attn._o_scale_float is not None

526
527
528
            torch.testing.assert_close(
                result_unfused, result_fused_2, atol=1e-2, rtol=1e-2
            )
529
530

    # Check attn fusion support
531
    quant_key = model_class.quant_key
532
    attn_fusion_supported = [
533
534
        layer.impl.fused_output_quant_supported(quant_key)
        for key, layer in vllm_config.compilation_config.static_forward_context.items()
535
536
537
    ]
    if any(attn_fusion_supported):
        # Check quantization ops in the graph before and after fusion
538
        test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=True)
539

540
541
542
    # access the underlying `AttnFusionPass` on the `LazyInitPass`
    assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)

543
544
    # Check attention ops in the graph before and after fusion
    attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass))
545
    attn_nodes_post = list(find_op_nodes(ATTN_OP, test_backend.graph_post_pass))
546
547

    assert len(attn_nodes_pre) > 0, "Should have attention nodes before fusion"
548
    assert len(attn_nodes_pre) == len(attn_nodes_post), (
549
        "Should have same number of attention nodes before and after fusion"
550
551
    )
    assert attn_nodes_pre[0].kwargs.get("output_scale") is None, (
552
        "Attention should not have output_scale before fusion"
553
554
    )
    assert attn_nodes_post[0].kwargs.get("output_scale") is not None, (
555
        "Attention should have output_scale after fusion"
556
    )
557

558
    assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, (
559
        "Attention should not have output_block_scale before fusion"
560
    )
561
    if quant_key.dtype == FP8_DTYPE:
562
        assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, (
563
            "Attention should not have output_block_scale after FP8 fusion"
564
        )
565
    elif quant_key.dtype == FP4_DTYPE:
566
567
        assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, (
            "Attention should have output_block_scale after FP4 fusion"
568
        )
569

570
    # Check that results are close
571
    torch.testing.assert_close(result_unfused, result_fused_1, atol=1e-2, rtol=1e-2)