test_rocm_aiter_topk.py 7.51 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# This is a test for the AITER ops.
# It tests if the AITER ops are
# 1. correctly registered as custom ops
# 2. correctly defined the relationship between
#    implementation and fake function
# 3. can be used with torch.compile
# This file will be skipped if AITER is not installed
# and the platform is not ROCm.

import importlib.util

import pytest
import torch

# this import statement is needed to ensure the ops are registered
import vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe  # noqa: F401
from vllm.platforms import current_platform

# need to import once to ensure the ops are registered
# Check if aiter package is installed
aiter_available = importlib.util.find_spec("aiter") is not None

pytestmark = pytest.mark.skipif(
    not (current_platform.is_rocm() and aiter_available),
27
28
    reason="AITER ops are only available on ROCm with aiter package installed",
)
29
30
31
32
33


def test_rocm_aiter_biased_grouped_topk_custom_op_registration():
    """Test that the custom op is correctly registered."""
    # Check if the op exists in torch.ops.vllm
34
    assert hasattr(torch.ops.vllm, "rocm_aiter_biased_grouped_topk")
35
36
37
38
39

    # Check if the op is callable
    assert callable(torch.ops.vllm.rocm_aiter_biased_grouped_topk)


40
41
42
def test_rocm_aiter_grouped_topk_custom_op_registration():
    """Test that the custom op is correctly registered."""
    # Check if the op exists in torch.ops.vllm
43
    assert hasattr(torch.ops.vllm, "rocm_aiter_grouped_topk")
44
45
46
47
48

    # Check if the op is callable
    assert callable(torch.ops.vllm.rocm_aiter_grouped_topk)


49
50
51
52
53
54
55
56
57
58
59
def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility():
    """Test that the op can be used with torch.compile."""
    # Create test tensors
    token = 64
    expert = 256
    num_expert_group = 8
    topk = 8
    topk_group = 4
    renormalize = True
    scale_factor = 1.0

60
61
62
63
    gating_output = torch.randn((token, expert), dtype=torch.bfloat16, device="cuda")
    e_score_correction_bias = torch.randn(
        (expert,), dtype=torch.bfloat16, device="cuda"
    )
64
65
66

    device = gating_output.device
    topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
67
    topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
68
69

    # Define a function that uses the op
70
71
72
    def biased_grouped_topk_fn(
        gating_output, e_score_correction_bias, topk_weights, topk_ids
    ):
73
        return torch.ops.vllm.rocm_aiter_biased_grouped_topk(
74
75
76
77
78
79
80
81
82
            gating_output,
            e_score_correction_bias,
            topk_weights,
            topk_ids,
            num_expert_group,
            topk_group,
            renormalize,
            scale_factor,
        )
83
84
85
86
87
88
89
90
91

    # Verify the op's fake implementation
    torch.library.opcheck(
        torch.ops.vllm.rocm_aiter_biased_grouped_topk,
        (gating_output, e_score_correction_bias, topk_weights, topk_ids),
        kwargs={
            "num_expert_group": num_expert_group,
            "topk_group": topk_group,
            "need_renorm": renormalize,
92
            "routed_scaling_factor": scale_factor,
93
        },
94
95
        test_utils=("test_faketensor"),
    )
96
97

    # Compile the function with appropriate settings
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    compiled_fn = torch.compile(
        biased_grouped_topk_fn,
        fullgraph=True,
        backend="inductor",
        mode="reduce-overhead",
        dynamic=False,
    )

    topk_weights_original = torch.empty(
        (token, topk), dtype=torch.float32, device=device
    )
    topk_ids_original = torch.empty((token, topk), dtype=torch.int32, device=device)

    topk_weights_compiled = torch.empty(
        (token, topk), dtype=torch.float32, device=device
    )
    topk_ids_compiled = torch.empty((token, topk), dtype=torch.int32, device=device)
115
116

    # Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode)
117
118
119
120
121
122
    biased_grouped_topk_fn(
        gating_output, e_score_correction_bias, topk_weights_original, topk_ids_original
    )
    compiled_fn(
        gating_output, e_score_correction_bias, topk_weights_compiled, topk_ids_compiled
    )
123
124
125

    # Sort the results for comparison since the order might not be deterministic
    topk_ids_original, indices_original = torch.sort(topk_ids_original)
126
    topk_weights_original = torch.gather(topk_weights_original, 1, indices_original)
127
128

    topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled)
129
    topk_weights_compiled = torch.gather(topk_weights_compiled, 1, indices_compiled)
130
131

    # Verify results match
132
133
134
    assert torch.allclose(
        topk_weights_original, topk_weights_compiled, rtol=1e-2, atol=1e-2
    )
135
    assert torch.allclose(topk_ids_original, topk_ids_compiled)
136
137
138
139
140
141
142
143
144
145
146
147
148
149


def test_rocm_aiter_grouped_topk_torch_compile_compatibility():
    """Test that the op can be used with torch.compile."""
    # Create test tensors
    token = 64
    expert = 256
    num_expert_group = 8
    topk = 8
    topk_group = 4
    renormalize = True
    scoring_func = "softmax"
    scale_factor = 1.0

150
    gating_output = torch.randn((token, expert), dtype=torch.bfloat16, device="cuda")
151
152
153

    device = gating_output.device
    topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
154
    topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
155
156
157
158

    # Define a function that uses the op
    def grouped_topk_fn(gating_output, topk_weights, topk_ids, scoring_func):
        return torch.ops.vllm.rocm_aiter_grouped_topk(
159
160
161
162
163
164
165
166
167
            gating_output,
            topk_weights,
            topk_ids,
            num_expert_group,
            topk_group,
            renormalize,
            scoring_func,
            scale_factor,
        )
168
169

    # Verify the op's fake implementation
170
171
172
173
174
175
176
177
178
179
180
181
    torch.library.opcheck(
        torch.ops.vllm.rocm_aiter_grouped_topk,
        (gating_output, topk_weights, topk_ids),
        kwargs={
            "num_expert_group": num_expert_group,
            "topk_group": topk_group,
            "need_renorm": renormalize,
            "scoring_func": scoring_func,
            "routed_scaling_factor": scale_factor,
        },
        test_utils=("test_faketensor"),
    )
182
183

    # Compile the function with appropriate settings
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    compiled_fn = torch.compile(
        grouped_topk_fn,
        fullgraph=True,
        backend="inductor",
        mode="reduce-overhead",
        dynamic=False,
    )

    topk_weights_original = torch.empty(
        (token, topk), dtype=torch.float32, device=device
    )
    topk_ids_original = torch.empty((token, topk), dtype=torch.int32, device=device)

    topk_weights_compiled = torch.empty(
        (token, topk), dtype=torch.float32, device=device
    )
    topk_ids_compiled = torch.empty((token, topk), dtype=torch.int32, device=device)
201
202

    # Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode)
203
204
205
206
    grouped_topk_fn(
        gating_output, topk_weights_original, topk_ids_original, scoring_func
    )
    compiled_fn(gating_output, topk_weights_compiled, topk_ids_compiled, scoring_func)
207
208
209

    # Sort the results for comparison since the order might not be deterministic
    topk_ids_original, indices_original = torch.sort(topk_ids_original)
210
    topk_weights_original = torch.gather(topk_weights_original, 1, indices_original)
211
212

    topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled)
213
    topk_weights_compiled = torch.gather(topk_weights_compiled, 1, indices_compiled)
214
215

    # Verify results match
216
217
218
    assert torch.allclose(
        topk_weights_original, topk_weights_compiled, rtol=1e-2, atol=1e-2
    )
219
    assert torch.allclose(topk_ids_original, topk_ids_compiled)