/************************************************************************* * Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ #include "TestBed.hpp" namespace RcclUnitTesting { TEST(AllReduce, OutOfPlace) { TestBed testBed; // Configuration std::vector const funcTypes = {ncclCollAllReduce}; std::vector const dataTypes = {ncclFloat32}; std::vector const redOps = {ncclSum}; std::vector const roots = {0}; std::vector const numElements = {393216, 384}; std::vector const inPlaceList = {false}; std::vector const managedMemList = {false}; std::vector const useHipGraphList = {false}; testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements, inPlaceList, managedMemList, useHipGraphList); testBed.Finalize(); } TEST(AllReduce, OutOfPlaceGraph) { TestBed testBed; // Configuration std::vector const funcTypes = {ncclCollAllReduce}; std::vector const dataTypes = {ncclFloat16, ncclFloat64}; std::vector const redOps = {ncclMin}; std::vector const roots = {0}; std::vector const numElements = {12888}; std::vector const inPlaceList = {false}; std::vector const managedMemList = {false}; std::vector const useHipGraphList = {true}; testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements, inPlaceList, managedMemList, useHipGraphList); testBed.Finalize(); } TEST(AllReduce, InPlace) { TestBed testBed; // Configuration std::vector const funcTypes = {ncclCollAllReduce}; std::vector const dataTypes = {ncclInt32, ncclInt8}; std::vector const redOps = {ncclProd}; std::vector const roots = {0}; std::vector const numElements = {384}; std::vector const inPlaceList = {true}; std::vector const managedMemList = {false}; std::vector const useHipGraphList = {false}; testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements, inPlaceList, managedMemList, useHipGraphList); testBed.Finalize(); } TEST(AllReduce, InPlaceGraph) { TestBed testBed; // Configuration std::vector const funcTypes = {ncclCollAllReduce}; std::vector const dataTypes = {ncclInt32}; std::vector const redOps = {ncclMax}; std::vector const roots = {0}; std::vector const numElements = {393216}; std::vector const inPlaceList = {true}; std::vector const managedMemList = {false}; std::vector const useHipGraphList = {true}; testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements, inPlaceList, managedMemList, useHipGraphList); testBed.Finalize(); } TEST(AllReduce, ManagedMem) { TestBed testBed; // Configuration std::vector const funcTypes = {ncclCollAllReduce}; std::vector const dataTypes = {ncclUint8, ncclUint64}; std::vector const redOps = {ncclSum}; std::vector const roots = {0}; std::vector const numElements = {2500}; std::vector const inPlaceList = {false}; std::vector const managedMemList = {true}; std::vector const useHipGraphList = {false}; testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements, inPlaceList, managedMemList, useHipGraphList); testBed.Finalize(); } TEST(AllReduce, ManagedMemGraph) { TestBed testBed; // Configuration std::vector const funcTypes = {ncclCollAllReduce}; std::vector const dataTypes = {ncclFloat64, ncclBfloat16}; std::vector const redOps = {ncclSum}; std::vector const roots = {0}; std::vector const numElements = {4314}; std::vector const inPlaceList = {false}; std::vector const managedMemList = {true}; std::vector const useHipGraphList = {true}; testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements, inPlaceList, managedMemList, useHipGraphList); testBed.Finalize(); } TEST(AllReduce, DISABLED_Clique) { // Set clique env var prior to TestBed setenv("RCCL_ENABLE_CLIQUE", "1", 1); TestBed testBed; // Configuration std::vector const funcTypes = {ncclCollAllReduce}; std::vector const dataTypes = testBed.GetAllSupportedDataTypes(); std::vector const redOps = testBed.GetAllSupportedRedOps(); std::vector const roots = {0}; std::vector const numElements = {1048576, 1024}; std::vector const inPlaceList = {false, true}; std::vector const managedMemList = {false}; std::vector const useHipGraphList = {false, true}; testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements, inPlaceList, managedMemList, useHipGraphList); testBed.Finalize(); unsetenv("RCCL_ENABLE_CLIQUE"); } // This tests using custom pre-mult scalars reductions TEST(AllReduce, PreMultScalar) { TestBed testBed; // Configuration ncclFunc_t const funcType = ncclCollAllReduce; std::vector const& dataTypes = {ncclFloat32}; ncclRedOp_t const redOp = ncclSum; std::vector const numElements = {384 * 1024, 384 * 32, 384}; bool const inPlace = false; bool const useManagedMem = false; OptionalColArgs options; // Terminate the test as soon as first failure occurs bool isCorrect = true; for (int totalRanks : testBed.ev.GetNumGpusList()) for (int isMultiProcess : testBed.ev.GetIsMultiProcessList()) { int const numProcesses = isMultiProcess ? totalRanks : 1; testBed.InitComms(TestBed::GetDeviceIdsList(numProcesses, totalRanks)); for (int dataIdx = 0; dataIdx < dataTypes.size() && isCorrect; ++dataIdx) { ncclDataType_t const dataType = dataTypes[dataIdx]; // Set scalars per rank PtrUnion scalarsPerRank; scalarsPerRank.AllocateCpuMem(totalRanks * DataTypeToBytes(dataType)); for (int i = 0; i < totalRanks; i++) { double F = i; scalarsPerRank.Set(dataType, i, i, F); } int const numBytes = totalRanks * DataTypeToBytes(dataType); memcpy(options.scalarTransport.ptr, scalarsPerRank.ptr, numBytes); // Test various scalar residence modes for (int scalarMode = 0; scalarMode <= 1 && isCorrect; ++scalarMode) { if (testBed.ev.showNames) INFO("%s %d-ranks AllReduce (custom-scalar Mode %d %s)\n", isMultiProcess ? "MP" : "SP", totalRanks, scalarMode, ncclDataTypeNames[dataType]); for (int i = 0; i < numElements.size() && isCorrect; ++i) { options.scalarMode = scalarMode; options.redOp = redOp; testBed.SetCollectiveArgs(funcType, dataType, numElements[i], numElements[i], options); // For performance, only allocate and prepare data on largest size if (i == 0) { testBed.AllocateMem(inPlace, useManagedMem); testBed.PrepareData(); } testBed.ExecuteCollectives(); testBed.ValidateResults(isCorrect); } testBed.DeallocateMem(); } } testBed.DestroyComms(); } testBed.Finalize(); } }