spmm.cc 13.9 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
/**
3
 *  Copyright (c) 2020 by Contributors
4
5
 * @file kernel/cpu/spmm.cc
 * @brief SPMM C APIs and definitions.
6
 */
sangwzh's avatar
sangwzh committed
7
#include "spmm.h"
8

9
10
11
12
13
#include <dgl/array.h>

namespace dgl {
namespace aten {

14
/** @brief Generalized SpMM on Csr format. */
15
template <int XPU, typename IdType, typename DType>
16
17
18
19
void SpMMCsr(
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux) {
20
  const int64_t dim = bcast.out_len;
21
  if (reduce == "sum") {
22
23
    SWITCH_OP(op, Op, {
      cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
24
25
    });
  } else if (reduce == "max" || reduce == "min") {
26
    SWITCH_OP(op, Op, {
27
      DType* out_off = out.Ptr<DType>();
28
      if (reduce == "max") {
29
30
        std::fill(
            out_off, out_off + csr.num_rows * dim, cpu::op::Max<DType>::zero);
31
32
33
        cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Max<DType>>(
            bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
      } else {
34
35
        std::fill(
            out_off, out_off + csr.num_rows * dim, cpu::op::Min<DType>::zero);
36
37
38
        cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Min<DType>>(
            bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
      }
39
40
41
42
43
44
    });
  } else {
    LOG(FATAL) << "Unsupported SpMM reducer: " << reduce;
  }
}

45
/** @brief Generalized SpMM on Csr format. */
46
template <int XPU, typename IdType, typename DType>
47
48
49
50
51
52
53
54
void SpMMCsrHetero(
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr,
    const std::vector<NDArray>& vec_ufeat,
    const std::vector<NDArray>& vec_efeat, std::vector<NDArray>* vec_out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids) {
55
56
  const int64_t dim = bcast.out_len;
  if (reduce == "sum") {
57
58
59
60
61
62
    SWITCH_OP(op, Op, {
      /* Call  SpMM for each relation type */
      for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
        const dgl_type_t src_id = ufeat_node_tids[etype];
        const dgl_type_t dst_id = out_node_tids[etype];
        CSRMatrix csr = vec_csr[etype];
63
64
65
66
        NDArray ufeat =
            (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
        NDArray efeat =
            (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
67
68
69
        NDArray out = (*vec_out)[dst_id];
        cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
      }
70
71
    });
  } else if (reduce == "max" || reduce == "min") {
72
73
74
75
    SWITCH_OP(op, Op, {
      std::vector<bool> updated((*vec_out).size(), false);
      // TODO(Israt): use vector updated to fill(out...) too
      for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
76
        DType* out_off = (*vec_out)[out_node_tids[etype]].Ptr<DType>();
77
        if (reduce == "max")
78
79
80
          std::fill(
              out_off, out_off + vec_csr[etype].num_rows * dim,
              cpu::op::Max<DType>::zero);
81
        else
82
83
84
          std::fill(
              out_off, out_off + vec_csr[etype].num_rows * dim,
              cpu::op::Min<DType>::zero);
85
86
87
88
        const dgl_type_t dst_id = out_node_tids[etype];
        if (!updated[dst_id]) {
          updated[dst_id] = true;
          if (Op::use_lhs) {
89
90
91
            IdType* argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>();
            std::fill(
                argu_ntype, argu_ntype + vec_csr[etype].num_rows * dim, -1);
92
          }
93
          if (Op::use_rhs) {
94
95
96
            IdType* arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>();
            std::fill(
                arge_etype, arge_etype + vec_csr[etype].num_rows * dim, -1);
97
98
          }
        }
99
100
101
102
103
104
      }
      /* Call  SpMM for each relation type */
      for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
        const dgl_type_t src_id = ufeat_node_tids[etype];
        const dgl_type_t dst_id = out_node_tids[etype];
        CSRMatrix csr = vec_csr[etype];
105
106
107
108
        NDArray ufeat =
            (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
        NDArray efeat =
            (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
109
110
111
        NDArray out = (*vec_out)[dst_id];
        if (reduce == "max") {
          cpu::SpMMCmpCsrHetero<IdType, DType, Op, cpu::op::Max<DType>>(
112
113
114
              bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id],
              (*out_aux)[1][dst_id], (*out_aux)[2][dst_id],
              (*out_aux)[3][dst_id], src_id, etype);
115
116
        } else {
          cpu::SpMMCmpCsrHetero<IdType, DType, Op, cpu::op::Min<DType>>(
117
118
119
              bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id],
              (*out_aux)[1][dst_id], (*out_aux)[2][dst_id],
              (*out_aux)[3][dst_id], src_id, etype);
120
121
        }
      }
122
123
124
125
126
127
    });
  } else {
    LOG(FATAL) << "Unsupported SpMM reducer: " << reduce;
  }
}

128
129
130
131
132
133
134
135
template void SpMMCsr<kDGLCPU, int32_t, BFloat16>(
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCPU, int64_t, BFloat16>(
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
136
template void SpMMCsr<kDGLCPU, int32_t, float>(
137
138
139
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
140
template void SpMMCsr<kDGLCPU, int64_t, float>(
141
142
143
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
144
template void SpMMCsr<kDGLCPU, int32_t, double>(
145
146
147
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
148
template void SpMMCsr<kDGLCPU, int64_t, double>(
149
150
151
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
152

153
154
155
156
157
158
159
160
161
162
163
164
165
166
template void SpMMCsrHetero<kDGLCPU, int32_t, BFloat16>(
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDGLCPU, int64_t, BFloat16>(
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
167
template void SpMMCsrHetero<kDGLCPU, int32_t, float>(
168
169
170
171
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
172
173
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
174
template void SpMMCsrHetero<kDGLCPU, int64_t, float>(
175
176
177
178
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
179
180
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
181
template void SpMMCsrHetero<kDGLCPU, int32_t, double>(
182
183
184
185
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
186
187
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
188
template void SpMMCsrHetero<kDGLCPU, int64_t, double>(
189
190
191
192
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
193
194
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
195

196
/** @brief Edge_softmax_csr forward op on Csr format. */
197
template <int XPU, typename IdType, typename DType>
198
199
200
void Edge_softmax_csr_forward(
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out) {
201
  SWITCH_OP(op, Op, {
202
203
    cpu::Edge_softmax_csr_forward<IdType, DType, Op>(
        bcast, csr, ufeat, efeat, out);
204
  });
205
206
}

207
/** @brief Edge_softmax_csr backward op on Csr format. */
208
template <int XPU, typename IdType, typename DType>
209
210
211
void Edge_softmax_csr_backward(
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
    NDArray out, NDArray sds, NDArray back_out) {
212
  SWITCH_OP(op, Op, {
213
214
    cpu::Edge_softmax_csr_backward<IdType, DType, Op>(
        bcast, csr, out, sds, back_out);
215
216
  });
}
217
218
219
220
221
222
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, BFloat16>(
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, BFloat16>(
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);
223
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, float>(
224
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
225
    NDArray ufeat, NDArray efeat, NDArray out);
226
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, float>(
227
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
228
    NDArray ufeat, NDArray efeat, NDArray out);
229
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, double>(
230
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
231
    NDArray ufeat, NDArray efeat, NDArray out);
232
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, double>(
233
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
234
235
    NDArray ufeat, NDArray efeat, NDArray out);

236
237
238
239
240
241
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, BFloat16>(
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, BFloat16>(
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out);
242
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, float>(
243
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
244
    NDArray ufeat, NDArray efeat, NDArray out);
245
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, float>(
246
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
247
    NDArray ufeat, NDArray efeat, NDArray out);
248
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, double>(
249
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
250
    NDArray ufeat, NDArray efeat, NDArray out);
251
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, double>(
252
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
253
254
    NDArray ufeat, NDArray efeat, NDArray out);

255
/** @brief Generalized SpMM on Coo format. */
256
template <int XPU, typename IdType, typename DType>
257
258
259
260
void SpMMCoo(
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux) {
261
  if (reduce == "sum") {
262
263
    SWITCH_OP(op, Op, {
      cpu::SpMMSumCoo<IdType, DType, Op>(bcast, coo, ufeat, efeat, out);
264
265
    });
  } else if (reduce == "max" || reduce == "min") {
266
267
268
269
270
271
272
    SWITCH_OP(op, Op, {
      if (reduce == "max")
        cpu::SpMMCmpCoo<IdType, DType, Op, cpu::op::Max<DType>>(
            bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
      else
        cpu::SpMMCmpCoo<IdType, DType, Op, cpu::op::Min<DType>>(
            bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
273
274
275
276
277
278
    });
  } else {
    LOG(FATAL) << "Unsupported SpMM reducer: " << reduce;
  }
}

279
280
281
282
283
284
285
286
template void SpMMCoo<kDGLCPU, int32_t, BFloat16>(
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCPU, int64_t, BFloat16>(
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
287
template void SpMMCoo<kDGLCPU, int32_t, float>(
288
289
290
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
291
template void SpMMCoo<kDGLCPU, int64_t, float>(
292
293
294
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
295
template void SpMMCoo<kDGLCPU, int32_t, double>(
296
297
298
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
299
template void SpMMCoo<kDGLCPU, int64_t, double>(
300
301
302
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
303
304
305

}  // namespace aten
}  // namespace dgl