"examples/backends/vllm/launch/agg_omni_image.sh" did not exist on "93208162753986f9449d3671d6a263dfc4f4381e"
test_rocm_aiter_topk.py 7.55 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# This is a test for the AITER ops.
# It tests if the AITER ops are
# 1. correctly registered as custom ops
# 2. correctly defined the relationship between
#    implementation and fake function
# 3. can be used with torch.compile
# This file will be skipped if AITER is not installed
# and the platform is not ROCm.

import importlib.util

import pytest
import torch

17
18
19
20
21
from vllm.platforms import current_platform

if not current_platform.is_rocm():
    pytest.skip("This test can only run on ROCm.", allow_module_level=True)

22
23
24
25
26
27
28
# this import statement is needed to ensure the ops are registered
import vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe  # noqa: F401

# need to import once to ensure the ops are registered
# Check if aiter package is installed
aiter_available = importlib.util.find_spec("aiter") is not None

29
30
if not aiter_available:
    pytest.skip("These tests require AITER to run.", allow_module_level=True)
31
32
33
34
35


def test_rocm_aiter_biased_grouped_topk_custom_op_registration():
    """Test that the custom op is correctly registered."""
    # Check if the op exists in torch.ops.vllm
36
    assert hasattr(torch.ops.vllm, "rocm_aiter_biased_grouped_topk")
37
38
39
40
41

    # Check if the op is callable
    assert callable(torch.ops.vllm.rocm_aiter_biased_grouped_topk)


42
43
44
def test_rocm_aiter_grouped_topk_custom_op_registration():
    """Test that the custom op is correctly registered."""
    # Check if the op exists in torch.ops.vllm
45
    assert hasattr(torch.ops.vllm, "rocm_aiter_grouped_topk")
46
47
48
49
50

    # Check if the op is callable
    assert callable(torch.ops.vllm.rocm_aiter_grouped_topk)


51
52
53
54
55
56
57
58
59
60
61
def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility():
    """Test that the op can be used with torch.compile."""
    # Create test tensors
    token = 64
    expert = 256
    num_expert_group = 8
    topk = 8
    topk_group = 4
    renormalize = True
    scale_factor = 1.0

62
63
64
65
    gating_output = torch.randn((token, expert), dtype=torch.bfloat16, device="cuda")
    e_score_correction_bias = torch.randn(
        (expert,), dtype=torch.bfloat16, device="cuda"
    )
66
67
68

    device = gating_output.device
    topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
69
    topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
70
71

    # Define a function that uses the op
72
73
74
    def biased_grouped_topk_fn(
        gating_output, e_score_correction_bias, topk_weights, topk_ids
    ):
75
        return torch.ops.vllm.rocm_aiter_biased_grouped_topk(
76
77
78
79
80
81
82
83
84
            gating_output,
            e_score_correction_bias,
            topk_weights,
            topk_ids,
            num_expert_group,
            topk_group,
            renormalize,
            scale_factor,
        )
85
86
87
88
89
90
91
92
93

    # Verify the op's fake implementation
    torch.library.opcheck(
        torch.ops.vllm.rocm_aiter_biased_grouped_topk,
        (gating_output, e_score_correction_bias, topk_weights, topk_ids),
        kwargs={
            "num_expert_group": num_expert_group,
            "topk_group": topk_group,
            "need_renorm": renormalize,
94
            "routed_scaling_factor": scale_factor,
95
        },
96
97
        test_utils=("test_faketensor"),
    )
98
99

    # Compile the function with appropriate settings
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    compiled_fn = torch.compile(
        biased_grouped_topk_fn,
        fullgraph=True,
        backend="inductor",
        mode="reduce-overhead",
        dynamic=False,
    )

    topk_weights_original = torch.empty(
        (token, topk), dtype=torch.float32, device=device
    )
    topk_ids_original = torch.empty((token, topk), dtype=torch.int32, device=device)

    topk_weights_compiled = torch.empty(
        (token, topk), dtype=torch.float32, device=device
    )
    topk_ids_compiled = torch.empty((token, topk), dtype=torch.int32, device=device)
117
118

    # Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode)
119
120
121
122
123
124
    biased_grouped_topk_fn(
        gating_output, e_score_correction_bias, topk_weights_original, topk_ids_original
    )
    compiled_fn(
        gating_output, e_score_correction_bias, topk_weights_compiled, topk_ids_compiled
    )
125
126
127

    # Sort the results for comparison since the order might not be deterministic
    topk_ids_original, indices_original = torch.sort(topk_ids_original)
128
    topk_weights_original = torch.gather(topk_weights_original, 1, indices_original)
129
130

    topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled)
131
    topk_weights_compiled = torch.gather(topk_weights_compiled, 1, indices_compiled)
132
133

    # Verify results match
134
135
136
    assert torch.allclose(
        topk_weights_original, topk_weights_compiled, rtol=1e-2, atol=1e-2
    )
137
    assert torch.allclose(topk_ids_original, topk_ids_compiled)
138
139
140
141
142
143
144
145
146
147
148
149
150
151


def test_rocm_aiter_grouped_topk_torch_compile_compatibility():
    """Test that the op can be used with torch.compile."""
    # Create test tensors
    token = 64
    expert = 256
    num_expert_group = 8
    topk = 8
    topk_group = 4
    renormalize = True
    scoring_func = "softmax"
    scale_factor = 1.0

152
    gating_output = torch.randn((token, expert), dtype=torch.bfloat16, device="cuda")
153
154
155

    device = gating_output.device
    topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
156
    topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
157
158
159
160

    # Define a function that uses the op
    def grouped_topk_fn(gating_output, topk_weights, topk_ids, scoring_func):
        return torch.ops.vllm.rocm_aiter_grouped_topk(
161
162
163
164
165
166
167
168
169
            gating_output,
            topk_weights,
            topk_ids,
            num_expert_group,
            topk_group,
            renormalize,
            scoring_func,
            scale_factor,
        )
170
171

    # Verify the op's fake implementation
172
173
174
175
176
177
178
179
180
181
182
183
    torch.library.opcheck(
        torch.ops.vllm.rocm_aiter_grouped_topk,
        (gating_output, topk_weights, topk_ids),
        kwargs={
            "num_expert_group": num_expert_group,
            "topk_group": topk_group,
            "need_renorm": renormalize,
            "scoring_func": scoring_func,
            "routed_scaling_factor": scale_factor,
        },
        test_utils=("test_faketensor"),
    )
184
185

    # Compile the function with appropriate settings
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    compiled_fn = torch.compile(
        grouped_topk_fn,
        fullgraph=True,
        backend="inductor",
        mode="reduce-overhead",
        dynamic=False,
    )

    topk_weights_original = torch.empty(
        (token, topk), dtype=torch.float32, device=device
    )
    topk_ids_original = torch.empty((token, topk), dtype=torch.int32, device=device)

    topk_weights_compiled = torch.empty(
        (token, topk), dtype=torch.float32, device=device
    )
    topk_ids_compiled = torch.empty((token, topk), dtype=torch.int32, device=device)
203
204

    # Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode)
205
206
207
208
    grouped_topk_fn(
        gating_output, topk_weights_original, topk_ids_original, scoring_func
    )
    compiled_fn(gating_output, topk_weights_compiled, topk_ids_compiled, scoring_func)
209
210
211

    # Sort the results for comparison since the order might not be deterministic
    topk_ids_original, indices_original = torch.sort(topk_ids_original)
212
    topk_weights_original = torch.gather(topk_weights_original, 1, indices_original)
213
214

    topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled)
215
    topk_weights_compiled = torch.gather(topk_weights_compiled, 1, indices_compiled)
216
217

    # Verify results match
218
219
220
    assert torch.allclose(
        topk_weights_original, topk_weights_compiled, rtol=1e-2, atol=1e-2
    )
221
    assert torch.allclose(topk_ids_original, topk_ids_compiled)