CollectiveArgs.hpp 4.66 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
/*************************************************************************
 * Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * See LICENSE.txt for license information
 ************************************************************************/
#pragma once
#include "PtrUnion.hpp"
#include "PrepDataFuncs.hpp"
#include "rccl/rccl.h"

namespace RcclUnitTesting
{
  // Enumeration of all collective functions currently supported
  typedef enum
  {
    ncclCollBroadcast = 0,
    ncclCollReduce,
    ncclCollAllGather,
    ncclCollReduceScatter,
    ncclCollAllReduce,
    ncclCollGather,
    ncclCollScatter,
    ncclCollAllToAll,
    ncclCollAllToAllv,
    ncclCollSend,
    ncclCollRecv,
    ncclNumFuncs
  } ncclFunc_t;

  char const ncclFuncNames[ncclNumFuncs][32] =
  {
    "Broadcast",
    "Reduce",
    "AllGather",
    "ReduceScatter",
    "AllReduce",
    "Gather",
    "Scatter",
    "AllToAll",
    "AllToAllv",
    "Send",
    "Recv"
  };

  char const ncclDataTypeNames[ncclNumTypes][32] =
  {
    "ncclInt8",
    "ncclUint8",
    "ncclInt32",
    "ncclUint32",
    "ncclInt64",
    "ncclUint64",
    "ncclFloat16",
    "ncclFloat32",
    "ncclFloat64",
    "ncclBfloat16"
  };

  char const ncclRedOpNames[ncclNumOps][32] =
  {
    "sum",
    "prod",
    "max",
    "min",
    "avg"
  };

  class CollectiveArgs;

  #define MAX_RANKS 32
  struct ScalarTransport
  {
    char ptr[MAX_RANKS * sizeof(double)];
  };

  struct OptionalColArgs
  {
    ncclRedOp_t     redOp = ncclSum;
    int             root = 0;               // Used as "peer" for Send/Recv
    ScalarTransport scalarTransport;        // Used for custom reduction operators
    int             scalarMode = -1;        // -1 if scalar not used

    // allToAllv args
    size_t          sendcounts[MAX_RANKS*MAX_RANKS];
    size_t          sdispls[MAX_RANKS*MAX_RANKS];
    size_t          recvcounts[MAX_RANKS*MAX_RANKS];
    size_t          rdispls[MAX_RANKS*MAX_RANKS];
  };

  // Function pointer for functions that operate on CollectiveArgs
  // e.g. For filling input / computing expected results
  typedef ErrCode (*CollFuncPtr)(CollectiveArgs &);

  class CollectiveArgs
  {
  public:
    // Arguments to execute
    int             globalRank;
    int             totalRanks;
    int             deviceId;
    ncclFunc_t      funcType;
    ncclDataType_t  dataType;
    size_t          numInputElements;
    size_t          numOutputElements;
    PtrUnion        localScalar;
    int             streamIdx;
    OptionalColArgs options;

    // Data
    PtrUnion       inputGpu;
    PtrUnion       outputGpu;
    PtrUnion       outputCpu;
    PtrUnion       expected;
    bool           inPlace;
    bool           useManagedMem;
    size_t         numInputBytesAllocated;
    size_t         numOutputBytesAllocated;
    size_t         numInputElementsAllocated;
    size_t         numOutputElementsAllocated;

    // Set collective arguments
    ErrCode SetArgs(int             const globalRank,
                    int             const totalRanks,
                    int             const deviceId,
                    ncclFunc_t      const funcType,
                    ncclDataType_t  const dataType,
                    size_t          const numInputElements,
                    size_t          const numOutputElements,
                    int             const streamIdx,
                    OptionalColArgs const &optionalArgs = {});

    // Allocates GPU memory for input/output and CPU memory for expected
    // When inPlace is true, input and output share the same memory
    ErrCode AllocateMem(bool   const inPlace,
                        bool   const useManagedMem);

    // Execute the provided data preparation function to fill input and compute expected results
    ErrCode PrepareData(CollFuncPtr const prepareDataFunc);

    // Compare outputs to expected values
    ErrCode ValidateResults();

    // Deallocate memory
    ErrCode DeallocateMem();

    // Provide a description for the current collective arguments
    std::string GetDescription() const;

    // Returns the number of inputs/outputs based on collective function type
    static void GetNumElementsForFuncType(ncclFunc_t const funcType,
                                          int        const N,
                                          int        const totalRanks,
                                          int*             numInputElements,
                                          int*             numOutputElements);

    // Returns true if collective function performs reduction
    static bool UsesReduce(ncclFunc_t const funcType);

    // Returns true if collective function utilizes a root rank
    static bool UsesRoot(ncclFunc_t const funcType);
  };
}