spmm.cc 11.4 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file kernel/cpu/spmm.cc
 * @brief SPMM C APIs and definitions.
5
6
 */
#include "./spmm.h"
7

8
9
10
11
12
#include <dgl/array.h>

namespace dgl {
namespace aten {

13
/** @brief Generalized SpMM on Csr format. */
14
template <int XPU, typename IdType, typename DType>
15
16
17
18
void SpMMCsr(
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux) {
19
  const int64_t dim = bcast.out_len;
20
  if (reduce == "sum") {
21
22
    SWITCH_OP(op, Op, {
      cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
23
24
    });
  } else if (reduce == "max" || reduce == "min") {
25
    SWITCH_OP(op, Op, {
26
      DType* out_off = out.Ptr<DType>();
27
      if (reduce == "max") {
28
29
        std::fill(
            out_off, out_off + csr.num_rows * dim, cpu::op::Max<DType>::zero);
30
31
32
        cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Max<DType>>(
            bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
      } else {
33
34
        std::fill(
            out_off, out_off + csr.num_rows * dim, cpu::op::Min<DType>::zero);
35
36
37
        cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Min<DType>>(
            bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
      }
38
39
40
41
42
43
    });
  } else {
    LOG(FATAL) << "Unsupported SpMM reducer: " << reduce;
  }
}

44
/** @brief Generalized SpMM on Csr format. */
45
template <int XPU, typename IdType, typename DType>
46
47
48
49
50
51
52
53
void SpMMCsrHetero(
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr,
    const std::vector<NDArray>& vec_ufeat,
    const std::vector<NDArray>& vec_efeat, std::vector<NDArray>* vec_out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids) {
54
55
  const int64_t dim = bcast.out_len;
  if (reduce == "sum") {
56
57
58
59
60
61
    SWITCH_OP(op, Op, {
      /* Call  SpMM for each relation type */
      for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
        const dgl_type_t src_id = ufeat_node_tids[etype];
        const dgl_type_t dst_id = out_node_tids[etype];
        CSRMatrix csr = vec_csr[etype];
62
63
64
65
        NDArray ufeat =
            (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
        NDArray efeat =
            (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
66
67
68
        NDArray out = (*vec_out)[dst_id];
        cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
      }
69
70
    });
  } else if (reduce == "max" || reduce == "min") {
71
72
73
74
    SWITCH_OP(op, Op, {
      std::vector<bool> updated((*vec_out).size(), false);
      // TODO(Israt): use vector updated to fill(out...) too
      for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
75
        DType* out_off = (*vec_out)[out_node_tids[etype]].Ptr<DType>();
76
        if (reduce == "max")
77
78
79
          std::fill(
              out_off, out_off + vec_csr[etype].num_rows * dim,
              cpu::op::Max<DType>::zero);
80
        else
81
82
83
          std::fill(
              out_off, out_off + vec_csr[etype].num_rows * dim,
              cpu::op::Min<DType>::zero);
84
85
86
87
        const dgl_type_t dst_id = out_node_tids[etype];
        if (!updated[dst_id]) {
          updated[dst_id] = true;
          if (Op::use_lhs) {
88
89
90
            IdType* argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>();
            std::fill(
                argu_ntype, argu_ntype + vec_csr[etype].num_rows * dim, -1);
91
          }
92
          if (Op::use_rhs) {
93
94
95
            IdType* arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>();
            std::fill(
                arge_etype, arge_etype + vec_csr[etype].num_rows * dim, -1);
96
97
          }
        }
98
99
100
101
102
103
      }
      /* Call  SpMM for each relation type */
      for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
        const dgl_type_t src_id = ufeat_node_tids[etype];
        const dgl_type_t dst_id = out_node_tids[etype];
        CSRMatrix csr = vec_csr[etype];
104
105
106
107
        NDArray ufeat =
            (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
        NDArray efeat =
            (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
108
109
110
        NDArray out = (*vec_out)[dst_id];
        if (reduce == "max") {
          cpu::SpMMCmpCsrHetero<IdType, DType, Op, cpu::op::Max<DType>>(
111
112
113
              bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id],
              (*out_aux)[1][dst_id], (*out_aux)[2][dst_id],
              (*out_aux)[3][dst_id], src_id, etype);
114
115
        } else {
          cpu::SpMMCmpCsrHetero<IdType, DType, Op, cpu::op::Min<DType>>(
116
117
118
              bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id],
              (*out_aux)[1][dst_id], (*out_aux)[2][dst_id],
              (*out_aux)[3][dst_id], src_id, etype);
119
120
        }
      }
121
122
123
124
125
126
    });
  } else {
    LOG(FATAL) << "Unsupported SpMM reducer: " << reduce;
  }
}

127
template void SpMMCsr<kDGLCPU, int32_t, float>(
128
129
130
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
131
template void SpMMCsr<kDGLCPU, int64_t, float>(
132
133
134
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
135
template void SpMMCsr<kDGLCPU, int32_t, double>(
136
137
138
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
139
template void SpMMCsr<kDGLCPU, int64_t, double>(
140
141
142
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
143

144
template void SpMMCsrHetero<kDGLCPU, int32_t, float>(
145
146
147
148
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
149
150
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
151
template void SpMMCsrHetero<kDGLCPU, int64_t, float>(
152
153
154
155
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
156
157
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
158
template void SpMMCsrHetero<kDGLCPU, int32_t, double>(
159
160
161
162
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
163
164
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
165
template void SpMMCsrHetero<kDGLCPU, int64_t, double>(
166
167
168
169
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
170
171
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
172

173
/** @brief Edge_softmax_csr forward op on Csr format. */
174
template <int XPU, typename IdType, typename DType>
175
176
177
void Edge_softmax_csr_forward(
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out) {
178
  SWITCH_OP(op, Op, {
179
180
    cpu::Edge_softmax_csr_forward<IdType, DType, Op>(
        bcast, csr, ufeat, efeat, out);
181
  });
182
183
}

184
/** @brief Edge_softmax_csr backward op on Csr format. */
185
template <int XPU, typename IdType, typename DType>
186
187
188
void Edge_softmax_csr_backward(
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
    NDArray out, NDArray sds, NDArray back_out) {
189
  SWITCH_OP(op, Op, {
190
191
    cpu::Edge_softmax_csr_backward<IdType, DType, Op>(
        bcast, csr, out, sds, back_out);
192
193
194
  });
}

195
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, float>(
196
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
197
    NDArray ufeat, NDArray efeat, NDArray out);
198
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, float>(
199
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
200
    NDArray ufeat, NDArray efeat, NDArray out);
201
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, double>(
202
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
203
    NDArray ufeat, NDArray efeat, NDArray out);
204
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, double>(
205
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
206
207
    NDArray ufeat, NDArray efeat, NDArray out);

208
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, float>(
209
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
210
    NDArray ufeat, NDArray efeat, NDArray out);
211
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, float>(
212
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
213
    NDArray ufeat, NDArray efeat, NDArray out);
214
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, double>(
215
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
216
    NDArray ufeat, NDArray efeat, NDArray out);
217
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, double>(
218
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
219
220
    NDArray ufeat, NDArray efeat, NDArray out);

221
/** @brief Generalized SpMM on Coo format. */
222
template <int XPU, typename IdType, typename DType>
223
224
225
226
void SpMMCoo(
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux) {
227
  if (reduce == "sum") {
228
229
    SWITCH_OP(op, Op, {
      cpu::SpMMSumCoo<IdType, DType, Op>(bcast, coo, ufeat, efeat, out);
230
231
    });
  } else if (reduce == "max" || reduce == "min") {
232
233
234
235
236
237
238
    SWITCH_OP(op, Op, {
      if (reduce == "max")
        cpu::SpMMCmpCoo<IdType, DType, Op, cpu::op::Max<DType>>(
            bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
      else
        cpu::SpMMCmpCoo<IdType, DType, Op, cpu::op::Min<DType>>(
            bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
239
240
241
242
243
244
    });
  } else {
    LOG(FATAL) << "Unsupported SpMM reducer: " << reduce;
  }
}

245
template void SpMMCoo<kDGLCPU, int32_t, float>(
246
247
248
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
249
template void SpMMCoo<kDGLCPU, int64_t, float>(
250
251
252
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
253
template void SpMMCoo<kDGLCPU, int32_t, double>(
254
255
256
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
257
template void SpMMCoo<kDGLCPU, int64_t, double>(
258
259
260
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
261
262
263

}  // namespace aten
}  // namespace dgl