Unverified Commit 8ae50c42 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4804)



* [Misc] clang-format auto fix.

* manual

* manual

* manual

* manual

* todo

* fix
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 81831111
...@@ -6,10 +6,10 @@ ...@@ -6,10 +6,10 @@
#ifndef DGL_ARRAY_CUDA_GE_SPMM_CUH_ #ifndef DGL_ARRAY_CUDA_GE_SPMM_CUH_
#define DGL_ARRAY_CUDA_GE_SPMM_CUH_ #define DGL_ARRAY_CUDA_GE_SPMM_CUH_
#include "macro.cuh"
#include "atomic.cuh"
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./utils.h" #include "./utils.h"
#include "atomic.cuh"
#include "macro.cuh"
namespace dgl { namespace dgl {
...@@ -23,23 +23,19 @@ namespace cuda { ...@@ -23,23 +23,19 @@ namespace cuda {
* \note GE-SpMM: https://arxiv.org/pdf/2007.03179.pdf * \note GE-SpMM: https://arxiv.org/pdf/2007.03179.pdf
* The grid dimension x and y are reordered for better performance. * The grid dimension x and y are reordered for better performance.
*/ */
template <typename Idx, typename DType, template <typename Idx, typename DType, typename BinaryOp>
typename BinaryOp>
__global__ void GESpMMKernel( __global__ void GESpMMKernel(
const DType* __restrict__ ufeat, const DType* __restrict__ ufeat, const DType* __restrict__ efeat,
const DType* __restrict__ efeat, DType* __restrict__ out, const Idx* __restrict__ indptr,
DType* __restrict__ out, const Idx* __restrict__ indices, const int64_t num_rows,
const Idx* __restrict__ indptr, const int64_t num_cols, const int64_t feat_len) {
const Idx* __restrict__ indices, const Idx rid =
const int64_t num_rows, const int64_t num_cols, blockIdx.x * blockDim.y + threadIdx.y; // over vertices dimension
const int64_t feat_len) { const Idx fid = (blockIdx.y * 64) + threadIdx.x; // over feature dimension
const Idx rid = blockIdx.x * blockDim.y + threadIdx.y; // over vertices dimension
const Idx fid = (blockIdx.y * 64) + threadIdx.x; // over feature dimension
if (rid < num_rows && fid < feat_len) { if (rid < num_rows && fid < feat_len) {
const Idx low = __ldg(indptr + rid), high = __ldg(indptr + rid + 1); const Idx low = __ldg(indptr + rid), high = __ldg(indptr + rid + 1);
DType accum_0 = 0., DType accum_0 = 0., accum_1 = 0.;
accum_1 = 0.;
if (blockIdx.y != gridDim.y - 1) { // fid + 32 < feat_len if (blockIdx.y != gridDim.y - 1) { // fid + 32 < feat_len
for (Idx left = low; left < high; left += 32) { for (Idx left = low; left < high; left += 32) {
...@@ -109,24 +105,21 @@ __global__ void GESpMMKernel( ...@@ -109,24 +105,21 @@ __global__ void GESpMMKernel(
} }
out[feat_len * rid + fid] = accum_0; out[feat_len * rid + fid] = accum_0;
if (fid + 32 < feat_len) if (fid + 32 < feat_len) out[feat_len * rid + fid + 32] = accum_1;
out[feat_len * rid + fid + 32] = accum_1;
} }
} }
} }
} }
template <typename Idx, typename DType, template <typename Idx, typename DType, typename BinaryOp>
typename BinaryOp>
void GESpMMCsr( void GESpMMCsr(
const CSRMatrix& csr, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, NDArray efeat, int64_t feat_len) {
NDArray out, int64_t feat_len) { const Idx* indptr = csr.indptr.Ptr<Idx>();
const Idx *indptr = csr.indptr.Ptr<Idx>(); const Idx* indices = csr.indices.Ptr<Idx>();
const Idx *indices = csr.indices.Ptr<Idx>(); const DType* ufeat_data = ufeat.Ptr<DType>();
const DType *ufeat_data = ufeat.Ptr<DType>(); const DType* efeat_data = efeat.Ptr<DType>();
const DType *efeat_data = efeat.Ptr<DType>(); DType* out_data = out.Ptr<DType>();
DType *out_data = out.Ptr<DType>();
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
...@@ -138,12 +131,10 @@ void GESpMMCsr( ...@@ -138,12 +131,10 @@ void GESpMMCsr(
const dim3 nthrs(ntx, nty); const dim3 nthrs(ntx, nty);
const int sh_mem_size = 0; const int sh_mem_size = 0;
CUDA_KERNEL_CALL((GESpMMKernel<Idx, DType, BinaryOp>), CUDA_KERNEL_CALL(
nblks, nthrs, sh_mem_size, stream, (GESpMMKernel<Idx, DType, BinaryOp>), nblks, nthrs, sh_mem_size, stream,
ufeat_data, efeat_data, out_data, ufeat_data, efeat_data, out_data, indptr, indices, csr.num_rows,
indptr, indices, csr.num_cols, feat_len);
csr.num_rows, csr.num_cols,
feat_len);
} }
} // namespace cuda } // namespace cuda
......
...@@ -8,44 +8,46 @@ ...@@ -8,44 +8,46 @@
///////////////////////// Dispatchers ////////////////////////// ///////////////////////// Dispatchers //////////////////////////
/* Macro used for switching between broadcasting and non-broadcasting kernels. /* Macro used for switching between broadcasting and non-broadcasting kernels.
* It also copies the auxiliary information for calculating broadcasting offsets * It also copies the auxiliary information for calculating broadcasting offsets
* to GPU. * to GPU.
*/ */
#define BCAST_IDX_CTX_SWITCH(BCAST, EDGE_MAP, CTX, LHS_OFF, RHS_OFF, ...) do { \ #define BCAST_IDX_CTX_SWITCH(BCAST, EDGE_MAP, CTX, LHS_OFF, RHS_OFF, ...) \
const BcastOff &info = (BCAST); \ do { \
if (!info.use_bcast) { \ const BcastOff &info = (BCAST); \
constexpr bool UseBcast = false; \ if (!info.use_bcast) { \
if ((EDGE_MAP)) { \ constexpr bool UseBcast = false; \
constexpr bool UseIdx = true; \ if ((EDGE_MAP)) { \
{ __VA_ARGS__ } \ constexpr bool UseIdx = true; \
} else { \ { __VA_ARGS__ } \
constexpr bool UseIdx = false; \ } else { \
{ __VA_ARGS__ } \ constexpr bool UseIdx = false; \
} \ { __VA_ARGS__ } \
} else { \ } \
constexpr bool UseBcast = true; \ } else { \
const DGLContext ctx = (CTX); \ constexpr bool UseBcast = true; \
const auto device = runtime::DeviceAPI::Get(ctx); \ const DGLContext ctx = (CTX); \
(LHS_OFF) = static_cast<int64_t*>( \ const auto device = runtime::DeviceAPI::Get(ctx); \
device->AllocWorkspace(ctx, sizeof(int64_t) * info.lhs_offset.size())); \ (LHS_OFF) = static_cast<int64_t *>(device->AllocWorkspace( \
CUDA_CALL(cudaMemcpy((LHS_OFF), &info.lhs_offset[0], \ ctx, sizeof(int64_t) * info.lhs_offset.size())); \
sizeof(int64_t) * info.lhs_offset.size(), cudaMemcpyHostToDevice)); \ CUDA_CALL(cudaMemcpy( \
(RHS_OFF) = static_cast<int64_t*>( \ (LHS_OFF), &info.lhs_offset[0], \
device->AllocWorkspace(ctx, sizeof(int64_t) * info.rhs_offset.size())); \ sizeof(int64_t) * info.lhs_offset.size(), cudaMemcpyHostToDevice)); \
CUDA_CALL(cudaMemcpy((RHS_OFF), &info.rhs_offset[0], \ (RHS_OFF) = static_cast<int64_t *>(device->AllocWorkspace( \
sizeof(int64_t) * info.rhs_offset.size(), cudaMemcpyHostToDevice)); \ ctx, sizeof(int64_t) * info.rhs_offset.size())); \
if ((EDGE_MAP)) { \ CUDA_CALL(cudaMemcpy( \
constexpr bool UseIdx = true; \ (RHS_OFF), &info.rhs_offset[0], \
{ __VA_ARGS__ } \ sizeof(int64_t) * info.rhs_offset.size(), cudaMemcpyHostToDevice)); \
} else { \ if ((EDGE_MAP)) { \
constexpr bool UseIdx = false; \ constexpr bool UseIdx = true; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} \ } else { \
device->FreeWorkspace(ctx, (LHS_OFF)); \ constexpr bool UseIdx = false; \
device->FreeWorkspace(ctx, (RHS_OFF)); \ { __VA_ARGS__ } \
} \ } \
} while (0) device->FreeWorkspace(ctx, (LHS_OFF)); \
device->FreeWorkspace(ctx, (RHS_OFF)); \
} \
} while (0)
#endif // DGL_ARRAY_CUDA_MACRO_CUH_ #endif // DGL_ARRAY_CUDA_MACRO_CUH_
...@@ -4,14 +4,14 @@ ...@@ -4,14 +4,14 @@
* \brief rowwise sampling * \brief rowwise sampling
*/ */
#include <dgl/random.h> #include <curand_kernel.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/array_iterator.h> #include <dgl/array_iterator.h>
#include <curand_kernel.h> #include <dgl/random.h>
#include "../../runtime/cuda/cuda_common.h"
#include "./dgl_cub.cuh" #include "./dgl_cub.cuh"
#include "./utils.h" #include "./utils.h"
#include "../../runtime/cuda/cuda_common.h"
using namespace dgl::runtime; using namespace dgl::runtime;
...@@ -23,20 +23,15 @@ namespace { ...@@ -23,20 +23,15 @@ namespace {
template <typename IdType> template <typename IdType>
__global__ void _GlobalUniformNegativeSamplingKernel( __global__ void _GlobalUniformNegativeSamplingKernel(
const IdType* __restrict__ indptr, const IdType* __restrict__ indptr, const IdType* __restrict__ indices,
const IdType* __restrict__ indices, IdType* __restrict__ row, IdType* __restrict__ col, int64_t num_row,
IdType* __restrict__ row, int64_t num_col, int64_t num_samples, int num_trials,
IdType* __restrict__ col, bool exclude_self_loops, int32_t random_seed) {
int64_t num_row,
int64_t num_col,
int64_t num_samples,
int num_trials,
bool exclude_self_loops,
int32_t random_seed) {
int64_t tx = blockIdx.x * blockDim.x + threadIdx.x; int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x; const int stride_x = gridDim.x * blockDim.x;
curandStatePhilox4_32_10_t rng; // this allows generating 4 32-bit ints at a time curandStatePhilox4_32_10_t
rng; // this allows generating 4 32-bit ints at a time
curand_init(random_seed * gridDim.x + blockIdx.x, threadIdx.x, 0, &rng); curand_init(random_seed * gridDim.x + blockIdx.x, threadIdx.x, 0, &rng);
while (tx < num_samples) { while (tx < num_samples) {
...@@ -50,8 +45,7 @@ __global__ void _GlobalUniformNegativeSamplingKernel( ...@@ -50,8 +45,7 @@ __global__ void _GlobalUniformNegativeSamplingKernel(
int64_t u = static_cast<int64_t>(((y_lo << 32L) | z) % num_row); int64_t u = static_cast<int64_t>(((y_lo << 32L) | z) % num_row);
int64_t v = static_cast<int64_t>(((y_hi << 32L) | w) % num_col); int64_t v = static_cast<int64_t>(((y_hi << 32L) | w) % num_col);
if (exclude_self_loops && (u == v)) if (exclude_self_loops && (u == v)) continue;
continue;
// binary search of v among indptr[u:u+1] // binary search of v among indptr[u:u+1]
int64_t b = indptr[u], e = indptr[u + 1] - 1; int64_t b = indptr[u], e = indptr[u + 1] - 1;
...@@ -81,48 +75,47 @@ __global__ void _GlobalUniformNegativeSamplingKernel( ...@@ -81,48 +75,47 @@ __global__ void _GlobalUniformNegativeSamplingKernel(
template <typename DType> template <typename DType>
struct IsNotMinusOne { struct IsNotMinusOne {
__device__ __forceinline__ bool operator() (const std::pair<DType, DType>& a) { __device__ __forceinline__ bool operator()(const std::pair<DType, DType>& a) {
return a.first != -1; return a.first != -1;
} }
}; };
/*! /*!
* \brief Sort ordered pairs in ascending order, using \a tmp_major and \a tmp_minor as * \brief Sort ordered pairs in ascending order, using \a tmp_major and \a
* temporary buffers, each with \a n elements. * tmp_minor as temporary buffers, each with \a n elements.
*/ */
template <typename IdType> template <typename IdType>
void SortOrderedPairs( void SortOrderedPairs(
runtime::DeviceAPI* device, runtime::DeviceAPI* device, DGLContext ctx, IdType* major, IdType* minor,
DGLContext ctx, IdType* tmp_major, IdType* tmp_minor, int64_t n, cudaStream_t stream) {
IdType* major,
IdType* minor,
IdType* tmp_major,
IdType* tmp_minor,
int64_t n,
cudaStream_t stream) {
// Sort ordered pairs in lexicographical order by two radix sorts since // Sort ordered pairs in lexicographical order by two radix sorts since
// cub's radix sorts are stable. // cub's radix sorts are stable.
// We need a 2*n auxiliary storage to store the results form the first radix sort. // We need a 2*n auxiliary storage to store the results form the first radix
// sort.
size_t s1 = 0, s2 = 0; size_t s1 = 0, s2 = 0;
void* tmp1 = nullptr; void* tmp1 = nullptr;
void* tmp2 = nullptr; void* tmp2 = nullptr;
// Radix sort by minor key first, reorder the major key in the progress. // Radix sort by minor key first, reorder the major key in the progress.
CUDA_CALL(cub::DeviceRadixSort::SortPairs( CUDA_CALL(cub::DeviceRadixSort::SortPairs(
tmp1, s1, minor, tmp_minor, major, tmp_major, n, 0, sizeof(IdType) * 8, stream)); tmp1, s1, minor, tmp_minor, major, tmp_major, n, 0, sizeof(IdType) * 8,
stream));
tmp1 = device->AllocWorkspace(ctx, s1); tmp1 = device->AllocWorkspace(ctx, s1);
CUDA_CALL(cub::DeviceRadixSort::SortPairs( CUDA_CALL(cub::DeviceRadixSort::SortPairs(
tmp1, s1, minor, tmp_minor, major, tmp_major, n, 0, sizeof(IdType) * 8, stream)); tmp1, s1, minor, tmp_minor, major, tmp_major, n, 0, sizeof(IdType) * 8,
stream));
// Radix sort by major key next. // Radix sort by major key next.
CUDA_CALL(cub::DeviceRadixSort::SortPairs( CUDA_CALL(cub::DeviceRadixSort::SortPairs(
tmp2, s2, tmp_major, major, tmp_minor, minor, n, 0, sizeof(IdType) * 8, stream)); tmp2, s2, tmp_major, major, tmp_minor, minor, n, 0, sizeof(IdType) * 8,
tmp2 = (s2 > s1) ? device->AllocWorkspace(ctx, s2) : tmp1; // reuse buffer if s2 <= s1 stream));
tmp2 = (s2 > s1) ? device->AllocWorkspace(ctx, s2)
: tmp1; // reuse buffer if s2 <= s1
CUDA_CALL(cub::DeviceRadixSort::SortPairs( CUDA_CALL(cub::DeviceRadixSort::SortPairs(
tmp2, s2, tmp_major, major, tmp_minor, minor, n, 0, sizeof(IdType) * 8, stream)); tmp2, s2, tmp_major, major, tmp_minor, minor, n, 0, sizeof(IdType) * 8,
stream));
if (tmp1 != tmp2) if (tmp1 != tmp2) device->FreeWorkspace(ctx, tmp2);
device->FreeWorkspace(ctx, tmp2);
device->FreeWorkspace(ctx, tmp1); device->FreeWorkspace(ctx, tmp1);
} }
...@@ -130,17 +123,14 @@ void SortOrderedPairs( ...@@ -130,17 +123,14 @@ void SortOrderedPairs(
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling( std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
const CSRMatrix& csr, const CSRMatrix& csr, int64_t num_samples, int num_trials,
int64_t num_samples, bool exclude_self_loops, bool replace, double redundancy) {
int num_trials,
bool exclude_self_loops,
bool replace,
double redundancy) {
auto ctx = csr.indptr->ctx; auto ctx = csr.indptr->ctx;
auto dtype = csr.indptr->dtype; auto dtype = csr.indptr->dtype;
const int64_t num_row = csr.num_rows; const int64_t num_row = csr.num_rows;
const int64_t num_col = csr.num_cols; const int64_t num_col = csr.num_cols;
const int64_t num_actual_samples = static_cast<int64_t>(num_samples * (1 + redundancy)); const int64_t num_actual_samples =
static_cast<int64_t>(num_samples * (1 + redundancy));
IdArray row = Full<IdType>(-1, num_actual_samples, ctx); IdArray row = Full<IdType>(-1, num_actual_samples, ctx);
IdArray col = Full<IdType>(-1, num_actual_samples, ctx); IdArray col = Full<IdType>(-1, num_actual_samples, ctx);
IdArray out_row = IdArray::Empty({num_actual_samples}, dtype, ctx); IdArray out_row = IdArray::Empty({num_actual_samples}, dtype, ctx);
...@@ -156,22 +146,25 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling( ...@@ -156,22 +146,25 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
std::pair<IdArray, IdArray> result; std::pair<IdArray, IdArray> result;
int64_t num_out; int64_t num_out;
CUDA_KERNEL_CALL(_GlobalUniformNegativeSamplingKernel, CUDA_KERNEL_CALL(
nb, nt, 0, stream, _GlobalUniformNegativeSamplingKernel, nb, nt, 0, stream,
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), row_data, col_data,
row_data, col_data, num_row, num_col, num_actual_samples, num_trials, num_row, num_col, num_actual_samples, num_trials, exclude_self_loops,
exclude_self_loops, RandomEngine::ThreadLocal()->RandInt32()); RandomEngine::ThreadLocal()->RandInt32());
size_t tmp_size = 0; size_t tmp_size = 0;
int64_t* num_out_cuda = static_cast<int64_t*>(device->AllocWorkspace(ctx, sizeof(int64_t))); int64_t* num_out_cuda =
static_cast<int64_t*>(device->AllocWorkspace(ctx, sizeof(int64_t)));
IsNotMinusOne<IdType> op; IsNotMinusOne<IdType> op;
PairIterator<IdType> begin(row_data, col_data); PairIterator<IdType> begin(row_data, col_data);
PairIterator<IdType> out_begin(out_row_data, out_col_data); PairIterator<IdType> out_begin(out_row_data, out_col_data);
CUDA_CALL(cub::DeviceSelect::If( CUDA_CALL(cub::DeviceSelect::If(
nullptr, tmp_size, begin, out_begin, num_out_cuda, num_actual_samples, op, stream)); nullptr, tmp_size, begin, out_begin, num_out_cuda, num_actual_samples, op,
stream));
void* tmp = device->AllocWorkspace(ctx, tmp_size); void* tmp = device->AllocWorkspace(ctx, tmp_size);
CUDA_CALL(cub::DeviceSelect::If( CUDA_CALL(cub::DeviceSelect::If(
tmp, tmp_size, begin, out_begin, num_out_cuda, num_actual_samples, op, stream)); tmp, tmp_size, begin, out_begin, num_out_cuda, num_actual_samples, op,
stream));
num_out = cuda::GetCUDAScalar(device, ctx, num_out_cuda); num_out = cuda::GetCUDAScalar(device, ctx, num_out_cuda);
if (!replace) { if (!replace) {
...@@ -182,28 +175,33 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling( ...@@ -182,28 +175,33 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
PairIterator<IdType> unique_begin(unique_row_data, unique_col_data); PairIterator<IdType> unique_begin(unique_row_data, unique_col_data);
SortOrderedPairs( SortOrderedPairs(
device, ctx, out_row_data, out_col_data, unique_row_data, unique_col_data, device, ctx, out_row_data, out_col_data, unique_row_data,
num_out, stream); unique_col_data, num_out, stream);
size_t tmp_size_unique = 0; size_t tmp_size_unique = 0;
void* tmp_unique = nullptr; void* tmp_unique = nullptr;
CUDA_CALL(cub::DeviceSelect::Unique( CUDA_CALL(cub::DeviceSelect::Unique(
nullptr, tmp_size_unique, out_begin, unique_begin, num_out_cuda, num_out, stream)); nullptr, tmp_size_unique, out_begin, unique_begin, num_out_cuda,
tmp_unique = (tmp_size_unique > tmp_size) ? num_out, stream));
device->AllocWorkspace(ctx, tmp_size_unique) : tmp_unique = (tmp_size_unique > tmp_size)
tmp; // reuse buffer ? device->AllocWorkspace(ctx, tmp_size_unique)
: tmp; // reuse buffer
CUDA_CALL(cub::DeviceSelect::Unique( CUDA_CALL(cub::DeviceSelect::Unique(
tmp_unique, tmp_size_unique, out_begin, unique_begin, num_out_cuda, num_out, stream)); tmp_unique, tmp_size_unique, out_begin, unique_begin, num_out_cuda,
num_out, stream));
num_out = cuda::GetCUDAScalar(device, ctx, num_out_cuda); num_out = cuda::GetCUDAScalar(device, ctx, num_out_cuda);
num_out = std::min(num_samples, num_out); num_out = std::min(num_samples, num_out);
result = {unique_row.CreateView({num_out}, dtype), unique_col.CreateView({num_out}, dtype)}; result = {
unique_row.CreateView({num_out}, dtype),
unique_col.CreateView({num_out}, dtype)};
if (tmp_unique != tmp) if (tmp_unique != tmp) device->FreeWorkspace(ctx, tmp_unique);
device->FreeWorkspace(ctx, tmp_unique);
} else { } else {
num_out = std::min(num_samples, num_out); num_out = std::min(num_samples, num_out);
result = {out_row.CreateView({num_out}, dtype), out_col.CreateView({num_out}, dtype)}; result = {
out_row.CreateView({num_out}, dtype),
out_col.CreateView({num_out}, dtype)};
} }
device->FreeWorkspace(ctx, tmp); device->FreeWorkspace(ctx, tmp);
...@@ -211,10 +209,10 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling( ...@@ -211,10 +209,10 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
return result; return result;
} }
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDGLCUDA, int32_t>( template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<
const CSRMatrix&, int64_t, int, bool, bool, double); kDGLCUDA, int32_t>(const CSRMatrix&, int64_t, int, bool, bool, double);
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDGLCUDA, int64_t>( template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<
const CSRMatrix&, int64_t, int, bool, bool, double); kDGLCUDA, int64_t>(const CSRMatrix&, int64_t, int, bool, bool, double);
}; // namespace impl }; // namespace impl
}; // namespace aten }; // namespace aten
......
...@@ -4,15 +4,15 @@ ...@@ -4,15 +4,15 @@
* \brief uniform rowwise sampling * \brief uniform rowwise sampling
*/ */
#include <curand_kernel.h>
#include <dgl/random.h> #include <dgl/random.h>
#include <dgl/runtime/device_api.h> #include <dgl/runtime/device_api.h>
#include <curand_kernel.h>
#include <numeric> #include <numeric>
#include "./dgl_cub.cuh"
#include "../../array/cuda/atomic.cuh" #include "../../array/cuda/atomic.cuh"
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./dgl_cub.cuh"
using namespace dgl::aten::cuda; using namespace dgl::aten::cuda;
...@@ -25,29 +25,28 @@ namespace { ...@@ -25,29 +25,28 @@ namespace {
constexpr int BLOCK_SIZE = 128; constexpr int BLOCK_SIZE = 128;
/** /**
* @brief Compute the size of each row in the sampled CSR, without replacement. * @brief Compute the size of each row in the sampled CSR, without replacement.
* *
* @tparam IdType The type of node and edge indexes. * @tparam IdType The type of node and edge indexes.
* @param num_picks The number of non-zero entries to pick per row. * @param num_picks The number of non-zero entries to pick per row.
* @param num_rows The number of rows to pick. * @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick. * @param in_rows The set of rows to pick.
* @param in_ptr The index where each row's edges start. * @param in_ptr The index where each row's edges start.
* @param out_deg The size of each row in the sampled matrix, as indexed by * @param out_deg The size of each row in the sampled matrix, as indexed by
* `in_rows` (output). * `in_rows` (output).
*/ */
template<typename IdType> template <typename IdType>
__global__ void _CSRRowWiseSampleDegreeKernel( __global__ void _CSRRowWiseSampleDegreeKernel(
const int64_t num_picks, const int64_t num_picks, const int64_t num_rows,
const int64_t num_rows, const IdType* const in_rows, const IdType* const in_ptr,
const IdType * const in_rows, IdType* const out_deg) {
const IdType * const in_ptr,
IdType * const out_deg) {
const int tIdx = threadIdx.x + blockIdx.x * blockDim.x; const int tIdx = threadIdx.x + blockIdx.x * blockDim.x;
if (tIdx < num_rows) { if (tIdx < num_rows) {
const int in_row = in_rows[tIdx]; const int in_row = in_rows[tIdx];
const int out_row = tIdx; const int out_row = tIdx;
out_deg[out_row] = min(static_cast<IdType>(num_picks), in_ptr[in_row + 1] - in_ptr[in_row]); out_deg[out_row] = min(
static_cast<IdType>(num_picks), in_ptr[in_row + 1] - in_ptr[in_row]);
if (out_row == num_rows - 1) { if (out_row == num_rows - 1) {
// make the prefixsum work // make the prefixsum work
...@@ -57,23 +56,21 @@ __global__ void _CSRRowWiseSampleDegreeKernel( ...@@ -57,23 +56,21 @@ __global__ void _CSRRowWiseSampleDegreeKernel(
} }
/** /**
* @brief Compute the size of each row in the sampled CSR, with replacement. * @brief Compute the size of each row in the sampled CSR, with replacement.
* *
* @tparam IdType The type of node and edge indexes. * @tparam IdType The type of node and edge indexes.
* @param num_picks The number of non-zero entries to pick per row. * @param num_picks The number of non-zero entries to pick per row.
* @param num_rows The number of rows to pick. * @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick. * @param in_rows The set of rows to pick.
* @param in_ptr The index where each row's edges start. * @param in_ptr The index where each row's edges start.
* @param out_deg The size of each row in the sampled matrix, as indexed by * @param out_deg The size of each row in the sampled matrix, as indexed by
* `in_rows` (output). * `in_rows` (output).
*/ */
template<typename IdType> template <typename IdType>
__global__ void _CSRRowWiseSampleDegreeReplaceKernel( __global__ void _CSRRowWiseSampleDegreeReplaceKernel(
const int64_t num_picks, const int64_t num_picks, const int64_t num_rows,
const int64_t num_rows, const IdType* const in_rows, const IdType* const in_ptr,
const IdType * const in_rows, IdType* const out_deg) {
const IdType * const in_ptr,
IdType * const out_deg) {
const int tIdx = threadIdx.x + blockIdx.x * blockDim.x; const int tIdx = threadIdx.x + blockIdx.x * blockDim.x;
if (tIdx < num_rows) { if (tIdx < num_rows) {
...@@ -94,41 +91,36 @@ __global__ void _CSRRowWiseSampleDegreeReplaceKernel( ...@@ -94,41 +91,36 @@ __global__ void _CSRRowWiseSampleDegreeReplaceKernel(
} }
/** /**
* @brief Perform row-wise uniform sampling on a CSR matrix, * @brief Perform row-wise uniform sampling on a CSR matrix,
* and generate a COO matrix, without replacement. * and generate a COO matrix, without replacement.
* *
* @tparam IdType The ID type used for matrices. * @tparam IdType The ID type used for matrices.
* @tparam TILE_SIZE The number of rows covered by each threadblock. * @tparam TILE_SIZE The number of rows covered by each threadblock.
* @param rand_seed The random seed to use. * @param rand_seed The random seed to use.
* @param num_picks The number of non-zeros to pick per row. * @param num_picks The number of non-zeros to pick per row.
* @param num_rows The number of rows to pick. * @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick. * @param in_rows The set of rows to pick.
* @param in_ptr The indptr array of the input CSR. * @param in_ptr The indptr array of the input CSR.
* @param in_index The indices array of the input CSR. * @param in_index The indices array of the input CSR.
* @param data The data array of the input CSR. * @param data The data array of the input CSR.
* @param out_ptr The offset to write each row to in the output COO. * @param out_ptr The offset to write each row to in the output COO.
* @param out_rows The rows of the output COO (output). * @param out_rows The rows of the output COO (output).
* @param out_cols The columns of the output COO (output). * @param out_cols The columns of the output COO (output).
* @param out_idxs The data array of the output COO (output). * @param out_idxs The data array of the output COO (output).
*/ */
template<typename IdType, int TILE_SIZE> template <typename IdType, int TILE_SIZE>
__global__ void _CSRRowWiseSampleUniformKernel( __global__ void _CSRRowWiseSampleUniformKernel(
const uint64_t rand_seed, const uint64_t rand_seed, const int64_t num_picks, const int64_t num_rows,
const int64_t num_picks, const IdType* const in_rows, const IdType* const in_ptr,
const int64_t num_rows, const IdType* const in_index, const IdType* const data,
const IdType * const in_rows, const IdType* const out_ptr, IdType* const out_rows, IdType* const out_cols,
const IdType * const in_ptr, IdType* const out_idxs) {
const IdType * const in_index,
const IdType * const data,
const IdType * const out_ptr,
IdType * const out_rows,
IdType * const out_cols,
IdType * const out_idxs) {
// we assign one warp per row // we assign one warp per row
assert(blockDim.x == BLOCK_SIZE); assert(blockDim.x == BLOCK_SIZE);
int64_t out_row = blockIdx.x * TILE_SIZE; int64_t out_row = blockIdx.x * TILE_SIZE;
const int64_t last_row = min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows); const int64_t last_row =
min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);
curandStatePhilox4_32_10_t rng; curandStatePhilox4_32_10_t rng;
curand_init(rand_seed * gridDim.x + blockIdx.x, threadIdx.x, 0, &rng); curand_init(rand_seed * gridDim.x + blockIdx.x, threadIdx.x, 0, &rng);
...@@ -177,41 +169,36 @@ __global__ void _CSRRowWiseSampleUniformKernel( ...@@ -177,41 +169,36 @@ __global__ void _CSRRowWiseSampleUniformKernel(
} }
/** /**
* @brief Perform row-wise uniform sampling on a CSR matrix, * @brief Perform row-wise uniform sampling on a CSR matrix,
* and generate a COO matrix, with replacement. * and generate a COO matrix, with replacement.
* *
* @tparam IdType The ID type used for matrices. * @tparam IdType The ID type used for matrices.
* @tparam TILE_SIZE The number of rows covered by each threadblock. * @tparam TILE_SIZE The number of rows covered by each threadblock.
* @param rand_seed The random seed to use. * @param rand_seed The random seed to use.
* @param num_picks The number of non-zeros to pick per row. * @param num_picks The number of non-zeros to pick per row.
* @param num_rows The number of rows to pick. * @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick. * @param in_rows The set of rows to pick.
* @param in_ptr The indptr array of the input CSR. * @param in_ptr The indptr array of the input CSR.
* @param in_index The indices array of the input CSR. * @param in_index The indices array of the input CSR.
* @param data The data array of the input CSR. * @param data The data array of the input CSR.
* @param out_ptr The offset to write each row to in the output COO. * @param out_ptr The offset to write each row to in the output COO.
* @param out_rows The rows of the output COO (output). * @param out_rows The rows of the output COO (output).
* @param out_cols The columns of the output COO (output). * @param out_cols The columns of the output COO (output).
* @param out_idxs The data array of the output COO (output). * @param out_idxs The data array of the output COO (output).
*/ */
template<typename IdType, int TILE_SIZE> template <typename IdType, int TILE_SIZE>
__global__ void _CSRRowWiseSampleUniformReplaceKernel( __global__ void _CSRRowWiseSampleUniformReplaceKernel(
const uint64_t rand_seed, const uint64_t rand_seed, const int64_t num_picks, const int64_t num_rows,
const int64_t num_picks, const IdType* const in_rows, const IdType* const in_ptr,
const int64_t num_rows, const IdType* const in_index, const IdType* const data,
const IdType * const in_rows, const IdType* const out_ptr, IdType* const out_rows, IdType* const out_cols,
const IdType * const in_ptr, IdType* const out_idxs) {
const IdType * const in_index,
const IdType * const data,
const IdType * const out_ptr,
IdType * const out_rows,
IdType * const out_cols,
IdType * const out_idxs) {
// we assign one warp per row // we assign one warp per row
assert(blockDim.x == BLOCK_SIZE); assert(blockDim.x == BLOCK_SIZE);
int64_t out_row = blockIdx.x * TILE_SIZE; int64_t out_row = blockIdx.x * TILE_SIZE;
const int64_t last_row = min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows); const int64_t last_row =
min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);
curandStatePhilox4_32_10_t rng; curandStatePhilox4_32_10_t rng;
curand_init(rand_seed * gridDim.x + blockIdx.x, threadIdx.x, 0, &rng); curand_init(rand_seed * gridDim.x + blockIdx.x, threadIdx.x, 0, &rng);
...@@ -229,7 +216,8 @@ __global__ void _CSRRowWiseSampleUniformReplaceKernel( ...@@ -229,7 +216,8 @@ __global__ void _CSRRowWiseSampleUniformReplaceKernel(
const int64_t out_idx = out_row_start + idx; const int64_t out_idx = out_row_start + idx;
out_rows[out_idx] = row; out_rows[out_idx] = row;
out_cols[out_idx] = in_index[in_row_start + edge]; out_cols[out_idx] = in_index[in_row_start + edge];
out_idxs[out_idx] = data ? data[in_row_start + edge] : in_row_start + edge; out_idxs[out_idx] =
data ? data[in_row_start + edge] : in_row_start + edge;
} }
} }
out_row += 1; out_row += 1;
...@@ -248,11 +236,14 @@ COOMatrix _CSRRowWiseSamplingUniform( ...@@ -248,11 +236,14 @@ COOMatrix _CSRRowWiseSamplingUniform(
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const int64_t num_rows = rows->shape[0]; const int64_t num_rows = rows->shape[0];
const IdType * const slice_rows = static_cast<const IdType*>(rows->data); const IdType* const slice_rows = static_cast<const IdType*>(rows->data);
IdArray picked_row = NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8); IdArray picked_row =
IdArray picked_col = NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8); NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
IdArray picked_idx = NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8); IdArray picked_col =
NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
IdArray picked_idx =
NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
IdType* const out_rows = static_cast<IdType*>(picked_row->data); IdType* const out_rows = static_cast<IdType*>(picked_row->data);
IdType* const out_cols = static_cast<IdType*>(picked_col->data); IdType* const out_cols = static_cast<IdType*>(picked_col->data);
IdType* const out_idxs = static_cast<IdType*>(picked_idx->data); IdType* const out_idxs = static_cast<IdType*>(picked_idx->data);
...@@ -261,65 +252,52 @@ COOMatrix _CSRRowWiseSamplingUniform( ...@@ -261,65 +252,52 @@ COOMatrix _CSRRowWiseSamplingUniform(
const IdType* in_cols = mat.indices.Ptr<IdType>(); const IdType* in_cols = mat.indices.Ptr<IdType>();
const IdType* data = CSRHasData(mat) ? mat.data.Ptr<IdType>() : nullptr; const IdType* data = CSRHasData(mat) ? mat.data.Ptr<IdType>() : nullptr;
if (mat.is_pinned) { if (mat.is_pinned) {
CUDA_CALL(cudaHostGetDevicePointer( CUDA_CALL(cudaHostGetDevicePointer(&in_ptr, mat.indptr.Ptr<IdType>(), 0));
&in_ptr, mat.indptr.Ptr<IdType>(), 0)); CUDA_CALL(cudaHostGetDevicePointer(&in_cols, mat.indices.Ptr<IdType>(), 0));
CUDA_CALL(cudaHostGetDevicePointer(
&in_cols, mat.indices.Ptr<IdType>(), 0));
if (CSRHasData(mat)) { if (CSRHasData(mat)) {
CUDA_CALL(cudaHostGetDevicePointer( CUDA_CALL(cudaHostGetDevicePointer(&data, mat.data.Ptr<IdType>(), 0));
&data, mat.data.Ptr<IdType>(), 0));
} }
} }
// compute degree // compute degree
IdType * out_deg = static_cast<IdType*>( IdType* out_deg = static_cast<IdType*>(
device->AllocWorkspace(ctx, (num_rows + 1) * sizeof(IdType))); device->AllocWorkspace(ctx, (num_rows + 1) * sizeof(IdType)));
if (replace) { if (replace) {
const dim3 block(512); const dim3 block(512);
const dim3 grid((num_rows + block.x - 1) / block.x); const dim3 grid((num_rows + block.x - 1) / block.x);
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
_CSRRowWiseSampleDegreeReplaceKernel, _CSRRowWiseSampleDegreeReplaceKernel, grid, block, 0, stream, num_picks,
grid, block, 0, stream, num_rows, slice_rows, in_ptr, out_deg);
num_picks, num_rows, slice_rows, in_ptr, out_deg);
} else { } else {
const dim3 block(512); const dim3 block(512);
const dim3 grid((num_rows + block.x - 1) / block.x); const dim3 grid((num_rows + block.x - 1) / block.x);
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
_CSRRowWiseSampleDegreeKernel, _CSRRowWiseSampleDegreeKernel, grid, block, 0, stream, num_picks,
grid, block, 0, stream, num_rows, slice_rows, in_ptr, out_deg);
num_picks, num_rows, slice_rows, in_ptr, out_deg);
} }
// fill out_ptr // fill out_ptr
IdType * out_ptr = static_cast<IdType*>( IdType* out_ptr = static_cast<IdType*>(
device->AllocWorkspace(ctx, (num_rows + 1) * sizeof(IdType))); device->AllocWorkspace(ctx, (num_rows + 1) * sizeof(IdType)));
size_t prefix_temp_size = 0; size_t prefix_temp_size = 0;
CUDA_CALL(cub::DeviceScan::ExclusiveSum(nullptr, prefix_temp_size, CUDA_CALL(cub::DeviceScan::ExclusiveSum(
out_deg, nullptr, prefix_temp_size, out_deg, out_ptr, num_rows + 1, stream));
out_ptr, void* prefix_temp = device->AllocWorkspace(ctx, prefix_temp_size);
num_rows+1, CUDA_CALL(cub::DeviceScan::ExclusiveSum(
stream)); prefix_temp, prefix_temp_size, out_deg, out_ptr, num_rows + 1, stream));
void * prefix_temp = device->AllocWorkspace(ctx, prefix_temp_size);
CUDA_CALL(cub::DeviceScan::ExclusiveSum(prefix_temp, prefix_temp_size,
out_deg,
out_ptr,
num_rows+1,
stream));
device->FreeWorkspace(ctx, prefix_temp); device->FreeWorkspace(ctx, prefix_temp);
device->FreeWorkspace(ctx, out_deg); device->FreeWorkspace(ctx, out_deg);
cudaEvent_t copyEvent; cudaEvent_t copyEvent;
CUDA_CALL(cudaEventCreate(&copyEvent)); CUDA_CALL(cudaEventCreate(&copyEvent));
// TODO(dlasalle): use pinned memory to overlap with the actual sampling, and wait on // TODO(dlasalle): use pinned memory to overlap with the actual sampling, and
// a cudaevent // wait on a cudaevent
IdType new_len; IdType new_len;
// copy using the internal current stream // copy using the internal current stream
device->CopyDataFromTo(out_ptr, num_rows * sizeof(new_len), &new_len, 0, device->CopyDataFromTo(
sizeof(new_len), out_ptr, num_rows * sizeof(new_len), &new_len, 0, sizeof(new_len), ctx,
ctx, DGLContext{kDGLCPU, 0}, mat.indptr->dtype);
DGLContext{kDGLCPU, 0},
mat.indptr->dtype);
CUDA_CALL(cudaEventRecord(copyEvent, stream)); CUDA_CALL(cudaEventRecord(copyEvent, stream));
const uint64_t random_seed = RandomEngine::ThreadLocal()->RandInt(1000000000); const uint64_t random_seed = RandomEngine::ThreadLocal()->RandInt(1000000000);
...@@ -331,36 +309,16 @@ COOMatrix _CSRRowWiseSamplingUniform( ...@@ -331,36 +309,16 @@ COOMatrix _CSRRowWiseSamplingUniform(
const dim3 block(BLOCK_SIZE); const dim3 block(BLOCK_SIZE);
const dim3 grid((num_rows + TILE_SIZE - 1) / TILE_SIZE); const dim3 grid((num_rows + TILE_SIZE - 1) / TILE_SIZE);
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
(_CSRRowWiseSampleUniformReplaceKernel<IdType, TILE_SIZE>), (_CSRRowWiseSampleUniformReplaceKernel<IdType, TILE_SIZE>), grid, block,
grid, block, 0, stream, 0, stream, random_seed, num_picks, num_rows, slice_rows, in_ptr,
random_seed, in_cols, data, out_ptr, out_rows, out_cols, out_idxs);
num_picks,
num_rows,
slice_rows,
in_ptr,
in_cols,
data,
out_ptr,
out_rows,
out_cols,
out_idxs);
} else { // without replacement } else { // without replacement
const dim3 block(BLOCK_SIZE); const dim3 block(BLOCK_SIZE);
const dim3 grid((num_rows + TILE_SIZE - 1) / TILE_SIZE); const dim3 grid((num_rows + TILE_SIZE - 1) / TILE_SIZE);
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
(_CSRRowWiseSampleUniformKernel<IdType, TILE_SIZE>), (_CSRRowWiseSampleUniformKernel<IdType, TILE_SIZE>), grid, block, 0,
grid, block, 0, stream, stream, random_seed, num_picks, num_rows, slice_rows, in_ptr, in_cols,
random_seed, data, out_ptr, out_rows, out_cols, out_idxs);
num_picks,
num_rows,
slice_rows,
in_ptr,
in_cols,
data,
out_ptr,
out_rows,
out_cols,
out_idxs);
} }
device->FreeWorkspace(ctx, out_ptr); device->FreeWorkspace(ctx, out_ptr);
...@@ -372,8 +330,8 @@ COOMatrix _CSRRowWiseSamplingUniform( ...@@ -372,8 +330,8 @@ COOMatrix _CSRRowWiseSamplingUniform(
picked_col = picked_col.CreateView({new_len}, picked_col->dtype); picked_col = picked_col.CreateView({new_len}, picked_col->dtype);
picked_idx = picked_idx.CreateView({new_len}, picked_idx->dtype); picked_idx = picked_idx.CreateView({new_len}, picked_idx->dtype);
return COOMatrix(mat.num_rows, mat.num_cols, picked_row, return COOMatrix(
picked_col, picked_idx); mat.num_rows, mat.num_cols, picked_row, picked_col, picked_idx);
} }
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
...@@ -383,9 +341,11 @@ COOMatrix CSRRowWiseSamplingUniform( ...@@ -383,9 +341,11 @@ COOMatrix CSRRowWiseSamplingUniform(
// Basically this is UnitGraph::InEdges(). // Basically this is UnitGraph::InEdges().
COOMatrix coo = CSRToCOO(CSRSliceRows(mat, rows), false); COOMatrix coo = CSRToCOO(CSRSliceRows(mat, rows), false);
IdArray sliced_rows = IndexSelect(rows, coo.row); IdArray sliced_rows = IndexSelect(rows, coo.row);
return COOMatrix(mat.num_rows, mat.num_cols, sliced_rows, coo.col, coo.data); return COOMatrix(
mat.num_rows, mat.num_cols, sliced_rows, coo.col, coo.data);
} else { } else {
return _CSRRowWiseSamplingUniform<XPU, IdType>(mat, rows, num_picks, replace); return _CSRRowWiseSamplingUniform<XPU, IdType>(
mat, rows, num_picks, replace);
} }
} }
......
...@@ -8,9 +8,10 @@ ...@@ -8,9 +8,10 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
#include "./atomic.cuh" #include "./atomic.cuh"
#include "./utils.h"
namespace dgl { namespace dgl {
...@@ -24,11 +25,9 @@ namespace cuda { ...@@ -24,11 +25,9 @@ namespace cuda {
* \note each blockthread is responsible for aggregation on a row * \note each blockthread is responsible for aggregation on a row
* in the result tensor. * in the result tensor.
*/ */
template <typename IdType, typename DType, template <typename IdType, typename DType, typename ReduceOp>
typename ReduceOp>
__global__ void SegmentReduceKernel( __global__ void SegmentReduceKernel(
const DType* feat, const IdType* offsets, const DType* feat, const IdType* offsets, DType* out, IdType* arg,
DType* out, IdType* arg,
int64_t n, int64_t dim) { int64_t n, int64_t dim) {
for (int row = blockIdx.x; row < n; row += gridDim.x) { for (int row = blockIdx.x; row < n; row += gridDim.x) {
int col = blockIdx.y * blockDim.x + threadIdx.x; int col = blockIdx.y * blockDim.x + threadIdx.x;
...@@ -39,8 +38,7 @@ __global__ void SegmentReduceKernel( ...@@ -39,8 +38,7 @@ __global__ void SegmentReduceKernel(
ReduceOp::Call(&local_accum, &local_arg, feat[i * dim + col], i); ReduceOp::Call(&local_accum, &local_arg, feat[i * dim + col], i);
} }
out[row * dim + col] = local_accum; out[row * dim + col] = local_accum;
if (ReduceOp::require_arg) if (ReduceOp::require_arg) arg[row * dim + col] = local_arg;
arg[row * dim + col] = local_arg;
col += gridDim.y * blockDim.x; col += gridDim.y * blockDim.x;
} }
} }
...@@ -53,8 +51,7 @@ __global__ void SegmentReduceKernel( ...@@ -53,8 +51,7 @@ __global__ void SegmentReduceKernel(
*/ */
template <typename IdType, typename DType> template <typename IdType, typename DType>
__global__ void ScatterAddKernel( __global__ void ScatterAddKernel(
const DType *feat, const IdType *idx, DType *out, const DType* feat, const IdType* idx, DType* out, int64_t n, int64_t dim) {
int64_t n, int64_t dim) {
for (int row = blockIdx.x; row < n; row += gridDim.x) { for (int row = blockIdx.x; row < n; row += gridDim.x) {
const int write_row = idx[row]; const int write_row = idx[row];
int col = blockIdx.y * blockDim.x + threadIdx.x; int col = blockIdx.y * blockDim.x + threadIdx.x;
...@@ -73,7 +70,7 @@ __global__ void ScatterAddKernel( ...@@ -73,7 +70,7 @@ __global__ void ScatterAddKernel(
template <typename IdType, typename DType> template <typename IdType, typename DType>
__global__ void UpdateGradMinMaxHeteroKernel( __global__ void UpdateGradMinMaxHeteroKernel(
const DType *feat, const IdType *idx, const IdType *idx_type, DType *out, const DType* feat, const IdType* idx, const IdType* idx_type, DType* out,
int64_t n, int64_t dim, int type) { int64_t n, int64_t dim, int type) {
unsigned int tId = threadIdx.x; unsigned int tId = threadIdx.x;
unsigned int laneId = tId & 31; unsigned int laneId = tId & 31;
...@@ -100,8 +97,7 @@ __global__ void UpdateGradMinMaxHeteroKernel( ...@@ -100,8 +97,7 @@ __global__ void UpdateGradMinMaxHeteroKernel(
*/ */
template <typename IdType, typename DType> template <typename IdType, typename DType>
__global__ void BackwardSegmentCmpKernel( __global__ void BackwardSegmentCmpKernel(
const DType *feat, const IdType *arg, DType *out, const DType* feat, const IdType* arg, DType* out, int64_t n, int64_t dim) {
int64_t n, int64_t dim) {
for (int row = blockIdx.x; row < n; row += gridDim.x) { for (int row = blockIdx.x; row < n; row += gridDim.x) {
int col = blockIdx.y * blockDim.x + threadIdx.x; int col = blockIdx.y * blockDim.x + threadIdx.x;
while (col < dim) { while (col < dim) {
...@@ -122,11 +118,7 @@ __global__ void BackwardSegmentCmpKernel( ...@@ -122,11 +118,7 @@ __global__ void BackwardSegmentCmpKernel(
* \param arg An auxiliary tensor storing ArgMax/Min information, * \param arg An auxiliary tensor storing ArgMax/Min information,
*/ */
template <typename IdType, typename DType, typename ReduceOp> template <typename IdType, typename DType, typename ReduceOp>
void SegmentReduce( void SegmentReduce(NDArray feat, NDArray offsets, NDArray out, NDArray arg) {
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg) {
const DType* feat_data = feat.Ptr<DType>(); const DType* feat_data = feat.Ptr<DType>();
const IdType* offsets_data = offsets.Ptr<IdType>(); const IdType* offsets_data = offsets.Ptr<IdType>();
DType* out_data = out.Ptr<DType>(); DType* out_data = out.Ptr<DType>();
...@@ -135,8 +127,7 @@ void SegmentReduce( ...@@ -135,8 +127,7 @@ void SegmentReduce(
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int64_t n = out->shape[0]; int64_t n = out->shape[0];
int64_t dim = 1; int64_t dim = 1;
for (int i = 1; i < out->ndim; ++i) for (int i = 1; i < out->ndim; ++i) dim *= out->shape[i];
dim *= out->shape[i];
const int nbx = FindNumBlocks<'x'>(n); const int nbx = FindNumBlocks<'x'>(n);
const int ntx = FindNumThreads(dim); const int ntx = FindNumThreads(dim);
...@@ -145,10 +136,9 @@ void SegmentReduce( ...@@ -145,10 +136,9 @@ void SegmentReduce(
const dim3 nblks(nbx, nby); const dim3 nblks(nbx, nby);
const dim3 nthrs(ntx, nty); const dim3 nthrs(ntx, nty);
// TODO(zihao): try cub's DeviceSegmentedReduce and compare the performance. // TODO(zihao): try cub's DeviceSegmentedReduce and compare the performance.
CUDA_KERNEL_CALL((SegmentReduceKernel<IdType, DType, ReduceOp>), CUDA_KERNEL_CALL(
nblks, nthrs, 0, stream, (SegmentReduceKernel<IdType, DType, ReduceOp>), nblks, nthrs, 0, stream,
feat_data, offsets_data, out_data, arg_data, feat_data, offsets_data, out_data, arg_data, n, dim);
n, dim);
} }
/*! /*!
...@@ -159,19 +149,15 @@ void SegmentReduce( ...@@ -159,19 +149,15 @@ void SegmentReduce(
* \param out The output tensor. * \param out The output tensor.
*/ */
template <typename IdType, typename DType> template <typename IdType, typename DType>
void ScatterAdd( void ScatterAdd(NDArray feat, NDArray idx, NDArray out) {
NDArray feat,
NDArray idx,
NDArray out) {
const DType* feat_data = feat.Ptr<DType>(); const DType* feat_data = feat.Ptr<DType>();
const IdType* idx_data = idx.Ptr<IdType>(); const IdType* idx_data = idx.Ptr<IdType>();
DType *out_data = out.Ptr<DType>(); DType* out_data = out.Ptr<DType>();
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int64_t n = feat->shape[0]; int64_t n = feat->shape[0];
int64_t dim = 1; int64_t dim = 1;
for (int i = 1; i < out->ndim; ++i) for (int i = 1; i < out->ndim; ++i) dim *= out->shape[i];
dim *= out->shape[i];
const int nbx = FindNumBlocks<'x'>(n); const int nbx = FindNumBlocks<'x'>(n);
const int ntx = FindNumThreads(dim); const int ntx = FindNumThreads(dim);
...@@ -179,10 +165,9 @@ void ScatterAdd( ...@@ -179,10 +165,9 @@ void ScatterAdd(
const int nty = 1; const int nty = 1;
const dim3 nblks(nbx, nby); const dim3 nblks(nbx, nby);
const dim3 nthrs(ntx, nty); const dim3 nthrs(ntx, nty);
CUDA_KERNEL_CALL((ScatterAddKernel<IdType, DType>), CUDA_KERNEL_CALL(
nblks, nthrs, 0, stream, (ScatterAddKernel<IdType, DType>), nblks, nthrs, 0, stream, feat_data,
feat_data, idx_data, out_data, idx_data, out_data, n, dim);
n, dim);
} }
/*! /*!
...@@ -195,24 +180,26 @@ void ScatterAdd( ...@@ -195,24 +180,26 @@ void ScatterAdd(
* \param list_out List of the output tensors. * \param list_out List of the output tensors.
*/ */
template <typename IdType, typename DType> template <typename IdType, typename DType>
void UpdateGradMinMax_hetero(const HeteroGraphPtr& graph, void UpdateGradMinMax_hetero(
const std::string& op, const HeteroGraphPtr& graph, const std::string& op,
const std::vector<NDArray>& list_feat, const std::vector<NDArray>& list_feat, const std::vector<NDArray>& list_idx,
const std::vector<NDArray>& list_idx, const std::vector<NDArray>& list_idx_types,
const std::vector<NDArray>& list_idx_types, std::vector<NDArray>* list_out) {
std::vector<NDArray>* list_out) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
if (op == "copy_lhs" || op == "copy_rhs") { if (op == "copy_lhs" || op == "copy_rhs") {
std::vector<std::vector<dgl_id_t>> src_dst_ntypes(graph->NumVertexTypes(), std::vector<std::vector<dgl_id_t>> src_dst_ntypes(
std::vector<dgl_id_t>()); graph->NumVertexTypes(), std::vector<dgl_id_t>());
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) { for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
auto pair = graph->meta_graph()->FindEdge(etype); auto pair = graph->meta_graph()->FindEdge(etype);
const dgl_id_t dst_ntype = pair.first; // graph is reversed const dgl_id_t dst_ntype = pair.first; // graph is reversed
const dgl_id_t src_ntype = pair.second; const dgl_id_t src_ntype = pair.second;
auto same_src_dst_ntype = std::find(std::begin(src_dst_ntypes[dst_ntype]), auto same_src_dst_ntype = std::find(
std::end(src_dst_ntypes[dst_ntype]), src_ntype); std::begin(src_dst_ntypes[dst_ntype]),
// if op is "copy_lhs", relation type with same src and dst node type will be updated once std::end(src_dst_ntypes[dst_ntype]), src_ntype);
if (op == "copy_lhs" && same_src_dst_ntype != std::end(src_dst_ntypes[dst_ntype])) // if op is "copy_lhs", relation type with same src and dst node type will
// be updated once
if (op == "copy_lhs" &&
same_src_dst_ntype != std::end(src_dst_ntypes[dst_ntype]))
continue; continue;
src_dst_ntypes[dst_ntype].push_back(src_ntype); src_dst_ntypes[dst_ntype].push_back(src_ntype);
const DType* feat_data = list_feat[dst_ntype].Ptr<DType>(); const DType* feat_data = list_feat[dst_ntype].Ptr<DType>();
...@@ -229,35 +216,31 @@ void UpdateGradMinMax_hetero(const HeteroGraphPtr& graph, ...@@ -229,35 +216,31 @@ void UpdateGradMinMax_hetero(const HeteroGraphPtr& graph,
const int nbx = FindNumBlocks<'x'>((n * th_per_row + ntx - 1) / ntx); const int nbx = FindNumBlocks<'x'>((n * th_per_row + ntx - 1) / ntx);
const dim3 nblks(nbx); const dim3 nblks(nbx);
const dim3 nthrs(ntx); const dim3 nthrs(ntx);
CUDA_KERNEL_CALL((UpdateGradMinMaxHeteroKernel<IdType, DType>), CUDA_KERNEL_CALL(
nblks, nthrs, 0, stream, (UpdateGradMinMaxHeteroKernel<IdType, DType>), nblks, nthrs, 0,
feat_data, idx_data, idx_type_data, stream, feat_data, idx_data, idx_type_data, out_data, n, dim, type);
out_data, n, dim, type);
} }
} }
} }
/*! /*!
* \brief CUDA implementation of backward phase of Segment Reduce with Min/Max reducer. * \brief CUDA implementation of backward phase of Segment Reduce with Min/Max
* \note math equation: out[arg[i, k], k] = feat[i, k] * reducer.
* \param feat The input tensor. * \note math equation: out[arg[i, k], k] = feat[i, k] \param feat The input
* tensor.
* \param arg The ArgMin/Max information, used for indexing. * \param arg The ArgMin/Max information, used for indexing.
* \param out The output tensor. * \param out The output tensor.
*/ */
template <typename IdType, typename DType> template <typename IdType, typename DType>
void BackwardSegmentCmp( void BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out) {
NDArray feat,
NDArray arg,
NDArray out) {
const DType* feat_data = feat.Ptr<DType>(); const DType* feat_data = feat.Ptr<DType>();
const IdType* arg_data = arg.Ptr<IdType>(); const IdType* arg_data = arg.Ptr<IdType>();
DType *out_data = out.Ptr<DType>(); DType* out_data = out.Ptr<DType>();
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int64_t n = feat->shape[0]; int64_t n = feat->shape[0];
int64_t dim = 1; int64_t dim = 1;
for (int i = 1; i < out->ndim; ++i) for (int i = 1; i < out->ndim; ++i) dim *= out->shape[i];
dim *= out->shape[i];
const int nbx = FindNumBlocks<'x'>(n); const int nbx = FindNumBlocks<'x'>(n);
const int ntx = FindNumThreads(dim); const int ntx = FindNumThreads(dim);
...@@ -265,10 +248,9 @@ void BackwardSegmentCmp( ...@@ -265,10 +248,9 @@ void BackwardSegmentCmp(
const int nty = 1; const int nty = 1;
const dim3 nblks(nbx, nby); const dim3 nblks(nbx, nby);
const dim3 nthrs(ntx, nty); const dim3 nthrs(ntx, nty);
CUDA_KERNEL_CALL((BackwardSegmentCmpKernel<IdType, DType>), CUDA_KERNEL_CALL(
nblks, nthrs, 0, stream, (BackwardSegmentCmpKernel<IdType, DType>), nblks, nthrs, 0, stream,
feat_data, arg_data, out_data, feat_data, arg_data, out_data, n, dim);
n, dim);
} }
} // namespace cuda } // namespace cuda
......
...@@ -4,12 +4,14 @@ ...@@ -4,12 +4,14 @@
* \brief COO operator GPU implementation * \brief COO operator GPU implementation
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <vector>
#include <unordered_set>
#include <numeric> #include <numeric>
#include <unordered_set>
#include <vector>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
#include "./atomic.cuh" #include "./atomic.cuh"
#include "./utils.h"
namespace dgl { namespace dgl {
...@@ -19,9 +21,8 @@ using namespace cuda; ...@@ -19,9 +21,8 @@ using namespace cuda;
namespace aten { namespace aten {
namespace impl { namespace impl {
template <typename IdType> template <typename IdType>
__device__ void _warpReduce(volatile IdType *sdata, IdType tid) { __device__ void _warpReduce(volatile IdType* sdata, IdType tid) {
sdata[tid] += sdata[tid + 32]; sdata[tid] += sdata[tid + 32];
sdata[tid] += sdata[tid + 16]; sdata[tid] += sdata[tid + 16];
sdata[tid] += sdata[tid + 8]; sdata[tid] += sdata[tid + 8];
...@@ -32,10 +33,8 @@ __device__ void _warpReduce(volatile IdType *sdata, IdType tid) { ...@@ -32,10 +33,8 @@ __device__ void _warpReduce(volatile IdType *sdata, IdType tid) {
template <typename IdType> template <typename IdType>
__global__ void _COOGetRowNNZKernel( __global__ void _COOGetRowNNZKernel(
const IdType* __restrict__ row_indices, const IdType* __restrict__ row_indices, IdType* __restrict__ glb_cnt,
IdType* __restrict__ glb_cnt, const int64_t row_query, IdType nnz) {
const int64_t row_query,
IdType nnz) {
__shared__ IdType local_cnt[1024]; __shared__ IdType local_cnt[1024];
IdType tx = threadIdx.x; IdType tx = threadIdx.x;
IdType bx = blockIdx.x; IdType bx = blockIdx.x;
...@@ -80,10 +79,9 @@ int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) { ...@@ -80,10 +79,9 @@ int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) {
IdType nb = dgl::cuda::FindNumBlocks<'x'>((nnz + nt - 1) / nt); IdType nb = dgl::cuda::FindNumBlocks<'x'>((nnz + nt - 1) / nt);
NDArray rst = NDArray::Empty({1}, coo.row->dtype, coo.row->ctx); NDArray rst = NDArray::Empty({1}, coo.row->dtype, coo.row->ctx);
_Fill(rst.Ptr<IdType>(), 1, IdType(0)); _Fill(rst.Ptr<IdType>(), 1, IdType(0));
CUDA_KERNEL_CALL(_COOGetRowNNZKernel, CUDA_KERNEL_CALL(
nb, nt, 0, stream, _COOGetRowNNZKernel, nb, nt, 0, stream, coo.row.Ptr<IdType>(),
coo.row.Ptr<IdType>(), rst.Ptr<IdType>(), rst.Ptr<IdType>(), row, nnz);
row, nnz);
rst = rst.CopyTo(DGLContext{kDGLCPU, 0}); rst = rst.CopyTo(DGLContext{kDGLCPU, 0});
return *rst.Ptr<IdType>(); return *rst.Ptr<IdType>();
} }
...@@ -93,8 +91,7 @@ template int64_t COOGetRowNNZ<kDGLCUDA, int64_t>(COOMatrix, int64_t); ...@@ -93,8 +91,7 @@ template int64_t COOGetRowNNZ<kDGLCUDA, int64_t>(COOMatrix, int64_t);
template <typename IdType> template <typename IdType>
__global__ void _COOGetAllRowNNZKernel( __global__ void _COOGetAllRowNNZKernel(
const IdType* __restrict__ row_indices, const IdType* __restrict__ row_indices, IdType* __restrict__ glb_cnts,
IdType* __restrict__ glb_cnts,
IdType nnz) { IdType nnz) {
IdType eid = blockIdx.x * blockDim.x + threadIdx.x; IdType eid = blockIdx.x * blockDim.x + threadIdx.x;
while (eid < nnz) { while (eid < nnz) {
...@@ -118,20 +115,18 @@ NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) { ...@@ -118,20 +115,18 @@ NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) {
IdType nb = dgl::cuda::FindNumBlocks<'x'>((nnz + nt - 1) / nt); IdType nb = dgl::cuda::FindNumBlocks<'x'>((nnz + nt - 1) / nt);
NDArray rst = NDArray::Empty({1}, coo.row->dtype, coo.row->ctx); NDArray rst = NDArray::Empty({1}, coo.row->dtype, coo.row->ctx);
_Fill(rst.Ptr<IdType>(), 1, IdType(0)); _Fill(rst.Ptr<IdType>(), 1, IdType(0));
CUDA_KERNEL_CALL(_COOGetRowNNZKernel, CUDA_KERNEL_CALL(
nb, nt, 0, stream, _COOGetRowNNZKernel, nb, nt, 0, stream, coo.row.Ptr<IdType>(),
coo.row.Ptr<IdType>(), rst.Ptr<IdType>(), rst.Ptr<IdType>(), row, nnz);
row, nnz);
return rst; return rst;
} else { } else {
IdType nt = 1024; IdType nt = 1024;
IdType nb = dgl::cuda::FindNumBlocks<'x'>((nnz + nt - 1) / nt); IdType nb = dgl::cuda::FindNumBlocks<'x'>((nnz + nt - 1) / nt);
NDArray in_degrees = NDArray::Empty({num_rows}, rows->dtype, rows->ctx); NDArray in_degrees = NDArray::Empty({num_rows}, rows->dtype, rows->ctx);
_Fill(in_degrees.Ptr<IdType>(), num_rows, IdType(0)); _Fill(in_degrees.Ptr<IdType>(), num_rows, IdType(0));
CUDA_KERNEL_CALL(_COOGetAllRowNNZKernel, CUDA_KERNEL_CALL(
nb, nt, 0, stream, _COOGetAllRowNNZKernel, nb, nt, 0, stream, coo.row.Ptr<IdType>(),
coo.row.Ptr<IdType>(), in_degrees.Ptr<IdType>(), in_degrees.Ptr<IdType>(), nnz);
nnz);
return IndexSelect(in_degrees, rows); return IndexSelect(in_degrees, rows);
} }
} }
......
...@@ -4,13 +4,15 @@ ...@@ -4,13 +4,15 @@
* \brief CSR operator CPU implementation * \brief CSR operator CPU implementation
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <vector>
#include <unordered_set>
#include <numeric> #include <numeric>
#include <unordered_set>
#include <vector>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
#include "./atomic.cuh" #include "./atomic.cuh"
#include "./dgl_cub.cuh" #include "./dgl_cub.cuh"
#include "./utils.h"
namespace dgl { namespace dgl {
...@@ -32,12 +34,11 @@ bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) { ...@@ -32,12 +34,11 @@ bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
IdArray out = aten::NewIdArray(1, ctx, sizeof(IdType) * 8); IdArray out = aten::NewIdArray(1, ctx, sizeof(IdType) * 8);
const IdType* data = nullptr; const IdType* data = nullptr;
// TODO(minjie): use binary search for sorted csr // TODO(minjie): use binary search for sorted csr
CUDA_KERNEL_CALL(dgl::cuda::_LinearSearchKernel, CUDA_KERNEL_CALL(
1, 1, 0, stream, dgl::cuda::_LinearSearchKernel, 1, 1, 0, stream, csr.indptr.Ptr<IdType>(),
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), data, csr.indices.Ptr<IdType>(), data, rows.Ptr<IdType>(), cols.Ptr<IdType>(),
rows.Ptr<IdType>(), cols.Ptr<IdType>(), 1, 1, 1, static_cast<IdType*>(nullptr), static_cast<IdType>(-1),
1, 1, 1, out.Ptr<IdType>());
static_cast<IdType*>(nullptr), static_cast<IdType>(-1), out.Ptr<IdType>());
out = out.CopyTo(DGLContext{kDGLCPU, 0}); out = out.CopyTo(DGLContext{kDGLCPU, 0});
return *out.Ptr<IdType>() != -1; return *out.Ptr<IdType>() != -1;
} }
...@@ -51,8 +52,7 @@ NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) { ...@@ -51,8 +52,7 @@ NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
const auto collen = col->shape[0]; const auto collen = col->shape[0];
const auto rstlen = std::max(rowlen, collen); const auto rstlen = std::max(rowlen, collen);
NDArray rst = NDArray::Empty({rstlen}, row->dtype, row->ctx); NDArray rst = NDArray::Empty({rstlen}, row->dtype, row->ctx);
if (rstlen == 0) if (rstlen == 0) return rst;
return rst;
const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1; const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1; const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
...@@ -62,18 +62,17 @@ NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) { ...@@ -62,18 +62,17 @@ NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
const IdType* indptr_data = csr.indptr.Ptr<IdType>(); const IdType* indptr_data = csr.indptr.Ptr<IdType>();
const IdType* indices_data = csr.indices.Ptr<IdType>(); const IdType* indices_data = csr.indices.Ptr<IdType>();
if (csr.is_pinned) { if (csr.is_pinned) {
CUDA_CALL(cudaHostGetDevicePointer( CUDA_CALL(
&indptr_data, csr.indptr.Ptr<IdType>(), 0)); cudaHostGetDevicePointer(&indptr_data, csr.indptr.Ptr<IdType>(), 0));
CUDA_CALL(cudaHostGetDevicePointer( CUDA_CALL(
&indices_data, csr.indices.Ptr<IdType>(), 0)); cudaHostGetDevicePointer(&indices_data, csr.indices.Ptr<IdType>(), 0));
} }
// TODO(minjie): use binary search for sorted csr // TODO(minjie): use binary search for sorted csr
CUDA_KERNEL_CALL(dgl::cuda::_LinearSearchKernel, CUDA_KERNEL_CALL(
nb, nt, 0, stream, dgl::cuda::_LinearSearchKernel, nb, nt, 0, stream, indptr_data,
indptr_data, indices_data, data, indices_data, data, row.Ptr<IdType>(), col.Ptr<IdType>(), row_stride,
row.Ptr<IdType>(), col.Ptr<IdType>(), col_stride, rstlen, static_cast<IdType*>(nullptr),
row_stride, col_stride, rstlen, static_cast<IdType>(-1), rst.Ptr<IdType>());
static_cast<IdType*>(nullptr), static_cast<IdType>(-1), rst.Ptr<IdType>());
return rst != -1; return rst != -1;
} }
...@@ -88,8 +87,8 @@ template NDArray CSRIsNonZero<kDGLCUDA, int64_t>(CSRMatrix, NDArray, NDArray); ...@@ -88,8 +87,8 @@ template NDArray CSRIsNonZero<kDGLCUDA, int64_t>(CSRMatrix, NDArray, NDArray);
*/ */
template <typename IdType> template <typename IdType>
__global__ void _SegmentHasNoDuplicate( __global__ void _SegmentHasNoDuplicate(
const IdType* indptr, const IdType* indices, const IdType* indptr, const IdType* indices, int64_t num_rows,
int64_t num_rows, int8_t* flags) { int8_t* flags) {
int tx = blockIdx.x * blockDim.x + threadIdx.x; int tx = blockIdx.x * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x; const int stride_x = gridDim.x * blockDim.x;
while (tx < num_rows) { while (tx < num_rows) {
...@@ -102,23 +101,21 @@ __global__ void _SegmentHasNoDuplicate( ...@@ -102,23 +101,21 @@ __global__ void _SegmentHasNoDuplicate(
} }
} }
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
bool CSRHasDuplicate(CSRMatrix csr) { bool CSRHasDuplicate(CSRMatrix csr) {
if (!csr.sorted) if (!csr.sorted) csr = CSRSort(csr);
csr = CSRSort(csr);
const auto& ctx = csr.indptr->ctx; const auto& ctx = csr.indptr->ctx;
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
// We allocate a workspace of num_rows bytes. It wastes a little bit memory but should // We allocate a workspace of num_rows bytes. It wastes a little bit memory
// be fine. // but should be fine.
int8_t* flags = static_cast<int8_t*>(device->AllocWorkspace(ctx, csr.num_rows)); int8_t* flags =
static_cast<int8_t*>(device->AllocWorkspace(ctx, csr.num_rows));
const int nt = dgl::cuda::FindNumThreads(csr.num_rows); const int nt = dgl::cuda::FindNumThreads(csr.num_rows);
const int nb = (csr.num_rows + nt - 1) / nt; const int nb = (csr.num_rows + nt - 1) / nt;
CUDA_KERNEL_CALL(_SegmentHasNoDuplicate, CUDA_KERNEL_CALL(
nb, nt, 0, stream, _SegmentHasNoDuplicate, nb, nt, 0, stream, csr.indptr.Ptr<IdType>(),
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), csr.indices.Ptr<IdType>(), csr.num_rows, flags);
csr.num_rows, flags);
bool ret = dgl::cuda::AllTrue(flags, csr.num_rows, ctx); bool ret = dgl::cuda::AllTrue(flags, csr.num_rows, ctx);
device->FreeWorkspace(ctx, flags); device->FreeWorkspace(ctx, flags);
return !ret; return !ret;
...@@ -141,10 +138,7 @@ template int64_t CSRGetRowNNZ<kDGLCUDA, int64_t>(CSRMatrix, int64_t); ...@@ -141,10 +138,7 @@ template int64_t CSRGetRowNNZ<kDGLCUDA, int64_t>(CSRMatrix, int64_t);
template <typename IdType> template <typename IdType>
__global__ void _CSRGetRowNNZKernel( __global__ void _CSRGetRowNNZKernel(
const IdType* vid, const IdType* vid, const IdType* indptr, IdType* out, int64_t length) {
const IdType* indptr,
IdType* out,
int64_t length) {
int tx = blockIdx.x * blockDim.x + threadIdx.x; int tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = gridDim.x * blockDim.x; int stride_x = gridDim.x * blockDim.x;
while (tx < length) { while (tx < length) {
...@@ -161,28 +155,29 @@ NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) { ...@@ -161,28 +155,29 @@ NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {
const IdType* vid_data = rows.Ptr<IdType>(); const IdType* vid_data = rows.Ptr<IdType>();
const IdType* indptr_data = csr.indptr.Ptr<IdType>(); const IdType* indptr_data = csr.indptr.Ptr<IdType>();
if (csr.is_pinned) { if (csr.is_pinned) {
CUDA_CALL(cudaHostGetDevicePointer( CUDA_CALL(
&indptr_data, csr.indptr.Ptr<IdType>(), 0)); cudaHostGetDevicePointer(&indptr_data, csr.indptr.Ptr<IdType>(), 0));
} }
NDArray rst = NDArray::Empty({len}, rows->dtype, rows->ctx); NDArray rst = NDArray::Empty({len}, rows->dtype, rows->ctx);
IdType* rst_data = static_cast<IdType*>(rst->data); IdType* rst_data = static_cast<IdType*>(rst->data);
const int nt = dgl::cuda::FindNumThreads(len); const int nt = dgl::cuda::FindNumThreads(len);
const int nb = (len + nt - 1) / nt; const int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL(_CSRGetRowNNZKernel, CUDA_KERNEL_CALL(
nb, nt, 0, stream, _CSRGetRowNNZKernel, nb, nt, 0, stream, vid_data, indptr_data, rst_data,
vid_data, indptr_data, rst_data, len); len);
return rst; return rst;
} }
template NDArray CSRGetRowNNZ<kDGLCUDA, int32_t>(CSRMatrix, NDArray); template NDArray CSRGetRowNNZ<kDGLCUDA, int32_t>(CSRMatrix, NDArray);
template NDArray CSRGetRowNNZ<kDGLCUDA, int64_t>(CSRMatrix, NDArray); template NDArray CSRGetRowNNZ<kDGLCUDA, int64_t>(CSRMatrix, NDArray);
///////////////////////////// CSRGetRowColumnIndices ///////////////////////////// ////////////////////////// CSRGetRowColumnIndices //////////////////////////////
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) { NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {
const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row); const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
const int64_t offset = aten::IndexSelect<IdType>(csr.indptr, row) * sizeof(IdType); const int64_t offset =
aten::IndexSelect<IdType>(csr.indptr, row) * sizeof(IdType);
return csr.indices.CreateView({len}, csr.indices->dtype, offset); return csr.indices.CreateView({len}, csr.indices->dtype, offset);
} }
...@@ -194,11 +189,13 @@ template NDArray CSRGetRowColumnIndices<kDGLCUDA, int64_t>(CSRMatrix, int64_t); ...@@ -194,11 +189,13 @@ template NDArray CSRGetRowColumnIndices<kDGLCUDA, int64_t>(CSRMatrix, int64_t);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
NDArray CSRGetRowData(CSRMatrix csr, int64_t row) { NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row); const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
const int64_t offset = aten::IndexSelect<IdType>(csr.indptr, row) * sizeof(IdType); const int64_t offset =
aten::IndexSelect<IdType>(csr.indptr, row) * sizeof(IdType);
if (aten::CSRHasData(csr)) if (aten::CSRHasData(csr))
return csr.data.CreateView({len}, csr.data->dtype, offset); return csr.data.CreateView({len}, csr.data->dtype, offset);
else else
return aten::Range(offset, offset + len, csr.indptr->dtype.bits, csr.indptr->ctx); return aten::Range(
offset, offset + len, csr.indptr->dtype.bits, csr.indptr->ctx);
} }
template NDArray CSRGetRowData<kDGLCUDA, int32_t>(CSRMatrix, int64_t); template NDArray CSRGetRowData<kDGLCUDA, int32_t>(CSRMatrix, int64_t);
...@@ -218,13 +215,13 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) { ...@@ -218,13 +215,13 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
{nnz}, csr.indices->dtype, st_pos * sizeof(IdType)); {nnz}, csr.indices->dtype, st_pos * sizeof(IdType));
IdArray ret_data; IdArray ret_data;
if (CSRHasData(csr)) if (CSRHasData(csr))
ret_data = csr.data.CreateView({nnz}, csr.data->dtype, st_pos * sizeof(IdType)); ret_data =
csr.data.CreateView({nnz}, csr.data->dtype, st_pos * sizeof(IdType));
else else
ret_data = aten::Range(st_pos, ed_pos, ret_data =
csr.indptr->dtype.bits, csr.indptr->ctx); aten::Range(st_pos, ed_pos, csr.indptr->dtype.bits, csr.indptr->ctx);
return CSRMatrix(num_rows, csr.num_cols, return CSRMatrix(
ret_indptr, ret_indices, ret_data, num_rows, csr.num_cols, ret_indptr, ret_indices, ret_data, csr.sorted);
csr.sorted);
} }
template CSRMatrix CSRSliceRows<kDGLCUDA, int32_t>(CSRMatrix, int64_t, int64_t); template CSRMatrix CSRSliceRows<kDGLCUDA, int32_t>(CSRMatrix, int64_t, int64_t);
...@@ -232,25 +229,25 @@ template CSRMatrix CSRSliceRows<kDGLCUDA, int64_t>(CSRMatrix, int64_t, int64_t); ...@@ -232,25 +229,25 @@ template CSRMatrix CSRSliceRows<kDGLCUDA, int64_t>(CSRMatrix, int64_t, int64_t);
/*! /*!
* \brief Copy data segment to output buffers * \brief Copy data segment to output buffers
* *
* For the i^th row r = row[i], copy the data from indptr[r] ~ indptr[r+1] * For the i^th row r = row[i], copy the data from indptr[r] ~ indptr[r+1]
* to the out_data from out_indptr[i] ~ out_indptr[i+1] * to the out_data from out_indptr[i] ~ out_indptr[i+1]
* *
* If the provided `data` array is nullptr, write the read index to the out_data. * If the provided `data` array is nullptr, write the read index to the
* out_data.
* *
*/ */
template <typename IdType, typename DType> template <typename IdType, typename DType>
__global__ void _SegmentCopyKernel( __global__ void _SegmentCopyKernel(
const IdType* indptr, const DType* data, const IdType* indptr, const DType* data, const IdType* row, int64_t length,
const IdType* row, int64_t length, int64_t n_row, int64_t n_row, const IdType* out_indptr, DType* out_data) {
const IdType* out_indptr, DType* out_data) {
IdType tx = static_cast<IdType>(blockIdx.x) * blockDim.x + threadIdx.x; IdType tx = static_cast<IdType>(blockIdx.x) * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x; const int stride_x = gridDim.x * blockDim.x;
while (tx < length) { while (tx < length) {
IdType rpos = dgl::cuda::_UpperBound(out_indptr, n_row, tx) - 1; IdType rpos = dgl::cuda::_UpperBound(out_indptr, n_row, tx) - 1;
IdType rofs = tx - out_indptr[rpos]; IdType rofs = tx - out_indptr[rpos];
const IdType u = row[rpos]; const IdType u = row[rpos];
out_data[tx] = data? data[indptr[u]+rofs] : indptr[u]+rofs; out_data[tx] = data ? data[indptr[u] + rofs] : indptr[u] + rofs;
tx += stride_x; tx += stride_x;
} }
} }
...@@ -272,42 +269,39 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) { ...@@ -272,42 +269,39 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
const IdType* indices_data = csr.indices.Ptr<IdType>(); const IdType* indices_data = csr.indices.Ptr<IdType>();
const IdType* data_data = CSRHasData(csr) ? csr.data.Ptr<IdType>() : nullptr; const IdType* data_data = CSRHasData(csr) ? csr.data.Ptr<IdType>() : nullptr;
if (csr.is_pinned) { if (csr.is_pinned) {
CUDA_CALL(cudaHostGetDevicePointer( CUDA_CALL(
&indptr_data, csr.indptr.Ptr<IdType>(), 0)); cudaHostGetDevicePointer(&indptr_data, csr.indptr.Ptr<IdType>(), 0));
CUDA_CALL(cudaHostGetDevicePointer( CUDA_CALL(
&indices_data, csr.indices.Ptr<IdType>(), 0)); cudaHostGetDevicePointer(&indices_data, csr.indices.Ptr<IdType>(), 0));
if (CSRHasData(csr)) { if (CSRHasData(csr)) {
CUDA_CALL(cudaHostGetDevicePointer( CUDA_CALL(
&data_data, csr.data.Ptr<IdType>(), 0)); cudaHostGetDevicePointer(&data_data, csr.data.Ptr<IdType>(), 0));
} }
} }
CUDA_KERNEL_CALL(_SegmentCopyKernel, CUDA_KERNEL_CALL(
nb, nt, 0, stream, _SegmentCopyKernel, nb, nt, 0, stream, indptr_data, indices_data,
indptr_data, indices_data, rows.Ptr<IdType>(), nnz, len, ret_indptr.Ptr<IdType>(),
rows.Ptr<IdType>(), nnz, len, ret_indices.Ptr<IdType>());
ret_indptr.Ptr<IdType>(), ret_indices.Ptr<IdType>());
// Copy data. // Copy data.
IdArray ret_data = NDArray::Empty({nnz}, csr.indptr->dtype, rows->ctx); IdArray ret_data = NDArray::Empty({nnz}, csr.indptr->dtype, rows->ctx);
CUDA_KERNEL_CALL(_SegmentCopyKernel, CUDA_KERNEL_CALL(
nb, nt, 0, stream, _SegmentCopyKernel, nb, nt, 0, stream, indptr_data, data_data,
indptr_data, data_data, rows.Ptr<IdType>(), nnz, len, ret_indptr.Ptr<IdType>(),
rows.Ptr<IdType>(), nnz, len, ret_data.Ptr<IdType>());
ret_indptr.Ptr<IdType>(), ret_data.Ptr<IdType>()); return CSRMatrix(
return CSRMatrix(len, csr.num_cols, len, csr.num_cols, ret_indptr, ret_indices, ret_data, csr.sorted);
ret_indptr, ret_indices, ret_data,
csr.sorted);
} }
template CSRMatrix CSRSliceRows<kDGLCUDA, int32_t>(CSRMatrix , NDArray); template CSRMatrix CSRSliceRows<kDGLCUDA, int32_t>(CSRMatrix, NDArray);
template CSRMatrix CSRSliceRows<kDGLCUDA, int64_t>(CSRMatrix , NDArray); template CSRMatrix CSRSliceRows<kDGLCUDA, int64_t>(CSRMatrix, NDArray);
///////////////////////////// CSRGetDataAndIndices ///////////////////////////// ///////////////////////////// CSRGetDataAndIndices /////////////////////////////
/*! /*!
* \brief Generate a 0-1 mask for each index that hits the provided (row, col) * \brief Generate a 0-1 mask for each index that hits the provided (row, col)
* index. * index.
* *
* Examples: * Examples:
* Given a CSR matrix (with duplicate entries) as follows: * Given a CSR matrix (with duplicate entries) as follows:
* [[0, 1, 2, 0, 0], * [[0, 1, 2, 0, 0],
...@@ -319,10 +313,9 @@ template CSRMatrix CSRSliceRows<kDGLCUDA, int64_t>(CSRMatrix , NDArray); ...@@ -319,10 +313,9 @@ template CSRMatrix CSRSliceRows<kDGLCUDA, int64_t>(CSRMatrix , NDArray);
*/ */
template <typename IdType> template <typename IdType>
__global__ void _SegmentMaskKernel( __global__ void _SegmentMaskKernel(
const IdType* indptr, const IdType* indices, const IdType* indptr, const IdType* indices, const IdType* row,
const IdType* row, const IdType* col, const IdType* col, int64_t row_stride, int64_t col_stride, int64_t length,
int64_t row_stride, int64_t col_stride, IdType* mask) {
int64_t length, IdType* mask) {
int tx = blockIdx.x * blockDim.x + threadIdx.x; int tx = blockIdx.x * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x; const int stride_x = gridDim.x * blockDim.x;
while (tx < length) { while (tx < length) {
...@@ -350,9 +343,8 @@ __global__ void _SegmentMaskKernel( ...@@ -350,9 +343,8 @@ __global__ void _SegmentMaskKernel(
*/ */
template <typename IdType> template <typename IdType>
__global__ void _SortedSearchKernel( __global__ void _SortedSearchKernel(
const IdType* hay, int64_t hay_size, const IdType* hay, int64_t hay_size, const IdType* needles,
const IdType* needles, int64_t num_needles, int64_t num_needles, IdType* pos) {
IdType* pos) {
int tx = blockIdx.x * blockDim.x + threadIdx.x; int tx = blockIdx.x * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x; const int stride_x = gridDim.x * blockDim.x;
while (tx < num_needles) { while (tx < num_needles) {
...@@ -367,18 +359,18 @@ __global__ void _SortedSearchKernel( ...@@ -367,18 +359,18 @@ __global__ void _SortedSearchKernel(
hi = mid; hi = mid;
} }
} }
pos[tx] = (hay[hi] == ele)? hi : hi - 1; pos[tx] = (hay[hi] == ele) ? hi : hi - 1;
tx += stride_x; tx += stride_x;
} }
} }
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray row, NDArray col) { std::vector<NDArray> CSRGetDataAndIndices(
CSRMatrix csr, NDArray row, NDArray col) {
const auto rowlen = row->shape[0]; const auto rowlen = row->shape[0];
const auto collen = col->shape[0]; const auto collen = col->shape[0];
const auto len = std::max(rowlen, collen); const auto len = std::max(rowlen, collen);
if (len == 0) if (len == 0) return {NullArray(), NullArray(), NullArray()};
return {NullArray(), NullArray(), NullArray()};
const auto& ctx = row->ctx; const auto& ctx = row->ctx;
const auto nbits = row->dtype.bits; const auto nbits = row->dtype.bits;
...@@ -390,21 +382,19 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray row, NDArray co ...@@ -390,21 +382,19 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray row, NDArray co
const IdType* indptr_data = csr.indptr.Ptr<IdType>(); const IdType* indptr_data = csr.indptr.Ptr<IdType>();
const IdType* indices_data = csr.indices.Ptr<IdType>(); const IdType* indices_data = csr.indices.Ptr<IdType>();
if (csr.is_pinned) { if (csr.is_pinned) {
CUDA_CALL(cudaHostGetDevicePointer( CUDA_CALL(
&indptr_data, csr.indptr.Ptr<IdType>(), 0)); cudaHostGetDevicePointer(&indptr_data, csr.indptr.Ptr<IdType>(), 0));
CUDA_CALL(cudaHostGetDevicePointer( CUDA_CALL(
&indices_data, csr.indices.Ptr<IdType>(), 0)); cudaHostGetDevicePointer(&indices_data, csr.indices.Ptr<IdType>(), 0));
} }
// Generate a 0-1 mask for matched (row, col) positions. // Generate a 0-1 mask for matched (row, col) positions.
IdArray mask = Full(0, nnz, nbits, ctx); IdArray mask = Full(0, nnz, nbits, ctx);
const int nt = dgl::cuda::FindNumThreads(len); const int nt = dgl::cuda::FindNumThreads(len);
const int nb = (len + nt - 1) / nt; const int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL(_SegmentMaskKernel, CUDA_KERNEL_CALL(
nb, nt, 0, stream, _SegmentMaskKernel, nb, nt, 0, stream, indptr_data, indices_data,
indptr_data, indices_data, row.Ptr<IdType>(), col.Ptr<IdType>(), row_stride, col_stride, len,
row.Ptr<IdType>(), col.Ptr<IdType>(),
row_stride, col_stride, len,
mask.Ptr<IdType>()); mask.Ptr<IdType>());
IdArray idx = AsNumBits(NonZero(mask), nbits); IdArray idx = AsNumBits(NonZero(mask), nbits);
...@@ -416,15 +406,13 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray row, NDArray co ...@@ -416,15 +406,13 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray row, NDArray co
IdArray ret_row = NewIdArray(idx->shape[0], ctx, nbits); IdArray ret_row = NewIdArray(idx->shape[0], ctx, nbits);
const int nt2 = dgl::cuda::FindNumThreads(idx->shape[0]); const int nt2 = dgl::cuda::FindNumThreads(idx->shape[0]);
const int nb2 = (idx->shape[0] + nt - 1) / nt; const int nb2 = (idx->shape[0] + nt - 1) / nt;
CUDA_KERNEL_CALL(_SortedSearchKernel, CUDA_KERNEL_CALL(
nb2, nt2, 0, stream, _SortedSearchKernel, nb2, nt2, 0, stream, indptr_data, csr.num_rows,
indptr_data, csr.num_rows, idx.Ptr<IdType>(), idx->shape[0], ret_row.Ptr<IdType>());
idx.Ptr<IdType>(), idx->shape[0],
ret_row.Ptr<IdType>());
// Column & data can be obtained by index select. // Column & data can be obtained by index select.
IdArray ret_col = IndexSelect(csr.indices, idx); IdArray ret_col = IndexSelect(csr.indices, idx);
IdArray ret_data = CSRHasData(csr)? IndexSelect(csr.data, idx) : idx; IdArray ret_data = CSRHasData(csr) ? IndexSelect(csr.data, idx) : idx;
return {ret_row, ret_col, ret_data}; return {ret_row, ret_col, ret_data};
} }
...@@ -436,14 +424,14 @@ template std::vector<NDArray> CSRGetDataAndIndices<kDGLCUDA, int64_t>( ...@@ -436,14 +424,14 @@ template std::vector<NDArray> CSRGetDataAndIndices<kDGLCUDA, int64_t>(
///////////////////////////// CSRSliceMatrix ///////////////////////////// ///////////////////////////// CSRSliceMatrix /////////////////////////////
/*! /*!
* \brief Generate a 0-1 mask for each index whose column is in the provided set. * \brief Generate a 0-1 mask for each index whose column is in the provided
* It also counts the number of masked values per row. * set. It also counts the number of masked values per row.
*/ */
template <typename IdType> template <typename IdType>
__global__ void _SegmentMaskColKernel( __global__ void _SegmentMaskColKernel(
const IdType* indptr, const IdType* indices, int64_t num_rows, int64_t num_nnz, const IdType* indptr, const IdType* indices, int64_t num_rows,
const IdType* col, int64_t col_len, int64_t num_nnz, const IdType* col, int64_t col_len, IdType* mask,
IdType* mask, IdType* count) { IdType* count) {
IdType tx = static_cast<IdType>(blockIdx.x) * blockDim.x + threadIdx.x; IdType tx = static_cast<IdType>(blockIdx.x) * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x; const int stride_x = gridDim.x * blockDim.x;
while (tx < num_nnz) { while (tx < num_nnz) {
...@@ -452,14 +440,15 @@ __global__ void _SegmentMaskColKernel( ...@@ -452,14 +440,15 @@ __global__ void _SegmentMaskColKernel(
IdType i = dgl::cuda::_BinarySearch(col, col_len, cur_c); IdType i = dgl::cuda::_BinarySearch(col, col_len, cur_c);
if (i < col_len) { if (i < col_len) {
mask[tx] = 1; mask[tx] = 1;
cuda::AtomicAdd(count+rpos, IdType(1)); cuda::AtomicAdd(count + rpos, IdType(1));
} }
tx += stride_x; tx += stride_x;
} }
} }
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols) { CSRMatrix CSRSliceMatrix(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const auto& ctx = rows->ctx; const auto& ctx = rows->ctx;
const auto& dtype = rows->dtype; const auto& dtype = rows->dtype;
...@@ -468,23 +457,24 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray ...@@ -468,23 +457,24 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray
const int64_t new_ncols = cols->shape[0]; const int64_t new_ncols = cols->shape[0];
if (new_nrows == 0 || new_ncols == 0) if (new_nrows == 0 || new_ncols == 0)
return CSRMatrix(new_nrows, new_ncols, return CSRMatrix(
Full(0, new_nrows + 1, nbits, ctx), new_nrows, new_ncols, Full(0, new_nrows + 1, nbits, ctx),
NullArray(dtype, ctx), NullArray(dtype, ctx)); NullArray(dtype, ctx), NullArray(dtype, ctx));
// First slice rows // First slice rows
csr = CSRSliceRows(csr, rows); csr = CSRSliceRows(csr, rows);
if (csr.indices->shape[0] == 0) if (csr.indices->shape[0] == 0)
return CSRMatrix(new_nrows, new_ncols, return CSRMatrix(
Full(0, new_nrows + 1, nbits, ctx), new_nrows, new_ncols, Full(0, new_nrows + 1, nbits, ctx),
NullArray(dtype, ctx), NullArray(dtype, ctx)); NullArray(dtype, ctx), NullArray(dtype, ctx));
// Generate a 0-1 mask for matched (row, col) positions. // Generate a 0-1 mask for matched (row, col) positions.
IdArray mask = Full(0, csr.indices->shape[0], nbits, ctx); IdArray mask = Full(0, csr.indices->shape[0], nbits, ctx);
// A count for how many masked values per row. // A count for how many masked values per row.
IdArray count = NewIdArray(csr.num_rows, ctx, nbits); IdArray count = NewIdArray(csr.num_rows, ctx, nbits);
CUDA_CALL(cudaMemset(count.Ptr<IdType>(), 0, sizeof(IdType) * (csr.num_rows))); CUDA_CALL(
cudaMemset(count.Ptr<IdType>(), 0, sizeof(IdType) * (csr.num_rows)));
const int64_t nnz_csr = csr.indices->shape[0]; const int64_t nnz_csr = csr.indices->shape[0];
const int nt = 256; const int nt = 256;
...@@ -499,51 +489,49 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray ...@@ -499,51 +489,49 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray
auto ptr_cols = cols.Ptr<IdType>(); auto ptr_cols = cols.Ptr<IdType>();
size_t workspace_size = 0; size_t workspace_size = 0;
CUDA_CALL(cub::DeviceRadixSort::SortKeys( CUDA_CALL(cub::DeviceRadixSort::SortKeys(
nullptr, workspace_size, ptr_cols, ptr_sorted_cols, cols->shape[0], nullptr, workspace_size, ptr_cols, ptr_sorted_cols, cols->shape[0], 0,
0, sizeof(IdType)*8, stream)); sizeof(IdType) * 8, stream));
void *workspace = device->AllocWorkspace(ctx, workspace_size); void* workspace = device->AllocWorkspace(ctx, workspace_size);
CUDA_CALL(cub::DeviceRadixSort::SortKeys( CUDA_CALL(cub::DeviceRadixSort::SortKeys(
workspace, workspace_size, ptr_cols, ptr_sorted_cols, cols->shape[0], workspace, workspace_size, ptr_cols, ptr_sorted_cols, cols->shape[0], 0,
0, sizeof(IdType)*8, stream)); sizeof(IdType) * 8, stream));
device->FreeWorkspace(ctx, workspace); device->FreeWorkspace(ctx, workspace);
const IdType* indptr_data = csr.indptr.Ptr<IdType>(); const IdType* indptr_data = csr.indptr.Ptr<IdType>();
const IdType* indices_data = csr.indices.Ptr<IdType>(); const IdType* indices_data = csr.indices.Ptr<IdType>();
if (csr.is_pinned) { if (csr.is_pinned) {
CUDA_CALL(cudaHostGetDevicePointer( CUDA_CALL(
&indptr_data, csr.indptr.Ptr<IdType>(), 0)); cudaHostGetDevicePointer(&indptr_data, csr.indptr.Ptr<IdType>(), 0));
CUDA_CALL(cudaHostGetDevicePointer( CUDA_CALL(
&indices_data, csr.indices.Ptr<IdType>(), 0)); cudaHostGetDevicePointer(&indices_data, csr.indices.Ptr<IdType>(), 0));
} }
// Execute SegmentMaskColKernel // Execute SegmentMaskColKernel
int nb = (nnz_csr + nt - 1) / nt; int nb = (nnz_csr + nt - 1) / nt;
CUDA_KERNEL_CALL(_SegmentMaskColKernel, CUDA_KERNEL_CALL(
nb, nt, 0, stream, _SegmentMaskColKernel, nb, nt, 0, stream, indptr_data, indices_data,
indptr_data, indices_data, csr.num_rows, nnz_csr, csr.num_rows, nnz_csr, ptr_sorted_cols, cols_size, mask.Ptr<IdType>(),
ptr_sorted_cols, cols_size, count.Ptr<IdType>());
mask.Ptr<IdType>(), count.Ptr<IdType>());
IdArray idx = AsNumBits(NonZero(mask), nbits); IdArray idx = AsNumBits(NonZero(mask), nbits);
if (idx->shape[0] == 0) if (idx->shape[0] == 0)
return CSRMatrix(new_nrows, new_ncols, return CSRMatrix(
Full(0, new_nrows + 1, nbits, ctx), new_nrows, new_ncols, Full(0, new_nrows + 1, nbits, ctx),
NullArray(dtype, ctx), NullArray(dtype, ctx)); NullArray(dtype, ctx), NullArray(dtype, ctx));
// Indptr needs to be adjusted according to the new nnz per row. // Indptr needs to be adjusted according to the new nnz per row.
IdArray ret_indptr = CumSum(count, true); IdArray ret_indptr = CumSum(count, true);
// Column & data can be obtained by index select. // Column & data can be obtained by index select.
IdArray ret_col = IndexSelect(csr.indices, idx); IdArray ret_col = IndexSelect(csr.indices, idx);
IdArray ret_data = CSRHasData(csr)? IndexSelect(csr.data, idx) : idx; IdArray ret_data = CSRHasData(csr) ? IndexSelect(csr.data, idx) : idx;
// Relabel column // Relabel column
IdArray col_hash = NewIdArray(csr.num_cols, ctx, nbits); IdArray col_hash = NewIdArray(csr.num_cols, ctx, nbits);
Scatter_(cols, Range(0, cols->shape[0], nbits, ctx), col_hash); Scatter_(cols, Range(0, cols->shape[0], nbits, ctx), col_hash);
ret_col = IndexSelect(col_hash, ret_col); ret_col = IndexSelect(col_hash, ret_col);
return CSRMatrix(new_nrows, new_ncols, ret_indptr, return CSRMatrix(new_nrows, new_ncols, ret_indptr, ret_col, ret_data);
ret_col, ret_data);
} }
template CSRMatrix CSRSliceMatrix<kDGLCUDA, int32_t>( template CSRMatrix CSRSliceMatrix<kDGLCUDA, int32_t>(
......
...@@ -4,17 +4,18 @@ ...@@ -4,17 +4,18 @@
* \brief Array index select GPU implementation * \brief Array index select GPU implementation
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include "../../../runtime/cuda/cuda_common.h" #include "../../../runtime/cuda/cuda_common.h"
#include "../array_index_select.cuh" #include "../array_index_select.cuh"
#include "./array_index_select_uvm.cuh"
#include "../utils.h" #include "../utils.h"
#include "./array_index_select_uvm.cuh"
namespace dgl { namespace dgl {
using runtime::NDArray; using runtime::NDArray;
namespace aten { namespace aten {
namespace impl { namespace impl {
template<typename DType, typename IdType> template <typename DType, typename IdType>
NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) { NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const IdType* idx_data = static_cast<IdType*>(index->data); const IdType* idx_data = static_cast<IdType*>(index->data);
...@@ -34,31 +35,31 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) { ...@@ -34,31 +35,31 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
} }
NDArray ret = NDArray::Empty(shape, array->dtype, index->ctx); NDArray ret = NDArray::Empty(shape, array->dtype, index->ctx);
if (len == 0) if (len == 0) return ret;
return ret;
DType* ret_data = static_cast<DType*>(ret->data); DType* ret_data = static_cast<DType*>(ret->data);
if (num_feat == 1) { if (num_feat == 1) {
const int nt = cuda::FindNumThreads(len); const int nt = cuda::FindNumThreads(len);
const int nb = (len + nt - 1) / nt; const int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL(IndexSelectSingleKernel, nb, nt, 0, CUDA_KERNEL_CALL(
stream, array_data, idx_data, len, arr_len, ret_data); IndexSelectSingleKernel, nb, nt, 0, stream, array_data, idx_data, len,
arr_len, ret_data);
} else { } else {
dim3 block(256, 1); dim3 block(256, 1);
while (static_cast<int64_t>(block.x) >= 2*num_feat) { while (static_cast<int64_t>(block.x) >= 2 * num_feat) {
block.x /= 2; block.x /= 2;
block.y *= 2; block.y *= 2;
} }
const dim3 grid((len+block.y-1)/block.y); const dim3 grid((len + block.y - 1) / block.y);
if (num_feat * sizeof(DType) < 2 * CACHE_LINE_SIZE) { if (num_feat * sizeof(DType) < 2 * CACHE_LINE_SIZE) {
CUDA_KERNEL_CALL(IndexSelectMultiKernel, grid, block, 0, CUDA_KERNEL_CALL(
stream, array_data, num_feat, idx_data, IndexSelectMultiKernel, grid, block, 0, stream, array_data, num_feat,
len, arr_len, ret_data); idx_data, len, arr_len, ret_data);
} else { } else {
CUDA_KERNEL_CALL(IndexSelectMultiKernelAligned, grid, block, 0, CUDA_KERNEL_CALL(
stream, array_data, num_feat, idx_data, IndexSelectMultiKernelAligned, grid, block, 0, stream, array_data,
len, arr_len, ret_data); num_feat, idx_data, len, arr_len, ret_data);
} }
} }
return ret; return ret;
} }
...@@ -73,8 +74,7 @@ template NDArray IndexSelectCPUFromGPU<int32_t, int64_t>(NDArray, IdArray); ...@@ -73,8 +74,7 @@ template NDArray IndexSelectCPUFromGPU<int32_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<int64_t, int32_t>(NDArray, IdArray); template NDArray IndexSelectCPUFromGPU<int64_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<int64_t, int64_t>(NDArray, IdArray); template NDArray IndexSelectCPUFromGPU<int64_t, int64_t>(NDArray, IdArray);
template <typename DType, typename IdType>
template<typename DType, typename IdType>
void IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source) { void IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const DType* source_data = static_cast<DType*>(source->data); const DType* source_data = static_cast<DType*>(source->data);
...@@ -94,24 +94,24 @@ void IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source) { ...@@ -94,24 +94,24 @@ void IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source) {
num_feat *= source->shape[d]; num_feat *= source->shape[d];
} }
if (len == 0) if (len == 0) return;
return;
if (num_feat == 1) { if (num_feat == 1) {
const int nt = cuda::FindNumThreads(len); const int nt = cuda::FindNumThreads(len);
const int nb = (len + nt - 1) / nt; const int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL(IndexScatterSingleKernel, nb, nt, 0, CUDA_KERNEL_CALL(
stream, source_data, idx_data, len, arr_len, dest_data); IndexScatterSingleKernel, nb, nt, 0, stream, source_data, idx_data, len,
arr_len, dest_data);
} else { } else {
dim3 block(256, 1); dim3 block(256, 1);
while (static_cast<int64_t>(block.x) >= 2*num_feat) { while (static_cast<int64_t>(block.x) >= 2 * num_feat) {
block.x /= 2; block.x /= 2;
block.y *= 2; block.y *= 2;
} }
const dim3 grid((len+block.y-1)/block.y); const dim3 grid((len + block.y - 1) / block.y);
CUDA_KERNEL_CALL(IndexScatterMultiKernel, grid, block, 0, CUDA_KERNEL_CALL(
stream, source_data, num_feat, idx_data, IndexScatterMultiKernel, grid, block, 0, stream, source_data, num_feat,
len, arr_len, dest_data); idx_data, len, arr_len, dest_data);
} }
} }
......
...@@ -14,31 +14,28 @@ namespace aten { ...@@ -14,31 +14,28 @@ namespace aten {
namespace impl { namespace impl {
/* This is a cross-device access version of IndexSelectMultiKernel. /* This is a cross-device access version of IndexSelectMultiKernel.
* Since the memory access over PCIe is more sensitive to the * Since the memory access over PCIe is more sensitive to the
* data access aligment (cacheline), we need a separate version here. * data access aligment (cacheline), we need a separate version here.
*/ */
template <typename DType, typename IdType> template <typename DType, typename IdType>
__global__ void IndexSelectMultiKernelAligned( __global__ void IndexSelectMultiKernelAligned(
const DType* const array, const DType* const array, const int64_t num_feat, const IdType* const index,
const int64_t num_feat, const int64_t length, const int64_t arr_len, DType* const out) {
const IdType* const index, int64_t out_row = blockIdx.x * blockDim.y + threadIdx.y;
const int64_t length,
const int64_t arr_len,
DType* const out) {
int64_t out_row = blockIdx.x*blockDim.y+threadIdx.y;
const int64_t stride = blockDim.y*gridDim.x; const int64_t stride = blockDim.y * gridDim.x;
while (out_row < length) { while (out_row < length) {
int64_t col = threadIdx.x; int64_t col = threadIdx.x;
const int64_t in_row = index[out_row]; const int64_t in_row = index[out_row];
assert(in_row >= 0 && in_row < arr_len); assert(in_row >= 0 && in_row < arr_len);
const int64_t idx_offset = const int64_t idx_offset =
((uint64_t)(&array[in_row*num_feat]) % CACHE_LINE_SIZE) / sizeof(DType); ((uint64_t)(&array[in_row * num_feat]) % CACHE_LINE_SIZE) /
sizeof(DType);
col = col - idx_offset; col = col - idx_offset;
while (col < num_feat) { while (col < num_feat) {
if (col >= 0) if (col >= 0)
out[out_row*num_feat+col] = array[in_row*num_feat+col]; out[out_row * num_feat + col] = array[in_row * num_feat + col];
col += blockDim.x; col += blockDim.x;
} }
out_row += stride; out_row += stride;
......
...@@ -6,49 +6,49 @@ ...@@ -6,49 +6,49 @@
#include "./filter.h" #include "./filter.h"
#include <dgl/runtime/registry.h>
#include <dgl/runtime/packed_func.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/registry.h>
namespace dgl { namespace dgl {
namespace array { namespace array {
using namespace dgl::runtime; using namespace dgl::runtime;
template<DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
FilterRef CreateSetFilter(IdArray set); FilterRef CreateSetFilter(IdArray set);
DGL_REGISTER_GLOBAL("utils.filter._CAPI_DGLFilterCreateFromSet") DGL_REGISTER_GLOBAL("utils.filter._CAPI_DGLFilterCreateFromSet")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
IdArray array = args[0]; IdArray array = args[0];
auto ctx = array->ctx; auto ctx = array->ctx;
// TODO(nv-dlasalle): Implement CPU version. // TODO(nv-dlasalle): Implement CPU version.
if (ctx.device_type == kDGLCUDA) { if (ctx.device_type == kDGLCUDA) {
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
ATEN_ID_TYPE_SWITCH(array->dtype, IdType, { ATEN_ID_TYPE_SWITCH(array->dtype, IdType, {
*rv = CreateSetFilter<kDGLCUDA, IdType>(array); *rv = CreateSetFilter<kDGLCUDA, IdType>(array);
});
#else
LOG(FATAL) << "GPU support not compiled.";
#endif
} else {
LOG(FATAL) << "CPU support not yet implemented.";
}
}); });
#else
LOG(FATAL) << "GPU support not compiled.";
#endif
} else {
LOG(FATAL) << "CPU support not yet implemented.";
}
});
DGL_REGISTER_GLOBAL("utils.filter._CAPI_DGLFilterFindIncludedIndices") DGL_REGISTER_GLOBAL("utils.filter._CAPI_DGLFilterFindIncludedIndices")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
FilterRef filter = args[0]; FilterRef filter = args[0];
IdArray array = args[1]; IdArray array = args[1];
*rv = filter->find_included_indices(array); *rv = filter->find_included_indices(array);
}); });
DGL_REGISTER_GLOBAL("utils.filter._CAPI_DGLFilterFindExcludedIndices") DGL_REGISTER_GLOBAL("utils.filter._CAPI_DGLFilterFindExcludedIndices")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
FilterRef filter = args[0]; FilterRef filter = args[0];
IdArray array = args[1]; IdArray array = args[1];
*rv = filter->find_excluded_indices(array); *rv = filter->find_excluded_indices(array);
}); });
} // namespace array } // namespace array
} // namespace dgl } // namespace dgl
...@@ -4,12 +4,11 @@ ...@@ -4,12 +4,11 @@
* \brief Object for selecting items in a set, or selecting items not in a set. * \brief Object for selecting items in a set, or selecting items not in a set.
*/ */
#ifndef DGL_ARRAY_FILTER_H_ #ifndef DGL_ARRAY_FILTER_H_
#define DGL_ARRAY_FILTER_H_ #define DGL_ARRAY_FILTER_H_
#include <dgl/runtime/object.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/runtime/object.h>
namespace dgl { namespace dgl {
namespace array { namespace array {
...@@ -28,8 +27,7 @@ class Filter : public runtime::Object { ...@@ -28,8 +27,7 @@ class Filter : public runtime::Object {
* @return The indices of the items from `test` that are selected by * @return The indices of the items from `test` that are selected by
* this filter. * this filter.
*/ */
virtual IdArray find_included_indices( virtual IdArray find_included_indices(IdArray test) = 0;
IdArray test) = 0;
/** /**
* @brief From the test set of items, get the indices of those which are * @brief From the test set of items, get the indices of those which are
...@@ -40,8 +38,7 @@ class Filter : public runtime::Object { ...@@ -40,8 +38,7 @@ class Filter : public runtime::Object {
* @return The indices of the items from `test` that are not selected by this * @return The indices of the items from `test` that are not selected by this
* filter. * filter.
*/ */
virtual IdArray find_excluded_indices( virtual IdArray find_excluded_indices(IdArray test) = 0;
IdArray test) = 0;
}; };
DGL_DEFINE_OBJECT_REF(FilterRef, Filter); DGL_DEFINE_OBJECT_REF(FilterRef, Filter);
...@@ -50,4 +47,3 @@ DGL_DEFINE_OBJECT_REF(FilterRef, Filter); ...@@ -50,4 +47,3 @@ DGL_DEFINE_OBJECT_REF(FilterRef, Filter);
} // namespace dgl } // namespace dgl
#endif // DGL_ARRAY_FILTER_H_ #endif // DGL_ARRAY_FILTER_H_
...@@ -10,48 +10,51 @@ Copyright (c) 2021 Intel Corporation ...@@ -10,48 +10,51 @@ Copyright (c) 2021 Intel Corporation
Nesreen K. Ahmed <nesreen.k.ahmed@intel.com> Nesreen K. Ahmed <nesreen.k.ahmed@intel.com>
*/ */
#include <stdint.h> #include <dgl/base_heterograph.h>
#include <dgl/runtime/parallel_for.h> #include <dgl/packed_func_ext.h>
#include <dgl/random.h> #include <dgl/random.h>
#include <dgl/runtime/parallel_for.h>
#include <dmlc/omp.h> #include <dmlc/omp.h>
#include <dgl/packed_func_ext.h> #include <stdint.h>
#include <dgl/base_heterograph.h>
#include <vector> #include <vector>
#ifdef USE_TVM #ifdef USE_TVM
#include <featgraph.h> #include <featgraph.h>
#endif // USE_TVM #endif // USE_TVM
#include "kernel_decl.h"
#include "../c_api_common.h" #include "../c_api_common.h"
#include "./check.h" #include "./check.h"
#include "kernel_decl.h"
using namespace dgl::runtime; using namespace dgl::runtime;
namespace dgl { namespace dgl {
namespace aten { namespace aten {
template<typename IdType> template <typename IdType>
int32_t Ver2partition(IdType in_val, int64_t* node_map, int32_t num_parts) { int32_t Ver2partition(IdType in_val, int64_t *node_map, int32_t num_parts) {
int32_t pos = 0; int32_t pos = 0;
for (int32_t p=0; p < num_parts; p++) { for (int32_t p = 0; p < num_parts; p++) {
if (in_val < node_map[p]) if (in_val < node_map[p]) return pos;
return pos;
pos = pos + 1; pos = pos + 1;
} }
LOG(FATAL) << "Error: Unexpected output in Ver2partition!"; LOG(FATAL) << "Error: Unexpected output in Ver2partition!";
} }
/*! \brief Identifies the lead loaded partition/community for a given edge assignment.*/ /*!
int32_t LeastLoad(int64_t* community_edges, int32_t nc) { * \brief Identifies the lead loaded partition/community for a given edge
* assignment.
*/
int32_t LeastLoad(int64_t *community_edges, int32_t nc) {
std::vector<int> loc; std::vector<int> loc;
int32_t min = 1e9; int32_t min = 1e9;
for (int32_t i=0; i < nc; i++) { for (int32_t i = 0; i < nc; i++) {
if (community_edges[i] < min) { if (community_edges[i] < min) {
min = community_edges[i]; min = community_edges[i];
} }
} }
for (int32_t i=0; i < nc; i++) { for (int32_t i = 0; i < nc; i++) {
if (community_edges[i] == min) { if (community_edges[i] == min) {
loc.push_back(i); loc.push_back(i);
} }
...@@ -62,43 +65,37 @@ int32_t LeastLoad(int64_t* community_edges, int32_t nc) { ...@@ -62,43 +65,37 @@ int32_t LeastLoad(int64_t* community_edges, int32_t nc) {
return loc[r]; return loc[r];
} }
/*! \brief Libra - vertexcut based graph partitioning. /*!
It takes list of edges from input DGL graph and distributed them among nc partitions * \brief Libra - vertexcut based graph partitioning.
During edge distribution, Libra assign a given edge to a partition based on the end vertices, * It takes list of edges from input DGL graph and distributed them among nc
in doing so, it tries to minimized the splitting of the graph vertices. In case of conflict * partitions During edge distribution, Libra assign a given edge to a partition
Libra assigns an edge to the least loaded partition/community. * based on the end vertices, in doing so, it tries to minimized the splitting
\param[in] nc Number of partitions/communities * of the graph vertices. In case of conflict Libra assigns an edge to the least
\param[in] node_degree per node degree * loaded partition/community.
\param[in] edgenum_unassigned node degree * \param[in] nc Number of partitions/communities
\param[out] community_weights weight of the created partitions * \param[in] node_degree per node degree
\param[in] u src nodes * \param[in] edgenum_unassigned node degree
\param[in] v dst nodes * \param[out] community_weights weight of the created partitions
\param[out] w weight per edge * \param[in] u src nodes
\param[out] out partition assignment of the edges * \param[in] v dst nodes
\param[in] N_n number of nodes in the input graph * \param[out] w weight per edge
\param[in] N_e number of edges in the input graph * \param[out] out partition assignment of the edges
\param[in] prefix output/partition storage location * \param[in] N_n number of nodes in the input graph
*/ * \param[in] N_e number of edges in the input graph
template<typename IdType, typename IdType2> * \param[in] prefix output/partition storage location
*/
template <typename IdType, typename IdType2>
void LibraVertexCut( void LibraVertexCut(
int32_t nc, int32_t nc, NDArray node_degree, NDArray edgenum_unassigned,
NDArray node_degree, NDArray community_weights, NDArray u, NDArray v, NDArray w, NDArray out,
NDArray edgenum_unassigned, int64_t N_n, int64_t N_e, const std::string &prefix) {
NDArray community_weights, int32_t *out_ptr = out.Ptr<int32_t>();
NDArray u, IdType2 *node_degree_ptr = node_degree.Ptr<IdType2>();
NDArray v,
NDArray w,
NDArray out,
int64_t N_n,
int64_t N_e,
const std::string& prefix) {
int32_t *out_ptr = out.Ptr<int32_t>();
IdType2 *node_degree_ptr = node_degree.Ptr<IdType2>();
IdType2 *edgenum_unassigned_ptr = edgenum_unassigned.Ptr<IdType2>(); IdType2 *edgenum_unassigned_ptr = edgenum_unassigned.Ptr<IdType2>();
IdType *u_ptr = u.Ptr<IdType>(); IdType *u_ptr = u.Ptr<IdType>();
IdType *v_ptr = v.Ptr<IdType>(); IdType *v_ptr = v.Ptr<IdType>();
int64_t *w_ptr = w.Ptr<int64_t>(); int64_t *w_ptr = w.Ptr<int64_t>();
int64_t *community_weights_ptr = community_weights.Ptr<int64_t>(); int64_t *community_weights_ptr = community_weights.Ptr<int64_t>();
std::vector<std::vector<int32_t> > node_assignments(N_n); std::vector<std::vector<int32_t> > node_assignments(N_n);
std::vector<IdType2> replication_list; std::vector<IdType2> replication_list;
...@@ -106,17 +103,18 @@ void LibraVertexCut( ...@@ -106,17 +103,18 @@ void LibraVertexCut(
int64_t *community_edges = new int64_t[nc](); int64_t *community_edges = new int64_t[nc]();
int64_t *cache = new int64_t[nc](); int64_t *cache = new int64_t[nc]();
int64_t meter = static_cast<int>(N_e/100); int64_t meter = static_cast<int>(N_e / 100);
for (int64_t i=0; i < N_e; i++) { for (int64_t i = 0; i < N_e; i++) {
IdType u = u_ptr[i]; // edge end vertex 1 IdType u = u_ptr[i]; // edge end vertex 1
IdType v = v_ptr[i]; // edge end vertex 2 IdType v = v_ptr[i]; // edge end vertex 2
int64_t w = w_ptr[i]; // edge weight int64_t w = w_ptr[i]; // edge weight
CHECK(u < N_n); CHECK(u < N_n);
CHECK(v < N_n); CHECK(v < N_n);
if (i % meter == 0) { if (i % meter == 0) {
fprintf(stderr, "."); fflush(0); fprintf(stderr, ".");
fflush(0);
} }
if (node_assignments[u].size() == 0 && node_assignments[v].size() == 0) { if (node_assignments[u].size() == 0 && node_assignments[v].size() == 0) {
...@@ -127,17 +125,17 @@ void LibraVertexCut( ...@@ -127,17 +125,17 @@ void LibraVertexCut(
community_edges[c]++; community_edges[c]++;
community_weights_ptr[c] = community_weights_ptr[c] + w; community_weights_ptr[c] = community_weights_ptr[c] + w;
node_assignments[u].push_back(c); node_assignments[u].push_back(c);
if (u != v) if (u != v) node_assignments[v].push_back(c);
node_assignments[v].push_back(c);
CHECK(node_assignments[u].size() <= static_cast<size_t>(nc)) << CHECK(node_assignments[u].size() <= static_cast<size_t>(nc))
"[bug] 1. generated splits (u) are greater than nc!"; << "[bug] 1. generated splits (u) are greater than nc!";
CHECK(node_assignments[v].size() <= static_cast<size_t>(nc)) << CHECK(node_assignments[v].size() <= static_cast<size_t>(nc))
"[bug] 1. generated splits (v) are greater than nc!"; << "[bug] 1. generated splits (v) are greater than nc!";
edgenum_unassigned_ptr[u]--; edgenum_unassigned_ptr[u]--;
edgenum_unassigned_ptr[v]--; edgenum_unassigned_ptr[v]--;
} else if (node_assignments[u].size() != 0 && node_assignments[v].size() == 0) { } else if (
for (uint32_t j=0; j < node_assignments[u].size(); j++) { node_assignments[u].size() != 0 && node_assignments[v].size() == 0) {
for (uint32_t j = 0; j < node_assignments[u].size(); j++) {
int32_t cind = node_assignments[u][j]; int32_t cind = node_assignments[u][j];
cache[j] = community_edges[cind]; cache[j] = community_edges[cind];
} }
...@@ -148,12 +146,13 @@ void LibraVertexCut( ...@@ -148,12 +146,13 @@ void LibraVertexCut(
community_weights_ptr[c] = community_weights_ptr[c] + w; community_weights_ptr[c] = community_weights_ptr[c] + w;
node_assignments[v].push_back(c); node_assignments[v].push_back(c);
CHECK(node_assignments[v].size() <= static_cast<size_t>(nc)) << CHECK(node_assignments[v].size() <= static_cast<size_t>(nc))
"[bug] 2. generated splits (v) are greater than nc!"; << "[bug] 2. generated splits (v) are greater than nc!";
edgenum_unassigned_ptr[u]--; edgenum_unassigned_ptr[u]--;
edgenum_unassigned_ptr[v]--; edgenum_unassigned_ptr[v]--;
} else if (node_assignments[v].size() != 0 && node_assignments[u].size() == 0) { } else if (
for (uint32_t j=0; j < node_assignments[v].size(); j++) { node_assignments[v].size() != 0 && node_assignments[u].size() == 0) {
for (uint32_t j = 0; j < node_assignments[v].size(); j++) {
int32_t cind = node_assignments[v][j]; int32_t cind = node_assignments[v][j];
cache[j] = community_edges[cind]; cache[j] = community_edges[cind];
} }
...@@ -166,30 +165,32 @@ void LibraVertexCut( ...@@ -166,30 +165,32 @@ void LibraVertexCut(
community_weights_ptr[c] = community_weights_ptr[c] + w; community_weights_ptr[c] = community_weights_ptr[c] + w;
node_assignments[u].push_back(c); node_assignments[u].push_back(c);
CHECK(node_assignments[u].size() <= static_cast<size_t>(nc)) << CHECK(node_assignments[u].size() <= static_cast<size_t>(nc))
"[bug] 3. generated splits (u) are greater than nc!"; << "[bug] 3. generated splits (u) are greater than nc!";
edgenum_unassigned_ptr[u]--; edgenum_unassigned_ptr[u]--;
edgenum_unassigned_ptr[v]--; edgenum_unassigned_ptr[v]--;
} else { } else {
std::vector<int> setv(nc), intersetv; std::vector<int> setv(nc), intersetv;
for (int32_t j=0; j < nc; j++) setv[j] = 0; for (int32_t j = 0; j < nc; j++) setv[j] = 0;
int32_t interset = 0; int32_t interset = 0;
CHECK(node_assignments[u].size() <= static_cast<size_t>(nc)) << CHECK(node_assignments[u].size() <= static_cast<size_t>(nc))
"[bug] 4. generated splits (u) are greater than nc!"; << "[bug] 4. generated splits (u) are greater than nc!";
CHECK(node_assignments[v].size() <= static_cast<size_t>(nc)) << CHECK(node_assignments[v].size() <= static_cast<size_t>(nc))
"[bug] 4. generated splits (v) are greater than nc!"; << "[bug] 4. generated splits (v) are greater than nc!";
for (size_t j=0; j < node_assignments[v].size(); j++) { for (size_t j = 0; j < node_assignments[v].size(); j++) {
CHECK(node_assignments[v][j] < nc) << "[bug] 4. Part assigned (v) greater than nc!"; CHECK(node_assignments[v][j] < nc)
<< "[bug] 4. Part assigned (v) greater than nc!";
setv[node_assignments[v][j]]++; setv[node_assignments[v][j]]++;
} }
for (size_t j=0; j < node_assignments[u].size(); j++) { for (size_t j = 0; j < node_assignments[u].size(); j++) {
CHECK(node_assignments[u][j] < nc) << "[bug] 4. Part assigned (u) greater than nc!"; CHECK(node_assignments[u][j] < nc)
<< "[bug] 4. Part assigned (u) greater than nc!";
setv[node_assignments[u][j]]++; setv[node_assignments[u][j]]++;
} }
for (int32_t j=0; j < nc; j++) { for (int32_t j = 0; j < nc; j++) {
CHECK(setv[j] <= 2) << "[bug] 4. unexpected computed value !!!"; CHECK(setv[j] <= 2) << "[bug] 4. unexpected computed value !!!";
if (setv[j] == 2) { if (setv[j] == 2) {
interset++; interset++;
...@@ -197,7 +198,7 @@ void LibraVertexCut( ...@@ -197,7 +198,7 @@ void LibraVertexCut(
} }
} }
if (interset) { if (interset) {
for (size_t j=0; j < intersetv.size(); j++) { for (size_t j = 0; j < intersetv.size(); j++) {
int32_t cind = intersetv[j]; int32_t cind = intersetv[j];
cache[j] = community_edges[cind]; cache[j] = community_edges[cind];
} }
...@@ -211,7 +212,7 @@ void LibraVertexCut( ...@@ -211,7 +212,7 @@ void LibraVertexCut(
edgenum_unassigned_ptr[v]--; edgenum_unassigned_ptr[v]--;
} else { } else {
if (node_degree_ptr[u] < node_degree_ptr[v]) { if (node_degree_ptr[u] < node_degree_ptr[v]) {
for (uint32_t j=0; j < node_assignments[u].size(); j++) { for (uint32_t j = 0; j < node_assignments[u].size(); j++) {
int32_t cind = node_assignments[u][j]; int32_t cind = node_assignments[u][j];
cache[j] = community_edges[cind]; cache[j] = community_edges[cind];
} }
...@@ -222,37 +223,36 @@ void LibraVertexCut( ...@@ -222,37 +223,36 @@ void LibraVertexCut(
community_edges[c]++; community_edges[c]++;
community_weights_ptr[c] = community_weights_ptr[c] + w; community_weights_ptr[c] = community_weights_ptr[c] + w;
for (uint32_t j=0; j < node_assignments[v].size(); j++) { for (uint32_t j = 0; j < node_assignments[v].size(); j++) {
CHECK(node_assignments[v][j] != c) << CHECK(node_assignments[v][j] != c)
"[bug] 5. duplicate partition (v) assignment !!"; << "[bug] 5. duplicate partition (v) assignment !!";
} }
node_assignments[v].push_back(c); node_assignments[v].push_back(c);
CHECK(node_assignments[v].size() <= static_cast<size_t>(nc)) << CHECK(node_assignments[v].size() <= static_cast<size_t>(nc))
"[bug] 5. generated splits (v) greater than nc!!"; << "[bug] 5. generated splits (v) greater than nc!!";
replication_list.push_back(v); replication_list.push_back(v);
edgenum_unassigned_ptr[u]--; edgenum_unassigned_ptr[u]--;
edgenum_unassigned_ptr[v]--; edgenum_unassigned_ptr[v]--;
} else { } else {
for (uint32_t j=0; j < node_assignments[v].size(); j++) { for (uint32_t j = 0; j < node_assignments[v].size(); j++) {
int32_t cind = node_assignments[v][j]; int32_t cind = node_assignments[v][j];
cache[j] = community_edges[cind]; cache[j] = community_edges[cind];
} }
int32_t cindex = LeastLoad(cache, node_assignments[v].size()); int32_t cindex = LeastLoad(cache, node_assignments[v].size());
int32_t c = node_assignments[v][cindex]; int32_t c = node_assignments[v][cindex];
CHECK(c < nc) << "[bug] 6. partition greater than nc !!"; CHECK(c < nc) << "[bug] 6. partition greater than nc !!";
out_ptr[i] = c; out_ptr[i] = c;
community_edges[c]++; community_edges[c]++;
community_weights_ptr[c] = community_weights_ptr[c] + w; community_weights_ptr[c] = community_weights_ptr[c] + w;
for (uint32_t j=0; j < node_assignments[u].size(); j++) { for (uint32_t j = 0; j < node_assignments[u].size(); j++) {
CHECK(node_assignments[u][j] != c) << CHECK(node_assignments[u][j] != c)
"[bug] 6. duplicate partition (u) assignment !!"; << "[bug] 6. duplicate partition (u) assignment !!";
} }
if (u != v) if (u != v) node_assignments[u].push_back(c);
node_assignments[u].push_back(c);
CHECK(node_assignments[u].size() <= static_cast<size_t>(nc)) << CHECK(node_assignments[u].size() <= static_cast<size_t>(nc))
"[bug] 6. generated splits (u) greater than nc!!"; << "[bug] 6. generated splits (u) greater than nc!!";
replication_list.push_back(u); replication_list.push_back(u);
edgenum_unassigned_ptr[u]--; edgenum_unassigned_ptr[u]--;
edgenum_unassigned_ptr[v]--; edgenum_unassigned_ptr[v]--;
...@@ -262,15 +262,17 @@ void LibraVertexCut( ...@@ -262,15 +262,17 @@ void LibraVertexCut(
} }
delete cache; delete cache;
for (int64_t c=0; c < nc; c++) { for (int64_t c = 0; c < nc; c++) {
std::string path = prefix + "/community" + std::to_string(c) +".txt"; std::string path = prefix + "/community" + std::to_string(c) + ".txt";
FILE *fp = fopen(path.c_str(), "w"); FILE *fp = fopen(path.c_str(), "w");
CHECK_NE(fp, static_cast<FILE*>(NULL)) << "Error: can not open file: " << path.c_str(); CHECK_NE(fp, static_cast<FILE *>(NULL))
<< "Error: can not open file: " << path.c_str();
for (int64_t i=0; i < N_e; i++) { for (int64_t i = 0; i < N_e; i++) {
if (out_ptr[i] == c) if (out_ptr[i] == c)
fprintf(fp, "%ld,%ld,%ld\n", static_cast<int64_t>(u_ptr[i]), fprintf(
fp, "%ld,%ld,%ld\n", static_cast<int64_t>(u_ptr[i]),
static_cast<int64_t>(v_ptr[i]), w_ptr[i]); static_cast<int64_t>(v_ptr[i]), w_ptr[i]);
} }
fclose(fp); fclose(fp);
...@@ -278,22 +280,21 @@ void LibraVertexCut( ...@@ -278,22 +280,21 @@ void LibraVertexCut(
std::string path = prefix + "/replicationlist.csv"; std::string path = prefix + "/replicationlist.csv";
FILE *fp = fopen(path.c_str(), "w"); FILE *fp = fopen(path.c_str(), "w");
CHECK_NE(fp, static_cast<FILE*>(NULL)) << "Error: can not open file: " << path.c_str(); CHECK_NE(fp, static_cast<FILE *>(NULL))
<< "Error: can not open file: " << path.c_str();
fprintf(fp, "## The Indices of Nodes that are replicated :: Header"); fprintf(fp, "## The Indices of Nodes that are replicated :: Header");
printf("\nTotal replication: %ld\n", replication_list.size()); printf("\nTotal replication: %ld\n", replication_list.size());
for (uint64_t i=0; i < replication_list.size(); i++) for (uint64_t i = 0; i < replication_list.size(); i++)
fprintf(fp, "%ld\n", static_cast<int64_t>(replication_list[i])); fprintf(fp, "%ld\n", static_cast<int64_t>(replication_list[i]));
printf("Community weights:\n"); printf("Community weights:\n");
for (int64_t c=0; c < nc; c++) for (int64_t c = 0; c < nc; c++) printf("%ld ", community_weights_ptr[c]);
printf("%ld ", community_weights_ptr[c]);
printf("\n"); printf("\n");
printf("Community edges:\n"); printf("Community edges:\n");
for (int64_t c=0; c < nc; c++) for (int64_t c = 0; c < nc; c++) printf("%ld ", community_edges[c]);
printf("%ld ", community_edges[c]);
printf("\n"); printf("\n");
delete community_edges; delete community_edges;
...@@ -301,86 +302,75 @@ void LibraVertexCut( ...@@ -301,86 +302,75 @@ void LibraVertexCut(
} }
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLLibraVertexCut") DGL_REGISTER_GLOBAL("sparse._CAPI_DGLLibraVertexCut")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
int32_t nc = args[0]; int32_t nc = args[0];
NDArray node_degree = args[1]; NDArray node_degree = args[1];
NDArray edgenum_unassigned = args[2]; NDArray edgenum_unassigned = args[2];
NDArray community_weights = args[3]; NDArray community_weights = args[3];
NDArray u = args[4]; NDArray u = args[4];
NDArray v = args[5]; NDArray v = args[5];
NDArray w = args[6]; NDArray w = args[6];
NDArray out = args[7]; NDArray out = args[7];
int64_t N = args[8]; int64_t N = args[8];
int64_t N_e = args[9]; int64_t N_e = args[9];
std::string prefix = args[10]; std::string prefix = args[10];
ATEN_ID_TYPE_SWITCH(node_degree->dtype, IdType2, { ATEN_ID_TYPE_SWITCH(node_degree->dtype, IdType2, {
ATEN_ID_TYPE_SWITCH(u->dtype, IdType, { ATEN_ID_TYPE_SWITCH(u->dtype, IdType, {
LibraVertexCut<IdType, IdType2>(nc, LibraVertexCut<IdType, IdType2>(
node_degree, nc, node_degree, edgenum_unassigned, community_weights, u, v, w,
edgenum_unassigned, out, N, N_e, prefix);
community_weights,
u,
v,
w,
out,
N,
N_e,
prefix);
}); });
});
}); });
});
/*!
* \brief
/*! \brief * 1. Builds dictionary (ldt) for assigning local node IDs to nodes in the
1. Builds dictionary (ldt) for assigning local node IDs to nodes in the partitions * partitions
2. Builds dictionary (gdt) for storing copies (local ID) of split nodes * 2. Builds dictionary (gdt) for storing copies (local ID) of split nodes
These dictionaries will be used in the subsequesnt stages to setup tracking of split nodes * These dictionaries will be used in the subsequesnt stages to setup
copies across the partition, setting up partition `ndata` dictionaries. * tracking of split nodes copies across the partition, setting up partition
\param[out] a local src node ID of an edge in a partition * `ndata` dictionaries.
\param[out] b local dst node ID of an edge in a partition * \param[out] a local src node ID of an edge in a partition
\param[-] indices temporary memory, keeps track of global node ID to local node ID in a partition * \param[out] b local dst node ID of an edge in a partition
\param[out] ldt_key per partition dict for storing global and local node IDs (consecutive) * \param[-] indices temporary memory, keeps track of global node ID to local
\param[out] gdt_key global dict for storing number of local nodes (or split nodes) for a * node ID in a partition
given global node ID * \param[out] ldt_key per partition dict for storing global and local node IDs
\param[out] gdt_value global dict, stores local node IDs (due to split) across partitions * (consecutive)
for a given global node ID * \param[out] gdt_key global dict for storing number of local nodes (or split
\param[out] node_map keeps track of range of local node IDs (consecutive) given to the nodes in * nodes) for a given global node ID
the partitions * \param[out] gdt_value global dict, stores local node IDs (due to split)
\param[in, out] offset start of the range of local node IDs for this partition * across partitions for a given global node ID
\param[in] nc number of partitions/communities * \param[out] node_map keeps track of range of local node IDs (consecutive)
\param[in] c current partition number * given to the nodes in the partitions
\param[in] fsize size of pre-allocated memory tensor * \param[in, out] offset start of the range of local node IDs for this
\param[in] prefix input Libra partition file location * partition
* \param[in] nc number of partitions/communities
* \param[in] c current partition number \param[in] fsize size of pre-allocated
* memory tensor
* \param[in] prefix input Libra partition file location
*/ */
List<Value> Libra2dglBuildDict( List<Value> Libra2dglBuildDict(
NDArray a, NDArray a, NDArray b, NDArray indices, NDArray ldt_key, NDArray gdt_key,
NDArray b, NDArray gdt_value, NDArray node_map, NDArray offset, int32_t nc, int32_t c,
NDArray indices, int64_t fsize, const std::string &prefix) {
NDArray ldt_key, int64_t *indices_ptr = indices.Ptr<int64_t>(); // 1D temp array
NDArray gdt_key, int64_t *ldt_key_ptr =
NDArray gdt_value, ldt_key.Ptr<int64_t>(); // 1D local nodes <-> global nodes
NDArray node_map, int64_t *gdt_key_ptr = gdt_key.Ptr<int64_t>(); // 1D #split copies per node
NDArray offset, int64_t *gdt_value_ptr = gdt_value.Ptr<int64_t>(); // 2D tensor
int32_t nc, int64_t *node_map_ptr = node_map.Ptr<int64_t>(); // 1D tensor
int32_t c, int64_t *offset_ptr = offset.Ptr<int64_t>(); // 1D tensor
int64_t fsize,
const std::string& prefix) {
int64_t *indices_ptr = indices.Ptr<int64_t>(); // 1D temp array
int64_t *ldt_key_ptr = ldt_key.Ptr<int64_t>(); // 1D local nodes <-> global nodes
int64_t *gdt_key_ptr = gdt_key.Ptr<int64_t>(); // 1D #split copies per node
int64_t *gdt_value_ptr = gdt_value.Ptr<int64_t>(); // 2D tensor
int64_t *node_map_ptr = node_map.Ptr<int64_t>(); // 1D tensor
int64_t *offset_ptr = offset.Ptr<int64_t>(); // 1D tensor
int32_t width = nc; int32_t width = nc;
int64_t *a_ptr = a.Ptr<int64_t>(); // stores local src and dst node ID, int64_t *a_ptr = a.Ptr<int64_t>(); // stores local src and dst node ID,
int64_t *b_ptr = b.Ptr<int64_t>(); // to create the partition graph int64_t *b_ptr = b.Ptr<int64_t>(); // to create the partition graph
int64_t N_n = indices->shape[0]; int64_t N_n = indices->shape[0];
int64_t num_nodes = ldt_key->shape[0]; int64_t num_nodes = ldt_key->shape[0];
for (int64_t i=0; i < N_n; i++) { for (int64_t i = 0; i < N_n; i++) {
indices_ptr[i] = -100; indices_ptr[i] = -100;
} }
...@@ -388,98 +378,106 @@ List<Value> Libra2dglBuildDict( ...@@ -388,98 +378,106 @@ List<Value> Libra2dglBuildDict(
int64_t edge = 0; int64_t edge = 0;
std::string path = prefix + "/community" + std::to_string(c) + ".txt"; std::string path = prefix + "/community" + std::to_string(c) + ".txt";
FILE *fp = fopen(path.c_str(), "r"); FILE *fp = fopen(path.c_str(), "r");
CHECK_NE(fp, static_cast<FILE*>(NULL)) << "Error: can not open file: " << path.c_str(); CHECK_NE(fp, static_cast<FILE *>(NULL))
<< "Error: can not open file: " << path.c_str();
while (!feof(fp) && edge < fsize) { while (!feof(fp) && edge < fsize) {
int64_t u, v; int64_t u, v;
float w; float w;
fscanf(fp, "%ld,%ld,%f\n", &u, &v, &w); // reading an edge - the src and dst global node IDs fscanf(
fp, "%ld,%ld,%f\n", &u, &v,
if (indices_ptr[u] == -100) { // if already not assigned a local node ID, local node ID is &w); // reading an edge - the src and dst global node IDs
ldt_key_ptr[pos] = u; // already assigned for this global node ID
CHECK(pos < num_nodes); // Sanity check if (indices_ptr[u] ==
indices_ptr[u] = pos++; // consecutive local node ID for a given global node ID -100) { // if already not assigned a local node ID, local node ID is
ldt_key_ptr[pos] = u; // already assigned for this global node ID
CHECK(pos < num_nodes); // Sanity check
indices_ptr[u] =
pos++; // consecutive local node ID for a given global node ID
} }
if (indices_ptr[v] == -100) { // if already not assigned a local node ID if (indices_ptr[v] == -100) { // if already not assigned a local node ID
ldt_key_ptr[pos] = v; ldt_key_ptr[pos] = v;
CHECK(pos < num_nodes); // Sanity check CHECK(pos < num_nodes); // Sanity check
indices_ptr[v] = pos++; indices_ptr[v] = pos++;
} }
a_ptr[edge] = indices_ptr[u]; // new local ID for an edge a_ptr[edge] = indices_ptr[u]; // new local ID for an edge
b_ptr[edge++] = indices_ptr[v]; // new local ID for an edge b_ptr[edge++] = indices_ptr[v]; // new local ID for an edge
} }
CHECK(edge <= fsize) << "[Bug] memory allocated for #edges per partition is not enough."; CHECK(edge <= fsize)
<< "[Bug] memory allocated for #edges per partition is not enough.";
fclose(fp); fclose(fp);
List<Value> ret; List<Value> ret;
ret.push_back(Value(MakeValue(pos))); // returns total number of nodes in this partition ret.push_back(Value(
ret.push_back(Value(MakeValue(edge))); // returns total number of edges in this partition MakeValue(pos))); // returns total number of nodes in this partition
ret.push_back(Value(
MakeValue(edge))); // returns total number of edges in this partition
for (int64_t i=0; i < pos; i++) { for (int64_t i = 0; i < pos; i++) {
int64_t u = ldt_key_ptr[i]; // global node ID int64_t u = ldt_key_ptr[i]; // global node ID
// int64_t v = indices_ptr[u]; // int64_t v = indices_ptr[u];
int64_t v = i; // local node ID int64_t v = i; // local node ID
int64_t *ind = &gdt_key_ptr[u]; // global dict, total number of local node IDs (an offset) int64_t *ind =
// as of now for a given global node ID &gdt_key_ptr[u]; // global dict, total number of local node IDs (an
int64_t *ptr = gdt_value_ptr + u*width; // offset) as of now for a given global node ID
ptr[*ind] = offset_ptr[0] + v; // stores a local node ID for the global node ID int64_t *ptr = gdt_value_ptr + u * width;
ptr[*ind] =
offset_ptr[0] + v; // stores a local node ID for the global node ID
(*ind)++; (*ind)++;
CHECK_NE(v, -100); CHECK_NE(v, -100);
CHECK(*ind <= nc); CHECK(*ind <= nc);
} }
node_map_ptr[c] = offset_ptr[0] + pos; // since local node IDs for a partition are consecutive, node_map_ptr[c] =
// we maintain the range of local node IDs like this offset_ptr[0] +
pos; // since local node IDs for a partition are consecutive,
// we maintain the range of local node IDs like this
offset_ptr[0] += pos; offset_ptr[0] += pos;
return ret; return ret;
} }
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLLibra2dglBuildDict") DGL_REGISTER_GLOBAL("sparse._CAPI_DGLLibra2dglBuildDict")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
NDArray a = args[0]; NDArray a = args[0];
NDArray b = args[1]; NDArray b = args[1];
NDArray indices = args[2]; NDArray indices = args[2];
NDArray ldt_key = args[3]; NDArray ldt_key = args[3];
NDArray gdt_key = args[4]; NDArray gdt_key = args[4];
NDArray gdt_value = args[5]; NDArray gdt_value = args[5];
NDArray node_map = args[6]; NDArray node_map = args[6];
NDArray offset = args[7]; NDArray offset = args[7];
int32_t nc = args[8]; int32_t nc = args[8];
int32_t c = args[9]; int32_t c = args[9];
int64_t fsize = args[10]; int64_t fsize = args[10];
std::string prefix = args[11]; std::string prefix = args[11];
List<Value> ret = Libra2dglBuildDict(a, b, indices, ldt_key, gdt_key, List<Value> ret = Libra2dglBuildDict(
gdt_value, node_map, offset, a, b, indices, ldt_key, gdt_key, gdt_value, node_map, offset, nc, c,
nc, c, fsize, prefix); fsize, prefix);
*rv = ret; *rv = ret;
}); });
/*!
/*! \brief sets up the 1-level tree among the clones of the split-nodes. * \brief sets up the 1-level tree among the clones of the split-nodes.
\param[in] gdt_key global dict for assigning consecutive node IDs to nodes across all the * \param[in] gdt_key global dict for assigning consecutive node IDs to nodes
partitions * across all the partitions
\param[in] gdt_value global dict for assigning consecutive node IDs to nodes across all the * \param[in] gdt_value global dict for assigning consecutive node IDs to nodes
partition * across all the partition
\param[out] lrtensor keeps the root node ID of 1-level tree * \param[out] lrtensor keeps the root node ID of 1-level tree
\param[in] nc number of partitions/communities * \param[in] nc number of partitions/communities
\param[in] Nn number of nodes in the input graph * \param[in] Nn number of nodes in the input graph
*/ */
void Libra2dglSetLR( void Libra2dglSetLR(
NDArray gdt_key, NDArray gdt_key, NDArray gdt_value, NDArray lrtensor, int32_t nc,
NDArray gdt_value, int64_t Nn) {
NDArray lrtensor, int64_t *gdt_key_ptr = gdt_key.Ptr<int64_t>(); // 1D tensor
int32_t nc, int64_t *gdt_value_ptr = gdt_value.Ptr<int64_t>(); // 2D tensor
int64_t Nn) { int64_t *lrtensor_ptr = lrtensor.Ptr<int64_t>(); // 1D tensor
int64_t *gdt_key_ptr = gdt_key.Ptr<int64_t>(); // 1D tensor
int64_t *gdt_value_ptr = gdt_value.Ptr<int64_t>(); // 2D tensor
int64_t *lrtensor_ptr = lrtensor.Ptr<int64_t>(); // 1D tensor
int32_t width = nc; int32_t width = nc;
int64_t cnt = 0; int64_t cnt = 0;
int64_t avg_split_copy = 0, scnt = 0; int64_t avg_split_copy = 0, scnt = 0;
for (int64_t i=0; i < Nn; i++) { for (int64_t i = 0; i < Nn; i++) {
if (gdt_key_ptr[i] <= 0) { if (gdt_key_ptr[i] <= 0) {
cnt++; cnt++;
} else { } else {
...@@ -487,7 +485,7 @@ void Libra2dglSetLR( ...@@ -487,7 +485,7 @@ void Libra2dglSetLR(
CHECK(val >= 0 && val < gdt_key_ptr[i]); CHECK(val >= 0 && val < gdt_key_ptr[i]);
CHECK(gdt_key_ptr[i] <= nc); CHECK(gdt_key_ptr[i] <= nc);
int64_t *ptr = gdt_value_ptr + i*width; int64_t *ptr = gdt_value_ptr + i * width;
lrtensor_ptr[i] = ptr[val]; lrtensor_ptr[i] = ptr[val];
} }
if (gdt_key_ptr[i] > 1) { if (gdt_key_ptr[i] > 1) {
...@@ -498,105 +496,87 @@ void Libra2dglSetLR( ...@@ -498,105 +496,87 @@ void Libra2dglSetLR(
} }
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLLibra2dglSetLR") DGL_REGISTER_GLOBAL("sparse._CAPI_DGLLibra2dglSetLR")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
NDArray gdt_key = args[0]; NDArray gdt_key = args[0];
NDArray gdt_value = args[1]; NDArray gdt_value = args[1];
NDArray lrtensor = args[2]; NDArray lrtensor = args[2];
int32_t nc = args[3]; int32_t nc = args[3];
int64_t Nn = args[4]; int64_t Nn = args[4];
Libra2dglSetLR(gdt_key, gdt_value, lrtensor, nc, Nn); Libra2dglSetLR(gdt_key, gdt_value, lrtensor, nc, Nn);
}); });
/*! /*!
\brief For each node in a partition, it creates a list of remote clone IDs; * \brief For each node in a partition, it creates a list of remote clone IDs;
also, for each node in a partition, it gathers the data (feats, label, trian, test) * also, for each node in a partition, it gathers the data (feats, label,
from input graph. * trian, test) from input graph.
\param[out] feat node features in current partition c * \param[out] feat node features in current partition c.
\param[in] gfeat input graph node features * \param[in] gfeat input graph node features.
\param[out] adj list of node IDs of remote clones * \param[out] adj list of node IDs of remote clones.
\param[out] inner_nodes marks whether a node is split or not * \param[out] inner_nodes marks whether a node is split or not.
\param[in] ldt_key per partition dict for tracking global to local node IDs * \param[in] ldt_key per partition dict for tracking global to local node IDs
\param[out] gdt_key global dict for storing number of local nodes (or split nodes) for a * \param[out] gdt_key global dict for storing number of local nodes (or split
given global node ID * nodes) for a given global node ID \param[out] gdt_value global
\param[out] gdt_value global dict, stores local node IDs (due to split) across partitions * dict, stores local node IDs (due to split) across partitions for
for a given global node ID * a given global node ID.
\param[in] node_map keeps track of range of local node IDs (consecutive) given to the nodes in * \param[in] node_map keeps track of range of local node IDs (consecutive)
the partitions * given to the nodes in the partitions.
\param[out] lr 1-level tree marking for local split nodes * \param[out] lr 1-level tree marking for local split nodes.
\param[in] lrtensor global (all the partitions) 1-level tree * \param[in] lrtensor global (all the partitions) 1-level tree.
\param[in] num_nodes number of nodes in current partition * \param[in] num_nodes number of nodes in current partition.
\param[in] nc number of partitions/communities * \param[in] nc number of partitions/communities.
\param[in] c current partition/community * \param[in] c current partition/community.
\param[in] feat_size node feature vector size * \param[in] feat_size node feature vector size.
\param[out] labels local (for this partition) labels * \param[out] labels local (for this partition) labels.
\param[out] trainm local (for this partition) training nodes * \param[out] trainm local (for this partition) training nodes.
\param[out] testm local (for this partition) testing nodes * \param[out] testm local (for this partition) testing nodes.
\param[out] valm local (for this partition) validation nodes * \param[out] valm local (for this partition) validation nodes.
\param[in] glabels global (input graph) labels * \param[in] glabels global (input graph) labels.
\param[in] gtrainm glabal (input graph) training nodes * \param[in] gtrainm glabal (input graph) training nodes.
\param[in] gtestm glabal (input graph) testing nodes * \param[in] gtestm glabal (input graph) testing nodes.
\param[in] gvalm glabal (input graph) validation nodes * \param[in] gvalm glabal (input graph) validation nodes.
\param[out] Nn number of nodes in the input graph * \param[out] Nn number of nodes in the input graph.
*/ */
template<typename IdType, typename IdType2, typename DType> template <typename IdType, typename IdType2, typename DType>
void Libra2dglBuildAdjlist( void Libra2dglBuildAdjlist(
NDArray feat, NDArray feat, NDArray gfeat, NDArray adj, NDArray inner_node,
NDArray gfeat, NDArray ldt_key, NDArray gdt_key, NDArray gdt_value, NDArray node_map,
NDArray adj, NDArray lr, NDArray lrtensor, int64_t num_nodes, int32_t nc, int32_t c,
NDArray inner_node, int32_t feat_size, NDArray labels, NDArray trainm, NDArray testm,
NDArray ldt_key, NDArray valm, NDArray glabels, NDArray gtrainm, NDArray gtestm,
NDArray gdt_key, NDArray gvalm, int64_t Nn) {
NDArray gdt_value, DType *feat_ptr = feat.Ptr<DType>(); // 2D tensor
NDArray node_map, DType *gfeat_ptr = gfeat.Ptr<DType>(); // 2D tensor
NDArray lr, int64_t *adj_ptr = adj.Ptr<int64_t>(); // 2D tensor
NDArray lrtensor,
int64_t num_nodes,
int32_t nc,
int32_t c,
int32_t feat_size,
NDArray labels ,
NDArray trainm ,
NDArray testm ,
NDArray valm ,
NDArray glabels,
NDArray gtrainm,
NDArray gtestm ,
NDArray gvalm,
int64_t Nn) {
DType *feat_ptr = feat.Ptr<DType>(); // 2D tensor
DType *gfeat_ptr = gfeat.Ptr<DType>(); // 2D tensor
int64_t *adj_ptr = adj.Ptr<int64_t>(); // 2D tensor
int32_t *inner_node_ptr = inner_node.Ptr<int32_t>(); int32_t *inner_node_ptr = inner_node.Ptr<int32_t>();
int64_t *ldt_key_ptr = ldt_key.Ptr<int64_t>(); int64_t *ldt_key_ptr = ldt_key.Ptr<int64_t>();
int64_t *gdt_key_ptr = gdt_key.Ptr<int64_t>(); int64_t *gdt_key_ptr = gdt_key.Ptr<int64_t>();
int64_t *gdt_value_ptr = gdt_value.Ptr<int64_t>(); // 2D tensor int64_t *gdt_value_ptr = gdt_value.Ptr<int64_t>(); // 2D tensor
int64_t *node_map_ptr = node_map.Ptr<int64_t>(); int64_t *node_map_ptr = node_map.Ptr<int64_t>();
int64_t *lr_ptr = lr.Ptr<int64_t>(); int64_t *lr_ptr = lr.Ptr<int64_t>();
int64_t *lrtensor_ptr = lrtensor.Ptr<int64_t>(); int64_t *lrtensor_ptr = lrtensor.Ptr<int64_t>();
int32_t width = nc - 1; int32_t width = nc - 1;
runtime::parallel_for(0, num_nodes, [&] (int64_t s, int64_t e) { runtime::parallel_for(0, num_nodes, [&](int64_t s, int64_t e) {
for (int64_t i=s; i < e; i++) { for (int64_t i = s; i < e; i++) {
int64_t k = ldt_key_ptr[i]; int64_t k = ldt_key_ptr[i];
int64_t v = i; int64_t v = i;
int64_t ind = gdt_key_ptr[k]; int64_t ind = gdt_key_ptr[k];
int64_t *adj_ptr_ptr = adj_ptr + v*width; int64_t *adj_ptr_ptr = adj_ptr + v * width;
if (ind == 1) { if (ind == 1) {
for (int32_t j=0; j < width; j++) adj_ptr_ptr[j] = -1; for (int32_t j = 0; j < width; j++) adj_ptr_ptr[j] = -1;
inner_node_ptr[i] = 1; inner_node_ptr[i] = 1;
lr_ptr[i] = -200; lr_ptr[i] = -200;
} else { } else {
lr_ptr[i] = lrtensor_ptr[k]; lr_ptr[i] = lrtensor_ptr[k];
int64_t *ptr = gdt_value_ptr + k*nc; int64_t *ptr = gdt_value_ptr + k * nc;
int64_t pos = 0; int64_t pos = 0;
CHECK(ind <= nc); CHECK(ind <= nc);
int32_t flg = 0; int32_t flg = 0;
for (int64_t j=0; j < ind; j++) { for (int64_t j = 0; j < ind; j++) {
if (ptr[j] == lr_ptr[i]) flg = 1; if (ptr[j] == lr_ptr[i]) flg = 1;
if (c != Ver2partition<int64_t>(ptr[j], node_map_ptr, nc) ) if (c != Ver2partition<int64_t>(ptr[j], node_map_ptr, nc))
adj_ptr_ptr[pos++] = ptr[j]; adj_ptr_ptr[pos++] = ptr[j];
} }
CHECK_EQ(flg, 1); CHECK_EQ(flg, 1);
...@@ -608,15 +588,14 @@ void Libra2dglBuildAdjlist( ...@@ -608,15 +588,14 @@ void Libra2dglBuildAdjlist(
}); });
// gather // gather
runtime::parallel_for(0, num_nodes, [&] (int64_t s, int64_t e) { runtime::parallel_for(0, num_nodes, [&](int64_t s, int64_t e) {
for (int64_t i=s; i < e; i++) { for (int64_t i = s; i < e; i++) {
int64_t k = ldt_key_ptr[i]; int64_t k = ldt_key_ptr[i];
int64_t ind = i*feat_size; int64_t ind = i * feat_size;
DType *optr = gfeat_ptr + ind; DType *optr = gfeat_ptr + ind;
DType *iptr = feat_ptr + k*feat_size; DType *iptr = feat_ptr + k * feat_size;
for (int32_t j=0; j < feat_size; j++) for (int32_t j = 0; j < feat_size; j++) optr[j] = iptr[j];
optr[j] = iptr[j];
} }
IdType *labels_ptr = labels.Ptr<IdType>(); IdType *labels_ptr = labels.Ptr<IdType>();
...@@ -628,9 +607,9 @@ void Libra2dglBuildAdjlist( ...@@ -628,9 +607,9 @@ void Libra2dglBuildAdjlist(
IdType2 *valm_ptr = valm.Ptr<IdType2>(); IdType2 *valm_ptr = valm.Ptr<IdType2>();
IdType2 *gvalm_ptr = gvalm.Ptr<IdType2>(); IdType2 *gvalm_ptr = gvalm.Ptr<IdType2>();
for (int64_t i=0; i < num_nodes; i++) { for (int64_t i = 0; i < num_nodes; i++) {
int64_t k = ldt_key_ptr[i]; int64_t k = ldt_key_ptr[i];
CHECK(k >=0 && k < Nn); CHECK(k >= 0 && k < Nn);
glabels_ptr[i] = labels_ptr[k]; glabels_ptr[i] = labels_ptr[k];
gtrainm_ptr[i] = trainm_ptr[k]; gtrainm_ptr[i] = trainm_ptr[k];
gtestm_ptr[i] = testm_ptr[k]; gtestm_ptr[i] = testm_ptr[k];
...@@ -639,53 +618,43 @@ void Libra2dglBuildAdjlist( ...@@ -639,53 +618,43 @@ void Libra2dglBuildAdjlist(
}); });
} }
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLLibra2dglBuildAdjlist") DGL_REGISTER_GLOBAL("sparse._CAPI_DGLLibra2dglBuildAdjlist")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
NDArray feat = args[0]; NDArray feat = args[0];
NDArray gfeat = args[1]; NDArray gfeat = args[1];
NDArray adj = args[2]; NDArray adj = args[2];
NDArray inner_node = args[3]; NDArray inner_node = args[3];
NDArray ldt_key = args[4]; NDArray ldt_key = args[4];
NDArray gdt_key = args[5]; NDArray gdt_key = args[5];
NDArray gdt_value = args[6]; NDArray gdt_value = args[6];
NDArray node_map = args[7]; NDArray node_map = args[7];
NDArray lr = args[8]; NDArray lr = args[8];
NDArray lrtensor = args[9]; NDArray lrtensor = args[9];
int64_t num_nodes = args[10]; int64_t num_nodes = args[10];
int32_t nc = args[11]; int32_t nc = args[11];
int32_t c = args[12]; int32_t c = args[12];
int32_t feat_size = args[13]; int32_t feat_size = args[13];
NDArray labels = args[14]; NDArray labels = args[14];
NDArray trainm = args[15]; NDArray trainm = args[15];
NDArray testm = args[16]; NDArray testm = args[16];
NDArray valm = args[17]; NDArray valm = args[17];
NDArray glabels = args[18]; NDArray glabels = args[18];
NDArray gtrainm = args[19]; NDArray gtrainm = args[19];
NDArray gtestm = args[20]; NDArray gtestm = args[20];
NDArray gvalm = args[21]; NDArray gvalm = args[21];
int64_t Nn = args[22]; int64_t Nn = args[22];
ATEN_FLOAT_TYPE_SWITCH(feat->dtype, DType, "Features", { ATEN_FLOAT_TYPE_SWITCH(feat->dtype, DType, "Features", {
ATEN_ID_TYPE_SWITCH(trainm->dtype, IdType2, { ATEN_ID_TYPE_SWITCH(trainm->dtype, IdType2, {
ATEN_ID_BITS_SWITCH((glabels->dtype).bits, IdType, { ATEN_ID_BITS_SWITCH((glabels->dtype).bits, IdType, {
Libra2dglBuildAdjlist<IdType, IdType2, DType>(feat, gfeat, Libra2dglBuildAdjlist<IdType, IdType2, DType>(
adj, inner_node, feat, gfeat, adj, inner_node, ldt_key, gdt_key, gdt_value,
ldt_key, gdt_key, node_map, lr, lrtensor, num_nodes, nc, c, feat_size, labels,
gdt_value, trainm, testm, valm, glabels, gtrainm, gtestm, gvalm, Nn);
node_map, lr, });
lrtensor, num_nodes,
nc, c,
feat_size, labels,
trainm, testm,
valm, glabels,
gtrainm, gtestm,
gvalm, Nn);
});
}); });
});
}); });
});
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
...@@ -4,17 +4,17 @@ ...@@ -4,17 +4,17 @@
* \brief COO union and partition * \brief COO union and partition
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <vector> #include <vector>
namespace dgl { namespace dgl {
namespace aten { namespace aten {
///////////////////////// COO Based Operations///////////////////////// ///////////////////////// COO Based Operations/////////////////////////
std::vector<COOMatrix> DisjointPartitionCooBySizes( std::vector<COOMatrix> DisjointPartitionCooBySizes(
const COOMatrix &coo, const COOMatrix &coo, const uint64_t batch_size,
const uint64_t batch_size, const std::vector<uint64_t> &edge_cumsum,
const std::vector<uint64_t> &edge_cumsum, const std::vector<uint64_t> &src_vertex_cumsum,
const std::vector<uint64_t> &src_vertex_cumsum, const std::vector<uint64_t> &dst_vertex_cumsum) {
const std::vector<uint64_t> &dst_vertex_cumsum) {
CHECK_EQ(edge_cumsum.size(), batch_size + 1); CHECK_EQ(edge_cumsum.size(), batch_size + 1);
CHECK_EQ(src_vertex_cumsum.size(), batch_size + 1); CHECK_EQ(src_vertex_cumsum.size(), batch_size + 1);
CHECK_EQ(dst_vertex_cumsum.size(), batch_size + 1); CHECK_EQ(dst_vertex_cumsum.size(), batch_size + 1);
...@@ -22,28 +22,23 @@ std::vector<COOMatrix> DisjointPartitionCooBySizes( ...@@ -22,28 +22,23 @@ std::vector<COOMatrix> DisjointPartitionCooBySizes(
ret.resize(batch_size); ret.resize(batch_size);
for (size_t g = 0; g < batch_size; ++g) { for (size_t g = 0; g < batch_size; ++g) {
IdArray result_src = IndexSelect(coo.row, IdArray result_src =
edge_cumsum[g], IndexSelect(coo.row, edge_cumsum[g], edge_cumsum[g + 1]) -
edge_cumsum[g + 1]) - src_vertex_cumsum[g]; src_vertex_cumsum[g];
IdArray result_dst = IndexSelect(coo.col, IdArray result_dst =
edge_cumsum[g], IndexSelect(coo.col, edge_cumsum[g], edge_cumsum[g + 1]) -
edge_cumsum[g + 1]) - dst_vertex_cumsum[g]; dst_vertex_cumsum[g];
IdArray result_data = NullArray(); IdArray result_data = NullArray();
// has data index array // has data index array
if (COOHasData(coo)) { if (COOHasData(coo)) {
result_data = IndexSelect(coo.data, result_data = IndexSelect(coo.data, edge_cumsum[g], edge_cumsum[g + 1]) -
edge_cumsum[g], edge_cumsum[g];
edge_cumsum[g + 1]) - edge_cumsum[g];
} }
COOMatrix sub_coo = COOMatrix( COOMatrix sub_coo = COOMatrix(
src_vertex_cumsum[g+1]-src_vertex_cumsum[g], src_vertex_cumsum[g + 1] - src_vertex_cumsum[g],
dst_vertex_cumsum[g+1]-dst_vertex_cumsum[g], dst_vertex_cumsum[g + 1] - dst_vertex_cumsum[g], result_src, result_dst,
result_src, result_data, coo.row_sorted, coo.col_sorted);
result_dst,
result_data,
coo.row_sorted,
coo.col_sorted);
ret[g] = sub_coo; ret[g] = sub_coo;
} }
...@@ -51,44 +46,36 @@ std::vector<COOMatrix> DisjointPartitionCooBySizes( ...@@ -51,44 +46,36 @@ std::vector<COOMatrix> DisjointPartitionCooBySizes(
} }
COOMatrix COOSliceContiguousChunk( COOMatrix COOSliceContiguousChunk(
const COOMatrix &coo, const COOMatrix &coo, const std::vector<uint64_t> &edge_range,
const std::vector<uint64_t> &edge_range, const std::vector<uint64_t> &src_vertex_range,
const std::vector<uint64_t> &src_vertex_range, const std::vector<uint64_t> &dst_vertex_range) {
const std::vector<uint64_t> &dst_vertex_range) {
IdArray result_src = NullArray(coo.row->dtype, coo.row->ctx); IdArray result_src = NullArray(coo.row->dtype, coo.row->ctx);
IdArray result_dst = NullArray(coo.row->dtype, coo.row->ctx); IdArray result_dst = NullArray(coo.row->dtype, coo.row->ctx);
if (edge_range[1] != edge_range[0]) { if (edge_range[1] != edge_range[0]) {
// The chunk has edges // The chunk has edges
result_src = IndexSelect(coo.row, result_src = IndexSelect(coo.row, edge_range[0], edge_range[1]) -
edge_range[0], src_vertex_range[0];
edge_range[1]) - src_vertex_range[0]; result_dst = IndexSelect(coo.col, edge_range[0], edge_range[1]) -
result_dst = IndexSelect(coo.col, dst_vertex_range[0];
edge_range[0],
edge_range[1]) - dst_vertex_range[0];
} }
IdArray result_data = NullArray(); IdArray result_data = NullArray();
// has data index array // has data index array
if (COOHasData(coo)) { if (COOHasData(coo)) {
result_data = IndexSelect(coo.data, result_data =
edge_range[0], IndexSelect(coo.data, edge_range[0], edge_range[1]) - edge_range[0];
edge_range[1]) - edge_range[0];
} }
COOMatrix sub_coo = COOMatrix( COOMatrix sub_coo = COOMatrix(
src_vertex_range[1]-src_vertex_range[0], src_vertex_range[1] - src_vertex_range[0],
dst_vertex_range[1]-dst_vertex_range[0], dst_vertex_range[1] - dst_vertex_range[0], result_src, result_dst,
result_src, result_data, coo.row_sorted, coo.col_sorted);
result_dst,
result_data,
coo.row_sorted,
coo.col_sorted);
return sub_coo; return sub_coo;
} }
///////////////////////// CSR Based Operations///////////////////////// ///////////////////////// CSR Based Operations/////////////////////////
CSRMatrix DisjointUnionCsr(const std::vector<CSRMatrix>& csrs) { CSRMatrix DisjointUnionCsr(const std::vector<CSRMatrix> &csrs) {
uint64_t src_offset = 0, dst_offset = 0; uint64_t src_offset = 0, dst_offset = 0;
int64_t indices_offset = 0; int64_t indices_offset = 0;
bool has_data = false; bool has_data = false;
...@@ -112,10 +99,7 @@ CSRMatrix DisjointUnionCsr(const std::vector<CSRMatrix>& csrs) { ...@@ -112,10 +99,7 @@ CSRMatrix DisjointUnionCsr(const std::vector<CSRMatrix>& csrs) {
sorted &= csr.sorted; sorted &= csr.sorted;
IdArray indptr = csr.indptr + indices_offset; IdArray indptr = csr.indptr + indices_offset;
IdArray indices = csr.indices + dst_offset; IdArray indices = csr.indices + dst_offset;
if (i > 0) if (i > 0) indptr = IndexSelect(indptr, 1, indptr->shape[0]);
indptr = IndexSelect(indptr,
1,
indptr->shape[0]);
res_indptr[i] = indptr; res_indptr[i] = indptr;
res_indices[i] = indices; res_indices[i] = indices;
src_offset += csr.num_rows; src_offset += csr.num_rows;
...@@ -125,10 +109,9 @@ CSRMatrix DisjointUnionCsr(const std::vector<CSRMatrix>& csrs) { ...@@ -125,10 +109,9 @@ CSRMatrix DisjointUnionCsr(const std::vector<CSRMatrix>& csrs) {
if (has_data) { if (has_data) {
IdArray edges_data; IdArray edges_data;
if (CSRHasData(csr) == false) { if (CSRHasData(csr) == false) {
edges_data = Range(indices_offset, edges_data = Range(
indices_offset + csr.indices->shape[0], indices_offset, indices_offset + csr.indices->shape[0],
csr.indices->dtype.bits, csr.indices->dtype.bits, csr.indices->ctx);
csr.indices->ctx);
} else { } else {
edges_data = csr.data + indices_offset; edges_data = csr.data + indices_offset;
} }
...@@ -142,19 +125,15 @@ CSRMatrix DisjointUnionCsr(const std::vector<CSRMatrix>& csrs) { ...@@ -142,19 +125,15 @@ CSRMatrix DisjointUnionCsr(const std::vector<CSRMatrix>& csrs) {
IdArray result_data = has_data ? Concat(res_data) : NullArray(); IdArray result_data = has_data ? Concat(res_data) : NullArray();
return CSRMatrix( return CSRMatrix(
src_offset, dst_offset, src_offset, dst_offset, result_indptr, result_indices, result_data,
result_indptr, sorted);
result_indices,
result_data,
sorted);
} }
std::vector<CSRMatrix> DisjointPartitionCsrBySizes( std::vector<CSRMatrix> DisjointPartitionCsrBySizes(
const CSRMatrix &csr, const CSRMatrix &csr, const uint64_t batch_size,
const uint64_t batch_size, const std::vector<uint64_t> &edge_cumsum,
const std::vector<uint64_t> &edge_cumsum, const std::vector<uint64_t> &src_vertex_cumsum,
const std::vector<uint64_t> &src_vertex_cumsum, const std::vector<uint64_t> &dst_vertex_cumsum) {
const std::vector<uint64_t> &dst_vertex_cumsum) {
CHECK_EQ(edge_cumsum.size(), batch_size + 1); CHECK_EQ(edge_cumsum.size(), batch_size + 1);
CHECK_EQ(src_vertex_cumsum.size(), batch_size + 1); CHECK_EQ(src_vertex_cumsum.size(), batch_size + 1);
CHECK_EQ(dst_vertex_cumsum.size(), batch_size + 1); CHECK_EQ(dst_vertex_cumsum.size(), batch_size + 1);
...@@ -162,37 +141,32 @@ std::vector<CSRMatrix> DisjointPartitionCsrBySizes( ...@@ -162,37 +141,32 @@ std::vector<CSRMatrix> DisjointPartitionCsrBySizes(
ret.resize(batch_size); ret.resize(batch_size);
for (size_t g = 0; g < batch_size; ++g) { for (size_t g = 0; g < batch_size; ++g) {
uint64_t num_src = src_vertex_cumsum[g+1]-src_vertex_cumsum[g]; uint64_t num_src = src_vertex_cumsum[g + 1] - src_vertex_cumsum[g];
IdArray result_indptr; IdArray result_indptr;
if (g == 0) { if (g == 0) {
result_indptr = IndexSelect(csr.indptr, result_indptr =
0, IndexSelect(csr.indptr, 0, src_vertex_cumsum[1] + 1) - edge_cumsum[0];
src_vertex_cumsum[1] + 1) - edge_cumsum[0];
} else { } else {
result_indptr = IndexSelect(csr.indptr, result_indptr =
src_vertex_cumsum[g], IndexSelect(
src_vertex_cumsum[g+1] + 1) - edge_cumsum[g]; csr.indptr, src_vertex_cumsum[g], src_vertex_cumsum[g + 1] + 1) -
edge_cumsum[g];
} }
IdArray result_indices = IndexSelect(csr.indices, IdArray result_indices =
edge_cumsum[g], IndexSelect(csr.indices, edge_cumsum[g], edge_cumsum[g + 1]) -
edge_cumsum[g+1]) - dst_vertex_cumsum[g]; dst_vertex_cumsum[g];
IdArray result_data = NullArray(); IdArray result_data = NullArray();
// has data index array // has data index array
if (CSRHasData(csr)) { if (CSRHasData(csr)) {
result_data = IndexSelect(csr.data, result_data = IndexSelect(csr.data, edge_cumsum[g], edge_cumsum[g + 1]) -
edge_cumsum[g], edge_cumsum[g];
edge_cumsum[g+1]) - edge_cumsum[g];
} }
CSRMatrix sub_csr = CSRMatrix( CSRMatrix sub_csr = CSRMatrix(
num_src, num_src, dst_vertex_cumsum[g + 1] - dst_vertex_cumsum[g], result_indptr,
dst_vertex_cumsum[g+1]-dst_vertex_cumsum[g], result_indices, result_data, csr.sorted);
result_indptr,
result_indices,
result_data,
csr.sorted);
ret[g] = sub_csr; ret[g] = sub_csr;
} }
...@@ -200,36 +174,31 @@ std::vector<CSRMatrix> DisjointPartitionCsrBySizes( ...@@ -200,36 +174,31 @@ std::vector<CSRMatrix> DisjointPartitionCsrBySizes(
} }
CSRMatrix CSRSliceContiguousChunk( CSRMatrix CSRSliceContiguousChunk(
const CSRMatrix &csr, const CSRMatrix &csr, const std::vector<uint64_t> &edge_range,
const std::vector<uint64_t> &edge_range, const std::vector<uint64_t> &src_vertex_range,
const std::vector<uint64_t> &src_vertex_range, const std::vector<uint64_t> &dst_vertex_range) {
const std::vector<uint64_t> &dst_vertex_range) {
int64_t indptr_len = src_vertex_range[1] - src_vertex_range[0] + 1; int64_t indptr_len = src_vertex_range[1] - src_vertex_range[0] + 1;
IdArray result_indptr = Full(0, indptr_len, csr.indptr->dtype.bits, csr.indptr->ctx); IdArray result_indptr =
Full(0, indptr_len, csr.indptr->dtype.bits, csr.indptr->ctx);
IdArray result_indices = NullArray(csr.indptr->dtype, csr.indptr->ctx); IdArray result_indices = NullArray(csr.indptr->dtype, csr.indptr->ctx);
IdArray result_data = NullArray(); IdArray result_data = NullArray();
if (edge_range[1] != edge_range[0]) { if (edge_range[1] != edge_range[0]) {
// The chunk has edges // The chunk has edges
result_indptr = IndexSelect(csr.indptr, result_indptr =
src_vertex_range[0], IndexSelect(csr.indptr, src_vertex_range[0], src_vertex_range[1] + 1) -
src_vertex_range[1] + 1) - edge_range[0]; edge_range[0];
result_indices = IndexSelect(csr.indices, result_indices = IndexSelect(csr.indices, edge_range[0], edge_range[1]) -
edge_range[0], dst_vertex_range[0];
edge_range[1]) - dst_vertex_range[0];
if (CSRHasData(csr)) { if (CSRHasData(csr)) {
result_data = IndexSelect(csr.data, result_data =
edge_range[0], IndexSelect(csr.data, edge_range[0], edge_range[1]) - edge_range[0];
edge_range[1]) - edge_range[0];
} }
} }
CSRMatrix sub_csr = CSRMatrix( CSRMatrix sub_csr = CSRMatrix(
src_vertex_range[1]-src_vertex_range[0], src_vertex_range[1] - src_vertex_range[0],
dst_vertex_range[1]-dst_vertex_range[0], dst_vertex_range[1] - dst_vertex_range[0], result_indptr, result_indices,
result_indptr, result_data, csr.sorted);
result_indices,
result_data,
csr.sorted);
return sub_csr; return sub_csr;
} }
......
...@@ -4,7 +4,9 @@ ...@@ -4,7 +4,9 @@
* \brief DGL array utilities implementation * \brief DGL array utilities implementation
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <sstream> #include <sstream>
#include "../c_api_common.h" #include "../c_api_common.h"
#include "./uvm_array_op.h" #include "./uvm_array_op.h"
...@@ -33,12 +35,14 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) { ...@@ -33,12 +35,14 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
void IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source) { void IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source) {
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
CHECK(dest.IsPinned()) << "Destination array must be in pinned memory."; CHECK(dest.IsPinned()) << "Destination array must be in pinned memory.";
CHECK_EQ(index->ctx.device_type, kDGLCUDA) << "Index must be on the GPU."; CHECK_EQ(index->ctx.device_type, kDGLCUDA) << "Index must be on the GPU.";
CHECK_EQ(source->ctx.device_type, kDGLCUDA) << "Source array must be on the GPU."; CHECK_EQ(source->ctx.device_type, kDGLCUDA)
<< "Source array must be on the GPU.";
CHECK_EQ(dest->dtype, source->dtype) << "Destination array and source " CHECK_EQ(dest->dtype, source->dtype) << "Destination array and source "
"array must have the same dtype."; "array must have the same dtype.";
CHECK_GE(dest->ndim, 1) << "Destination array must have at least 1 dimension."; CHECK_GE(dest->ndim, 1)
<< "Destination array must have at least 1 dimension.";
CHECK_EQ(index->ndim, 1) << "Index must be a 1D array."; CHECK_EQ(index->ndim, 1) << "Index must be a 1D array.";
ATEN_DTYPE_BITS_ONLY_SWITCH(source->dtype, DType, "values", { ATEN_DTYPE_BITS_ONLY_SWITCH(source->dtype, DType, "values", {
...@@ -52,21 +56,19 @@ void IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source) { ...@@ -52,21 +56,19 @@ void IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source) {
} }
DGL_REGISTER_GLOBAL("ndarray.uvm._CAPI_DGLIndexSelectCPUFromGPU") DGL_REGISTER_GLOBAL("ndarray.uvm._CAPI_DGLIndexSelectCPUFromGPU")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
NDArray array = args[0]; NDArray array = args[0];
IdArray index = args[1]; IdArray index = args[1];
*rv = IndexSelectCPUFromGPU(array, index); *rv = IndexSelectCPUFromGPU(array, index);
}); });
DGL_REGISTER_GLOBAL("ndarray.uvm._CAPI_DGLIndexScatterGPUToCPU") DGL_REGISTER_GLOBAL("ndarray.uvm._CAPI_DGLIndexScatterGPUToCPU")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
NDArray dest = args[0]; NDArray dest = args[0];
IdArray index = args[1]; IdArray index = args[1];
NDArray source = args[2]; NDArray source = args[2];
IndexScatterGPUToCPU(dest, index, source); IndexScatterGPUToCPU(dest, index, source);
}); });
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#define DGL_ARRAY_UVM_ARRAY_OP_H_ #define DGL_ARRAY_UVM_ARRAY_OP_H_
#include <dgl/array.h> #include <dgl/array.h>
#include <utility> #include <utility>
namespace dgl { namespace dgl {
......
...@@ -3,42 +3,43 @@ ...@@ -3,42 +3,43 @@
* \file c_runtime_api.cc * \file c_runtime_api.cc
* \brief DGL C API common implementations * \brief DGL C API common implementations
*/ */
#include <dgl/graph_interface.h>
#include "c_api_common.h" #include "c_api_common.h"
#include <dgl/graph_interface.h>
using dgl::runtime::DGLArgs; using dgl::runtime::DGLArgs;
using dgl::runtime::DGLArgValue; using dgl::runtime::DGLArgValue;
using dgl::runtime::DGLRetValue; using dgl::runtime::DGLRetValue;
using dgl::runtime::PackedFunc;
using dgl::runtime::NDArray; using dgl::runtime::NDArray;
using dgl::runtime::PackedFunc;
namespace dgl { namespace dgl {
PackedFunc ConvertNDArrayVectorToPackedFunc(const std::vector<NDArray>& vec) { PackedFunc ConvertNDArrayVectorToPackedFunc(const std::vector<NDArray>& vec) {
auto body = [vec](DGLArgs args, DGLRetValue* rv) { auto body = [vec](DGLArgs args, DGLRetValue* rv) {
const uint64_t which = args[0]; const uint64_t which = args[0];
if (which >= vec.size()) { if (which >= vec.size()) {
LOG(FATAL) << "invalid choice"; LOG(FATAL) << "invalid choice";
} else { } else {
*rv = std::move(vec[which]); *rv = std::move(vec[which]);
} }
}; };
return PackedFunc(body); return PackedFunc(body);
} }
PackedFunc ConvertEdgeArrayToPackedFunc(const EdgeArray& ea) { PackedFunc ConvertEdgeArrayToPackedFunc(const EdgeArray& ea) {
auto body = [ea] (DGLArgs args, DGLRetValue* rv) { auto body = [ea](DGLArgs args, DGLRetValue* rv) {
const int which = args[0]; const int which = args[0];
if (which == 0) { if (which == 0) {
*rv = std::move(ea.src); *rv = std::move(ea.src);
} else if (which == 1) { } else if (which == 1) {
*rv = std::move(ea.dst); *rv = std::move(ea.dst);
} else if (which == 2) { } else if (which == 2) {
*rv = std::move(ea.id); *rv = std::move(ea.id);
} else { } else {
LOG(FATAL) << "invalid choice"; LOG(FATAL) << "invalid choice";
} }
}; };
return PackedFunc(body); return PackedFunc(body);
} }
......
...@@ -6,15 +6,16 @@ ...@@ -6,15 +6,16 @@
#ifndef DGL_C_API_COMMON_H_ #ifndef DGL_C_API_COMMON_H_
#define DGL_C_API_COMMON_H_ #define DGL_C_API_COMMON_H_
#include <dgl/array.h>
#include <dgl/graph_interface.h>
#include <dgl/runtime/ndarray.h> #include <dgl/runtime/ndarray.h>
#include <dgl/runtime/packed_func.h> #include <dgl/runtime/packed_func.h>
#include <dgl/runtime/registry.h> #include <dgl/runtime/registry.h>
#include <dgl/array.h>
#include <dgl/graph_interface.h>
#include <algorithm> #include <algorithm>
#include <vector>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector>
namespace dgl { namespace dgl {
...@@ -36,13 +37,13 @@ dgl::runtime::PackedFunc ConvertNDArrayVectorToPackedFunc( ...@@ -36,13 +37,13 @@ dgl::runtime::PackedFunc ConvertNDArrayVectorToPackedFunc(
* The data type of the NDArray will be IdType, which must be an integer type. * The data type of the NDArray will be IdType, which must be an integer type.
* The element type (DType) of the vector must be convertible to IdType. * The element type (DType) of the vector must be convertible to IdType.
*/ */
template<typename IdType, typename DType> template <typename IdType, typename DType>
dgl::runtime::NDArray CopyVectorToNDArray( dgl::runtime::NDArray CopyVectorToNDArray(const std::vector<DType>& vec) {
const std::vector<DType>& vec) {
using dgl::runtime::NDArray; using dgl::runtime::NDArray;
const int64_t len = vec.size(); const int64_t len = vec.size();
NDArray a = NDArray::Empty({len}, DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, NDArray a = NDArray::Empty(
DGLContext{kDGLCPU, 0}); {len}, DGLDataType{kDGLInt, sizeof(IdType) * 8, 1},
DGLContext{kDGLCPU, 0});
std::copy(vec.begin(), vec.end(), static_cast<IdType*>(a->data)); std::copy(vec.begin(), vec.end(), static_cast<IdType*>(a->data));
return a; return a;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment