Unverified Commit fac75e16 authored by sanchit-misra's avatar sanchit-misra Committed by GitHub
Browse files

[CPU][Kernel] Single socket spmm (#3024)



* optimizations of spmm for CPU

* Added names of contributors

* Minor code cleanup

* Moved the spmm optimization code to a new header file

* Moved to DGL's logging method

* removed duplicate code between SpMMSumCsr and SpMMCmpCsr

* Changes made to follow Google coding style

* Fixed lint errors in spmm.h

* Fixed some lint errors from spmm_blocking_libxsmm.h

* Fixed lint errors from spmm_blocking_libxsmm.h

* Added comments to SpMMCreateLibxsmmKernel

* to enable building of tests, and other cosmetic changes

* disabling libxsmm on windows

* Put a condition to avoid opt impl for FP64 as libxsmm does not have FP64 support yet

* cosmetic changes and documentation

* cosmetic changes

* to pass lint tests

* replaced multiple allocations for buffers of indices and edges with a single allocation
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent 186ef592
...@@ -35,3 +35,6 @@ ...@@ -35,3 +35,6 @@
[submodule "third_party/nccl"] [submodule "third_party/nccl"]
path = third_party/nccl path = third_party/nccl
url = https://github.com/nvidia/nccl url = https://github.com/nvidia/nccl
[submodule "third_party/libxsmm"]
path = third_party/libxsmm
url = https://github.com/hfp/libxsmm.git
...@@ -27,6 +27,7 @@ dgl_option(USE_NCCL "Build with NCCL support" OFF) ...@@ -27,6 +27,7 @@ dgl_option(USE_NCCL "Build with NCCL support" OFF)
dgl_option(USE_SYSTEM_NCCL "Build using system's NCCL library" OFF) dgl_option(USE_SYSTEM_NCCL "Build using system's NCCL library" OFF)
dgl_option(USE_OPENMP "Build with OpenMP" ON) dgl_option(USE_OPENMP "Build with OpenMP" ON)
dgl_option(USE_AVX "Build with AVX optimization" ON) dgl_option(USE_AVX "Build with AVX optimization" ON)
dgl_option(USE_LIBXSMM "Build with LIBXSMM library optimization" ON)
dgl_option(USE_FP16 "Build with fp16 support to enable mixed precision training" OFF) dgl_option(USE_FP16 "Build with fp16 support to enable mixed precision training" OFF)
dgl_option(USE_TVM "Build with TVM kernels" OFF) dgl_option(USE_TVM "Build with TVM kernels" OFF)
dgl_option(BUILD_CPP_TEST "Build cpp unittest executables" OFF) dgl_option(BUILD_CPP_TEST "Build cpp unittest executables" OFF)
...@@ -36,7 +37,7 @@ dgl_option(USE_HDFS "Build with HDFS support" OFF) # Set env HADOOP_HDFS_HOME if ...@@ -36,7 +37,7 @@ dgl_option(USE_HDFS "Build with HDFS support" OFF) # Set env HADOOP_HDFS_HOME if
# Set debug compile option for gdb, only happens when -DCMAKE_BUILD_TYPE=DEBUG # Set debug compile option for gdb, only happens when -DCMAKE_BUILD_TYPE=DEBUG
if (NOT MSVC) if (NOT MSVC)
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0 -g3 -ggdb") set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -DDEBUG -O0 -g3 -ggdb")
endif(NOT MSVC) endif(NOT MSVC)
if(USE_CUDA) if(USE_CUDA)
...@@ -89,10 +90,10 @@ if(MSVC) ...@@ -89,10 +90,10 @@ if(MSVC)
else(MSVC) else(MSVC)
include(CheckCXXCompilerFlag) include(CheckCXXCompilerFlag)
check_cxx_compiler_flag("-std=c++11" SUPPORT_CXX11) check_cxx_compiler_flag("-std=c++11" SUPPORT_CXX11)
set(CMAKE_C_FLAGS "-O2 -Wall -fPIC ${CMAKE_C_FLAGS}") set(CMAKE_C_FLAGS "-O2 -Wall -fPIC -march=native ${CMAKE_C_FLAGS}")
# We still use c++11 flag in CPU build because gcc5.4 (our default compiler) is # We still use c++11 flag in CPU build because gcc5.4 (our default compiler) is
# not fully compatible with c++14 feature. # not fully compatible with c++14 feature.
set(CMAKE_CXX_FLAGS "-O2 -Wall -fPIC -std=c++11 ${CMAKE_CXX_FLAGS}") set(CMAKE_CXX_FLAGS "-O2 -Wall -fPIC -std=c++11 -march=native ${CMAKE_CXX_FLAGS}")
if(NOT APPLE) if(NOT APPLE)
set(CMAKE_SHARED_LINKER_FLAGS "-Wl,--warn-common ${CMAKE_SHARED_LINKER_FLAGS}") set(CMAKE_SHARED_LINKER_FLAGS "-Wl,--warn-common ${CMAKE_SHARED_LINKER_FLAGS}")
endif(NOT APPLE) endif(NOT APPLE)
...@@ -108,9 +109,15 @@ if(USE_OPENMP) ...@@ -108,9 +109,15 @@ if(USE_OPENMP)
endif(USE_OPENMP) endif(USE_OPENMP)
if(USE_AVX) if(USE_AVX)
if(USE_LIBXSMM)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DUSE_AVX -DUSE_LIBXSMM -DDGL_CPU_LLC_SIZE=40000000")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_AVX -DUSE_LIBXSMM -DDGL_CPU_LLC_SIZE=40000000")
message(STATUS "Build with LIBXSMM optimization.")
else(USE_LIBXSMM)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DUSE_AVX") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DUSE_AVX")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_AVX") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_AVX")
message(STATUS "Build with AVX optimization.") message(STATUS "Build with AVX optimization.")
endif(USE_LIBXSMM)
endif(USE_AVX) endif(USE_AVX)
# Build with fp16 to support mixed precision training. # Build with fp16 to support mixed precision training.
...@@ -194,6 +201,7 @@ target_include_directories(dgl PRIVATE "third_party/xbyak/") ...@@ -194,6 +201,7 @@ target_include_directories(dgl PRIVATE "third_party/xbyak/")
target_include_directories(dgl PRIVATE "third_party/METIS/include/") target_include_directories(dgl PRIVATE "third_party/METIS/include/")
target_include_directories(dgl PRIVATE "tensoradapter/include") target_include_directories(dgl PRIVATE "tensoradapter/include")
target_include_directories(dgl PRIVATE "third_party/nanoflann/include") target_include_directories(dgl PRIVATE "third_party/nanoflann/include")
target_include_directories(dgl PRIVATE "third_party/libxsmm/include")
# For serialization # For serialization
if (USE_HDFS) if (USE_HDFS)
...@@ -213,6 +221,15 @@ if(NOT MSVC) ...@@ -213,6 +221,15 @@ if(NOT MSVC)
list(APPEND DGL_LINKER_LIBS metis) list(APPEND DGL_LINKER_LIBS metis)
endif(NOT MSVC) endif(NOT MSVC)
# Compile LIBXSMM
if((NOT MSVC) AND USE_LIBXSMM)
add_custom_target(libxsmm COMMAND make realclean COMMAND make -j BLAS=0
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/third_party/libxsmm
)
add_dependencies(dgl libxsmm)
list(APPEND DGL_LINKER_LIBS -L${CMAKE_SOURCE_DIR}/third_party/libxsmm/lib/ xsmm)
endif((NOT MSVC) AND USE_LIBXSMM)
# Compile TVM Runtime and Featgraph # Compile TVM Runtime and Featgraph
# (NOTE) We compile a dynamic library called featgraph_runtime, which the DGL library links to. # (NOTE) We compile a dynamic library called featgraph_runtime, which the DGL library links to.
# Kernels are packed in a separate dynamic library called featgraph_kernels, which DGL # Kernels are packed in a separate dynamic library called featgraph_kernels, which DGL
...@@ -287,6 +304,7 @@ if(BUILD_CPP_TEST) ...@@ -287,6 +304,7 @@ if(BUILD_CPP_TEST)
include_directories("third_party/xbyak") include_directories("third_party/xbyak")
include_directories("third_party/dmlc-core/include") include_directories("third_party/dmlc-core/include")
include_directories("third_party/phmap") include_directories("third_party/phmap")
include_directories("third_party/libxsmm/include")
file(GLOB_RECURSE TEST_SRC_FILES ${PROJECT_SOURCE_DIR}/tests/cpp/*.cc) file(GLOB_RECURSE TEST_SRC_FILES ${PROJECT_SOURCE_DIR}/tests/cpp/*.cc)
add_executable(runUnitTests ${TEST_SRC_FILES}) add_executable(runUnitTests ${TEST_SRC_FILES})
target_link_libraries(runUnitTests gtest gtest_main) target_link_libraries(runUnitTests gtest gtest_main)
......
...@@ -15,44 +15,37 @@ ...@@ -15,44 +15,37 @@
#if !defined(_WIN32) #if !defined(_WIN32)
#ifdef USE_AVX #ifdef USE_AVX
#include "intel/cpu_support.h" #include "intel/cpu_support.h"
#ifdef USE_LIBXSMM
#include "spmm_blocking_libxsmm.h"
#endif // USE_LIBXSMM
#endif // USE_AVX #endif // USE_AVX
#endif // _WIN32 #endif // _WIN32
namespace dgl { namespace dgl {
namespace aten { namespace aten {
namespace cpu { namespace cpu {
#if !defined(_WIN32)
#ifdef USE_AVX
/*! /*!
* \brief CPU kernel of SpMM on Csr format. * \brief CPU kernel of SpMM on Csr format using Xbyak.
* \param cpu_spec JIT'ed kernel
* \param bcast Broadcast information. * \param bcast Broadcast information.
* \param csr The Csr matrix. * \param csr The Csr matrix.
* \param ufeat The feature on source nodes. * \param X The feature on source nodes.
* \param efeat The feature on edges. * \param W The feature on edges.
* \param out The result feature on destination nodes. * \param O The result feature on destination nodes.
* \note it uses node parallel strategy, different threads are responsible * \note it uses node parallel strategy, different threads are responsible
* for the computation of different nodes. * for the computation of different nodes. For each edge, it uses the
* JIT'ed kernel.
*/ */
template <typename IdType, typename DType, typename Op> template <typename IdType, typename DType, typename Op>
void SpMMSumCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, void SpMMSumCsrXbyak(dgl::ElemWiseAddUpdate<Op>* cpu_spec, const BcastOff& bcast,
NDArray efeat, NDArray out) { const CSRMatrix& csr, const DType* X, const DType* W, DType* O) {
const bool has_idx = !IsNullArray(csr.data); const bool has_idx = !IsNullArray(csr.data);
const IdType* indptr = csr.indptr.Ptr<IdType>(); const IdType* indptr = csr.indptr.Ptr<IdType>();
const IdType* indices = csr.indices.Ptr<IdType>(); const IdType* indices = csr.indices.Ptr<IdType>();
const IdType* edges = csr.data.Ptr<IdType>(); const IdType* edges = csr.data.Ptr<IdType>();
const DType* X = ufeat.Ptr<DType>();
const DType* W = efeat.Ptr<DType>();
int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len, rhs_dim = bcast.rhs_len; int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len, rhs_dim = bcast.rhs_len;
DType* O = out.Ptr<DType>();
#if !defined(_WIN32)
#ifdef USE_AVX
typedef dgl::ElemWiseAddUpdate<Op> ElemWiseUpd;
/* Prepare an assembler kernel */
static std::unique_ptr<ElemWiseUpd> asm_kernel_ptr(
(dgl::IntelKernel<>::IsEnabled()) ? new ElemWiseUpd() : nullptr);
/* Distribute the kernel among OMP threads */
ElemWiseUpd* cpu_spec = (asm_kernel_ptr && asm_kernel_ptr->applicable())
? asm_kernel_ptr.get()
: nullptr;
if (cpu_spec && dim > 16 && !bcast.use_bcast) {
#pragma omp parallel for #pragma omp parallel for
for (IdType rid = 0; rid < csr.num_rows; ++rid) { for (IdType rid = 0; rid < csr.num_rows; ++rid) {
const IdType row_start = indptr[rid], row_end = indptr[rid + 1]; const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
...@@ -63,10 +56,29 @@ void SpMMSumCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, ...@@ -63,10 +56,29 @@ void SpMMSumCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
cpu_spec->run(out_off, X + cid * lhs_dim, W + eid * rhs_dim, dim); cpu_spec->run(out_off, X + cid * lhs_dim, W + eid * rhs_dim, dim);
} }
} }
} else { }
#endif // USE_AVX #endif // USE_AVX
#endif // _WIN32 #endif // _WIN32
/*!
* \brief Naive CPU kernel of SpMM on Csr format.
* \param cpu_spec JIT'ed kernel
* \param bcast Broadcast information.
* \param csr The Csr matrix.
* \param X The feature on source nodes.
* \param W The feature on edges.
* \param O The result feature on destination nodes.
* \note it uses node parallel strategy, different threads are responsible
* for the computation of different nodes.
*/
template <typename IdType, typename DType, typename Op>
void SpMMSumCsrNaive(const BcastOff& bcast, const CSRMatrix& csr, const DType* X,
const DType* W, DType* O) {
const bool has_idx = !IsNullArray(csr.data);
const IdType* indptr = csr.indptr.Ptr<IdType>();
const IdType* indices = csr.indices.Ptr<IdType>();
const IdType* edges = csr.data.Ptr<IdType>();
int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len, rhs_dim = bcast.rhs_len;
#pragma omp parallel for #pragma omp parallel for
for (IdType rid = 0; rid < csr.num_rows; ++rid) { for (IdType rid = 0; rid < csr.num_rows; ++rid) {
const IdType row_start = indptr[rid], row_end = indptr[rid + 1]; const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
...@@ -85,9 +97,69 @@ void SpMMSumCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, ...@@ -85,9 +97,69 @@ void SpMMSumCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
} }
} }
} }
}
/*!
* \brief CPU kernel of SpMM on Csr format.
* \param bcast Broadcast information.
* \param csr The Csr matrix.
* \param ufeat The feature on source nodes.
* \param efeat The feature on edges.
* \param out The result feature on destination nodes.
* \note it uses node parallel strategy, different threads are responsible
* for the computation of different nodes.
*/
template <typename IdType, typename DType, typename Op>
void SpMMSumCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
NDArray efeat, NDArray out) {
const bool has_idx = !IsNullArray(csr.data);
const IdType* indptr = csr.indptr.Ptr<IdType>();
const IdType* indices = csr.indices.Ptr<IdType>();
const IdType* edges = csr.data.Ptr<IdType>();
const DType* X = ufeat.Ptr<DType>();
const DType* W = efeat.Ptr<DType>();
int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len, rhs_dim = bcast.rhs_len;
DType* O = out.Ptr<DType>();
CHECK_NOTNULL(indptr);
CHECK_NOTNULL(O);
if (Op::use_lhs) {
CHECK_NOTNULL(indices);
CHECK_NOTNULL(X);
}
if (Op::use_rhs) {
if (has_idx)
CHECK_NOTNULL(edges);
CHECK_NOTNULL(W);
}
#if !defined(_WIN32)
#ifdef USE_AVX
#ifdef USE_LIBXSMM
const bool no_libxsmm =
bcast.use_bcast || std::is_same<DType, double>::value;
if (!no_libxsmm) {
SpMMSumCsrLibxsmm<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
} else {
#endif // USE_LIBXSMM
typedef dgl::ElemWiseAddUpdate<Op> ElemWiseUpd;
/* Prepare an assembler kernel */
static std::unique_ptr<ElemWiseUpd> asm_kernel_ptr(
(dgl::IntelKernel<>::IsEnabled()) ? new ElemWiseUpd() : nullptr);
/* Distribute the kernel among OMP threads */
ElemWiseUpd* cpu_spec = (asm_kernel_ptr && asm_kernel_ptr->applicable())
? asm_kernel_ptr.get()
: nullptr;
if (cpu_spec && dim > 16 && !bcast.use_bcast) {
SpMMSumCsrXbyak<IdType, DType, Op>(cpu_spec, bcast, csr, X, W, O);
} else {
#endif // USE_AVX
#endif // _WIN32
SpMMSumCsrNaive<IdType, DType, Op>(bcast, csr, X, W, O);
#if !defined(_WIN32) #if !defined(_WIN32)
#ifdef USE_AVX #ifdef USE_AVX
} }
#ifdef USE_LIBXSMM
}
#endif // USE_LIBXSMM
#endif // USE_AVX #endif // USE_AVX
#endif // _WIN32 #endif // _WIN32
} }
...@@ -172,6 +244,32 @@ void SpMMCmpCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, ...@@ -172,6 +244,32 @@ void SpMMCmpCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
DType* O = static_cast<DType*>(out->data); DType* O = static_cast<DType*>(out->data);
IdType* argX = Op::use_lhs ? static_cast<IdType*>(argu->data) : nullptr; IdType* argX = Op::use_lhs ? static_cast<IdType*>(argu->data) : nullptr;
IdType* argW = Op::use_rhs ? static_cast<IdType*>(arge->data) : nullptr; IdType* argW = Op::use_rhs ? static_cast<IdType*>(arge->data) : nullptr;
CHECK_NOTNULL(indptr);
CHECK_NOTNULL(O);
if (Op::use_lhs) {
CHECK_NOTNULL(indices);
CHECK_NOTNULL(X);
CHECK_NOTNULL(argX);
}
if (Op::use_rhs) {
if (has_idx)
CHECK_NOTNULL(edges);
CHECK_NOTNULL(W);
CHECK_NOTNULL(argW);
}
#if !defined(_WIN32)
#ifdef USE_AVX
#ifdef USE_LIBXSMM
const bool no_libxsmm =
bcast.use_bcast || std::is_same<DType, double>::value;
if (!no_libxsmm) {
SpMMCmpCsrLibxsmm<IdType, DType, Op, Cmp>(bcast, csr, ufeat, efeat, out, argu, arge);
} else {
#endif // USE_LIBXSMM
#endif // USE_AVX
#endif // _WIN32
#pragma omp parallel for #pragma omp parallel for
for (IdType rid = 0; rid < csr.num_rows; ++rid) { for (IdType rid = 0; rid < csr.num_rows; ++rid) {
const IdType row_start = indptr[rid], row_end = indptr[rid + 1]; const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
...@@ -197,6 +295,13 @@ void SpMMCmpCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, ...@@ -197,6 +295,13 @@ void SpMMCmpCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
} }
} }
} }
#if !defined(_WIN32)
#ifdef USE_AVX
#ifdef USE_LIBXSMM
}
#endif // USE_LIBXSMM
#endif // USE_AVX
#endif // _WIN32
} }
/*! /*!
......
/*!
* Copyright (c) 2021 Intel Corporation
* \file array/cpu/spmm.h
* \brief SPMM CPU kernel function header.
* \author Sanchit Misra <sanchit.misra@intel.com>,
* Ramanarayan Mohanty <ramanarayan.mohanty@intel.com>,
* Vasimuddin Md <vasimuddin.md@intel.com>,
* Sasikanth Avancha <sasikanth.avancha@intel.com>
*/
#ifndef DGL_ARRAY_CPU_SPMM_BLOCKING_LIBXSMM_H_
#define DGL_ARRAY_CPU_SPMM_BLOCKING_LIBXSMM_H_
#include <dgl/array.h>
#include <dgl/bcast.h>
#include <dmlc/logging.h>
#include <algorithm>
#if !defined(_WIN32)
#ifdef USE_AVX
#ifdef USE_LIBXSMM
#include <unistd.h>
#include <libxsmm.h>
#ifdef DEBUG
#include <x86intrin.h>
#endif // DEBUG
#include <dmlc/omp.h>
#define NUM_BLOCKS_PER_THREAD 20
#define BLOCKING_HEURISTIC_PARAM 500
namespace dgl {
namespace aten {
namespace cpu {
template <typename IdType, typename DType>
struct CSRMatrixInternal {
IdType num_rows;
IdType num_cols;
IdType *indptr;
IdType *indices;
DType *data;
};
int32_t GetLLCSize() {
int32_t cache_size = sysconf(_SC_LEVEL3_CACHE_SIZE);
if (cache_size < 0) cache_size = DGL_CPU_LLC_SIZE;
return cache_size;
}
/*!
* \brief Tile the CSR matrix to roughly make sure that the column tiles and
* corresponding neighbor features fit into LLC and the row tiles
* are assigned to OMP threads.
* \param csr The Csr matrix.
* \param block_csr_array The array containing csr matrices of all blocks.
* \param num_M_blocks Number of blocks to create along the rows of adjacency matrix.
* \param num_K_blocks Number of blocks to create along the columns of adjacency matrix.
* \param M_block_size block size along the rows of adjacency matrix.
* \param K_block_size block size along the columns of adjacency matrix.
* \param use_lhs Whether to use lhs.
* \param use_rhs Whether to use rhs.
*/
template <typename IdType>
inline void SpMMCreateBlocks(
const CSRMatrix& csr,
CSRMatrixInternal<IdType, IdType> *block_csr_array,
IdType num_M_blocks,
IdType num_K_blocks,
IdType M_block_size,
IdType K_block_size,
bool use_lhs, bool use_rhs) {
const IdType M = csr.num_rows;
const IdType K = csr.num_cols;
IdType* indptr = csr.indptr.Ptr<IdType>();
IdType* indices = csr.indices.Ptr<IdType>();
IdType* edges = csr.data.Ptr<IdType>();
CHECK_NOTNULL(indptr);
if (use_lhs)
CHECK_NOTNULL(indices);
if (use_rhs)
CHECK_NOTNULL(edges);
if (num_K_blocks > 1) {
IdType *indptr_block_buf = reinterpret_cast<IdType *>(aligned_alloc(64,
(M_block_size + 1) * num_M_blocks *
num_K_blocks * sizeof(IdType)));
IdType *indices_block_buf = reinterpret_cast<IdType *>(aligned_alloc(64,
indptr[M] * sizeof(IdType)));
IdType *edges_block_buf = reinterpret_cast<IdType *>(aligned_alloc(64,
indptr[M] * sizeof(IdType)));
#pragma omp parallel
{
IdType *my_cur_col_id = reinterpret_cast<IdType *>(aligned_alloc(64, 2 * M_block_size *
sizeof(IdType)));
#pragma omp for
for (IdType m = 0; m < num_M_blocks; m++) {
const IdType M_start = m * M_block_size;
const IdType M_end = std::min((m + 1) * M_block_size, M);
const IdType nnz = indptr[M_end] - indptr[M_start];
IdType cur_indices_id = 0;
IdType *my_indices_block_buf, *my_edges_block_buf;
if (use_lhs)
my_indices_block_buf = indices_block_buf + indptr[M_start];
if (use_rhs)
my_edges_block_buf = edges_block_buf + indptr[M_start];
for (IdType i = M_start; i < M_end; i++) {
my_cur_col_id[(i - M_start) * 2] = indptr[i];
my_cur_col_id[(i - M_start) * 2 + 1] = indptr[i + 1];
}
for (IdType k = 0; k < num_K_blocks; k++) {
const IdType K_start = k * K_block_size;
const IdType K_end = std::min((k + 1) * K_block_size, K);
CSRMatrixInternal<IdType, IdType> cur_csr;
cur_csr.num_rows = M_end - M_start;
cur_csr.num_cols = K_end - K_start;
// Create csr_ij
IdType *cur_csr_indptr = indptr_block_buf + (m * num_K_blocks + k) * (M_block_size + 1);
IdType *cur_csr_indices = nullptr, *cur_csr_edges = nullptr;
if (use_lhs)
cur_csr_indices = my_indices_block_buf + cur_indices_id;
if (use_rhs)
cur_csr_edges = my_edges_block_buf + cur_indices_id;
IdType cur_nnz = 0;
for (IdType i = M_start; i < M_end; i++) {
const IdType row_start = my_cur_col_id[(i - M_start) * 2];
const IdType row_end = my_cur_col_id[(i - M_start) * 2 + 1];
cur_csr_indptr[i - M_start] = cur_nnz;
IdType eid;
for (eid = row_start; eid < row_end; eid++) {
const IdType src = indices[eid];
const IdType edge = edges[eid];
if (src >= K_end) {
break;
}
CHECK_LT(cur_indices_id + cur_nnz, nnz);
if (use_lhs)
cur_csr_indices[cur_nnz] = src;
if (use_rhs)
cur_csr_edges[cur_nnz] = edge;
cur_nnz++;
}
my_cur_col_id[(i - M_start) * 2] = eid;
}
cur_csr_indptr[cur_csr.num_rows] = cur_nnz;
cur_indices_id += cur_nnz;
cur_csr.indptr = cur_csr_indptr;
if (use_lhs)
cur_csr.indices = cur_csr_indices;
if (use_rhs)
cur_csr.data = cur_csr_edges;
block_csr_array[m * num_K_blocks + k] = cur_csr;
}
CHECK_EQ(nnz, cur_indices_id);
}
free(my_cur_col_id);
}
} else {
#pragma omp for
for (IdType m = 0; m < num_M_blocks; m++) {
const IdType M_start = m * M_block_size;
const IdType M_end = std::min((m + 1) * M_block_size, M);
CSRMatrixInternal<IdType, IdType> cur_csr;
cur_csr.num_rows = M_end - M_start;
cur_csr.num_cols = K;
cur_csr.indptr = indptr + M_start;
cur_csr.indices = indices;
cur_csr.data = edges;
block_csr_array[m] = cur_csr;
}
}
}
/*!
* \brief Create libxsmm kernel.
* \param has_idx For the edge features, are there indices available.
* \param N Feature size.
* \param redop_flag Flag specifying the reduction operation.
* \param is_cmp Is the reduction operation a compare operation.
* \note libxsmm_dispatch_meltw_opreduce_vecs_idx creates a JIT'ed kernel.
* Given a node u, the kernel performs an elementwise "Op" on the
* features of the neighbors and/or the edges incident on u.
* Subsequently, it performs an elementwise "Redop" on all such
* features created and stores into the feature of node u.
* It uses a SIMD and a cache efficient design and also provides
* support to enable software prefetching if needed. For IdType,
* it supports INT32 and INT64. For DType, it supports BF16 and FP32.
* It supports all the "Ops" and "Redops" supported by DGL. Once a
* kernel is generated by libxsmm_dispatch_meltw_opreduce_vecs_idx,
* it is cached for the entire duration of the execution of a program
* so that subsequently if the kernel is needed again, it just returns
* the cached copy.
*/
template <typename IdType, typename DType, typename Op>
inline libxsmm_meltwfunction_opreduce_vecs_idx SpMMCreateLibxsmmKernel(
bool has_idx,
IdType N,
libxsmm_meltw_opreduce_vecs_flags redop_flag,
bool is_cmp) {
int _ld = N;
libxsmm_meltw_opreduce_vecs_flags opredop_flags;
// First, set the Op in the opredop_flags
if (std::is_same<Op, op::Add<DType>>::value) {
opredop_flags = LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_ADD;
} else if (std::is_same<Op, op::Sub<DType>>::value) {
opredop_flags = LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_SUB;
} else if (std::is_same<Op, op::Mul<DType>>::value) {
opredop_flags = LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_MUL;
} else if (std::is_same<Op, op::Div<DType>>::value) {
opredop_flags = LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_DIV;
} else if (std::is_same<Op, op::CopyLhs<DType>>::value) {
opredop_flags = LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_COPY;
} else if (std::is_same<Op, op::CopyRhs<DType>>::value) {
opredop_flags = LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_COPY;
}
// Second, set which of lhs or rhs is considered first and second operand.
// This is needed since libxsmm assumes that the copy operation always copies the first operand.
// So, if we need to copy rhs, we need to set that as the first operand.
// For rhs, we also set whether to use implicit indices or provided indices.
if (std::is_same<Op, op::CopyLhs<DType>>::value) {
opredop_flags = (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags |
LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OPORDER_VECIDX_VECIN);
} else if (std::is_same<Op, op::CopyRhs<DType>>::value) {
opredop_flags = (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags |
LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OPORDER_VECIN_VECIDX);
if (!has_idx) {
opredop_flags = (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags |
LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_IMPLICIT_INDEXED_VECIDX);
}
} else {
opredop_flags = (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags |
LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OPORDER_VECIDX_VECIN);
if (has_idx) {
opredop_flags = (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags |
LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_INDEXED_VEC);
} else {
opredop_flags = (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags |
LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_IMPLICIT_INDEXED_VEC);
}
}
// Third, we set the Redop in the opredop_flags
opredop_flags = (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | redop_flag);
// Fourth, in case of Cmp Redop, set whether to record argmax/argmin for lhs/rhs
if (is_cmp) {
if (Op::use_lhs) {
opredop_flags = (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags |
LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_RECORD_ARGOP_OFF_VEC_0);
}
if (Op::use_rhs) {
opredop_flags = (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags |
LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_RECORD_ARGOP_OFF_VEC_1);
}
}
libxsmm_meltwfunction_opreduce_vecs_idx kernel = nullptr;
if (std::is_same<DType, float>::value) {
kernel = libxsmm_dispatch_meltw_opreduce_vecs_idx(
N, &_ld, &_ld, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32,
(sizeof(IdType) == 8) ? LIBXSMM_DATATYPE_I64 : LIBXSMM_DATATYPE_I32, opredop_flags);
}
if (kernel == nullptr) {
LOG(FATAL) << "Failed to generate libxsmm kernel for the SpMM operation!";
}
return kernel;
}
/*!
* \brief Use libxsmm to perform SpMM-Sum on all blocks.
* \param block_csr_array The array containing csr matrices of all blocks.
* \param B The feature on source nodes.
* \param E The feature on edges.
* \param C The result feature on destination nodes.
* \param has_idx For the edge features, are there indices available.
* \param N Feature size.
* \param num_M_blocks Number of blocks to create along the rows of adjacency matrix.
* \param num_K_blocks Number of blocks to create along the columns of adjacency matrix.
* \param M_block_size block size along the rows of adjacency matrix.
* \param kernel The libxsmm kernel.
*/
template <typename IdType, typename DType>
inline void SpMMBlockwiseOpSum(
CSRMatrixInternal<IdType, IdType> *block_csr_array,
const DType *B, const DType *E, DType *C, bool has_idx, IdType N,
IdType num_M_blocks, IdType num_K_blocks, IdType M_block_size,
libxsmm_meltwfunction_opreduce_vecs_idx kernel) {
DType (*in_matrix1)[N] = (DType (*)[N])B;
DType (*in_matrix2)[N] = (DType (*)[N])E;
DType (*output)[N] = (DType (*)[N])C;
#pragma omp parallel
{
for (IdType k = 0; k < num_K_blocks; k++) {
#pragma omp for schedule(dynamic)
for (IdType m = 0; m < num_M_blocks; m++) {
CSRMatrixInternal<IdType, IdType> cur_csr = block_csr_array[m * num_K_blocks + k];
const IdType M_start = m * M_block_size;
for (IdType i = 0; i < cur_csr.num_rows; i++) {
const IdType row_start = cur_csr.indptr[i];
const IdType row_end = cur_csr.indptr[i + 1];
const IdType dst = i + M_start;
libxsmm_meltw_opreduce_vecs_idx_param params;
params.n = row_end - row_start;
params.indices = &cur_csr.indices[row_start];
params.in_matrix = in_matrix1;
params.out_vec = &output[dst][0];
params.scale_vals = nullptr;
if (has_idx) {
params.in_matrix2 = in_matrix2;
params.indices2 = &cur_csr.data[row_start];
} else {
params.in_matrix2 = &in_matrix2[row_start];
}
kernel(&params);
}
}
}
}
}
/*!
* \brief Use libxsmm to perform SpMM-Max/Min on all blocks.
* \param block_csr_array The array containing csr matrices of all blocks.
* \param B The feature on source nodes.
* \param E The feature on edges.
* \param C The result feature on destination nodes.
* \param argB Arg-Min/Max on source nodes.
* \param argE Arg-Min/Max on edges.
* \param has_idx For the edge features, are there indices available.
* \param N Feature size.
* \param num_M_blocks Number of blocks to create along the rows of adjacency matrix.
* \param num_K_blocks Number of blocks to create along the columns of adjacency matrix.
* \param M_block_size block size along the rows of adjacency matrix.
* \param kernel The libxsmm kernel.
*/
template <typename IdType, typename DType, typename Op, typename Cmp>
inline void SpMMBlockwiseOpCmp(
CSRMatrixInternal<IdType, IdType> *block_csr_array,
const DType *B, const DType *E, DType *C, IdType *argB, IdType *argE,
bool has_idx, IdType N,
IdType num_M_blocks, IdType num_K_blocks, IdType M_block_size,
libxsmm_meltwfunction_opreduce_vecs_idx kernel) {
DType (*in_matrix1)[N] = (DType (*)[N])B;
DType (*in_matrix2)[N] = (DType (*)[N])E;
DType (*output)[N] = (DType (*)[N])C;
IdType (*out_matrix1)[N] = (IdType (*)[N])argB;
IdType (*out_matrix2)[N] = (IdType (*)[N])argE;
#pragma omp parallel
{
for (IdType k = 0; k < num_K_blocks; k++) {
#pragma omp for schedule(dynamic)
for (IdType m = 0; m < num_M_blocks; m++) {
CSRMatrixInternal<IdType, IdType> cur_csr = block_csr_array[m * num_K_blocks + k];
const IdType M_start = m * M_block_size;
for (IdType i = 0; i < cur_csr.num_rows; i++) {
const IdType row_start = cur_csr.indptr[i];
const IdType row_end = cur_csr.indptr[i + 1];
const IdType dst = i + M_start;
libxsmm_meltw_opreduce_vecs_idx_param params;
params.n = row_end - row_start;
params.indices = &cur_csr.indices[row_start];
params.in_matrix = in_matrix1;
params.out_vec = &output[dst][0];
params.argop_off_vec_0 = &out_matrix1[dst][0];
params.argop_off_vec_1 = &out_matrix2[dst][0];
params.scale_vals = nullptr;
if (has_idx) {
params.in_matrix2 = in_matrix2;
params.indices2 = &cur_csr.data[row_start];
} else {
params.in_matrix2 = &in_matrix2[row_start];
}
kernel(&params);
}
}
}
}
}
/*!
* \brief Free the tiled CSR matrix data.
* \param block_csr_array The array containing csr matrices of all blocks.
* \param num_M_blocks Number of blocks to create along the rows of adjacency matrix.
* \param num_K_blocks Number of blocks to create along the columns of adjacency matrix.
* \param use_lhs Whether to use lhs.
* \param use_rhs Whether to use rhs.
*/
template <typename IdType>
inline void SpMMFreeBlocks(
CSRMatrixInternal<IdType, IdType> *block_csr_array,
IdType num_M_blocks, IdType num_K_blocks,
bool use_lhs, bool use_rhs) {
if (num_K_blocks > 1) {
free(block_csr_array[0].indptr);
if (use_lhs)
free(block_csr_array[0].indices);
if (use_rhs)
free(block_csr_array[0].data);
}
free(block_csr_array);
}
/*!
* \brief Optimized CPU kernel of SpMM-Sum/Max/Min on Csr format.
* \param bcast Broadcast information.
* \param csr The Csr matrix.
* \param ufeat The feature on source nodes.
* \param efeat The feature on edges.
* \param out The result feature on destination nodes.
* \param argu Arg-Min/Max on source nodes.
* \param arge Arg-Min/Max on edges.
* \note it uses libxsmm, blocking and dynamic thread scheduling.
*/
template <typename IdType, typename DType, typename Op, typename Redop>
void SpMMRedopCsrOpt(
const BcastOff& bcast,
const CSRMatrix& csr,
NDArray ufeat, NDArray efeat,
NDArray out,
NDArray argu, NDArray arge) {
int32_t llc_size = GetLLCSize();
#ifdef DEBUG
uint64_t startTick, endTick;
startTick = __rdtsc();
#endif // DEBUG
const bool has_idx = !IsNullArray(csr.data);
DType* C = out.Ptr<DType>();
const DType* B = ufeat.Ptr<DType>();
const DType* E = efeat.Ptr<DType>();
IdType *argB, *argE;
if (std::is_same<Redop, op::Max<DType>>::value || std::is_same<Redop, op::Min<DType>>::value) {
argB = argu.Ptr<IdType>();
argE = arge.Ptr<IdType>();
}
const int nthreads = omp_get_max_threads();
const IdType M = csr.num_rows;
const IdType N = bcast.out_len;
const IdType K = csr.num_cols;
const IdType* indptr = csr.indptr.Ptr<IdType>();
CHECK_NOTNULL(indptr);
const int total_nnz = indptr[M];
if (M <= 0 || K <= 0 || N <= 0 || total_nnz <= 0) return;
const float avg_degree = total_nnz * 1.0 / M;
const float nnz_prob = avg_degree / K;
IdType K_block_size = llc_size / (N * sizeof(DType) * nnz_prob * BLOCKING_HEURISTIC_PARAM);
IdType M_block_size = M / (nthreads * NUM_BLOCKS_PER_THREAD);
if (M_block_size == 0) M_block_size = 1;
if (K_block_size == 0) K_block_size = 1;
IdType num_M_blocks = (M + M_block_size - 1) / M_block_size;
IdType num_K_blocks = (K + K_block_size - 1) / K_block_size;
CSRMatrixInternal<IdType, IdType> *block_csr_array =
(CSRMatrixInternal<IdType, IdType> *)aligned_alloc(64,
sizeof(CSRMatrixInternal<IdType, IdType>) * num_M_blocks * num_K_blocks);
#ifdef DEBUG
endTick = __rdtsc();
if (std::is_same<Redop, op::Max<DType>>::value) {
LOG(INFO) << "Redop = Max";
} else if (std::is_same<Redop, op::Min<DType>>::value) {
LOG(INFO) << "Redop = Min";
} else if (std::is_same<Redop, op::Add<DType>>::value) {
LOG(INFO) << "Redop = Add";
}
LOG(INFO) << "nthreads = " << nthreads << ", llc_size = " << llc_size;
LOG(INFO) << "M = " << M << ", K = " << K << ", N = " << N;
LOG(INFO) << "use_lhs = " << Op::use_lhs << ", use_rhs = " << Op::use_rhs;
LOG(INFO) << "total_nnz = " << total_nnz << ", avg_degree = " << avg_degree;
LOG(INFO) << "has_idx = " << has_idx;
LOG(INFO) << "nnz_prob = " << nnz_prob;
LOG(INFO) << "K_block_size = " << K_block_size << ", M_block_size = " << M_block_size;
LOG(INFO) << "num_K_blocks = " << num_K_blocks << ", num_M_blocks = " << num_M_blocks;
LOG(INFO) << "stage0 ticks = " << (endTick - startTick);
startTick = __rdtsc();
#endif // DEBUG
SpMMCreateBlocks(csr, block_csr_array, num_M_blocks, num_K_blocks, M_block_size, K_block_size,
Op::use_lhs, Op::use_rhs);
#ifdef DEBUG
endTick = __rdtsc();
LOG(INFO) << "stage1 ticks = " << (endTick - startTick);
startTick = __rdtsc();
#endif // DEBUG
libxsmm_meltwfunction_opreduce_vecs_idx kernel = nullptr;
if (std::is_same<Redop, op::Max<DType>>::value) {
kernel = SpMMCreateLibxsmmKernel<IdType, DType, Op>(has_idx, N,
LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_REDOP_MAX,
true);
} else if (std::is_same<Redop, op::Min<DType>>::value) {
kernel = SpMMCreateLibxsmmKernel<IdType, DType, Op>(has_idx, N,
LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_REDOP_MIN,
true);
} else if (std::is_same<Redop, op::Add<DType>>::value) {
kernel = SpMMCreateLibxsmmKernel<IdType, DType, Op>(has_idx, N,
LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_REDOP_SUM,
false);
}
#ifdef DEBUG
endTick = __rdtsc();
LOG(INFO) << "stage2 ticks = " << (endTick - startTick);
startTick = __rdtsc();
#endif // DEBUG
if (std::is_same<Redop, op::Max<DType>>::value || std::is_same<Redop, op::Min<DType>>::value) {
SpMMBlockwiseOpCmp<IdType, DType, Op, Redop>(block_csr_array, B, E, C, argB, argE, has_idx, N,
num_M_blocks, num_K_blocks, M_block_size, kernel);
} else {
SpMMBlockwiseOpSum(block_csr_array, B, E, C, has_idx, N, num_M_blocks, num_K_blocks,
M_block_size, kernel);
}
#ifdef DEBUG
endTick = __rdtsc();
LOG(INFO) << "stage3 ticks = " << (endTick - startTick);
startTick = __rdtsc();
#endif // DEBUG
SpMMFreeBlocks(block_csr_array, num_M_blocks, num_K_blocks, Op::use_lhs, Op::use_rhs);
#ifdef DEBUG
endTick = __rdtsc();
LOG(INFO) << "stage4 ticks = " << (endTick - startTick);
#endif // DEBUG
}
/*!
* \brief Optimized CPU kernel of SpMM-Sum on Csr format.
* \param bcast Broadcast information.
* \param csr The Csr matrix.
* \param ufeat The feature on source nodes.
* \param efeat The feature on edges.
* \param out The result feature on destination nodes.
* \note it uses libxsmm, blocking and dynamic thread scheduling.
*/
template <typename IdType, typename DType, typename Op>
void SpMMSumCsrLibxsmm(const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out) {
NDArray dummy;
SpMMRedopCsrOpt<IdType, DType, Op, op::Add<DType>>(bcast, csr, ufeat, efeat, out, dummy, dummy);
}
/*!
* \brief Optimized CPU kernel of SpMM-Min/Max on Csr format.
* \param bcast Broadcast information.
* \param csr The Csr matrix.
* \param ufeat The feature on source nodes.
* \param efeat The feature on edges.
* \param out The result feature on destination nodes.
* \param argu Arg-Min/Max on source nodes.
* \param arge Arg-Min/Max on edges.
* \note it uses libxsmm, blocking and dynamic thread scheduling.
*/
template <typename IdType, typename DType, typename Op, typename Cmp>
void SpMMCmpCsrLibxsmm(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
NDArray efeat, NDArray out, NDArray argu, NDArray arge) {
SpMMRedopCsrOpt<IdType, DType, Op, Cmp>(bcast, csr, ufeat, efeat, out, argu, arge);
}
} // namespace cpu
} // namespace aten
} // namespace dgl
#endif // USE_LIBXSMM
#endif // USE_AVX
#endif // _WIN32
#endif // DGL_ARRAY_CPU_SPMM_BLOCKING_LIBXSMM_H_
Subproject commit fa687556130b6298430f1c0555a14cf79ab6101c
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment