test_fusions_e2e.py 16 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from __future__ import annotations

import itertools
import logging
from collections.abc import Iterable
from typing import Any, NamedTuple

import pytest
import regex as re

14
from tests.v1.attention.utils import AttentionBackendEnum
15
16
17
18
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
19
from vllm.utils.torch_utils import is_torch_equal_or_newer
20

21
from ...utils import flat_product, multi_gpu_test
22

23
24
25
26
27
28
29
30
31
32
is_blackwell = lambda: current_platform.is_device_capability(100)
"""Are we running on Blackwell, a lot of tests depend on it"""


class Matches(NamedTuple):
    attention_fusion: int = 0
    allreduce_fusion: int = 0
    sequence_parallel: int = 0
    async_tp: int = 0

33
34
35
36

class ModelBackendTestCase(NamedTuple):
    model_name: str
    model_kwargs: dict[str, Any]
37
    backend: AttentionBackendEnum
38
    matches: Matches
39
40
41
42
43
44
45
46
47
48
49


MODELS_FP8: list[ModelBackendTestCase] = []
MODELS_FP4: list[ModelBackendTestCase] = []
MODELS: list[ModelBackendTestCase] = []  # tp-only

if current_platform.is_cuda():
    MODELS_FP8 = [
        ModelBackendTestCase(
            # Use smaller model for L40s in CI
            model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
50
51
52
53
54
55
56
57
58
59
60
61
            # TODO while llama4 is broken, use FLASHINFER for llama3 on Blackwell
            #  so FI attention+fp8_quant is at least tested once
            model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
            backend=AttentionBackendEnum.FLASHINFER
            if is_blackwell()
            else AttentionBackendEnum.TRITON_ATTN,
            matches=Matches(
                attention_fusion=32,
                allreduce_fusion=65,
                sequence_parallel=65,
                async_tp=128,
            ),
62
63
64
65
        ),
        ModelBackendTestCase(
            model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
            model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
66
67
68
69
70
71
72
73
74
75
76
            # TODO FlashInfer attn broken on Hopper with kvcache=fp8:
            # https://github.com/vllm-project/vllm/issues/28568
            # TODO FlashInfer attn broken on Blackwell for llama4:
            # https://github.com/vllm-project/vllm/issues/28604
            backend=AttentionBackendEnum.TRITON_ATTN,
            matches=Matches(
                attention_fusion=48,
                allreduce_fusion=96,
                sequence_parallel=96,
                async_tp=95,  # mlp is moe, no fusion there
            ),
77
78
79
80
81
        ),
    ]

    MODELS_FP4 = [
        ModelBackendTestCase(
82
            model_name="nvidia/Llama-3.1-8B-Instruct-FP4",
83
            model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
84
            backend=AttentionBackendEnum.FLASHINFER,
85
86
87
88
89
90
            matches=Matches(
                attention_fusion=32,
                allreduce_fusion=65,
                sequence_parallel=65,
                async_tp=128,
            ),
91
92
93
94
95
96
97
98
        ),
    ]

    # TP only
    MODELS = [
        ModelBackendTestCase(
            model_name="meta-llama/Llama-3.1-8B-Instruct",
            model_kwargs=dict(max_model_len=1024),
99
            backend=AttentionBackendEnum.TRITON_ATTN,
100
101
102
103
104
105
            matches=Matches(
                attention_fusion=0,
                allreduce_fusion=65,
                sequence_parallel=65,
                async_tp=128,
            ),
106
        ),
107
108
109
        ModelBackendTestCase(
            model_name="Qwen/Qwen3-30B-A3B",
            model_kwargs=dict(max_model_len=1024),
110
            backend=AttentionBackendEnum.TRITON_ATTN,
111
112
113
114
115
116
            matches=Matches(
                attention_fusion=0,
                allreduce_fusion=97,
                sequence_parallel=97,
                async_tp=96,  # MLP is MoE, half the fusions of dense
            ),
117
        ),
118
119
120
121
122
123
124
    ]

elif current_platform.is_rocm():
    MODELS_FP8 = [
        ModelBackendTestCase(
            model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
            model_kwargs=dict(max_model_len=1024),
125
            backend=AttentionBackendEnum.TRITON_ATTN,
126
            matches=Matches(attention_fusion=32),
127
128
129
130
        ),
        ModelBackendTestCase(
            model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
            model_kwargs=dict(max_model_len=1024),
131
            backend=AttentionBackendEnum.ROCM_ATTN,
132
            matches=Matches(attention_fusion=32),
133
134
135
136
        ),
        ModelBackendTestCase(
            model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
            model_kwargs=dict(max_model_len=1024),
137
            backend=AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
138
            matches=Matches(attention_fusion=32),
139
140
141
        ),
    ]

142
CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"]
143
144
145


@pytest.mark.parametrize(
146
    "model_name, model_kwargs, backend, matches, custom_ops",
147
148
149
150
151
152
153
154
155
    # Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
    list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8))
    # quant_fp4 only has the custom impl
    + list(flat_product(MODELS_FP4, [""])),
)
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
def test_attn_quant(
    model_name: str,
    model_kwargs: dict[str, Any],
156
    backend: AttentionBackendEnum,
157
    matches: Matches,
158
159
160
161
162
    custom_ops: str,
    inductor_graph_partition: bool,
    caplog_mp_spawn,
    monkeypatch,
):
163
    if backend == AttentionBackendEnum.FLASHINFER and (
164
        not is_blackwell() or not has_flashinfer()
165
166
167
168
169
170
171
172
173
174
175
    ):
        pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
    if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
        pytest.skip("Inductor graph partition requires torch>=2.9")

    custom_ops_list = custom_ops.split(",") if custom_ops else []

    if inductor_graph_partition:
        mode = CUDAGraphMode.FULL_AND_PIECEWISE
        splitting_ops: list[str] | None = None
    else:
176
177
178
        # FIXME: Llama-4-Scout-17B-16E-Instruct-FP8 + FlashInfer + Blackwell end at
        # CUDAGraphMode.NONE here because it derives an attention backend that
        # does not support full cudagraphs
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
        mode = CUDAGraphMode.FULL_DECODE_ONLY
        splitting_ops = []

    # Disable, compile cache to make sure custom passes run.
    # Otherwise, we can't verify fusion happened through the logs.
    monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")

    # To capture subprocess logs, we need to know whether spawn or fork is used.
    # Force spawn as it is more general.
    monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
    monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)

    compilation_config = CompilationConfig(
        # Testing properties
        custom_ops=custom_ops_list,
        use_inductor_graph_partition=inductor_graph_partition,
        cudagraph_mode=mode,
        splitting_ops=splitting_ops,
        # Common
198
        mode=CompilationMode.VLLM_COMPILE,
199
200
201
202
203
204
205
206
        pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True),
        # Inductor caches custom passes by default as well via uuid
        inductor_compile_config={"force_disable_caches": True},
    )

    with caplog_mp_spawn(logging.DEBUG) as log_holder:
        run_model(compilation_config, model_name, **model_kwargs)

207
    log_matches = re.findall(
208
209
210
        r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
        log_holder.text,
    )
211
212
    assert len(log_matches) == 1, log_holder.text
    assert int(log_matches[0]) == matches.attention_fusion
213
214


215
CUSTOM_OPS_RMS_NORM = ["-rms_norm", "+rms_norm"]
216
217
218
219
220
221
222
223
224


def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
    for op_list in itertools.product(*custom_ops_lists):
        yield ",".join(op_list)


@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
225
    "model_name, model_kwargs, backend, matches, custom_ops",
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    # Toggle RMSNorm and QuantFP8 for FP8 models
    list(
        flat_product(
            MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM)
        )
    )
    # Toggle RMSNorm for FP4 models and unquant models
    + list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)),
)
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
@pytest.mark.skipif(
    not current_platform.is_cuda()
    or not has_flashinfer()
    or not current_platform.has_device_capability(90),
    reason="allreduce+rmsnorm fusion requires flashinfer",
)
def test_tp2_attn_quant_allreduce_rmsnorm(
    model_name: str,
    model_kwargs: dict,
245
    backend: AttentionBackendEnum,
246
    matches: Matches,
247
248
249
250
251
252
253
254
    custom_ops: str,
    inductor_graph_partition: bool,
    caplog_mp_spawn,
    monkeypatch,
):
    if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
        pytest.skip("Inductor graph partition requires torch>=2.9")

255
256
257
258
259
260
261
    if "fp4" in model_name.lower() and not is_blackwell():
        pytest.skip("NVFP4 quant requires Blackwell")

    if backend == AttentionBackendEnum.FLASHINFER and not is_blackwell():
        # FlashInfer attn fusion requires Blackwell
        matches = matches._replace(attention_fusion=0)

262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
    custom_ops_list = custom_ops.split(",") if custom_ops else []

    if inductor_graph_partition:
        mode = CUDAGraphMode.FULL_AND_PIECEWISE
        splitting_ops: list[str] | None = None
    else:
        mode = CUDAGraphMode.FULL_DECODE_ONLY
        splitting_ops = []

    # Disable, compile cache to make sure custom passes run.
    # Otherwise, we can't verify fusion happened through the logs.
    monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")

    # To capture subprocess logs, we need to know whether spawn or fork is used.
    # Force spawn as it is more general.
    monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
    monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)

    compilation_config = CompilationConfig(
        # Testing properties
        use_inductor_graph_partition=inductor_graph_partition,
        cudagraph_mode=mode,
        custom_ops=custom_ops_list,
        splitting_ops=splitting_ops,
        # Common
287
        mode=CompilationMode.VLLM_COMPILE,
288
289
290
291
292
293
294
295
296
297
298
299
300
        pass_config=PassConfig(
            enable_attn_fusion=True,
            enable_noop=True,
            enable_fi_allreduce_fusion=True,
        ),
        # Inductor caches custom passes by default as well via uuid
        inductor_compile_config={"force_disable_caches": True},
    )

    with caplog_mp_spawn(logging.DEBUG) as log_holder:
        run_model(
            compilation_config, model_name, tensor_parallel_size=2, **model_kwargs
        )
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
    log_matches = re.findall(
        r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
        log_holder.text,
    )
    assert len(log_matches) == 2, log_holder.text

    assert int(log_matches[0]) == matches.attention_fusion
    assert int(log_matches[1]) == matches.attention_fusion

    log_matches = re.findall(
        r"collective_fusion.py:\d+] Replaced (\d+) patterns",
        log_holder.text,
    )
    assert len(log_matches) == 2, log_holder.text

    assert int(log_matches[0]) == matches.allreduce_fusion
    assert int(log_matches[1]) == matches.allreduce_fusion


@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
    "model_name, model_kwargs, backend, matches, custom_ops",
    # Toggle RMSNorm and QuantFP8 for FP8 models
    list(
        flat_product(
            MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM)
        )
    )
    # Toggle RMSNorm for FP4 models and unquant models
    + list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)),
)
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
@pytest.mark.skipif(
    not current_platform.is_cuda(),
    reason="sequence parallel only tested on CUDA",
)
def test_tp2_attn_quant_async_tp(
    model_name: str,
    model_kwargs: dict,
    backend: AttentionBackendEnum,
    matches: Matches,
    custom_ops: str,
    inductor_graph_partition: bool,
    caplog_mp_spawn,
    monkeypatch,
):
    if is_blackwell():
        # TODO: https://github.com/vllm-project/vllm/issues/27893
        pytest.skip("Blackwell is not supported for AsyncTP pass")

    if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
        pytest.skip("Inductor graph partition requires torch>=2.9")

    if "fp4" in model_name.lower() and not is_blackwell():
        pytest.skip("NVFP4 quant requires Blackwell")

    if backend == AttentionBackendEnum.FLASHINFER:
        if not has_flashinfer():
            pytest.skip("FlashInfer backend requires flashinfer installed")
        if not is_blackwell():
            # FlashInfer attn fusion requires Blackwell
            matches = matches._replace(attention_fusion=0)

    custom_ops_list = custom_ops.split(",") if custom_ops else []

    if inductor_graph_partition:
        mode = CUDAGraphMode.FULL_AND_PIECEWISE
        splitting_ops: list[str] | None = None
    else:
        mode = CUDAGraphMode.FULL_DECODE_ONLY
        splitting_ops = []

    # Disable, compile cache to make sure custom passes run.
    # Otherwise, we can't verify fusion happened through the logs.
    monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")

    # To capture subprocess logs, we need to know whether spawn or fork is used.
    # Force spawn as it is more general.
    monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
    monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)

    compilation_config = CompilationConfig(
        # Testing properties
        use_inductor_graph_partition=inductor_graph_partition,
        cudagraph_mode=mode,
        custom_ops=custom_ops_list,
        splitting_ops=splitting_ops,
        # Common
        level=CompilationMode.VLLM_COMPILE,
        pass_config=PassConfig(
            enable_attn_fusion=True,
            enable_noop=True,
            enable_sequence_parallelism=True,
            enable_async_tp=True,
        ),
        # Inductor caches custom passes by default as well via uuid
        inductor_compile_config={"force_disable_caches": True},
    )

    with caplog_mp_spawn(logging.DEBUG) as log_holder:
        run_model(
            compilation_config, model_name, tensor_parallel_size=2, **model_kwargs
        )
    log_matches = re.findall(
405
406
407
        r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
        log_holder.text,
    )
408
409
410
411
412
413
414
415
416
417
    assert len(log_matches) == 2, log_holder.text

    assert int(log_matches[0]) == matches.attention_fusion
    assert int(log_matches[1]) == matches.attention_fusion

    log_matches = re.findall(
        r"sequence_parallelism.py:\d+] Replaced (\d+) patterns",
        log_holder.text,
    )
    assert len(log_matches) == 2, log_holder.text
418

419
420
    assert int(log_matches[0]) == matches.sequence_parallel
    assert int(log_matches[1]) == matches.sequence_parallel
421

422
    log_matches = re.findall(
423
424
425
        r"collective_fusion.py:\d+] Replaced (\d+) patterns",
        log_holder.text,
    )
426
    assert len(log_matches) == 2, log_holder.text
427

428
429
    assert int(log_matches[0]) == matches.async_tp
    assert int(log_matches[1]) == matches.async_tp
430
431
432
433
434
435


def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs):
    compilation_config = (
        compile_config
        if isinstance(compile_config, CompilationConfig)
436
        else CompilationConfig(mode=compile_config)
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
    )

    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]
    sampling_params = SamplingParams(temperature=0)
    # Allow override from model_kwargs
    model_kwargs = {"tensor_parallel_size": 1, **model_kwargs}
    model_kwargs = {"disable_custom_all_reduce": True, **model_kwargs}

    # No cudagraphs by default
    if compilation_config.cudagraph_mode is None:
        compilation_config.cudagraph_mode = CUDAGraphMode.NONE

    llm = LLM(
        model=model,
        compilation_config=compilation_config,
        **model_kwargs,
    )
    outputs = llm.generate(prompts, sampling_params)

    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")