sddmm.cc 12.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
/*!
 *  Copyright (c) 2020 by Contributors
 * \file aten/cpu/sddmm.cc
 * \brief SDDMM C APIs and definitions.
 */
#include "./sddmm.h"
#include <dgl/array.h>

namespace dgl {
namespace aten {

12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#define SWITCH_RHS(rhs_target, RhsTarget, ...)                        \
  do {                                                                \
    if ((rhs_target) == 0) {                                          \
      constexpr int RhsTarget = 0;                                    \
      { __VA_ARGS__ }                                                 \
    } else if ((rhs_target) == 1) {                                   \
      constexpr int RhsTarget = 1;                                    \
      { __VA_ARGS__ }                                                 \
    } else if ((rhs_target) == 2) {                                   \
      constexpr int RhsTarget = 2;                                    \
      { __VA_ARGS__ }                                                 \
    } else {                                                          \
      LOG(INFO) << "Invalid rhs target: " << (rhs_target);            \
    }                                                                 \
  } while (0)

#define SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, ...)\
  do {                                                                  \
    if ((lhs_target) == 0) {                                            \
      constexpr int LhsTarget = 0;                                      \
      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                   \
    } else if ((lhs_target) == 1) {                                     \
      constexpr int LhsTarget = 1;                                      \
      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                   \
    } else if ((lhs_target) == 2) {                                     \
      constexpr int LhsTarget = 2;                                      \
      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                   \
    } else {                                                            \
      LOG(INFO) << "Invalid lhs target: " << (lhs_target);              \
    }                                                                   \
  } while (0)

44
45
46
47
48
49
50
51
52
#define SWITCH_BITS(bits, DType, ...)                           \
  do {                                                          \
    if ((bits) == 16 || (bits) == 32) {                         \
      typedef float DType;                                      \
      { __VA_ARGS__ }                                           \
    } else if ((bits) == 64) {                                  \
      typedef double DType;                                     \
      { __VA_ARGS__ }                                           \
    } else {                                                    \
53
      LOG(FATAL) << "Data type not recognized with bits " << bits; \
54
55
56
    }                                                           \
  } while (0)

57

58
/*! \brief Generalized SDDMM on Csr format. */
59
template <int XPU, typename IdType, int bits>
60
61
62
void SDDMMCsr(const std::string& op,
              const BcastOff& bcast,
              const CSRMatrix& csr,
63
64
65
66
67
              NDArray lhs,
              NDArray rhs,
              NDArray out,
              int lhs_target,
              int rhs_target) {
68
69
70
71
72
  SWITCH_BITS(bits, DType, {
    SWITCH_OP(op, Op, {
      SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
        cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out);
      });
73
    });
74
75
76
  });
}

77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
/*! \brief Generalized SDDMM on Csr format with Heterograph support. */
template <int XPU, typename IdType, int bits>
void SDDMMCsrHetero(const std::string& op,
              const BcastOff& bcast,
              const std::vector<CSRMatrix>& vec_csr,
              const std::vector<NDArray>& vec_lhs,
              const std::vector<NDArray>& vec_rhs,
              std::vector<NDArray> vec_out,
              int lhs_target,
              int rhs_target,
              const std::vector<dgl_type_t>& lhs_nid,
              const std::vector<dgl_type_t>& rhs_nid) {
  SWITCH_BITS(bits, DType, {
    SWITCH_OP(op, Op, {
      SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
        /* Call  SDDMM for each relation type */
        for (dgl_type_t etype = 0; etype < lhs_nid.size(); ++etype) {
          CSRMatrix csr = vec_csr[etype];
          NDArray lhs = vec_lhs[lhs_nid[etype]];
          NDArray rhs = vec_rhs[rhs_nid[etype]];
          NDArray out = vec_out[etype];
          cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out);
        }
      });
    });
  });
}

105
106
107
108
109
110
111
112
113
template void SDDMMCsr<kDLCPU, int32_t, 16>(
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
template void SDDMMCsr<kDLCPU, int64_t, 16>(
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
template void SDDMMCsr<kDLCPU, int32_t, 32>(
114
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
115
116
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
117
template void SDDMMCsr<kDLCPU, int64_t, 32>(
118
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
119
120
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
121
template void SDDMMCsr<kDLCPU, int32_t, 64>(
122
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
123
124
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
125
template void SDDMMCsr<kDLCPU, int64_t, 64>(
126
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
127
128
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
129

130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
template void SDDMMCsrHetero<kDLCPU, int32_t, 16>(
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLCPU, int64_t, 16>(
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLCPU, int32_t, 32>(
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLCPU, int64_t, 32>(
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLCPU, int32_t, 64>(
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLCPU, int64_t, 64>(
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
172

173
/*! \brief Generalized SDDMM on Coo format. */
174
template <int XPU, typename IdType, int bits>
175
176
177
void SDDMMCoo(const std::string& op,
              const BcastOff& bcast,
              const COOMatrix& coo,
178
179
180
181
182
              NDArray lhs,
              NDArray rhs,
              NDArray out,
              int lhs_target,
              int rhs_target) {
183
184
185
186
187
  SWITCH_BITS(bits, DType, {
    SWITCH_OP(op, Op, {
      SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
        cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out);
      });
188
    });
189
190
191
  });
}

192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
/*! \brief Generalized SDDMM on Coo format with Heterograph support. */
template <int XPU, typename IdType, int bits>
void SDDMMCooHetero(const std::string& op,
              const BcastOff& bcast,
              const std::vector<COOMatrix>& vec_coo,
              const std::vector<NDArray>& vec_lhs,
              const std::vector<NDArray>& vec_rhs,
              std::vector<NDArray> vec_out,
              int lhs_target,
              int rhs_target,
              const std::vector<dgl_type_t>& lhs_nid,
              const std::vector<dgl_type_t>& rhs_nid) {
  SWITCH_BITS(bits, DType, {
    SWITCH_OP(op, Op, {
      SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
        /* Call  SDDMM for each relation type */
        for (dgl_type_t etype = 0; etype < lhs_nid.size(); ++etype) {
          COOMatrix coo = vec_coo[etype];
          NDArray lhs = vec_lhs[lhs_nid[etype]];
          NDArray rhs = vec_rhs[rhs_nid[etype]];
          NDArray out = vec_out[etype];
          cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out);
        }
      });
    });
  });
}

220
221
222
223
224
template void SDDMMCoo<kDLCPU, int32_t, 16>(
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
template void SDDMMCoo<kDLCPU, int64_t, 16>(
225
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
226
227
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
228
template void SDDMMCoo<kDLCPU, int32_t, 32>(
229
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
230
231
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
232
template void SDDMMCoo<kDLCPU, int64_t, 32>(
233
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
234
235
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
236
template void SDDMMCoo<kDLCPU, int32_t, 64>(
237
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
238
239
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
240
241
242
243
244
template void SDDMMCoo<kDLCPU, int64_t, 64>(
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);

245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
template void SDDMMCooHetero<kDLCPU, int32_t, 16>(
    const std::string& op, const BcastOff& bcast,
    const std::vector<COOMatrix>& vec_coo,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLCPU, int64_t, 16>(
    const std::string& op, const BcastOff& bcast,
    const std::vector<COOMatrix>& vec_coo,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLCPU, int32_t, 32>(
    const std::string& op, const BcastOff& bcast,
    const std::vector<COOMatrix>& vec_coo,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLCPU, int64_t, 32>(
    const std::string& op, const BcastOff& bcast,
    const std::vector<COOMatrix>& vec_coo,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLCPU, int32_t, 64>(
    const std::string& op, const BcastOff& bcast,
    const std::vector<COOMatrix>& vec_coo,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLCPU, int64_t, 64>(
    const std::string& op, const BcastOff& bcast,
    const std::vector<COOMatrix>& vec_coo,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);

288
289
290

}  // namespace aten
}  // namespace dgl