topkrouter.py 8.07 KB
Newer Older
blkmjsian's avatar
blkmjsian committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import ctypes
from ctypes import c_uint64
import torch
import torch.nn as nn
import torch.nn.functional as F

from libinfiniop import (
    LIBINFINIOP,
    TestTensor,
    get_test_devices,
    check_error,
    test_operator,
    get_args,
    debug,
    get_tolerance,
    profile_operation,
    TestWorkspace,
    InfiniDtype,
    InfiniDtypeNames,
    InfiniDeviceNames,
    infiniopOperatorDescriptor_t,
22
    torch_device_map,
blkmjsian's avatar
blkmjsian committed
23
24
25
26
27
28
29
)

# ==============================================================================
#  Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES_ = [
30
31
    # x_shape, x_stride, topk, routed_scaling_factor
    ((1, 256), None, 8, 2.5),
32
    ((2, 256), None, 8, 1.0),
blkmjsian's avatar
blkmjsian committed
33
34
35
36
]

# w (weight) types
# Note: 'None' means the same as input dtype
zhangyue's avatar
zhangyue committed
37
# _X_DTYPES = [InfiniDtype.F32, InfiniDtype.BF16, InfiniDtype.F16]
38
39
_X_DTYPES = []  # CPU CI

blkmjsian's avatar
blkmjsian committed
40
41
42
43
44
45
46
47
48
49
50
# x types used for testing
_VALUE_DTYPES = [InfiniDtype.F32]

# Form the test cases by appending each element of _X_DTYPES to each tuple in _TEST_CASES_
_TEST_CASES = [
    test_case + (x_dtype,) for test_case in _TEST_CASES_ for x_dtype in _X_DTYPES
]

# Tolerance map for different data types
_TOLERANCE_MAP = {
    InfiniDtype.F32: {"atol": 1e-3, "rtol": 1e-3},
51
52
    InfiniDtype.BF16: {"atol": 1e-3, "rtol": 1e-3},
    InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3},
blkmjsian's avatar
blkmjsian committed
53
54
55
56
57
58
59
60
61
}

DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000


def tensorInfo(data):
62
63
64
65
66
67
68
69
70
71
    print(
        "data:  ",
        data.is_contiguous(),
        data.device,
        data.dtype,
        data.shape,
        data.stride(),
        data.data_ptr(),
        hex(data.data_ptr()),
    )
blkmjsian's avatar
blkmjsian committed
72
73
74


class DeepseekV3TopkRouter(nn.Module):
75
76
77
78
79
80
81
    def __init__(
        self,
        correction_bias,
        routed_scaling_factor: float = 2.5,
        topk: int = 8,
        config=None,
    ):
blkmjsian's avatar
blkmjsian committed
82
83
        super().__init__()
        self.config = config
84
85
        self.top_k = topk  # config.num_experts_per_tok 8
        assert topk == 8
blkmjsian's avatar
blkmjsian committed
86
        self.n_routed_experts = 256  # config.n_routed_experts
87
88
89
        self.routed_scaling_factor = (
            routed_scaling_factor  # config.routed_scaling_factor 2.5
        )
blkmjsian's avatar
blkmjsian committed
90
91
92
93
94
95
96
97
        self.n_group = 8  # config.n_group
        self.topk_group = 4  # config.topk_group
        self.norm_topk_prob = True  # config.norm_topk_prob

        # self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
        # self.weight = torch.rand(256, 7168) * 2 - 1

        # self.register_buffer("e_score_correction_bias", torch.zeros(self.n_routed_experts))
98
        self.e_score_correction_bias = torch.zeros(256, device=correction_bias.device)
blkmjsian's avatar
blkmjsian committed
99
100
101
102
        self.e_score_correction_bias[:] = correction_bias[:]

    @torch.no_grad()
    def get_topk_indices(self, scores):
103
104
105
        scores_for_choice = scores.view(
            -1, self.n_routed_experts
        ) + self.e_score_correction_bias.unsqueeze(0)  # Size([1, 256])
blkmjsian's avatar
blkmjsian committed
106
        group_scores = (
107
108
109
            scores_for_choice.view(
                -1, self.n_group, self.n_routed_experts // self.n_group
            )
blkmjsian's avatar
blkmjsian committed
110
111
112
113
            .topk(2, dim=-1)[0]
            .sum(dim=-1)
        )

114
115
116
        group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=True)[
            1
        ]  # Size([1, 4])
blkmjsian's avatar
blkmjsian committed
117
118
119
120
121
122
123
124
125
        group_mask = torch.zeros_like(group_scores)  # Size([1, 8])
        group_mask.scatter_(1, group_idx, 1)  # Size([1, 8])

        score_mask = (
            group_mask.unsqueeze(-1)
            .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
            .reshape(-1, self.n_routed_experts)
        )

126
127
128
129
130
131
        scores_for_choice = scores_for_choice.masked_fill(
            ~score_mask.bool(), 0.0
        )  # Size([1, 256])
        topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=True)[
            1
        ]  # Size([1, 8])
blkmjsian's avatar
blkmjsian committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148

        return topk_indices

    def forward(self, router_logits):
        # hidden_states = hidden_states.view(-1, 7168)
        # router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))

        scores = router_logits.sigmoid()  # (1,256)
        scores = scores.to(torch.float32)

        topk_indices = self.get_topk_indices(scores)  # (1,8)
        topk_weights = scores.gather(1, topk_indices)

        if self.norm_topk_prob:
            denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
            topk_weights /= denominator
        topk_weights = topk_weights * self.routed_scaling_factor
149

blkmjsian's avatar
blkmjsian committed
150
151
152
        return topk_indices, topk_weights


153
def torch_topkrouter(router_logits, correction_bias, routed_scaling_factor, topk):
154
155
156
    lable_indices, lable_values = DeepseekV3TopkRouter(
        correction_bias, routed_scaling_factor, topk
    )(router_logits)
blkmjsian's avatar
blkmjsian committed
157
158
159
160
161
    lable_indices = lable_indices.to(torch.int32)
    return lable_values, lable_indices


def test(
162
163
164
165
166
167
168
169
170
    handle,
    device,
    x_shape,
    x_stride,
    topk,
    routed_scaling_factor,
    x_dtype=InfiniDtype.F32,
    dtype=InfiniDtype.F16,
    sync=None,
blkmjsian's avatar
blkmjsian committed
171
172
):
    print(
173
        f"Testing topkrouter on {InfiniDeviceNames[device]} with x_shape:{x_shape} "
blkmjsian's avatar
blkmjsian committed
174
175
176
177
178
179
        f"x_stride:{x_stride} w_dtype:{InfiniDtypeNames[x_dtype]} dtype:{InfiniDtypeNames[dtype]}"
    )

    data = torch.arange(0, x_shape[0] * x_shape[1]).reshape(x_shape)

    N, width = x_shape
180
181
182
183
184
185
    x = TestTensor(
        x_shape, data.stride(), x_dtype, device, scale=5.0, bias=-5.0, mode="random"
    )
    correction_bias = TestTensor(
        [x_shape[1]], [1], InfiniDtype.F32, device, mode="random"
    )
blkmjsian's avatar
blkmjsian committed
186
187
188
189
190
191
192

    if sync is not None:
        sync()

    descriptor = infiniopOperatorDescriptor_t()
    check_error(
        LIBINFINIOP.infiniopCreateTopkrouterDescriptor(
193
            handle, ctypes.byref(descriptor), x.descriptor, correction_bias.descriptor
blkmjsian's avatar
blkmjsian committed
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
        )
    )

    # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
    for tensor in [x, correction_bias]:
        tensor.destroy_desc()

    workspace_size = c_uint64(0)
    check_error(
        LIBINFINIOP.infiniopGetTopkrouterWorkspaceSize(
            descriptor, ctypes.byref(workspace_size)
        )
    )
    workspace = TestWorkspace(workspace_size.value, x.device)

209
210
211
212
213
214
    values = torch.zeros(
        (N, topk), dtype=torch.float32, device=torch_device_map[x.device]
    )
    indices = torch.zeros(
        (N, topk), dtype=torch.int32, device=torch_device_map[x.device]
    )
blkmjsian's avatar
blkmjsian committed
215
216
217
218
219
220
221
222
223
224
225

    def lib_topkrouter():
        check_error(
            LIBINFINIOP.infiniopTopkrouter(
                descriptor,
                workspace.data(),
                workspace_size.value,
                values.data_ptr(),
                indices.data_ptr(),
                x.data(),
                correction_bias.data(),
226
                routed_scaling_factor,
blkmjsian's avatar
blkmjsian committed
227
228
229
230
231
232
                topk,
                None,
            )
        )

    lib_topkrouter()
zhangyue's avatar
zhangyue committed
233

234
235
236
    lable_values, lable_indices = torch_topkrouter(
        x.actual_tensor(), correction_bias.actual_tensor(), routed_scaling_factor, topk
    )
blkmjsian's avatar
blkmjsian committed
237
238
239
240
241
242
    atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
    if DEBUG:
        debug(lable_values, values, atol=atol, rtol=rtol)
        debug(lable_indices, indices, atol=atol, rtol=rtol)

    assert torch.allclose(lable_values, values, atol=atol, rtol=rtol)
243
    assert torch.allclose(lable_indices, indices, atol=atol, rtol=rtol)
blkmjsian's avatar
blkmjsian committed
244
245
246
247

    # Profiling workflow
    if PROFILE:
        # fmt: off
248
        profile_operation("PyTorch", lambda: torch_topkrouter(x.actual_tensor(), correction_bias.actual_tensor(), routed_scaling_factor, topk), device, NUM_PRERUN, NUM_ITERATIONS)
blkmjsian's avatar
blkmjsian committed
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
        profile_operation("    lib", lambda: lib_topkrouter(), device, NUM_PRERUN, NUM_ITERATIONS)
        # fmt: on
    check_error(LIBINFINIOP.infiniopDestroyTopkrouterDescriptor(descriptor))


if __name__ == "__main__":
    args = get_args()

    # Configure testing options
    DEBUG = args.debug
    PROFILE = args.profile
    NUM_PRERUN = args.num_prerun
    NUM_ITERATIONS = args.num_iterations

    # Execute tests
    for device in get_test_devices(args):
        test_operator(device, test, _TEST_CASES, _VALUE_DTYPES)

    print("\033[92mTest passed!\033[0m")