"docs/source/ko/using-diffusers/loading_adapters.md" did not exist on "1a8b3c2ee86c09d0d3e066f7e9ea2ab69e8e78fa"
spmm.h 22.6 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cpu/spmm.h
 * @brief SPMM CPU kernel function header.
5
6
7
8
9
10
 */
#ifndef DGL_ARRAY_CPU_SPMM_H_
#define DGL_ARRAY_CPU_SPMM_H_

#include <dgl/array.h>
#include <dgl/bcast.h>
11
#include <dgl/runtime/config.h>
12
#include <dgl/runtime/parallel_for.h>
13
#include <math.h>
14

15
#include <algorithm>
16
17
#include <limits>
#include <memory>
18
#include <vector>
19

20
21
#include "spmm_binary_ops.h"
#if !defined(_WIN32)
22
#ifdef USE_AVX
23
#include "intel/cpu_support.h"
24
25
26
#ifdef USE_LIBXSMM
#include "spmm_blocking_libxsmm.h"
#endif  // USE_LIBXSMM
27
28
#endif  // USE_AVX
#endif  // _WIN32
29
30
31
32
namespace dgl {
namespace aten {
namespace cpu {

33
34
#if !defined(_WIN32)
#ifdef USE_AVX
35
/**
36
37
38
39
40
41
42
43
 * @brief CPU kernel of SpMM on Csr format using Xbyak.
 * @param cpu_spec JIT'ed kernel
 * @param bcast Broadcast information.
 * @param csr The Csr matrix.
 * @param X The feature on source nodes.
 * @param W The feature on edges.
 * @param O The result feature on destination nodes.
 * @note it uses node parallel strategy, different threads are responsible
44
45
46
47
 *       for the computation of different nodes. For each edge, it uses the
 *       JIT'ed kernel.
 */
template <typename IdType, typename DType, typename Op>
48
49
50
void SpMMSumCsrXbyak(
    dgl::ElemWiseAddUpdate<Op>* cpu_spec, const BcastOff& bcast,
    const CSRMatrix& csr, const DType* X, const DType* W, DType* O) {
51
52
53
54
55
  const bool has_idx = !IsNullArray(csr.data);
  const IdType* indptr = csr.indptr.Ptr<IdType>();
  const IdType* indices = csr.indices.Ptr<IdType>();
  const IdType* edges = csr.data.Ptr<IdType>();
  int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len, rhs_dim = bcast.rhs_len;
56
57
58
59
60
61
62
63
64
65

  runtime::parallel_for(0, csr.num_rows, [&](size_t b, size_t e) {
    for (auto rid = b; rid < e; ++rid) {
      const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
      DType* out_off = O + rid * dim;
      for (IdType j = row_start; j < row_end; ++j) {
        const IdType cid = indices[j];
        const IdType eid = has_idx ? edges[j] : j;
        cpu_spec->run(out_off, X + cid * lhs_dim, W + eid * rhs_dim, dim);
      }
66
    }
67
  });
68
69
70
71
}
#endif  // USE_AVX
#endif  // _WIN32

72
/**
73
74
75
76
77
78
79
80
 * @brief Naive CPU kernel of SpMM on Csr format.
 * @param cpu_spec JIT'ed kernel
 * @param bcast Broadcast information.
 * @param csr The Csr matrix.
 * @param X The feature on source nodes.
 * @param W The feature on edges.
 * @param O The result feature on destination nodes.
 * @note it uses node parallel strategy, different threads are responsible
81
82
83
 *       for the computation of different nodes.
 */
template <typename IdType, typename DType, typename Op>
84
85
86
void SpMMSumCsrNaive(
    const BcastOff& bcast, const CSRMatrix& csr, const DType* X, const DType* W,
    DType* O) {
87
88
89
90
91
  const bool has_idx = !IsNullArray(csr.data);
  const IdType* indptr = csr.indptr.Ptr<IdType>();
  const IdType* indices = csr.indices.Ptr<IdType>();
  const IdType* edges = csr.data.Ptr<IdType>();
  int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len, rhs_dim = bcast.rhs_len;
92
93
94
95
96
97
98
99
100
101
102
  runtime::parallel_for(0, csr.num_rows, [&](size_t b, size_t e) {
    for (auto rid = b; rid < e; ++rid) {
      const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
      DType* out_off = O + rid * dim;
      for (IdType j = row_start; j < row_end; ++j) {
        const IdType cid = indices[j];
        const IdType eid = has_idx ? edges[j] : j;
        for (int64_t k = 0; k < dim; ++k) {
          const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
          const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
          const DType* lhs_off =
103
              Op::use_lhs ? X + cid * lhs_dim + lhs_add : nullptr;
104
          const DType* rhs_off =
105
              Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
106
107
          out_off[k] += Op::Call(lhs_off, rhs_off);
        }
108
109
      }
    }
110
  });
111
112
}

113
/**
114
115
116
117
118
119
120
 * @brief CPU kernel of SpMM on Csr format.
 * @param bcast Broadcast information.
 * @param csr The Csr matrix.
 * @param ufeat The feature on source nodes.
 * @param efeat The feature on edges.
 * @param out The result feature on destination nodes.
 * @note it uses node parallel strategy, different threads are responsible
121
122
123
 *       for the computation of different nodes.
 */
template <typename IdType, typename DType, typename Op>
124
125
126
void SpMMSumCsr(
    const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,
    NDArray out) {
127
128
129
130
131
132
133
  const bool has_idx = !IsNullArray(csr.data);
  const IdType* indptr = csr.indptr.Ptr<IdType>();
  const IdType* indices = csr.indices.Ptr<IdType>();
  const IdType* edges = csr.data.Ptr<IdType>();
  const DType* X = ufeat.Ptr<DType>();
  const DType* W = efeat.Ptr<DType>();
  DType* O = out.Ptr<DType>();
134
135
136
137
138
139
140
  CHECK_NOTNULL(indptr);
  CHECK_NOTNULL(O);
  if (Op::use_lhs) {
    CHECK_NOTNULL(indices);
    CHECK_NOTNULL(X);
  }
  if (Op::use_rhs) {
141
    if (has_idx) CHECK_NOTNULL(edges);
142
143
    CHECK_NOTNULL(W);
  }
144
#if !defined(_WIN32)
145
#ifdef USE_AVX
146
#ifdef USE_LIBXSMM
147
148
149
  const bool no_libxsmm = bcast.use_bcast ||
                          std::is_same<DType, double>::value ||
                          !dgl::runtime::Config::Global()->IsLibxsmmAvailable();
150
151
  if (!no_libxsmm) {
    SpMMSumCsrLibxsmm<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
152
  } else {
153
154
155
156
157
158
159
#endif  // USE_LIBXSMM
    typedef dgl::ElemWiseAddUpdate<Op> ElemWiseUpd;
    /* Prepare an assembler kernel */
    static std::unique_ptr<ElemWiseUpd> asm_kernel_ptr(
        (dgl::IntelKernel<>::IsEnabled()) ? new ElemWiseUpd() : nullptr);
    /* Distribute the kernel among OMP threads */
    ElemWiseUpd* cpu_spec = (asm_kernel_ptr && asm_kernel_ptr->applicable())
160
161
                                ? asm_kernel_ptr.get()
                                : nullptr;
162
    if (cpu_spec && bcast.out_len > 16 && !bcast.use_bcast) {
163
164
      SpMMSumCsrXbyak<IdType, DType, Op>(cpu_spec, bcast, csr, X, W, O);
    } else {
165
166
#endif  // USE_AVX
#endif  // _WIN32
167
      SpMMSumCsrNaive<IdType, DType, Op>(bcast, csr, X, W, O);
168
#if !defined(_WIN32)
169
#ifdef USE_AVX
170
171
    }
#ifdef USE_LIBXSMM
172
  }
173
#endif  // USE_LIBXSMM
174
175
#endif  // USE_AVX
#endif  // _WIN32
176
177
}

178
/**
179
180
181
182
183
184
185
 * @brief CPU kernel of SpMM on Coo format.
 * @param bcast Broadcast information.
 * @param coo The Coo matrix.
 * @param ufeat The feature on source nodes.
 * @param efeat The feature on edges.
 * @param out The result feature on destination nodes.
 * @note it uses node parallel strategy, different threads are responsible
186
187
188
189
 *       for the computation of different nodes. To avoid possible data hazard,
 *       we use atomic operators in the reduction phase.
 */
template <typename IdType, typename DType, typename Op>
190
191
192
void SpMMSumCoo(
    const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat,
    NDArray out) {
193
194
195
196
197
198
  const bool has_idx = !IsNullArray(coo.data);
  const IdType* row = coo.row.Ptr<IdType>();
  const IdType* col = coo.col.Ptr<IdType>();
  const IdType* edges = coo.data.Ptr<IdType>();
  const DType* X = ufeat.Ptr<DType>();
  const DType* W = efeat.Ptr<DType>();
199
  int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len, rhs_dim = bcast.rhs_len;
200
201
202
203
204
205
206
207
208
  DType* O = out.Ptr<DType>();
  const int64_t nnz = coo.row->shape[0];
  // fill zero elements
  memset(O, 0, out.GetSize());
  // spmm
#pragma omp parallel for
  for (IdType i = 0; i < nnz; ++i) {
    const IdType rid = row[i];
    const IdType cid = col[i];
209
    const IdType eid = has_idx ? edges[i] : i;
210
211
212
213
    DType* out_off = O + cid * dim;
    for (int64_t k = 0; k < dim; ++k) {
      const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
      const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
214
      const DType* lhs_off =
215
          Op::use_lhs ? X + rid * lhs_dim + lhs_add : nullptr;
216
      const DType* rhs_off =
217
          Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
218
      const DType val = Op::Call(lhs_off, rhs_off);
219
      if (val != 0) {
220
#pragma omp atomic
221
222
        out_off[k] += val;
      }
223
224
225
226
    }
  }
}

227
/**
228
229
230
231
232
233
234
 * @brief CPU kernel of SpMM-Min/Max on Csr format.
 * @param bcast Broadcast information.
 * @param csr The Csr matrix.
 * @param ufeat The feature on source nodes.
 * @param efeat The feature on edges.
 * @param out The result feature on destination nodes.
 * @param argu Arg-Min/Max on source nodes, which refers the source node indices
235
 *        correspond to the minimum/maximum values of reduction result on
236
 *        destination nodes. It's useful in computing gradients of Min/Max
237
 *        reducer.
238
 * @param arge Arg-Min/Max on edges. which refers the source node indices
239
          correspond to the minimum/maximum values of reduction result on
240
 *        destination nodes. It's useful in computing gradients of Min/Max
241
 *        reducer.
242
 * @note It uses node parallel strategy, different threads are responsible for
243
 *       the computation of different nodes.
244
 * @note The result will contain infinity for zero-degree nodes.
245
246
 */
template <typename IdType, typename DType, typename Op, typename Cmp>
247
248
249
void SpMMCmpCsr(
    const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,
    NDArray out, NDArray argu, NDArray arge) {
250
251
252
  const bool has_idx = !IsNullArray(csr.data);
  const IdType* indptr = static_cast<IdType*>(csr.indptr->data);
  const IdType* indices = static_cast<IdType*>(csr.indices->data);
253
  const IdType* edges =
254
      has_idx ? static_cast<IdType*>(csr.data->data) : nullptr;
255
256
257
  const DType* X = Op::use_lhs ? static_cast<DType*>(ufeat->data) : nullptr;
  const DType* W = Op::use_rhs ? static_cast<DType*>(efeat->data) : nullptr;
  const int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len,
258
259
                rhs_dim = bcast.rhs_len;
  DType* O = static_cast<DType*>(out->data);
260
261
  IdType* argX = Op::use_lhs ? static_cast<IdType*>(argu->data) : nullptr;
  IdType* argW = Op::use_rhs ? static_cast<IdType*>(arge->data) : nullptr;
262
263
264
265
266
267
268
269
  CHECK_NOTNULL(indptr);
  CHECK_NOTNULL(O);
  if (Op::use_lhs) {
    CHECK_NOTNULL(indices);
    CHECK_NOTNULL(X);
    CHECK_NOTNULL(argX);
  }
  if (Op::use_rhs) {
270
    if (has_idx) CHECK_NOTNULL(edges);
271
272
273
274
275
276
277
    CHECK_NOTNULL(W);
    CHECK_NOTNULL(argW);
  }
#if !defined(_WIN32)
#ifdef USE_AVX
#ifdef USE_LIBXSMM

278
279
280
  const bool no_libxsmm = bcast.use_bcast ||
                          std::is_same<DType, double>::value ||
                          !dgl::runtime::Config::Global()->IsLibxsmmAvailable();
281
  if (!no_libxsmm) {
282
283
    SpMMCmpCsrLibxsmm<IdType, DType, Op, Cmp>(
        bcast, csr, ufeat, efeat, out, argu, arge);
284
285
286
287
288
  } else {
#endif  // USE_LIBXSMM
#endif  // USE_AVX
#endif  // _WIN32

289
290
291
292
293
294
295
296
297
298
299
300
301
    runtime::parallel_for(0, csr.num_rows, [&](size_t b, size_t e) {
      for (auto rid = b; rid < e; ++rid) {
        const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
        DType* out_off = O + rid * dim;
        IdType* argx_off = argX + rid * dim;
        IdType* argw_off = argW + rid * dim;
        for (IdType j = row_start; j < row_end; ++j) {
          const IdType cid = indices[j];
          const IdType eid = has_idx ? edges[j] : j;
          for (int64_t k = 0; k < dim; ++k) {
            const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
            const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
            const DType* lhs_off =
302
                Op::use_lhs ? X + cid * lhs_dim + lhs_add : nullptr;
303
            const DType* rhs_off =
304
                Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
305
306
307
308
309
310
311
            const DType val = Op::Call(lhs_off, rhs_off);
            if (Cmp::Call(out_off[k], val)) {
              out_off[k] = val;
              if (Op::use_lhs) argx_off[k] = cid;
              if (Op::use_rhs) argw_off[k] = eid;
            }
          }
312
313
        }
      }
314
    });
315
316
317
318
319
320
321
#if !defined(_WIN32)
#ifdef USE_AVX
#ifdef USE_LIBXSMM
  }
#endif  // USE_LIBXSMM
#endif  // USE_AVX
#endif  // _WIN32
322
323
}

324
/**
325
326
327
328
329
330
331
 * @brief CPU kernel of SpMM-Min/Max on Csr format.
 * @param bcast Broadcast information.
 * @param csr The Csr matrix.
 * @param ufeat The feature on source nodes.
 * @param efeat The feature on edges.
 * @param out The result feature on destination nodes.
 * @param argu Arg-Min/Max on source nodes, which refers the source node indices
332
333
 *        correspond to the minimum/maximum values of reduction result on
 *        destination nodes. It's useful in computing gradients of Min/Max
334
 *        reducer.
335
 * @param arge Arg-Min/Max on edges. which refers the source node indices
336
 *        correspond to the minimum/maximum values of reduction result on
337
 *        destination nodes. It's useful in computing gradients of Min/Max
338
 *        reducer.
339
 * @param argu_ntype Node type of the arg-Min/Max on source nodes, which refers
340
341
342
 *        the source node types correspond to the minimum/maximum values of
 *        reduction result on destination nodes. It's useful in computing
 *        gradients of Min/Max reducer.
343
 * @param arge_etype Edge-type of the arg-Min/Max on edges. which refers the
344
345
346
 *        source node indices correspond to the minimum/maximum values of
 *        reduction result on destination nodes. It's useful in computing
 *        gradients of Min/Max reducer.
347
348
 * @param src_type Node type of the source nodes of an etype
 * @param etype Edge type
349
350
 */
template <typename IdType, typename DType, typename Op, typename Cmp>
351
352
353
354
void SpMMCmpCsrHetero(
    const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,
    NDArray out, NDArray argu, NDArray arge, NDArray argu_ntype,
    NDArray arge_etype, const int ntype, const int etype) {
355
356
357
358
  const bool has_idx = !IsNullArray(csr.data);
  const IdType* indptr = static_cast<IdType*>(csr.indptr->data);
  const IdType* indices = static_cast<IdType*>(csr.indices->data);
  const IdType* edges =
359
      has_idx ? static_cast<IdType*>(csr.data->data) : nullptr;
360
361
362
363
364
365
366
  const DType* X = Op::use_lhs ? static_cast<DType*>(ufeat->data) : nullptr;
  const DType* W = Op::use_rhs ? static_cast<DType*>(efeat->data) : nullptr;
  const int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len,
                rhs_dim = bcast.rhs_len;
  DType* O = static_cast<DType*>(out->data);
  IdType* argX = Op::use_lhs ? static_cast<IdType*>(argu->data) : nullptr;
  IdType* argW = Op::use_rhs ? static_cast<IdType*>(arge->data) : nullptr;
367
368
369
370
  IdType* argX_ntype =
      Op::use_lhs ? static_cast<IdType*>(argu_ntype->data) : nullptr;
  IdType* argW_etype =
      Op::use_rhs ? static_cast<IdType*>(arge_etype->data) : nullptr;
371
372
373
374
375
376
377
378
  CHECK_NOTNULL(indptr);
  CHECK_NOTNULL(O);
  if (Op::use_lhs) {
    CHECK_NOTNULL(indices);
    CHECK_NOTNULL(X);
    CHECK_NOTNULL(argX);
  }
  if (Op::use_rhs) {
379
    if (has_idx) CHECK_NOTNULL(edges);
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
    CHECK_NOTNULL(W);
    CHECK_NOTNULL(argW);
  }
  // TODO(Israt): Use LIBXSMM. Homogeneous graph uses LIBXMM when enabled.
  runtime::parallel_for(0, csr.num_rows, [&](size_t b, size_t e) {
    for (auto rid = b; rid < e; ++rid) {
      const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
      DType* out_off = O + rid * dim;
      IdType* argx_off = argX + rid * dim;
      IdType* argw_off = argW + rid * dim;
      IdType* argx_ntype = argX_ntype + rid * dim;
      IdType* argw_etype = argW_etype + rid * dim;
      for (IdType j = row_start; j < row_end; ++j) {
        const IdType cid = indices[j];
        const IdType eid = has_idx ? edges[j] : j;
        for (int64_t k = 0; k < dim; ++k) {
          const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
          const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
          const DType* lhs_off =
399
              Op::use_lhs ? X + cid * lhs_dim + lhs_add : nullptr;
400
          const DType* rhs_off =
401
              Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
          const DType val = Op::Call(lhs_off, rhs_off);
          if (Cmp::Call(out_off[k], val)) {
            out_off[k] = val;
            if (Op::use_lhs) {
              argx_off[k] = cid;
              argx_ntype[k] = ntype;
            }
            if (Op::use_rhs) {
              argw_off[k] = eid;
              argw_etype[k] = etype;
            }
          }
        }
      }
    }
  });
}

420
/**
421
422
423
424
425
426
427
 * @brief CPU kernel of SpMM-Min/Max on Coo format.
 * @param bcast Broadcast information.
 * @param coo The Coo matrix.
 * @param ufeat The feature on source nodes.
 * @param efeat The feature on edges.
 * @param out The result feature on destination nodes.
 * @param argu Arg-Min/Max on source nodes, which refers the source node indices
428
 *        correspond to the minimum/maximum values of reduction result on
429
 *        destination nodes. It's useful in computing gradients of Min/Max
430
 *        reducer.
431
 * @param arge Arg-Min/Max on edges. which refers the source node indices
432
 *        correspond to the minimum/maximum values of reduction result on
433
 *        destination nodes. It's useful in computing gradients of Min/Max
434
 *        reducer.
435
 * @note it uses node parallel strategy, different threads are responsible for
436
437
 *       the computation of different nodes. To avoid possible data hazard, we
 *       use atomic operators in the reduction phase.
438
 * @note The result will contain infinity for zero-degree nodes.
439
440
 */
template <typename IdType, typename DType, typename Op, typename Cmp>
441
442
443
void SpMMCmpCoo(
    const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat,
    NDArray out, NDArray argu, NDArray arge) {
444
445
446
  const bool has_idx = !IsNullArray(coo.data);
  const IdType* row = static_cast<IdType*>(coo.row->data);
  const IdType* col = static_cast<IdType*>(coo.col->data);
447
  const IdType* edges =
448
      has_idx ? static_cast<IdType*>(coo.data->data) : nullptr;
449
450
451
  const DType* X = Op::use_lhs ? static_cast<DType*>(ufeat->data) : nullptr;
  const DType* W = Op::use_rhs ? static_cast<DType*>(efeat->data) : nullptr;
  const int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len,
452
453
                rhs_dim = bcast.rhs_len;
  DType* O = static_cast<DType*>(out->data);
454
455
  IdType* argX = Op::use_lhs ? static_cast<IdType*>(argu->data) : nullptr;
  IdType* argW = Op::use_rhs ? static_cast<IdType*>(arge->data) : nullptr;
456
457
458
459
460
461
462
463
  const int64_t nnz = coo.row->shape[0];
  // fill zero elements
  std::fill(O, O + out.NumElements(), Cmp::zero);
  // spmm
#pragma omp parallel for
  for (IdType i = 0; i < nnz; ++i) {
    const IdType rid = row[i];
    const IdType cid = col[i];
464
    const IdType eid = has_idx ? edges[i] : i;
465
    DType* out_off = O + cid * dim;
466
467
    IdType* argx_off = Op::use_lhs ? argX + cid * dim : nullptr;
    IdType* argw_off = Op::use_rhs ? argW + cid * dim : nullptr;
468
469
470
    for (int64_t k = 0; k < dim; ++k) {
      const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
      const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
471
      const DType* lhs_off =
472
          Op::use_lhs ? X + rid * lhs_dim + lhs_add : nullptr;
473
      const DType* rhs_off =
474
          Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
475
476
477
478
      const DType val = Op::Call(lhs_off, rhs_off);
#pragma omp critical
      if (Cmp::Call(out_off[k], val)) {
        out_off[k] = val;
479
480
        if (Op::use_lhs) argx_off[k] = rid;
        if (Op::use_rhs) argw_off[k] = eid;
481
482
483
484
485
      }
    }
  }
}

486
/**
487
488
489
490
491
492
 * @brief CPU kernel of Edge_softmax_csr_forward on Csr format.
 * @param bcast Broadcast information.
 * @param csr The Csr matrix.
 * @param ufeat The feature on source nodes.
 * @param efeat The feature on edges.
 * @param out The result of edge_softmax_forward.
493
494
 */
template <typename IdType, typename DType, typename Op>
495
496
497
void Edge_softmax_csr_forward(
    const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,
    NDArray out) {
498
499
500
  const bool has_idx = !IsNullArray(csr.data);
  const IdType* indptr = static_cast<IdType*>(csr.indptr->data);
  const IdType* edges =
501
      has_idx ? static_cast<IdType*>(csr.data->data) : nullptr;
502
503
504
505
506
  const DType* W = Op::use_rhs ? static_cast<DType*>(efeat->data) : nullptr;
  const int64_t dim = bcast.out_len, rhs_dim = bcast.rhs_len;
  runtime::parallel_for(0, csr.num_rows, [&](size_t b, size_t e) {
    for (auto rid = b; rid < e; ++rid) {
      const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
507
508
      std::vector<DType> data_e(row_end - row_start, 0);
      std::vector<IdType> num(row_end - row_start, 0);
509
510
511
512
513
514
      for (int64_t k = 0; k < dim; ++k) {
        DType max_v = -std::numeric_limits<DType>::infinity();
        for (IdType j = row_start; j < row_end; ++j) {
          const IdType eid = has_idx ? edges[j] : j;
          const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
          const DType* rhs_off =
515
516
517
              Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
          data_e[j - row_start] = *rhs_off;
          num[j - row_start] = eid * rhs_dim + rhs_add;
518
519
520
521
522
523
524
525
          max_v = std::max<DType>(max_v, (*rhs_off));
        }
        DType exp_sum = 0;
        for (auto& element : data_e) {
          element -= max_v;
          element = std::exp(element);
          exp_sum += element;
        }
526
527
        for (int i = 0; i < row_end - row_start; i++) {
          out.Ptr<DType>()[num[i]] = data_e[i] / exp_sum;
528
529
530
531
532
533
        }
      }
    }
  });
}

534
/**
535
536
537
538
539
540
 * @brief CPU kernel of Edge_softmax_csr_backward on Csr format.
 * @param bcast Broadcast information.
 * @param csr The Csr matrix.
 * @param out The result of forward.
 * @param sds The result of gradiet * out.
 * @param back_out The result of edge_softmax_backward.
541
542
 */
template <typename IdType, typename DType, typename Op>
543
544
545
void Edge_softmax_csr_backward(
    const BcastOff& bcast, const CSRMatrix& csr, NDArray out, NDArray sds,
    NDArray back_out) {
546
547
548
  const bool has_idx = !IsNullArray(csr.data);
  const IdType* indptr = static_cast<IdType*>(csr.indptr->data);
  const IdType* edges =
549
      has_idx ? static_cast<IdType*>(csr.data->data) : nullptr;
550
551
552
553
554
555
556
557
558
559
560
561
  const DType* W_out = Op::use_rhs ? static_cast<DType*>(out->data) : nullptr;
  const DType* W_sds = Op::use_rhs ? static_cast<DType*>(sds->data) : nullptr;
  const int64_t dim = bcast.out_len, rhs_dim = bcast.rhs_len;
  runtime::parallel_for(0, csr.num_rows, [&](size_t b, size_t e) {
    for (auto rid = b; rid < e; ++rid) {
      const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
      for (int64_t k = 0; k < dim; ++k) {
        DType sum_sds = 0;
        for (IdType j = row_start; j < row_end; ++j) {
          const IdType eid = has_idx ? edges[j] : j;
          const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
          const DType* rhs_off_sds =
562
              Op::use_rhs ? W_sds + eid * rhs_dim + rhs_add : nullptr;
563
564
          sum_sds += (*rhs_off_sds);
        }
565
        for (IdType j = row_start; j < row_end; ++j) {
566
567
568
          const IdType eid = has_idx ? edges[j] : j;
          const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
          const DType* rhs_off_out =
569
              Op::use_rhs ? W_out + eid * rhs_dim + rhs_add : nullptr;
570
          const DType* rhs_off_sds =
571
572
573
              Op::use_rhs ? W_sds + eid * rhs_dim + rhs_add : nullptr;
          back_out.Ptr<DType>()[eid * rhs_dim + rhs_add] =
              (*rhs_off_sds) - sum_sds * (*rhs_off_out);
574
575
576
577
578
579
        }
      }
    }
  });
}

580
581
582
583
584
}  // namespace cpu
}  // namespace aten
}  // namespace dgl

#endif  // DGL_ARRAY_CPU_SPMM_H_