spmm.cc 14.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
/*!
 *  Copyright (c) 2020 by Contributors
 * \file kernel/cpu/spmm.cc
 * \brief SPMM C APIs and definitions.
 */
#include "./spmm.h"
#include <dgl/array.h>

namespace dgl {
namespace aten {

/*! \brief Generalized SpMM on Csr format. */
13
template <int XPU, typename IdType, int bits>
14
15
16
17
18
19
20
void SpMMCsr(const std::string& op, const std::string& reduce,
             const BcastOff& bcast,
             const CSRMatrix& csr,
             NDArray ufeat,
             NDArray efeat,
             NDArray out,
             std::vector<NDArray> out_aux) {
21
  const int64_t dim = bcast.out_len;
22
  if (reduce == "sum") {
23
24
25
26
    SWITCH_BITS(bits, DType, {
      SWITCH_OP(op, Op, {
        cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
      });
27
28
    });
  } else if (reduce == "max" || reduce == "min") {
29
30
    SWITCH_BITS(bits, DType, {
      SWITCH_OP(op, Op, {
31
32
33
34
35
        DType *out_off = out.Ptr<DType>();
        IdType* argX = Op::use_lhs ? static_cast<IdType*>(out_aux[0]->data) : nullptr;
        IdType* argW = Op::use_rhs ? static_cast<IdType*>(out_aux[1]->data) : nullptr;
        if (reduce == "max") {
          std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Max<DType>::zero);
36
37
          cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Max<DType>>(
              bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
38
39
        } else {
          std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Min<DType>::zero);
40
41
          cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Min<DType>>(
              bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
        }
      });
    });
  } else {
    LOG(FATAL) << "Unsupported SpMM reducer: " << reduce;
  }
}

/*! \brief Generalized SpMM on Csr format. */
template <int XPU, typename IdType, int bits>
void SpMMCsrHetero(const std::string& op, const std::string& reduce,
             const BcastOff& bcast,
             const std::vector<CSRMatrix>& vec_csr,
             const std::vector<NDArray>& vec_ufeat,
             const std::vector<NDArray>& vec_efeat,
57
58
             std::vector<NDArray>* vec_out,
             std::vector<std::vector<NDArray>>* out_aux,
59
60
61
62
63
64
65
66
67
68
69
70
71
             const std::vector<dgl_type_t>& ufeat_node_tids,
             const std::vector<dgl_type_t>& out_node_tids) {
  const int64_t dim = bcast.out_len;
  if (reduce == "sum") {
    SWITCH_BITS(bits, DType, {
      SWITCH_OP(op, Op, {
        /* Call  SpMM for each relation type */
        for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
          const dgl_type_t src_id = ufeat_node_tids[etype];
          const dgl_type_t dst_id = out_node_tids[etype];
          CSRMatrix csr = vec_csr[etype];
          NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
          NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
72
          NDArray out = (*vec_out)[dst_id];
73
74
75
76
77
78
79
          cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
        }
      });
    });
  } else if (reduce == "max" || reduce == "min") {
    SWITCH_BITS(bits, DType, {
      SWITCH_OP(op, Op, {
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        std::vector<bool> updated((*vec_out).size(), false);
        // TODO(Israt): use vector updated to fill(out...) too
        for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
          DType *out_off = (*vec_out)[out_node_tids[etype]].Ptr<DType>();
          if (reduce == "max")
            std::fill(out_off, out_off + vec_csr[etype].num_rows * dim, cpu::op::Max<DType>::zero);
          else
            std::fill(out_off, out_off + vec_csr[etype].num_rows * dim, cpu::op::Min<DType>::zero);
          const dgl_type_t dst_id = out_node_tids[etype];
          if (!updated[dst_id]) {
            updated[dst_id] = true;
            if (Op::use_lhs) {
              IdType *argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>();
              std::fill(argu_ntype, argu_ntype + vec_csr[etype].num_rows * dim, -1);
            }
            if (Op::use_rhs) {
              IdType *arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>();
              std::fill(arge_etype, arge_etype + vec_csr[etype].num_rows * dim, -1);
            }
          }
        }
101
102
103
104
105
        /* Call  SpMM for each relation type */
        for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
          const dgl_type_t src_id = ufeat_node_tids[etype];
          const dgl_type_t dst_id = out_node_tids[etype];
          CSRMatrix csr = vec_csr[etype];
106
          DType *out_off = (*vec_out)[out_node_tids[etype]].Ptr<DType>();
107
108
          NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
          NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
109
          NDArray out = (*vec_out)[dst_id];
110
          if (reduce == "max") {
111
112
113
            cpu::SpMMCmpCsrHetero<IdType, DType, Op, cpu::op::Max<DType>>(
                bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id], (*out_aux)[1][dst_id],
                (*out_aux)[2][dst_id], (*out_aux)[3][dst_id], src_id, etype);
114
          } else {
115
116
117
            cpu::SpMMCmpCsrHetero<IdType, DType, Op, cpu::op::Min<DType>>(
                bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id], (*out_aux)[1][dst_id],
                (*out_aux)[2][dst_id], (*out_aux)[3][dst_id], src_id, etype);
118
119
          }
        }
120
      });
121
122
123
124
125
126
    });
  } else {
    LOG(FATAL) << "Unsupported SpMM reducer: " << reduce;
  }
}

127
template void SpMMCsr<kDLCPU, int32_t, 16>(
128
129
130
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
131
template void SpMMCsr<kDLCPU, int64_t, 16>(
132
133
134
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
135
template void SpMMCsr<kDLCPU, int32_t, 32>(
136
137
138
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
139
template void SpMMCsr<kDLCPU, int64_t, 32>(
140
141
142
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
143
144
145
146
147
148
149
150
151
template void SpMMCsr<kDLCPU, int32_t, 64>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLCPU, int64_t, 64>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);

152
153
154
155
template void SpMMCsrHetero<kDLCPU, int32_t, 16>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
156
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
157
158
159
160
161
162
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int64_t, 16>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
163
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
164
165
166
167
168
169
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int32_t, 32>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
170
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
171
172
173
174
175
176
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int64_t, 32>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
177
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
178
179
180
181
182
183
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int32_t, 64>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
184
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
185
186
187
188
189
190
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int64_t, 64>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
191
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
192
193
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
194

195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
/*! \brief Edge_softmax_csr forward op on Csr format. */
template <int XPU, typename IdType, int bits>
void Edge_softmax_csr_forward(const std::string& op,
             const BcastOff& bcast,
             const CSRMatrix& csr,
             NDArray ufeat,
             NDArray efeat,
             NDArray out) {
  SWITCH_BITS(bits, DType, {
      SWITCH_OP(op, Op, {
        cpu::Edge_softmax_csr_forward<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
      });
    });
}

/*! \brief Edge_softmax_csr backward op on Csr format. */
template <int XPU, typename IdType, int bits>
void Edge_softmax_csr_backward(const std::string& op,
             const BcastOff& bcast,
             const CSRMatrix& csr,
             NDArray out,
             NDArray sds,
             NDArray back_out) {
  SWITCH_BITS(bits, DType, {
    SWITCH_OP(op, Op, {
      cpu::Edge_softmax_csr_backward<IdType, DType, Op>(bcast, csr, out, sds, back_out);
    });
  });
}

template void Edge_softmax_csr_forward<kDLCPU, int32_t, 16>(
    const std::string& op,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDLCPU, int64_t, 16>(
    const std::string& op,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDLCPU, int32_t, 32>(
    const std::string& op,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDLCPU, int64_t, 32>(
    const std::string& op,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDLCPU, int32_t, 64>(
    const std::string& op,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDLCPU, int64_t, 64>(
    const std::string& op,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);

template void Edge_softmax_csr_backward<kDLCPU, int32_t, 16>(
    const std::string& op,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDLCPU, int64_t, 16>(
    const std::string& op,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDLCPU, int32_t, 32>(
    const std::string& op,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDLCPU, int64_t, 32>(
    const std::string& op,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDLCPU, int32_t, 64>(
    const std::string& op,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDLCPU, int64_t, 64>(
    const std::string& op,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);

275
/*! \brief Generalized SpMM on Coo format. */
276
template <int XPU, typename IdType, int bits>
277
278
279
280
281
282
283
284
void SpMMCoo(const std::string& op, const std::string& reduce,
             const BcastOff& bcast,
             const COOMatrix& coo,
             NDArray ufeat,
             NDArray efeat,
             NDArray out,
             std::vector<NDArray> out_aux) {
  if (reduce == "sum") {
285
286
287
288
    SWITCH_BITS(bits, DType, {
      SWITCH_OP(op, Op, {
        cpu::SpMMSumCoo<IdType, DType, Op>(bcast, coo, ufeat, efeat, out);
      });
289
290
    });
  } else if (reduce == "max" || reduce == "min") {
291
292
293
294
295
296
297
298
299
    SWITCH_BITS(bits, DType, {
      SWITCH_OP(op, Op, {
        if (reduce == "max")
          cpu::SpMMCmpCoo<IdType, DType, Op, cpu::op::Max<DType>>(
              bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
        else
          cpu::SpMMCmpCoo<IdType, DType, Op, cpu::op::Min<DType>>(
              bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
      });
300
301
302
303
304
305
    });
  } else {
    LOG(FATAL) << "Unsupported SpMM reducer: " << reduce;
  }
}

306
template void SpMMCoo<kDLCPU, int32_t, 16>(
307
308
309
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
310
template void SpMMCoo<kDLCPU, int64_t, 16>(
311
312
313
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
314
template void SpMMCoo<kDLCPU, int32_t, 32>(
315
316
317
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
318
template void SpMMCoo<kDLCPU, int64_t, 32>(
319
320
321
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
322
323
324
325
326
327
328
329
330
template void SpMMCoo<kDLCPU, int32_t, 64>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLCPU, int64_t, 64>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);

331
332
333

}  // namespace aten
}  // namespace dgl