test_fusion_attn.py 21 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import copy
4
5
import logging
from typing import Any
6
7

import pytest
8
import regex as re
9
10
import torch._dynamo

11
from tests.compile.backend import LazyInitPass, TestBackend
12
13
14
15
16
17
18
19
20
21
from tests.compile.fusion_test_utils import (
    CUSTOM_OPS_FP8,
    MODELS_FP4,
    MODELS_FP8,
    Matches,
    has_cuda_graph_wrapper_metadata,
    is_blackwell,
    run_model,
)
from tests.utils import cuda_device_count_stateless, flat_product
22
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
23
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
24
from vllm.attention.layer import Attention
25
26
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes
27
from vllm.compilation.matcher_utils import QUANT_OPS
28
from vllm.compilation.noop_elimination import NoOpEliminationPass
29
from vllm.compilation.post_cleanup import PostCleanupPass
30
from vllm.config import (
31
    AttentionConfig,
32
33
    CacheConfig,
    CompilationConfig,
34
    CompilationMode,
35
    CUDAGraphMode,
36
37
38
39
40
41
    ModelConfig,
    PassConfig,
    SchedulerConfig,
    VllmConfig,
    set_current_vllm_config,
)
42
43
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.quantization.utils.quant_utils import (
44
    QuantKey,
45
46
47
    kFp8StaticTensorSym,
    kNvfp4Quant,
)
48
from vllm.platforms import current_platform
49
from vllm.utils.flashinfer import has_flashinfer
50
from vllm.utils.torch_utils import is_torch_equal_or_newer
51
52
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.registry import AttentionBackendEnum
53
54
from vllm.v1.kv_cache_interface import AttentionSpec

55
56
from ..utils import TestFP8Layer

57
FP8_DTYPE = current_platform.fp8_dtype()
58
FP4_DTYPE = torch.uint8
59
60


61
62
class AttentionQuantPatternModel(torch.nn.Module):
    """Base model for AttentionQuantPattern fusion."""
63

64
65
66
67
68
69
70
71
72
73
    def __init__(
        self,
        num_qo_heads: int,
        num_kv_heads: int,
        head_size: int,
        kv_cache_dtype: torch.dtype,
        device: torch.device,
        vllm_config: VllmConfig,
        **kwargs,
    ):
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        super().__init__()
        self.num_qo_heads = num_qo_heads
        self.num_kv_heads = num_kv_heads
        self.head_size = head_size
        self.kv_cache_dtype = kv_cache_dtype
        self.device = device
        self.vllm_config = vllm_config

        self.attn = Attention(
            num_heads=self.num_qo_heads,
            head_size=self.head_size,
            scale=1.0 / (self.head_size**0.5),
            num_kv_heads=self.num_kv_heads,
            cache_config=vllm_config.cache_config,
            prefix="model.layers.0.self_attn.attn",
        )
90
91
        self.attn._k_scale = self.attn._k_scale.to(device)
        self.attn._v_scale = self.attn._v_scale.to(device)
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107

        self.block_size = 16

        # Initialize attn MetadataBuilder
        self.builder = self.attn.attn_backend.get_builder_cls()(
            kv_cache_spec=AttentionSpec(
                block_size=self.block_size,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
                dtype=self.kv_cache_dtype,
            ),
            layer_names=[self.attn.layer_name],
            vllm_config=self.vllm_config,
            device=self.device,
        )

108
    def build_attn_metadata(self, batch_size: int) -> AttentionMetadata:
109
110
111
        """Initialize attention metadata."""

        # Create common attn metadata
112
        batch_spec = BatchSpec(seq_lens=[1] * batch_size, query_lens=[1] * batch_size)
113
        common_attn_metadata = create_common_attn_metadata(
114
115
            batch_spec, self.block_size, self.device, arange_block_indices=True
        )
116

117
        max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
118
        num_blocks = batch_size * max_blocks
119
        backend = self.attn.backend
120

121
        # TODO(luka) use get_kv_cache_stride_order
122
        # Create dummy KV cache for the selected backend
123
        if backend == AttentionBackendEnum.ROCM_ATTN:
124
            # k/v as 1st dimention
125
126
127
128
129
130
131
132
133
134
            # HND: [num_blocks, num_kv_heads, block_size, head_size]
            kv_cache = torch.zeros(
                2,
                num_blocks,
                self.num_kv_heads,
                self.block_size,
                self.head_size,
                dtype=self.kv_cache_dtype,
                device=self.device,
            )
135
        elif backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
136
137
138
139
140
141
142
143
144
145
146
            # k/v as 1st dimention
            # NHD: [num_blocks, block_size, num_kv_heads, head_size]
            kv_cache = torch.zeros(
                2,
                num_blocks,
                self.block_size,
                self.num_kv_heads,
                self.head_size,
                dtype=self.kv_cache_dtype,
                device=self.device,
            )
147
        elif backend == AttentionBackendEnum.TRITON_ATTN:
148
            # k/v as 2nd dimention
149
150
151
152
153
154
155
156
157
158
            # NHD: [num_blocks, block_size, num_kv_heads, head_size]
            kv_cache = torch.zeros(
                num_blocks,
                2,
                self.num_kv_heads,
                self.block_size,
                self.head_size,
                dtype=self.kv_cache_dtype,
                device=self.device,
            )
159
        elif backend == AttentionBackendEnum.FLASHINFER:
160
161
162
163
164
165
166
167
168
169
170
            kv_cache = torch.zeros(
                num_blocks,
                2,
                self.num_kv_heads,
                self.block_size,
                self.head_size,
                dtype=self.kv_cache_dtype,
                device=self.device,
            ).permute(0, 1, 3, 2, 4)
        else:
            raise ValueError(f"Unsupported backend: {backend}")
171
172
173
174
        self.attn.kv_cache = [kv_cache]

        # Build attn metadata
        self.attn_metadata = self.builder.build(
175
176
            common_prefix_len=0, common_attn_metadata=common_attn_metadata
        )
177
178
179

        return self.attn_metadata

180
181
182
183
184
185
186
187
188
189

class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
    """Test model for AttentionFp8StaticQuantPattern fusion."""

    quant_key = kFp8StaticTensorSym

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        hidden_size = self.num_qo_heads * self.head_size
190
191
192
193
194
        self.fp8_linear = TestFP8Layer(
            weight_shape=(hidden_size, hidden_size),
            activation_quant_key=self.quant_key,
            weight_quant_key=self.quant_key,
            device=self.device,
195
        )
196

197
198
199
200
201
202
203
204
205
206
207
208
        w = kwargs.get("w")
        if w is not None:
            self.fp8_linear.weight = w["weight"]
            self.fp8_linear.weight_scale = w["wscale"]
            self.fp8_linear.input_scale = w["scale"]

        self.w = {
            "weight": self.fp8_linear.weight,
            "wscale": self.fp8_linear.weight_scale,
            "scale": self.fp8_linear.input_scale,
        }

209
    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
210
211
        """Forward pass that creates the pattern to be fused."""
        attn_output = self.attn(q, k, v)
212
        return self.fp8_linear(attn_output)
213
214
215
216
217
218
219
220
221
222
223
224


class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
    """Test model for AttentionNvfp4QuantPattern fusion."""

    quant_key = kNvfp4Quant

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        hidden_size = self.num_qo_heads * self.head_size
        self.w = kwargs.get(
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
            "w",
            {
                "weight": torch.randint(
                    256,
                    (hidden_size, hidden_size // 2),
                    dtype=FP4_DTYPE,
                    device=self.device,
                ),
                "wscale_swizzled": torch.randn(hidden_size, hidden_size // 16).to(
                    dtype=FP8_DTYPE, device=self.device
                ),
                "wscale": torch.tensor([500], dtype=torch.float32, device=self.device),
                "scale": torch.tensor([0.002], dtype=torch.float32, device=self.device),
            },
        )
240
241
242
243
244

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
        """Forward pass that creates the pattern to be fused."""
        attn_output = self.attn(q, k, v)
        quant_output, output_block_scale = scaled_fp4_quant(
245
246
247
248
249
250
251
252
253
254
            attn_output, 1 / self.w["scale"]
        )
        return cutlass_scaled_fp4_mm(
            a=quant_output,
            b=self.w["weight"],
            block_scale_a=output_block_scale,
            block_scale_b=self.w["wscale_swizzled"],
            alpha=self.w["scale"] * self.w["wscale"],
            out_dtype=attn_output.dtype,
        )
255
256


257
258
PATTERN_TEST_MODELS_FP8: list[tuple[str, type]] = []
PATTERN_TEST_MODELS_FP4: list[tuple[str, type]] = []
259
260
HEADS: list[tuple[int, int]] = []
SPLIT_ATTENTION: list[bool] = []
261
262
BACKENDS_FP8: list[AttentionBackendEnum] = []
BACKENDS_FP4: list[AttentionBackendEnum] = []
263

264
if current_platform.is_cuda():
265
    HEADS = [(64, 8), (40, 8)]
266
    PATTERN_TEST_MODELS_FP8 = [
267
268
269
        (
            "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
            TestAttentionFp8StaticQuantPatternModel,
270
271
        )
    ]
272
    PATTERN_TEST_MODELS_FP4 = [
273
274
275
        (
            "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
            TestAttentionNvfp4QuantPatternModel,
276
        )
277
    ]
278
279
    BACKENDS_FP8 = [AttentionBackendEnum.TRITON_ATTN, AttentionBackendEnum.FLASHINFER]
    BACKENDS_FP4 = [AttentionBackendEnum.FLASHINFER]
280

281
elif current_platform.is_rocm():
282
    HEADS = [(32, 8), (40, 8)]
283
    PATTERN_TEST_MODELS_FP8 = [
284
285
        ("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel)
    ]
286
    BACKENDS = [
287
288
289
        AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
        AttentionBackendEnum.ROCM_ATTN,
        AttentionBackendEnum.TRITON_ATTN,
290
    ]
291
292
293


@pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS)
294
@pytest.mark.parametrize("head_size", [128])
295
296
297
@pytest.mark.parametrize(
    "batch_size", [7, 256, 533] if current_platform.is_cuda() else [8]
)
298
299
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize(
300
301
    "backend, model_name, model_class, custom_ops",
    # Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
302
303
304
305
306
    list(
        flat_product(
            BACKENDS_FP8, PATTERN_TEST_MODELS_FP8, ["+quant_fp8", "-quant_fp8"]
        )
    )
307
    # quant_fp4 only has the custom impl
308
    + list(flat_product(BACKENDS_FP4, PATTERN_TEST_MODELS_FP4, [""])),
309
310
311
312
)
@pytest.mark.skipif(
    not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
)
313
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
314
315
316
317
318
319
def test_attention_quant_pattern(
    num_qo_heads: int,
    num_kv_heads: int,
    head_size: int,
    batch_size: int,
    dtype: torch.dtype,
320
    custom_ops: str,
321
322
    model_name: str,
    model_class: type[AttentionQuantPatternModel],
323
    backend: AttentionBackendEnum,
324
    dist_init,
325
326
    monkeypatch,
    use_fresh_inductor_cache,
327
):
328
    """Test AttentionStaticQuantPattern fusion pass"""
329
330
    monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")

331
    if backend == AttentionBackendEnum.FLASHINFER and (
332
333
334
        not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
    ):
        pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
335
336
    if "Llama-4-Scout" in model_name and cuda_device_count_stateless() < 2:
        pytest.skip("Llama-4-Scout requires at least 2 GPUs")
337

338
    custom_ops_list = custom_ops.split(",") if custom_ops else []
339

340
    device = torch.device("cuda:0")
341
    torch.set_default_dtype(dtype)
342
343
    torch.manual_seed(42)

344
345
346
347
348
    model_config = ModelConfig(
        model=model_name,
        max_model_len=2048,
        dtype=dtype,
    )
349
    vllm_config = VllmConfig(
350
351
352
353
354
        model_config=model_config,
        scheduler_config=SchedulerConfig(
            max_num_seqs=1024,
            max_model_len=model_config.max_model_len,
            is_encoder_decoder=model_config.is_encoder_decoder,
355
356
        ),
        compilation_config=CompilationConfig(
357
            mode=CompilationMode.VLLM_COMPILE,
358
            custom_ops=custom_ops_list,
359
        ),
360
        cache_config=CacheConfig(cache_dtype="fp8"),
361
        attention_config=AttentionConfig(backend=backend),
362
    )
363
364

    # Create test inputs
365
366
367
    q = torch.randn(batch_size, num_qo_heads * head_size, dtype=dtype, device=device)
    k = torch.randn(batch_size, num_kv_heads * head_size, dtype=dtype, device=device)
    v = torch.randn(batch_size, num_kv_heads * head_size, dtype=dtype, device=device)
368
369
370
371
372
373
374
375

    # Mark first dimension as dynamic for realistic testing
    torch._dynamo.mark_dynamic(q, 0)
    torch._dynamo.mark_dynamic(k, 0)
    torch._dynamo.mark_dynamic(v, 0)

    # Run model directly without compilation and fusion
    vllm_config_unfused = copy.deepcopy(vllm_config)
376
377
378
379
380
381
382
383
384
385
386
387
    with (
        set_current_vllm_config(vllm_config_unfused),
        set_forward_context(attn_metadata=None, vllm_config=vllm_config_unfused),
    ):
        model_unfused = model_class(
            num_qo_heads=num_qo_heads,
            num_kv_heads=num_kv_heads,
            head_size=head_size,
            kv_cache_dtype=FP8_DTYPE,
            device=device,
            vllm_config=vllm_config_unfused,
        )
388
        model_unfused = model_unfused.to(device)
389
        result_unfused_0 = model_unfused(q, k, v)  # noqa: F841  HACK: See #131044
390
391

        forward_ctx = get_forward_context()
392
        forward_ctx.attn_metadata = model_unfused.build_attn_metadata(batch_size)
393

394
395
        # Run model directly without fusion
        # Still compile so query QuantFP8 has closer numerics
396
397
        compiled_unfused = torch.compile(model_unfused, fullgraph=True)
        result_unfused = compiled_unfused(q, k, v)
398
399
400

    # Run model with attn fusion enabled
    vllm_config.compilation_config.pass_config = PassConfig(
401
        fuse_attn_quant=True, eliminate_noops=True
402
403
404
405
406
407
408
409
410
411
412
413
414
415
    )
    with (
        set_current_vllm_config(vllm_config),
        set_forward_context(attn_metadata=None, vllm_config=vllm_config),
    ):
        model_fused = model_class(
            num_qo_heads=num_qo_heads,
            num_kv_heads=num_kv_heads,
            head_size=head_size,
            kv_cache_dtype=FP8_DTYPE,
            device=device,
            vllm_config=vllm_config,
            w=model_unfused.w,
        )
416
417
418
        model_fused = model_fused.to(device)

        forward_ctx = get_forward_context()
419
        forward_ctx.attn_metadata = model_fused.build_attn_metadata(batch_size)
420
421
422

        # Create test backend with fusion passes enabled
        noop_pass = NoOpEliminationPass(vllm_config)
423
424
425
426
        attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
        cleanup_pass = PostCleanupPass(vllm_config)

        test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass)
427
428
        # HACK: See https://github.com/vllm-project/vllm/issues/31044
        result_fused_0 = model_fused(q, k, v)  # noqa: F841
429
430

        # Compile model with fusion enabled
431
        compiled_fused = torch.compile(
432
433
            model_fused, backend=test_backend, fullgraph=True
        )
434
        assert compiled_fused.attn._o_scale_float is None
435

436
        result_fused = compiled_fused(q, k, v)
437

438
        if backend == AttentionBackendEnum.FLASHINFER:
439
440
441
442
            # With the Flashinfer backend after the 1st round of the forward
            # pass, output quant scale should be loaded into the attn layer's
            # _o_scale_float, the 2nd round should reuse the loaded
            # _o_scale_float
443
444
            assert compiled_fused.attn._o_scale_float is not None
            result_fused_2 = compiled_fused(q, k, v)
445

446
            assert compiled_fused.attn._o_scale_float is not None
447

448
449
450
            torch.testing.assert_close(
                result_unfused, result_fused_2, atol=1e-2, rtol=1e-2
            )
451
452

    # Check attn fusion support
453
    quant_key: QuantKey = model_class.quant_key
454
    attn_fusion_supported = [
455
456
        layer.impl.fused_output_quant_supported(quant_key)
        for key, layer in vllm_config.compilation_config.static_forward_context.items()
457
    ]
458
459
460
461
462
463
464
465
466
467
468
469
470
471
    assert sum(attn_fusion_supported) == len(attn_fusion_supported), (
        "All layers should support attention fusion"
    )

    # Check quantization ops in the graph before and after fusion
    quant_op = (
        torch.ops.aten.reciprocal
        if "-quant_fp8" in custom_ops_list
        else QUANT_OPS[quant_key]
    )

    # Note: for fp8, fully_replaced=False because query quant ops remain in graph.
    # Only output quant ops are fused into attention.
    test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Quant)
472

473
474
475
    # access the underlying `AttnFusionPass` on the `LazyInitPass`
    assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)

476
477
    # Check attention ops in the graph before and after fusion
    attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass))
478
    attn_nodes_post = list(find_op_nodes(ATTN_OP, test_backend.graph_post_pass))
479
480

    assert len(attn_nodes_pre) > 0, "Should have attention nodes before fusion"
481
    assert len(attn_nodes_pre) == len(attn_nodes_post), (
482
        "Should have same number of attention nodes before and after fusion"
483
484
    )
    assert attn_nodes_pre[0].kwargs.get("output_scale") is None, (
485
        "Attention should not have output_scale before fusion"
486
487
    )
    assert attn_nodes_post[0].kwargs.get("output_scale") is not None, (
488
        "Attention should have output_scale after fusion"
489
    )
490

491
    assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, (
492
        "Attention should not have output_block_scale before fusion"
493
    )
494
    if quant_key.dtype == FP8_DTYPE:
495
        assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, (
496
            "Attention should not have output_block_scale after FP8 fusion"
497
        )
498
    elif quant_key.dtype == FP4_DTYPE:
499
500
        assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, (
            "Attention should have output_block_scale after FP4 fusion"
501
        )
502

503
    # Check that results are close
504
    torch.testing.assert_close(result_unfused, result_fused, atol=1e-2, rtol=1e-2)
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589


@pytest.mark.parametrize(
    "model_name, model_kwargs, backend, matches, custom_ops",
    # Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
    list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8))
    # quant_fp4 only has the custom impl
    + list(flat_product(MODELS_FP4, [""])),
)
@pytest.mark.parametrize(
    "inductor_graph_partition",
    [
        pytest.param(
            True,
            marks=pytest.mark.skipif(
                not has_cuda_graph_wrapper_metadata(),
                reason="This test requires"
                "torch._inductor.utils.CUDAGraphWrapperMetadata to run",
            ),
        ),
        False,
    ],
)
def test_attn_quant(
    model_name: str,
    model_kwargs: dict[str, Any],
    backend: AttentionBackendEnum,
    matches: Matches,
    custom_ops: str,
    inductor_graph_partition: bool,
    caplog_mp_spawn,
    monkeypatch,
):
    if not current_platform.has_device_capability(90):
        pytest.skip("test_attn_quant requires H100 (SM90) or B200 (SM100) GPU")
    if backend == AttentionBackendEnum.FLASHINFER and (
        not is_blackwell() or not has_flashinfer()
    ):
        pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
    if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
        pytest.skip("Inductor graph partition requires torch>=2.9")

    custom_ops_list = custom_ops.split(",") if custom_ops else []

    if inductor_graph_partition:
        mode = CUDAGraphMode.FULL_AND_PIECEWISE
        splitting_ops: list[str] | None = None
    else:
        # FIXME: Llama-4-Scout-17B-16E-Instruct-FP8 + FlashInfer + Blackwell end at
        # CUDAGraphMode.NONE here because it derives an attention backend that
        # does not support full cudagraphs
        mode = CUDAGraphMode.FULL_DECODE_ONLY
        splitting_ops = []

    # Disable, compile cache to make sure custom passes run.
    # Otherwise, we can't verify fusion happened through the logs.
    monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")

    # To capture subprocess logs, we need to know whether spawn or fork is used.
    # Force spawn as it is more general.
    monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
    model_kwargs["attention_config"] = {"backend": backend.name}

    compilation_config = CompilationConfig(
        # Testing properties
        custom_ops=custom_ops_list,
        use_inductor_graph_partition=inductor_graph_partition,
        cudagraph_mode=mode,
        splitting_ops=splitting_ops,
        # Common
        mode=CompilationMode.VLLM_COMPILE,
        pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
        # Inductor caches custom passes by default as well via uuid
        inductor_compile_config={"force_disable_caches": True},
    )

    with caplog_mp_spawn(logging.DEBUG) as log_holder:
        run_model(compilation_config, model_name, **model_kwargs)

    log_matches = re.findall(
        r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
        log_holder.text,
    )
    assert len(log_matches) == 1, log_holder.text
    assert int(log_matches[0]) == matches.attention_fusion