test_fusion_attn.py 21.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import copy
4
5
6
7
8
9
10
from typing import Optional

import pytest
import torch._dynamo

from tests.compile.backend import TestBackend
from tests.models.utils import check_outputs_equal
11
12
from tests.v1.attention.utils import (BatchSpec, _Backend,
                                      create_common_attn_metadata)
13
from vllm import LLM, SamplingParams
14
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
15
16
from vllm.attention import Attention
from vllm.attention.selector import global_force_attn_backend_context_manager
17
from vllm.compilation.fusion import QUANT_OPS
18
19
20
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.noop_elimination import NoOpEliminationPass
21
22
23
24
25
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
                         ModelConfig, PassConfig, SchedulerConfig, VllmConfig,
                         set_current_vllm_config)
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.quantization.utils.quant_utils import (
26
    QuantKey, kFp8StaticTensorSym, kNvfp4Quant)
27
28
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
    Fp8LinearOp)
29
from vllm.platforms import current_platform
30
from vllm.utils import is_torch_equal_or_newer
31
32
33
from vllm.v1.kv_cache_interface import AttentionSpec

FP8_DTYPE = current_platform.fp8_dtype()
34
FP4_DTYPE = torch.uint8
35
36
37
38
39
40
41
42
43

# globals needed for string-import custom Dynamo backend field
backend: Optional[TestBackend] = None
backend_unfused: Optional[TestBackend] = None


@pytest.mark.parametrize(
    "model, quant_key",
    [("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)])
44
@pytest.mark.parametrize("use_triton_fa", [True, False])
45
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
46
47
48
49
@pytest.mark.skipif(not current_platform.is_rocm(),
                    reason="V0 attn quant fusion only on ROCm")
def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
                             quant_key: QuantKey, use_triton_fa: bool):
50
51
52
53
54
55
56
    # Clean Dynamo cache to avoid reusing other test cases
    # (for some reason the reset at the end is not enough)
    torch._dynamo.reset()

    # Use global backends
    global backend, backend_unfused

57
    monkeypatch.setenv("VLLM_USE_V1", "1")
58
59
60
61
62
63
64
65
66
67
68
    monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", str(int(use_triton_fa)))

    # Prompt 4 seems too open-ended, differs between fused and unfused
    # (both outputs look reasonable though)
    prompts = example_prompts[:4] + example_prompts[5:]

    compile_config = CompilationConfig(
        # DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
        # DYNAMO_ONCE does not properly propagate shapes.
        level=CompilationLevel.DYNAMO_AS_IS,
        backend="tests.compile.test_fusion_attn.backend_unfused",
69
        custom_ops=["+quant_fp8"],
70
    )
71
72
73
74
75
    vllm_config = VllmConfig(compilation_config=compile_config,
                             model_config=ModelConfig(
                                 model=model,
                                 dtype=torch.bfloat16,
                             ))
76
77
78
79
80
    backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))

    llm = LLM(model,
              enforce_eager=True,
              compilation_config=compile_config,
81
              gpu_memory_utilization=0.5,
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
              max_model_len=2048)

    sampling_params = SamplingParams(temperature=0.0,
                                     max_tokens=10,
                                     top_p=0.95)

    unfused_output = llm.generate(prompts, sampling_params)
    backend_unfused = None  # Reset backend to make sure llm gets released
    del llm

    compile_config = CompilationConfig(
        # DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
        # DYNAMO_ONCE does not properly propagate shapes.
        level=CompilationLevel.DYNAMO_AS_IS,
        backend="tests.compile.test_fusion_attn.backend",
97
        custom_ops=["+quant_fp8"],
98
    )
99
100
101
102
103
    vllm_config = VllmConfig(compilation_config=compile_config,
                             model_config=ModelConfig(
                                 model=model,
                                 dtype=torch.bfloat16,
                             ))
104
105
106
107
108
109
110
111

    # AttnFusionPass needs attention layers to be registered in config upon init
    # so we initialize it during compilation.
    attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw)
    backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass)
    llm2 = LLM(model,
               enforce_eager=True,
               compilation_config=compile_config,
112
               gpu_memory_utilization=0.5,
113
114
115
116
               max_model_len=2048)

    # check support
    attn_fusion_supported = [
117
        layer.impl.fused_output_quant_supported(quant_key)
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
        for key, layer in compile_config.static_forward_context.items()
    ]

    print(f"{attn_fusion_supported=}")
    if any(attn_fusion_supported):
        # Check quant ops
        backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False)

    # attention ops present in both, just output_scale param changes
    attn_nodes_pre = list(find_op_nodes(ATTN_OP, backend.graph_pre_pass))
    attn_nodes_post = list(find_op_nodes(ATTN_OP, backend.graph_post_pass))
    assert len(attn_nodes_pre) == len(attn_nodes_post)

    for i in range(len(attn_nodes_pre)):
        assert attn_nodes_pre[i].kwargs["output_scale"] is None
        fused = attn_nodes_post[i].kwargs["output_scale"] is not None
        assert fused == attn_fusion_supported[i], \
            f"Node {i} {'' if fused else 'not '} expected " \
            f"to have fused output quant"

    # check outputs
    fused_output = llm2.generate(prompts, sampling_params)

    # transform outputs to format expected by check_outputs_equal
    sample_outs = lambda s: (list(s.token_ids), s.text)
    outs_lst = lambda ros: [sample_outs(ro.outputs[0]) for ro in ros]

    check_outputs_equal(
        outputs_0_lst=outs_lst(unfused_output),
        outputs_1_lst=outs_lst(fused_output),
        name_0="unfused",
        name_1="fused",
    )

    # Clean Dynamo cache to avoid polluting other case(s)
    torch._dynamo.reset()

    # Reset backend to make sure llm2 gets released
    backend = None
157
158


159
160
class AttentionQuantPatternModel(torch.nn.Module):
    """Base model for AttentionQuantPattern fusion."""
161
162
163

    def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int,
                 kv_cache_dtype: torch.dtype, device: torch.device,
164
                 vllm_config: VllmConfig, **kwargs):
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
        super().__init__()
        self.num_qo_heads = num_qo_heads
        self.num_kv_heads = num_kv_heads
        self.head_size = head_size
        self.kv_cache_dtype = kv_cache_dtype
        self.device = device
        self.vllm_config = vllm_config

        self.attn = Attention(
            num_heads=self.num_qo_heads,
            head_size=self.head_size,
            scale=1.0 / (self.head_size**0.5),
            num_kv_heads=self.num_kv_heads,
            cache_config=vllm_config.cache_config,
            prefix="model.layers.0.self_attn.attn",
        )
181
182
        self.attn._k_scale = self.attn._k_scale.to(device)
        self.attn._v_scale = self.attn._v_scale.to(device)
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199

        self.block_size = 16

        # Initialize attn MetadataBuilder
        self.builder = self.attn.attn_backend.get_builder_cls()(
            kv_cache_spec=AttentionSpec(
                block_size=self.block_size,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
                dtype=self.kv_cache_dtype,
                use_mla=False,
            ),
            layer_names=[self.attn.layer_name],
            vllm_config=self.vllm_config,
            device=self.device,
        )

200
    def build_attn_metadata(self, batch_size: int, use_hnd: bool):
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
        """Initialize attention metadata."""

        # Create common attn metadata
        batch_spec = BatchSpec(seq_lens=[1] * batch_size,
                               query_lens=[1] * batch_size)
        common_attn_metadata = create_common_attn_metadata(
            batch_spec,
            self.block_size,
            self.device,
            arange_block_indices=True)

        max_blocks = (max(batch_spec.seq_lens) + self.block_size -
                      1) // self.block_size
        num_blocks = batch_size * max_blocks

        # Create dummy KV cache for FlashInfer TRTLLM
217
218
        #   - NHD: [num_blocks, block_size, num_kv_heads, head_size]
        #   - HND: [num_blocks, num_kv_heads, block_size, head_size]
219
220
221
222
223
224
225
        kv_cache = torch.zeros(num_blocks,
                               2,
                               self.num_kv_heads,
                               self.block_size,
                               self.head_size,
                               dtype=self.kv_cache_dtype,
                               device=self.device)
226
227
228
229
230
231
232
233
234
235
236
        if current_platform.is_rocm():
            # k/v as 1st dimention
            if use_hnd:
                kv_cache = kv_cache.permute(1, 0, 2, 3, 4)
            else:
                kv_cache = kv_cache.permute(1, 0, 3, 2, 4)
        else:
            # k/v as 2nd dimention
            # Create kv_cache in HND layout and permute to NHD layout
            # (later will be permuted back to HND layout in forward pass)
            kv_cache = kv_cache.permute(0, 1, 3, 2, 4)
237
238
239
240
241
242
243
244
        self.attn.kv_cache = [kv_cache]

        # Build attn metadata
        self.attn_metadata = self.builder.build(
            common_prefix_len=0, common_attn_metadata=common_attn_metadata)

        return self.attn_metadata

245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270

class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
    """Test model for AttentionFp8StaticQuantPattern fusion."""

    quant_key = kFp8StaticTensorSym

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.fp8_linear = Fp8LinearOp(
            act_quant_static=self.quant_key.scale.static,
            act_quant_group_shape=self.quant_key.scale.group_shape)

        hidden_size = self.num_qo_heads * self.head_size
        self.w = kwargs.get(
            "w", {
                "weight":
                torch.randn(hidden_size, hidden_size).to(
                    dtype=FP8_DTYPE, device=self.device).t(),
                "wscale":
                torch.tensor([1.0], dtype=torch.float32, device=self.device),
                "scale":
                torch.tensor([1.0], dtype=torch.float32, device=self.device),
            })

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
271
272
273
        """Forward pass that creates the pattern to be fused."""
        attn_output = self.attn(q, k, v)
        return self.fp8_linear.apply(input=attn_output,
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
                                     weight=self.w["weight"],
                                     weight_scale=self.w["wscale"],
                                     input_scale=self.w["scale"])


class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
    """Test model for AttentionNvfp4QuantPattern fusion."""

    quant_key = kNvfp4Quant

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        hidden_size = self.num_qo_heads * self.head_size
        self.w = kwargs.get(
            "w", {
                "weight":
                torch.randint(256, (hidden_size, hidden_size // 2),
                              dtype=FP4_DTYPE,
                              device=self.device),
                "wscale_swizzled":
                torch.randn(hidden_size, hidden_size // 16).to(
                    dtype=FP8_DTYPE, device=self.device),
                "wscale":
                torch.tensor([500], dtype=torch.float32, device=self.device),
                "scale":
                torch.tensor([0.002], dtype=torch.float32, device=self.device),
            })

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
        """Forward pass that creates the pattern to be fused."""
        attn_output = self.attn(q, k, v)
        quant_output, output_block_scale = scaled_fp4_quant(
            attn_output, 1 / self.w["scale"])
        return cutlass_scaled_fp4_mm(a=quant_output,
                                     b=self.w["weight"],
                                     block_scale_a=output_block_scale,
                                     block_scale_b=self.w["wscale_swizzled"],
                                     alpha=self.w["scale"] * self.w["wscale"],
                                     out_dtype=attn_output.dtype)
314
315


316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
if current_platform.is_cuda():
    MODELS = [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
               TestAttentionFp8StaticQuantPatternModel),
              ("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
               TestAttentionNvfp4QuantPatternModel)]
    HEADS = [(64, 8), (40, 8)]
elif current_platform.is_rocm():
    MODELS = [("amd/Llama-3.1-8B-Instruct-FP8-KV",
               TestAttentionFp8StaticQuantPatternModel)]
    HEADS = [(32, 8), (40, 8)]
else:
    MODELS = []
    HEADS = []


@pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS)
332
@pytest.mark.parametrize("head_size", [128])
333
334
335
336
@pytest.mark.parametrize("batch_size",
                         [7, 256, 533] if current_platform.is_cuda() else [8])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("model_name, model_class", MODELS)
337
338
339
@pytest.mark.parametrize("backend",
                         [_Backend.FLASHINFER] if current_platform.is_cuda()
                         else [_Backend.TRITON_ATTN_VLLM_V1])
340
341
342
@pytest.mark.parametrize(
    "split_attention",
    [False, True] if current_platform.is_rocm() else [False])
343
344
345
346
# TODO(boyuan): test inductor graph partition on rocm
@pytest.mark.parametrize(
    "use_inductor_graph_partition",
    [False] if current_platform.is_rocm() else [False, True])
347
348
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
                    reason="Only test ROCm or CUDA")
349
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
350
351
352
353
354
@pytest.mark.skipif(current_platform.is_cuda()
                    and not current_platform.is_device_capability((10, 0)),
                    reason="On CUDA only test on SM100(Blackwell)")
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
                    reason="Only test ROCm or CUDA")
355
356
357
def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
                                 head_size: int, batch_size: int,
                                 dtype: torch.dtype, model_name: str,
358
                                 model_class: type[AttentionQuantPatternModel],
359
                                 backend: _Backend, split_attention: bool,
360
361
                                 use_inductor_graph_partition: bool,
                                 monkeypatch, dist_init, caplog_vllm):
362
363
    """Test AttentionStaticQuantPattern fusion pass"""

364
365
366
367
368
    if use_inductor_graph_partition and not is_torch_equal_or_newer(
            "2.9.0.dev"):
        pytest.skip("inductor graph partition is only available "
                    "in PyTorch 2.9+")

369
    monkeypatch.setenv("VLLM_USE_V1", "1")
370
371
    if split_attention:
        monkeypatch.setenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "1")
372
373
374
375
376
377
378
379

    device = torch.device("cuda:0")
    torch.manual_seed(42)

    vllm_config = VllmConfig(
        model_config=ModelConfig(
            model=model_name,
            max_model_len=2048,
380
            dtype=dtype,
381
382
383
384
385
        ),
        scheduler_config=SchedulerConfig(max_num_seqs=1024),
        compilation_config=CompilationConfig(
            level=CompilationLevel.PIECEWISE,
            custom_ops=["+quant_fp8"],
386
            use_inductor_graph_partition=use_inductor_graph_partition,
387
388
389
390
        ),
        cache_config=CacheConfig(cache_dtype="fp8"))

    # Create test inputs
391
392
393
394
    q = torch.randn(batch_size,
                    num_qo_heads * head_size,
                    dtype=dtype,
                    device=device)
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
    k = torch.randn(batch_size,
                    num_kv_heads * head_size,
                    dtype=dtype,
                    device=device)
    v = torch.randn(batch_size,
                    num_kv_heads * head_size,
                    dtype=dtype,
                    device=device)

    # Mark first dimension as dynamic for realistic testing
    torch._dynamo.mark_dynamic(q, 0)
    torch._dynamo.mark_dynamic(k, 0)
    torch._dynamo.mark_dynamic(v, 0)

    # Run model directly without compilation and fusion
    vllm_config_unfused = copy.deepcopy(vllm_config)
    with set_current_vllm_config(vllm_config_unfused), set_forward_context(
            attn_metadata=None, vllm_config=vllm_config_unfused
    ), global_force_attn_backend_context_manager(backend):
414
415
416
417
418
419
        model_unfused = model_class(num_qo_heads=num_qo_heads,
                                    num_kv_heads=num_kv_heads,
                                    head_size=head_size,
                                    kv_cache_dtype=FP8_DTYPE,
                                    device=device,
                                    vllm_config=vllm_config_unfused)
420
421
422
423
        model_unfused = model_unfused.to(device)

        forward_ctx = get_forward_context()
        forward_ctx.attn_metadata = model_unfused.build_attn_metadata(
424
            batch_size, use_hnd=split_attention)
425
426

        # Run model directly without compilation and fusion
427
        result_unfused = model_unfused(q, k, v)
428
429
430
431
432
433
434

    # Run model with attn fusion enabled
    vllm_config.compilation_config.pass_config = PassConfig(
        enable_attn_fusion=True, enable_noop=True)
    with set_current_vllm_config(vllm_config), set_forward_context(
            attn_metadata=None, vllm_config=vllm_config
    ), global_force_attn_backend_context_manager(backend):
435
436
437
438
439
440
441
        model_fused = model_class(num_qo_heads=num_qo_heads,
                                  num_kv_heads=num_kv_heads,
                                  head_size=head_size,
                                  kv_cache_dtype=FP8_DTYPE,
                                  device=device,
                                  vllm_config=vllm_config,
                                  w=model_unfused.w)
442
443
444
        model_fused = model_fused.to(device)

        forward_ctx = get_forward_context()
445
446
        forward_ctx.attn_metadata = model_fused.build_attn_metadata(
            batch_size, use_hnd=split_attention)
447
448
449
450
451
452
453
454
455
456
457
458

        # Create test backend with fusion passes enabled
        noop_pass = NoOpEliminationPass(vllm_config)
        attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw
                                                                    )
        test_backend = TestBackend(noop_pass, attn_pass)

        # Compile model with fusion enabled
        model_compiled = torch.compile(model_fused,
                                       backend=test_backend,
                                       fullgraph=True)
        assert model_compiled.attn._o_scale_float is None
459

460
        result_fused_1 = model_compiled(q, k, v)
461

462
463
464
465
466
467
468
        if backend == _Backend.FLASHINFER:
            # With the Flashinfer backend after the 1st round of the forward
            # pass, output quant scale should be loaded into the attn layer's
            # _o_scale_float, the 2nd round should reuse the loaded
            # _o_scale_float
            assert model_compiled.attn._o_scale_float is not None
            result_fused_2 = model_compiled(q, k, v)
469

470
471
472
473
474
475
            assert model_compiled.attn._o_scale_float is not None

            torch.testing.assert_close(result_unfused,
                                       result_fused_2,
                                       atol=1e-2,
                                       rtol=1e-2)
476
477

    # Check attn fusion support
478
    quant_key = model_class.quant_key
479
    attn_fusion_supported = [
480
481
        layer.impl.fused_output_quant_supported(quant_key) for key, layer in
        vllm_config.compilation_config.static_forward_context.items()
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
    ]
    if any(attn_fusion_supported):
        # Check quantization ops in the graph before and after fusion
        test_backend.check_before_ops([QUANT_OPS[quant_key]],
                                      fully_replaced=True)

    # Check attention ops in the graph before and after fusion
    attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass))
    attn_nodes_post = list(find_op_nodes(ATTN_OP,
                                         test_backend.graph_post_pass))

    assert len(attn_nodes_pre) > 0, "Should have attention nodes before fusion"
    assert len(attn_nodes_pre) == len(attn_nodes_post), \
        "Should have same number of attention nodes before and after fusion"
    assert attn_nodes_pre[0].kwargs.get("output_scale") is None, \
        "Attention should not have output_scale before fusion"
    assert attn_nodes_post[0].kwargs.get("output_scale") is not None, \
        "Attention should have output_scale after fusion"

501
502
503
504
505
506
507
508
509
    assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, \
        "Attention should not have output_block_scale before fusion"
    if quant_key.dtype == FP8_DTYPE:
        assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, \
            "Attention should not have output_block_scale after FP8 fusion"
    elif quant_key.dtype == FP4_DTYPE:
        assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, \
            "Attention should have output_block_scale after FP4 fusion"  # noqa: E501

510
    # Check that results are close
511
512
513
514
    torch.testing.assert_close(result_unfused,
                               result_fused_1,
                               atol=1e-2,
                               rtol=1e-2)