"tests/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "e178158616012f2e3bc83e28457127ce6dc0cf03"
Unverified Commit d798280f authored by Tianqi Zhang (张天启)'s avatar Tianqi Zhang (张天启) Committed by GitHub
Browse files

[Bugfix] Fix SetDevice issue for NeighborMatching (#3341)



* fix setdevice issue

* change to curand device API
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent d6eecf90
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/random.h> #include <dgl/random.h>
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <curand.h> #include <curand_kernel.h>
#include <cstdint> #include <cstdint>
#include "../geometry_op.h" #include "../geometry_op.h"
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
...@@ -26,6 +26,15 @@ constexpr int EMPTY_IDX = -1; ...@@ -26,6 +26,15 @@ constexpr int EMPTY_IDX = -1;
__device__ bool done_d; __device__ bool done_d;
__global__ void init_done_kernel() { done_d = true; } __global__ void init_done_kernel() { done_d = true; }
__global__ void generate_uniform_kernel(float *ret_values, size_t num, uint64_t seed) {
size_t id = blockIdx.x * blockDim.x + threadIdx.x;
if (id < num) {
curandState state;
curand_init(seed, id, 0, &state);
ret_values[id] = curand_uniform(&state);
}
}
template <typename IdType> template <typename IdType>
__global__ void colorize_kernel(const float *prop, int64_t num_elem, IdType *result) { __global__ void colorize_kernel(const float *prop, int64_t num_elem, IdType *result) {
const IdType idx = blockIdx.x * blockDim.x + threadIdx.x; const IdType idx = blockIdx.x * blockDim.x + threadIdx.x;
...@@ -106,20 +115,21 @@ __global__ void weighted_respond_kernel(const IdType *indptr, const IdType *indi ...@@ -106,20 +115,21 @@ __global__ void weighted_respond_kernel(const IdType *indptr, const IdType *indi
* process has finished. * process has finished.
*/ */
template<typename IdType> template<typename IdType>
bool Colorize(IdType * result_data, curandGenerator_t gen, int64_t num_nodes) { bool Colorize(IdType * result_data, int64_t num_nodes) {
// initial done signal // initial done signal
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
CUDA_KERNEL_CALL(init_done_kernel, 1, 1, 0, thr_entry->stream); CUDA_KERNEL_CALL(init_done_kernel, 1, 1, 0, thr_entry->stream);
// generate color prop for each node // generate color prop for each node
float *prop; float *prop;
uint64_t seed = dgl::RandomEngine::ThreadLocal()->RandInt(UINT64_MAX);
auto num_threads = cuda::FindNumThreads(num_nodes);
auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_nodes, num_threads));
CUDA_CALL(cudaMalloc(reinterpret_cast<void **>(&prop), num_nodes * sizeof(float))); CUDA_CALL(cudaMalloc(reinterpret_cast<void **>(&prop), num_nodes * sizeof(float)));
CURAND_CALL(curandGenerateUniform(gen, prop, num_nodes)); CUDA_KERNEL_CALL(generate_uniform_kernel, num_blocks, num_threads, 0, thr_entry->stream,
cudaDeviceSynchronize(); // wait for random number generation finish since curand is async prop, num_nodes, seed);
// call kernel // call kernel
auto num_threads = cuda::FindNumThreads(num_nodes);
auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_nodes, num_threads));
CUDA_KERNEL_CALL(colorize_kernel, num_blocks, num_threads, 0, thr_entry->stream, CUDA_KERNEL_CALL(colorize_kernel, num_blocks, num_threads, 0, thr_entry->stream,
prop, num_nodes, result_data); prop, num_nodes, result_data);
bool done_h = false; bool done_h = false;
...@@ -146,15 +156,13 @@ bool Colorize(IdType * result_data, curandGenerator_t gen, int64_t num_nodes) { ...@@ -146,15 +156,13 @@ bool Colorize(IdType * result_data, curandGenerator_t gen, int64_t num_nodes) {
template <DLDeviceType XPU, typename FloatType, typename IdType> template <DLDeviceType XPU, typename FloatType, typename IdType>
void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight, IdArray result) { void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight, IdArray result) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
if (!thr_entry->curand_gen) { const auto& ctx = result->ctx;
uint64_t seed = dgl::RandomEngine::ThreadLocal()->RandInt(UINT64_MAX); auto device = runtime::DeviceAPI::Get(ctx);
CURAND_CALL(curandCreateGenerator(&thr_entry->curand_gen, CURAND_RNG_PSEUDO_DEFAULT)); device->SetDevice(ctx);
CURAND_CALL(curandSetPseudoRandomGeneratorSeed(thr_entry->curand_gen, seed));
}
// create proposal tensor // create proposal tensor
const int64_t num_nodes = result->shape[0]; const int64_t num_nodes = result->shape[0];
IdArray proposal = aten::Full(-1, num_nodes, sizeof(IdType) * 8, result->ctx); IdArray proposal = aten::Full(-1, num_nodes, sizeof(IdType) * 8, ctx);
// get data ptrs // get data ptrs
IdType *indptr_data = static_cast<IdType*>(csr.indptr->data); IdType *indptr_data = static_cast<IdType*>(csr.indptr->data);
...@@ -165,7 +173,7 @@ void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight, ...@@ -165,7 +173,7 @@ void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight,
auto num_threads = cuda::FindNumThreads(num_nodes); auto num_threads = cuda::FindNumThreads(num_nodes);
auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_nodes, num_threads)); auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_nodes, num_threads));
while (!Colorize<IdType>(result_data, thr_entry->curand_gen, num_nodes)) { while (!Colorize<IdType>(result_data, num_nodes)) {
CUDA_KERNEL_CALL(weighted_propose_kernel, num_blocks, num_threads, 0, thr_entry->stream, CUDA_KERNEL_CALL(weighted_propose_kernel, num_blocks, num_threads, 0, thr_entry->stream,
indptr_data, indices_data, weight_data, num_nodes, proposal_data, result_data); indptr_data, indices_data, weight_data, num_nodes, proposal_data, result_data);
CUDA_KERNEL_CALL(weighted_respond_kernel, num_blocks, num_threads, 0, thr_entry->stream, CUDA_KERNEL_CALL(weighted_respond_kernel, num_blocks, num_threads, 0, thr_entry->stream,
...@@ -194,19 +202,20 @@ template void WeightedNeighborMatching<kDLGPU, double, int64_t>( ...@@ -194,19 +202,20 @@ template void WeightedNeighborMatching<kDLGPU, double, int64_t>(
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) { void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) {
const int64_t num_edges = csr.indices->shape[0]; const int64_t num_edges = csr.indices->shape[0];
const auto& ctx = result->ctx;
auto device = runtime::DeviceAPI::Get(ctx);
device->SetDevice(ctx);
// generate random weights // generate random weights
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
if (!thr_entry->curand_gen) {
uint64_t seed = dgl::RandomEngine::ThreadLocal()->RandInt(UINT64_MAX);
CURAND_CALL(curandCreateGenerator(&thr_entry->curand_gen, CURAND_RNG_PSEUDO_DEFAULT));
CURAND_CALL(curandSetPseudoRandomGeneratorSeed(thr_entry->curand_gen, seed));
}
NDArray weight = NDArray::Empty( NDArray weight = NDArray::Empty(
{num_edges}, DLDataType{kDLFloat, sizeof(float) * 8, 1}, result->ctx); {num_edges}, DLDataType{kDLFloat, sizeof(float) * 8, 1}, ctx);
float *weight_data = static_cast<float*>(weight->data); float *weight_data = static_cast<float*>(weight->data);
CURAND_CALL(curandGenerateUniform(thr_entry->curand_gen, weight_data, num_edges)); uint64_t seed = dgl::RandomEngine::ThreadLocal()->RandInt(UINT64_MAX);
cudaDeviceSynchronize(); auto num_threads = cuda::FindNumThreads(num_edges);
auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_edges, num_threads));
CUDA_KERNEL_CALL(generate_uniform_kernel, num_blocks, num_threads, 0, thr_entry->stream,
weight_data, num_edges, seed);
WeightedNeighborMatching<XPU, float, IdType>(csr, weight, result); WeightedNeighborMatching<XPU, float, IdType>(csr, weight, result);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment