test_fusion_attn.py 21.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import copy
4
5
import logging
from typing import Any
6
7

import pytest
8
import regex as re
9
10
import torch._dynamo

11
from tests.compile.backend import LazyInitPass, TestBackend
12
13
14
15
16
17
18
19
20
21
from tests.compile.fusion_test_utils import (
    CUSTOM_OPS_FP8,
    MODELS_FP4,
    MODELS_FP8,
    Matches,
    has_cuda_graph_wrapper_metadata,
    is_blackwell,
    run_model,
)
from tests.utils import cuda_device_count_stateless, flat_product
22
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
23
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
24
from vllm.attention.layer import Attention
25
26
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes
27
from vllm.compilation.matcher_utils import QUANT_OPS
28
from vllm.compilation.noop_elimination import NoOpEliminationPass
29
from vllm.compilation.post_cleanup import PostCleanupPass
30
from vllm.config import (
31
    AttentionConfig,
32
33
    CacheConfig,
    CompilationConfig,
34
    CompilationMode,
35
    CUDAGraphMode,
36
37
38
39
40
41
    ModelConfig,
    PassConfig,
    SchedulerConfig,
    VllmConfig,
    set_current_vllm_config,
)
42
43
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.quantization.utils.quant_utils import (
44
    QuantKey,
45
46
47
48
    kFp8StaticTensorSym,
    kNvfp4Quant,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
49
from vllm.platforms import current_platform
50
from vllm.utils.flashinfer import has_flashinfer
51
from vllm.utils.torch_utils import is_torch_equal_or_newer
52
53
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.registry import AttentionBackendEnum
54
55
56
from vllm.v1.kv_cache_interface import AttentionSpec

FP8_DTYPE = current_platform.fp8_dtype()
57
FP4_DTYPE = torch.uint8
58
59


60
61
class AttentionQuantPatternModel(torch.nn.Module):
    """Base model for AttentionQuantPattern fusion."""
62

63
64
65
66
67
68
69
70
71
72
    def __init__(
        self,
        num_qo_heads: int,
        num_kv_heads: int,
        head_size: int,
        kv_cache_dtype: torch.dtype,
        device: torch.device,
        vllm_config: VllmConfig,
        **kwargs,
    ):
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        super().__init__()
        self.num_qo_heads = num_qo_heads
        self.num_kv_heads = num_kv_heads
        self.head_size = head_size
        self.kv_cache_dtype = kv_cache_dtype
        self.device = device
        self.vllm_config = vllm_config

        self.attn = Attention(
            num_heads=self.num_qo_heads,
            head_size=self.head_size,
            scale=1.0 / (self.head_size**0.5),
            num_kv_heads=self.num_kv_heads,
            cache_config=vllm_config.cache_config,
            prefix="model.layers.0.self_attn.attn",
        )
89
90
        self.attn._k_scale = self.attn._k_scale.to(device)
        self.attn._v_scale = self.attn._v_scale.to(device)
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106

        self.block_size = 16

        # Initialize attn MetadataBuilder
        self.builder = self.attn.attn_backend.get_builder_cls()(
            kv_cache_spec=AttentionSpec(
                block_size=self.block_size,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
                dtype=self.kv_cache_dtype,
            ),
            layer_names=[self.attn.layer_name],
            vllm_config=self.vllm_config,
            device=self.device,
        )

107
    def build_attn_metadata(self, batch_size: int) -> AttentionMetadata:
108
109
110
        """Initialize attention metadata."""

        # Create common attn metadata
111
        batch_spec = BatchSpec(seq_lens=[1] * batch_size, query_lens=[1] * batch_size)
112
        common_attn_metadata = create_common_attn_metadata(
113
114
            batch_spec, self.block_size, self.device, arange_block_indices=True
        )
115

116
        max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
117
        num_blocks = batch_size * max_blocks
118
        backend = self.attn.backend
119

120
        # TODO(luka) use get_kv_cache_stride_order
121
        # Create dummy KV cache for the selected backend
122
        if backend == AttentionBackendEnum.ROCM_ATTN:
123
            # k/v as 1st dimention
124
125
126
127
128
129
130
131
132
133
            # HND: [num_blocks, num_kv_heads, block_size, head_size]
            kv_cache = torch.zeros(
                2,
                num_blocks,
                self.num_kv_heads,
                self.block_size,
                self.head_size,
                dtype=self.kv_cache_dtype,
                device=self.device,
            )
134
        elif backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
135
136
137
138
139
140
141
142
143
144
145
            # k/v as 1st dimention
            # NHD: [num_blocks, block_size, num_kv_heads, head_size]
            kv_cache = torch.zeros(
                2,
                num_blocks,
                self.block_size,
                self.num_kv_heads,
                self.head_size,
                dtype=self.kv_cache_dtype,
                device=self.device,
            )
146
        elif backend == AttentionBackendEnum.TRITON_ATTN:
147
            # k/v as 2nd dimention
148
149
150
151
152
153
154
155
156
157
            # NHD: [num_blocks, block_size, num_kv_heads, head_size]
            kv_cache = torch.zeros(
                num_blocks,
                2,
                self.num_kv_heads,
                self.block_size,
                self.head_size,
                dtype=self.kv_cache_dtype,
                device=self.device,
            )
158
        elif backend == AttentionBackendEnum.FLASHINFER:
159
160
161
162
163
164
165
166
167
168
169
            kv_cache = torch.zeros(
                num_blocks,
                2,
                self.num_kv_heads,
                self.block_size,
                self.head_size,
                dtype=self.kv_cache_dtype,
                device=self.device,
            ).permute(0, 1, 3, 2, 4)
        else:
            raise ValueError(f"Unsupported backend: {backend}")
170
171
172
173
        self.attn.kv_cache = [kv_cache]

        # Build attn metadata
        self.attn_metadata = self.builder.build(
174
175
            common_prefix_len=0, common_attn_metadata=common_attn_metadata
        )
176
177
178

        return self.attn_metadata

179
180
181
182
183
184
185
186
187
188
189

class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
    """Test model for AttentionFp8StaticQuantPattern fusion."""

    quant_key = kFp8StaticTensorSym

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.fp8_linear = Fp8LinearOp(
            act_quant_static=self.quant_key.scale.static,
190
191
            act_quant_group_shape=self.quant_key.scale.group_shape,
        )
192
193
194

        hidden_size = self.num_qo_heads * self.head_size
        self.w = kwargs.get(
195
196
197
198
199
200
201
202
203
            "w",
            {
                "weight": torch.randn(hidden_size, hidden_size)
                .to(dtype=FP8_DTYPE, device=self.device)
                .t(),
                "wscale": torch.tensor([1.0], dtype=torch.float32, device=self.device),
                "scale": torch.tensor([1.0], dtype=torch.float32, device=self.device),
            },
        )
204
205

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
206
207
        """Forward pass that creates the pattern to be fused."""
        attn_output = self.attn(q, k, v)
208
209
210
211
212
213
        return self.fp8_linear.apply(
            input=attn_output,
            weight=self.w["weight"],
            weight_scale=self.w["wscale"],
            input_scale=self.w["scale"],
        )
214
215
216
217
218
219
220
221
222
223
224
225


class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
    """Test model for AttentionNvfp4QuantPattern fusion."""

    quant_key = kNvfp4Quant

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        hidden_size = self.num_qo_heads * self.head_size
        self.w = kwargs.get(
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
            "w",
            {
                "weight": torch.randint(
                    256,
                    (hidden_size, hidden_size // 2),
                    dtype=FP4_DTYPE,
                    device=self.device,
                ),
                "wscale_swizzled": torch.randn(hidden_size, hidden_size // 16).to(
                    dtype=FP8_DTYPE, device=self.device
                ),
                "wscale": torch.tensor([500], dtype=torch.float32, device=self.device),
                "scale": torch.tensor([0.002], dtype=torch.float32, device=self.device),
            },
        )
241
242
243
244
245

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
        """Forward pass that creates the pattern to be fused."""
        attn_output = self.attn(q, k, v)
        quant_output, output_block_scale = scaled_fp4_quant(
246
247
248
249
250
251
252
253
254
255
            attn_output, 1 / self.w["scale"]
        )
        return cutlass_scaled_fp4_mm(
            a=quant_output,
            b=self.w["weight"],
            block_scale_a=output_block_scale,
            block_scale_b=self.w["wscale_swizzled"],
            alpha=self.w["scale"] * self.w["wscale"],
            out_dtype=attn_output.dtype,
        )
256
257


258
259
PATTERN_TEST_MODELS_FP8: list[tuple[str, type]] = []
PATTERN_TEST_MODELS_FP4: list[tuple[str, type]] = []
260
261
HEADS: list[tuple[int, int]] = []
SPLIT_ATTENTION: list[bool] = []
262
263
BACKENDS_FP8: list[AttentionBackendEnum] = []
BACKENDS_FP4: list[AttentionBackendEnum] = []
264

265
if current_platform.is_cuda():
266
    HEADS = [(64, 8), (40, 8)]
267
    PATTERN_TEST_MODELS_FP8 = [
268
269
270
        (
            "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
            TestAttentionFp8StaticQuantPatternModel,
271
272
        )
    ]
273
    PATTERN_TEST_MODELS_FP4 = [
274
275
276
        (
            "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
            TestAttentionNvfp4QuantPatternModel,
277
        )
278
    ]
279
280
    BACKENDS_FP8 = [AttentionBackendEnum.TRITON_ATTN, AttentionBackendEnum.FLASHINFER]
    BACKENDS_FP4 = [AttentionBackendEnum.FLASHINFER]
281

282
elif current_platform.is_rocm():
283
    HEADS = [(32, 8), (40, 8)]
284
    PATTERN_TEST_MODELS_FP8 = [
285
286
        ("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel)
    ]
287
    BACKENDS = [
288
289
290
        AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
        AttentionBackendEnum.ROCM_ATTN,
        AttentionBackendEnum.TRITON_ATTN,
291
    ]
292
293
294


@pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS)
295
@pytest.mark.parametrize("head_size", [128])
296
297
298
@pytest.mark.parametrize(
    "batch_size", [7, 256, 533] if current_platform.is_cuda() else [8]
)
299
300
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize(
301
302
    "backend, model_name, model_class, custom_ops",
    # Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
303
304
305
306
307
    list(
        flat_product(
            BACKENDS_FP8, PATTERN_TEST_MODELS_FP8, ["+quant_fp8", "-quant_fp8"]
        )
    )
308
    # quant_fp4 only has the custom impl
309
    + list(flat_product(BACKENDS_FP4, PATTERN_TEST_MODELS_FP4, [""])),
310
311
312
313
)
@pytest.mark.skipif(
    not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
)
314
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
315
316
317
318
319
320
def test_attention_quant_pattern(
    num_qo_heads: int,
    num_kv_heads: int,
    head_size: int,
    batch_size: int,
    dtype: torch.dtype,
321
    custom_ops: str,
322
323
    model_name: str,
    model_class: type[AttentionQuantPatternModel],
324
    backend: AttentionBackendEnum,
325
    dist_init,
326
327
    monkeypatch,
    use_fresh_inductor_cache,
328
):
329
    """Test AttentionStaticQuantPattern fusion pass"""
330
331
    monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")

332
    if backend == AttentionBackendEnum.FLASHINFER and (
333
334
335
        not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
    ):
        pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
336
337
    if "Llama-4-Scout" in model_name and cuda_device_count_stateless() < 2:
        pytest.skip("Llama-4-Scout requires at least 2 GPUs")
338

339
    custom_ops_list = custom_ops.split(",") if custom_ops else []
340

341
    device = torch.device("cuda:0")
342
    torch.set_default_dtype(dtype)
343
344
    torch.manual_seed(42)

345
346
347
348
349
    model_config = ModelConfig(
        model=model_name,
        max_model_len=2048,
        dtype=dtype,
    )
350
    vllm_config = VllmConfig(
351
352
353
354
355
        model_config=model_config,
        scheduler_config=SchedulerConfig(
            max_num_seqs=1024,
            max_model_len=model_config.max_model_len,
            is_encoder_decoder=model_config.is_encoder_decoder,
356
357
        ),
        compilation_config=CompilationConfig(
358
            mode=CompilationMode.VLLM_COMPILE,
359
            custom_ops=custom_ops_list,
360
        ),
361
        cache_config=CacheConfig(cache_dtype="fp8"),
362
        attention_config=AttentionConfig(backend=backend),
363
    )
364
365

    # Create test inputs
366
367
368
    q = torch.randn(batch_size, num_qo_heads * head_size, dtype=dtype, device=device)
    k = torch.randn(batch_size, num_kv_heads * head_size, dtype=dtype, device=device)
    v = torch.randn(batch_size, num_kv_heads * head_size, dtype=dtype, device=device)
369
370
371
372
373
374
375
376

    # Mark first dimension as dynamic for realistic testing
    torch._dynamo.mark_dynamic(q, 0)
    torch._dynamo.mark_dynamic(k, 0)
    torch._dynamo.mark_dynamic(v, 0)

    # Run model directly without compilation and fusion
    vllm_config_unfused = copy.deepcopy(vllm_config)
377
378
379
380
381
382
383
384
385
386
387
388
    with (
        set_current_vllm_config(vllm_config_unfused),
        set_forward_context(attn_metadata=None, vllm_config=vllm_config_unfused),
    ):
        model_unfused = model_class(
            num_qo_heads=num_qo_heads,
            num_kv_heads=num_kv_heads,
            head_size=head_size,
            kv_cache_dtype=FP8_DTYPE,
            device=device,
            vllm_config=vllm_config_unfused,
        )
389
        model_unfused = model_unfused.to(device)
390
        result_unfused_0 = model_unfused(q, k, v)  # noqa: F841  HACK: See #131044
391
392

        forward_ctx = get_forward_context()
393
        forward_ctx.attn_metadata = model_unfused.build_attn_metadata(batch_size)
394

395
396
        # Run model directly without fusion
        # Still compile so query QuantFP8 has closer numerics
397
398
        compiled_unfused = torch.compile(model_unfused, fullgraph=True)
        result_unfused = compiled_unfused(q, k, v)
399
400
401

    # Run model with attn fusion enabled
    vllm_config.compilation_config.pass_config = PassConfig(
402
        fuse_attn_quant=True, eliminate_noops=True
403
404
405
406
407
408
409
410
411
412
413
414
415
416
    )
    with (
        set_current_vllm_config(vllm_config),
        set_forward_context(attn_metadata=None, vllm_config=vllm_config),
    ):
        model_fused = model_class(
            num_qo_heads=num_qo_heads,
            num_kv_heads=num_kv_heads,
            head_size=head_size,
            kv_cache_dtype=FP8_DTYPE,
            device=device,
            vllm_config=vllm_config,
            w=model_unfused.w,
        )
417
418
419
        model_fused = model_fused.to(device)

        forward_ctx = get_forward_context()
420
        forward_ctx.attn_metadata = model_fused.build_attn_metadata(batch_size)
421
422
423

        # Create test backend with fusion passes enabled
        noop_pass = NoOpEliminationPass(vllm_config)
424
425
426
427
        attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
        cleanup_pass = PostCleanupPass(vllm_config)

        test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass)
428
429
        # HACK: See https://github.com/vllm-project/vllm/issues/31044
        result_fused_0 = model_fused(q, k, v)  # noqa: F841
430
431

        # Compile model with fusion enabled
432
        compiled_fused = torch.compile(
433
434
            model_fused, backend=test_backend, fullgraph=True
        )
435
        assert compiled_fused.attn._o_scale_float is None
436

437
        result_fused = compiled_fused(q, k, v)
438

439
        if backend == AttentionBackendEnum.FLASHINFER:
440
441
442
443
            # With the Flashinfer backend after the 1st round of the forward
            # pass, output quant scale should be loaded into the attn layer's
            # _o_scale_float, the 2nd round should reuse the loaded
            # _o_scale_float
444
445
            assert compiled_fused.attn._o_scale_float is not None
            result_fused_2 = compiled_fused(q, k, v)
446

447
            assert compiled_fused.attn._o_scale_float is not None
448

449
450
451
            torch.testing.assert_close(
                result_unfused, result_fused_2, atol=1e-2, rtol=1e-2
            )
452
453

    # Check attn fusion support
454
    quant_key: QuantKey = model_class.quant_key
455
    attn_fusion_supported = [
456
457
        layer.impl.fused_output_quant_supported(quant_key)
        for key, layer in vllm_config.compilation_config.static_forward_context.items()
458
    ]
459
460
461
462
463
464
465
466
467
468
469
470
471
472
    assert sum(attn_fusion_supported) == len(attn_fusion_supported), (
        "All layers should support attention fusion"
    )

    # Check quantization ops in the graph before and after fusion
    quant_op = (
        torch.ops.aten.reciprocal
        if "-quant_fp8" in custom_ops_list
        else QUANT_OPS[quant_key]
    )

    # Note: for fp8, fully_replaced=False because query quant ops remain in graph.
    # Only output quant ops are fused into attention.
    test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Quant)
473

474
475
476
    # access the underlying `AttnFusionPass` on the `LazyInitPass`
    assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)

477
478
    # Check attention ops in the graph before and after fusion
    attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass))
479
    attn_nodes_post = list(find_op_nodes(ATTN_OP, test_backend.graph_post_pass))
480
481

    assert len(attn_nodes_pre) > 0, "Should have attention nodes before fusion"
482
    assert len(attn_nodes_pre) == len(attn_nodes_post), (
483
        "Should have same number of attention nodes before and after fusion"
484
485
    )
    assert attn_nodes_pre[0].kwargs.get("output_scale") is None, (
486
        "Attention should not have output_scale before fusion"
487
488
    )
    assert attn_nodes_post[0].kwargs.get("output_scale") is not None, (
489
        "Attention should have output_scale after fusion"
490
    )
491

492
    assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, (
493
        "Attention should not have output_block_scale before fusion"
494
    )
495
    if quant_key.dtype == FP8_DTYPE:
496
        assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, (
497
            "Attention should not have output_block_scale after FP8 fusion"
498
        )
499
    elif quant_key.dtype == FP4_DTYPE:
500
501
        assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, (
            "Attention should have output_block_scale after FP4 fusion"
502
        )
503

504
    # Check that results are close
505
    torch.testing.assert_close(result_unfused, result_fused, atol=1e-2, rtol=1e-2)
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590


@pytest.mark.parametrize(
    "model_name, model_kwargs, backend, matches, custom_ops",
    # Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
    list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8))
    # quant_fp4 only has the custom impl
    + list(flat_product(MODELS_FP4, [""])),
)
@pytest.mark.parametrize(
    "inductor_graph_partition",
    [
        pytest.param(
            True,
            marks=pytest.mark.skipif(
                not has_cuda_graph_wrapper_metadata(),
                reason="This test requires"
                "torch._inductor.utils.CUDAGraphWrapperMetadata to run",
            ),
        ),
        False,
    ],
)
def test_attn_quant(
    model_name: str,
    model_kwargs: dict[str, Any],
    backend: AttentionBackendEnum,
    matches: Matches,
    custom_ops: str,
    inductor_graph_partition: bool,
    caplog_mp_spawn,
    monkeypatch,
):
    if not current_platform.has_device_capability(90):
        pytest.skip("test_attn_quant requires H100 (SM90) or B200 (SM100) GPU")
    if backend == AttentionBackendEnum.FLASHINFER and (
        not is_blackwell() or not has_flashinfer()
    ):
        pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
    if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
        pytest.skip("Inductor graph partition requires torch>=2.9")

    custom_ops_list = custom_ops.split(",") if custom_ops else []

    if inductor_graph_partition:
        mode = CUDAGraphMode.FULL_AND_PIECEWISE
        splitting_ops: list[str] | None = None
    else:
        # FIXME: Llama-4-Scout-17B-16E-Instruct-FP8 + FlashInfer + Blackwell end at
        # CUDAGraphMode.NONE here because it derives an attention backend that
        # does not support full cudagraphs
        mode = CUDAGraphMode.FULL_DECODE_ONLY
        splitting_ops = []

    # Disable, compile cache to make sure custom passes run.
    # Otherwise, we can't verify fusion happened through the logs.
    monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")

    # To capture subprocess logs, we need to know whether spawn or fork is used.
    # Force spawn as it is more general.
    monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
    model_kwargs["attention_config"] = {"backend": backend.name}

    compilation_config = CompilationConfig(
        # Testing properties
        custom_ops=custom_ops_list,
        use_inductor_graph_partition=inductor_graph_partition,
        cudagraph_mode=mode,
        splitting_ops=splitting_ops,
        # Common
        mode=CompilationMode.VLLM_COMPILE,
        pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
        # Inductor caches custom passes by default as well via uuid
        inductor_compile_config={"force_disable_caches": True},
    )

    with caplog_mp_spawn(logging.DEBUG) as log_holder:
        run_model(compilation_config, model_name, **model_kwargs)

    log_matches = re.findall(
        r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
        log_holder.text,
    )
    assert len(log_matches) == 1, log_holder.text
    assert int(log_matches[0]) == matches.attention_fusion