test_fusion_attn.py 16.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import copy
4
5
6
7
8
from typing import Optional

import pytest
import torch._dynamo

9
from tests.compile.backend import LazyInitPass, TestBackend
10
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
11
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
12
from vllm.attention import Attention, AttentionMetadata
13
from vllm.attention.backends.registry import _Backend
14
from vllm.attention.selector import global_force_attn_backend_context_manager
15
from vllm.compilation.fusion import QUANT_OPS
16
17
18
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.noop_elimination import NoOpEliminationPass
19
from vllm.compilation.post_cleanup import PostCleanupPass
20
21
22
23
24
25
26
27
28
29
from vllm.config import (
    CacheConfig,
    CompilationConfig,
    CompilationLevel,
    ModelConfig,
    PassConfig,
    SchedulerConfig,
    VllmConfig,
    set_current_vllm_config,
)
30
31
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.quantization.utils.quant_utils import (
32
33
34
35
    kFp8StaticTensorSym,
    kNvfp4Quant,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
36
from vllm.platforms import current_platform
37
from vllm.utils import is_torch_equal_or_newer
38
39
40
from vllm.v1.kv_cache_interface import AttentionSpec

FP8_DTYPE = current_platform.fp8_dtype()
41
FP4_DTYPE = torch.uint8
42
43
44
45
46
47

# globals needed for string-import custom Dynamo backend field
backend: Optional[TestBackend] = None
backend_unfused: Optional[TestBackend] = None


48
49
class AttentionQuantPatternModel(torch.nn.Module):
    """Base model for AttentionQuantPattern fusion."""
50

51
52
53
54
55
56
57
58
59
60
    def __init__(
        self,
        num_qo_heads: int,
        num_kv_heads: int,
        head_size: int,
        kv_cache_dtype: torch.dtype,
        device: torch.device,
        vllm_config: VllmConfig,
        **kwargs,
    ):
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
        super().__init__()
        self.num_qo_heads = num_qo_heads
        self.num_kv_heads = num_kv_heads
        self.head_size = head_size
        self.kv_cache_dtype = kv_cache_dtype
        self.device = device
        self.vllm_config = vllm_config

        self.attn = Attention(
            num_heads=self.num_qo_heads,
            head_size=self.head_size,
            scale=1.0 / (self.head_size**0.5),
            num_kv_heads=self.num_kv_heads,
            cache_config=vllm_config.cache_config,
            prefix="model.layers.0.self_attn.attn",
        )
77
78
        self.attn._k_scale = self.attn._k_scale.to(device)
        self.attn._v_scale = self.attn._v_scale.to(device)
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94

        self.block_size = 16

        # Initialize attn MetadataBuilder
        self.builder = self.attn.attn_backend.get_builder_cls()(
            kv_cache_spec=AttentionSpec(
                block_size=self.block_size,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
                dtype=self.kv_cache_dtype,
            ),
            layer_names=[self.attn.layer_name],
            vllm_config=self.vllm_config,
            device=self.device,
        )

95
    def build_attn_metadata(self, batch_size: int) -> AttentionMetadata:
96
97
98
        """Initialize attention metadata."""

        # Create common attn metadata
99
        batch_spec = BatchSpec(seq_lens=[1] * batch_size, query_lens=[1] * batch_size)
100
        common_attn_metadata = create_common_attn_metadata(
101
102
            batch_spec, self.block_size, self.device, arange_block_indices=True
        )
103

104
        max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
105
        num_blocks = batch_size * max_blocks
106
        backend = self.attn.backend
107

108
109
        # Create dummy KV cache for the selected backend
        if backend == _Backend.ROCM_ATTN:
110
            # k/v as 1st dimention
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
            # HND: [num_blocks, num_kv_heads, block_size, head_size]
            kv_cache = torch.zeros(
                2,
                num_blocks,
                self.num_kv_heads,
                self.block_size,
                self.head_size,
                dtype=self.kv_cache_dtype,
                device=self.device,
            )
        elif backend == _Backend.ROCM_AITER_UNIFIED_ATTN:
            # k/v as 1st dimention
            # NHD: [num_blocks, block_size, num_kv_heads, head_size]
            kv_cache = torch.zeros(
                2,
                num_blocks,
                self.block_size,
                self.num_kv_heads,
                self.head_size,
                dtype=self.kv_cache_dtype,
                device=self.device,
            )
        elif backend == _Backend.TRITON_ATTN:
134
            # k/v as 2nd dimention
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
            # NHD: [num_blocks, block_size, num_kv_heads, head_size]
            kv_cache = torch.zeros(
                num_blocks,
                2,
                self.num_kv_heads,
                self.block_size,
                self.head_size,
                dtype=self.kv_cache_dtype,
                device=self.device,
            )
        elif backend == _Backend.FLASHINFER:
            kv_cache = torch.zeros(
                num_blocks,
                2,
                self.num_kv_heads,
                self.block_size,
                self.head_size,
                dtype=self.kv_cache_dtype,
                device=self.device,
            ).permute(0, 1, 3, 2, 4)
        else:
            raise ValueError(f"Unsupported backend: {backend}")
157
158
159
160
        self.attn.kv_cache = [kv_cache]

        # Build attn metadata
        self.attn_metadata = self.builder.build(
161
162
            common_prefix_len=0, common_attn_metadata=common_attn_metadata
        )
163
164
165

        return self.attn_metadata

166
167
168
169
170
171
172
173
174
175
176

class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
    """Test model for AttentionFp8StaticQuantPattern fusion."""

    quant_key = kFp8StaticTensorSym

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.fp8_linear = Fp8LinearOp(
            act_quant_static=self.quant_key.scale.static,
177
178
            act_quant_group_shape=self.quant_key.scale.group_shape,
        )
179
180
181

        hidden_size = self.num_qo_heads * self.head_size
        self.w = kwargs.get(
182
183
184
185
186
187
188
189
190
            "w",
            {
                "weight": torch.randn(hidden_size, hidden_size)
                .to(dtype=FP8_DTYPE, device=self.device)
                .t(),
                "wscale": torch.tensor([1.0], dtype=torch.float32, device=self.device),
                "scale": torch.tensor([1.0], dtype=torch.float32, device=self.device),
            },
        )
191
192

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
193
194
        """Forward pass that creates the pattern to be fused."""
        attn_output = self.attn(q, k, v)
195
196
197
198
199
200
        return self.fp8_linear.apply(
            input=attn_output,
            weight=self.w["weight"],
            weight_scale=self.w["wscale"],
            input_scale=self.w["scale"],
        )
201
202
203
204
205
206
207
208
209
210
211
212


class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
    """Test model for AttentionNvfp4QuantPattern fusion."""

    quant_key = kNvfp4Quant

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        hidden_size = self.num_qo_heads * self.head_size
        self.w = kwargs.get(
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
            "w",
            {
                "weight": torch.randint(
                    256,
                    (hidden_size, hidden_size // 2),
                    dtype=FP4_DTYPE,
                    device=self.device,
                ),
                "wscale_swizzled": torch.randn(hidden_size, hidden_size // 16).to(
                    dtype=FP8_DTYPE, device=self.device
                ),
                "wscale": torch.tensor([500], dtype=torch.float32, device=self.device),
                "scale": torch.tensor([0.002], dtype=torch.float32, device=self.device),
            },
        )
228
229
230
231
232

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
        """Forward pass that creates the pattern to be fused."""
        attn_output = self.attn(q, k, v)
        quant_output, output_block_scale = scaled_fp4_quant(
233
234
235
236
237
238
239
240
241
242
            attn_output, 1 / self.w["scale"]
        )
        return cutlass_scaled_fp4_mm(
            a=quant_output,
            b=self.w["weight"],
            block_scale_a=output_block_scale,
            block_scale_b=self.w["wscale_swizzled"],
            alpha=self.w["scale"] * self.w["wscale"],
            out_dtype=attn_output.dtype,
        )
243
244


245
if current_platform.is_cuda():
246
247
248
249
250
251
252
253
254
255
    MODELS = [
        (
            "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
            TestAttentionFp8StaticQuantPatternModel,
        ),
        (
            "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
            TestAttentionNvfp4QuantPatternModel,
        ),
    ]
256
257
    HEADS = [(64, 8), (40, 8)]
elif current_platform.is_rocm():
258
259
260
    MODELS = [
        ("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel)
    ]
261
262
263
264
265
266
267
    HEADS = [(32, 8), (40, 8)]
else:
    MODELS = []
    HEADS = []


@pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS)
268
@pytest.mark.parametrize("head_size", [128])
269
270
271
@pytest.mark.parametrize(
    "batch_size", [7, 256, 533] if current_platform.is_cuda() else [8]
)
272
273
274
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("model_name, model_class", MODELS)
@pytest.mark.parametrize(
275
    "backend",
276
277
278
    [_Backend.FLASHINFER]
    if current_platform.is_cuda()
    else [_Backend.ROCM_AITER_UNIFIED_ATTN, _Backend.ROCM_ATTN, _Backend.TRITON_ATTN],
279
)
280
281
282
# TODO(boyuan): test inductor graph partition on rocm
@pytest.mark.parametrize(
    "use_inductor_graph_partition",
283
284
285
286
287
    [False] if current_platform.is_rocm() else [False, True],
)
@pytest.mark.skipif(
    not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
)
288
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
@pytest.mark.skipif(
    current_platform.is_cuda() and not current_platform.is_device_capability((10, 0)),
    reason="On CUDA only test on SM100(Blackwell)",
)
@pytest.mark.skipif(
    not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
)
def test_attention_quant_pattern(
    num_qo_heads: int,
    num_kv_heads: int,
    head_size: int,
    batch_size: int,
    dtype: torch.dtype,
    model_name: str,
    model_class: type[AttentionQuantPatternModel],
    backend: _Backend,
    use_inductor_graph_partition: bool,
    dist_init,
    caplog_vllm,
):
309
310
    """Test AttentionStaticQuantPattern fusion pass"""

311
312
    if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
        pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
313

314
315
316
317
318
319
320
    device = torch.device("cuda:0")
    torch.manual_seed(42)

    vllm_config = VllmConfig(
        model_config=ModelConfig(
            model=model_name,
            max_model_len=2048,
321
            dtype=dtype,
322
323
324
325
326
        ),
        scheduler_config=SchedulerConfig(max_num_seqs=1024),
        compilation_config=CompilationConfig(
            level=CompilationLevel.PIECEWISE,
            custom_ops=["+quant_fp8"],
327
            use_inductor_graph_partition=use_inductor_graph_partition,
328
        ),
329
330
        cache_config=CacheConfig(cache_dtype="fp8"),
    )
331
332

    # Create test inputs
333
334
335
    q = torch.randn(batch_size, num_qo_heads * head_size, dtype=dtype, device=device)
    k = torch.randn(batch_size, num_kv_heads * head_size, dtype=dtype, device=device)
    v = torch.randn(batch_size, num_kv_heads * head_size, dtype=dtype, device=device)
336
337
338
339
340
341
342
343

    # Mark first dimension as dynamic for realistic testing
    torch._dynamo.mark_dynamic(q, 0)
    torch._dynamo.mark_dynamic(k, 0)
    torch._dynamo.mark_dynamic(v, 0)

    # Run model directly without compilation and fusion
    vllm_config_unfused = copy.deepcopy(vllm_config)
344
345
346
347
348
349
350
351
352
353
354
355
356
    with (
        set_current_vllm_config(vllm_config_unfused),
        set_forward_context(attn_metadata=None, vllm_config=vllm_config_unfused),
        global_force_attn_backend_context_manager(backend),
    ):
        model_unfused = model_class(
            num_qo_heads=num_qo_heads,
            num_kv_heads=num_kv_heads,
            head_size=head_size,
            kv_cache_dtype=FP8_DTYPE,
            device=device,
            vllm_config=vllm_config_unfused,
        )
357
358
359
        model_unfused = model_unfused.to(device)

        forward_ctx = get_forward_context()
360
        forward_ctx.attn_metadata = model_unfused.build_attn_metadata(batch_size)
361
362

        # Run model directly without compilation and fusion
363
        result_unfused = model_unfused(q, k, v)
364
365
366

    # Run model with attn fusion enabled
    vllm_config.compilation_config.pass_config = PassConfig(
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
        enable_attn_fusion=True, enable_noop=True
    )
    with (
        set_current_vllm_config(vllm_config),
        set_forward_context(attn_metadata=None, vllm_config=vllm_config),
        global_force_attn_backend_context_manager(backend),
    ):
        model_fused = model_class(
            num_qo_heads=num_qo_heads,
            num_kv_heads=num_kv_heads,
            head_size=head_size,
            kv_cache_dtype=FP8_DTYPE,
            device=device,
            vllm_config=vllm_config,
            w=model_unfused.w,
        )
383
384
385
        model_fused = model_fused.to(device)

        forward_ctx = get_forward_context()
386
        forward_ctx.attn_metadata = model_fused.build_attn_metadata(batch_size)
387
388
389

        # Create test backend with fusion passes enabled
        noop_pass = NoOpEliminationPass(vllm_config)
390
391
392
393
        attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
        cleanup_pass = PostCleanupPass(vllm_config)

        test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass)
394
395

        # Compile model with fusion enabled
396
397
398
        model_compiled = torch.compile(
            model_fused, backend=test_backend, fullgraph=True
        )
399
        assert model_compiled.attn._o_scale_float is None
400

401
        result_fused_1 = model_compiled(q, k, v)
402

403
404
405
406
407
408
409
        if backend == _Backend.FLASHINFER:
            # With the Flashinfer backend after the 1st round of the forward
            # pass, output quant scale should be loaded into the attn layer's
            # _o_scale_float, the 2nd round should reuse the loaded
            # _o_scale_float
            assert model_compiled.attn._o_scale_float is not None
            result_fused_2 = model_compiled(q, k, v)
410

411
412
            assert model_compiled.attn._o_scale_float is not None

413
414
415
            torch.testing.assert_close(
                result_unfused, result_fused_2, atol=1e-2, rtol=1e-2
            )
416
417

    # Check attn fusion support
418
    quant_key = model_class.quant_key
419
    attn_fusion_supported = [
420
421
        layer.impl.fused_output_quant_supported(quant_key)
        for key, layer in vllm_config.compilation_config.static_forward_context.items()
422
423
424
    ]
    if any(attn_fusion_supported):
        # Check quantization ops in the graph before and after fusion
425
        test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=True)
426

427
428
429
    # access the underlying `AttnFusionPass` on the `LazyInitPass`
    assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)

430
431
    # Check attention ops in the graph before and after fusion
    attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass))
432
    attn_nodes_post = list(find_op_nodes(ATTN_OP, test_backend.graph_post_pass))
433
434

    assert len(attn_nodes_pre) > 0, "Should have attention nodes before fusion"
435
    assert len(attn_nodes_pre) == len(attn_nodes_post), (
436
        "Should have same number of attention nodes before and after fusion"
437
438
    )
    assert attn_nodes_pre[0].kwargs.get("output_scale") is None, (
439
        "Attention should not have output_scale before fusion"
440
441
    )
    assert attn_nodes_post[0].kwargs.get("output_scale") is not None, (
442
        "Attention should have output_scale after fusion"
443
    )
444

445
    assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, (
446
        "Attention should not have output_block_scale before fusion"
447
    )
448
    if quant_key.dtype == FP8_DTYPE:
449
        assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, (
450
            "Attention should not have output_block_scale after FP8 fusion"
451
        )
452
    elif quant_key.dtype == FP4_DTYPE:
453
454
        assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, (
            "Attention should have output_block_scale after FP4 fusion"
455
        )
456

457
    # Check that results are close
458
    torch.testing.assert_close(result_unfused, result_fused_1, atol=1e-2, rtol=1e-2)