test_grouped_gemm.py 23.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
# SPDX-License-Identifier: MIT
"""
aiter.ck_grouped_gemm 精度与性能测试。

语义
----
对每个 group i,内核计算  C[i] = A[i] @ B[i]^T。
  A[i] : [M_i, K_i]  行主序,输入 dtype
  B[i] : [N_i, K_i]  行主序,输入 dtype(CK 侧按列主序解释,等价于 B^T)
  C[i] : [M_i, N_i]  行主序,输出 dtype

输出 dtype
  fp16  -> fp16
  bf16  -> bf16
  fp8   -> float32(fp8 点积,float32 累加,无 scale)
  int8  -> int32  (int8 点积,int32  累加,无 scale)

fp8 / int8 说明
---------------
CK 执行原始硬件 MMAC,不做量化 scale。
下方参考实现为 a.float() @ b.float().T,与之对齐。
  - int8:int32 累加,结果应严格一致。
  - fp8:硬件先 fp8×fp8 再 float32 累加;参考实现先扩到 float32 再乘,
    存在数值差异,容差见 TOLERANCE。

dense / random fp8、int8 正确性仍待跟进:
  - int8 dense random:应精确;失败则可能是 CK instance bug。
  - fp8  dense random:允许小误差,rtol/atol ~ 0.2。

Layout 支持(--layout)
-----------------------
  RC(默认):A 行主序,B 存 [N,K] 行主序,CK 按列主序读(等价 B^T)
  RR/CR/CC:其他组合,经同一 C ABI 与 stride 路由;本测试对 RC 做正确性校验。

CI smoke 命令(小 shape,快速验正确性)
---------------------------------------
  python op_tests/test_grouped_gemm.py --smoke --dtype fp16
  python op_tests/test_grouped_gemm.py --smoke --dtype all --variable
  python op_tests/test_grouped_gemm.py --moe --dtype fp16

性能测试(默认 m=n=k=1024,约 6.4 GFLOPS/group)
-------------------------------------------------
  python op_tests/test_grouped_gemm.py --dtype fp16
  python op_tests/test_grouped_gemm.py --dtype all --variable --warmup 10 --repeat 100

说明:TFLOPS = sum(2*M*N*K) / time。shape 过小时 launch/sync 开销主导,数值会偏低;
torch_gemm 使用原生 dtype 的 matmul,torch_ref 仅用于精度校验(float32 慢路径)。
"""

import argparse
import time

import torch

import aiter


DTYPE_MAP = {
    "fp16": torch.float16,
    "bf16": torch.bfloat16,
    "fp8":  torch.float8_e4m3fn,
    "int8": torch.int8,
}

# 各 dtype 的 assert_close 容差。
# fp8:硬件 fp8×fp8 MMAC 与 float32 参考可能舍入不同。
# int8:必须精确(int32 累加,无舍入)。
TOLERANCE = {
    torch.float16:       dict(rtol=5e-2, atol=5e-2),
    torch.bfloat16:      dict(rtol=5e-2, atol=5e-2),
    torch.float8_e4m3fn: dict(rtol=2e-1, atol=2e-1),
    torch.int8:          dict(rtol=0,    atol=0),
}

# 各 dtype 的 shape 对齐要求(CK tile 约束)。
# fp16/bf16 GemmConfigComputeV4 的 K_Tile=64;当前 CK instance 在 K=64(恰好一个 tile)会出错,
# 实际要求 K >= 128。
SHAPE_ALIGN = {
    torch.float16:       dict(m=64,  n=128, k=128),
    torch.bfloat16:      dict(m=64,  n=128, k=128),
    torch.float8_e4m3fn: dict(m=128, n=128, k=128),
    torch.int8:          dict(m=32,  n=32,  k=128),
}

# 各 dtype 预定义变长 shape(来自 Stage 1 验证)。
VARIABLE_SHAPES = {
    torch.float16:       [(128, 128, 128), (256, 128, 128), (384, 128, 128)],
    torch.bfloat16:      [(128, 128, 128), (256, 128, 128), (384, 128, 128)],
    torch.float8_e4m3fn: [(128, 128, 128), (256, 128, 256), (128, 256, 128)],
    torch.int8:          [(64,  64,  128), (128, 128, 256), (192, 64,  128)],
}

# 异构 shape:每组 M/N 不同(同 dtype 内 K 保持一致)。
# CK fp16/bf16 GemmConfigComputeV4 不支持组间 mixed K,因此只变 M/N。
HETERO_SHAPES = {
    torch.float16:       [(128, 128, 128), (192, 256, 128), (256, 128, 128)],
    torch.bfloat16:      [(128, 128, 128), (192, 256, 128), (256, 128, 128)],
    torch.float8_e4m3fn: [(128, 128, 128), (256, 128, 128), (128, 256, 128)],
    torch.int8:          [(32,  32,  128), (64,  64,  128), (96,  32,  128)],
}

# MOE:固定 N/K,每组 M 可不对齐(wrapper 自动 pad M 到 tile 边界)。
MOE_FIXED_NK = {
    torch.float16:       (128, 128),
    torch.bfloat16:      (128, 128),
    torch.float8_e4m3fn: (128, 128),
    torch.int8:          (32,  128),
}
MOE_M_VALUES = [1, 17, 33, 63, 65, 100]


# ── 张量构造 ──────────────────────────────────────────────────────────────────

def _rand_tensor(shape, dtype):
    """8-bit 用确定性 projection;fp16/bf16 用随机数。"""
    if dtype is torch.int8:
        values = (torch.arange(shape[1], device="cuda", dtype=torch.int16) % 5).to(torch.int8)
        return values.view(1, shape[1]).repeat(shape[0], 1).contiguous()
    if dtype is torch.float8_e4m3fn:
        values = (torch.arange(shape[1], device="cuda", dtype=torch.float16) % 5).to(dtype)
        return values.view(1, shape[1]).repeat(shape[0], 1).contiguous()
    src = (torch.randn(shape, device="cuda", dtype=torch.float16) / 10).contiguous()
    return src.to(dtype).contiguous()


def _make_b(shape, dtype):
    """
    构造 B 张量:8-bit 用 projection(仅首行非零),fp16/bf16 用随机数。
    保持与既有 projection 测试一致。
    """
    n, k = shape
    if dtype in (torch.int8, torch.float8_e4m3fn):
        b = torch.zeros(shape, device="cuda", dtype=dtype)
        if dtype is torch.int8:
            b[0, :] = 1
        else:
            b[0, :] = torch.ones(k, device="cuda", dtype=torch.float16).to(dtype)
        return b.contiguous()
    return _rand_tensor(shape, dtype)


def _make_b_layout(shape, dtype, b_layout):
    """
    按指定 layout 构造 B 的内存布局。

    Layout 编码(与 CK C ABI 的 a_layout / b_layout 一致):
      'R' -> 行主序 -> [rows, cols] stride=(cols, 1)
      'C' -> 列主序 -> [rows, cols] stride=(1, rows)

    Python API 固定 a_layout='R', b_layout='C',B 存 [N,K] 行主序。
    测其他 layout 时需传连续转置张量以匹配 shape 语义。
    """
    if b_layout == "C":
        # 常规:B 存 [N,K],CK 按 [K,N] 列主序读,等价 B^T。
        return _make_b(shape, dtype)
    # b_layout == "R":B 存 [N,K] 行主序;CK 按 [K,N] 行主序读,
    # 即 C = A @ B(非 A @ B^T);参考实现需相应调整。
    n, k = shape
    return _make_b((n, k), dtype)


def _reference_rc(a, b):
    """RC layout 精度参考:C = A @ B^T(float32 累加,用于 assert_close)。"""
    return a.float() @ b.float().T


def _torch_gemm_grouped(a_tensors, b_tensors, dtype):
    """
    torch GEMM 性能基准:尽量用原生 dtype 的 matmul,避免 float() 转换开销。
    fp8/int8 无高效原生 matmul 时回退到 float32 参考。
    """
    if dtype in (torch.float16, torch.bfloat16):
        return [a @ b.T for a, b in zip(a_tensors, b_tensors)]
    return [_reference_rc(a, b) for a, b in zip(a_tensors, b_tensors)]


def _reference_rr(a, b):
    """RR layout 参考:C = A @ B(CK 视角下 B 为 [K,N])。"""
    # b_layout='R' 时 B 存 [N,K] 行主序,CK 按 [K,N] 行主序解释。
    # 为简化,测试统一用 RC 约定。
    return a.float() @ b.float().T


# ── 断言 / 基准 / 算力 ────────────────────────────────────────────────────────

def _accuracy_stats(outputs, refs):
    """逐 group 统计 max_abs / max_rel(用于打印,不判定 pass/fail)。"""
    stats = []
    for out, ref in zip(outputs, refs):
        out_f = out.float()
        ref_f = ref.float()
        diff = (out_f - ref_f).abs()
        max_abs = diff.max().item()
        denom = ref_f.abs().clamp_min(1e-6)
        max_rel = (diff / denom).max().item()
        stats.append(dict(max_abs=max_abs, max_rel=max_rel))
    return stats


def _assert_close(dtype, outputs, refs):
    tol = TOLERANCE[dtype]
    for idx, (out, ref) in enumerate(zip(outputs, refs)):
        torch.testing.assert_close(
            out.float(),
            ref.float(),
            **tol,
            msg=lambda msg: f"group {idx} failed\n{msg}",
        )


def _format_accuracy_line(dtype_name, dtype, outputs, refs, shapes):
    """
    执行精度校验(失败则抛 AssertionError),并返回可打印的 PASS 行。
    """
    stats = _accuracy_stats(outputs, refs)
    _assert_close(dtype, outputs, refs)
    tol = TOLERANCE[dtype]
    shape_text = ",".join(f"{m}x{n}x{k}" for m, n, k in shapes)
    group_text = "  ".join(
        f"g{i}: max_abs={s['max_abs']:.2e} max_rel={s['max_rel']:.2e}"
        for i, s in enumerate(stats)
    )
    return (
        f"  [PASS] accuracy  dtype={dtype_name}  groups={len(shapes)}  shapes={shape_text}"
        f"  rtol={tol['rtol']} atol={tol['atol']}  {group_text}"
    )


def _bench(fn, warmup, repeat):
    for _ in range(warmup):
        fn()
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(repeat):
        fn()
    torch.cuda.synchronize()
    return (time.perf_counter() - t0) * 1000.0 / repeat


def _total_gemm_flops(shapes):
    """
    计算 grouped GEMM 总 FLOPs(乘加各算 1 op,共 2*M*N*K)。
    shapes: [(M, N, K), ...]
    """
    return sum(2 * m * n * k for m, n, k in shapes)


def _ms_to_tflops(flops, ms):
    """将耗时(ms)与 FLOPs 换算为 TFLOPS。"""
    if ms <= 0:
        return 0.0
    return flops / (ms * 1e-3) / 1e12


def _output_dtype(dtype):
    if dtype is torch.int8:
        return torch.int32
    if dtype is torch.float8_e4m3fn:
        return torch.float32
    return dtype


def _make_c_tensors(shapes, dtype):
    out_dtype = _output_dtype(dtype)
    return [torch.empty((m, n), device="cuda", dtype=out_dtype) for m, n, _ in shapes]


def _format_perf_line(dtype_name, shapes, ck_ms, torch_ms, ck_tag="ck"):
    """格式化延迟与 TFLOPS 输出(单行)。"""
    flops = _total_gemm_flops(shapes)
    gflops = flops / 1e9
    ck_tflops = _ms_to_tflops(flops, ck_ms)
    torch_tflops = _ms_to_tflops(flops, torch_ms)
    shape_text = ",".join(f"{m}x{n}x{k}" for m, n, k in shapes)
    return (
        f"grouped_gemm  dtype={dtype_name}  groups={len(shapes)}  total={gflops:.3f} GFLOPS"
        f"  shapes={shape_text}"
        f"  {ck_tag}={ck_ms:.4f} ms ({ck_tflops:.3f} TFLOPS)"
        f"  torch_gemm={torch_ms:.4f} ms ({torch_tflops:.3f} TFLOPS)"
    )


# ── shape 工厂 ────────────────────────────────────────────────────────────────

def _make_shapes_uniform(args, dtype):
    return [(args.m, args.n, args.k)] * args.groups


def _make_shapes_variable(args, dtype):
    return VARIABLE_SHAPES[dtype][: args.groups]


def _make_shapes_heterogeneous(args, dtype):
    return HETERO_SHAPES[dtype][: args.groups]


def _make_shapes(args, dtype):
    if args.heterogeneous:
        return _make_shapes_heterogeneous(args, dtype)
    if args.variable:
        return _make_shapes_variable(args, dtype)
    return _make_shapes_uniform(args, dtype)


# ── 各 dtype 测试用例 ────────────────────────────────────────────────────────

def run_case(dtype_name, args):
    dtype = DTYPE_MAP[dtype_name]
    shapes = _make_shapes(args, dtype)

    a_tensors = [_rand_tensor((m, k), dtype) for m, _, k in shapes]
    b_tensors = [_make_b((n, k), dtype) for _, n, k in shapes]

    outputs = aiter.ck_grouped_gemm(a_tensors, b_tensors)
    refs = [_reference_rc(a, b) for a, b in zip(a_tensors, b_tensors)]
    print(_format_accuracy_line(dtype_name, dtype, outputs, refs, shapes))

    c_tensors = _make_c_tensors(shapes, dtype)
    outputs_pre = aiter.ck_grouped_gemm_out(a_tensors, b_tensors, c_tensors)
    print(_format_accuracy_line(dtype_name, dtype, outputs_pre, refs, shapes).replace(
        "[PASS]", "[PASS prealloc c]", 1))

    ck_ms = _bench(lambda: aiter.ck_grouped_gemm(a_tensors, b_tensors), args.warmup, args.repeat)
    ck_prealloc_ms = _bench(
        lambda: aiter.ck_grouped_gemm_out(a_tensors, b_tensors, c_tensors),
        args.warmup,
        args.repeat,
    )
    torch_ms = _bench(
        lambda: _torch_gemm_grouped(a_tensors, b_tensors, dtype),
        args.warmup,
        args.repeat,
    )

    print(_format_perf_line(dtype_name, shapes, ck_ms, torch_ms, ck_tag="ck"))
    print(_format_perf_line(dtype_name, shapes, ck_prealloc_ms, torch_ms, ck_tag="ck_prealloc"))


# ── layout 变体测试 ───────────────────────────────────────────────────────────

def run_layout_cases(args):
    """
    对 fp16/bf16 测试 RC layout 下多种 shape 的正确性。

    aiter Python API 目前仅暴露 RC(A 行主序,B 存 [N,K] 行主序,CK 按列主序读),
    为推理最常见 layout。本函数用标准 API 校验各 shape 无回归。

    CK layout 说明:
      - RC: A[M,K] 行主序,B[N,K] 行主序(CK 读 [K,N] 列主序)-> C = A @ B^T
      - RR: A[M,K] 行主序,B[K,N] 行主序 -> C = A @ B
      - CR/CC: A 列主序变体,实践中较少用。

    测其他 layout 需直接调 C ABI;此处只测 RC 路径。
    """
    print("\n=== Layout 变体测试(RC,fp16/bf16)===")
    for dtype_name in ("fp16", "bf16"):
        dtype = DTYPE_MAP[dtype_name]
        for m, n, k in VARIABLE_SHAPES[dtype]:
            a = _rand_tensor((m, k), dtype)
            b = _make_b((n, k), dtype)
            out = aiter.ck_grouped_gemm([a], [b])[0]
            ref = _reference_rc(a, b)
            print(_format_accuracy_line(dtype_name, dtype, [out], [ref], [(m, n, k)]))


# ── 异构 shape 测试 ───────────────────────────────────────────────────────────

def run_heterogeneous_cases(args):
    """
    每组 M、N 不同(同 dtype 内 K 固定)。
    验证内核能正确处理逐组 shape 描述符。
    """
    print("\n=== 异构 shape 测试 ===")
    dtype_names = ["fp16", "bf16", "fp8", "int8"] if args.dtype == "all" else [args.dtype]
    for dtype_name in dtype_names:
        dtype = DTYPE_MAP[dtype_name]
        shapes = HETERO_SHAPES[dtype]
        a_tensors = [_rand_tensor((m, k), dtype) for m, _, k in shapes]
        b_tensors = [_make_b((n, k), dtype) for _, n, k in shapes]
        outputs = aiter.ck_grouped_gemm(a_tensors, b_tensors)
        refs = [_reference_rc(a, b) for a, b in zip(a_tensors, b_tensors)]
        print(_format_accuracy_line(dtype_name, dtype, outputs, refs, shapes))


# ── MOE(动态 M,固定 N/K)测试 ───────────────────────────────────────────────
#
# MOE 语义:每组 C_i = A_i @ B_i^T,A_i: [M_i, K],B_i: [N, K](全组相同 N/K)。
# M_i 可为任意正整数,N/K 对齐要求:N % 128 == 0, K % 64 == 0, K >= 128。
#
# 实现路径(CK kernel 层 kPadM):
#   1. ck_tile_dcu_grouped_gemm_run 检测 M_i 是否全部对齐 64
#   2. 对齐 → 走 GemmConfigComputeV4(kPadM=false,~40 TFLOPS 大 GEMM)
#   3. 不对齐 → 走 GemmConfigComputeV4Mpad(kPadM=true,M 维自动 pad 到 tile 边界)
#   4. kernel 内 pad_tensor_view 确保越界写入被抑制,epilogue 只写有效行
#   5. 无需 Python 侧 zero-pad A / slice C(Stage 5 的 ck_grouped_gemm_moe 保留作 fallback)

def run_moe_cases(args):
    """
    MOE 场景:每组 token 数 M_i 任意(不对齐),固定 N/K。
    直接调 ck_grouped_gemm / ck_grouped_gemm_out,CK kernel 层 kPadM 自动处理 M padding。
    """
    print("\n=== MOE 动态 M 测试(固定 N/K,kernel kPadM)===")
    # fp16/bf16 → V4Mpad (MPerBlock=64), fp8 → V5Mpad (MPerBlock=128)
    moe_dtypes = ["fp16", "bf16", "fp8"] if args.dtype == "all" else [args.dtype]
    for dtype_name in moe_dtypes:
        if dtype_name not in ("fp16", "bf16", "fp8"):
            print(f"  [SKIP] {dtype_name}: MOE kPadM not supported")
            continue
        dtype = DTYPE_MAP[dtype_name]
        n, k = MOE_FIXED_NK[dtype]
        ms = MOE_M_VALUES[: args.groups] if args.groups <= len(MOE_M_VALUES) else MOE_M_VALUES
        shapes = [(m, n, k) for m in ms]

        # 构造不对齐 M 的 A 张量(M=1,17,33,63,65,100),全部 M % 64 != 0
        a_tensors = [_rand_tensor((m, k), dtype) for m, _, _ in shapes]
        b_tensors = [_make_b((n, k), dtype) for _ in shapes]

        # -- 路径 1: ck_grouped_gemm(alloc 输出) --
        # CK C ABI → ck_tile_dcu_grouped_gemm_run 检测 M 不对齐 → 自动走 V4Mpad
        # kernel 内 pad_tensor_view 将 M pad 到 ceil(M/64)*64,epilogue 只写有效行
        outputs = aiter.ck_grouped_gemm(a_tensors, b_tensors)
        refs = [_reference_rc(a, b) for a, b in zip(a_tensors, b_tensors)]
        print(_format_accuracy_line(dtype_name, dtype, outputs, refs, shapes).replace(
            "[PASS]", "[PASS moe alloc]", 1))

        # -- 路径 2: ck_grouped_gemm_out(prealloc 输出) --
        # 调用方预先分配 [M_i, N] 的 C tensor,kernel 直接写入逻辑尺寸
        c_tensors = [torch.empty((m, n), device="cuda", dtype=_output_dtype(dtype))
                     for m, _, _ in shapes]
        outputs_pre = aiter.ck_grouped_gemm_out(a_tensors, b_tensors, c_tensors)
        print(_format_accuracy_line(dtype_name, dtype, outputs_pre, refs, shapes).replace(
            "[PASS]", "[PASS moe prealloc]", 1))

        if args.warmup > 0 and args.repeat > 0:
            # 性能:kernel kPadM vs torch baseline
            ck_ms = _bench(lambda: aiter.ck_grouped_gemm(a_tensors, b_tensors),
                           args.warmup, args.repeat)
            ck_pre_ms = _bench(
                lambda: aiter.ck_grouped_gemm_out(a_tensors, b_tensors, c_tensors),
                args.warmup, args.repeat)
            torch_ms = _bench(lambda: _torch_gemm_grouped(a_tensors, b_tensors, dtype),
                              args.warmup, args.repeat)
            print(_format_perf_line(dtype_name, shapes, ck_ms, torch_ms, ck_tag="ck_moe"))
            print(_format_perf_line(dtype_name, shapes, ck_pre_ms, torch_ms, ck_tag="ck_moe_out"))


# ── 非法输入测试 ──────────────────────────────────────────────────────────────

def run_bad_input_cases():
    """
    验证 wrapper 对非法输入抛出合理错误。
    期望 RuntimeError / ValueError(来自 C++ TORCH_CHECK)。
    """
    print("\n=== 非法输入测试 ===")
    device = "cuda"
    dtype = torch.float16

    def _expect_error(desc, fn):
        try:
            fn()
            print(f"  [FAIL] {desc}: 期望报错但未抛出")
        except Exception as e:
            print(f"  [PASS] {desc}: {type(e).__name__}: {e}")

    _expect_error("空 a_tensors", lambda: aiter.ck_grouped_gemm([], []))

    _expect_error(
        "a/b 列表长度不一致",
        lambda: aiter.ck_grouped_gemm(
            [torch.zeros((128, 64), device=device, dtype=dtype)],
            [],
        ),
    )

    _expect_error(
        "非 2D 张量",
        lambda: aiter.ck_grouped_gemm(
            [torch.zeros((2, 128, 64), device=device, dtype=dtype)],
            [torch.zeros((128, 64), device=device, dtype=dtype)],
        ),
    )

    _expect_error(
        "K 维不匹配",
        lambda: aiter.ck_grouped_gemm(
            [torch.zeros((128, 64), device=device, dtype=dtype)],
            [torch.zeros((128, 128), device=device, dtype=dtype)],
        ),
    )

    # M 不对齐现在由 kernel kPadM 支持(V4Mpad),不应报错
    _expect_error(
        "shape 对齐违规(fp16 N 非 128 倍数)",
        lambda: aiter.ck_grouped_gemm(
            [torch.zeros((128, 128), device=device, dtype=dtype).contiguous()],
            [torch.zeros((127, 128), device=device, dtype=dtype).contiguous()],
        ),
    )


# ── main ──────────────────────────────────────────────────────────────────────
#
# 常用命令示例
# ────────────
# 对齐 M 测试(uniform 模式,M/N/K 均对齐 tile 边界,走 V4/V5 快速路径)
#   python op_tests/test_grouped_gemm.py --dtype fp16                          # 默认 1024³ x3, ~40 TFLOPS
#   python op_tests/test_grouped_gemm.py --dtype fp8                           # fp8 128³ 对齐
#   python op_tests/test_grouped_gemm.py --dtype all --variable                # 变长对齐 shape
#   python op_tests/test_grouped_gemm.py --heterogeneous --dtype fp16          # 异构 M/N,固定 K
#
# 任意 M 测试(MOE 模式,固定 N/K,M 任意值,kernel kPadM 自动处理不对齐)
#   python op_tests/test_grouped_gemm.py --moe --dtype fp16                    # M=1,17,33,63,65,100
#   python op_tests/test_grouped_gemm.py --moe --dtype fp8                     # M 不对齐 128 也可
#   python op_tests/test_grouped_gemm.py --moe --dtype fp16 --groups 3         # 只测前 3 个 M
#
# CI smoke(快速冒烟,小 shape)
#   python op_tests/test_grouped_gemm.py --smoke --dtype fp16
#   python op_tests/test_grouped_gemm.py --smoke --dtype all --variable
#
# 其他
#   python op_tests/test_grouped_gemm.py --layout                              # layout 变体
#   python op_tests/test_grouped_gemm.py --bad-input                           # 错误输入测试

def main():
    parser = argparse.ArgumentParser(
        description="aiter.ck_grouped_gemm 精度与性能测试",
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument("--groups", type=int, default=3,
                        help="GEMM group 数量(uniform/variable 模式)")
    parser.add_argument("--m", type=int, default=1024, help="M 维(uniform 模式,默认 1024 便于测 TFLOPS)")
    parser.add_argument("--n", type=int, default=1024, help="N 维(uniform 模式)")
    parser.add_argument("--k", type=int, default=1024, help="K 维(uniform 模式)")
    parser.add_argument("--warmup", type=int, default=10, help="预热迭代次数")
    parser.add_argument("--repeat", type=int, default=100, help="基准测试迭代次数")
    parser.add_argument("--smoke", action="store_true",
                        help="CI smoke:groups=2, 128^3, warmup=1, repeat=5")
    parser.add_argument("--dtype", choices=["fp16", "bf16", "fp8", "int8", "all"], default="all",
                        help="测试的数据类型")
    parser.add_argument("--variable", action="store_true",
                        help="使用预定义变长 shape(见 VARIABLE_SHAPES)")
    parser.add_argument("--heterogeneous", action="store_true",
                        help="异构 shape:每组 M/N 不同(同 dtype 内 K 相同)")
    parser.add_argument("--moe", action="store_true",
                        help="MOE 模式:固定 N/K,M 任意(CK kernel kPadM 自动处理 M 对齐)")
    parser.add_argument("--layout", action="store_true",
                        help="运行 layout 变体正确性测试(RC,fp16/bf16)")
    parser.add_argument("--bad-input", action="store_true",
                        help="运行非法输入错误处理测试")
    args = parser.parse_args()

    if args.smoke:
        args.groups = 2
        args.m = args.n = args.k = 128
        args.warmup = 1
        args.repeat = 5

    if args.bad_input:
        run_bad_input_cases()
        return

    if args.layout:
        run_layout_cases(args)
        return

    if args.heterogeneous:
        run_heterogeneous_cases(args)
        return

    if args.moe:
        run_moe_cases(args)
        return

    dtype_names = ["fp16", "bf16", "fp8", "int8"] if args.dtype == "all" else [args.dtype]
    for dtype_name in dtype_names:
        run_case(dtype_name, args)


if __name__ == "__main__":
    main()