spmm.h 22.8 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cpu/spmm.h
 * @brief SPMM CPU kernel function header.
5
6
7
8
9
10
 */
#ifndef DGL_ARRAY_CPU_SPMM_H_
#define DGL_ARRAY_CPU_SPMM_H_

#include <dgl/array.h>
#include <dgl/bcast.h>
11
#include <dgl/runtime/config.h>
12
#include <dgl/runtime/parallel_for.h>
13
#include <math.h>
14

15
#include <algorithm>
16
17
#include <limits>
#include <memory>
18
#include <vector>
19

20
21
#include "spmm_binary_ops.h"
#if !defined(_WIN32)
22
23
24
#ifdef USE_LIBXSMM
#include "spmm_blocking_libxsmm.h"
#endif  // USE_LIBXSMM
25
#endif  // _WIN32
26
27
28
29
namespace dgl {
namespace aten {
namespace cpu {

30
31
32
33
template <typename DType>
using AccType = typename std::conditional<
    std::is_same<DType, BFloat16>::value, float, DType>::type;

34
/**
35
36
37
38
39
40
41
42
 * @brief Naive CPU kernel of SpMM on Csr format.
 * @param cpu_spec JIT'ed kernel
 * @param bcast Broadcast information.
 * @param csr The Csr matrix.
 * @param X The feature on source nodes.
 * @param W The feature on edges.
 * @param O The result feature on destination nodes.
 * @note it uses node parallel strategy, different threads are responsible
43
44
45
 *       for the computation of different nodes.
 */
template <typename IdType, typename DType, typename Op>
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
typename std::enable_if<!std::is_same<DType, BFloat16>::value, void>::type
SpMMSumCsrNaive(
    const BcastOff& bcast, const CSRMatrix& csr, const DType* X, const DType* W,
    DType* O) {
  const bool has_idx = !IsNullArray(csr.data);
  const IdType* indptr = csr.indptr.Ptr<IdType>();
  const IdType* indices = csr.indices.Ptr<IdType>();
  const IdType* edges = csr.data.Ptr<IdType>();
  int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len, rhs_dim = bcast.rhs_len;
  runtime::parallel_for(0, csr.num_rows, [&](size_t b, size_t e) {
    for (auto rid = b; rid < e; ++rid) {
      const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
      DType* out_off = O + rid * dim;
      for (IdType j = row_start; j < row_end; ++j) {
        const IdType cid = indices[j];
        const IdType eid = has_idx ? edges[j] : j;
        for (int64_t k = 0; k < dim; ++k) {
          const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
          const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
          const DType* lhs_off =
              Op::use_lhs ? X + cid * lhs_dim + lhs_add : nullptr;
          const DType* rhs_off =
              Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
          out_off[k] += Op::Call(lhs_off, rhs_off);
        }
      }
    }
  });
}

// Naive implementation with additional accumulator, which prevents accuracy
// degradation in less precise data types, like bfloat16.
template <typename IdType, typename DType, typename Op>
typename std::enable_if<std::is_same<DType, BFloat16>::value, void>::type
SpMMSumCsrNaive(
81
82
    const BcastOff& bcast, const CSRMatrix& csr, const DType* X, const DType* W,
    DType* O) {
83
84
85
86
87
  const bool has_idx = !IsNullArray(csr.data);
  const IdType* indptr = csr.indptr.Ptr<IdType>();
  const IdType* indices = csr.indices.Ptr<IdType>();
  const IdType* edges = csr.data.Ptr<IdType>();
  int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len, rhs_dim = bcast.rhs_len;
88
89
90
91
  runtime::parallel_for(0, csr.num_rows, [&](size_t b, size_t e) {
    for (auto rid = b; rid < e; ++rid) {
      const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
      DType* out_off = O + rid * dim;
92
93
94
95
96
      for (int64_t k = 0; k < dim; ++k) {
        AccType<DType> acc = 0.;
        for (IdType j = row_start; j < row_end; ++j) {
          const IdType cid = indices[j];
          const IdType eid = has_idx ? edges[j] : j;
97
98
99
          const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
          const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
          const DType* lhs_off =
100
              Op::use_lhs ? X + cid * lhs_dim + lhs_add : nullptr;
101
          const DType* rhs_off =
102
              Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
103
          acc += Op::Call(lhs_off, rhs_off);
104
        }
105
        out_off[k] += acc;
106
107
      }
    }
108
  });
109
110
}

111
/**
112
113
114
115
116
117
118
 * @brief CPU kernel of SpMM on Csr format.
 * @param bcast Broadcast information.
 * @param csr The Csr matrix.
 * @param ufeat The feature on source nodes.
 * @param efeat The feature on edges.
 * @param out The result feature on destination nodes.
 * @note it uses node parallel strategy, different threads are responsible
119
120
121
 *       for the computation of different nodes.
 */
template <typename IdType, typename DType, typename Op>
122
123
124
void SpMMSumCsr(
    const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,
    NDArray out) {
125
126
127
128
129
130
131
  const bool has_idx = !IsNullArray(csr.data);
  const IdType* indptr = csr.indptr.Ptr<IdType>();
  const IdType* indices = csr.indices.Ptr<IdType>();
  const IdType* edges = csr.data.Ptr<IdType>();
  const DType* X = ufeat.Ptr<DType>();
  const DType* W = efeat.Ptr<DType>();
  DType* O = out.Ptr<DType>();
132
133
134
135
136
137
138
  CHECK_NOTNULL(indptr);
  CHECK_NOTNULL(O);
  if (Op::use_lhs) {
    CHECK_NOTNULL(indices);
    CHECK_NOTNULL(X);
  }
  if (Op::use_rhs) {
139
    if (has_idx) CHECK_NOTNULL(edges);
140
141
    CHECK_NOTNULL(W);
  }
142
#if !defined(_WIN32)
143
#ifdef USE_LIBXSMM
144
145
146
147
148
  int cpu_id = libxsmm_cpuid_x86();
  const bool no_libxsmm =
      bcast.use_bcast || std::is_same<DType, double>::value ||
      (std::is_same<DType, BFloat16>::value && cpu_id < LIBXSMM_X86_AVX512) ||
      !dgl::runtime::Config::Global()->IsLibxsmmAvailable();
149
150
  if (!no_libxsmm) {
    SpMMSumCsrLibxsmm<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
151
  } else {
152
#endif  // USE_LIBXSMM
153
#endif  // _WIN32
154
    SpMMSumCsrNaive<IdType, DType, Op>(bcast, csr, X, W, O);
155
#if !defined(_WIN32)
156
#ifdef USE_LIBXSMM
157
  }
158
#endif  // USE_LIBXSMM
159
#endif  // _WIN32
160
161
}

162
/**
163
164
165
166
167
168
169
 * @brief CPU kernel of SpMM on Coo format.
 * @param bcast Broadcast information.
 * @param coo The Coo matrix.
 * @param ufeat The feature on source nodes.
 * @param efeat The feature on edges.
 * @param out The result feature on destination nodes.
 * @note it uses node parallel strategy, different threads are responsible
170
171
172
173
 *       for the computation of different nodes. To avoid possible data hazard,
 *       we use atomic operators in the reduction phase.
 */
template <typename IdType, typename DType, typename Op>
174
175
typename std::enable_if<!std::is_same<DType, BFloat16>::value, void>::type
SpMMSumCoo(
176
177
    const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat,
    NDArray out) {
178
179
180
181
182
183
  const bool has_idx = !IsNullArray(coo.data);
  const IdType* row = coo.row.Ptr<IdType>();
  const IdType* col = coo.col.Ptr<IdType>();
  const IdType* edges = coo.data.Ptr<IdType>();
  const DType* X = ufeat.Ptr<DType>();
  const DType* W = efeat.Ptr<DType>();
184
  int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len, rhs_dim = bcast.rhs_len;
185
186
187
188
189
190
191
192
193
  DType* O = out.Ptr<DType>();
  const int64_t nnz = coo.row->shape[0];
  // fill zero elements
  memset(O, 0, out.GetSize());
  // spmm
#pragma omp parallel for
  for (IdType i = 0; i < nnz; ++i) {
    const IdType rid = row[i];
    const IdType cid = col[i];
194
    const IdType eid = has_idx ? edges[i] : i;
195
196
197
198
    DType* out_off = O + cid * dim;
    for (int64_t k = 0; k < dim; ++k) {
      const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
      const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
199
      const DType* lhs_off =
200
          Op::use_lhs ? X + rid * lhs_dim + lhs_add : nullptr;
201
      const DType* rhs_off =
202
          Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
203
      const DType val = Op::Call(lhs_off, rhs_off);
204
      if (val != 0) {
205
#pragma omp atomic
206
207
        out_off[k] += val;
      }
208
209
210
211
    }
  }
}

212
213
214
215
216
217
218
219
template <typename IdType, typename DType, typename Op>
typename std::enable_if<std::is_same<DType, BFloat16>::value, void>::type
SpMMSumCoo(
    const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat,
    NDArray out) {
  LOG(FATAL) << "Unsupported CPU kernel for SpMMSumCoo for BF16.";
}

220
/**
221
222
223
224
225
226
227
 * @brief CPU kernel of SpMM-Min/Max on Csr format.
 * @param bcast Broadcast information.
 * @param csr The Csr matrix.
 * @param ufeat The feature on source nodes.
 * @param efeat The feature on edges.
 * @param out The result feature on destination nodes.
 * @param argu Arg-Min/Max on source nodes, which refers the source node indices
228
 *        correspond to the minimum/maximum values of reduction result on
229
 *        destination nodes. It's useful in computing gradients of Min/Max
230
 *        reducer.
231
 * @param arge Arg-Min/Max on edges. which refers the source node indices
232
          correspond to the minimum/maximum values of reduction result on
233
 *        destination nodes. It's useful in computing gradients of Min/Max
234
 *        reducer.
235
 * @note It uses node parallel strategy, different threads are responsible for
236
 *       the computation of different nodes.
237
 * @note The result will contain infinity for zero-degree nodes.
238
239
 */
template <typename IdType, typename DType, typename Op, typename Cmp>
240
241
242
void SpMMCmpCsr(
    const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,
    NDArray out, NDArray argu, NDArray arge) {
243
244
245
  const bool has_idx = !IsNullArray(csr.data);
  const IdType* indptr = static_cast<IdType*>(csr.indptr->data);
  const IdType* indices = static_cast<IdType*>(csr.indices->data);
246
  const IdType* edges =
247
      has_idx ? static_cast<IdType*>(csr.data->data) : nullptr;
248
249
250
  const DType* X = Op::use_lhs ? static_cast<DType*>(ufeat->data) : nullptr;
  const DType* W = Op::use_rhs ? static_cast<DType*>(efeat->data) : nullptr;
  const int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len,
251
252
                rhs_dim = bcast.rhs_len;
  DType* O = static_cast<DType*>(out->data);
253
254
  IdType* argX = Op::use_lhs ? static_cast<IdType*>(argu->data) : nullptr;
  IdType* argW = Op::use_rhs ? static_cast<IdType*>(arge->data) : nullptr;
255
256
257
258
259
260
261
262
  CHECK_NOTNULL(indptr);
  CHECK_NOTNULL(O);
  if (Op::use_lhs) {
    CHECK_NOTNULL(indices);
    CHECK_NOTNULL(X);
    CHECK_NOTNULL(argX);
  }
  if (Op::use_rhs) {
263
    if (has_idx) CHECK_NOTNULL(edges);
264
265
266
267
268
    CHECK_NOTNULL(W);
    CHECK_NOTNULL(argW);
  }
#if !defined(_WIN32)
#ifdef USE_LIBXSMM
269
270
271
272
273
  int cpu_id = libxsmm_cpuid_x86();
  const bool no_libxsmm =
      bcast.use_bcast || std::is_same<DType, double>::value ||
      (std::is_same<DType, BFloat16>::value && cpu_id < LIBXSMM_X86_AVX512) ||
      !dgl::runtime::Config::Global()->IsLibxsmmAvailable();
274
  if (!no_libxsmm) {
275
276
    SpMMCmpCsrLibxsmm<IdType, DType, Op, Cmp>(
        bcast, csr, ufeat, efeat, out, argu, arge);
277
278
279
280
  } else {
#endif  // USE_LIBXSMM
#endif  // _WIN32

281
282
283
284
285
286
287
288
289
290
291
292
293
    runtime::parallel_for(0, csr.num_rows, [&](size_t b, size_t e) {
      for (auto rid = b; rid < e; ++rid) {
        const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
        DType* out_off = O + rid * dim;
        IdType* argx_off = argX + rid * dim;
        IdType* argw_off = argW + rid * dim;
        for (IdType j = row_start; j < row_end; ++j) {
          const IdType cid = indices[j];
          const IdType eid = has_idx ? edges[j] : j;
          for (int64_t k = 0; k < dim; ++k) {
            const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
            const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
            const DType* lhs_off =
294
                Op::use_lhs ? X + cid * lhs_dim + lhs_add : nullptr;
295
            const DType* rhs_off =
296
                Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
297
298
299
300
301
302
303
            const DType val = Op::Call(lhs_off, rhs_off);
            if (Cmp::Call(out_off[k], val)) {
              out_off[k] = val;
              if (Op::use_lhs) argx_off[k] = cid;
              if (Op::use_rhs) argw_off[k] = eid;
            }
          }
304
305
        }
      }
306
    });
307
308
309
310
311
#if !defined(_WIN32)
#ifdef USE_LIBXSMM
  }
#endif  // USE_LIBXSMM
#endif  // _WIN32
312
313
}

314
/**
315
316
317
318
319
320
321
 * @brief CPU kernel of SpMM-Min/Max on Csr format.
 * @param bcast Broadcast information.
 * @param csr The Csr matrix.
 * @param ufeat The feature on source nodes.
 * @param efeat The feature on edges.
 * @param out The result feature on destination nodes.
 * @param argu Arg-Min/Max on source nodes, which refers the source node indices
322
323
 *        correspond to the minimum/maximum values of reduction result on
 *        destination nodes. It's useful in computing gradients of Min/Max
324
 *        reducer.
325
 * @param arge Arg-Min/Max on edges. which refers the source node indices
326
 *        correspond to the minimum/maximum values of reduction result on
327
 *        destination nodes. It's useful in computing gradients of Min/Max
328
 *        reducer.
329
 * @param argu_ntype Node type of the arg-Min/Max on source nodes, which refers
330
331
332
 *        the source node types correspond to the minimum/maximum values of
 *        reduction result on destination nodes. It's useful in computing
 *        gradients of Min/Max reducer.
333
 * @param arge_etype Edge-type of the arg-Min/Max on edges. which refers the
334
335
336
 *        source node indices correspond to the minimum/maximum values of
 *        reduction result on destination nodes. It's useful in computing
 *        gradients of Min/Max reducer.
337
338
 * @param src_type Node type of the source nodes of an etype
 * @param etype Edge type
339
340
 */
template <typename IdType, typename DType, typename Op, typename Cmp>
341
342
343
344
void SpMMCmpCsrHetero(
    const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,
    NDArray out, NDArray argu, NDArray arge, NDArray argu_ntype,
    NDArray arge_etype, const int ntype, const int etype) {
345
346
347
348
  const bool has_idx = !IsNullArray(csr.data);
  const IdType* indptr = static_cast<IdType*>(csr.indptr->data);
  const IdType* indices = static_cast<IdType*>(csr.indices->data);
  const IdType* edges =
349
      has_idx ? static_cast<IdType*>(csr.data->data) : nullptr;
350
351
352
353
354
355
356
  const DType* X = Op::use_lhs ? static_cast<DType*>(ufeat->data) : nullptr;
  const DType* W = Op::use_rhs ? static_cast<DType*>(efeat->data) : nullptr;
  const int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len,
                rhs_dim = bcast.rhs_len;
  DType* O = static_cast<DType*>(out->data);
  IdType* argX = Op::use_lhs ? static_cast<IdType*>(argu->data) : nullptr;
  IdType* argW = Op::use_rhs ? static_cast<IdType*>(arge->data) : nullptr;
357
358
359
360
  IdType* argX_ntype =
      Op::use_lhs ? static_cast<IdType*>(argu_ntype->data) : nullptr;
  IdType* argW_etype =
      Op::use_rhs ? static_cast<IdType*>(arge_etype->data) : nullptr;
361
362
363
364
365
366
367
368
  CHECK_NOTNULL(indptr);
  CHECK_NOTNULL(O);
  if (Op::use_lhs) {
    CHECK_NOTNULL(indices);
    CHECK_NOTNULL(X);
    CHECK_NOTNULL(argX);
  }
  if (Op::use_rhs) {
369
    if (has_idx) CHECK_NOTNULL(edges);
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
    CHECK_NOTNULL(W);
    CHECK_NOTNULL(argW);
  }
  // TODO(Israt): Use LIBXSMM. Homogeneous graph uses LIBXMM when enabled.
  runtime::parallel_for(0, csr.num_rows, [&](size_t b, size_t e) {
    for (auto rid = b; rid < e; ++rid) {
      const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
      DType* out_off = O + rid * dim;
      IdType* argx_off = argX + rid * dim;
      IdType* argw_off = argW + rid * dim;
      IdType* argx_ntype = argX_ntype + rid * dim;
      IdType* argw_etype = argW_etype + rid * dim;
      for (IdType j = row_start; j < row_end; ++j) {
        const IdType cid = indices[j];
        const IdType eid = has_idx ? edges[j] : j;
        for (int64_t k = 0; k < dim; ++k) {
          const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
          const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
          const DType* lhs_off =
389
              Op::use_lhs ? X + cid * lhs_dim + lhs_add : nullptr;
390
          const DType* rhs_off =
391
              Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
          const DType val = Op::Call(lhs_off, rhs_off);
          if (Cmp::Call(out_off[k], val)) {
            out_off[k] = val;
            if (Op::use_lhs) {
              argx_off[k] = cid;
              argx_ntype[k] = ntype;
            }
            if (Op::use_rhs) {
              argw_off[k] = eid;
              argw_etype[k] = etype;
            }
          }
        }
      }
    }
  });
}

410
/**
411
412
413
414
415
416
417
 * @brief CPU kernel of SpMM-Min/Max on Coo format.
 * @param bcast Broadcast information.
 * @param coo The Coo matrix.
 * @param ufeat The feature on source nodes.
 * @param efeat The feature on edges.
 * @param out The result feature on destination nodes.
 * @param argu Arg-Min/Max on source nodes, which refers the source node indices
418
 *        correspond to the minimum/maximum values of reduction result on
419
 *        destination nodes. It's useful in computing gradients of Min/Max
420
 *        reducer.
421
 * @param arge Arg-Min/Max on edges. which refers the source node indices
422
 *        correspond to the minimum/maximum values of reduction result on
423
 *        destination nodes. It's useful in computing gradients of Min/Max
424
 *        reducer.
425
 * @note it uses node parallel strategy, different threads are responsible for
426
427
 *       the computation of different nodes. To avoid possible data hazard, we
 *       use atomic operators in the reduction phase.
428
 * @note The result will contain infinity for zero-degree nodes.
429
430
 */
template <typename IdType, typename DType, typename Op, typename Cmp>
431
432
433
void SpMMCmpCoo(
    const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat,
    NDArray out, NDArray argu, NDArray arge) {
434
435
436
  const bool has_idx = !IsNullArray(coo.data);
  const IdType* row = static_cast<IdType*>(coo.row->data);
  const IdType* col = static_cast<IdType*>(coo.col->data);
437
  const IdType* edges =
438
      has_idx ? static_cast<IdType*>(coo.data->data) : nullptr;
439
440
441
  const DType* X = Op::use_lhs ? static_cast<DType*>(ufeat->data) : nullptr;
  const DType* W = Op::use_rhs ? static_cast<DType*>(efeat->data) : nullptr;
  const int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len,
442
443
                rhs_dim = bcast.rhs_len;
  DType* O = static_cast<DType*>(out->data);
444
445
  IdType* argX = Op::use_lhs ? static_cast<IdType*>(argu->data) : nullptr;
  IdType* argW = Op::use_rhs ? static_cast<IdType*>(arge->data) : nullptr;
446
447
448
449
450
451
452
453
  const int64_t nnz = coo.row->shape[0];
  // fill zero elements
  std::fill(O, O + out.NumElements(), Cmp::zero);
  // spmm
#pragma omp parallel for
  for (IdType i = 0; i < nnz; ++i) {
    const IdType rid = row[i];
    const IdType cid = col[i];
454
    const IdType eid = has_idx ? edges[i] : i;
455
    DType* out_off = O + cid * dim;
456
457
    IdType* argx_off = Op::use_lhs ? argX + cid * dim : nullptr;
    IdType* argw_off = Op::use_rhs ? argW + cid * dim : nullptr;
458
459
460
    for (int64_t k = 0; k < dim; ++k) {
      const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
      const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
461
      const DType* lhs_off =
462
          Op::use_lhs ? X + rid * lhs_dim + lhs_add : nullptr;
463
      const DType* rhs_off =
464
          Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
465
466
467
468
      const DType val = Op::Call(lhs_off, rhs_off);
#pragma omp critical
      if (Cmp::Call(out_off[k], val)) {
        out_off[k] = val;
469
470
        if (Op::use_lhs) argx_off[k] = rid;
        if (Op::use_rhs) argw_off[k] = eid;
471
472
473
474
475
      }
    }
  }
}

476
/**
477
478
479
480
481
482
 * @brief CPU kernel of Edge_softmax_csr_forward on Csr format.
 * @param bcast Broadcast information.
 * @param csr The Csr matrix.
 * @param ufeat The feature on source nodes.
 * @param efeat The feature on edges.
 * @param out The result of edge_softmax_forward.
483
484
 */
template <typename IdType, typename DType, typename Op>
485
486
487
void Edge_softmax_csr_forward(
    const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,
    NDArray out) {
488
489
490
  const bool has_idx = !IsNullArray(csr.data);
  const IdType* indptr = static_cast<IdType*>(csr.indptr->data);
  const IdType* edges =
491
      has_idx ? static_cast<IdType*>(csr.data->data) : nullptr;
492
493
494
495
496
  const DType* W = Op::use_rhs ? static_cast<DType*>(efeat->data) : nullptr;
  const int64_t dim = bcast.out_len, rhs_dim = bcast.rhs_len;
  runtime::parallel_for(0, csr.num_rows, [&](size_t b, size_t e) {
    for (auto rid = b; rid < e; ++rid) {
      const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
497
      std::vector<AccType<DType>> data_e(row_end - row_start, 0);
498
      std::vector<IdType> num(row_end - row_start, 0);
499
500
501
502
503
504
      for (int64_t k = 0; k < dim; ++k) {
        DType max_v = -std::numeric_limits<DType>::infinity();
        for (IdType j = row_start; j < row_end; ++j) {
          const IdType eid = has_idx ? edges[j] : j;
          const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
          const DType* rhs_off =
505
506
507
              Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
          data_e[j - row_start] = *rhs_off;
          num[j - row_start] = eid * rhs_dim + rhs_add;
508
509
510
511
512
513
514
515
          max_v = std::max<DType>(max_v, (*rhs_off));
        }
        DType exp_sum = 0;
        for (auto& element : data_e) {
          element -= max_v;
          element = std::exp(element);
          exp_sum += element;
        }
516
517
        for (int i = 0; i < row_end - row_start; i++) {
          out.Ptr<DType>()[num[i]] = data_e[i] / exp_sum;
518
519
520
521
522
523
        }
      }
    }
  });
}

524
/**
525
526
527
528
529
530
 * @brief CPU kernel of Edge_softmax_csr_backward on Csr format.
 * @param bcast Broadcast information.
 * @param csr The Csr matrix.
 * @param out The result of forward.
 * @param sds The result of gradiet * out.
 * @param back_out The result of edge_softmax_backward.
531
532
 */
template <typename IdType, typename DType, typename Op>
533
534
535
void Edge_softmax_csr_backward(
    const BcastOff& bcast, const CSRMatrix& csr, NDArray out, NDArray sds,
    NDArray back_out) {
536
537
  typedef typename std::conditional<
      std::is_same<DType, BFloat16>::value, float, DType>::type AccType;
538
539
540
  const bool has_idx = !IsNullArray(csr.data);
  const IdType* indptr = static_cast<IdType*>(csr.indptr->data);
  const IdType* edges =
541
      has_idx ? static_cast<IdType*>(csr.data->data) : nullptr;
542
543
544
545
546
547
548
  const DType* W_out = Op::use_rhs ? static_cast<DType*>(out->data) : nullptr;
  const DType* W_sds = Op::use_rhs ? static_cast<DType*>(sds->data) : nullptr;
  const int64_t dim = bcast.out_len, rhs_dim = bcast.rhs_len;
  runtime::parallel_for(0, csr.num_rows, [&](size_t b, size_t e) {
    for (auto rid = b; rid < e; ++rid) {
      const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
      for (int64_t k = 0; k < dim; ++k) {
549
        AccType sum_sds = 0;
550
551
552
553
        for (IdType j = row_start; j < row_end; ++j) {
          const IdType eid = has_idx ? edges[j] : j;
          const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
          const DType* rhs_off_sds =
554
              Op::use_rhs ? W_sds + eid * rhs_dim + rhs_add : nullptr;
555
556
          sum_sds += (*rhs_off_sds);
        }
557
        for (IdType j = row_start; j < row_end; ++j) {
558
559
560
          const IdType eid = has_idx ? edges[j] : j;
          const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
          const DType* rhs_off_out =
561
              Op::use_rhs ? W_out + eid * rhs_dim + rhs_add : nullptr;
562
          const DType* rhs_off_sds =
563
564
565
              Op::use_rhs ? W_sds + eid * rhs_dim + rhs_add : nullptr;
          back_out.Ptr<DType>()[eid * rhs_dim + rhs_add] =
              (*rhs_off_sds) - sum_sds * (*rhs_off_out);
566
567
568
569
570
571
        }
      }
    }
  });
}

572
573
574
575
576
}  // namespace cpu
}  // namespace aten
}  // namespace dgl

#endif  // DGL_ARRAY_CPU_SPMM_H_