Unverified Commit 0ff7127a authored by nv-dlasalle's avatar nv-dlasalle Committed by GitHub
Browse files

[Bugfix] Wrap cub with CUB_NS_PREFIX and remove dependency on Thrust to...


[Bugfix] Wrap cub with CUB_NS_PREFIX and remove dependency on Thrust to linking issues with Torch 1.8 (#2758)

* Wrap cub with prefixes and remove thrust

* Using counting iterator
Co-authored-by: default avatarZihao Ye <expye@outlook.com>
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent 9aac93ff
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
* \brief Array cumsum GPU implementation * \brief Array cumsum GPU implementation
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <cub/cub.cuh>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./utils.h" #include "./utils.h"
#include "./dgl_cub.cuh"
namespace dgl { namespace dgl {
using runtime::NDArray; using runtime::NDArray;
......
...@@ -3,13 +3,11 @@ ...@@ -3,13 +3,11 @@
* \file array/cpu/array_nonzero.cc * \file array/cpu/array_nonzero.cc
* \brief Array nonzero CPU implementation * \brief Array nonzero CPU implementation
*/ */
#include <thrust/iterator/counting_iterator.h>
#include <thrust/copy.h>
#include <thrust/functional.h>
#include <thrust/device_vector.h>
#include <dgl/array.h> #include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./utils.h" #include "./utils.h"
#include "./dgl_cub.cuh"
namespace dgl { namespace dgl {
using runtime::NDArray; using runtime::NDArray;
...@@ -17,32 +15,59 @@ namespace aten { ...@@ -17,32 +15,59 @@ namespace aten {
namespace impl { namespace impl {
template <typename IdType> template <typename IdType>
struct IsNonZero { struct IsNonZeroIndex {
__device__ bool operator() (const IdType val) { explicit IsNonZeroIndex(const IdType * array) : array_(array) {
return val != 0; }
__device__ bool operator() (const int64_t index) {
return array_[index] != 0;
} }
const IdType * array_;
}; };
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
IdArray NonZero(IdArray array) { IdArray NonZero(IdArray array) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); const auto& ctx = array->ctx;
auto device = runtime::DeviceAPI::Get(ctx);
const int64_t len = array->shape[0]; const int64_t len = array->shape[0];
IdArray ret = NewIdArray(len, array->ctx, 64); IdArray ret = NewIdArray(len, ctx, 64);
thrust::device_ptr<IdType> in_data(array.Ptr<IdType>());
thrust::device_ptr<int64_t> out_data(ret.Ptr<int64_t>()); cudaStream_t stream = 0;
// TODO(minjie): should take control of the memory allocator.
// See PyTorch's implementation here: const IdType * const in_data = static_cast<const IdType*>(array->data);
// https://github.com/pytorch/pytorch/blob/1f7557d173c8e9066ed9542ada8f4a09314a7e17/ int64_t * const out_data = static_cast<int64_t*>(ret->data);
// aten/src/THC/generic/THCTensorMath.cu#L104
auto startiter = thrust::make_counting_iterator<int64_t>(0); IsNonZeroIndex<IdType> comp(in_data);
auto enditer = startiter + len; cub::CountingInputIterator<int64_t> counter(0);
auto indices_end = thrust::copy_if(thrust::cuda::par.on(thr_entry->stream),
startiter, // room for cub to output on GPU
enditer, int64_t * d_num_nonzeros = static_cast<int64_t*>(
in_data, device->AllocWorkspace(ctx, sizeof(int64_t)));
out_data,
IsNonZero<IdType>()); size_t temp_size = 0;
const int64_t num_nonzeros = indices_end - out_data; cub::DeviceSelect::If(nullptr, temp_size, counter, out_data,
d_num_nonzeros, len, comp, stream);
void * temp = device->AllocWorkspace(ctx, temp_size);
cub::DeviceSelect::If(temp, temp_size, counter, out_data,
d_num_nonzeros, len, comp, stream);
device->FreeWorkspace(ctx, temp);
// copy number of selected elements from GPU to CPU
int64_t num_nonzeros;
device->CopyDataFromTo(
d_num_nonzeros, 0,
&num_nonzeros, 0,
sizeof(num_nonzeros),
ctx,
DGLContext{kDLCPU, 0},
DGLType{kDLInt, 64, 1},
stream);
device->FreeWorkspace(ctx, d_num_nonzeros);
device->StreamSync(ctx, stream);
// truncate array to size
return ret.CreateView({num_nonzeros}, ret->dtype, 0); return ret.CreateView({num_nonzeros}, ret->dtype, 0);
} }
......
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
* \brief Array sort GPU implementation * \brief Array sort GPU implementation
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <cub/cub.cuh>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./utils.h" #include "./utils.h"
#include "./dgl_cub.cuh"
namespace dgl { namespace dgl {
using runtime::NDArray; using runtime::NDArray;
......
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
* \brief Sort CSR index * \brief Sort CSR index
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <cub/cub.cuh>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./utils.h" #include "./utils.h"
#include "./dgl_cub.cuh"
namespace dgl { namespace dgl {
......
/*!
* Copyright (c) 2021 by Contributors
* \file cuda_common.h
* \brief Wrapper to place cub in dgl namespace.
*/
#ifndef DGL_ARRAY_CUDA_DGL_CUB_CUH_
#define DGL_ARRAY_CUDA_DGL_CUB_CUH_
// include cub in a safe manner
#define CUB_NS_PREFIX namespace dgl {
#define CUB_NS_POSTFIX }
#include "cub/cub.cuh"
#undef CUB_NS_POSTFIX
#undef CUB_NS_PREFIX
#endif
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
*/ */
#include "./utils.h" #include "./utils.h"
#include <cub/cub.cuh> #include "./dgl_cub.cuh"
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
namespace dgl { namespace dgl {
......
...@@ -4,11 +4,11 @@ ...@@ -4,11 +4,11 @@
* \brief Device level functions for within cuda kernels. * \brief Device level functions for within cuda kernels.
*/ */
#include <cub/cub.cuh>
#include <cassert> #include <cassert>
#include "cuda_hashtable.cuh" #include "cuda_hashtable.cuh"
#include "../../kernel/cuda/atomic.cuh" #include "../../kernel/cuda/atomic.cuh"
#include "../../array/cuda/dgl_cub.cuh"
using namespace dgl::kernel::cuda; using namespace dgl::kernel::cuda;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment