test_silu_mul_quant_fusion.py 3.64 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
import pytest
import torch

import vllm.envs as envs
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
9
from vllm.compilation.noop_elimination import NoOpEliminationPass
10
from vllm.config import CompilationConfig, PassConfig, VllmConfig
11
from vllm.model_executor.layers.activation import SiluAndMul
12
13
14
15
16
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    GroupShape)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
    CUTLASS_FP8_SUPPORTED, Fp8LinearOp)
from vllm.platforms import current_platform
17
18
19
20
21
22

from .backend import TestBackend


class TestModel(torch.nn.Module):

23
24
    def __init__(self, hidden_size: int, cutlass_fp8_enabled: bool, *args,
                 **kwargs):
25
26
        super().__init__(*args, **kwargs)
        self.silu_and_mul = SiluAndMul()
27
        self.wscale = torch.rand(1, dtype=torch.float32)
28
29
        self.scale = torch.rand(1, dtype=torch.float32)

30
31
32
33
34
35
36
37
38
39
        self.w = (torch.rand(
            hidden_size,
            hidden_size).to(dtype=current_platform.fp8_dtype()).t())

        self.fp8_linear = Fp8LinearOp(
            cutlass_fp8_supported=cutlass_fp8_enabled,
            act_quant_static=True,
            act_quant_group_shape=GroupShape.PER_TENSOR,
        )

40
41
    def forward(self, x):
        y = self.silu_and_mul(x)
42
43
44
45
        x2 = self.fp8_linear.apply(y,
                                   self.w,
                                   self.wscale,
                                   input_scale=self.wscale)
46
47
48
49
50
        return x2


@pytest.mark.parametrize("num_tokens", [256])
@pytest.mark.parametrize("hidden_size", [64])
51
52
@pytest.mark.parametrize("cutlass_fp8_enabled",
                         [True, False] if CUTLASS_FP8_SUPPORTED else [False])
53
54
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
                    reason="Only test on CUDA and ROCm")
55
56
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size,
                                   cutlass_fp8_enabled):
57
58
59
60
61
62
    torch.set_default_device("cuda")
    torch.set_default_dtype(torch.float16)

    # Reshape pass is needed for the fusion pass to work
    config = VllmConfig()
    config.compilation_config = CompilationConfig(
63
        pass_config=PassConfig(enable_fusion=True, enable_noop=True))
64
65
    fusion_pass = ActivationQuantFusionPass(config)

66
67
    backend = TestBackend(NoOpEliminationPass(config), fusion_pass)
    model = TestModel(hidden_size, cutlass_fp8_enabled)
68
69

    # First dimension dynamic
70
    x = torch.rand(num_tokens, hidden_size * 2)
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    torch._dynamo.mark_dynamic(x, 0)

    result = model(x)

    model2 = torch.compile(model, backend=backend)
    result2 = model2(x)

    # Check that it gives the same answer
    torch.testing.assert_close(result[0].to(dtype=torch.float16),
                               result2[0].to(dtype=torch.float16),
                               atol=1e-3,
                               rtol=1e-3)

    # Check substitution worked
    pre_nodes = backend.graph_pre_pass.nodes
    post_nodes = backend.graph_post_pass.nodes

    silu_and_mul_quant = torch.ops._C.silu_and_mul_quant.default
    fp8_quant = torch.ops._C.static_scaled_fp8_quant.default

    # In pre-nodes, fp8 quant should be present and fused kernels should not
    assert find_auto_fn_maybe(pre_nodes, silu_and_mul_quant) is None
    find_auto_fn(pre_nodes, fp8_quant)

    # In post-nodes, fused kernels should be present and fp8 quant should not
    find_auto_fn(post_nodes, silu_and_mul_quant)
    assert find_auto_fn_maybe(post_nodes, fp8_quant) is None