# SPDX-License-Identifier: MIT import argparse import logging import os from multiprocessing import Pool, freeze_support, set_start_method from typing_extensions import Optional import torch import torch.distributed as dist import pandas as pd from aiter import dtypes from aiter.dist.communication_op import tensor_model_parallel_all_reduce from aiter.dist.parallel_state import ( destroy_distributed_environment, destroy_model_parallel, ensure_model_parallel_initialized, get_tp_group, graph_capture, init_distributed_environment, set_custom_all_reduce, ) from aiter.dist.utils import get_distributed_init_method, get_ip, get_open_port from aiter.test_common import benchmark, checkAllclose, perftest logger = logging.getLogger("aiter") set_start_method("spawn", force=True) def allreduce_custom( tp_size, pp_size, rankID, x, withGraph=False, distributed_init_method: Optional[str] = None, enable_register_for_capturing: bool = True, ): device = torch.device(f"cuda:{rankID}") torch.cuda.set_device(device) # init # Forward the user-requested capturing-registration policy down to the # CustomAllreduce constructor via the env var consumed inside # CudaCommunicator. Must be set BEFORE init_distributed_environment so # the worker process picks it up when the communicator is built. os.environ["AITER_AR_ENABLE_REG_CAPTURE"] = ( "1" if enable_register_for_capturing else "0" ) logger.info(f"RANK: {rankID} {tp_size} init_process_group...") set_custom_all_reduce(True) init_distributed_environment( world_size=tp_size, rank=rankID, distributed_init_method=distributed_init_method, ) ensure_model_parallel_initialized(tp_size, pp_size) x = x.to(device) # dist.barrier(device_ids=[i for i in range(tp_size)]) # warmup and align all gpu group = get_tp_group().device_group dist.all_reduce(torch.zeros(1).cuda(), group=group) torch.cuda.synchronize() if withGraph: graph = torch.cuda.CUDAGraph() with graph_capture() as gc: with torch.cuda.graph(graph, stream=gc.stream): out = tensor_model_parallel_all_reduce(x) out.fill_(0) @perftest() def run_ca(): graph.replay() _, us = run_ca() out = (out, us) else: @perftest() def run_ca(x): return tensor_model_parallel_all_reduce(x) out = run_ca(x) # destroy if dist.is_initialized(): destroy_model_parallel() destroy_distributed_environment() torch.cuda.empty_cache() return out @benchmark() def test_allreduce_custom( tp_size, pp_size, shape, dtype, withGraph=False, distributed_init_method: Optional[str] = None, enable_register_for_capturing: bool = True, ): os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "49373" pool = Pool(processes=tp_size) ref = torch.zeros(shape, dtype=dtype) rets = [] for i in range(tp_size): x = torch.randn(shape, dtype=dtype) ref += x rets.append( pool.apply_async( allreduce_custom, args=( tp_size, pp_size, i, x, withGraph, distributed_init_method, enable_register_for_capturing, ), ) ) pool.close() pool.join() rets = [el.get() for el in rets] all_us = [us for _, us in rets] max_err = 0.0 for out, us in rets: msg = ( f"test_allreduce_custom: {shape=} {dtype=} " f"{withGraph=} reg_cap={enable_register_for_capturing} {us:>8.2f}" ) err = checkAllclose(ref, out.to(ref), msg=msg) max_err = max(max_err, err) return { "min_us": min(all_us), "max_us": max(all_us), "err": max_err, } l_dtype = ["fp16", "bf16"] l_shape = [(2, 7168), (128, 8192)] parser = argparse.ArgumentParser(description="config input of test") parser.add_argument( "-d", "--dtype", type=str, choices=l_dtype, nargs="?", const=None, default=None, help="data type", ) parser.add_argument( "-s", "--shape", type=dtypes.str2tuple, nargs="?", const=None, default=None, help="shape. e.g. -s 128,8192", ) parser.add_argument( "-g", "--with-graph", type=lambda x: str(x).lower() in ["true", "1", "yes"], default=True, help="use CUDA graph (default: True). e.g. -g true or -g false", ) parser.add_argument( "--reg-capturing", type=str, choices=["true", "false", "both"], default="both", help=( "whether CustomAllreduce.enable_register_for_capturing is True/False. " "'both' (default) exercises both paths as a regression sweep." ), ) if __name__ == "__main__": freeze_support() args = parser.parse_args() if args.dtype is None: l_dtype = [dtypes.d_dtypes[key] for key in l_dtype] else: l_dtype = [dtypes.d_dtypes[args.dtype]] if args.shape is not None: l_shape = [args.shape] if args.reg_capturing == "true": l_reg = [True] elif args.reg_capturing == "false": l_reg = [False] else: l_reg = [True, False] df = [] for dtype in l_dtype: for shape in l_shape: for reg in l_reg: ret = test_allreduce_custom( 8, 1, shape, dtype, withGraph=args.with_graph, distributed_init_method=get_distributed_init_method( get_ip(), get_open_port() ), enable_register_for_capturing=reg, ) df.append(ret) df = pd.DataFrame(df) show_cols = [ "tp_size", "shape", "dtype", "withGraph", "enable_register_for_capturing", "min_us", "max_us", "err", ] show_cols = [c for c in show_cols if c in df.columns] df[show_cols].to_csv("test_custom_allreduce.csv", index=False) logger.info( "custom allreduce summary (markdown):\n%s", df[show_cols].to_markdown(index=False), )