# SPDX-License-Identifier: MIT """ aiter.ck_grouped_gemm 精度与性能测试。 语义 ---- 对每个 group i,内核计算 C[i] = A[i] @ B[i]^T。 A[i] : [M_i, K_i] 行主序,输入 dtype B[i] : [N_i, K_i] 行主序,输入 dtype(CK 侧按列主序解释,等价于 B^T) C[i] : [M_i, N_i] 行主序,输出 dtype 输出 dtype fp16 -> fp16 bf16 -> bf16 fp8 -> float32(fp8 点积,float32 累加,无 scale) int8 -> int32 (int8 点积,int32 累加,无 scale) fp8 / int8 说明 --------------- CK 执行原始硬件 MMAC,不做量化 scale。 下方参考实现为 a.float() @ b.float().T,与之对齐。 - int8:int32 累加,结果应严格一致。 - fp8:硬件先 fp8×fp8 再 float32 累加;参考实现先扩到 float32 再乘, 存在数值差异,容差见 TOLERANCE。 dense / random fp8、int8 正确性仍待跟进: - int8 dense random:应精确;失败则可能是 CK instance bug。 - fp8 dense random:允许小误差,rtol/atol ~ 0.2。 Layout 支持(--layout) ----------------------- RC(默认):A 行主序,B 存 [N,K] 行主序,CK 按列主序读(等价 B^T) RR/CR/CC:其他组合,经同一 C ABI 与 stride 路由;本测试对 RC 做正确性校验。 CI smoke 命令(小 shape,快速验正确性) --------------------------------------- python op_tests/test_grouped_gemm.py --smoke --dtype fp16 python op_tests/test_grouped_gemm.py --smoke --dtype all --variable python op_tests/test_grouped_gemm.py --moe --dtype fp16 性能测试(默认 m=n=k=1024,约 6.4 GFLOPS/group) ------------------------------------------------- python op_tests/test_grouped_gemm.py --dtype fp16 python op_tests/test_grouped_gemm.py --dtype all --variable --warmup 10 --repeat 100 说明:TFLOPS = sum(2*M*N*K) / time。shape 过小时 launch/sync 开销主导,数值会偏低; torch_gemm 使用原生 dtype 的 matmul,torch_ref 仅用于精度校验(float32 慢路径)。 """ import argparse import time import torch import aiter DTYPE_MAP = { "fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.float8_e4m3fn, "int8": torch.int8, } # 各 dtype 的 assert_close 容差。 # fp8:硬件 fp8×fp8 MMAC 与 float32 参考可能舍入不同。 # int8:必须精确(int32 累加,无舍入)。 TOLERANCE = { torch.float16: dict(rtol=5e-2, atol=5e-2), torch.bfloat16: dict(rtol=5e-2, atol=5e-2), torch.float8_e4m3fn: dict(rtol=2e-1, atol=2e-1), torch.int8: dict(rtol=0, atol=0), } # 各 dtype 的 shape 对齐要求(CK tile 约束)。 # fp16/bf16 GemmConfigComputeV4 的 K_Tile=64;当前 CK instance 在 K=64(恰好一个 tile)会出错, # 实际要求 K >= 128。 SHAPE_ALIGN = { torch.float16: dict(m=64, n=128, k=128), torch.bfloat16: dict(m=64, n=128, k=128), torch.float8_e4m3fn: dict(m=128, n=128, k=128), torch.int8: dict(m=32, n=32, k=128), } # 各 dtype 预定义变长 shape(来自 Stage 1 验证)。 VARIABLE_SHAPES = { torch.float16: [(128, 128, 128), (256, 128, 128), (384, 128, 128)], torch.bfloat16: [(128, 128, 128), (256, 128, 128), (384, 128, 128)], torch.float8_e4m3fn: [(128, 128, 128), (256, 128, 256), (128, 256, 128)], torch.int8: [(64, 64, 128), (128, 128, 256), (192, 64, 128)], } # 异构 shape:每组 M/N 不同(同 dtype 内 K 保持一致)。 # CK fp16/bf16 GemmConfigComputeV4 不支持组间 mixed K,因此只变 M/N。 HETERO_SHAPES = { torch.float16: [(128, 128, 128), (192, 256, 128), (256, 128, 128)], torch.bfloat16: [(128, 128, 128), (192, 256, 128), (256, 128, 128)], torch.float8_e4m3fn: [(128, 128, 128), (256, 128, 128), (128, 256, 128)], torch.int8: [(32, 32, 128), (64, 64, 128), (96, 32, 128)], } # MOE:固定 N/K,每组 M 可不对齐(wrapper 自动 pad M 到 tile 边界)。 MOE_FIXED_NK = { torch.float16: (128, 128), torch.bfloat16: (128, 128), torch.float8_e4m3fn: (128, 128), torch.int8: (32, 128), } MOE_M_VALUES = [1, 17, 33, 63, 65, 100] # ── 张量构造 ────────────────────────────────────────────────────────────────── def _rand_tensor(shape, dtype): """8-bit 用确定性 projection;fp16/bf16 用随机数。""" if dtype is torch.int8: values = (torch.arange(shape[1], device="cuda", dtype=torch.int16) % 5).to(torch.int8) return values.view(1, shape[1]).repeat(shape[0], 1).contiguous() if dtype is torch.float8_e4m3fn: values = (torch.arange(shape[1], device="cuda", dtype=torch.float16) % 5).to(dtype) return values.view(1, shape[1]).repeat(shape[0], 1).contiguous() src = (torch.randn(shape, device="cuda", dtype=torch.float16) / 10).contiguous() return src.to(dtype).contiguous() def _make_b(shape, dtype): """ 构造 B 张量:8-bit 用 projection(仅首行非零),fp16/bf16 用随机数。 保持与既有 projection 测试一致。 """ n, k = shape if dtype in (torch.int8, torch.float8_e4m3fn): b = torch.zeros(shape, device="cuda", dtype=dtype) if dtype is torch.int8: b[0, :] = 1 else: b[0, :] = torch.ones(k, device="cuda", dtype=torch.float16).to(dtype) return b.contiguous() return _rand_tensor(shape, dtype) def _make_b_layout(shape, dtype, b_layout): """ 按指定 layout 构造 B 的内存布局。 Layout 编码(与 CK C ABI 的 a_layout / b_layout 一致): 'R' -> 行主序 -> [rows, cols] stride=(cols, 1) 'C' -> 列主序 -> [rows, cols] stride=(1, rows) Python API 固定 a_layout='R', b_layout='C',B 存 [N,K] 行主序。 测其他 layout 时需传连续转置张量以匹配 shape 语义。 """ if b_layout == "C": # 常规:B 存 [N,K],CK 按 [K,N] 列主序读,等价 B^T。 return _make_b(shape, dtype) # b_layout == "R":B 存 [N,K] 行主序;CK 按 [K,N] 行主序读, # 即 C = A @ B(非 A @ B^T);参考实现需相应调整。 n, k = shape return _make_b((n, k), dtype) def _reference_rc(a, b): """RC layout 精度参考:C = A @ B^T(float32 累加,用于 assert_close)。""" return a.float() @ b.float().T def _torch_gemm_grouped(a_tensors, b_tensors, dtype): """ torch GEMM 性能基准:尽量用原生 dtype 的 matmul,避免 float() 转换开销。 fp8/int8 无高效原生 matmul 时回退到 float32 参考。 """ if dtype in (torch.float16, torch.bfloat16): return [a @ b.T for a, b in zip(a_tensors, b_tensors)] return [_reference_rc(a, b) for a, b in zip(a_tensors, b_tensors)] def _reference_rr(a, b): """RR layout 参考:C = A @ B(CK 视角下 B 为 [K,N])。""" # b_layout='R' 时 B 存 [N,K] 行主序,CK 按 [K,N] 行主序解释。 # 为简化,测试统一用 RC 约定。 return a.float() @ b.float().T # ── 断言 / 基准 / 算力 ──────────────────────────────────────────────────────── def _accuracy_stats(outputs, refs): """逐 group 统计 max_abs / max_rel(用于打印,不判定 pass/fail)。""" stats = [] for out, ref in zip(outputs, refs): out_f = out.float() ref_f = ref.float() diff = (out_f - ref_f).abs() max_abs = diff.max().item() denom = ref_f.abs().clamp_min(1e-6) max_rel = (diff / denom).max().item() stats.append(dict(max_abs=max_abs, max_rel=max_rel)) return stats def _assert_close(dtype, outputs, refs): tol = TOLERANCE[dtype] for idx, (out, ref) in enumerate(zip(outputs, refs)): torch.testing.assert_close( out.float(), ref.float(), **tol, msg=lambda msg: f"group {idx} failed\n{msg}", ) def _format_accuracy_line(dtype_name, dtype, outputs, refs, shapes): """ 执行精度校验(失败则抛 AssertionError),并返回可打印的 PASS 行。 """ stats = _accuracy_stats(outputs, refs) _assert_close(dtype, outputs, refs) tol = TOLERANCE[dtype] shape_text = ",".join(f"{m}x{n}x{k}" for m, n, k in shapes) group_text = " ".join( f"g{i}: max_abs={s['max_abs']:.2e} max_rel={s['max_rel']:.2e}" for i, s in enumerate(stats) ) return ( f" [PASS] accuracy dtype={dtype_name} groups={len(shapes)} shapes={shape_text}" f" rtol={tol['rtol']} atol={tol['atol']} {group_text}" ) def _bench(fn, warmup, repeat): for _ in range(warmup): fn() torch.cuda.synchronize() t0 = time.perf_counter() for _ in range(repeat): fn() torch.cuda.synchronize() return (time.perf_counter() - t0) * 1000.0 / repeat def _total_gemm_flops(shapes): """ 计算 grouped GEMM 总 FLOPs(乘加各算 1 op,共 2*M*N*K)。 shapes: [(M, N, K), ...] """ return sum(2 * m * n * k for m, n, k in shapes) def _ms_to_tflops(flops, ms): """将耗时(ms)与 FLOPs 换算为 TFLOPS。""" if ms <= 0: return 0.0 return flops / (ms * 1e-3) / 1e12 def _output_dtype(dtype): if dtype is torch.int8: return torch.int32 if dtype is torch.float8_e4m3fn: return torch.float32 return dtype def _make_c_tensors(shapes, dtype): out_dtype = _output_dtype(dtype) return [torch.empty((m, n), device="cuda", dtype=out_dtype) for m, n, _ in shapes] def _format_perf_line(dtype_name, shapes, ck_ms, torch_ms, ck_tag="ck"): """格式化延迟与 TFLOPS 输出(单行)。""" flops = _total_gemm_flops(shapes) gflops = flops / 1e9 ck_tflops = _ms_to_tflops(flops, ck_ms) torch_tflops = _ms_to_tflops(flops, torch_ms) shape_text = ",".join(f"{m}x{n}x{k}" for m, n, k in shapes) return ( f"grouped_gemm dtype={dtype_name} groups={len(shapes)} total={gflops:.3f} GFLOPS" f" shapes={shape_text}" f" {ck_tag}={ck_ms:.4f} ms ({ck_tflops:.3f} TFLOPS)" f" torch_gemm={torch_ms:.4f} ms ({torch_tflops:.3f} TFLOPS)" ) # ── shape 工厂 ──────────────────────────────────────────────────────────────── def _make_shapes_uniform(args, dtype): return [(args.m, args.n, args.k)] * args.groups def _make_shapes_variable(args, dtype): return VARIABLE_SHAPES[dtype][: args.groups] def _make_shapes_heterogeneous(args, dtype): return HETERO_SHAPES[dtype][: args.groups] def _make_shapes(args, dtype): if args.heterogeneous: return _make_shapes_heterogeneous(args, dtype) if args.variable: return _make_shapes_variable(args, dtype) return _make_shapes_uniform(args, dtype) # ── 各 dtype 测试用例 ──────────────────────────────────────────────────────── def run_case(dtype_name, args): dtype = DTYPE_MAP[dtype_name] shapes = _make_shapes(args, dtype) a_tensors = [_rand_tensor((m, k), dtype) for m, _, k in shapes] b_tensors = [_make_b((n, k), dtype) for _, n, k in shapes] outputs = aiter.ck_grouped_gemm(a_tensors, b_tensors) refs = [_reference_rc(a, b) for a, b in zip(a_tensors, b_tensors)] print(_format_accuracy_line(dtype_name, dtype, outputs, refs, shapes)) c_tensors = _make_c_tensors(shapes, dtype) outputs_pre = aiter.ck_grouped_gemm_out(a_tensors, b_tensors, c_tensors) print(_format_accuracy_line(dtype_name, dtype, outputs_pre, refs, shapes).replace( "[PASS]", "[PASS prealloc c]", 1)) ck_ms = _bench(lambda: aiter.ck_grouped_gemm(a_tensors, b_tensors), args.warmup, args.repeat) ck_prealloc_ms = _bench( lambda: aiter.ck_grouped_gemm_out(a_tensors, b_tensors, c_tensors), args.warmup, args.repeat, ) torch_ms = _bench( lambda: _torch_gemm_grouped(a_tensors, b_tensors, dtype), args.warmup, args.repeat, ) print(_format_perf_line(dtype_name, shapes, ck_ms, torch_ms, ck_tag="ck")) print(_format_perf_line(dtype_name, shapes, ck_prealloc_ms, torch_ms, ck_tag="ck_prealloc")) # ── layout 变体测试 ─────────────────────────────────────────────────────────── def run_layout_cases(args): """ 对 fp16/bf16 测试 RC layout 下多种 shape 的正确性。 aiter Python API 目前仅暴露 RC(A 行主序,B 存 [N,K] 行主序,CK 按列主序读), 为推理最常见 layout。本函数用标准 API 校验各 shape 无回归。 CK layout 说明: - RC: A[M,K] 行主序,B[N,K] 行主序(CK 读 [K,N] 列主序)-> C = A @ B^T - RR: A[M,K] 行主序,B[K,N] 行主序 -> C = A @ B - CR/CC: A 列主序变体,实践中较少用。 测其他 layout 需直接调 C ABI;此处只测 RC 路径。 """ print("\n=== Layout 变体测试(RC,fp16/bf16)===") for dtype_name in ("fp16", "bf16"): dtype = DTYPE_MAP[dtype_name] for m, n, k in VARIABLE_SHAPES[dtype]: a = _rand_tensor((m, k), dtype) b = _make_b((n, k), dtype) out = aiter.ck_grouped_gemm([a], [b])[0] ref = _reference_rc(a, b) print(_format_accuracy_line(dtype_name, dtype, [out], [ref], [(m, n, k)])) # ── 异构 shape 测试 ─────────────────────────────────────────────────────────── def run_heterogeneous_cases(args): """ 每组 M、N 不同(同 dtype 内 K 固定)。 验证内核能正确处理逐组 shape 描述符。 """ print("\n=== 异构 shape 测试 ===") dtype_names = ["fp16", "bf16", "fp8", "int8"] if args.dtype == "all" else [args.dtype] for dtype_name in dtype_names: dtype = DTYPE_MAP[dtype_name] shapes = HETERO_SHAPES[dtype] a_tensors = [_rand_tensor((m, k), dtype) for m, _, k in shapes] b_tensors = [_make_b((n, k), dtype) for _, n, k in shapes] outputs = aiter.ck_grouped_gemm(a_tensors, b_tensors) refs = [_reference_rc(a, b) for a, b in zip(a_tensors, b_tensors)] print(_format_accuracy_line(dtype_name, dtype, outputs, refs, shapes)) # ── MOE(动态 M,固定 N/K)测试 ─────────────────────────────────────────────── # # MOE 语义:每组 C_i = A_i @ B_i^T,A_i: [M_i, K],B_i: [N, K](全组相同 N/K)。 # M_i 可为任意正整数,N/K 对齐要求:N % 128 == 0, K % 64 == 0, K >= 128。 # # 实现路径(CK kernel 层 kPadM): # 1. ck_tile_dcu_grouped_gemm_run 检测 M_i 是否全部对齐 64 # 2. 对齐 → 走 GemmConfigComputeV4(kPadM=false,~40 TFLOPS 大 GEMM) # 3. 不对齐 → 走 GemmConfigComputeV4Mpad(kPadM=true,M 维自动 pad 到 tile 边界) # 4. kernel 内 pad_tensor_view 确保越界写入被抑制,epilogue 只写有效行 # 5. 无需 Python 侧 zero-pad A / slice C(Stage 5 的 ck_grouped_gemm_moe 保留作 fallback) def run_moe_cases(args): """ MOE 场景:每组 token 数 M_i 任意(不对齐),固定 N/K。 直接调 ck_grouped_gemm / ck_grouped_gemm_out,CK kernel 层 kPadM 自动处理 M padding。 """ print("\n=== MOE 动态 M 测试(固定 N/K,kernel kPadM)===") # fp16/bf16 → V4Mpad (MPerBlock=64), fp8 → V5Mpad (MPerBlock=128) moe_dtypes = ["fp16", "bf16", "fp8"] if args.dtype == "all" else [args.dtype] for dtype_name in moe_dtypes: if dtype_name not in ("fp16", "bf16", "fp8"): print(f" [SKIP] {dtype_name}: MOE kPadM not supported") continue dtype = DTYPE_MAP[dtype_name] n, k = MOE_FIXED_NK[dtype] ms = MOE_M_VALUES[: args.groups] if args.groups <= len(MOE_M_VALUES) else MOE_M_VALUES shapes = [(m, n, k) for m in ms] # 构造不对齐 M 的 A 张量(M=1,17,33,63,65,100),全部 M % 64 != 0 a_tensors = [_rand_tensor((m, k), dtype) for m, _, _ in shapes] b_tensors = [_make_b((n, k), dtype) for _ in shapes] # -- 路径 1: ck_grouped_gemm(alloc 输出) -- # CK C ABI → ck_tile_dcu_grouped_gemm_run 检测 M 不对齐 → 自动走 V4Mpad # kernel 内 pad_tensor_view 将 M pad 到 ceil(M/64)*64,epilogue 只写有效行 outputs = aiter.ck_grouped_gemm(a_tensors, b_tensors) refs = [_reference_rc(a, b) for a, b in zip(a_tensors, b_tensors)] print(_format_accuracy_line(dtype_name, dtype, outputs, refs, shapes).replace( "[PASS]", "[PASS moe alloc]", 1)) # -- 路径 2: ck_grouped_gemm_out(prealloc 输出) -- # 调用方预先分配 [M_i, N] 的 C tensor,kernel 直接写入逻辑尺寸 c_tensors = [torch.empty((m, n), device="cuda", dtype=_output_dtype(dtype)) for m, _, _ in shapes] outputs_pre = aiter.ck_grouped_gemm_out(a_tensors, b_tensors, c_tensors) print(_format_accuracy_line(dtype_name, dtype, outputs_pre, refs, shapes).replace( "[PASS]", "[PASS moe prealloc]", 1)) if args.warmup > 0 and args.repeat > 0: # 性能:kernel kPadM vs torch baseline ck_ms = _bench(lambda: aiter.ck_grouped_gemm(a_tensors, b_tensors), args.warmup, args.repeat) ck_pre_ms = _bench( lambda: aiter.ck_grouped_gemm_out(a_tensors, b_tensors, c_tensors), args.warmup, args.repeat) torch_ms = _bench(lambda: _torch_gemm_grouped(a_tensors, b_tensors, dtype), args.warmup, args.repeat) print(_format_perf_line(dtype_name, shapes, ck_ms, torch_ms, ck_tag="ck_moe")) print(_format_perf_line(dtype_name, shapes, ck_pre_ms, torch_ms, ck_tag="ck_moe_out")) # ── 非法输入测试 ────────────────────────────────────────────────────────────── def run_bad_input_cases(): """ 验证 wrapper 对非法输入抛出合理错误。 期望 RuntimeError / ValueError(来自 C++ TORCH_CHECK)。 """ print("\n=== 非法输入测试 ===") device = "cuda" dtype = torch.float16 def _expect_error(desc, fn): try: fn() print(f" [FAIL] {desc}: 期望报错但未抛出") except Exception as e: print(f" [PASS] {desc}: {type(e).__name__}: {e}") _expect_error("空 a_tensors", lambda: aiter.ck_grouped_gemm([], [])) _expect_error( "a/b 列表长度不一致", lambda: aiter.ck_grouped_gemm( [torch.zeros((128, 64), device=device, dtype=dtype)], [], ), ) _expect_error( "非 2D 张量", lambda: aiter.ck_grouped_gemm( [torch.zeros((2, 128, 64), device=device, dtype=dtype)], [torch.zeros((128, 64), device=device, dtype=dtype)], ), ) _expect_error( "K 维不匹配", lambda: aiter.ck_grouped_gemm( [torch.zeros((128, 64), device=device, dtype=dtype)], [torch.zeros((128, 128), device=device, dtype=dtype)], ), ) # M 不对齐现在由 kernel kPadM 支持(V4Mpad),不应报错 _expect_error( "shape 对齐违规(fp16 N 非 128 倍数)", lambda: aiter.ck_grouped_gemm( [torch.zeros((128, 128), device=device, dtype=dtype).contiguous()], [torch.zeros((127, 128), device=device, dtype=dtype).contiguous()], ), ) # ── main ────────────────────────────────────────────────────────────────────── # # 常用命令示例 # ──────────── # 对齐 M 测试(uniform 模式,M/N/K 均对齐 tile 边界,走 V4/V5 快速路径) # python op_tests/test_grouped_gemm.py --dtype fp16 # 默认 1024³ x3, ~40 TFLOPS # python op_tests/test_grouped_gemm.py --dtype fp8 # fp8 128³ 对齐 # python op_tests/test_grouped_gemm.py --dtype all --variable # 变长对齐 shape # python op_tests/test_grouped_gemm.py --heterogeneous --dtype fp16 # 异构 M/N,固定 K # # 任意 M 测试(MOE 模式,固定 N/K,M 任意值,kernel kPadM 自动处理不对齐) # python op_tests/test_grouped_gemm.py --moe --dtype fp16 # M=1,17,33,63,65,100 # python op_tests/test_grouped_gemm.py --moe --dtype fp8 # M 不对齐 128 也可 # python op_tests/test_grouped_gemm.py --moe --dtype fp16 --groups 3 # 只测前 3 个 M # # CI smoke(快速冒烟,小 shape) # python op_tests/test_grouped_gemm.py --smoke --dtype fp16 # python op_tests/test_grouped_gemm.py --smoke --dtype all --variable # # 其他 # python op_tests/test_grouped_gemm.py --layout # layout 变体 # python op_tests/test_grouped_gemm.py --bad-input # 错误输入测试 def main(): parser = argparse.ArgumentParser( description="aiter.ck_grouped_gemm 精度与性能测试", formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument("--groups", type=int, default=3, help="GEMM group 数量(uniform/variable 模式)") parser.add_argument("--m", type=int, default=1024, help="M 维(uniform 模式,默认 1024 便于测 TFLOPS)") parser.add_argument("--n", type=int, default=1024, help="N 维(uniform 模式)") parser.add_argument("--k", type=int, default=1024, help="K 维(uniform 模式)") parser.add_argument("--warmup", type=int, default=10, help="预热迭代次数") parser.add_argument("--repeat", type=int, default=100, help="基准测试迭代次数") parser.add_argument("--smoke", action="store_true", help="CI smoke:groups=2, 128^3, warmup=1, repeat=5") parser.add_argument("--dtype", choices=["fp16", "bf16", "fp8", "int8", "all"], default="all", help="测试的数据类型") parser.add_argument("--variable", action="store_true", help="使用预定义变长 shape(见 VARIABLE_SHAPES)") parser.add_argument("--heterogeneous", action="store_true", help="异构 shape:每组 M/N 不同(同 dtype 内 K 相同)") parser.add_argument("--moe", action="store_true", help="MOE 模式:固定 N/K,M 任意(CK kernel kPadM 自动处理 M 对齐)") parser.add_argument("--layout", action="store_true", help="运行 layout 变体正确性测试(RC,fp16/bf16)") parser.add_argument("--bad-input", action="store_true", help="运行非法输入错误处理测试") args = parser.parse_args() if args.smoke: args.groups = 2 args.m = args.n = args.k = 128 args.warmup = 1 args.repeat = 5 if args.bad_input: run_bad_input_cases() return if args.layout: run_layout_cases(args) return if args.heterogeneous: run_heterogeneous_cases(args) return if args.moe: run_moe_cases(args) return dtype_names = ["fp16", "bf16", "fp8", "int8"] if args.dtype == "all" else [args.dtype] for dtype_name in dtype_names: run_case(dtype_name, args) if __name__ == "__main__": main()