test_fuse_act_padding.py 4.21 KB
Newer Older
raojy's avatar
raojy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project


import pytest
import torch

import vllm.config
from tests.compile.backend import TestBackend
from vllm._aiter_ops import is_aiter_found_and_supported, rocm_aiter_ops
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
from vllm.config import (
    CompilationConfig,
    CompilationMode,
    ModelConfig,
    PassConfig,
    VllmConfig,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.utils import rocm_unquantized_gemm


class TestModel(torch.nn.Module):
    def __init__(
        self,
        num_layers: int,
        hidden_size: int,
        num_local_experts: int,
        x_pad_to_multiple: int,
    ):
        super().__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.x_pad_to_multiple = x_pad_to_multiple
        self.pad_dim = x_pad_to_multiple - (hidden_size % x_pad_to_multiple)

        self.norm = [RMSNorm(hidden_size, eps=1e-5) for _ in range(num_layers)]
        self.router = [
            torch.nn.Linear(hidden_size, num_local_experts) for _ in range(4)
        ]

    def forward(self, x):
        # avoid having graph input be an arg to a pattern directly
        x = resid = torch.relu(x)
        all_router_logits = []
        for layer in range(self.num_layers):
            x = x[:, : self.hidden_size]
            x, resid = self.norm[layer](x, resid)
            router_logits = rocm_unquantized_gemm(
                self, x, self.router[layer].weight, self.router[layer].bias
            )
            x = torch.nn.functional.pad(
                x, (0, self.pad_dim), mode="constant", value=0.0
            )
            all_router_logits.append(router_logits)

        return x, resid, *all_router_logits

    def ops_in_model_before(self):
        return [
            rocm_aiter_ops.get_rmsnorm_fused_add_op(),
            torch.ops.aten.constant_pad_nd,
        ]

    def ops_in_model_after(self):
        return [rocm_aiter_ops.get_triton_add_rmsnorm_pad_op()]


@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("num_layers", [3])
@pytest.mark.parametrize("hidden_size", [2880])
@pytest.mark.parametrize("num_local_experts", [128])
@pytest.mark.parametrize("x_pad_to_multiple", [256])
@pytest.mark.skipif(
    not is_aiter_found_and_supported(),
    reason="Only test on ROCm with AITER installed and supported",
)
def test_fuse_act_padding(
    dtype: torch.dtype,
    num_layers: int,
    hidden_size: int,
    num_local_experts: int,
    x_pad_to_multiple: int,
    monkeypatch: pytest.MonkeyPatch,
):
    vllm_config = VllmConfig(
        model_config=ModelConfig(dtype=dtype),
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
            custom_ops=["+rms_norm"],
            pass_config=PassConfig(fuse_act_padding=True, eliminate_noops=True),
        ),
    )

    with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
        from vllm.compilation.passes.fusion.rocm_aiter_fusion import (
            RocmAiterTritonAddRMSNormPadFusionPass,
        )

        torch.set_default_device("cuda")
        torch.set_default_dtype(dtype)
        torch.manual_seed(1)

        m.setenv("VLLM_ROCM_USE_AITER", "1")
        rocm_aiter_ops.refresh_env_variables()

        fusion_pass = RocmAiterTritonAddRMSNormPadFusionPass(vllm_config)
        passes = [
            NoOpEliminationPass(vllm_config),
            fusion_pass,
            PostCleanupPass(vllm_config),
        ]
        backend = TestBackend(*passes)
        model = TestModel(num_layers, hidden_size, num_local_experts, x_pad_to_multiple)

        x = torch.rand(1, hidden_size)
        torch._dynamo.mark_dynamic(x, 0)

        outputs_unfused = model(x)

        model_fused = torch.compile(model, backend=backend)
        outputs_fused = model_fused(x)

        torch.testing.assert_close(outputs_unfused, outputs_fused)

        assert fusion_pass.matched_count == num_layers

        backend.check_before_ops(model.ops_in_model_before())
        backend.check_after_ops(model.ops_in_model_after())