# SPDX-License-Identifier: MIT import torch import torch.nn.functional as F import aiter from aiter.test_common import checkAllclose, perftest from aiter import dtypes import argparse import pandas as pd ERROR_SUMMARY_ROWS_OUT = [] ERROR_SUMMARY_ROWS_RES = [] def print_error_metrics(ref, pred, name, dim, dtype, ck_us=None): # Compute metrics in fp32 for stable statistics across low-precision dtypes. ref_f = ref.float().flatten() pred_f = pred.float().flatten() err = pred_f - ref_f abs_err = err.abs() mae = abs_err.mean().item() rmse = torch.sqrt((err * err).mean()).item() max_ae = abs_err.max().item() ref_rms = torch.sqrt((ref_f * ref_f).mean()).item() nrmse = rmse / (ref_rms + 1e-12) mean_err = err.mean().item() std_err = err.std(unbiased=False).item() p50 = torch.quantile(abs_err, 0.50).item() p90 = torch.quantile(abs_err, 0.90).item() p99 = torch.quantile(abs_err, 0.99).item() ref_mean = ref_f.mean().item() pred_mean = pred_f.mean().item() ref_std = ref_f.std(unbiased=False).item() pred_std = pred_f.std(unbiased=False).item() ref_rms_full = ref_rms pred_rms = torch.sqrt((pred_f * pred_f).mean()).item() # Fit pred ~= a * ref + b to separate scale and bias mismatch. ref_centered = ref_f - ref_f.mean() pred_centered = pred_f - pred_f.mean() var_ref = (ref_centered * ref_centered).mean().item() cov = (ref_centered * pred_centered).mean().item() a = cov / (var_ref + 1e-12) b = pred_mean - a * ref_mean fit_residual = pred_f - (a * ref_f + b) fit_rmse = torch.sqrt((fit_residual * fit_residual).mean()).item() header = f"[metrics:{name}] dim={dim}, dtype={dtype}" print("\n" + "=" * len(header)) print(header) print("=" * len(header)) rows = [ ("mae", mae), ("rmse", rmse), ("nrmse", nrmse), ("max(|err|)", max_ae), ("p50(|err|)", p50), ("p90(|err|)", p90), ("p99(|err|)", p99), ("mean(err)", mean_err), ("std(err)", std_err), ("mean(ref)", ref_mean), ("mean(pred)", pred_mean), ("std(ref)", ref_std), ("std(pred)", pred_std), ("rms(ref)", ref_rms_full), ("rms(pred)", pred_rms), ("fit_a", a), ("fit_b", b), ("fit_rmse", fit_rmse), ] for k, v in rows: if k == "nrmse": print(f" {k:<12}: {v:>14.8f}") else: print(f" {k:<12}: {v:>14.8f}") return { "case": name, "dim": str(dim), "dtype": str(dtype), "cos": F.cosine_similarity(ref_f, pred_f, dim=0).item(), "mae": mae, "rmse": rmse, "nrmse": nrmse, "max_ae": max_ae, "p99": p99, "ck_us": ck_us, } def print_error_summary_table(rows): if not rows: return [] cols = [ ("case", "case"), ("dim", "dim"), ("dtype", "dtype"), ("cos", "cos"), ("mae", "mae"), ("rmse", "rmse"), ("nrmse", "nrmse"), ("max_ae", "max_ae"), ("p99", "p99"), ("ck_us", "ck_us"), ] formatted_rows = [] for r in rows: formatted_rows.append( { "case": r["case"], "dim": r["dim"], "dtype": r["dtype"], "cos": f"{r['cos']:.8f}", "mae": f"{r['mae']:.8f}", "rmse": f"{r['rmse']:.8f}", "nrmse": f"{r['nrmse']:.8f}", "max_ae": f"{r['max_ae']:.8f}", "p99": f"{r['p99']:.8f}", "ck_us": f"{r['ck_us']:.2f}" if r.get("ck_us") is not None else "N/A", } ) widths = { key: max(len(header), *(len(fr[key]) for fr in formatted_rows)) for key, header in cols } title = "[metrics-summary] cross-case error comparison" print("\n" + "=" * len(title)) print(title) print("=" * len(title)) header = " | ".join(f"{h:<{widths[k]}}" for k, h in cols) sep = "-+-".join("-" * widths[k] for k, _ in cols) print(header) print(sep) for fr in formatted_rows: print(" | ".join(f"{fr[k]:<{widths[k]}}" for k, _ in cols)) return formatted_rows @perftest() def run_torch(input, weight, eps, residual=None): if residual is None: residual_out = None output = F.rms_norm( input=input, normalized_shape=(input.shape[-1],), weight=weight, eps=eps ) else: residual_out = input + residual output = F.rms_norm( input=residual_out, normalized_shape=(input.shape[-1],), weight=weight, eps=eps, ) return output, residual_out @perftest() def run_ck(input, weight, eps, residual=None): if residual is None: residual_out = None output = aiter.rms_norm(input, weight, eps) else: residual_out = torch.empty_like(input) output = torch.empty_like(input) aiter.rmsnorm2d_fwd_with_add( output, input, residual, residual_out, weight, eps, ) return output, residual_out @perftest() def run_cu(input, weight, eps, residual=None): if residual is None: residual_out = None output = torch.empty_like(input) aiter.rms_norm_cu(output, input, weight, eps) else: aiter.fused_add_rms_norm_cu(input, residual, weight, eps) output = input residual_out = residual return output, residual_out def test_rmsnorm2d(dtype, m, n): dim = (m, n) input = torch.randn(dim, dtype=dtype, device="cuda") weight = torch.randn(n, dtype=dtype, device="cuda") # q, k, v = torch.split(hidden_stats, [6*n, n, n], dim=1) # input = k (a, *_), avg_a = run_torch(input, weight, 1e-5) (b, *_), avg_b = run_ck(input, weight, 1e-5) (c, *_), avg_c = run_cu(input, weight, 1e-5) msg = f"[perf] dim: {str(dim):<20}, dtype: {dtype}, torch avg: {avg_a:<8.2f} us, ck avg: {avg_b:<8.2f} us, cu avg: {avg_c:<8.2f} us, uplift: {avg_a/avg_b-1:<5.1%}" checkAllclose(a, b, msg=msg) checkAllclose(a, c, msg="cu") def test_rmsnorm2d_fuseAdd(dtype, m, n): dim = (m, n) input = torch.randn(dim, dtype=dtype, device="cuda") weight = torch.randn(n, dtype=dtype, device="cuda") res = torch.randn(dim, dtype=dtype, device="cuda") # q, k, v = torch.split(hidden_stats, [6*n, n, n], dim=1) # input = k (a, res_a, *_), avg_a = run_torch(input, weight, 1e-5, residual=res) (b, res_b, *_), avg_b = run_ck(input, weight, 1e-5, residual=res) # (c, res_c, *_), avg_c = run_ck( # input, weight, 1e-5, residual=res, use_model_sensitive_rmsnorm=1 # ) # (d, res_d, *_), avg_d = run_cu(input, weight, 1e-5, residual=res) msg = f"[perf] dim: {str(dim):<20}, dtype: {dtype}, torch avg: {avg_a:<8.2f} us, ck avg: {avg_b:<8.2f} us,uplift: {avg_a/avg_b-1:<5.1%}" cos_ab = F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item() cos_res_ab = F.cosine_similarity(res_a.flatten(), res_b.flatten(), dim=0).item() print( f"[cos] dim: {str(dim):<20}, dtype: {dtype}, out(torch,ck): {cos_ab:.8f}, res(torch,ck): {cos_res_ab:.8f}" ) ERROR_SUMMARY_ROWS_OUT.append(print_error_metrics(a, b, "out_torch_vs_ck", dim, dtype, ck_us=avg_b)) ERROR_SUMMARY_ROWS_RES.append( print_error_metrics(res_a, res_b, "res_torch_vs_ck", dim, dtype, ck_us=avg_b) ) checkAllclose(a, b, atol=0.03, rtol=0.001, msg=msg) checkAllclose(res_a, res_b, atol=0.03, rtol=0.001, msg="ck res check (NO_SPECIFIC_MODEL)") # checkAllclose(a, c, atol=0.03, msg=msg) # checkAllclose(res_a, res_c, msg="ck res check (T5_MODEL_LIKE)") # checkAllclose(a, d, atol=0.03, msg='cu') # checkAllclose(res_a, res_d, atol=0.01, msg='cu res check') # for dtype in [dtypes.fp16, dtypes.bf16]: # for m in [1, 2, 4, 8, 16, 32, 64, 128, 256]: # for n in [4096, 8192, 16384, 32768, 65536]: # test_rmsnorm2d(dtype, m, n) # l_dtype = ["fp16", "bf16"] # l_m = [1, 2, 4, 8, 16, 32, 64, 128, 256] # l_n = [4096, 8192, 16384, 32768, 65536] l_dtype = ["fp16"] l_m = [1, 16, 64, 128, 256] l_n = [1024] parser = argparse.ArgumentParser( formatter_class=argparse.RawTextHelpFormatter, description="config input of test", ) parser.add_argument( "-d", "--dtype", type=str, choices=l_dtype, nargs="?", const=None, default=None, help="""Data type. e.g.: -d bf16""", ) parser.add_argument( "-m", "--m", type=int, nargs="?", default=None, help="""M of mnk. e.g.: -m 32""", ) parser.add_argument( "-n", "--n", type=int, nargs="?", default=None, help="""N of mnk. e.g.: -n 1024""", ) args = parser.parse_args() if args.dtype is None: l_dtype = [dtypes.d_dtypes[key] for key in l_dtype] else: l_dtype = [dtypes.d_dtypes[args.dtype]] if args.m is not None: l_m = [args.m] if args.n is not None: l_n = [args.n] print("\nstart fuse add test") for dtype in l_dtype: for m in l_m: for n in l_n: test_rmsnorm2d_fuseAdd(dtype, m, n) csv_rows = [] csv_rows.extend(print_error_summary_table(ERROR_SUMMARY_ROWS_OUT)) csv_rows.extend(print_error_summary_table(ERROR_SUMMARY_ROWS_RES)) if csv_rows: pd.DataFrame(csv_rows).to_csv("rmsnorm2d.csv", index=False)