spmm.cc 11.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
/*!
 *  Copyright (c) 2020 by Contributors
 * \file kernel/cpu/spmm.cc
 * \brief SPMM C APIs and definitions.
 */
#include "./spmm.h"
#include <dgl/array.h>

namespace dgl {
namespace aten {

/*! \brief Generalized SpMM on Csr format. */
13
template <int XPU, typename IdType, int bits>
14
15
16
17
18
19
20
void SpMMCsr(const std::string& op, const std::string& reduce,
             const BcastOff& bcast,
             const CSRMatrix& csr,
             NDArray ufeat,
             NDArray efeat,
             NDArray out,
             std::vector<NDArray> out_aux) {
21
  const int64_t dim = bcast.out_len;
22
  if (reduce == "sum") {
23
24
25
26
    SWITCH_BITS(bits, DType, {
      SWITCH_OP(op, Op, {
        cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
      });
27
28
    });
  } else if (reduce == "max" || reduce == "min") {
29
30
    SWITCH_BITS(bits, DType, {
      SWITCH_OP(op, Op, {
31
32
33
34
35
        DType *out_off = out.Ptr<DType>();
        IdType* argX = Op::use_lhs ? static_cast<IdType*>(out_aux[0]->data) : nullptr;
        IdType* argW = Op::use_rhs ? static_cast<IdType*>(out_aux[1]->data) : nullptr;
        if (reduce == "max") {
          std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Max<DType>::zero);
36
37
          cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Max<DType>>(
              bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
38
39
        } else {
          std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Min<DType>::zero);
40
41
          cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Min<DType>>(
              bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
        }
      });
    });
  } else {
    LOG(FATAL) << "Unsupported SpMM reducer: " << reduce;
  }
}

/*! \brief Generalized SpMM on Csr format. */
template <int XPU, typename IdType, int bits>
void SpMMCsrHetero(const std::string& op, const std::string& reduce,
             const BcastOff& bcast,
             const std::vector<CSRMatrix>& vec_csr,
             const std::vector<NDArray>& vec_ufeat,
             const std::vector<NDArray>& vec_efeat,
57
58
             std::vector<NDArray>* vec_out,
             std::vector<std::vector<NDArray>>* out_aux,
59
60
61
62
63
64
65
66
67
68
69
70
71
             const std::vector<dgl_type_t>& ufeat_node_tids,
             const std::vector<dgl_type_t>& out_node_tids) {
  const int64_t dim = bcast.out_len;
  if (reduce == "sum") {
    SWITCH_BITS(bits, DType, {
      SWITCH_OP(op, Op, {
        /* Call  SpMM for each relation type */
        for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
          const dgl_type_t src_id = ufeat_node_tids[etype];
          const dgl_type_t dst_id = out_node_tids[etype];
          CSRMatrix csr = vec_csr[etype];
          NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
          NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
72
          NDArray out = (*vec_out)[dst_id];
73
74
75
76
77
78
79
          cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
        }
      });
    });
  } else if (reduce == "max" || reduce == "min") {
    SWITCH_BITS(bits, DType, {
      SWITCH_OP(op, Op, {
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        std::vector<bool> updated((*vec_out).size(), false);
        // TODO(Israt): use vector updated to fill(out...) too
        for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
          DType *out_off = (*vec_out)[out_node_tids[etype]].Ptr<DType>();
          if (reduce == "max")
            std::fill(out_off, out_off + vec_csr[etype].num_rows * dim, cpu::op::Max<DType>::zero);
          else
            std::fill(out_off, out_off + vec_csr[etype].num_rows * dim, cpu::op::Min<DType>::zero);
          const dgl_type_t dst_id = out_node_tids[etype];
          if (!updated[dst_id]) {
            updated[dst_id] = true;
            if (Op::use_lhs) {
              IdType *argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>();
              std::fill(argu_ntype, argu_ntype + vec_csr[etype].num_rows * dim, -1);
            }
            if (Op::use_rhs) {
              IdType *arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>();
              std::fill(arge_etype, arge_etype + vec_csr[etype].num_rows * dim, -1);
            }
          }
        }
101
102
103
104
105
        /* Call  SpMM for each relation type */
        for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
          const dgl_type_t src_id = ufeat_node_tids[etype];
          const dgl_type_t dst_id = out_node_tids[etype];
          CSRMatrix csr = vec_csr[etype];
106
          DType *out_off = (*vec_out)[out_node_tids[etype]].Ptr<DType>();
107
108
          NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
          NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
109
          NDArray out = (*vec_out)[dst_id];
110
          if (reduce == "max") {
111
112
113
            cpu::SpMMCmpCsrHetero<IdType, DType, Op, cpu::op::Max<DType>>(
                bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id], (*out_aux)[1][dst_id],
                (*out_aux)[2][dst_id], (*out_aux)[3][dst_id], src_id, etype);
114
          } else {
115
116
117
            cpu::SpMMCmpCsrHetero<IdType, DType, Op, cpu::op::Min<DType>>(
                bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id], (*out_aux)[1][dst_id],
                (*out_aux)[2][dst_id], (*out_aux)[3][dst_id], src_id, etype);
118
119
          }
        }
120
      });
121
122
123
124
125
126
    });
  } else {
    LOG(FATAL) << "Unsupported SpMM reducer: " << reduce;
  }
}

127
template void SpMMCsr<kDLCPU, int32_t, 16>(
128
129
130
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
131
template void SpMMCsr<kDLCPU, int64_t, 16>(
132
133
134
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
135
template void SpMMCsr<kDLCPU, int32_t, 32>(
136
137
138
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
139
template void SpMMCsr<kDLCPU, int64_t, 32>(
140
141
142
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
143
144
145
146
147
148
149
150
151
template void SpMMCsr<kDLCPU, int32_t, 64>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLCPU, int64_t, 64>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);

152
153
154
155
template void SpMMCsrHetero<kDLCPU, int32_t, 16>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
156
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
157
158
159
160
161
162
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int64_t, 16>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
163
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
164
165
166
167
168
169
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int32_t, 32>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
170
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
171
172
173
174
175
176
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int64_t, 32>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
177
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
178
179
180
181
182
183
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int32_t, 64>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
184
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
185
186
187
188
189
190
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int64_t, 64>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
191
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
192
193
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
194
195

/*! \brief Generalized SpMM on Coo format. */
196
template <int XPU, typename IdType, int bits>
197
198
199
200
201
202
203
204
void SpMMCoo(const std::string& op, const std::string& reduce,
             const BcastOff& bcast,
             const COOMatrix& coo,
             NDArray ufeat,
             NDArray efeat,
             NDArray out,
             std::vector<NDArray> out_aux) {
  if (reduce == "sum") {
205
206
207
208
    SWITCH_BITS(bits, DType, {
      SWITCH_OP(op, Op, {
        cpu::SpMMSumCoo<IdType, DType, Op>(bcast, coo, ufeat, efeat, out);
      });
209
210
    });
  } else if (reduce == "max" || reduce == "min") {
211
212
213
214
215
216
217
218
219
    SWITCH_BITS(bits, DType, {
      SWITCH_OP(op, Op, {
        if (reduce == "max")
          cpu::SpMMCmpCoo<IdType, DType, Op, cpu::op::Max<DType>>(
              bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
        else
          cpu::SpMMCmpCoo<IdType, DType, Op, cpu::op::Min<DType>>(
              bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
      });
220
221
222
223
224
225
    });
  } else {
    LOG(FATAL) << "Unsupported SpMM reducer: " << reduce;
  }
}

226
template void SpMMCoo<kDLCPU, int32_t, 16>(
227
228
229
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
230
template void SpMMCoo<kDLCPU, int64_t, 16>(
231
232
233
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
234
template void SpMMCoo<kDLCPU, int32_t, 32>(
235
236
237
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
238
template void SpMMCoo<kDLCPU, int64_t, 32>(
239
240
241
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
242
243
244
245
246
247
248
249
250
template void SpMMCoo<kDLCPU, int32_t, 64>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLCPU, int64_t, 64>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);

251
252
253

}  // namespace aten
}  // namespace dgl