spmm.cc 10.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
/*!
 *  Copyright (c) 2020 by Contributors
 * \file kernel/cpu/spmm.cc
 * \brief SPMM C APIs and definitions.
 */
#include "./spmm.h"
#include <dgl/array.h>

namespace dgl {
namespace aten {

/*! \brief Generalized SpMM on Csr format. */
13
template <int XPU, typename IdType, int bits>
14
15
16
17
18
19
20
void SpMMCsr(const std::string& op, const std::string& reduce,
             const BcastOff& bcast,
             const CSRMatrix& csr,
             NDArray ufeat,
             NDArray efeat,
             NDArray out,
             std::vector<NDArray> out_aux) {
21
  const int64_t dim = bcast.out_len;
22
  if (reduce == "sum") {
23
24
25
26
    SWITCH_BITS(bits, DType, {
      SWITCH_OP(op, Op, {
        cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
      });
27
28
    });
  } else if (reduce == "max" || reduce == "min") {
29
30
    SWITCH_BITS(bits, DType, {
      SWITCH_OP(op, Op, {
31
32
33
34
35
        DType *out_off = out.Ptr<DType>();
        IdType* argX = Op::use_lhs ? static_cast<IdType*>(out_aux[0]->data) : nullptr;
        IdType* argW = Op::use_rhs ? static_cast<IdType*>(out_aux[1]->data) : nullptr;
        if (reduce == "max") {
          std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Max<DType>::zero);
36
37
          cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Max<DType>>(
              bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
38
39
        } else {
          std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Min<DType>::zero);
40
41
          cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Min<DType>>(
              bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
        }
      });
    });
  } else {
    LOG(FATAL) << "Unsupported SpMM reducer: " << reduce;
  }
}

/*! \brief Generalized SpMM on Csr format. */
template <int XPU, typename IdType, int bits>
void SpMMCsrHetero(const std::string& op, const std::string& reduce,
             const BcastOff& bcast,
             const std::vector<CSRMatrix>& vec_csr,
             const std::vector<NDArray>& vec_ufeat,
             const std::vector<NDArray>& vec_efeat,
             std::vector<NDArray> vec_out,
             const std::vector<NDArray>& out_aux,
             const std::vector<dgl_type_t>& ufeat_node_tids,
             const std::vector<dgl_type_t>& out_node_tids) {
  const int64_t dim = bcast.out_len;
  if (reduce == "sum") {
    SWITCH_BITS(bits, DType, {
      SWITCH_OP(op, Op, {
        /* Call  SpMM for each relation type */
        for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
          const dgl_type_t src_id = ufeat_node_tids[etype];
          const dgl_type_t dst_id = out_node_tids[etype];
          CSRMatrix csr = vec_csr[etype];
          NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
          NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
          NDArray out = vec_out[dst_id];
          cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
        }
      });
    });
  } else if (reduce == "max" || reduce == "min") {
    SWITCH_BITS(bits, DType, {
      SWITCH_OP(op, Op, {
        /* Call  SpMM for each relation type */
        for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
          const dgl_type_t src_id = ufeat_node_tids[etype];
          const dgl_type_t dst_id = out_node_tids[etype];
          CSRMatrix csr = vec_csr[etype];
          DType *out_off = vec_out[out_node_tids[etype]].Ptr<DType>();
          NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
          NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
          NDArray out = vec_out[dst_id];
          if (reduce == "max") {
            std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Max<DType>::zero);
            cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Max<DType>>(
                bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
          } else {
            std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Min<DType>::zero);
            cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Min<DType>>(
                bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
          }
        }
99
      });
100
101
102
103
104
105
    });
  } else {
    LOG(FATAL) << "Unsupported SpMM reducer: " << reduce;
  }
}

106
template void SpMMCsr<kDLCPU, int32_t, 16>(
107
108
109
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
110
template void SpMMCsr<kDLCPU, int64_t, 16>(
111
112
113
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
114
template void SpMMCsr<kDLCPU, int32_t, 32>(
115
116
117
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
118
template void SpMMCsr<kDLCPU, int64_t, 32>(
119
120
121
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
122
123
124
125
126
127
128
129
130
template void SpMMCsr<kDLCPU, int32_t, 64>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLCPU, int64_t, 64>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);

131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
template void SpMMCsrHetero<kDLCPU, int32_t, 16>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
    std::vector<NDArray> out, const std::vector<NDArray>& out_aux,
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int64_t, 16>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
    std::vector<NDArray> out, const std::vector<NDArray>& out_aux,
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int32_t, 32>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
    std::vector<NDArray> out, const std::vector<NDArray>& out_aux,
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int64_t, 32>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
    std::vector<NDArray> out, const std::vector<NDArray>& out_aux,
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int32_t, 64>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
    std::vector<NDArray> out, const std::vector<NDArray>& out_aux,
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int64_t, 64>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
    std::vector<NDArray> out, const std::vector<NDArray>& out_aux,
    const std::vector<dgl_type_t>& ufeat_node_tids,
    const std::vector<dgl_type_t>& out_node_tids);
173
174

/*! \brief Generalized SpMM on Coo format. */
175
template <int XPU, typename IdType, int bits>
176
177
178
179
180
181
182
183
void SpMMCoo(const std::string& op, const std::string& reduce,
             const BcastOff& bcast,
             const COOMatrix& coo,
             NDArray ufeat,
             NDArray efeat,
             NDArray out,
             std::vector<NDArray> out_aux) {
  if (reduce == "sum") {
184
185
186
187
    SWITCH_BITS(bits, DType, {
      SWITCH_OP(op, Op, {
        cpu::SpMMSumCoo<IdType, DType, Op>(bcast, coo, ufeat, efeat, out);
      });
188
189
    });
  } else if (reduce == "max" || reduce == "min") {
190
191
192
193
194
195
196
197
198
    SWITCH_BITS(bits, DType, {
      SWITCH_OP(op, Op, {
        if (reduce == "max")
          cpu::SpMMCmpCoo<IdType, DType, Op, cpu::op::Max<DType>>(
              bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
        else
          cpu::SpMMCmpCoo<IdType, DType, Op, cpu::op::Min<DType>>(
              bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
      });
199
200
201
202
203
204
    });
  } else {
    LOG(FATAL) << "Unsupported SpMM reducer: " << reduce;
  }
}

205
template void SpMMCoo<kDLCPU, int32_t, 16>(
206
207
208
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
209
template void SpMMCoo<kDLCPU, int64_t, 16>(
210
211
212
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
213
template void SpMMCoo<kDLCPU, int32_t, 32>(
214
215
216
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
217
template void SpMMCoo<kDLCPU, int64_t, 32>(
218
219
220
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
221
222
223
224
225
226
227
228
229
template void SpMMCoo<kDLCPU, int32_t, 64>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLCPU, int64_t, 64>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);

230
231
232

}  // namespace aten
}  // namespace dgl