"vllm/vscode:/vscode.git/clone" did not exist on "22aeb430072f676424e7a27966b074d2710b29d4"
test_silu_mul_quant_fusion.py 2.49 KB
Newer Older
1
2
3
4
5
6
7
8
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch

import vllm.envs as envs
from vllm._custom_ops import scaled_fp8_quant
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
9
from vllm.config import CompilationConfig, PassConfig, VllmConfig
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from vllm.model_executor.layers.activation import SiluAndMul

from .backend import TestBackend


class TestModel(torch.nn.Module):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.silu_and_mul = SiluAndMul()
        self.scale = torch.rand(1, dtype=torch.float32)

    def forward(self, x):
        y = self.silu_and_mul(x)
        x2 = scaled_fp8_quant(y, self.scale)
        return x2


@pytest.mark.parametrize("num_tokens", [256])
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
                    reason="Only test on CUDA")
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
    torch.set_default_device("cuda")
    torch.set_default_dtype(torch.float16)

    # Reshape pass is needed for the fusion pass to work
    config = VllmConfig()
    config.compilation_config = CompilationConfig(
39
        pass_config=PassConfig(enable_fusion=True, enable_reshape=True))
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    fusion_pass = ActivationQuantFusionPass(config)

    backend = TestBackend(fusion_pass)
    model = TestModel()

    # First dimension dynamic
    x = torch.rand(num_tokens, hidden_size)
    torch._dynamo.mark_dynamic(x, 0)

    result = model(x)

    model2 = torch.compile(model, backend=backend)
    result2 = model2(x)

    # Check that it gives the same answer
    torch.testing.assert_close(result[0].to(dtype=torch.float16),
                               result2[0].to(dtype=torch.float16),
                               atol=1e-3,
                               rtol=1e-3)

    # Check substitution worked
    pre_nodes = backend.graph_pre_pass.nodes
    post_nodes = backend.graph_post_pass.nodes

    silu_and_mul_quant = torch.ops._C.silu_and_mul_quant.default
    fp8_quant = torch.ops._C.static_scaled_fp8_quant.default

    # In pre-nodes, fp8 quant should be present and fused kernels should not
    assert find_auto_fn_maybe(pre_nodes, silu_and_mul_quant) is None
    find_auto_fn(pre_nodes, fp8_quant)

    # In post-nodes, fused kernels should be present and fp8 quant should not
    find_auto_fn(post_nodes, silu_and_mul_quant)
    assert find_auto_fn_maybe(post_nodes, fp8_quant) is None