test_moe_w8a8_fused_shared_experts.py 12.8 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
# SPDX-License-Identifier: MIT
import pytest
import torch
import itertools
from typing import Optional, List  # Add this import at the top

from op_tests.utility.scalar_type import ScalarType, scalar_types
from op_tests.utility.utils import quantize_weights
from op_tests.utility.utils import torch_moe as torch_score_moe
from aiter.ops.triton.moe_op import fused_moe
from aiter.fused_moe import fused_topk, torch_moe
from aiter import ActivationType
from aiter.test_common import checkAllclose, perftest,benchmark
from aiter import ck_moe, ck_shuffle_moe, dtypes, silu_and_mul, gelu_and_mul, moe_sum
from aiter.ops.triton.utils.moe_config_utils import get_optimal_moe_config_func
from aiter.ops.triton.fused_moe import fused_experts_impl
from aiter.ops.triton.utils.types import torch_to_triton_dtype
from aiter.fused_moe_asm_wna16 import fused_experts_asm_impl
from aiter import per_token_quant_hip, per_block_quant_wrapper
import pandas as pd
import aiter
import os
BLOCK_SIZE_M = 32
MAX_TOKENS = 65536
os.environ["AMDGCN_USE_BUFFER_OPS"] = "1"

@perftest(num_warmup=1, num_iters=2)
def torch_moe_test(
    hidden_states,
    w1,
    w2,
    topk_weight,
    topk_ids,
    # following for int8 quant
    fc1_scale=None,  # [expert, inter_dim, 1]
    fc2_scale=None,  # [expert, model_dim, 1]
    fc1_smooth_scale=None,  # [expert, 1, model_dim]
    fc2_smooth_scale=None,  # [expert, 1, inter_dim]
    expert_mask=None,
):
    return torch_moe(
        hidden_states,
        w1,
        w2,
        topk_weight,
        topk_ids,
        fc1_scale,
        fc2_scale,
        fc1_smooth_scale,
        fc2_smooth_scale,
        expert_mask,
    )

@perftest(num_warmup=5, num_iters=10,testGraph=False)
def asm_fused_experts_impl(hidden_states: torch.Tensor,
                       w1: torch.Tensor,
                       w2: torch.Tensor,
                       topk_weights: torch.Tensor,
                       topk_ids: torch.Tensor,
                       dtype,
                       inplace: bool = False,
                       activation: str = "silu",
                       use_fp8_w8a8: bool = False,
                       use_int8_w8a8: bool = False,
                       use_int8_w4a8: bool = False,
                       use_int8_w8a16: bool = False,
                       use_int4_w4a16: bool = False,
                       per_channel_quant: bool = False,
                       global_num_experts: int = -1,
                       expert_map: Optional[torch.Tensor] = None,
                       w1_scale: Optional[torch.Tensor] = None,
                       w2_scale: Optional[torch.Tensor] = None,
                       w1_zp: Optional[torch.Tensor] = None,
                       w2_zp: Optional[torch.Tensor] = None,
                       a1_scale: Optional[torch.Tensor] = None,
                       a2_scale: Optional[torch.Tensor] = None,
                       block_shape: Optional[List[int]] = None):

    return fused_experts_asm_impl(
        hidden_states,
        w1,
        w2,
        topk_weights,
        topk_ids,
        dtype,
        inplace,
        activation,
        use_fp8_w8a8,
        use_int8_w8a8,
        use_int8_w4a8,
        use_int8_w8a16,
        use_int4_w4a16,
        per_channel_quant,
        global_num_experts,
        expert_map,
        w1_scale,
        w2_scale,
        w1_zp,
        w2_zp,
        a1_scale,
        a2_scale,
        block_shape
    )

@perftest(num_warmup=5, num_iters=10,testGraph=False)
def triton_fused_experts_impl(hidden_states: torch.Tensor,
                       w1: torch.Tensor,
                       w2: torch.Tensor,
                       topk_weights: torch.Tensor,
                       topk_ids: torch.Tensor,
                       odtype,
                       inplace: bool = False,
                       activation: str = "silu",
                       use_fp8_w8a8: bool = False,
                       use_int8_w8a8: bool = False,
                       use_int8_w8a16: bool = False,
                       use_int4_w4a16: bool = False,
                       use_int4_w4a8: bool = False,
                       per_channel_quant: bool = False,
                       global_num_experts: int = -1,
                       expert_map: Optional[torch.Tensor] = None,
                       w1_scale: Optional[torch.Tensor] = None,
                       w2_scale: Optional[torch.Tensor] = None,
                       w1_zp: Optional[torch.Tensor] = None,
                       w2_zp: Optional[torch.Tensor] = None,
                       a1_scale: Optional[torch.Tensor] = None,
                       a2_scale: Optional[torch.Tensor] = None,
                       block_shape: Optional[List[int]] = None):

    return fused_experts_impl(
        hidden_states,
        w1,
        w2,
        topk_weights,
        topk_ids,
        odtype,
        inplace,
        activation,
        use_fp8_w8a8,
        use_int8_w8a8,
        use_int8_w8a16,
        use_int4_w4a16,
        use_int4_w4a8,
        per_channel_quant,
        global_num_experts,
        expert_map,
        w1_scale,
        w2_scale,
        w1_zp,
        w2_zp,
        a1_scale,
        a2_scale,
        block_shape
    )

@benchmark()
def test_fmoe_w8a8_fused_shared_experts(
    dtype,
    token,
    model_dim,
    inter_dim,
    local_E,
    shared_E,
    local_topk,
    has_zp,
    quant="No",
):

    input = torch.randn((token, model_dim), dtype=dtype, device="cuda") / 10

    # 1. original weight in fp16/fp32
    w1_local = torch.randn((local_E, inter_dim * 2, model_dim), dtype=dtype, device="cuda") / 2
    w2_local = torch.randn((local_E, model_dim, inter_dim), dtype=dtype, device="cuda") / 2
    w1_shared = torch.randn((shared_E, inter_dim * 2, model_dim), dtype=dtype, device="cuda") / 2
    w2_shared = torch.randn((shared_E, model_dim, inter_dim), dtype=dtype, device="cuda") / 2

    w1_local_ref = w1_local.clone()
    w2_local_ref = w2_local.clone()
    w1_shared_ref = w1_shared.clone()
    w2_shared_ref = w2_shared.clone()

    w1_total_ref = torch.cat([w1_local, w1_shared], dim=0)  # (local_E + shared_E, inter_dim * 2, model_dim)
    w2_total_ref = torch.cat([w2_local, w2_shared], dim=0)  # (local_E + shared_E, model_dim, inter_dim)

    # 2. quantize weight to int8
    max_vals = torch.abs(w1_local.to(torch.float32)).max(dim=-1, keepdim=True)[0]
    max_vals = max_vals.clamp(min=1e-5)
    w1_local_scales = max_vals / 127.0
    w1_local_qweight = (w1_local / max_vals * 127.0).round().clamp(min=-128, max=127).to(torch.int8)

    max_vals = torch.abs(w2_local.to(torch.float32)).max(dim=-1, keepdim=True)[0]
    max_vals = max_vals.clamp(min=1e-5)
    w2_local_scales = max_vals / 127.0
    w2_local_qweight = (w2_local / max_vals * 127.0).round().clamp(min=-128, max=127).to(torch.int8)

    max_vals = torch.abs(w1_shared.to(torch.float32)).max(dim=-1, keepdim=True)[0]
    max_vals = max_vals.clamp(min=1e-5)
    w1_shared_scales = max_vals / 127.0
    w1_shared_qweight = (w1_shared / max_vals * 127.0).round().clamp(min=-128, max=127).to(torch.int8)

    max_vals = torch.abs(w2_shared.to(torch.float32)).max(dim=-1, keepdim=True)[0]
    max_vals = max_vals.clamp(min=1e-5)
    w2_shared_scales = max_vals / 127.0
    w2_shared_qweight = (w2_shared / max_vals * 127.0).round().clamp(min=-128, max=127).to(torch.int8)

    w1_total_qweight = torch.cat([w1_local_qweight, w1_shared_qweight], dim=0)
    w2_total_qweight = torch.cat([w2_local_qweight, w2_shared_qweight], dim=0)
    w1_total_scales = torch.cat([w1_local_scales, w1_shared_scales], dim=0)
    w2_total_scales = torch.cat([w2_local_scales, w2_shared_scales], dim=0)


    # 3. generate score for local experts
    local_score = torch.randn((token, local_E), device="cuda", dtype=dtype)
    # Add small unique offset to each expert to ensure unique scores
    local_score += torch.arange(local_E, device="cuda", dtype=dtype).unsqueeze(0) * 1e-5

    # 目前 biased_grouped_topk 接口暂不支持softmax方案,而是为local_score执行sigmoid,故is_softmax=False
    local_topk_weights, local_topk_ids = fused_topk(input, local_score, local_topk, True, is_softmax=False)

    # 4. local expert golden output
    torch_local_out, torch_local_avg = torch_moe_test(
        input, w1_local_ref, w2_local_ref, local_topk_weights, local_topk_ids
    )

    # 5. sharded expert golden output. 若多个shared expert结果直接相加,等价于多个shared expert的topk weight相同,即fused_topk_weights[:, local_topk:] = 1.0
    SHARED_EXPERT_WEIGHT = 1.0
    torch_shared_out = torch.zeros((token, model_dim), dtype=dtype, device=input.device)
    for e in range(shared_E):
        gate_up_e = input @ w1_shared_ref[e].transpose(0, 1)  # (token, inter_dim * 2)
        silu_out_e = torch.empty((token, inter_dim), dtype=dtype, device=input.device)
        silu_and_mul(silu_out_e, gate_up_e)
        out_e = silu_out_e @ w2_shared_ref[e].transpose(0, 1)  # (token, model_dim)
        torch_shared_out += (out_e * SHARED_EXPERT_WEIGHT)

    # 6. local + shared output
    torch_total_out = torch_local_out + torch_shared_out

    # 7. get fused topk ids and weights
    # 7.1 local method
    total_topk = local_topk + shared_E
    fused_topk_ids = torch.empty(token, total_topk, dtype=dtypes.i32, device=input.device)
    fused_topk_ids[:, :local_topk] = local_topk_ids
    # 在local_topk后面补上shared expert的id
    fused_topk_ids[:, local_topk:] = torch.arange(local_E, local_E + shared_E, device=input.device).unsqueeze(0).expand(token, -1)
    
    fused_topk_weights = torch.empty(token, total_topk, dtype=dtypes.fp32, device=input.device)
    fused_topk_weights[:, :local_topk] = local_topk_weights
    fused_topk_weights[:, local_topk:] = SHARED_EXPERT_WEIGHT

    # 7.2 fused shared expert topk method
    fused_topk_ids_0 = torch.empty(token, total_topk, dtype=dtypes.i32, device=input.device)
    fused_topk_weights_0 = torch.empty(token, total_topk, dtype=dtypes.fp32, device=input.device)
    aiter.biased_grouped_topk(
        gating_output = local_score,
        correction_bias = torch.zeros(local_E, device=input.device),  # bias for all experts, zeros
        topk_weights = fused_topk_weights_0,
        topk_ids = fused_topk_ids_0,
        num_expert_group = 1,
        topk_group = 1,
        num_fused_shared_experts = shared_E,
        need_renorm = True,
        routed_scaling_factor = 1.0,
    )

    # print(f"topk_ids:\n{fused_topk_ids}\nfused_topk_ids_0:\n{fused_topk_ids_0}")
    msg = f"[Fused shared expert topk ids] {token=}, {local_E=}, {local_topk=}, {shared_E=}"
    checkAllclose(fused_topk_ids, fused_topk_ids_0, rtol=0.01, atol=0.01, msg=msg)

    # print(f"fused_topk_weights:\n{fused_topk_weights}\nfused_topk_weights_0:\n{fused_topk_weights_0}")
    msg = f"[Fused shared expert topk weights] {token=}, {local_E=}, {local_topk=}, {shared_E=}"
    checkAllclose(fused_topk_weights, fused_topk_weights_0, rtol=0.01, atol=0.01, msg=msg)


    # 8. torch fused shared expert output
    torch_fused_out, torch_fused_avg = torch_moe_test(
        input, w1_total_ref, w2_total_ref, fused_topk_weights, fused_topk_ids
    )

    # 9. 验证fused experts算法正确性
    msg = f"[Fused shared expert logit] {token=}, {model_dim=}, {inter_dim=}, {local_E=}, {shared_E=}, {local_topk=}, dtype: {dtype}, torch_fused_avg: {torch_fused_avg:<8.2f} us"
    checkAllclose(torch_total_out, torch_fused_out, rtol=0.01, atol=10, msg=msg)


    # 10. triton fused shared expert output
    input_q, input_scale = per_token_quant_hip(input,quant_dtype=torch.int8)
    triton_output, avg_triton = triton_fused_experts_impl(
        input_q,
        w1_total_qweight,
        w2_total_qweight,
        fused_topk_weights_0,
        fused_topk_ids_0,
        dtype,
        inplace=False,
        activation="silu",
        use_fp8_w8a8=False,
        use_int8_w8a8=True,
        use_int8_w8a16=False,
        use_int4_w4a16=False,
        use_int4_w4a8=False,
        per_channel_quant=True,
        global_num_experts=local_E + shared_E,
        expert_map=None,
        w1_scale=w1_total_scales,
        w2_scale=w2_total_scales,
        w1_zp=None,
        w2_zp=None,
        a1_scale=input_scale,
        a2_scale=None,
        block_shape=None)
    msg = f"[TRITON_fused_test] {token=}, {model_dim=}, {inter_dim=}, {local_E=}, {shared_E=}, {local_topk=}, dtype: {dtype}, triton_avg: {avg_triton:>8.2f} us, torch_fused_avg: {torch_fused_avg:>8.2f} us"
    checkAllclose(torch_total_out, triton_output, rtol=0.01, atol=50, msg=msg)


    return {
        "local_avg_torch": torch_local_avg,
        "avg_triton": avg_triton,
    }



df = []

for dtype in [dtypes.fp16]:
    for m in [1, 16]:
        for dim in [7168]:
            for hdim in [256]:
                for sh_e in [1, 2]:
                    ret = test_fmoe_w8a8_fused_shared_experts(dtype, m, dim, hdim, 128, sh_e, 8, False, quant="int8")
                    df.append(ret)

df = pd.DataFrame(df)
# df.to_csv("test_moe_w8a8_fused_shared_experts.csv")
aiter.logger.info(f"summary:\n{df}")