NonBlockingTests.cpp 2.63 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
/*************************************************************************
 * Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
 *
 * See LICENSE.txt for license information
 ************************************************************************/
#include "TestBed.hpp"

namespace RcclUnitTesting
{
  // Test various collectives using with non-blocking comms
  TEST(NonBlocking, SingleCalls)
  {
    TestBed testBed;

    // Configuration
    std::vector<ncclFunc_t> const  funcTypes = {ncclCollBroadcast,
                                                ncclCollReduce,
                                                ncclCollAllGather,
                                                ncclCollReduceScatter,
                                                ncclCollAllReduce,
                                                ncclCollGather,
                                                ncclCollScatter};
    int        const  numElements   = 1048576;
    bool       const  inPlace       = false;
    bool       const  useManagedMem = false;
    bool       const  useBlocking   = false;

    OptionalColArgs options;
    options.redOp = ncclSum;

    bool isCorrect = true;
    for (int totalRanks : testBed.ev.GetNumGpusList())
    for (int isMultiProcess : testBed.ev.GetIsMultiProcessList())
    {
      int const numProcesses = isMultiProcess ? totalRanks : 1;
      // Initialize communicators in non-blocking mode
      testBed.InitComms(TestBed::GetDeviceIdsList(numProcesses, totalRanks), 1, useBlocking);

      // Loop over various collective functions
      for (auto funcType : funcTypes)
      {
        if (testBed.ev.showNames)
          INFO("%s %d-ranks Non-Blocking %s\n",
               isMultiProcess ? "MP" : "SP", totalRanks, ncclFuncNames[funcType]);

        int numInputElements;
        int numOutputElements;
        CollectiveArgs::GetNumElementsForFuncType(funcType,
                                                  numElements,
                                                  totalRanks,
                                                  &numInputElements,
                                                  &numOutputElements);

        testBed.SetCollectiveArgs(funcType,
                                  ncclFloat,
                                  numInputElements,
                                  numOutputElements,
                                  options);

        testBed.AllocateMem(inPlace, useManagedMem);
        testBed.PrepareData();
        testBed.ExecuteCollectives();
        testBed.ValidateResults(isCorrect);
        testBed.DeallocateMem();
      }
      testBed.DestroyComms();
    }
    testBed.Finalize();
  }
}