import torch import argparse import time import numpy as np def parse_args(): parser = argparse.ArgumentParser(description='gemm benchmark') parser.add_argument('--M', type=int, default=4096, help='M') parser.add_argument('--K', type=int, default=4096, help='K') parser.add_argument('--N', type=int, default=4096, help='N') parser.add_argument('--dtype', type=str, default='bfloat16', choices=['float64', 'float32', 'float16', 'bfloat16', 'tf32', 'int8', 'mixed_fp16_fp32', 'mixed_bf16_fp32', 'mixed_int8_int32', 'mixed_tf32_fp32', 'w8a8'], help='测试数据类型') parser.add_argument('--alpha', type=float, default=1.0, help='alpha') parser.add_argument('--beta', type=float, default=0.0, help='beta') parser.add_argument('--warmup_iterations', type=int, default=50, help='warmup次数') parser.add_argument('--bench_iterations', type=int, default=1000, help='benchmark迭代次数') parser.add_argument('--transA', action='store_true', default=False, help='是否转置A矩阵') parser.add_argument('--transB', action='store_true', default=False, help='是否转置B矩阵') return parser.parse_args() def get_matrix(dims, dtype, device='cuda'): """创建指定类型的矩阵""" if dtype in [torch.float64, torch.float32, torch.float16, torch.bfloat16]: return torch.randn(dims, dtype=dtype, device=device) elif dtype == torch.int8: return torch.randint(-128, 127, dims, dtype=torch.int8, device=device) else: return torch.randn(dims, dtype=torch.float32, device=device) def get_blas_op(alpha, beta, transA=False, transB=False): """返回BLAS操作函数""" def blas_op(a, b, c): # torch.addmm 不支持直接指定转置,需要手动转置 a_op = a.t() if transA else a b_op = b.t() if transB else b return torch.addmm(c, a_op, b_op, beta=beta, alpha=alpha) return blas_op def benchmark_gemm(args, dtype_config): """执行GEMM基准测试""" M, K, N = args.M, args.K, args.N alpha, beta = args.alpha, args.beta transA, transB = args.transA, args.transB # 根据转置标志确定实际矩阵维度 a_rows, a_cols = (K, M) if transA else (M, K) b_rows, b_cols = (N, K) if transB else (K, N) # 解析数据类型配置 if dtype_config == 'mixed_fp16_fp32': # A,B: fp16, C: fp32 - 不支持addmm,单独实现 a = torch.randn((a_rows, a_cols), dtype=torch.float16, device='cuda') b = torch.randn((b_rows, b_cols), dtype=torch.float16, device='cuda') c = torch.zeros((M, N), dtype=torch.float32, device='cuda') def matmul_op(a, b, c): a_op = a.t() if transA else a b_op = b.t() if transB else b result = torch.mm(a_op, b_op) # 自动提升到fp32 if alpha != 1.0 or beta != 0.0: result = alpha * result + beta * c c.copy_(result) return c elif dtype_config == 'mixed_bf16_fp32': # A,B: bf16, C: fp32 - 不支持addmm,单独实现 a = torch.randn((a_rows, a_cols), dtype=torch.bfloat16, device='cuda') b = torch.randn((b_rows, b_cols), dtype=torch.bfloat16, device='cuda') c = torch.zeros((M, N), dtype=torch.float32, device='cuda') def matmul_op(a, b, c): a_op = a.t() if transA else a b_op = b.t() if transB else b result = torch.mm(a_op, b_op) # 自动提升到fp32 if alpha != 1.0 or beta != 0.0: result = alpha * result + beta * c c.copy_(result) return c elif dtype_config == 'mixed_int8_int32': # A,B: int8, C: int32 - 不支持addmm,单独实现 a = torch.randint(-128, 127, (a_rows, a_cols), dtype=torch.int8, device='cuda') b = torch.randint(-128, 127, (b_rows, b_cols), dtype=torch.int8, device='cuda') c = torch.zeros((M, N), dtype=torch.int32, device='cuda') if hasattr(torch, '_int_mm'): print(" Using torch._int_mm for int8 matmul") def matmul_op(a, b, c): a_op = a.t() if transA else a b_op = b.t() if transB else b result = torch._int_mm(a_op, b_op) if alpha != 1.0 or beta != 0.0: result = (alpha * result.float()).to(torch.int32) + beta * c c.copy_(result) return c else: print(" Warning: torch._int_mm not available, using fallback") def matmul_op(a, b, c): a_op = a.t() if transA else a b_op = b.t() if transB else b result = torch.mm(a_op.float(), b_op.float()).to(torch.int32) if alpha != 1.0 or beta != 0.0: result = (alpha * result.float()).to(torch.int32) + beta * c c.copy_(result) return c elif dtype_config == 'w8a8': # W8A8: 权重int8, 激活fp16 - 不支持addmm,单独实现 a = torch.randn((a_rows, a_cols), dtype=torch.float16, device='cuda') b = torch.randint(-128, 127, (b_rows, b_cols), dtype=torch.int8, device='cuda') c = torch.zeros((M, N), dtype=torch.float16, device='cuda') def matmul_op(a, b, c): a_op = a.t() if transA else a b_op = b.t() if transB else b b_fp16 = b_op.to(torch.float16) result = torch.mm(a_op, b_fp16) if alpha != 1.0 or beta != 0.0: result = alpha * result + beta * c c.copy_(result) return c elif dtype_config == 'mixed_tf32_fp32': # TF32模式 - 支持addmm torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True a = torch.randn((a_rows, a_cols), dtype=torch.float32, device='cuda') b = torch.randn((b_rows, b_cols), dtype=torch.float32, device='cuda') c = torch.zeros((M, N), dtype=torch.float32, device='cuda') matmul_op = get_blas_op(alpha, beta, transA, transB) elif dtype_config == 'tf32': # TF32模式 - 支持addmm torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True a = torch.randn((a_rows, a_cols), dtype=torch.float32, device='cuda') b = torch.randn((b_rows, b_cols), dtype=torch.float32, device='cuda') c = torch.zeros((M, N), dtype=torch.float32, device='cuda') matmul_op = get_blas_op(alpha, beta, transA, transB) elif dtype_config == 'int8': # 纯int8模式 - 不支持addmm,单独实现 a = torch.randint(-128, 127, (a_rows, a_cols), dtype=torch.int8, device='cuda') b = torch.randint(-128, 127, (b_rows, b_cols), dtype=torch.int8, device='cuda') c = torch.zeros((M, N), dtype=torch.int8, device='cuda') def matmul_op(a, b, c): a_op = a.t() if transA else a b_op = b.t() if transB else b result = torch.mm(a_op.float(), b_op.float()).to(torch.int8) if alpha != 1.0 or beta != 0.0: result = (alpha * result.float()).to(torch.int8) + beta * c c.copy_(result) return c else: # 标准精度模式 - 支持addmm,使用高性能实现 dtype_map = { 'float64': torch.float64, 'float32': torch.float32, 'float16': torch.float16, 'bfloat16': torch.bfloat16, } dtype = dtype_map.get(dtype_config, torch.float32) a = torch.randn((a_rows, a_cols), dtype=dtype, device='cuda') b = torch.randn((b_rows, b_cols), dtype=dtype, device='cuda') c = torch.zeros((M, N), dtype=dtype, device='cuda') matmul_op = get_blas_op(alpha, beta, transA, transB) # Warmup for _ in range(args.warmup_iterations): matmul_op(a, b, c) # 同步确保warmup完成 torch.cuda.synchronize() # 计时 start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() for _ in range(args.bench_iterations): matmul_op(a, b, c) end_event.record() end_event.synchronize() latency_ms = start_event.elapsed_time(end_event) avg_latency_us = latency_ms * 1e3 / args.bench_iterations # 计算FLOPs: 2*M*N*K (乘法+加法) total_flops = 2 * M * N * K tflops = total_flops / (avg_latency_us * 1e-6) / 1e12 return avg_latency_us, tflops, True def main(): args = parse_args() print(f"\n{'='*80}") print(f"GEMM Benchmark") print(f"Matrix Size: [{args.M}, {args.K}] x [{args.K}, {args.N}]") if args.transA: print(f"Transpose A: Yes (actual A shape: [{args.K}, {args.M}])") if args.transB: print(f"Transpose B: Yes (actual B shape: [{args.N}, {args.K}])") print(f"Alpha: {args.alpha}, Beta: {args.beta}") print(f"Data Type: {args.dtype}") print(f"{'='*80}") try: avg_latency_us, tflops, success = benchmark_gemm(args, args.dtype) if success: print(f"\nResults:") print(f" Warmup iterations: {args.warmup_iterations}") print(f" Benchmark iterations: {args.bench_iterations}") print(f" Average latency: {avg_latency_us:.3f} μs") print(f" Performance: {tflops:.3f} TFLOPS") else: print(f"\nBenchmark failed for {args.dtype}") except Exception as e: print(f"\nError: {str(e)}") import traceback traceback.print_exc() print(f" Benchmark failed for {args.dtype}") if __name__ == "__main__": main()