"docs/source/en/using-diffusers/custom_pipeline_examples.md" did not exist on "2fdd094c10aadd71c17d3e3d28e47b6c3cf26bd2"
spmm.h 22.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cpu/spmm.h
 * \brief SPMM CPU kernel function header.
 */
#ifndef DGL_ARRAY_CPU_SPMM_H_
#define DGL_ARRAY_CPU_SPMM_H_

#include <dgl/array.h>
#include <dgl/bcast.h>
11
#include <dgl/runtime/parallel_for.h>
12
#include <dgl/runtime/config.h>
13
#include <math.h>
14
#include <algorithm>
15
16
#include <limits>
#include <memory>
17
18
#include <algorithm>
#include <vector>
19
20
#include "spmm_binary_ops.h"
#if !defined(_WIN32)
21
#ifdef USE_AVX
22
#include "intel/cpu_support.h"
23
24
25
#ifdef USE_LIBXSMM
#include "spmm_blocking_libxsmm.h"
#endif  // USE_LIBXSMM
26
27
#endif  // USE_AVX
#endif  // _WIN32
28
29
30
31
namespace dgl {
namespace aten {
namespace cpu {

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#if !defined(_WIN32)
#ifdef USE_AVX
/*!
 * \brief CPU kernel of SpMM on Csr format using Xbyak.
 * \param cpu_spec JIT'ed kernel
 * \param bcast Broadcast information.
 * \param csr The Csr matrix.
 * \param X The feature on source nodes.
 * \param W The feature on edges.
 * \param O The result feature on destination nodes.
 * \note it uses node parallel strategy, different threads are responsible
 *       for the computation of different nodes. For each edge, it uses the
 *       JIT'ed kernel.
 */
template <typename IdType, typename DType, typename Op>
void SpMMSumCsrXbyak(dgl::ElemWiseAddUpdate<Op>* cpu_spec, const BcastOff& bcast,
                     const CSRMatrix& csr, const DType* X, const DType* W, DType* O) {
  const bool has_idx = !IsNullArray(csr.data);
  const IdType* indptr = csr.indptr.Ptr<IdType>();
  const IdType* indices = csr.indices.Ptr<IdType>();
  const IdType* edges = csr.data.Ptr<IdType>();
  int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len, rhs_dim = bcast.rhs_len;
54
55
56
57
58
59
60
61
62
63

  runtime::parallel_for(0, csr.num_rows, [&](size_t b, size_t e) {
    for (auto rid = b; rid < e; ++rid) {
      const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
      DType* out_off = O + rid * dim;
      for (IdType j = row_start; j < row_end; ++j) {
        const IdType cid = indices[j];
        const IdType eid = has_idx ? edges[j] : j;
        cpu_spec->run(out_off, X + cid * lhs_dim, W + eid * rhs_dim, dim);
      }
64
    }
65
  });
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
}
#endif  // USE_AVX
#endif  // _WIN32

/*!
 * \brief Naive CPU kernel of SpMM on Csr format.
 * \param cpu_spec JIT'ed kernel
 * \param bcast Broadcast information.
 * \param csr The Csr matrix.
 * \param X The feature on source nodes.
 * \param W The feature on edges.
 * \param O The result feature on destination nodes.
 * \note it uses node parallel strategy, different threads are responsible
 *       for the computation of different nodes.
 */
template <typename IdType, typename DType, typename Op>
void SpMMSumCsrNaive(const BcastOff& bcast, const CSRMatrix& csr, const DType* X,
                     const DType* W, DType* O) {
  const bool has_idx = !IsNullArray(csr.data);
  const IdType* indptr = csr.indptr.Ptr<IdType>();
  const IdType* indices = csr.indices.Ptr<IdType>();
  const IdType* edges = csr.data.Ptr<IdType>();
  int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len, rhs_dim = bcast.rhs_len;
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
  runtime::parallel_for(0, csr.num_rows, [&](size_t b, size_t e) {
    for (auto rid = b; rid < e; ++rid) {
      const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
      DType* out_off = O + rid * dim;
      for (IdType j = row_start; j < row_end; ++j) {
        const IdType cid = indices[j];
        const IdType eid = has_idx ? edges[j] : j;
        for (int64_t k = 0; k < dim; ++k) {
          const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
          const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
          const DType* lhs_off =
            Op::use_lhs ? X + cid * lhs_dim + lhs_add : nullptr;
          const DType* rhs_off =
            Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
          out_off[k] += Op::Call(lhs_off, rhs_off);
        }
105
106
      }
    }
107
  });
108
109
}

110
111
112
113
114
115
116
117
118
119
120
/*!
 * \brief CPU kernel of SpMM on Csr format.
 * \param bcast Broadcast information.
 * \param csr The Csr matrix.
 * \param ufeat The feature on source nodes.
 * \param efeat The feature on edges.
 * \param out The result feature on destination nodes.
 * \note it uses node parallel strategy, different threads are responsible
 *       for the computation of different nodes.
 */
template <typename IdType, typename DType, typename Op>
121
122
void SpMMSumCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
                NDArray efeat, NDArray out) {
123
124
125
126
127
128
  const bool has_idx = !IsNullArray(csr.data);
  const IdType* indptr = csr.indptr.Ptr<IdType>();
  const IdType* indices = csr.indices.Ptr<IdType>();
  const IdType* edges = csr.data.Ptr<IdType>();
  const DType* X = ufeat.Ptr<DType>();
  const DType* W = efeat.Ptr<DType>();
129
  int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len, rhs_dim = bcast.rhs_len;
130
  DType* O = out.Ptr<DType>();
131
132
133
134
135
136
137
138
139
140
141
  CHECK_NOTNULL(indptr);
  CHECK_NOTNULL(O);
  if (Op::use_lhs) {
    CHECK_NOTNULL(indices);
    CHECK_NOTNULL(X);
  }
  if (Op::use_rhs) {
    if (has_idx)
      CHECK_NOTNULL(edges);
    CHECK_NOTNULL(W);
  }
142
#if !defined(_WIN32)
143
#ifdef USE_AVX
144
145
#ifdef USE_LIBXSMM
  const bool no_libxsmm =
146
147
148
       bcast.use_bcast ||
       std::is_same<DType, double>::value ||
       !dgl::runtime::Config::Global()->IsLibxsmmAvailable();
149
150
  if (!no_libxsmm) {
    SpMMSumCsrLibxsmm<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
151
  } else {
152
153
154
155
156
157
158
159
160
161
162
163
#endif  // USE_LIBXSMM
    typedef dgl::ElemWiseAddUpdate<Op> ElemWiseUpd;
    /* Prepare an assembler kernel */
    static std::unique_ptr<ElemWiseUpd> asm_kernel_ptr(
        (dgl::IntelKernel<>::IsEnabled()) ? new ElemWiseUpd() : nullptr);
    /* Distribute the kernel among OMP threads */
    ElemWiseUpd* cpu_spec = (asm_kernel_ptr && asm_kernel_ptr->applicable())
      ? asm_kernel_ptr.get()
      : nullptr;
    if (cpu_spec && dim > 16 && !bcast.use_bcast) {
      SpMMSumCsrXbyak<IdType, DType, Op>(cpu_spec, bcast, csr, X, W, O);
    } else {
164
165
#endif  // USE_AVX
#endif  // _WIN32
166
    SpMMSumCsrNaive<IdType, DType, Op>(bcast, csr, X, W, O);
167
#if !defined(_WIN32)
168
#ifdef USE_AVX
169
170
    }
#ifdef USE_LIBXSMM
171
  }
172
#endif  // USE_LIBXSMM
173
174
#endif  // USE_AVX
#endif  // _WIN32
175
176
177
178
179
180
181
182
183
184
185
186
187
188
}

/*!
 * \brief CPU kernel of SpMM on Coo format.
 * \param bcast Broadcast information.
 * \param coo The Coo matrix.
 * \param ufeat The feature on source nodes.
 * \param efeat The feature on edges.
 * \param out The result feature on destination nodes.
 * \note it uses node parallel strategy, different threads are responsible
 *       for the computation of different nodes. To avoid possible data hazard,
 *       we use atomic operators in the reduction phase.
 */
template <typename IdType, typename DType, typename Op>
189
190
void SpMMSumCoo(const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat,
                NDArray efeat, NDArray out) {
191
192
193
194
195
196
  const bool has_idx = !IsNullArray(coo.data);
  const IdType* row = coo.row.Ptr<IdType>();
  const IdType* col = coo.col.Ptr<IdType>();
  const IdType* edges = coo.data.Ptr<IdType>();
  const DType* X = ufeat.Ptr<DType>();
  const DType* W = efeat.Ptr<DType>();
197
  int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len, rhs_dim = bcast.rhs_len;
198
199
200
201
202
203
204
205
206
  DType* O = out.Ptr<DType>();
  const int64_t nnz = coo.row->shape[0];
  // fill zero elements
  memset(O, 0, out.GetSize());
  // spmm
#pragma omp parallel for
  for (IdType i = 0; i < nnz; ++i) {
    const IdType rid = row[i];
    const IdType cid = col[i];
207
    const IdType eid = has_idx ? edges[i] : i;
208
209
210
211
    DType* out_off = O + cid * dim;
    for (int64_t k = 0; k < dim; ++k) {
      const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
      const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
212
213
214
215
      const DType* lhs_off =
        Op::use_lhs ? X + rid * lhs_dim + lhs_add : nullptr;
      const DType* rhs_off =
        Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
216
      const DType val = Op::Call(lhs_off, rhs_off);
217
      if (val != 0) {
218
#pragma omp atomic
219
220
        out_off[k] += val;
      }
221
222
223
224
225
226
227
228
229
230
231
    }
  }
}

/*!
 * \brief CPU kernel of SpMM-Min/Max on Csr format.
 * \param bcast Broadcast information.
 * \param csr The Csr matrix.
 * \param ufeat The feature on source nodes.
 * \param efeat The feature on edges.
 * \param out The result feature on destination nodes.
232
 * \param argu Arg-Min/Max on source nodes, which refers the source node indices
233
 *        correspond to the minimum/maximum values of reduction result on
234
235
236
237
238
239
240
 *        destination nodes. It's useful in computing gradients of Min/Max
 * reducer. \param arge Arg-Min/Max on edges. which refers the source node
 * indices correspond to the minimum/maximum values of reduction result on
 *        destination nodes. It's useful in computing gradients of Min/Max
 * reducer. \note It uses node parallel strategy, different threads are
 * responsible for the computation of different nodes. \note The result will
 * contain infinity for zero-degree nodes.
241
242
 */
template <typename IdType, typename DType, typename Op, typename Cmp>
243
244
void SpMMCmpCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
                NDArray efeat, NDArray out, NDArray argu, NDArray arge) {
245
246
247
  const bool has_idx = !IsNullArray(csr.data);
  const IdType* indptr = static_cast<IdType*>(csr.indptr->data);
  const IdType* indices = static_cast<IdType*>(csr.indices->data);
248
249
250
251
252
  const IdType* edges =
    has_idx ? static_cast<IdType*>(csr.data->data) : nullptr;
  const DType* X = Op::use_lhs ? static_cast<DType*>(ufeat->data) : nullptr;
  const DType* W = Op::use_rhs ? static_cast<DType*>(efeat->data) : nullptr;
  const int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len,
253
254
                rhs_dim = bcast.rhs_len;
  DType* O = static_cast<DType*>(out->data);
255
256
  IdType* argX = Op::use_lhs ? static_cast<IdType*>(argu->data) : nullptr;
  IdType* argW = Op::use_rhs ? static_cast<IdType*>(arge->data) : nullptr;
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
  CHECK_NOTNULL(indptr);
  CHECK_NOTNULL(O);
  if (Op::use_lhs) {
    CHECK_NOTNULL(indices);
    CHECK_NOTNULL(X);
    CHECK_NOTNULL(argX);
  }
  if (Op::use_rhs) {
    if (has_idx)
      CHECK_NOTNULL(edges);
    CHECK_NOTNULL(W);
    CHECK_NOTNULL(argW);
  }
#if !defined(_WIN32)
#ifdef USE_AVX
#ifdef USE_LIBXSMM

  const bool no_libxsmm =
275
276
277
       bcast.use_bcast ||
       std::is_same<DType, double>::value ||
       !dgl::runtime::Config::Global()->IsLibxsmmAvailable();
278
279
280
281
282
283
284
  if (!no_libxsmm) {
    SpMMCmpCsrLibxsmm<IdType, DType, Op, Cmp>(bcast, csr, ufeat, efeat, out, argu, arge);
  } else {
#endif  // USE_LIBXSMM
#endif  // USE_AVX
#endif  // _WIN32

285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
    runtime::parallel_for(0, csr.num_rows, [&](size_t b, size_t e) {
      for (auto rid = b; rid < e; ++rid) {
        const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
        DType* out_off = O + rid * dim;
        IdType* argx_off = argX + rid * dim;
        IdType* argw_off = argW + rid * dim;
        for (IdType j = row_start; j < row_end; ++j) {
          const IdType cid = indices[j];
          const IdType eid = has_idx ? edges[j] : j;
          for (int64_t k = 0; k < dim; ++k) {
            const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
            const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
            const DType* lhs_off =
              Op::use_lhs ? X + cid * lhs_dim + lhs_add : nullptr;
            const DType* rhs_off =
              Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
            const DType val = Op::Call(lhs_off, rhs_off);
            if (Cmp::Call(out_off[k], val)) {
              out_off[k] = val;
              if (Op::use_lhs) argx_off[k] = cid;
              if (Op::use_rhs) argw_off[k] = eid;
            }
          }
308
309
        }
      }
310
    });
311
312
313
314
315
316
317
#if !defined(_WIN32)
#ifdef USE_AVX
#ifdef USE_LIBXSMM
  }
#endif  // USE_LIBXSMM
#endif  // USE_AVX
#endif  // _WIN32
318
319
}

320
321
322
323
324
325
326
327
328
329
/*!
 * \brief CPU kernel of SpMM-Min/Max on Csr format.
 * \param bcast Broadcast information.
 * \param csr The Csr matrix.
 * \param ufeat The feature on source nodes.
 * \param efeat The feature on edges.
 * \param out The result feature on destination nodes.
 * \param argu Arg-Min/Max on source nodes, which refers the source node indices
 *        correspond to the minimum/maximum values of reduction result on
 *        destination nodes. It's useful in computing gradients of Min/Max
330
331
332
 *        reducer.
 * \param arge Arg-Min/Max on edges. which refers the source node
 *        indices correspond to the minimum/maximum values of reduction result on
333
 *        destination nodes. It's useful in computing gradients of Min/Max
334
335
336
337
338
339
340
341
342
 *        reducer.
 * \param argu_ntype Node type of the arg-Min/Max on source nodes, which refers the
 *        source node types correspond to the minimum/maximum values of reduction result
 *        on destination nodes. It's useful in computing gradients of Min/Max reducer.
 * \param arge_etype Edge-type of the arg-Min/Max on edges. which refers the source
 *        node indices correspond to the minimum/maximum values of reduction result on
 *        destination nodes. It's useful in computing gradients of Min/Max reducer.
 * \param src_type Node type of the source nodes of an etype
 * \param etype Edge type
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
 */
template <typename IdType, typename DType, typename Op, typename Cmp>
void SpMMCmpCsrHetero(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
                NDArray efeat, NDArray out, NDArray argu, NDArray arge,
                NDArray argu_ntype, NDArray arge_etype,
                const int ntype, const int etype) {
  const bool has_idx = !IsNullArray(csr.data);
  const IdType* indptr = static_cast<IdType*>(csr.indptr->data);
  const IdType* indices = static_cast<IdType*>(csr.indices->data);
  const IdType* edges =
    has_idx ? static_cast<IdType*>(csr.data->data) : nullptr;
  const DType* X = Op::use_lhs ? static_cast<DType*>(ufeat->data) : nullptr;
  const DType* W = Op::use_rhs ? static_cast<DType*>(efeat->data) : nullptr;
  const int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len,
                rhs_dim = bcast.rhs_len;
  DType* O = static_cast<DType*>(out->data);
  IdType* argX = Op::use_lhs ? static_cast<IdType*>(argu->data) : nullptr;
  IdType* argW = Op::use_rhs ? static_cast<IdType*>(arge->data) : nullptr;
  IdType* argX_ntype = Op::use_lhs ? static_cast<IdType*>(argu_ntype->data) : nullptr;
  IdType* argW_etype = Op::use_rhs ? static_cast<IdType*>(arge_etype->data) : nullptr;
  CHECK_NOTNULL(indptr);
  CHECK_NOTNULL(O);
  if (Op::use_lhs) {
    CHECK_NOTNULL(indices);
    CHECK_NOTNULL(X);
    CHECK_NOTNULL(argX);
  }
  if (Op::use_rhs) {
    if (has_idx)
      CHECK_NOTNULL(edges);
    CHECK_NOTNULL(W);
    CHECK_NOTNULL(argW);
  }
  // TODO(Israt): Use LIBXSMM. Homogeneous graph uses LIBXMM when enabled.
  runtime::parallel_for(0, csr.num_rows, [&](size_t b, size_t e) {
    for (auto rid = b; rid < e; ++rid) {
      const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
      DType* out_off = O + rid * dim;
      IdType* argx_off = argX + rid * dim;
      IdType* argw_off = argW + rid * dim;
      IdType* argx_ntype = argX_ntype + rid * dim;
      IdType* argw_etype = argW_etype + rid * dim;
      for (IdType j = row_start; j < row_end; ++j) {
        const IdType cid = indices[j];
        const IdType eid = has_idx ? edges[j] : j;
        for (int64_t k = 0; k < dim; ++k) {
          const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
          const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
          const DType* lhs_off =
            Op::use_lhs ? X + cid * lhs_dim + lhs_add : nullptr;
          const DType* rhs_off =
            Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
          const DType val = Op::Call(lhs_off, rhs_off);
          if (Cmp::Call(out_off[k], val)) {
            out_off[k] = val;
            if (Op::use_lhs) {
              argx_off[k] = cid;
              argx_ntype[k] = ntype;
            }
            if (Op::use_rhs) {
              argw_off[k] = eid;
              argw_etype[k] = etype;
            }
          }
        }
      }
    }
  });
}


414
415
416
417
418
419
420
/*!
 * \brief CPU kernel of SpMM-Min/Max on Coo format.
 * \param bcast Broadcast information.
 * \param coo The Coo matrix.
 * \param ufeat The feature on source nodes.
 * \param efeat The feature on edges.
 * \param out The result feature on destination nodes.
421
 * \param argu Arg-Min/Max on source nodes, which refers the source node indices
422
 *        correspond to the minimum/maximum values of reduction result on
423
424
425
426
427
428
429
430
 *        destination nodes. It's useful in computing gradients of Min/Max
 * reducer. \param arge Arg-Min/Max on edges. which refers the source node
 * indices correspond to the minimum/maximum values of reduction result on
 *        destination nodes. It's useful in computing gradients of Min/Max
 * reducer. \note it uses node parallel strategy, different threads are
 * responsible for the computation of different nodes. To avoid possible data
 * hazard, we use atomic operators in the reduction phase. \note The result will
 * contain infinity for zero-degree nodes.
431
432
 */
template <typename IdType, typename DType, typename Op, typename Cmp>
433
434
void SpMMCmpCoo(const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat,
                NDArray efeat, NDArray out, NDArray argu, NDArray arge) {
435
436
437
  const bool has_idx = !IsNullArray(coo.data);
  const IdType* row = static_cast<IdType*>(coo.row->data);
  const IdType* col = static_cast<IdType*>(coo.col->data);
438
439
440
441
442
  const IdType* edges =
    has_idx ? static_cast<IdType*>(coo.data->data) : nullptr;
  const DType* X = Op::use_lhs ? static_cast<DType*>(ufeat->data) : nullptr;
  const DType* W = Op::use_rhs ? static_cast<DType*>(efeat->data) : nullptr;
  const int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len,
443
444
                rhs_dim = bcast.rhs_len;
  DType* O = static_cast<DType*>(out->data);
445
446
  IdType* argX = Op::use_lhs ? static_cast<IdType*>(argu->data) : nullptr;
  IdType* argW = Op::use_rhs ? static_cast<IdType*>(arge->data) : nullptr;
447
448
449
450
451
452
453
454
  const int64_t nnz = coo.row->shape[0];
  // fill zero elements
  std::fill(O, O + out.NumElements(), Cmp::zero);
  // spmm
#pragma omp parallel for
  for (IdType i = 0; i < nnz; ++i) {
    const IdType rid = row[i];
    const IdType cid = col[i];
455
    const IdType eid = has_idx ? edges[i] : i;
456
    DType* out_off = O + cid * dim;
457
458
    IdType* argx_off = Op::use_lhs ? argX + cid * dim : nullptr;
    IdType* argw_off = Op::use_rhs ? argW + cid * dim : nullptr;
459
460
461
    for (int64_t k = 0; k < dim; ++k) {
      const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
      const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
462
463
464
465
      const DType* lhs_off =
        Op::use_lhs ? X + rid * lhs_dim + lhs_add : nullptr;
      const DType* rhs_off =
        Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
466
467
468
469
      const DType val = Op::Call(lhs_off, rhs_off);
#pragma omp critical
      if (Cmp::Call(out_off[k], val)) {
        out_off[k] = val;
470
471
        if (Op::use_lhs) argx_off[k] = rid;
        if (Op::use_rhs) argw_off[k] = eid;
472
473
474
475
476
      }
    }
  }
}

477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569

/*!
 * \brief CPU kernel of Edge_softmax_csr_forward on Csr format.
 * \param bcast Broadcast information.
 * \param csr The Csr matrix.
 * \param ufeat The feature on source nodes.
 * \param efeat The feature on edges.
 * \param out The result of edge_softmax_forward.
 */
template <typename IdType, typename DType, typename Op>
void Edge_softmax_csr_forward(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
                NDArray efeat, NDArray out) {
  const bool has_idx = !IsNullArray(csr.data);
  const IdType* indptr = static_cast<IdType*>(csr.indptr->data);
  const IdType* edges =
    has_idx ? static_cast<IdType*>(csr.data->data) : nullptr;
  const DType* W = Op::use_rhs ? static_cast<DType*>(efeat->data) : nullptr;
  const int64_t dim = bcast.out_len, rhs_dim = bcast.rhs_len;
  runtime::parallel_for(0, csr.num_rows, [&](size_t b, size_t e) {
    for (auto rid = b; rid < e; ++rid) {
      const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
      std::vector<DType> data_e(row_end-row_start, 0);
      std::vector<IdType> num(row_end-row_start, 0);
      for (int64_t k = 0; k < dim; ++k) {
        DType max_v = -std::numeric_limits<DType>::infinity();
        for (IdType j = row_start; j < row_end; ++j) {
          const IdType eid = has_idx ? edges[j] : j;
          const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
          const DType* rhs_off =
            Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
          data_e[j-row_start] = *rhs_off;
          num[j-row_start] = eid*rhs_dim+rhs_add;
          max_v = std::max<DType>(max_v, (*rhs_off));
        }
        DType exp_sum = 0;
        for (auto& element : data_e) {
          element -= max_v;
          element = std::exp(element);
          exp_sum += element;
        }
        for (int i=0; i < row_end-row_start; i++) {
          out.Ptr<DType>()[num[i]] = data_e[i]/exp_sum;
        }
      }
    }
  });
}


/*!
 * \brief CPU kernel of Edge_softmax_csr_backward on Csr format.
 * \param bcast Broadcast information.
 * \param csr The Csr matrix.
 * \param out The result of forward.
 * \param sds The result of gradiet * out.
 * \param back_out The result of edge_softmax_backward.
 */
template <typename IdType, typename DType, typename Op>
void Edge_softmax_csr_backward(const BcastOff& bcast, const CSRMatrix& csr, NDArray out,
                NDArray sds, NDArray back_out) {
  const bool has_idx = !IsNullArray(csr.data);
  const IdType* indptr = static_cast<IdType*>(csr.indptr->data);
  const IdType* edges =
    has_idx ? static_cast<IdType*>(csr.data->data) : nullptr;
  const DType* W_out = Op::use_rhs ? static_cast<DType*>(out->data) : nullptr;
  const DType* W_sds = Op::use_rhs ? static_cast<DType*>(sds->data) : nullptr;
  const int64_t dim = bcast.out_len, rhs_dim = bcast.rhs_len;
  runtime::parallel_for(0, csr.num_rows, [&](size_t b, size_t e) {
    for (auto rid = b; rid < e; ++rid) {
      const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
      for (int64_t k = 0; k < dim; ++k) {
        DType sum_sds = 0;
        for (IdType j = row_start; j < row_end; ++j) {
          const IdType eid = has_idx ? edges[j] : j;
          const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
          const DType* rhs_off_sds =
            Op::use_rhs ? W_sds + eid * rhs_dim + rhs_add : nullptr;
          sum_sds += (*rhs_off_sds);
        }
        for (IdType j = row_start; j< row_end; ++j) {
          const IdType eid = has_idx ? edges[j] : j;
          const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
          const DType* rhs_off_out =
            Op::use_rhs ? W_out + eid * rhs_dim + rhs_add : nullptr;
          const DType* rhs_off_sds =
            Op::use_rhs ? W_sds + eid * rhs_dim + rhs_add : nullptr;
          back_out.Ptr<DType>()[eid*rhs_dim+rhs_add] =  (*rhs_off_sds) - sum_sds*(*rhs_off_out);
        }
      }
    }
  });
}

570
571
572
573
574
}  // namespace cpu
}  // namespace aten
}  // namespace dgl

#endif  // DGL_ARRAY_CPU_SPMM_H_