spmm_blocking_libxsmm.h 22.6 KB
Newer Older
1
/**
2
 *  Copyright (c) 2021 Intel Corporation
3
4
5
 * @file array/cpu/spmm.h
 * @brief SPMM CPU kernel function header.
 * @author Sanchit Misra <sanchit.misra@intel.com>,
6
7
8
9
10
11
12
13
14
15
 *         Ramanarayan Mohanty <ramanarayan.mohanty@intel.com>,
 *         Vasimuddin Md <vasimuddin.md@intel.com>,
 *         Sasikanth Avancha <sasikanth.avancha@intel.com>
 */
#ifndef DGL_ARRAY_CPU_SPMM_BLOCKING_LIBXSMM_H_
#define DGL_ARRAY_CPU_SPMM_BLOCKING_LIBXSMM_H_

#include <dgl/array.h>
#include <dgl/bcast.h>
#include <dmlc/logging.h>
16

17
18
19
20
21
#include <algorithm>

#if !defined(_WIN32)
#ifdef USE_LIBXSMM
#include <libxsmm.h>
22
#include <unistd.h>
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#ifdef DEBUG
#include <x86intrin.h>
#endif  // DEBUG
#include <dmlc/omp.h>

#define NUM_BLOCKS_PER_THREAD 20
#define BLOCKING_HEURISTIC_PARAM 500

namespace dgl {
namespace aten {
namespace cpu {

template <typename IdType, typename DType>
struct CSRMatrixInternal {
  IdType num_rows;
  IdType num_cols;
  IdType *indptr;
  IdType *indices;
  DType *data;
};

int32_t GetLLCSize() {
45
#ifdef _SC_LEVEL3_CACHE_SIZE
46
47
  int32_t cache_size = sysconf(_SC_LEVEL3_CACHE_SIZE);
  if (cache_size < 0) cache_size = DGL_CPU_LLC_SIZE;
48
49
50
#else
  int32_t cache_size = DGL_CPU_LLC_SIZE;
#endif
51
52
53
  return cache_size;
}

54
/**
55
 * @brief Tile the CSR matrix to roughly make sure that the column tiles and
56
57
 *        corresponding neighbor features fit into LLC and the row tiles
 *        are assigned to OMP threads.
58
59
60
 * @param csr The Csr matrix.
 * @param block_csr_array The array containing csr matrices of all blocks.
 * @param num_M_blocks Number of blocks to create along the rows of adjacency
61
 *        matrix.
62
 * @param num_K_blocks Number of blocks to create along the columns of adjacency
63
 *        matrix.
64
65
66
67
 * @param M_block_size block size along the rows of adjacency matrix.
 * @param K_block_size block size along the columns of adjacency matrix.
 * @param use_lhs Whether to use lhs.
 * @param use_rhs Whether to use rhs.
68
69
70
 */
template <typename IdType>
inline void SpMMCreateBlocks(
71
72
73
    const CSRMatrix &csr, CSRMatrixInternal<IdType, IdType> *block_csr_array,
    IdType num_M_blocks, IdType num_K_blocks, IdType M_block_size,
    IdType K_block_size, bool use_lhs, bool use_rhs) {
74
75
  const IdType M = csr.num_rows;
  const IdType K = csr.num_cols;
76
77
78
  IdType *indptr = csr.indptr.Ptr<IdType>();
  IdType *indices = csr.indices.Ptr<IdType>();
  IdType *edges = csr.data.Ptr<IdType>();
79
  CHECK_NOTNULL(indptr);
80
81
  if (use_lhs) CHECK_NOTNULL(indices);
  if (use_rhs) CHECK_NOTNULL(edges);
82
83

  if (num_K_blocks > 1) {
84
85
    IdType *indptr_block_buf = reinterpret_cast<IdType *>(aligned_alloc(
        64, (M_block_size + 1) * num_M_blocks * num_K_blocks * sizeof(IdType)));
86
87
88
89
90
91
92
93
94
95
    IdType *indices_block_buf = nullptr;
    if (use_lhs) {
      indices_block_buf = reinterpret_cast<IdType *>(
          aligned_alloc(64, indptr[M] * sizeof(IdType)));
    }
    IdType *edges_block_buf = nullptr;
    if (use_rhs) {
      edges_block_buf = reinterpret_cast<IdType *>(
          aligned_alloc(64, indptr[M] * sizeof(IdType)));
    }
96
97
98

#pragma omp parallel
    {
99
100
      IdType *my_cur_col_id = reinterpret_cast<IdType *>(
          aligned_alloc(64, 2 * M_block_size * sizeof(IdType)));
101
102
103
104
105
106
107
108
109

#pragma omp for
      for (IdType m = 0; m < num_M_blocks; m++) {
        const IdType M_start = m * M_block_size;
        const IdType M_end = std::min((m + 1) * M_block_size, M);
        const IdType nnz = indptr[M_end] - indptr[M_start];

        IdType cur_indices_id = 0;
        IdType *my_indices_block_buf, *my_edges_block_buf;
110
111
        if (use_lhs) my_indices_block_buf = indices_block_buf + indptr[M_start];
        if (use_rhs) my_edges_block_buf = edges_block_buf + indptr[M_start];
112
113
114
115
116
117
118
119
120
121
122
123

        for (IdType i = M_start; i < M_end; i++) {
          my_cur_col_id[(i - M_start) * 2] = indptr[i];
          my_cur_col_id[(i - M_start) * 2 + 1] = indptr[i + 1];
        }
        for (IdType k = 0; k < num_K_blocks; k++) {
          const IdType K_start = k * K_block_size;
          const IdType K_end = std::min((k + 1) * K_block_size, K);
          CSRMatrixInternal<IdType, IdType> cur_csr;
          cur_csr.num_rows = M_end - M_start;
          cur_csr.num_cols = K_end - K_start;
          // Create csr_ij
124
125
          IdType *cur_csr_indptr =
              indptr_block_buf + (m * num_K_blocks + k) * (M_block_size + 1);
126
          IdType *cur_csr_indices = nullptr, *cur_csr_edges = nullptr;
127
128
          if (use_lhs) cur_csr_indices = my_indices_block_buf + cur_indices_id;
          if (use_rhs) cur_csr_edges = my_edges_block_buf + cur_indices_id;
129
130
131
          IdType cur_nnz = 0;
          for (IdType i = M_start; i < M_end; i++) {
            const IdType row_start = my_cur_col_id[(i - M_start) * 2];
132
            const IdType row_end = my_cur_col_id[(i - M_start) * 2 + 1];
133
134
135
136
137
138
139
140
141
            cur_csr_indptr[i - M_start] = cur_nnz;
            IdType eid;
            for (eid = row_start; eid < row_end; eid++) {
              const IdType src = indices[eid];
              const IdType edge = edges[eid];
              if (src >= K_end) {
                break;
              }
              CHECK_LT(cur_indices_id + cur_nnz, nnz);
142
143
              if (use_lhs) cur_csr_indices[cur_nnz] = src;
              if (use_rhs) cur_csr_edges[cur_nnz] = edge;
144
145
146
147
148
149
150
              cur_nnz++;
            }
            my_cur_col_id[(i - M_start) * 2] = eid;
          }
          cur_csr_indptr[cur_csr.num_rows] = cur_nnz;
          cur_indices_id += cur_nnz;
          cur_csr.indptr = cur_csr_indptr;
151
152
          if (use_lhs) cur_csr.indices = cur_csr_indices;
          if (use_rhs) cur_csr.data = cur_csr_edges;
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
          block_csr_array[m * num_K_blocks + k] = cur_csr;
        }
        CHECK_EQ(nnz, cur_indices_id);
      }
      free(my_cur_col_id);
    }
  } else {
    for (IdType m = 0; m < num_M_blocks; m++) {
      const IdType M_start = m * M_block_size;
      const IdType M_end = std::min((m + 1) * M_block_size, M);

      CSRMatrixInternal<IdType, IdType> cur_csr;
      cur_csr.num_rows = M_end - M_start;
      cur_csr.num_cols = K;
      cur_csr.indptr = indptr + M_start;
      cur_csr.indices = indices;
      cur_csr.data = edges;

      block_csr_array[m] = cur_csr;
    }
  }
}

176
/**
177
178
179
180
181
182
 * @brief Create libxsmm kernel.
 * @param has_idx For the edge features, are there indices available.
 * @param N Feature size.
 * @param redop_flag Flag specifying the reduction operation.
 * @param is_cmp Is the reduction operation a compare operation.
 * @note libxsmm_dispatch_meltw_opreduce_vecs_idx creates a JIT'ed kernel.
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
 *       Given a node u, the kernel performs an elementwise "Op" on the
 *       features of the neighbors and/or the edges incident on u.
 *       Subsequently, it performs an elementwise "Redop" on all such
 *       features created and stores into the feature of node u.
 *       It uses a SIMD and a cache efficient design and also provides
 *       support to enable software prefetching if needed. For IdType,
 *       it supports INT32 and INT64. For DType, it supports BF16 and FP32.
 *       It supports all the "Ops" and "Redops" supported by DGL. Once a
 *       kernel is generated by libxsmm_dispatch_meltw_opreduce_vecs_idx,
 *       it is cached for the entire duration of the execution of a program
 *       so that subsequently if the kernel is needed again, it just returns
 *       the cached copy.
 */
template <typename IdType, typename DType, typename Op>
inline libxsmm_meltwfunction_opreduce_vecs_idx SpMMCreateLibxsmmKernel(
198
    bool has_idx, IdType N, libxsmm_meltw_opreduce_vecs_flags redop_flag,
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    bool is_cmp) {
  int _ld = N;
  libxsmm_meltw_opreduce_vecs_flags opredop_flags;
  // First, set the Op in the opredop_flags
  if (std::is_same<Op, op::Add<DType>>::value) {
    opredop_flags = LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_ADD;
  } else if (std::is_same<Op, op::Sub<DType>>::value) {
    opredop_flags = LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_SUB;
  } else if (std::is_same<Op, op::Mul<DType>>::value) {
    opredop_flags = LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_MUL;
  } else if (std::is_same<Op, op::Div<DType>>::value) {
    opredop_flags = LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_DIV;
  } else if (std::is_same<Op, op::CopyLhs<DType>>::value) {
    opredop_flags = LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_COPY;
  } else if (std::is_same<Op, op::CopyRhs<DType>>::value) {
    opredop_flags = LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_COPY;
  }
  // Second, set which of lhs or rhs is considered first and second operand.
217
218
219
220
221
  // This is needed since libxsmm assumes that the copy operation always copies
  // the first operand. So, if we need to copy rhs, we need to set that as the
  // first operand. For rhs, we also set whether to use implicit indices or
  // provided indices.
  // TODO(Steve): fix this long line in a separate PR.
222
  if (std::is_same<Op, op::CopyLhs<DType>>::value) {
223
    opredop_flags =
224
        (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OPORDER_VECIDX_VECIN);  // NOLINT
225
  } else if (std::is_same<Op, op::CopyRhs<DType>>::value) {
226
    opredop_flags =
227
        (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OPORDER_VECIN_VECIDX);  // NOLINT
228
    if (!has_idx) {
229
      opredop_flags =
230
          (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_IMPLICIT_INDEXED_VECIDX);  // NOLINT
231
232
    }
  } else {
233
    opredop_flags =
234
        (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OPORDER_VECIDX_VECIN);  // NOLINT
235
    if (has_idx) {
236
      opredop_flags =
237
          (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_INDEXED_VEC);  // NOLINT
238
    } else {
239
      opredop_flags =
240
          (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_IMPLICIT_INDEXED_VEC);  // NOLINT
241
242
243
    }
  }
  // Third, we set the Redop in the opredop_flags
244
245
246
247
  opredop_flags =
      (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | redop_flag);
  // Fourth, in case of Cmp Redop, set whether to record argmax/argmin for
  // lhs/rhs
248
249
  if (is_cmp) {
    if (Op::use_lhs) {
250
      opredop_flags =
251
          (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_RECORD_ARGOP_OFF_VEC_0);  // NOLINT
252
253
    }
    if (Op::use_rhs) {
254
      opredop_flags =
255
          (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_RECORD_ARGOP_OFF_VEC_1);  // NOLINT
256
257
258
259
260
    }
  }
  libxsmm_meltwfunction_opreduce_vecs_idx kernel = nullptr;
  if (std::is_same<DType, float>::value) {
    kernel = libxsmm_dispatch_meltw_opreduce_vecs_idx(
261
262
        N, &_ld, &_ld, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32,
        (sizeof(IdType) == 8) ? LIBXSMM_DATATYPE_I64 : LIBXSMM_DATATYPE_I32,
Andrzej Kotłowski's avatar
Andrzej Kotłowski committed
263
        opredop_flags, 0);
264
265
266
267
  } else {  // assume bf16
    kernel = libxsmm_dispatch_meltw_opreduce_vecs_idx(
        N, &_ld, &_ld, LIBXSMM_DATATYPE_BF16, LIBXSMM_DATATYPE_BF16,
        (sizeof(IdType) == 8) ? LIBXSMM_DATATYPE_I64 : LIBXSMM_DATATYPE_I32,
Andrzej Kotłowski's avatar
Andrzej Kotłowski committed
268
        opredop_flags, 0);
269
  }
270

271
  if (kernel == nullptr) {
272
273
    LOG(FATAL) << "Failed to generate libxsmm kernel for the SpMM operation."
                  "To disable libxsmm, use dgl.use_libxsmm(false).";
274
275
276
277
  }
  return kernel;
}

278
/**
279
280
281
282
283
284
285
286
 * @brief Use libxsmm to perform SpMM-Sum on all blocks.
 * @param block_csr_array The array containing csr matrices of all blocks.
 * @param B The feature on source nodes.
 * @param E The feature on edges.
 * @param C The result feature on destination nodes.
 * @param has_idx For the edge features, are there indices available.
 * @param N Feature size.
 * @param num_M_blocks Number of blocks to create along the rows of adjacency
287
 *        matrix.
288
 * @param num_K_blocks Number of blocks to create along the columns of adjacency
289
 *        matrix.
290
291
 * @param M_block_size block size along the rows of adjacency matrix.
 * @param kernel The libxsmm kernel.
292
293
294
 */
template <typename IdType, typename DType>
inline void SpMMBlockwiseOpSum(
295
296
297
    CSRMatrixInternal<IdType, IdType> *block_csr_array, const DType *B,
    const DType *E, DType *C, bool has_idx, IdType N, IdType num_M_blocks,
    IdType num_K_blocks, IdType M_block_size,
298
    libxsmm_meltwfunction_opreduce_vecs_idx kernel) {
299
300
301
  const DType *in_matrix1 = B;
  const DType *in_matrix2 = E;
  DType *output = C;
302
303
304
305
306
#pragma omp parallel
  {
    for (IdType k = 0; k < num_K_blocks; k++) {
#pragma omp for schedule(dynamic)
      for (IdType m = 0; m < num_M_blocks; m++) {
307
308
        CSRMatrixInternal<IdType, IdType> cur_csr =
            block_csr_array[m * num_K_blocks + k];
309
310
311
312

        const IdType M_start = m * M_block_size;
        for (IdType i = 0; i < cur_csr.num_rows; i++) {
          const IdType row_start = cur_csr.indptr[i];
313
          const IdType row_end = cur_csr.indptr[i + 1];
314
315
316
317
318
319
          const IdType dst = i + M_start;

          libxsmm_meltw_opreduce_vecs_idx_param params;
          params.n = row_end - row_start;
          params.indices = &cur_csr.indices[row_start];
          params.in_matrix = in_matrix1;
320
          params.out_vec = &output[dst * N];
321
322
323
324
325
          params.scale_vals = nullptr;
          if (has_idx) {
            params.in_matrix2 = in_matrix2;
            params.indices2 = &cur_csr.data[row_start];
          } else {
326
            params.in_matrix2 = &in_matrix2[row_start * N];
327
328
329
330
331
332
333
334
          }
          kernel(&params);
        }
      }
    }
  }
}

335
/**
336
337
338
339
340
341
342
343
344
345
 * @brief Use libxsmm to perform SpMM-Max/Min on all blocks.
 * @param block_csr_array The array containing csr matrices of all blocks.
 * @param B The feature on source nodes.
 * @param E The feature on edges.
 * @param C The result feature on destination nodes.
 * @param argB Arg-Min/Max on source nodes.
 * @param argE Arg-Min/Max on edges.
 * @param has_idx For the edge features, are there indices available.
 * @param N Feature size.
 * @param num_M_blocks Number of blocks to create along the rows of adjacency
346
 *        matrix.
347
 * @param num_K_blocks Number of blocks to create along the columns of adjacency
348
 *        matrix.
349
350
 * @param M_block_size block size along the rows of adjacency matrix.
 * @param kernel The libxsmm kernel.
351
352
353
 */
template <typename IdType, typename DType, typename Op, typename Cmp>
inline void SpMMBlockwiseOpCmp(
354
355
356
    CSRMatrixInternal<IdType, IdType> *block_csr_array, const DType *B,
    const DType *E, DType *C, IdType *argB, IdType *argE, bool has_idx,
    IdType N, IdType num_M_blocks, IdType num_K_blocks, IdType M_block_size,
357
    libxsmm_meltwfunction_opreduce_vecs_idx kernel) {
358
359
360
361
362
  const DType *in_matrix1 = B;
  const DType *in_matrix2 = E;
  DType *output = C;
  IdType *out_matrix1 = argB;
  IdType *out_matrix2 = argE;
363
364
365
366
367
368

#pragma omp parallel
  {
    for (IdType k = 0; k < num_K_blocks; k++) {
#pragma omp for schedule(dynamic)
      for (IdType m = 0; m < num_M_blocks; m++) {
369
370
        CSRMatrixInternal<IdType, IdType> cur_csr =
            block_csr_array[m * num_K_blocks + k];
371
372
373
374

        const IdType M_start = m * M_block_size;
        for (IdType i = 0; i < cur_csr.num_rows; i++) {
          const IdType row_start = cur_csr.indptr[i];
375
          const IdType row_end = cur_csr.indptr[i + 1];
376
377
378
379
380
381
          const IdType dst = i + M_start;

          libxsmm_meltw_opreduce_vecs_idx_param params;
          params.n = row_end - row_start;
          params.indices = &cur_csr.indices[row_start];
          params.in_matrix = in_matrix1;
382
383
384
          params.out_vec = &output[dst * N];
          params.argop_off_vec_0 = &out_matrix1[dst * N];
          params.argop_off_vec_1 = &out_matrix2[dst * N];
385
386
387
388
389
          params.scale_vals = nullptr;
          if (has_idx) {
            params.in_matrix2 = in_matrix2;
            params.indices2 = &cur_csr.data[row_start];
          } else {
390
            params.in_matrix2 = &in_matrix2[row_start * N];
391
392
393
394
395
396
397
398
          }
          kernel(&params);
        }
      }
    }
  }
}

399
/**
400
401
402
 * @brief Free the tiled CSR matrix data.
 * @param block_csr_array The array containing csr matrices of all blocks.
 * @param num_M_blocks Number of blocks to create along the rows of adjacency
403
 *        matrix.
404
 * @param num_K_blocks Number of blocks to create along the columns of adjacency
405
 *        matrix.
406
407
 * @param use_lhs Whether to use lhs.
 * @param use_rhs Whether to use rhs.
408
409
410
 */
template <typename IdType>
inline void SpMMFreeBlocks(
411
412
    CSRMatrixInternal<IdType, IdType> *block_csr_array, IdType num_M_blocks,
    IdType num_K_blocks, bool use_lhs, bool use_rhs) {
413
414
  if (num_K_blocks > 1) {
    free(block_csr_array[0].indptr);
415
416
    if (use_lhs) free(block_csr_array[0].indices);
    if (use_rhs) free(block_csr_array[0].data);
417
418
419
420
  }
  free(block_csr_array);
}

421
/**
422
423
424
425
426
427
428
429
430
 * @brief Optimized CPU kernel of SpMM-Sum/Max/Min on Csr format.
 * @param bcast Broadcast information.
 * @param csr The Csr matrix.
 * @param ufeat The feature on source nodes.
 * @param efeat The feature on edges.
 * @param out The result feature on destination nodes.
 * @param argu Arg-Min/Max on source nodes.
 * @param arge Arg-Min/Max on edges.
 * @note it uses libxsmm, blocking and dynamic thread scheduling.
431
432
433
 */
template <typename IdType, typename DType, typename Op, typename Redop>
void SpMMRedopCsrOpt(
434
435
    const BcastOff &bcast, const CSRMatrix &csr, NDArray ufeat, NDArray efeat,
    NDArray out, NDArray argu, NDArray arge) {
436
437
438
439
440
441
442
443
444
  int32_t llc_size = GetLLCSize();

#ifdef DEBUG
  uint64_t startTick, endTick;
  startTick = __rdtsc();
#endif  // DEBUG

  const bool has_idx = !IsNullArray(csr.data);

445
446
447
  DType *C = out.Ptr<DType>();
  const DType *B = ufeat.Ptr<DType>();
  const DType *E = efeat.Ptr<DType>();
448
  IdType *argB, *argE;
449
450
  if (std::is_same<Redop, op::Max<DType>>::value ||
      std::is_same<Redop, op::Min<DType>>::value) {
451
452
453
454
455
456
457
458
    argB = argu.Ptr<IdType>();
    argE = arge.Ptr<IdType>();
  }

  const int nthreads = omp_get_max_threads();
  const IdType M = csr.num_rows;
  const IdType N = bcast.out_len;
  const IdType K = csr.num_cols;
459
  const IdType *indptr = csr.indptr.Ptr<IdType>();
460
  CHECK_NOTNULL(indptr);
461
  const IdType total_nnz = indptr[M];
462
463
  if (M <= 0 || K <= 0 || N <= 0 || total_nnz <= 0) return;

sanchit-misra's avatar
sanchit-misra committed
464
465
  const double avg_degree = total_nnz * 1.0 / M;
  const double nnz_prob = avg_degree / K;
466

467
468
  IdType K_block_size = std::min(
      (int64_t)K,
469
      (int64_t)(llc_size / (N * sizeof(DType) * nnz_prob * BLOCKING_HEURISTIC_PARAM)));  // NOLINT
470
471
472
473
474
475
476
477
  IdType M_block_size = M / (nthreads * NUM_BLOCKS_PER_THREAD);
  if (M_block_size == 0) M_block_size = 1;
  if (K_block_size == 0) K_block_size = 1;

  IdType num_M_blocks = (M + M_block_size - 1) / M_block_size;
  IdType num_K_blocks = (K + K_block_size - 1) / K_block_size;

  CSRMatrixInternal<IdType, IdType> *block_csr_array =
478
479
480
      (CSRMatrixInternal<IdType, IdType> *)aligned_alloc(
          64, sizeof(CSRMatrixInternal<IdType, IdType>) * num_M_blocks *
                  num_K_blocks);
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496

#ifdef DEBUG
  endTick = __rdtsc();
  if (std::is_same<Redop, op::Max<DType>>::value) {
    LOG(INFO) << "Redop = Max";
  } else if (std::is_same<Redop, op::Min<DType>>::value) {
    LOG(INFO) << "Redop = Min";
  } else if (std::is_same<Redop, op::Add<DType>>::value) {
    LOG(INFO) << "Redop = Add";
  }
  LOG(INFO) << "nthreads = " << nthreads << ", llc_size = " << llc_size;
  LOG(INFO) << "M = " << M << ", K = " << K << ", N = " << N;
  LOG(INFO) << "use_lhs = " << Op::use_lhs << ", use_rhs = " << Op::use_rhs;
  LOG(INFO) << "total_nnz = " << total_nnz << ", avg_degree = " << avg_degree;
  LOG(INFO) << "has_idx = " << has_idx;
  LOG(INFO) << "nnz_prob = " << nnz_prob;
497
498
499
500
  LOG(INFO) << "K_block_size = " << K_block_size
            << ", M_block_size = " << M_block_size;
  LOG(INFO) << "num_K_blocks = " << num_K_blocks
            << ", num_M_blocks = " << num_M_blocks;
501
502
503
504
  LOG(INFO) << "stage0 ticks = " << (endTick - startTick);
  startTick = __rdtsc();
#endif  // DEBUG

505
506
507
  SpMMCreateBlocks(
      csr, block_csr_array, num_M_blocks, num_K_blocks, M_block_size,
      K_block_size, Op::use_lhs, Op::use_rhs);
508
509
510
511
512
513
514
515
516

#ifdef DEBUG
  endTick = __rdtsc();
  LOG(INFO) << "stage1 ticks = " << (endTick - startTick);
  startTick = __rdtsc();
#endif  // DEBUG

  libxsmm_meltwfunction_opreduce_vecs_idx kernel = nullptr;
  if (std::is_same<Redop, op::Max<DType>>::value) {
517
518
    kernel = SpMMCreateLibxsmmKernel<IdType, DType, Op>(
        has_idx, N, LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_REDOP_MAX, true);
519
  } else if (std::is_same<Redop, op::Min<DType>>::value) {
520
521
    kernel = SpMMCreateLibxsmmKernel<IdType, DType, Op>(
        has_idx, N, LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_REDOP_MIN, true);
522
  } else if (std::is_same<Redop, op::Add<DType>>::value) {
523
524
    kernel = SpMMCreateLibxsmmKernel<IdType, DType, Op>(
        has_idx, N, LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_REDOP_SUM, false);
525
526
527
528
529
530
531
532
  }

#ifdef DEBUG
  endTick = __rdtsc();
  LOG(INFO) << "stage2 ticks = " << (endTick - startTick);
  startTick = __rdtsc();
#endif  // DEBUG

533
534
535
536
537
  if (std::is_same<Redop, op::Max<DType>>::value ||
      std::is_same<Redop, op::Min<DType>>::value) {
    SpMMBlockwiseOpCmp<IdType, DType, Op, Redop>(
        block_csr_array, B, E, C, argB, argE, has_idx, N, num_M_blocks,
        num_K_blocks, M_block_size, kernel);
538
  } else {
539
540
541
    SpMMBlockwiseOpSum(
        block_csr_array, B, E, C, has_idx, N, num_M_blocks, num_K_blocks,
        M_block_size, kernel);
542
543
544
545
546
547
548
549
  }

#ifdef DEBUG
  endTick = __rdtsc();
  LOG(INFO) << "stage3 ticks = " << (endTick - startTick);
  startTick = __rdtsc();
#endif  // DEBUG

550
551
  SpMMFreeBlocks(
      block_csr_array, num_M_blocks, num_K_blocks, Op::use_lhs, Op::use_rhs);
552
553
554
555
556
557
558

#ifdef DEBUG
  endTick = __rdtsc();
  LOG(INFO) << "stage4 ticks = " << (endTick - startTick);
#endif  // DEBUG
}

559
/**
560
561
562
563
564
565
566
 * @brief Optimized CPU kernel of SpMM-Sum on Csr format.
 * @param bcast Broadcast information.
 * @param csr The Csr matrix.
 * @param ufeat The feature on source nodes.
 * @param efeat The feature on edges.
 * @param out The result feature on destination nodes.
 * @note it uses libxsmm, blocking and dynamic thread scheduling.
567
568
 */
template <typename IdType, typename DType, typename Op>
569
570
571
void SpMMSumCsrLibxsmm(
    const BcastOff &bcast, const CSRMatrix &csr, NDArray ufeat, NDArray efeat,
    NDArray out) {
572
  NDArray dummy;
573
574
  SpMMRedopCsrOpt<IdType, DType, Op, op::Add<DType>>(
      bcast, csr, ufeat, efeat, out, dummy, dummy);
575
576
}

577
/**
578
579
580
581
582
583
584
585
586
 * @brief Optimized CPU kernel of SpMM-Min/Max on Csr format.
 * @param bcast Broadcast information.
 * @param csr The Csr matrix.
 * @param ufeat The feature on source nodes.
 * @param efeat The feature on edges.
 * @param out The result feature on destination nodes.
 * @param argu Arg-Min/Max on source nodes.
 * @param arge Arg-Min/Max on edges.
 * @note it uses libxsmm, blocking and dynamic thread scheduling.
587
588
 */
template <typename IdType, typename DType, typename Op, typename Cmp>
589
590
591
592
593
void SpMMCmpCsrLibxsmm(
    const BcastOff &bcast, const CSRMatrix &csr, NDArray ufeat, NDArray efeat,
    NDArray out, NDArray argu, NDArray arge) {
  SpMMRedopCsrOpt<IdType, DType, Op, Cmp>(
      bcast, csr, ufeat, efeat, out, argu, arge);
594
595
596
597
598
599
600
601
602
603
}

}  // namespace cpu
}  // namespace aten
}  // namespace dgl

#endif  // USE_LIBXSMM
#endif  // _WIN32

#endif  // DGL_ARRAY_CPU_SPMM_BLOCKING_LIBXSMM_H_