# SPDX-License-Identifier: MIT # Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import os from typing import Optional import aiter import torch import torch.nn.functional as F import torch.distributed as dist import argparse import itertools import pandas as pd from aiter import dtypes from aiter.dist.parallel_state import ( ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce, get_tp_group, graph_capture, destroy_model_parallel, destroy_distributed_environment, ) from aiter.dist.utils import get_open_port, get_distributed_init_method, get_ip from aiter.dist.communication_op import ( tensor_model_parallel_all_reduce, tensor_model_parallel_fused_allreduce_rmsnorm, tensor_model_parallel_fused_allreduce_rmsnorm_quant, ) from aiter.test_common import ( checkAllclose, perftest, benchmark, ) from multiprocessing import set_start_method, Pool, freeze_support import logging logger = logging.getLogger("aiter") set_start_method("spawn", force=True) def fused_ar_rmsnorm( tp_size, pp_size, rankID, x, weight, eps, withGraph=False, distributed_init_method: Optional[str] = None, post_per_token_quant: bool = False, ): device = torch.device(f"cuda:{rankID}") torch.cuda.set_device(device) # init logger.info(f"RANK: {rankID} {tp_size} init_process_group...") set_custom_all_reduce(True) init_distributed_environment( world_size=tp_size, rank=rankID, distributed_init_method=distributed_init_method, ) ensure_model_parallel_initialized(tp_size, pp_size) x = x.to(device) weight = weight.to(device) # dist.barrier(device_ids=[i for i in range(tp_size)]) # warmup and align all gpu group = get_tp_group().device_group dist.all_reduce(torch.zeros(1).cuda(), group=group) torch.cuda.synchronize() if withGraph: graph = torch.cuda.CUDAGraph() with graph_capture() as gc: with torch.cuda.graph(graph, stream=gc.stream): if not post_per_token_quant: out, res_out = tensor_model_parallel_fused_allreduce_rmsnorm( x, x, weight, eps ) else: out, res_out, scale_out = ( tensor_model_parallel_fused_allreduce_rmsnorm_quant( x, x, weight, eps ) ) out.fill_(0) res_out.fill_(0) @perftest() def run_ca(): graph.replay() _, us = run_ca() if not post_per_token_quant: out = (out, us) else: out = (out.float() * scale_out, us) else: @perftest() def run_ca(x): if not post_per_token_quant: out, res_out = tensor_model_parallel_fused_allreduce_rmsnorm( x, x, weight, eps ) return out else: out, res_out, scale_out = ( tensor_model_parallel_fused_allreduce_rmsnorm_quant( x, x, weight, eps ) ) return out, scale_out if not post_per_token_quant: out = run_ca(x) else: out = run_ca(x) out = (out[0][0].float() * out[0][1], out[1]) # destroy if dist.is_initialized(): destroy_model_parallel() destroy_distributed_environment() torch.cuda.empty_cache() return out def get_acc_value_with_cudagraph( tp_size, pp_size, rankID, x, weight, eps, loop_time=1, distributed_init_method: Optional[str] = None, ): device = torch.device(f"cuda:{rankID}") torch.cuda.set_device(device) # init logger.info(f"RANK: {rankID} {tp_size} init_process_group...") set_custom_all_reduce(True) init_distributed_environment( world_size=tp_size, rank=rankID, distributed_init_method=distributed_init_method, ) ensure_model_parallel_initialized(tp_size, pp_size) x = x.to(device) weight = weight.to(device) # dist.barrier(device_ids=[i for i in range(tp_size)]) # warmup and align all gpu group = get_tp_group().device_group dist.all_reduce(torch.zeros(1).cuda(), group=group) torch.cuda.synchronize() # out = torch.empty_like(x) graph = torch.cuda.CUDAGraph() with graph_capture() as gc: with torch.cuda.graph(graph, stream=gc.stream): # out = torch.empty_like(x) out, res_out = tensor_model_parallel_fused_allreduce_rmsnorm( x, x, weight, eps ) out.fill_(0) def run_ca(): graph.replay() rslt = out.clone() out.fill_(0) return rslt for i in range(loop_time): out = run_ca() # destroy if dist.is_initialized(): destroy_model_parallel() destroy_distributed_environment() torch.cuda.empty_cache() return out def get_acc_value_only( tp_size, pp_size, rankID, x, weight, eps, loop_time=1, distributed_init_method: Optional[str] = None, ): device = torch.device(f"cuda:{rankID}") torch.cuda.set_device(device) # init logger.info(f"RANK: {rankID} {tp_size} init_process_group...") set_custom_all_reduce(True) init_distributed_environment( world_size=tp_size, rank=rankID, distributed_init_method=distributed_init_method, ) ensure_model_parallel_initialized(tp_size, pp_size) x = x.to(device) weight = weight.to(device) # dist.barrier(device_ids=[i for i in range(tp_size)]) # warmup and align all gpu group = get_tp_group().device_group torch.cuda.synchronize() for i in range(loop_time): out, res = tensor_model_parallel_fused_allreduce_rmsnorm(x, x, weight, eps) # destroy if dist.is_initialized(): destroy_model_parallel() destroy_distributed_environment() torch.cuda.empty_cache() return out def split_ar_rmsnorm( tp_size, pp_size, rankID, x, weight, eps, withGraph=False, distributed_init_method: Optional[str] = None, ): device = torch.device(f"cuda:{rankID}") torch.cuda.set_device(device) # init logger.info(f"RANK: {rankID} {tp_size} init_process_group...") set_custom_all_reduce(True) init_distributed_environment( world_size=tp_size, rank=rankID, distributed_init_method=distributed_init_method, ) ensure_model_parallel_initialized(tp_size, pp_size) x = x.to(device) weight = weight.to(device) # dist.barrier(device_ids=[i for i in range(tp_size)]) # warmup and align all gpu group = get_tp_group().device_group dist.all_reduce(torch.zeros(1).cuda(), group=group) torch.cuda.synchronize() if withGraph: graph = torch.cuda.CUDAGraph() with graph_capture() as gc: with torch.cuda.graph(graph, stream=gc.stream): ar_out = tensor_model_parallel_all_reduce(x) # out = aiter.rms_norm(ar_out, weight, eps, 0) out = torch.empty_like(ar_out) residual_out = torch.empty_like(ar_out) aiter.rmsnorm2d_fwd_with_add( out, ar_out, x, residual_out, weight, eps, 0, ) out.fill_(0) @perftest() def run_ca(): graph.replay() _, us = run_ca() out = (out, us) else: @perftest() def run_ca(x): ar_out = tensor_model_parallel_all_reduce(x) out = torch.empty_like(ar_out) residual_out = torch.empty_like(ar_out) aiter.rmsnorm2d_fwd_with_add( out, ar_out, x, residual_out, weight, eps, 0, ) return out out = run_ca(x) # destroy if dist.is_initialized(): destroy_model_parallel() destroy_distributed_environment() torch.cuda.empty_cache() return out @benchmark() def test_fused_ar_rmsnorm( tp_size, pp_size, shape, dtype, withGraph=False, distributed_init_method: Optional[str] = None, post_per_token_quant: bool = False, ): os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "49373" pool = Pool(processes=tp_size) ref = torch.zeros(shape, dtype=dtype) rets = [] cpu_rslt = [] weight_list = [] res_inp = [] # print(type(shape[0]), shape[1], ref.device) n = shape[1] eps = 1e-6 weight = torch.randn((n,), dtype=dtype) x = torch.randn(shape, dtype=dtype) ref = x * tp_size for i in range(tp_size): res_inp.append(x) weight_list.append(weight) rets.append( pool.apply_async( fused_ar_rmsnorm, args=( tp_size, pp_size, i, x, weight, eps, withGraph, distributed_init_method, post_per_token_quant, ), ) ) pool.close() pool.join() print(f"rslt[0][0] = {ref[0][0]}") for i in range(tp_size): host_rslt = F.rms_norm( input=(ref + res_inp[i]), normalized_shape=(ref.shape[-1],), weight=weight_list[i], eps=eps, ) # host_rslt = ref + res_inp[i] cpu_rslt.append(host_rslt) rets = [el.get() for el in rets] all_us = [us for _, us in rets] atol = 5e-2 if post_per_token_quant else 1e-2 rtol = atol max_err = 0.0 for out, us in rets: msg = f"test_fused_ar_rmsnorm: {shape=} {dtype=} {withGraph=} {us:>8.2f}" # print(cpu_rslt[out.device.index]) err = checkAllclose( cpu_rslt[out.device.index], out.to(ref), msg=msg, atol=atol, rtol=rtol ) max_err = max(max_err, err) # checkAllclose(ref, out.to(ref), msg=msg) suffix = "quant" if post_per_token_quant else "fused" return { f"{suffix}_min_us": min(all_us), f"{suffix}_max_us": max(all_us), f"{suffix}_err": max_err, } l_dtype = ["fp16", "bf16"] # (13, 2880): GPT-OSS-120B / GPT-OSS-20B hidden_size (n_bytes=5760, 4096 < 5760 < 8192) l_shape = [ (13, 512), (13, 1024), (13, 2048), (13, 2880), (17, 4096), (17, 7168), (19, 8192), ] l_tp = [8] l_pp = [1] l_graph = [False, True] parser = argparse.ArgumentParser(description="config input of test") parser.add_argument( "-d", "--dtype", type=str, choices=l_dtype, nargs="?", const=None, default=None, help="data type", ) parser.add_argument( "-s", "--shape", type=dtypes.str2tuple, nargs="*", default=None, help="shape(s). e.g. -s 128,8192 256,7168", ) parser.add_argument( "-t", "--tp", type=int, nargs="?", const=None, default=None, help="tp num. e.g. -t 8", ) parser.add_argument( "-p", "--pp", type=int, nargs="?", const=None, default=None, help="tp num. e.g. -p 1", ) parser.add_argument( "-g", "--graphon", type=int, nargs="?", const=None, default=None, help="open cudagraph. e.g. -g 1", ) l_test_types = ["fused", "quant"] parser.add_argument( "--test", type=str, choices=l_test_types, nargs="*", default=None, help="test type(s) to run. e.g. --test fused quant", ) if __name__ == "__main__": freeze_support() args = parser.parse_args() if args.dtype is None: l_dtype = [dtypes.d_dtypes[key] for key in l_dtype] else: l_dtype = [dtypes.d_dtypes[args.dtype]] if args.shape is not None: l_shape = args.shape if args.tp is not None: l_tp = [args.tp] if args.pp is not None: l_pp = [args.pp] if args.graphon is not None: print(args.graphon) l_graph = [args.graphon] run_tests = args.test if args.test else l_test_types df = [] for dtype, shape, tp, pp, graph_on in itertools.product( l_dtype, l_shape, l_tp, l_pp, l_graph ): row = {} if "fused" in run_tests: ret = test_fused_ar_rmsnorm( tp, pp, shape, dtype, withGraph=graph_on, distributed_init_method=get_distributed_init_method( get_ip(), get_open_port() ), post_per_token_quant=False, ) row.update(ret) if "quant" in run_tests: ret = test_fused_ar_rmsnorm( tp, pp, shape, dtype, withGraph=graph_on, distributed_init_method=get_distributed_init_method( get_ip(), get_open_port() ), post_per_token_quant=True, ) row.update(ret) df.append(row) df = pd.DataFrame(df) show_cols = [ "tp_size", "shape", "dtype", "withGraph", "fused_min_us", "fused_max_us", "fused_err", "quant_min_us", "quant_max_us", "quant_err", ] show_cols = [c for c in show_cols if c in df.columns] df[show_cols].to_csv("test_fused_ar_rms.csv", index=False) logger.info( "fused allreduce rmsnorm summary (markdown):\n%s", df[show_cols].to_markdown(index=False), )