test_fusion_attn.py 17.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import copy
4
5
6
7

import pytest
import torch._dynamo

8
from tests.compile.backend import LazyInitPass, TestBackend
9
from tests.utils import TestFP8Layer, flat_product
10
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
11
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
12
13
14
15
16
from vllm.compilation.passes.fusion.attn_quant_fusion import ATTN_OP, AttnFusionPass
from vllm.compilation.passes.fusion.matcher_utils import QUANT_OPS
from vllm.compilation.passes.fx_utils import find_op_nodes
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
17
from vllm.config import (
18
    AttentionConfig,
19
20
    CacheConfig,
    CompilationConfig,
21
    CompilationMode,
22
23
24
25
26
27
    ModelConfig,
    PassConfig,
    SchedulerConfig,
    VllmConfig,
    set_current_vllm_config,
)
28
from vllm.forward_context import get_forward_context, set_forward_context
29
from vllm.model_executor.layers.attention import Attention
30
from vllm.model_executor.layers.quantization.utils.quant_utils import (
31
    QuantKey,
32
    kFp8StaticTensorSym,
33
    kNvfp4Dynamic,
34
)
35
from vllm.platforms import current_platform
36
from vllm.utils.flashinfer import has_flashinfer
37
38
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.registry import AttentionBackendEnum
39
40
41
from vllm.v1.kv_cache_interface import AttentionSpec

FP8_DTYPE = current_platform.fp8_dtype()
42
FP4_DTYPE = torch.uint8
43
44


45
46
class AttentionQuantPatternModel(torch.nn.Module):
    """Base model for AttentionQuantPattern fusion."""
47

48
49
50
51
52
53
54
55
56
57
    def __init__(
        self,
        num_qo_heads: int,
        num_kv_heads: int,
        head_size: int,
        kv_cache_dtype: torch.dtype,
        device: torch.device,
        vllm_config: VllmConfig,
        **kwargs,
    ):
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
        super().__init__()
        self.num_qo_heads = num_qo_heads
        self.num_kv_heads = num_kv_heads
        self.head_size = head_size
        self.kv_cache_dtype = kv_cache_dtype
        self.device = device
        self.vllm_config = vllm_config

        self.attn = Attention(
            num_heads=self.num_qo_heads,
            head_size=self.head_size,
            scale=1.0 / (self.head_size**0.5),
            num_kv_heads=self.num_kv_heads,
            cache_config=vllm_config.cache_config,
            prefix="model.layers.0.self_attn.attn",
        )
74
75
        self.attn._k_scale = self.attn._k_scale.to(device)
        self.attn._v_scale = self.attn._v_scale.to(device)
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91

        self.block_size = 16

        # Initialize attn MetadataBuilder
        self.builder = self.attn.attn_backend.get_builder_cls()(
            kv_cache_spec=AttentionSpec(
                block_size=self.block_size,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
                dtype=self.kv_cache_dtype,
            ),
            layer_names=[self.attn.layer_name],
            vllm_config=self.vllm_config,
            device=self.device,
        )

92
    def build_attn_metadata(self, batch_size: int) -> AttentionMetadata:
93
94
        """Initialize attention metadata."""

95
96
        # TODO (Rohan138) reuse utils from vllm/v1/worker/gpu/attn_utils.py

97
        # Create common attn metadata
98
        batch_spec = BatchSpec(seq_lens=[1] * batch_size, query_lens=[1] * batch_size)
99
        common_attn_metadata = create_common_attn_metadata(
100
101
            batch_spec, self.block_size, self.device, arange_block_indices=True
        )
102

103
        max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
104
        num_blocks = batch_size * max_blocks
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129

        # Fetch the attention backend and kv cache shape and stride order
        attn_backend = self.attn.attn_backend
        kv_cache_shape = attn_backend.get_kv_cache_shape(
            num_blocks, self.block_size, self.num_kv_heads, self.head_size
        )
        try:
            kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
        except (AttributeError, NotImplementedError):
            kv_cache_stride_order = tuple(range(len(kv_cache_shape)))

        kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
        inv_order = [
            kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order))
        ]

        # Create dummy KV cache
        raw_tensor = torch.zeros(
            2 * num_blocks * self.block_size * self.num_kv_heads * self.head_size,
            dtype=self.kv_cache_dtype,
            device=self.device,
        )
        raw_tensor = raw_tensor.view(kv_cache_shape)
        kv_cache = raw_tensor.permute(*inv_order)

130
131
132
133
        self.attn.kv_cache = [kv_cache]

        # Build attn metadata
        self.attn_metadata = self.builder.build(
134
135
            common_prefix_len=0, common_attn_metadata=common_attn_metadata
        )
136
137
138

        return self.attn_metadata

139
140
141
142
143
144
145
146
147
148

class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
    """Test model for AttentionFp8StaticQuantPattern fusion."""

    quant_key = kFp8StaticTensorSym

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        hidden_size = self.num_qo_heads * self.head_size
149
150
151
152
153
        self.fp8_linear = TestFP8Layer(
            weight_shape=(hidden_size, hidden_size),
            activation_quant_key=self.quant_key,
            weight_quant_key=self.quant_key,
            device=self.device,
154
        )
155

156
157
158
159
160
161
162
163
164
165
166
167
        w = kwargs.get("w")
        if w is not None:
            self.fp8_linear.weight = w["weight"]
            self.fp8_linear.weight_scale = w["wscale"]
            self.fp8_linear.input_scale = w["scale"]

        self.w = {
            "weight": self.fp8_linear.weight,
            "wscale": self.fp8_linear.weight_scale,
            "scale": self.fp8_linear.input_scale,
        }

168
    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
169
170
        """Forward pass that creates the pattern to be fused."""
        attn_output = self.attn(q, k, v)
171
        return self.fp8_linear(attn_output)
172
173
174
175
176


class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
    """Test model for AttentionNvfp4QuantPattern fusion."""

177
    quant_key = kNvfp4Dynamic
178
179
180
181
182
183

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        hidden_size = self.num_qo_heads * self.head_size
        self.w = kwargs.get(
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
            "w",
            {
                "weight": torch.randint(
                    256,
                    (hidden_size, hidden_size // 2),
                    dtype=FP4_DTYPE,
                    device=self.device,
                ),
                "wscale_swizzled": torch.randn(hidden_size, hidden_size // 16).to(
                    dtype=FP8_DTYPE, device=self.device
                ),
                "wscale": torch.tensor([500], dtype=torch.float32, device=self.device),
                "scale": torch.tensor([0.002], dtype=torch.float32, device=self.device),
            },
        )
199
200
201
202
203

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
        """Forward pass that creates the pattern to be fused."""
        attn_output = self.attn(q, k, v)
        quant_output, output_block_scale = scaled_fp4_quant(
204
205
206
207
208
209
210
211
212
213
            attn_output, 1 / self.w["scale"]
        )
        return cutlass_scaled_fp4_mm(
            a=quant_output,
            b=self.w["weight"],
            block_scale_a=output_block_scale,
            block_scale_b=self.w["wscale_swizzled"],
            alpha=self.w["scale"] * self.w["wscale"],
            out_dtype=attn_output.dtype,
        )
214
215


216
217
PATTERN_TEST_MODELS_FP8: list[tuple[str, type]] = []
PATTERN_TEST_MODELS_FP4: list[tuple[str, type]] = []
218
219
HEADS: list[tuple[int, int]] = []
SPLIT_ATTENTION: list[bool] = []
220
221
BACKENDS_FP8: list[AttentionBackendEnum] = []
BACKENDS_FP4: list[AttentionBackendEnum] = []
222

223
if current_platform.is_cuda():
224
    HEADS = [(64, 8), (40, 8)]
225
    PATTERN_TEST_MODELS_FP8 = [
226
        (
227
            "RedHatAI/Meta-Llama-3.1-8B-FP8",
228
            TestAttentionFp8StaticQuantPatternModel,
229
230
        )
    ]
231
    PATTERN_TEST_MODELS_FP4 = [
232
        (
233
            "nvidia/Llama-3.1-8B-Instruct-NVFP4",
234
            TestAttentionNvfp4QuantPatternModel,
235
        )
236
    ]
237
238
    BACKENDS_FP8 = [AttentionBackendEnum.TRITON_ATTN, AttentionBackendEnum.FLASHINFER]
    BACKENDS_FP4 = [AttentionBackendEnum.FLASHINFER]
239

240
elif current_platform.is_rocm():
241
    HEADS = [(32, 8), (40, 8)]
242
    PATTERN_TEST_MODELS_FP8 = [
243
244
        ("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel)
    ]
245
    BACKENDS_FP8 = [
246
247
248
        AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
        AttentionBackendEnum.ROCM_ATTN,
        AttentionBackendEnum.TRITON_ATTN,
249
    ]
250
251
252


@pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS)
253
@pytest.mark.parametrize("head_size", [128])
254
255
256
@pytest.mark.parametrize(
    "batch_size", [7, 256, 533] if current_platform.is_cuda() else [8]
)
257
258
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize(
259
260
    "backend, model_name, model_class, custom_ops",
    # Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
261
262
263
264
265
    list(
        flat_product(
            BACKENDS_FP8, PATTERN_TEST_MODELS_FP8, ["+quant_fp8", "-quant_fp8"]
        )
    )
266
    # quant_fp4 only has the custom impl
267
    + list(flat_product(BACKENDS_FP4, PATTERN_TEST_MODELS_FP4, [""])),
268
269
270
271
)
@pytest.mark.skipif(
    not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
)
272
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
273
274
275
276
277
278
def test_attention_quant_pattern(
    num_qo_heads: int,
    num_kv_heads: int,
    head_size: int,
    batch_size: int,
    dtype: torch.dtype,
279
    custom_ops: str,
280
281
    model_name: str,
    model_class: type[AttentionQuantPatternModel],
282
    backend: AttentionBackendEnum,
283
    dist_init,
284
285
    monkeypatch,
    use_fresh_inductor_cache,
286
):
287
    """Test AttentionStaticQuantPattern fusion pass"""
288
289
    monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")

290
    if backend == AttentionBackendEnum.FLASHINFER and (
291
292
        not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
    ):
293
        # This also captures the FP4 case
294
        pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
295

296
    custom_ops_list = custom_ops.split(",") if custom_ops else []
297

298
    device = torch.device("cuda:0")
299
    torch.set_default_dtype(dtype)
300
301
    torch.manual_seed(42)

302
303
304
305
306
    model_config = ModelConfig(
        model=model_name,
        max_model_len=2048,
        dtype=dtype,
    )
307
    vllm_config = VllmConfig(
308
309
310
311
312
        model_config=model_config,
        scheduler_config=SchedulerConfig(
            max_num_seqs=1024,
            max_model_len=model_config.max_model_len,
            is_encoder_decoder=model_config.is_encoder_decoder,
313
314
        ),
        compilation_config=CompilationConfig(
315
            mode=CompilationMode.VLLM_COMPILE,
316
            custom_ops=custom_ops_list,
317
        ),
318
        cache_config=CacheConfig(cache_dtype="fp8"),
319
        attention_config=AttentionConfig(backend=backend),
320
    )
321
322

    # Create test inputs
323
324
325
    q = torch.randn(batch_size, num_qo_heads * head_size, dtype=dtype, device=device)
    k = torch.randn(batch_size, num_kv_heads * head_size, dtype=dtype, device=device)
    v = torch.randn(batch_size, num_kv_heads * head_size, dtype=dtype, device=device)
326
327
328
329
330
331
332
333

    # Mark first dimension as dynamic for realistic testing
    torch._dynamo.mark_dynamic(q, 0)
    torch._dynamo.mark_dynamic(k, 0)
    torch._dynamo.mark_dynamic(v, 0)

    # Run model directly without compilation and fusion
    vllm_config_unfused = copy.deepcopy(vllm_config)
334
335
336
337
338
339
340
341
342
343
344
345
    with (
        set_current_vllm_config(vllm_config_unfused),
        set_forward_context(attn_metadata=None, vllm_config=vllm_config_unfused),
    ):
        model_unfused = model_class(
            num_qo_heads=num_qo_heads,
            num_kv_heads=num_kv_heads,
            head_size=head_size,
            kv_cache_dtype=FP8_DTYPE,
            device=device,
            vllm_config=vllm_config_unfused,
        )
346
        model_unfused = model_unfused.to(device)
347
        result_unfused_0 = model_unfused(q, k, v)  # noqa: F841  HACK: See #131044
348
349

        forward_ctx = get_forward_context()
350
        forward_ctx.attn_metadata = model_unfused.build_attn_metadata(batch_size)
351

352
353
        # Run model directly without fusion
        # Still compile so query QuantFP8 has closer numerics
354
355
        compiled_unfused = torch.compile(model_unfused, fullgraph=True)
        result_unfused = compiled_unfused(q, k, v)
356
357
358

    # Run model with attn fusion enabled
    vllm_config.compilation_config.pass_config = PassConfig(
359
        fuse_attn_quant=True, eliminate_noops=True
360
361
362
363
364
365
366
367
368
369
370
371
372
373
    )
    with (
        set_current_vllm_config(vllm_config),
        set_forward_context(attn_metadata=None, vllm_config=vllm_config),
    ):
        model_fused = model_class(
            num_qo_heads=num_qo_heads,
            num_kv_heads=num_kv_heads,
            head_size=head_size,
            kv_cache_dtype=FP8_DTYPE,
            device=device,
            vllm_config=vllm_config,
            w=model_unfused.w,
        )
374
375
376
        model_fused = model_fused.to(device)

        forward_ctx = get_forward_context()
377
        forward_ctx.attn_metadata = model_fused.build_attn_metadata(batch_size)
378
379
380

        # Create test backend with fusion passes enabled
        noop_pass = NoOpEliminationPass(vllm_config)
381
382
383
384
        attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
        cleanup_pass = PostCleanupPass(vllm_config)

        test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass)
385
386
        # HACK: See https://github.com/vllm-project/vllm/issues/31044
        result_fused_0 = model_fused(q, k, v)  # noqa: F841
387
388

        # Compile model with fusion enabled
389
        compiled_fused = torch.compile(
390
391
            model_fused, backend=test_backend, fullgraph=True
        )
392
        assert compiled_fused.attn._o_scale_float is None
393

394
        result_fused = compiled_fused(q, k, v)
395

396
        if backend == AttentionBackendEnum.FLASHINFER:
397
398
399
400
            # With the Flashinfer backend after the 1st round of the forward
            # pass, output quant scale should be loaded into the attn layer's
            # _o_scale_float, the 2nd round should reuse the loaded
            # _o_scale_float
401
402
            assert compiled_fused.attn._o_scale_float is not None
            result_fused_2 = compiled_fused(q, k, v)
403

404
            assert compiled_fused.attn._o_scale_float is not None
405

406
407
408
            torch.testing.assert_close(
                result_unfused, result_fused_2, atol=1e-2, rtol=1e-2
            )
409
410

    # Check attn fusion support
411
    quant_key: QuantKey = model_class.quant_key
412
    attn_fusion_supported = [
413
414
        layer.impl.fused_output_quant_supported(quant_key)
        for key, layer in vllm_config.compilation_config.static_forward_context.items()
415
    ]
416
417
418
419
420
421
422
423
424
425
426
427
428
    assert sum(attn_fusion_supported) == len(attn_fusion_supported), (
        "All layers should support attention fusion"
    )

    # Check quantization ops in the graph before and after fusion
    quant_op = (
        torch.ops.aten.reciprocal
        if "-quant_fp8" in custom_ops_list
        else QUANT_OPS[quant_key]
    )

    # Note: for fp8, fully_replaced=False because query quant ops remain in graph.
    # Only output quant ops are fused into attention.
429
    test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Dynamic)
430

431
432
433
    # access the underlying `AttnFusionPass` on the `LazyInitPass`
    assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)

434
435
    # Check attention ops in the graph before and after fusion
    attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass))
436
    attn_nodes_post = list(find_op_nodes(ATTN_OP, test_backend.graph_post_pass))
437
438

    assert len(attn_nodes_pre) > 0, "Should have attention nodes before fusion"
439
    assert len(attn_nodes_pre) == len(attn_nodes_post), (
440
        "Should have same number of attention nodes before and after fusion"
441
442
    )
    assert attn_nodes_pre[0].kwargs.get("output_scale") is None, (
443
        "Attention should not have output_scale before fusion"
444
445
    )
    assert attn_nodes_post[0].kwargs.get("output_scale") is not None, (
446
        "Attention should have output_scale after fusion"
447
    )
448

449
    assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, (
450
        "Attention should not have output_block_scale before fusion"
451
    )
452
453
454
455
456
457
458
459
460
461
462

    kv_cache_dummy_dep_pre_is_none = (
        attn_nodes_pre[0].kwargs.get("kv_cache_dummy_dep") is None
    )
    kv_cache_dummy_dep_post_is_none = (
        attn_nodes_post[0].kwargs.get("kv_cache_dummy_dep") is None
    )
    assert not (kv_cache_dummy_dep_pre_is_none ^ kv_cache_dummy_dep_post_is_none), (
        "The kv_cache_dummy_dep should be consistent before and after fusion"
    )

463
    if quant_key.dtype == FP8_DTYPE:
464
        assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, (
465
            "Attention should not have output_block_scale after FP8 fusion"
466
        )
467
    elif quant_key.dtype == FP4_DTYPE:
468
469
        assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, (
            "Attention should have output_block_scale after FP4 fusion"
470
        )
471

472
    # Check that results are close
473
    torch.testing.assert_close(result_unfused, result_fused, atol=1e-2, rtol=1e-2)