test_rocm_aiter_topk.py 7.66 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
# This is a test for the AITER ops.
# It tests if the AITER ops are
# 1. correctly registered as custom ops
# 2. correctly defined the relationship between
#    implementation and fake function
# 3. can be used with torch.compile
# This file will be skipped if AITER is not installed
# and the platform is not ROCm.

import importlib.util
13
import os
14
15
16
17

import pytest
import torch

18
19
20
21
22
23
24
25
from vllm.platforms import current_platform

if not current_platform.is_rocm():
    pytest.skip("This test can only run on ROCm.", allow_module_level=True)

# This environment variable must be set so ops will be registered.
os.environ["VLLM_ROCM_USE_AITER"] = "1"

26
27
28
29
30
31
32
# this import statement is needed to ensure the ops are registered
import vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe  # noqa: F401

# need to import once to ensure the ops are registered
# Check if aiter package is installed
aiter_available = importlib.util.find_spec("aiter") is not None

33
34
if not aiter_available:
    pytest.skip("These tests require AITER to run.", allow_module_level=True)
35
36
37
38
39


def test_rocm_aiter_biased_grouped_topk_custom_op_registration():
    """Test that the custom op is correctly registered."""
    # Check if the op exists in torch.ops.vllm
40
    assert hasattr(torch.ops.vllm, "rocm_aiter_biased_grouped_topk")
41
42
43
44
45

    # Check if the op is callable
    assert callable(torch.ops.vllm.rocm_aiter_biased_grouped_topk)


46
47
48
def test_rocm_aiter_grouped_topk_custom_op_registration():
    """Test that the custom op is correctly registered."""
    # Check if the op exists in torch.ops.vllm
49
    assert hasattr(torch.ops.vllm, "rocm_aiter_grouped_topk")
50
51
52
53
54

    # Check if the op is callable
    assert callable(torch.ops.vllm.rocm_aiter_grouped_topk)


55
56
57
58
59
60
61
62
63
64
65
def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility():
    """Test that the op can be used with torch.compile."""
    # Create test tensors
    token = 64
    expert = 256
    num_expert_group = 8
    topk = 8
    topk_group = 4
    renormalize = True
    scale_factor = 1.0

66
67
68
69
    gating_output = torch.randn((token, expert), dtype=torch.bfloat16, device="cuda")
    e_score_correction_bias = torch.randn(
        (expert,), dtype=torch.bfloat16, device="cuda"
    )
70
71
72

    device = gating_output.device
    topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
73
    topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
74
75

    # Define a function that uses the op
76
77
78
    def biased_grouped_topk_fn(
        gating_output, e_score_correction_bias, topk_weights, topk_ids
    ):
79
        return torch.ops.vllm.rocm_aiter_biased_grouped_topk(
80
81
82
83
84
85
86
87
88
            gating_output,
            e_score_correction_bias,
            topk_weights,
            topk_ids,
            num_expert_group,
            topk_group,
            renormalize,
            scale_factor,
        )
89
90
91
92
93
94
95
96
97

    # Verify the op's fake implementation
    torch.library.opcheck(
        torch.ops.vllm.rocm_aiter_biased_grouped_topk,
        (gating_output, e_score_correction_bias, topk_weights, topk_ids),
        kwargs={
            "num_expert_group": num_expert_group,
            "topk_group": topk_group,
            "need_renorm": renormalize,
98
            "routed_scaling_factor": scale_factor,
99
        },
100
101
        test_utils=("test_faketensor"),
    )
102
103

    # Compile the function with appropriate settings
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    compiled_fn = torch.compile(
        biased_grouped_topk_fn,
        fullgraph=True,
        backend="inductor",
        mode="reduce-overhead",
        dynamic=False,
    )

    topk_weights_original = torch.empty(
        (token, topk), dtype=torch.float32, device=device
    )
    topk_ids_original = torch.empty((token, topk), dtype=torch.int32, device=device)

    topk_weights_compiled = torch.empty(
        (token, topk), dtype=torch.float32, device=device
    )
    topk_ids_compiled = torch.empty((token, topk), dtype=torch.int32, device=device)
121
122

    # Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode)
123
124
125
126
127
128
    biased_grouped_topk_fn(
        gating_output, e_score_correction_bias, topk_weights_original, topk_ids_original
    )
    compiled_fn(
        gating_output, e_score_correction_bias, topk_weights_compiled, topk_ids_compiled
    )
129
130
131

    # Sort the results for comparison since the order might not be deterministic
    topk_ids_original, indices_original = torch.sort(topk_ids_original)
132
    topk_weights_original = torch.gather(topk_weights_original, 1, indices_original)
133
134

    topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled)
135
    topk_weights_compiled = torch.gather(topk_weights_compiled, 1, indices_compiled)
136
137

    # Verify results match
138
139
140
    assert torch.allclose(
        topk_weights_original, topk_weights_compiled, rtol=1e-2, atol=1e-2
    )
141
    assert torch.allclose(topk_ids_original, topk_ids_compiled)
142
143
144
145
146
147
148
149
150
151
152
153
154
155


def test_rocm_aiter_grouped_topk_torch_compile_compatibility():
    """Test that the op can be used with torch.compile."""
    # Create test tensors
    token = 64
    expert = 256
    num_expert_group = 8
    topk = 8
    topk_group = 4
    renormalize = True
    scoring_func = "softmax"
    scale_factor = 1.0

156
    gating_output = torch.randn((token, expert), dtype=torch.bfloat16, device="cuda")
157
158
159

    device = gating_output.device
    topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
160
    topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
161
162
163
164

    # Define a function that uses the op
    def grouped_topk_fn(gating_output, topk_weights, topk_ids, scoring_func):
        return torch.ops.vllm.rocm_aiter_grouped_topk(
165
166
167
168
169
170
171
172
173
            gating_output,
            topk_weights,
            topk_ids,
            num_expert_group,
            topk_group,
            renormalize,
            scoring_func,
            scale_factor,
        )
174
175

    # Verify the op's fake implementation
176
177
178
179
180
181
182
183
184
185
186
187
    torch.library.opcheck(
        torch.ops.vllm.rocm_aiter_grouped_topk,
        (gating_output, topk_weights, topk_ids),
        kwargs={
            "num_expert_group": num_expert_group,
            "topk_group": topk_group,
            "need_renorm": renormalize,
            "scoring_func": scoring_func,
            "routed_scaling_factor": scale_factor,
        },
        test_utils=("test_faketensor"),
    )
188
189

    # Compile the function with appropriate settings
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    compiled_fn = torch.compile(
        grouped_topk_fn,
        fullgraph=True,
        backend="inductor",
        mode="reduce-overhead",
        dynamic=False,
    )

    topk_weights_original = torch.empty(
        (token, topk), dtype=torch.float32, device=device
    )
    topk_ids_original = torch.empty((token, topk), dtype=torch.int32, device=device)

    topk_weights_compiled = torch.empty(
        (token, topk), dtype=torch.float32, device=device
    )
    topk_ids_compiled = torch.empty((token, topk), dtype=torch.int32, device=device)
207
208

    # Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode)
209
210
211
212
    grouped_topk_fn(
        gating_output, topk_weights_original, topk_ids_original, scoring_func
    )
    compiled_fn(gating_output, topk_weights_compiled, topk_ids_compiled, scoring_func)
213
214
215

    # Sort the results for comparison since the order might not be deterministic
    topk_ids_original, indices_original = torch.sort(topk_ids_original)
216
    topk_weights_original = torch.gather(topk_weights_original, 1, indices_original)
217
218

    topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled)
219
    topk_weights_compiled = torch.gather(topk_weights_compiled, 1, indices_compiled)
220
221

    # Verify results match
222
223
224
    assert torch.allclose(
        topk_weights_original, topk_weights_compiled, rtol=1e-2, atol=1e-2
    )
225
    assert torch.allclose(topk_ids_original, topk_ids_compiled)