spmm.cc 14.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
/*!
 *  Copyright (c) 2020 by Contributors
 * \file kernel/cpu/spmm.cc
 * \brief SPMM C APIs and definitions.
 */
#include "./spmm.h"
#include <dgl/array.h>

namespace dgl {
namespace aten {

/*! \brief Generalized SpMM on Csr format. */
13
template <int XPU, typename IdType, int bits>
14
15
16
17
18
19
20
void SpMMCsr(const std::string& op, const std::string& reduce,
             const BcastOff& bcast,
             const CSRMatrix& csr,
             NDArray ufeat,
             NDArray efeat,
             NDArray out,
             std::vector<NDArray> out_aux) {
21
  const int64_t dim = bcast.out_len;
22
  if (reduce == "sum") {
23
24
25
26
    SWITCH_BITS(bits, DType, {
      SWITCH_OP(op, Op, {
        cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
      });
27
28
    });
  } else if (reduce == "max" || reduce == "min") {
29
30
    SWITCH_BITS(bits, DType, {
      SWITCH_OP(op, Op, {
31
32
33
34
35
        DType *out_off = out.Ptr<DType>();
        IdType* argX = Op::use_lhs ? static_cast<IdType*>(out_aux[0]->data) : nullptr;
        IdType* argW = Op::use_rhs ? static_cast<IdType*>(out_aux[1]->data) : nullptr;
        if (reduce == "max") {
          std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Max<DType>::zero);
36
37
          cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Max<DType>>(
              bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
38
39
        } else {
          std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Min<DType>::zero);
40
41
          cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Min<DType>>(
              bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
        }
      });
    });
  } else {
    LOG(FATAL) << "Unsupported SpMM reducer: " << reduce;
  }
}

/*! \brief Generalized SpMM on Csr format. */
template <int XPU, typename IdType, int bits>
void SpMMCsrHetero(const std::string& op, const std::string& reduce,
             const BcastOff& bcast,
             const std::vector<CSRMatrix>& vec_csr,
             const std::vector<NDArray>& vec_ufeat,
             const std::vector<NDArray>& vec_efeat,
57
58
             std::vector<NDArray>* vec_out,
             std::vector<std::vector<NDArray>>* out_aux,
59
60
61
62
63
64
65
66
67
68
69
70
71
             const std::vector<dgl_type_t>& ufeat_node_tids,
             const std::vector<dgl_type_t>& out_node_tids) {
  const int64_t dim = bcast.out_len;
  if (reduce == "sum") {
    SWITCH_BITS(bits, DType, {
      SWITCH_OP(op, Op, {
        /* Call  SpMM for each relation type */
        for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
          const dgl_type_t src_id = ufeat_node_tids[etype];
          const dgl_type_t dst_id = out_node_tids[etype];
          CSRMatrix csr = vec_csr[etype];
          NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
          NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
72
          NDArray out = (*vec_out)[dst_id];
73
74
75
76
77
78
79
          cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
        }
      });
    });
  } else if (reduce == "max" || reduce == "min") {
    SWITCH_BITS(bits, DType, {
      SWITCH_OP(op, Op, {
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        std::vector<bool> updated((*vec_out).size(), false);
        // TODO(Israt): use vector updated to fill(out...) too
        for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
          DType *out_off = (*vec_out)[out_node_tids[etype]].Ptr<DType>();
          if (reduce == "max")
            std::fill(out_off, out_off + vec_csr[etype].num_rows * dim, cpu::op::Max<DType>::zero);
          else
            std::fill(out_off, out_off + vec_csr[etype].num_rows * dim, cpu::op::Min<DType>::zero);
          const dgl_type_t dst_id = out_node_tids[etype];
          if (!updated[dst_id]) {
            updated[dst_id] = true;
            if (Op::use_lhs) {
              IdType *argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>();
              std::fill(argu_ntype, argu_ntype + vec_csr[etype].num_rows * dim, -1);
            }
            if (Op::use_rhs) {
              IdType *arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>();
              std::fill(arge_etype, arge_etype + vec_csr[etype].num_rows * dim, -1);
            }
          }
        }
101
102
103
104
105
        /* Call  SpMM for each relation type */
        for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
          const dgl_type_t src_id = ufeat_node_tids[etype];
          const dgl_type_t dst_id = out_node_tids[etype];
          CSRMatrix csr = vec_csr[etype];
106
          DType *out_off = (*vec_out)[out_node_tids[etype]].Ptr<DType>();
107
108
          NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
          NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
109
          NDArray out = (*vec_out)[dst_id];
110
          if (reduce == "max") {
111
112
113
            cpu::SpMMCmpCsrHetero<IdType, DType, Op, cpu::op::Max<DType>>(
                bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id], (*out_aux)[1][dst_id],
                (*out_aux)[2][dst_id], (*out_aux)[3][dst_id], src_id, etype);
114
          } else {
115
116
117
            cpu::SpMMCmpCsrHetero<IdType, DType, Op, cpu::op::Min<DType>>(
                bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id], (*out_aux)[1][dst_id],
                (*out_aux)[2][dst_id], (*out_aux)[3][dst_id], src_id, etype);
118
119
          }
        }
120
      });
121
122
123
124
125
126
    });
  } else {
    LOG(FATAL) << "Unsupported SpMM reducer: " << reduce;
  }
}

127
template void SpMMCsr<kDGLCPU, int32_t, 16>(
128
129
130
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
131
template void SpMMCsr<kDGLCPU, int64_t, 16>(
132
133
134
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
135
template void SpMMCsr<kDGLCPU, int32_t, 32>(
136
137
138
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
139
template void SpMMCsr<kDGLCPU, int64_t, 32>(
140
141
142
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
143
template void SpMMCsr<kDGLCPU, int32_t, 64>(
144
145
146
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
147
template void SpMMCsr<kDGLCPU, int64_t, 64>(
148
149
150
151
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);

152
template void SpMMCsrHetero<kDGLCPU, int32_t, 16>(
153
154
155
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
156
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
157
158
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
159
template void SpMMCsrHetero<kDGLCPU, int64_t, 16>(
160
161
162
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
163
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
164
165
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
166
template void SpMMCsrHetero<kDGLCPU, int32_t, 32>(
167
168
169
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
170
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
171
172
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
173
template void SpMMCsrHetero<kDGLCPU, int64_t, 32>(
174
175
176
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
177
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
178
179
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
180
template void SpMMCsrHetero<kDGLCPU, int32_t, 64>(
181
182
183
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
184
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
185
186
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
187
template void SpMMCsrHetero<kDGLCPU, int64_t, 64>(
188
189
190
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
191
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
192
193
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
194

195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
/*! \brief Edge_softmax_csr forward op on Csr format. */
template <int XPU, typename IdType, int bits>
void Edge_softmax_csr_forward(const std::string& op,
             const BcastOff& bcast,
             const CSRMatrix& csr,
             NDArray ufeat,
             NDArray efeat,
             NDArray out) {
  SWITCH_BITS(bits, DType, {
      SWITCH_OP(op, Op, {
        cpu::Edge_softmax_csr_forward<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
      });
    });
}

/*! \brief Edge_softmax_csr backward op on Csr format. */
template <int XPU, typename IdType, int bits>
void Edge_softmax_csr_backward(const std::string& op,
             const BcastOff& bcast,
             const CSRMatrix& csr,
             NDArray out,
             NDArray sds,
             NDArray back_out) {
  SWITCH_BITS(bits, DType, {
    SWITCH_OP(op, Op, {
      cpu::Edge_softmax_csr_backward<IdType, DType, Op>(bcast, csr, out, sds, back_out);
    });
  });
}

225
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, 16>(
226
227
228
    const std::string& op,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);
229
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, 16>(
230
231
232
    const std::string& op,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);
233
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, 32>(
234
235
236
    const std::string& op,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);
237
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, 32>(
238
239
240
    const std::string& op,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);
241
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, 64>(
242
243
244
    const std::string& op,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);
245
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, 64>(
246
247
248
249
    const std::string& op,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);

250
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, 16>(
251
252
253
    const std::string& op,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);
254
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, 16>(
255
256
257
    const std::string& op,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);
258
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, 32>(
259
260
261
    const std::string& op,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);
262
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, 32>(
263
264
265
    const std::string& op,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);
266
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, 64>(
267
268
269
    const std::string& op,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);
270
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, 64>(
271
272
273
274
    const std::string& op,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);

275
/*! \brief Generalized SpMM on Coo format. */
276
template <int XPU, typename IdType, int bits>
277
278
279
280
281
282
283
284
void SpMMCoo(const std::string& op, const std::string& reduce,
             const BcastOff& bcast,
             const COOMatrix& coo,
             NDArray ufeat,
             NDArray efeat,
             NDArray out,
             std::vector<NDArray> out_aux) {
  if (reduce == "sum") {
285
286
287
288
    SWITCH_BITS(bits, DType, {
      SWITCH_OP(op, Op, {
        cpu::SpMMSumCoo<IdType, DType, Op>(bcast, coo, ufeat, efeat, out);
      });
289
290
    });
  } else if (reduce == "max" || reduce == "min") {
291
292
293
294
295
296
297
298
299
    SWITCH_BITS(bits, DType, {
      SWITCH_OP(op, Op, {
        if (reduce == "max")
          cpu::SpMMCmpCoo<IdType, DType, Op, cpu::op::Max<DType>>(
              bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
        else
          cpu::SpMMCmpCoo<IdType, DType, Op, cpu::op::Min<DType>>(
              bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
      });
300
301
302
303
304
305
    });
  } else {
    LOG(FATAL) << "Unsupported SpMM reducer: " << reduce;
  }
}

306
template void SpMMCoo<kDGLCPU, int32_t, 16>(
307
308
309
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
310
template void SpMMCoo<kDGLCPU, int64_t, 16>(
311
312
313
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
314
template void SpMMCoo<kDGLCPU, int32_t, 32>(
315
316
317
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
318
template void SpMMCoo<kDGLCPU, int64_t, 32>(
319
320
321
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
322
template void SpMMCoo<kDGLCPU, int32_t, 64>(
323
324
325
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
326
template void SpMMCoo<kDGLCPU, int64_t, 64>(
327
328
329
330
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);

331
332
333

}  // namespace aten
}  // namespace dgl