spmm.h 14 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cpu/spmm.h
 * \brief SPMM CPU kernel function header.
 */
#ifndef DGL_ARRAY_CPU_SPMM_H_
#define DGL_ARRAY_CPU_SPMM_H_

#include <dgl/array.h>
#include <dgl/bcast.h>
#include <algorithm>
12
13
14
15
#include <limits>
#include <memory>
#include "spmm_binary_ops.h"
#if !defined(_WIN32)
16
#ifdef USE_AVX
17
#include "intel/cpu_support.h"
18
19
20
#ifdef USE_LIBXSMM
#include "spmm_blocking_libxsmm.h"
#endif  // USE_LIBXSMM
21
22
#endif  // USE_AVX
#endif  // _WIN32
23
24
25
26
namespace dgl {
namespace aten {
namespace cpu {

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#if !defined(_WIN32)
#ifdef USE_AVX
/*!
 * \brief CPU kernel of SpMM on Csr format using Xbyak.
 * \param cpu_spec JIT'ed kernel
 * \param bcast Broadcast information.
 * \param csr The Csr matrix.
 * \param X The feature on source nodes.
 * \param W The feature on edges.
 * \param O The result feature on destination nodes.
 * \note it uses node parallel strategy, different threads are responsible
 *       for the computation of different nodes. For each edge, it uses the
 *       JIT'ed kernel.
 */
template <typename IdType, typename DType, typename Op>
void SpMMSumCsrXbyak(dgl::ElemWiseAddUpdate<Op>* cpu_spec, const BcastOff& bcast,
                     const CSRMatrix& csr, const DType* X, const DType* W, DType* O) {
  const bool has_idx = !IsNullArray(csr.data);
  const IdType* indptr = csr.indptr.Ptr<IdType>();
  const IdType* indices = csr.indices.Ptr<IdType>();
  const IdType* edges = csr.data.Ptr<IdType>();
  int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len, rhs_dim = bcast.rhs_len;
#pragma omp parallel for
  for (IdType rid = 0; rid < csr.num_rows; ++rid) {
    const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
    DType* out_off = O + rid * dim;
    for (IdType j = row_start; j < row_end; ++j) {
      const IdType cid = indices[j];
      const IdType eid = has_idx ? edges[j] : j;
      cpu_spec->run(out_off, X + cid * lhs_dim, W + eid * rhs_dim, dim);
    }
  }
}
#endif  // USE_AVX
#endif  // _WIN32

/*!
 * \brief Naive CPU kernel of SpMM on Csr format.
 * \param cpu_spec JIT'ed kernel
 * \param bcast Broadcast information.
 * \param csr The Csr matrix.
 * \param X The feature on source nodes.
 * \param W The feature on edges.
 * \param O The result feature on destination nodes.
 * \note it uses node parallel strategy, different threads are responsible
 *       for the computation of different nodes.
 */
template <typename IdType, typename DType, typename Op>
void SpMMSumCsrNaive(const BcastOff& bcast, const CSRMatrix& csr, const DType* X,
                     const DType* W, DType* O) {
  const bool has_idx = !IsNullArray(csr.data);
  const IdType* indptr = csr.indptr.Ptr<IdType>();
  const IdType* indices = csr.indices.Ptr<IdType>();
  const IdType* edges = csr.data.Ptr<IdType>();
  int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len, rhs_dim = bcast.rhs_len;
#pragma omp parallel for
  for (IdType rid = 0; rid < csr.num_rows; ++rid) {
    const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
    DType* out_off = O + rid * dim;
    for (IdType j = row_start; j < row_end; ++j) {
      const IdType cid = indices[j];
      const IdType eid = has_idx ? edges[j] : j;
      for (int64_t k = 0; k < dim; ++k) {
        const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
        const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
        const DType* lhs_off =
          Op::use_lhs ? X + cid * lhs_dim + lhs_add : nullptr;
        const DType* rhs_off =
          Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
        out_off[k] += Op::Call(lhs_off, rhs_off);
      }
    }
  }
}

102
103
104
105
106
107
108
109
110
111
112
/*!
 * \brief CPU kernel of SpMM on Csr format.
 * \param bcast Broadcast information.
 * \param csr The Csr matrix.
 * \param ufeat The feature on source nodes.
 * \param efeat The feature on edges.
 * \param out The result feature on destination nodes.
 * \note it uses node parallel strategy, different threads are responsible
 *       for the computation of different nodes.
 */
template <typename IdType, typename DType, typename Op>
113
114
void SpMMSumCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
                NDArray efeat, NDArray out) {
115
116
117
118
119
120
  const bool has_idx = !IsNullArray(csr.data);
  const IdType* indptr = csr.indptr.Ptr<IdType>();
  const IdType* indices = csr.indices.Ptr<IdType>();
  const IdType* edges = csr.data.Ptr<IdType>();
  const DType* X = ufeat.Ptr<DType>();
  const DType* W = efeat.Ptr<DType>();
121
  int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len, rhs_dim = bcast.rhs_len;
122
  DType* O = out.Ptr<DType>();
123
124
125
126
127
128
129
130
131
132
133
  CHECK_NOTNULL(indptr);
  CHECK_NOTNULL(O);
  if (Op::use_lhs) {
    CHECK_NOTNULL(indices);
    CHECK_NOTNULL(X);
  }
  if (Op::use_rhs) {
    if (has_idx)
      CHECK_NOTNULL(edges);
    CHECK_NOTNULL(W);
  }
134
#if !defined(_WIN32)
135
#ifdef USE_AVX
136
137
138
139
140
#ifdef USE_LIBXSMM
  const bool no_libxsmm =
       bcast.use_bcast || std::is_same<DType, double>::value;
  if (!no_libxsmm) {
    SpMMSumCsrLibxsmm<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
141
  } else {
142
143
144
145
146
147
148
149
150
151
152
153
#endif  // USE_LIBXSMM
    typedef dgl::ElemWiseAddUpdate<Op> ElemWiseUpd;
    /* Prepare an assembler kernel */
    static std::unique_ptr<ElemWiseUpd> asm_kernel_ptr(
        (dgl::IntelKernel<>::IsEnabled()) ? new ElemWiseUpd() : nullptr);
    /* Distribute the kernel among OMP threads */
    ElemWiseUpd* cpu_spec = (asm_kernel_ptr && asm_kernel_ptr->applicable())
      ? asm_kernel_ptr.get()
      : nullptr;
    if (cpu_spec && dim > 16 && !bcast.use_bcast) {
      SpMMSumCsrXbyak<IdType, DType, Op>(cpu_spec, bcast, csr, X, W, O);
    } else {
154
155
#endif  // USE_AVX
#endif  // _WIN32
156
    SpMMSumCsrNaive<IdType, DType, Op>(bcast, csr, X, W, O);
157
#if !defined(_WIN32)
158
#ifdef USE_AVX
159
160
    }
#ifdef USE_LIBXSMM
161
  }
162
#endif  // USE_LIBXSMM
163
164
#endif  // USE_AVX
#endif  // _WIN32
165
166
167
168
169
170
171
172
173
174
175
176
177
178
}

/*!
 * \brief CPU kernel of SpMM on Coo format.
 * \param bcast Broadcast information.
 * \param coo The Coo matrix.
 * \param ufeat The feature on source nodes.
 * \param efeat The feature on edges.
 * \param out The result feature on destination nodes.
 * \note it uses node parallel strategy, different threads are responsible
 *       for the computation of different nodes. To avoid possible data hazard,
 *       we use atomic operators in the reduction phase.
 */
template <typename IdType, typename DType, typename Op>
179
180
void SpMMSumCoo(const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat,
                NDArray efeat, NDArray out) {
181
182
183
184
185
186
  const bool has_idx = !IsNullArray(coo.data);
  const IdType* row = coo.row.Ptr<IdType>();
  const IdType* col = coo.col.Ptr<IdType>();
  const IdType* edges = coo.data.Ptr<IdType>();
  const DType* X = ufeat.Ptr<DType>();
  const DType* W = efeat.Ptr<DType>();
187
  int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len, rhs_dim = bcast.rhs_len;
188
189
190
191
192
193
194
195
196
  DType* O = out.Ptr<DType>();
  const int64_t nnz = coo.row->shape[0];
  // fill zero elements
  memset(O, 0, out.GetSize());
  // spmm
#pragma omp parallel for
  for (IdType i = 0; i < nnz; ++i) {
    const IdType rid = row[i];
    const IdType cid = col[i];
197
    const IdType eid = has_idx ? edges[i] : i;
198
199
200
201
    DType* out_off = O + cid * dim;
    for (int64_t k = 0; k < dim; ++k) {
      const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
      const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
202
203
204
205
      const DType* lhs_off =
        Op::use_lhs ? X + rid * lhs_dim + lhs_add : nullptr;
      const DType* rhs_off =
        Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
206
      const DType val = Op::Call(lhs_off, rhs_off);
207
      if (val != 0) {
208
#pragma omp atomic
209
210
        out_off[k] += val;
      }
211
212
213
214
215
216
217
218
219
220
221
    }
  }
}

/*!
 * \brief CPU kernel of SpMM-Min/Max on Csr format.
 * \param bcast Broadcast information.
 * \param csr The Csr matrix.
 * \param ufeat The feature on source nodes.
 * \param efeat The feature on edges.
 * \param out The result feature on destination nodes.
222
 * \param argu Arg-Min/Max on source nodes, which refers the source node indices
223
 *        correspond to the minimum/maximum values of reduction result on
224
225
226
227
228
229
230
 *        destination nodes. It's useful in computing gradients of Min/Max
 * reducer. \param arge Arg-Min/Max on edges. which refers the source node
 * indices correspond to the minimum/maximum values of reduction result on
 *        destination nodes. It's useful in computing gradients of Min/Max
 * reducer. \note It uses node parallel strategy, different threads are
 * responsible for the computation of different nodes. \note The result will
 * contain infinity for zero-degree nodes.
231
232
 */
template <typename IdType, typename DType, typename Op, typename Cmp>
233
234
void SpMMCmpCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
                NDArray efeat, NDArray out, NDArray argu, NDArray arge) {
235
236
237
  const bool has_idx = !IsNullArray(csr.data);
  const IdType* indptr = static_cast<IdType*>(csr.indptr->data);
  const IdType* indices = static_cast<IdType*>(csr.indices->data);
238
239
240
241
242
  const IdType* edges =
    has_idx ? static_cast<IdType*>(csr.data->data) : nullptr;
  const DType* X = Op::use_lhs ? static_cast<DType*>(ufeat->data) : nullptr;
  const DType* W = Op::use_rhs ? static_cast<DType*>(efeat->data) : nullptr;
  const int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len,
243
244
                rhs_dim = bcast.rhs_len;
  DType* O = static_cast<DType*>(out->data);
245
246
  IdType* argX = Op::use_lhs ? static_cast<IdType*>(argu->data) : nullptr;
  IdType* argW = Op::use_rhs ? static_cast<IdType*>(arge->data) : nullptr;
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
  CHECK_NOTNULL(indptr);
  CHECK_NOTNULL(O);
  if (Op::use_lhs) {
    CHECK_NOTNULL(indices);
    CHECK_NOTNULL(X);
    CHECK_NOTNULL(argX);
  }
  if (Op::use_rhs) {
    if (has_idx)
      CHECK_NOTNULL(edges);
    CHECK_NOTNULL(W);
    CHECK_NOTNULL(argW);
  }
#if !defined(_WIN32)
#ifdef USE_AVX
#ifdef USE_LIBXSMM

  const bool no_libxsmm =
       bcast.use_bcast || std::is_same<DType, double>::value;
  if (!no_libxsmm) {
    SpMMCmpCsrLibxsmm<IdType, DType, Op, Cmp>(bcast, csr, ufeat, efeat, out, argu, arge);
  } else {
#endif  // USE_LIBXSMM
#endif  // USE_AVX
#endif  // _WIN32

273
274
275
276
277
278
#pragma omp parallel for
  for (IdType rid = 0; rid < csr.num_rows; ++rid) {
    const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
    DType* out_off = O + rid * dim;
    IdType* argx_off = argX + rid * dim;
    IdType* argw_off = argW + rid * dim;
Zihao Ye's avatar
Zihao Ye committed
279
280
    for (IdType j = row_start; j < row_end; ++j) {
      const IdType cid = indices[j];
281
      const IdType eid = has_idx ? edges[j] : j;
Zihao Ye's avatar
Zihao Ye committed
282
      for (int64_t k = 0; k < dim; ++k) {
283
284
        const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
        const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
285
286
287
288
        const DType* lhs_off =
          Op::use_lhs ? X + cid * lhs_dim + lhs_add : nullptr;
        const DType* rhs_off =
          Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
289
        const DType val = Op::Call(lhs_off, rhs_off);
Zihao Ye's avatar
Zihao Ye committed
290
291
        if (Cmp::Call(out_off[k], val)) {
          out_off[k] = val;
292
293
          if (Op::use_lhs) argx_off[k] = cid;
          if (Op::use_rhs) argw_off[k] = eid;
294
295
296
297
        }
      }
    }
  }
298
299
300
301
302
303
304
#if !defined(_WIN32)
#ifdef USE_AVX
#ifdef USE_LIBXSMM
  }
#endif  // USE_LIBXSMM
#endif  // USE_AVX
#endif  // _WIN32
305
306
307
308
309
310
311
312
313
}

/*!
 * \brief CPU kernel of SpMM-Min/Max on Coo format.
 * \param bcast Broadcast information.
 * \param coo The Coo matrix.
 * \param ufeat The feature on source nodes.
 * \param efeat The feature on edges.
 * \param out The result feature on destination nodes.
314
 * \param argu Arg-Min/Max on source nodes, which refers the source node indices
315
 *        correspond to the minimum/maximum values of reduction result on
316
317
318
319
320
321
322
323
 *        destination nodes. It's useful in computing gradients of Min/Max
 * reducer. \param arge Arg-Min/Max on edges. which refers the source node
 * indices correspond to the minimum/maximum values of reduction result on
 *        destination nodes. It's useful in computing gradients of Min/Max
 * reducer. \note it uses node parallel strategy, different threads are
 * responsible for the computation of different nodes. To avoid possible data
 * hazard, we use atomic operators in the reduction phase. \note The result will
 * contain infinity for zero-degree nodes.
324
325
 */
template <typename IdType, typename DType, typename Op, typename Cmp>
326
327
void SpMMCmpCoo(const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat,
                NDArray efeat, NDArray out, NDArray argu, NDArray arge) {
328
329
330
  const bool has_idx = !IsNullArray(coo.data);
  const IdType* row = static_cast<IdType*>(coo.row->data);
  const IdType* col = static_cast<IdType*>(coo.col->data);
331
332
333
334
335
  const IdType* edges =
    has_idx ? static_cast<IdType*>(coo.data->data) : nullptr;
  const DType* X = Op::use_lhs ? static_cast<DType*>(ufeat->data) : nullptr;
  const DType* W = Op::use_rhs ? static_cast<DType*>(efeat->data) : nullptr;
  const int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len,
336
337
                rhs_dim = bcast.rhs_len;
  DType* O = static_cast<DType*>(out->data);
338
339
  IdType* argX = Op::use_lhs ? static_cast<IdType*>(argu->data) : nullptr;
  IdType* argW = Op::use_rhs ? static_cast<IdType*>(arge->data) : nullptr;
340
341
342
343
344
345
346
347
  const int64_t nnz = coo.row->shape[0];
  // fill zero elements
  std::fill(O, O + out.NumElements(), Cmp::zero);
  // spmm
#pragma omp parallel for
  for (IdType i = 0; i < nnz; ++i) {
    const IdType rid = row[i];
    const IdType cid = col[i];
348
    const IdType eid = has_idx ? edges[i] : i;
349
    DType* out_off = O + cid * dim;
350
351
    IdType* argx_off = Op::use_lhs ? argX + cid * dim : nullptr;
    IdType* argw_off = Op::use_rhs ? argW + cid * dim : nullptr;
352
353
354
    for (int64_t k = 0; k < dim; ++k) {
      const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
      const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
355
356
357
358
      const DType* lhs_off =
        Op::use_lhs ? X + rid * lhs_dim + lhs_add : nullptr;
      const DType* rhs_off =
        Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
359
360
361
362
      const DType val = Op::Call(lhs_off, rhs_off);
#pragma omp critical
      if (Cmp::Call(out_off[k], val)) {
        out_off[k] = val;
363
364
        if (Op::use_lhs) argx_off[k] = rid;
        if (Op::use_rhs) argw_off[k] = eid;
365
366
367
368
369
370
371
372
373
374
      }
    }
  }
}

}  // namespace cpu
}  // namespace aten
}  // namespace dgl

#endif  // DGL_ARRAY_CPU_SPMM_H_