spmm.cc 11.5 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file kernel/cpu/spmm.cc
 * @brief SPMM C APIs and definitions.
5
6
7
8
9
10
11
 */
#include "./spmm.h"
#include <dgl/array.h>

namespace dgl {
namespace aten {

12
/** @brief Generalized SpMM on Csr format. */
13
template <int XPU, typename IdType, typename DType>
14
15
16
17
18
19
20
void SpMMCsr(const std::string& op, const std::string& reduce,
             const BcastOff& bcast,
             const CSRMatrix& csr,
             NDArray ufeat,
             NDArray efeat,
             NDArray out,
             std::vector<NDArray> out_aux) {
21
  const int64_t dim = bcast.out_len;
22
  if (reduce == "sum") {
23
24
    SWITCH_OP(op, Op, {
      cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
25
26
    });
  } else if (reduce == "max" || reduce == "min") {
27
28
29
30
31
32
33
34
35
36
37
    SWITCH_OP(op, Op, {
      DType *out_off = out.Ptr<DType>();
      if (reduce == "max") {
        std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Max<DType>::zero);
        cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Max<DType>>(
            bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
      } else {
        std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Min<DType>::zero);
        cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Min<DType>>(
            bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
      }
38
39
40
41
42
43
    });
  } else {
    LOG(FATAL) << "Unsupported SpMM reducer: " << reduce;
  }
}

44
/** @brief Generalized SpMM on Csr format. */
45
template <int XPU, typename IdType, typename DType>
46
47
48
49
50
void SpMMCsrHetero(const std::string& op, const std::string& reduce,
             const BcastOff& bcast,
             const std::vector<CSRMatrix>& vec_csr,
             const std::vector<NDArray>& vec_ufeat,
             const std::vector<NDArray>& vec_efeat,
51
52
             std::vector<NDArray>* vec_out,
             std::vector<std::vector<NDArray>>* out_aux,
53
54
55
56
             const std::vector<dgl_type_t>& ufeat_node_tids,
             const std::vector<dgl_type_t>& out_node_tids) {
  const int64_t dim = bcast.out_len;
  if (reduce == "sum") {
57
58
59
60
61
62
63
64
65
66
67
    SWITCH_OP(op, Op, {
      /* Call  SpMM for each relation type */
      for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
        const dgl_type_t src_id = ufeat_node_tids[etype];
        const dgl_type_t dst_id = out_node_tids[etype];
        CSRMatrix csr = vec_csr[etype];
        NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
        NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
        NDArray out = (*vec_out)[dst_id];
        cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
      }
68
69
    });
  } else if (reduce == "max" || reduce == "min") {
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    SWITCH_OP(op, Op, {
      std::vector<bool> updated((*vec_out).size(), false);
      // TODO(Israt): use vector updated to fill(out...) too
      for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
        DType *out_off = (*vec_out)[out_node_tids[etype]].Ptr<DType>();
        if (reduce == "max")
          std::fill(out_off, out_off + vec_csr[etype].num_rows * dim, cpu::op::Max<DType>::zero);
        else
          std::fill(out_off, out_off + vec_csr[etype].num_rows * dim, cpu::op::Min<DType>::zero);
        const dgl_type_t dst_id = out_node_tids[etype];
        if (!updated[dst_id]) {
          updated[dst_id] = true;
          if (Op::use_lhs) {
            IdType *argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>();
            std::fill(argu_ntype, argu_ntype + vec_csr[etype].num_rows * dim, -1);
85
          }
86
87
88
          if (Op::use_rhs) {
            IdType *arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>();
            std::fill(arge_etype, arge_etype + vec_csr[etype].num_rows * dim, -1);
89
90
          }
        }
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
      }
      /* Call  SpMM for each relation type */
      for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
        const dgl_type_t src_id = ufeat_node_tids[etype];
        const dgl_type_t dst_id = out_node_tids[etype];
        CSRMatrix csr = vec_csr[etype];
        NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
        NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
        NDArray out = (*vec_out)[dst_id];
        if (reduce == "max") {
          cpu::SpMMCmpCsrHetero<IdType, DType, Op, cpu::op::Max<DType>>(
              bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id], (*out_aux)[1][dst_id],
              (*out_aux)[2][dst_id], (*out_aux)[3][dst_id], src_id, etype);
        } else {
          cpu::SpMMCmpCsrHetero<IdType, DType, Op, cpu::op::Min<DType>>(
              bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id], (*out_aux)[1][dst_id],
              (*out_aux)[2][dst_id], (*out_aux)[3][dst_id], src_id, etype);
        }
      }
110
111
112
113
114
115
    });
  } else {
    LOG(FATAL) << "Unsupported SpMM reducer: " << reduce;
  }
}

116
template void SpMMCsr<kDGLCPU, int32_t, float>(
117
118
119
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
120
template void SpMMCsr<kDGLCPU, int64_t, float>(
121
122
123
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
124
template void SpMMCsr<kDGLCPU, int32_t, double>(
125
126
127
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
128
template void SpMMCsr<kDGLCPU, int64_t, double>(
129
130
131
132
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);

133
template void SpMMCsrHetero<kDGLCPU, int32_t, float>(
134
135
136
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
137
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
138
139
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
140
template void SpMMCsrHetero<kDGLCPU, int64_t, float>(
141
142
143
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
144
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
145
146
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
147
template void SpMMCsrHetero<kDGLCPU, int32_t, double>(
148
149
150
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
151
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
152
153
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
154
template void SpMMCsrHetero<kDGLCPU, int64_t, double>(
155
156
157
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
158
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
159
160
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
161

162
/** @brief Edge_softmax_csr forward op on Csr format. */
163
template <int XPU, typename IdType, typename DType>
164
165
166
167
168
169
void Edge_softmax_csr_forward(const std::string& op,
             const BcastOff& bcast,
             const CSRMatrix& csr,
             NDArray ufeat,
             NDArray efeat,
             NDArray out) {
170
171
172
  SWITCH_OP(op, Op, {
    cpu::Edge_softmax_csr_forward<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
  });
173
174
}

175
/** @brief Edge_softmax_csr backward op on Csr format. */
176
template <int XPU, typename IdType, typename DType>
177
178
179
180
181
182
void Edge_softmax_csr_backward(const std::string& op,
             const BcastOff& bcast,
             const CSRMatrix& csr,
             NDArray out,
             NDArray sds,
             NDArray back_out) {
183
184
  SWITCH_OP(op, Op, {
    cpu::Edge_softmax_csr_backward<IdType, DType, Op>(bcast, csr, out, sds, back_out);
185
186
187
  });
}

188
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, float>(
189
190
191
    const std::string& op,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);
192
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, float>(
193
194
195
    const std::string& op,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);
196
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, double>(
197
198
199
    const std::string& op,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);
200
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, double>(
201
202
203
204
    const std::string& op,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);

205
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, float>(
206
207
208
    const std::string& op,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);
209
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, float>(
210
211
212
    const std::string& op,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);
213
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, double>(
214
215
216
    const std::string& op,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);
217
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, double>(
218
219
220
221
    const std::string& op,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);

222
/** @brief Generalized SpMM on Coo format. */
223
template <int XPU, typename IdType, typename DType>
224
225
226
227
228
229
230
231
void SpMMCoo(const std::string& op, const std::string& reduce,
             const BcastOff& bcast,
             const COOMatrix& coo,
             NDArray ufeat,
             NDArray efeat,
             NDArray out,
             std::vector<NDArray> out_aux) {
  if (reduce == "sum") {
232
233
    SWITCH_OP(op, Op, {
      cpu::SpMMSumCoo<IdType, DType, Op>(bcast, coo, ufeat, efeat, out);
234
235
    });
  } else if (reduce == "max" || reduce == "min") {
236
237
238
239
240
241
242
    SWITCH_OP(op, Op, {
      if (reduce == "max")
        cpu::SpMMCmpCoo<IdType, DType, Op, cpu::op::Max<DType>>(
            bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
      else
        cpu::SpMMCmpCoo<IdType, DType, Op, cpu::op::Min<DType>>(
            bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
243
244
245
246
247
248
    });
  } else {
    LOG(FATAL) << "Unsupported SpMM reducer: " << reduce;
  }
}

249
template void SpMMCoo<kDGLCPU, int32_t, float>(
250
251
252
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
253
template void SpMMCoo<kDGLCPU, int64_t, float>(
254
255
256
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
257
template void SpMMCoo<kDGLCPU, int32_t, double>(
258
259
260
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
261
template void SpMMCoo<kDGLCPU, int64_t, double>(
262
263
264
265
266
267
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);

}  // namespace aten
}  // namespace dgl