"examples/vscode:/vscode.git/clone" did not exist on "a5a35d1f7a104ab3a2e6e275d3b6440aebf79a0e"
sddmm.cc 12.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
/*!
 *  Copyright (c) 2020 by Contributors
 * \file aten/cpu/sddmm.cc
 * \brief SDDMM C APIs and definitions.
 */
#include "./sddmm.h"
#include <dgl/array.h>

namespace dgl {
namespace aten {

12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#define SWITCH_RHS(rhs_target, RhsTarget, ...)                        \
  do {                                                                \
    if ((rhs_target) == 0) {                                          \
      constexpr int RhsTarget = 0;                                    \
      { __VA_ARGS__ }                                                 \
    } else if ((rhs_target) == 1) {                                   \
      constexpr int RhsTarget = 1;                                    \
      { __VA_ARGS__ }                                                 \
    } else if ((rhs_target) == 2) {                                   \
      constexpr int RhsTarget = 2;                                    \
      { __VA_ARGS__ }                                                 \
    } else {                                                          \
      LOG(INFO) << "Invalid rhs target: " << (rhs_target);            \
    }                                                                 \
  } while (0)

#define SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, ...)\
  do {                                                                  \
    if ((lhs_target) == 0) {                                            \
      constexpr int LhsTarget = 0;                                      \
      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                   \
    } else if ((lhs_target) == 1) {                                     \
      constexpr int LhsTarget = 1;                                      \
      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                   \
    } else if ((lhs_target) == 2) {                                     \
      constexpr int LhsTarget = 2;                                      \
      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                   \
    } else {                                                            \
      LOG(INFO) << "Invalid lhs target: " << (lhs_target);              \
    }                                                                   \
  } while (0)

44
45
46
47
48
49
50
51
52
#define SWITCH_BITS(bits, DType, ...)                           \
  do {                                                          \
    if ((bits) == 16 || (bits) == 32) {                         \
      typedef float DType;                                      \
      { __VA_ARGS__ }                                           \
    } else if ((bits) == 64) {                                  \
      typedef double DType;                                     \
      { __VA_ARGS__ }                                           \
    } else {                                                    \
53
      LOG(FATAL) << "Data type not recognized with bits " << bits; \
54
55
56
    }                                                           \
  } while (0)

57

58
/*! \brief Generalized SDDMM on Csr format. */
59
template <int XPU, typename IdType, int bits>
60
61
62
void SDDMMCsr(const std::string& op,
              const BcastOff& bcast,
              const CSRMatrix& csr,
63
64
65
66
67
              NDArray lhs,
              NDArray rhs,
              NDArray out,
              int lhs_target,
              int rhs_target) {
68
69
70
71
72
  SWITCH_BITS(bits, DType, {
    SWITCH_OP(op, Op, {
      SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
        cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out);
      });
73
    });
74
75
76
  });
}

77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
/*! \brief Generalized SDDMM on Csr format with Heterograph support. */
template <int XPU, typename IdType, int bits>
void SDDMMCsrHetero(const std::string& op,
              const BcastOff& bcast,
              const std::vector<CSRMatrix>& vec_csr,
              const std::vector<NDArray>& vec_lhs,
              const std::vector<NDArray>& vec_rhs,
              std::vector<NDArray> vec_out,
              int lhs_target,
              int rhs_target,
              const std::vector<dgl_type_t>& lhs_nid,
              const std::vector<dgl_type_t>& rhs_nid) {
  SWITCH_BITS(bits, DType, {
    SWITCH_OP(op, Op, {
      SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
        /* Call  SDDMM for each relation type */
        for (dgl_type_t etype = 0; etype < lhs_nid.size(); ++etype) {
          CSRMatrix csr = vec_csr[etype];
          NDArray lhs = vec_lhs[lhs_nid[etype]];
          NDArray rhs = vec_rhs[rhs_nid[etype]];
          NDArray out = vec_out[etype];
          cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out);
        }
      });
    });
  });
}

105
template void SDDMMCsr<kDGLCPU, int32_t, 16>(
106
107
108
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
109
template void SDDMMCsr<kDGLCPU, int64_t, 16>(
110
111
112
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
113
template void SDDMMCsr<kDGLCPU, int32_t, 32>(
114
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
115
116
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
117
template void SDDMMCsr<kDGLCPU, int64_t, 32>(
118
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
119
120
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
121
template void SDDMMCsr<kDGLCPU, int32_t, 64>(
122
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
123
124
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
125
template void SDDMMCsr<kDGLCPU, int64_t, 64>(
126
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
127
128
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
129

130
template void SDDMMCsrHetero<kDGLCPU, int32_t, 16>(
131
132
133
134
135
136
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
137
template void SDDMMCsrHetero<kDGLCPU, int64_t, 16>(
138
139
140
141
142
143
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
144
template void SDDMMCsrHetero<kDGLCPU, int32_t, 32>(
145
146
147
148
149
150
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
151
template void SDDMMCsrHetero<kDGLCPU, int64_t, 32>(
152
153
154
155
156
157
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
158
template void SDDMMCsrHetero<kDGLCPU, int32_t, 64>(
159
160
161
162
163
164
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
165
template void SDDMMCsrHetero<kDGLCPU, int64_t, 64>(
166
167
168
169
170
171
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
172

173
/*! \brief Generalized SDDMM on Coo format. */
174
template <int XPU, typename IdType, int bits>
175
176
177
void SDDMMCoo(const std::string& op,
              const BcastOff& bcast,
              const COOMatrix& coo,
178
179
180
181
182
              NDArray lhs,
              NDArray rhs,
              NDArray out,
              int lhs_target,
              int rhs_target) {
183
184
185
186
187
  SWITCH_BITS(bits, DType, {
    SWITCH_OP(op, Op, {
      SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
        cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out);
      });
188
    });
189
190
191
  });
}

192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
/*! \brief Generalized SDDMM on Coo format with Heterograph support. */
template <int XPU, typename IdType, int bits>
void SDDMMCooHetero(const std::string& op,
              const BcastOff& bcast,
              const std::vector<COOMatrix>& vec_coo,
              const std::vector<NDArray>& vec_lhs,
              const std::vector<NDArray>& vec_rhs,
              std::vector<NDArray> vec_out,
              int lhs_target,
              int rhs_target,
              const std::vector<dgl_type_t>& lhs_nid,
              const std::vector<dgl_type_t>& rhs_nid) {
  SWITCH_BITS(bits, DType, {
    SWITCH_OP(op, Op, {
      SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
        /* Call  SDDMM for each relation type */
        for (dgl_type_t etype = 0; etype < lhs_nid.size(); ++etype) {
          COOMatrix coo = vec_coo[etype];
          NDArray lhs = vec_lhs[lhs_nid[etype]];
          NDArray rhs = vec_rhs[rhs_nid[etype]];
          NDArray out = vec_out[etype];
          cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out);
        }
      });
    });
  });
}

220
template void SDDMMCoo<kDGLCPU, int32_t, 16>(
221
222
223
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
224
template void SDDMMCoo<kDGLCPU, int64_t, 16>(
225
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
226
227
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
228
template void SDDMMCoo<kDGLCPU, int32_t, 32>(
229
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
230
231
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
232
template void SDDMMCoo<kDGLCPU, int64_t, 32>(
233
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
234
235
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
236
template void SDDMMCoo<kDGLCPU, int32_t, 64>(
237
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
238
239
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
240
template void SDDMMCoo<kDGLCPU, int64_t, 64>(
241
242
243
244
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);

245
template void SDDMMCooHetero<kDGLCPU, int32_t, 16>(
246
247
248
249
250
251
    const std::string& op, const BcastOff& bcast,
    const std::vector<COOMatrix>& vec_coo,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
252
template void SDDMMCooHetero<kDGLCPU, int64_t, 16>(
253
254
255
256
257
258
    const std::string& op, const BcastOff& bcast,
    const std::vector<COOMatrix>& vec_coo,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
259
template void SDDMMCooHetero<kDGLCPU, int32_t, 32>(
260
261
262
263
264
265
    const std::string& op, const BcastOff& bcast,
    const std::vector<COOMatrix>& vec_coo,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
266
template void SDDMMCooHetero<kDGLCPU, int64_t, 32>(
267
268
269
270
271
272
    const std::string& op, const BcastOff& bcast,
    const std::vector<COOMatrix>& vec_coo,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
273
template void SDDMMCooHetero<kDGLCPU, int32_t, 64>(
274
275
276
277
278
279
    const std::string& op, const BcastOff& bcast,
    const std::vector<COOMatrix>& vec_coo,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
280
template void SDDMMCooHetero<kDGLCPU, int64_t, 64>(
281
282
283
284
285
286
287
    const std::string& op, const BcastOff& bcast,
    const std::vector<COOMatrix>& vec_coo,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);

288
289
290

}  // namespace aten
}  // namespace dgl