/************************************************************************* * Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ #include "TestBed.hpp" namespace RcclUnitTesting { TEST(ReduceScatter, OutOfPlace) { TestBed testBed; // Configuration std::vector const funcTypes = {ncclCollReduceScatter}; std::vector const dataTypes = {ncclFloat32}; std::vector const redOps = {ncclMax}; std::vector const roots = {0}; std::vector const numElements = {393216, 384}; std::vector const inPlaceList = {false}; std::vector const managedMemList = {false}; std::vector const useHipGraphList = {false}; testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements, inPlaceList, managedMemList, useHipGraphList); testBed.Finalize(); } TEST(ReduceScatter, OutOfPlaceGraph) { TestBed testBed; // Configuration std::vector const funcTypes = {ncclCollReduceScatter}; std::vector const dataTypes = {ncclFloat64, ncclBfloat16}; std::vector const redOps = {ncclMax}; std::vector const roots = {0}; std::vector const numElements = {1048576}; std::vector const inPlaceList = {false}; std::vector const managedMemList = {false}; std::vector const useHipGraphList = {true}; testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements, inPlaceList, managedMemList, useHipGraphList); testBed.Finalize(); } TEST(ReduceScatter, InPlace) { TestBed testBed; // Configuration std::vector const funcTypes = {ncclCollReduceScatter}; std::vector const dataTypes = {ncclInt32}; std::vector const redOps = {ncclProd}; std::vector const roots = {0, 1}; std::vector const numElements = {542357}; std::vector const inPlaceList = {true}; std::vector const managedMemList = {false}; std::vector const useHipGraphList = {false}; testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements, inPlaceList, managedMemList, useHipGraphList); testBed.Finalize(); } TEST(ReduceScatter, InPlaceGraph) { TestBed testBed; // Configuration std::vector const funcTypes = {ncclCollReduceScatter}; std::vector const dataTypes = {ncclUint8, ncclFloat16}; std::vector const redOps = {ncclMin}; std::vector const roots = {0}; std::vector const numElements = {246};; std::vector const inPlaceList = {true}; std::vector const managedMemList = {false}; std::vector const useHipGraphList = {true}; testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements, inPlaceList, managedMemList, useHipGraphList); testBed.Finalize(); } TEST(ReduceScatter, ManagedMem) { TestBed testBed; // Configuration std::vector const funcTypes = {ncclCollReduceScatter}; std::vector const dataTypes = {ncclInt64, ncclUint8}; std::vector const redOps = {ncclAvg}; std::vector const roots = {0}; std::vector const numElements = {1024}; std::vector const inPlaceList = {false}; std::vector const managedMemList = {true}; std::vector const useHipGraphList = {false}; testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements, inPlaceList, managedMemList, useHipGraphList); testBed.Finalize(); } TEST(ReduceScatter, ManagedMemGraph) { TestBed testBed; // Configuration std::vector const funcTypes = {ncclCollReduceScatter}; std::vector const dataTypes = {ncclUint32, ncclUint64}; std::vector const redOps = {ncclAvg}; std::vector const roots = {0}; std::vector const numElements = {6485423}; std::vector const inPlaceList = {false}; std::vector const managedMemList = {true}; std::vector const useHipGraphList = {true}; testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements, inPlaceList, managedMemList, useHipGraphList); testBed.Finalize(); } }