spmm_blocking_libxsmm.h 22.3 KB
Newer Older
1
/**
2
 *  Copyright (c) 2021 Intel Corporation
3
4
5
 * @file array/cpu/spmm.h
 * @brief SPMM CPU kernel function header.
 * @author Sanchit Misra <sanchit.misra@intel.com>,
6
7
8
9
10
11
12
13
14
15
 *         Ramanarayan Mohanty <ramanarayan.mohanty@intel.com>,
 *         Vasimuddin Md <vasimuddin.md@intel.com>,
 *         Sasikanth Avancha <sasikanth.avancha@intel.com>
 */
#ifndef DGL_ARRAY_CPU_SPMM_BLOCKING_LIBXSMM_H_
#define DGL_ARRAY_CPU_SPMM_BLOCKING_LIBXSMM_H_

#include <dgl/array.h>
#include <dgl/bcast.h>
#include <dmlc/logging.h>
16

17
18
19
20
21
22
#include <algorithm>

#if !defined(_WIN32)
#ifdef USE_AVX
#ifdef USE_LIBXSMM
#include <libxsmm.h>
23
#include <unistd.h>
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#ifdef DEBUG
#include <x86intrin.h>
#endif  // DEBUG
#include <dmlc/omp.h>

#define NUM_BLOCKS_PER_THREAD 20
#define BLOCKING_HEURISTIC_PARAM 500

namespace dgl {
namespace aten {
namespace cpu {

template <typename IdType, typename DType>
struct CSRMatrixInternal {
  IdType num_rows;
  IdType num_cols;
  IdType *indptr;
  IdType *indices;
  DType *data;
};

int32_t GetLLCSize() {
  int32_t cache_size = sysconf(_SC_LEVEL3_CACHE_SIZE);
  if (cache_size < 0) cache_size = DGL_CPU_LLC_SIZE;
  return cache_size;
}

51
/**
52
 * @brief Tile the CSR matrix to roughly make sure that the column tiles and
53
54
 *        corresponding neighbor features fit into LLC and the row tiles
 *        are assigned to OMP threads.
55
56
57
 * @param csr The Csr matrix.
 * @param block_csr_array The array containing csr matrices of all blocks.
 * @param num_M_blocks Number of blocks to create along the rows of adjacency
58
 *        matrix.
59
 * @param num_K_blocks Number of blocks to create along the columns of adjacency
60
 *        matrix.
61
62
63
64
 * @param M_block_size block size along the rows of adjacency matrix.
 * @param K_block_size block size along the columns of adjacency matrix.
 * @param use_lhs Whether to use lhs.
 * @param use_rhs Whether to use rhs.
65
66
67
 */
template <typename IdType>
inline void SpMMCreateBlocks(
68
69
70
    const CSRMatrix &csr, CSRMatrixInternal<IdType, IdType> *block_csr_array,
    IdType num_M_blocks, IdType num_K_blocks, IdType M_block_size,
    IdType K_block_size, bool use_lhs, bool use_rhs) {
71
72
  const IdType M = csr.num_rows;
  const IdType K = csr.num_cols;
73
74
75
  IdType *indptr = csr.indptr.Ptr<IdType>();
  IdType *indices = csr.indices.Ptr<IdType>();
  IdType *edges = csr.data.Ptr<IdType>();
76
  CHECK_NOTNULL(indptr);
77
78
  if (use_lhs) CHECK_NOTNULL(indices);
  if (use_rhs) CHECK_NOTNULL(edges);
79
80

  if (num_K_blocks > 1) {
81
82
83
84
85
86
    IdType *indptr_block_buf = reinterpret_cast<IdType *>(aligned_alloc(
        64, (M_block_size + 1) * num_M_blocks * num_K_blocks * sizeof(IdType)));
    IdType *indices_block_buf = reinterpret_cast<IdType *>(
        aligned_alloc(64, indptr[M] * sizeof(IdType)));
    IdType *edges_block_buf = reinterpret_cast<IdType *>(
        aligned_alloc(64, indptr[M] * sizeof(IdType)));
87
88
89

#pragma omp parallel
    {
90
91
      IdType *my_cur_col_id = reinterpret_cast<IdType *>(
          aligned_alloc(64, 2 * M_block_size * sizeof(IdType)));
92
93
94
95
96
97
98
99
100

#pragma omp for
      for (IdType m = 0; m < num_M_blocks; m++) {
        const IdType M_start = m * M_block_size;
        const IdType M_end = std::min((m + 1) * M_block_size, M);
        const IdType nnz = indptr[M_end] - indptr[M_start];

        IdType cur_indices_id = 0;
        IdType *my_indices_block_buf, *my_edges_block_buf;
101
102
        if (use_lhs) my_indices_block_buf = indices_block_buf + indptr[M_start];
        if (use_rhs) my_edges_block_buf = edges_block_buf + indptr[M_start];
103
104
105
106
107
108
109
110
111
112
113
114

        for (IdType i = M_start; i < M_end; i++) {
          my_cur_col_id[(i - M_start) * 2] = indptr[i];
          my_cur_col_id[(i - M_start) * 2 + 1] = indptr[i + 1];
        }
        for (IdType k = 0; k < num_K_blocks; k++) {
          const IdType K_start = k * K_block_size;
          const IdType K_end = std::min((k + 1) * K_block_size, K);
          CSRMatrixInternal<IdType, IdType> cur_csr;
          cur_csr.num_rows = M_end - M_start;
          cur_csr.num_cols = K_end - K_start;
          // Create csr_ij
115
116
          IdType *cur_csr_indptr =
              indptr_block_buf + (m * num_K_blocks + k) * (M_block_size + 1);
117
          IdType *cur_csr_indices = nullptr, *cur_csr_edges = nullptr;
118
119
          if (use_lhs) cur_csr_indices = my_indices_block_buf + cur_indices_id;
          if (use_rhs) cur_csr_edges = my_edges_block_buf + cur_indices_id;
120
121
122
          IdType cur_nnz = 0;
          for (IdType i = M_start; i < M_end; i++) {
            const IdType row_start = my_cur_col_id[(i - M_start) * 2];
123
            const IdType row_end = my_cur_col_id[(i - M_start) * 2 + 1];
124
125
126
127
128
129
130
131
132
            cur_csr_indptr[i - M_start] = cur_nnz;
            IdType eid;
            for (eid = row_start; eid < row_end; eid++) {
              const IdType src = indices[eid];
              const IdType edge = edges[eid];
              if (src >= K_end) {
                break;
              }
              CHECK_LT(cur_indices_id + cur_nnz, nnz);
133
134
              if (use_lhs) cur_csr_indices[cur_nnz] = src;
              if (use_rhs) cur_csr_edges[cur_nnz] = edge;
135
136
137
138
139
140
141
              cur_nnz++;
            }
            my_cur_col_id[(i - M_start) * 2] = eid;
          }
          cur_csr_indptr[cur_csr.num_rows] = cur_nnz;
          cur_indices_id += cur_nnz;
          cur_csr.indptr = cur_csr_indptr;
142
143
          if (use_lhs) cur_csr.indices = cur_csr_indices;
          if (use_rhs) cur_csr.data = cur_csr_edges;
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
          block_csr_array[m * num_K_blocks + k] = cur_csr;
        }
        CHECK_EQ(nnz, cur_indices_id);
      }
      free(my_cur_col_id);
    }
  } else {
#pragma omp for
    for (IdType m = 0; m < num_M_blocks; m++) {
      const IdType M_start = m * M_block_size;
      const IdType M_end = std::min((m + 1) * M_block_size, M);

      CSRMatrixInternal<IdType, IdType> cur_csr;
      cur_csr.num_rows = M_end - M_start;
      cur_csr.num_cols = K;
      cur_csr.indptr = indptr + M_start;
      cur_csr.indices = indices;
      cur_csr.data = edges;

      block_csr_array[m] = cur_csr;
    }
  }
}

168
/**
169
170
171
172
173
174
 * @brief Create libxsmm kernel.
 * @param has_idx For the edge features, are there indices available.
 * @param N Feature size.
 * @param redop_flag Flag specifying the reduction operation.
 * @param is_cmp Is the reduction operation a compare operation.
 * @note libxsmm_dispatch_meltw_opreduce_vecs_idx creates a JIT'ed kernel.
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
 *       Given a node u, the kernel performs an elementwise "Op" on the
 *       features of the neighbors and/or the edges incident on u.
 *       Subsequently, it performs an elementwise "Redop" on all such
 *       features created and stores into the feature of node u.
 *       It uses a SIMD and a cache efficient design and also provides
 *       support to enable software prefetching if needed. For IdType,
 *       it supports INT32 and INT64. For DType, it supports BF16 and FP32.
 *       It supports all the "Ops" and "Redops" supported by DGL. Once a
 *       kernel is generated by libxsmm_dispatch_meltw_opreduce_vecs_idx,
 *       it is cached for the entire duration of the execution of a program
 *       so that subsequently if the kernel is needed again, it just returns
 *       the cached copy.
 */
template <typename IdType, typename DType, typename Op>
inline libxsmm_meltwfunction_opreduce_vecs_idx SpMMCreateLibxsmmKernel(
190
    bool has_idx, IdType N, libxsmm_meltw_opreduce_vecs_flags redop_flag,
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    bool is_cmp) {
  int _ld = N;
  libxsmm_meltw_opreduce_vecs_flags opredop_flags;
  // First, set the Op in the opredop_flags
  if (std::is_same<Op, op::Add<DType>>::value) {
    opredop_flags = LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_ADD;
  } else if (std::is_same<Op, op::Sub<DType>>::value) {
    opredop_flags = LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_SUB;
  } else if (std::is_same<Op, op::Mul<DType>>::value) {
    opredop_flags = LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_MUL;
  } else if (std::is_same<Op, op::Div<DType>>::value) {
    opredop_flags = LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_DIV;
  } else if (std::is_same<Op, op::CopyLhs<DType>>::value) {
    opredop_flags = LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_COPY;
  } else if (std::is_same<Op, op::CopyRhs<DType>>::value) {
    opredop_flags = LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_COPY;
  }
  // Second, set which of lhs or rhs is considered first and second operand.
209
210
211
212
213
  // This is needed since libxsmm assumes that the copy operation always copies
  // the first operand. So, if we need to copy rhs, we need to set that as the
  // first operand. For rhs, we also set whether to use implicit indices or
  // provided indices.
  // TODO(Steve): fix this long line in a separate PR.
214
  if (std::is_same<Op, op::CopyLhs<DType>>::value) {
215
    opredop_flags =
216
        (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OPORDER_VECIDX_VECIN);  // NOLINT
217
  } else if (std::is_same<Op, op::CopyRhs<DType>>::value) {
218
    opredop_flags =
219
        (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OPORDER_VECIN_VECIDX);  // NOLINT
220
    if (!has_idx) {
221
      opredop_flags =
222
          (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_IMPLICIT_INDEXED_VECIDX);  // NOLINT
223
224
    }
  } else {
225
    opredop_flags =
226
        (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OPORDER_VECIDX_VECIN);  // NOLINT
227
    if (has_idx) {
228
      opredop_flags =
229
          (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_INDEXED_VEC);  // NOLINT
230
    } else {
231
      opredop_flags =
232
          (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_IMPLICIT_INDEXED_VEC);  // NOLINT
233
234
235
    }
  }
  // Third, we set the Redop in the opredop_flags
236
237
238
239
  opredop_flags =
      (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | redop_flag);
  // Fourth, in case of Cmp Redop, set whether to record argmax/argmin for
  // lhs/rhs
240
241
  if (is_cmp) {
    if (Op::use_lhs) {
242
      opredop_flags =
243
          (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_RECORD_ARGOP_OFF_VEC_0);  // NOLINT
244
245
    }
    if (Op::use_rhs) {
246
      opredop_flags =
247
          (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_RECORD_ARGOP_OFF_VEC_1);  // NOLINT
248
249
250
251
252
    }
  }
  libxsmm_meltwfunction_opreduce_vecs_idx kernel = nullptr;
  if (std::is_same<DType, float>::value) {
    kernel = libxsmm_dispatch_meltw_opreduce_vecs_idx(
253
254
255
        N, &_ld, &_ld, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32,
        (sizeof(IdType) == 8) ? LIBXSMM_DATATYPE_I64 : LIBXSMM_DATATYPE_I32,
        opredop_flags);
256
257
  }
  if (kernel == nullptr) {
258
259
    LOG(FATAL) << "Failed to generate libxsmm kernel for the SpMM operation."
                  "To disable libxsmm, use dgl.use_libxsmm(false).";
260
261
262
263
  }
  return kernel;
}

264
/**
265
266
267
268
269
270
271
272
 * @brief Use libxsmm to perform SpMM-Sum on all blocks.
 * @param block_csr_array The array containing csr matrices of all blocks.
 * @param B The feature on source nodes.
 * @param E The feature on edges.
 * @param C The result feature on destination nodes.
 * @param has_idx For the edge features, are there indices available.
 * @param N Feature size.
 * @param num_M_blocks Number of blocks to create along the rows of adjacency
273
 *        matrix.
274
 * @param num_K_blocks Number of blocks to create along the columns of adjacency
275
 *        matrix.
276
277
 * @param M_block_size block size along the rows of adjacency matrix.
 * @param kernel The libxsmm kernel.
278
279
280
 */
template <typename IdType, typename DType>
inline void SpMMBlockwiseOpSum(
281
282
283
    CSRMatrixInternal<IdType, IdType> *block_csr_array, const DType *B,
    const DType *E, DType *C, bool has_idx, IdType N, IdType num_M_blocks,
    IdType num_K_blocks, IdType M_block_size,
284
    libxsmm_meltwfunction_opreduce_vecs_idx kernel) {
285
286
287
  DType(*in_matrix1)[N] = (DType(*)[N])B;
  DType(*in_matrix2)[N] = (DType(*)[N])E;
  DType(*output)[N] = (DType(*)[N])C;
288
289
290
291
292
#pragma omp parallel
  {
    for (IdType k = 0; k < num_K_blocks; k++) {
#pragma omp for schedule(dynamic)
      for (IdType m = 0; m < num_M_blocks; m++) {
293
294
        CSRMatrixInternal<IdType, IdType> cur_csr =
            block_csr_array[m * num_K_blocks + k];
295
296
297
298

        const IdType M_start = m * M_block_size;
        for (IdType i = 0; i < cur_csr.num_rows; i++) {
          const IdType row_start = cur_csr.indptr[i];
299
          const IdType row_end = cur_csr.indptr[i + 1];
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
          const IdType dst = i + M_start;

          libxsmm_meltw_opreduce_vecs_idx_param params;
          params.n = row_end - row_start;
          params.indices = &cur_csr.indices[row_start];
          params.in_matrix = in_matrix1;
          params.out_vec = &output[dst][0];
          params.scale_vals = nullptr;
          if (has_idx) {
            params.in_matrix2 = in_matrix2;
            params.indices2 = &cur_csr.data[row_start];
          } else {
            params.in_matrix2 = &in_matrix2[row_start];
          }
          kernel(&params);
        }
      }
    }
  }
}

321
/**
322
323
324
325
326
327
328
329
330
331
 * @brief Use libxsmm to perform SpMM-Max/Min on all blocks.
 * @param block_csr_array The array containing csr matrices of all blocks.
 * @param B The feature on source nodes.
 * @param E The feature on edges.
 * @param C The result feature on destination nodes.
 * @param argB Arg-Min/Max on source nodes.
 * @param argE Arg-Min/Max on edges.
 * @param has_idx For the edge features, are there indices available.
 * @param N Feature size.
 * @param num_M_blocks Number of blocks to create along the rows of adjacency
332
 *        matrix.
333
 * @param num_K_blocks Number of blocks to create along the columns of adjacency
334
 *        matrix.
335
336
 * @param M_block_size block size along the rows of adjacency matrix.
 * @param kernel The libxsmm kernel.
337
338
339
 */
template <typename IdType, typename DType, typename Op, typename Cmp>
inline void SpMMBlockwiseOpCmp(
340
341
342
    CSRMatrixInternal<IdType, IdType> *block_csr_array, const DType *B,
    const DType *E, DType *C, IdType *argB, IdType *argE, bool has_idx,
    IdType N, IdType num_M_blocks, IdType num_K_blocks, IdType M_block_size,
343
    libxsmm_meltwfunction_opreduce_vecs_idx kernel) {
344
345
346
347
348
  DType(*in_matrix1)[N] = (DType(*)[N])B;
  DType(*in_matrix2)[N] = (DType(*)[N])E;
  DType(*output)[N] = (DType(*)[N])C;
  IdType(*out_matrix1)[N] = (IdType(*)[N])argB;
  IdType(*out_matrix2)[N] = (IdType(*)[N])argE;
349
350
351
352
353
354

#pragma omp parallel
  {
    for (IdType k = 0; k < num_K_blocks; k++) {
#pragma omp for schedule(dynamic)
      for (IdType m = 0; m < num_M_blocks; m++) {
355
356
        CSRMatrixInternal<IdType, IdType> cur_csr =
            block_csr_array[m * num_K_blocks + k];
357
358
359
360

        const IdType M_start = m * M_block_size;
        for (IdType i = 0; i < cur_csr.num_rows; i++) {
          const IdType row_start = cur_csr.indptr[i];
361
          const IdType row_end = cur_csr.indptr[i + 1];
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
          const IdType dst = i + M_start;

          libxsmm_meltw_opreduce_vecs_idx_param params;
          params.n = row_end - row_start;
          params.indices = &cur_csr.indices[row_start];
          params.in_matrix = in_matrix1;
          params.out_vec = &output[dst][0];
          params.argop_off_vec_0 = &out_matrix1[dst][0];
          params.argop_off_vec_1 = &out_matrix2[dst][0];
          params.scale_vals = nullptr;
          if (has_idx) {
            params.in_matrix2 = in_matrix2;
            params.indices2 = &cur_csr.data[row_start];
          } else {
            params.in_matrix2 = &in_matrix2[row_start];
          }
          kernel(&params);
        }
      }
    }
  }
}

385
/**
386
387
388
 * @brief Free the tiled CSR matrix data.
 * @param block_csr_array The array containing csr matrices of all blocks.
 * @param num_M_blocks Number of blocks to create along the rows of adjacency
389
 *        matrix.
390
 * @param num_K_blocks Number of blocks to create along the columns of adjacency
391
 *        matrix.
392
393
 * @param use_lhs Whether to use lhs.
 * @param use_rhs Whether to use rhs.
394
395
396
 */
template <typename IdType>
inline void SpMMFreeBlocks(
397
398
    CSRMatrixInternal<IdType, IdType> *block_csr_array, IdType num_M_blocks,
    IdType num_K_blocks, bool use_lhs, bool use_rhs) {
399
400
  if (num_K_blocks > 1) {
    free(block_csr_array[0].indptr);
401
402
    if (use_lhs) free(block_csr_array[0].indices);
    if (use_rhs) free(block_csr_array[0].data);
403
404
405
406
  }
  free(block_csr_array);
}

407
/**
408
409
410
411
412
413
414
415
416
 * @brief Optimized CPU kernel of SpMM-Sum/Max/Min on Csr format.
 * @param bcast Broadcast information.
 * @param csr The Csr matrix.
 * @param ufeat The feature on source nodes.
 * @param efeat The feature on edges.
 * @param out The result feature on destination nodes.
 * @param argu Arg-Min/Max on source nodes.
 * @param arge Arg-Min/Max on edges.
 * @note it uses libxsmm, blocking and dynamic thread scheduling.
417
418
419
 */
template <typename IdType, typename DType, typename Op, typename Redop>
void SpMMRedopCsrOpt(
420
421
    const BcastOff &bcast, const CSRMatrix &csr, NDArray ufeat, NDArray efeat,
    NDArray out, NDArray argu, NDArray arge) {
422
423
424
425
426
427
428
429
430
  int32_t llc_size = GetLLCSize();

#ifdef DEBUG
  uint64_t startTick, endTick;
  startTick = __rdtsc();
#endif  // DEBUG

  const bool has_idx = !IsNullArray(csr.data);

431
432
433
  DType *C = out.Ptr<DType>();
  const DType *B = ufeat.Ptr<DType>();
  const DType *E = efeat.Ptr<DType>();
434
  IdType *argB, *argE;
435
436
  if (std::is_same<Redop, op::Max<DType>>::value ||
      std::is_same<Redop, op::Min<DType>>::value) {
437
438
439
440
441
442
443
444
    argB = argu.Ptr<IdType>();
    argE = arge.Ptr<IdType>();
  }

  const int nthreads = omp_get_max_threads();
  const IdType M = csr.num_rows;
  const IdType N = bcast.out_len;
  const IdType K = csr.num_cols;
445
  const IdType *indptr = csr.indptr.Ptr<IdType>();
446
  CHECK_NOTNULL(indptr);
447
  const IdType total_nnz = indptr[M];
448
449
  if (M <= 0 || K <= 0 || N <= 0 || total_nnz <= 0) return;

sanchit-misra's avatar
sanchit-misra committed
450
451
  const double avg_degree = total_nnz * 1.0 / M;
  const double nnz_prob = avg_degree / K;
452

453
454
  IdType K_block_size = std::min(
      (int64_t)K,
455
      (int64_t)(llc_size / (N * sizeof(DType) * nnz_prob * BLOCKING_HEURISTIC_PARAM)));  // NOLINT
456
457
458
459
460
461
462
463
  IdType M_block_size = M / (nthreads * NUM_BLOCKS_PER_THREAD);
  if (M_block_size == 0) M_block_size = 1;
  if (K_block_size == 0) K_block_size = 1;

  IdType num_M_blocks = (M + M_block_size - 1) / M_block_size;
  IdType num_K_blocks = (K + K_block_size - 1) / K_block_size;

  CSRMatrixInternal<IdType, IdType> *block_csr_array =
464
465
466
      (CSRMatrixInternal<IdType, IdType> *)aligned_alloc(
          64, sizeof(CSRMatrixInternal<IdType, IdType>) * num_M_blocks *
                  num_K_blocks);
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482

#ifdef DEBUG
  endTick = __rdtsc();
  if (std::is_same<Redop, op::Max<DType>>::value) {
    LOG(INFO) << "Redop = Max";
  } else if (std::is_same<Redop, op::Min<DType>>::value) {
    LOG(INFO) << "Redop = Min";
  } else if (std::is_same<Redop, op::Add<DType>>::value) {
    LOG(INFO) << "Redop = Add";
  }
  LOG(INFO) << "nthreads = " << nthreads << ", llc_size = " << llc_size;
  LOG(INFO) << "M = " << M << ", K = " << K << ", N = " << N;
  LOG(INFO) << "use_lhs = " << Op::use_lhs << ", use_rhs = " << Op::use_rhs;
  LOG(INFO) << "total_nnz = " << total_nnz << ", avg_degree = " << avg_degree;
  LOG(INFO) << "has_idx = " << has_idx;
  LOG(INFO) << "nnz_prob = " << nnz_prob;
483
484
485
486
  LOG(INFO) << "K_block_size = " << K_block_size
            << ", M_block_size = " << M_block_size;
  LOG(INFO) << "num_K_blocks = " << num_K_blocks
            << ", num_M_blocks = " << num_M_blocks;
487
488
489
490
  LOG(INFO) << "stage0 ticks = " << (endTick - startTick);
  startTick = __rdtsc();
#endif  // DEBUG

491
492
493
  SpMMCreateBlocks(
      csr, block_csr_array, num_M_blocks, num_K_blocks, M_block_size,
      K_block_size, Op::use_lhs, Op::use_rhs);
494
495
496
497
498
499
500
501
502

#ifdef DEBUG
  endTick = __rdtsc();
  LOG(INFO) << "stage1 ticks = " << (endTick - startTick);
  startTick = __rdtsc();
#endif  // DEBUG

  libxsmm_meltwfunction_opreduce_vecs_idx kernel = nullptr;
  if (std::is_same<Redop, op::Max<DType>>::value) {
503
504
    kernel = SpMMCreateLibxsmmKernel<IdType, DType, Op>(
        has_idx, N, LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_REDOP_MAX, true);
505
  } else if (std::is_same<Redop, op::Min<DType>>::value) {
506
507
    kernel = SpMMCreateLibxsmmKernel<IdType, DType, Op>(
        has_idx, N, LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_REDOP_MIN, true);
508
  } else if (std::is_same<Redop, op::Add<DType>>::value) {
509
510
    kernel = SpMMCreateLibxsmmKernel<IdType, DType, Op>(
        has_idx, N, LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_REDOP_SUM, false);
511
512
513
514
515
516
517
518
  }

#ifdef DEBUG
  endTick = __rdtsc();
  LOG(INFO) << "stage2 ticks = " << (endTick - startTick);
  startTick = __rdtsc();
#endif  // DEBUG

519
520
521
522
523
  if (std::is_same<Redop, op::Max<DType>>::value ||
      std::is_same<Redop, op::Min<DType>>::value) {
    SpMMBlockwiseOpCmp<IdType, DType, Op, Redop>(
        block_csr_array, B, E, C, argB, argE, has_idx, N, num_M_blocks,
        num_K_blocks, M_block_size, kernel);
524
  } else {
525
526
527
    SpMMBlockwiseOpSum(
        block_csr_array, B, E, C, has_idx, N, num_M_blocks, num_K_blocks,
        M_block_size, kernel);
528
529
530
531
532
533
534
535
  }

#ifdef DEBUG
  endTick = __rdtsc();
  LOG(INFO) << "stage3 ticks = " << (endTick - startTick);
  startTick = __rdtsc();
#endif  // DEBUG

536
537
  SpMMFreeBlocks(
      block_csr_array, num_M_blocks, num_K_blocks, Op::use_lhs, Op::use_rhs);
538
539
540
541
542
543
544

#ifdef DEBUG
  endTick = __rdtsc();
  LOG(INFO) << "stage4 ticks = " << (endTick - startTick);
#endif  // DEBUG
}

545
/**
546
547
548
549
550
551
552
 * @brief Optimized CPU kernel of SpMM-Sum on Csr format.
 * @param bcast Broadcast information.
 * @param csr The Csr matrix.
 * @param ufeat The feature on source nodes.
 * @param efeat The feature on edges.
 * @param out The result feature on destination nodes.
 * @note it uses libxsmm, blocking and dynamic thread scheduling.
553
554
 */
template <typename IdType, typename DType, typename Op>
555
556
557
void SpMMSumCsrLibxsmm(
    const BcastOff &bcast, const CSRMatrix &csr, NDArray ufeat, NDArray efeat,
    NDArray out) {
558
  NDArray dummy;
559
560
  SpMMRedopCsrOpt<IdType, DType, Op, op::Add<DType>>(
      bcast, csr, ufeat, efeat, out, dummy, dummy);
561
562
}

563
/**
564
565
566
567
568
569
570
571
572
 * @brief Optimized CPU kernel of SpMM-Min/Max on Csr format.
 * @param bcast Broadcast information.
 * @param csr The Csr matrix.
 * @param ufeat The feature on source nodes.
 * @param efeat The feature on edges.
 * @param out The result feature on destination nodes.
 * @param argu Arg-Min/Max on source nodes.
 * @param arge Arg-Min/Max on edges.
 * @note it uses libxsmm, blocking and dynamic thread scheduling.
573
574
 */
template <typename IdType, typename DType, typename Op, typename Cmp>
575
576
577
578
579
void SpMMCmpCsrLibxsmm(
    const BcastOff &bcast, const CSRMatrix &csr, NDArray ufeat, NDArray efeat,
    NDArray out, NDArray argu, NDArray arge) {
  SpMMRedopCsrOpt<IdType, DType, Op, Cmp>(
      bcast, csr, ufeat, efeat, out, argu, arge);
580
581
582
583
584
585
586
587
588
589
590
}

}  // namespace cpu
}  // namespace aten
}  // namespace dgl

#endif  // USE_LIBXSMM
#endif  // USE_AVX
#endif  // _WIN32

#endif  // DGL_ARRAY_CPU_SPMM_BLOCKING_LIBXSMM_H_