Unverified Commit 60b1c992 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Fix] Wrap all CUDA runtime API/CUB calls with macro (#4083)



* Wrap all CUDA runtime API/CUB calls with macro

* remove the usage of explicit cudaMalloc in favor of AllocWorkspace

* fix typo
Co-authored-by: default avatarIsrat Nisa <neesha295@gmail.com>
parent 966d1aa8
......@@ -33,11 +33,13 @@ IdArray CumSum(IdArray array, bool prepend_zero) {
}
// Allocate workspace
size_t workspace_size = 0;
cub::DeviceScan::InclusiveSum(nullptr, workspace_size, in_d, out_d, len, thr_entry->stream);
CUDA_CALL(cub::DeviceScan::InclusiveSum(
nullptr, workspace_size, in_d, out_d, len, thr_entry->stream));
void* workspace = device->AllocWorkspace(array->ctx, workspace_size);
// Compute cumsum
cub::DeviceScan::InclusiveSum(workspace, workspace_size, in_d, out_d, len, thr_entry->stream);
CUDA_CALL(cub::DeviceScan::InclusiveSum(
workspace, workspace_size, in_d, out_d, len, thr_entry->stream));
device->FreeWorkspace(array->ctx, workspace);
......
......@@ -47,11 +47,11 @@ IdArray NonZero(IdArray array) {
device->AllocWorkspace(ctx, sizeof(int64_t)));
size_t temp_size = 0;
cub::DeviceSelect::If(nullptr, temp_size, counter, out_data,
d_num_nonzeros, len, comp, stream);
CUDA_CALL(cub::DeviceSelect::If(nullptr, temp_size, counter, out_data,
d_num_nonzeros, len, comp, stream));
void * temp = device->AllocWorkspace(ctx, temp_size);
cub::DeviceSelect::If(temp, temp_size, counter, out_data,
d_num_nonzeros, len, comp, stream);
CUDA_CALL(cub::DeviceSelect::If(temp, temp_size, counter, out_data,
d_num_nonzeros, len, comp, stream));
device->FreeWorkspace(ctx, temp);
// copy number of selected elements from GPU to CPU
......
......@@ -33,13 +33,13 @@ std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits) {
// Allocate workspace
size_t workspace_size = 0;
cub::DeviceRadixSort::SortPairs(nullptr, workspace_size,
keys_in, keys_out, values_in, values_out, nitems, 0, num_bits);
CUDA_CALL(cub::DeviceRadixSort::SortPairs(nullptr, workspace_size,
keys_in, keys_out, values_in, values_out, nitems, 0, num_bits));
void* workspace = device->AllocWorkspace(ctx, workspace_size);
// Compute
cub::DeviceRadixSort::SortPairs(workspace, workspace_size,
keys_in, keys_out, values_in, values_out, nitems, 0, num_bits);
CUDA_CALL(cub::DeviceRadixSort::SortPairs(workspace, workspace_size,
keys_in, keys_out, values_in, values_out, nitems, 0, num_bits));
device->FreeWorkspace(ctx, workspace);
......
......@@ -128,15 +128,15 @@ void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr) {
// Allocate workspace
size_t workspace_size = 0;
cub::DeviceSegmentedRadixSort::SortPairs(nullptr, workspace_size,
CUDA_CALL(cub::DeviceSegmentedRadixSort::SortPairs(nullptr, workspace_size,
key_in, key_out, value_in, value_out,
nnz, csr->num_rows, offsets, offsets + 1);
nnz, csr->num_rows, offsets, offsets + 1));
void* workspace = device->AllocWorkspace(ctx, workspace_size);
// Compute
cub::DeviceSegmentedRadixSort::SortPairs(workspace, workspace_size,
CUDA_CALL(cub::DeviceSegmentedRadixSort::SortPairs(workspace, workspace_size,
key_in, key_out, value_in, value_out,
nnz, csr->num_rows, offsets, offsets + 1);
nnz, csr->num_rows, offsets, offsets + 1));
csr->sorted = true;
csr->indices = new_indices;
......
......@@ -462,11 +462,11 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray
auto ptr_sorted_cols = sorted_array.Ptr<IdType>();
auto ptr_cols = cols.Ptr<IdType>();
size_t workspace_size = 0;
cub::DeviceRadixSort::SortKeys(
nullptr, workspace_size, ptr_cols, ptr_sorted_cols, cols->shape[0]);
CUDA_CALL(cub::DeviceRadixSort::SortKeys(
nullptr, workspace_size, ptr_cols, ptr_sorted_cols, cols->shape[0]));
void *workspace = device->AllocWorkspace(ctx, workspace_size);
cub::DeviceRadixSort::SortKeys(
workspace, workspace_size, ptr_cols, ptr_sorted_cols, cols->shape[0]);
CUDA_CALL(cub::DeviceRadixSort::SortKeys(
workspace, workspace_size, ptr_cols, ptr_sorted_cols, cols->shape[0]));
device->FreeWorkspace(ctx, workspace);
// Execute SegmentMaskColKernel
......
......@@ -115,17 +115,15 @@ __global__ void weighted_respond_kernel(const IdType *indptr, const IdType *indi
* process has finished.
*/
template<typename IdType>
bool Colorize(IdType * result_data, int64_t num_nodes) {
bool Colorize(IdType * result_data, int64_t num_nodes, float * const prop) {
// initial done signal
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
CUDA_KERNEL_CALL(init_done_kernel, 1, 1, 0, thr_entry->stream);
// generate color prop for each node
float *prop;
uint64_t seed = dgl::RandomEngine::ThreadLocal()->RandInt(UINT64_MAX);
auto num_threads = cuda::FindNumThreads(num_nodes);
auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_nodes, num_threads));
CUDA_CALL(cudaMalloc(reinterpret_cast<void **>(&prop), num_nodes * sizeof(float)));
CUDA_KERNEL_CALL(generate_uniform_kernel, num_blocks, num_threads, 0, thr_entry->stream,
prop, num_nodes, seed);
......@@ -134,7 +132,6 @@ bool Colorize(IdType * result_data, int64_t num_nodes) {
prop, num_nodes, result_data);
bool done_h = false;
CUDA_CALL(cudaMemcpyFromSymbol(&done_h, done_d, sizeof(done_h), 0, cudaMemcpyDeviceToHost));
CUDA_CALL(cudaFree(prop));
return done_h;
}
......@@ -171,14 +168,19 @@ void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight,
IdType *proposal_data = static_cast<IdType*>(proposal->data);
FloatType *weight_data = static_cast<FloatType*>(weight->data);
// allocate workspace for prop used in Colorize()
float *prop = static_cast<float*>(
device->AllocWorkspace(ctx, num_nodes * sizeof(float)));
auto num_threads = cuda::FindNumThreads(num_nodes);
auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_nodes, num_threads));
while (!Colorize<IdType>(result_data, num_nodes)) {
while (!Colorize<IdType>(result_data, num_nodes, prop)) {
CUDA_KERNEL_CALL(weighted_propose_kernel, num_blocks, num_threads, 0, thr_entry->stream,
indptr_data, indices_data, weight_data, num_nodes, proposal_data, result_data);
CUDA_KERNEL_CALL(weighted_respond_kernel, num_blocks, num_threads, 0, thr_entry->stream,
indptr_data, indices_data, weight_data, num_nodes, proposal_data, result_data);
}
device->FreeWorkspace(ctx, prop);
}
template void WeightedNeighborMatching<kDLGPU, float, int32_t>(
const aten::CSRMatrix &csr, const NDArray weight, IdArray result);
......
......@@ -250,7 +250,7 @@ FrequencyHashmap<IdxType>::FrequencyHashmap(
constexpr int TILE_SIZE = BLOCK_SIZE * 8;
dim3 block(BLOCK_SIZE);
dim3 grid((num_dst * num_items_each_dst + TILE_SIZE - 1) / TILE_SIZE);
cudaMemset(dst_unique_edges, 0, (num_dst) * sizeof(IdxType));
CUDA_CALL(cudaMemset(dst_unique_edges, 0, (num_dst) * sizeof(IdxType)));
CUDA_KERNEL_CALL((_init_edge_table<IdxType, BLOCK_SIZE, TILE_SIZE>),
grid, block, 0, _stream,
edge_hashmap, (num_dst * num_items_each_dst));
......
......@@ -120,9 +120,9 @@ class CUDADeviceAPI final : public DeviceAPI {
if (ctx_from.device_id == ctx_to.device_id) {
GPUCopy(from, to, size, cudaMemcpyDeviceToDevice, cu_stream);
} else {
cudaMemcpyPeerAsync(to, ctx_to.device_id,
CUDA_CALL(cudaMemcpyPeerAsync(to, ctx_to.device_id,
from, ctx_from.device_id,
size, cu_stream);
size, cu_stream));
}
} else if (ctx_from.device_type == kDLGPU && ctx_to.device_type == kDLCPU) {
CUDA_CALL(cudaSetDevice(ctx_from.device_id));
......
......@@ -239,7 +239,7 @@ std::pair<IdArray, NDArray> SparsePush(
comm->AllToAll(send_sum, recv_sum.get(), 1, stream);
cudaEvent_t d2h;
cudaEventCreate(&d2h);
CUDA_CALL(cudaEventCreate(&d2h));
// compute the prefix sum of the recv values
Workspace<int64_t> recv_prefix(device, ctx, comm_size+1);
......@@ -269,11 +269,11 @@ std::pair<IdArray, NDArray> SparsePush(
recv_prefix.free();
// use an event to track when copying is done
cudaEventRecord(d2h, stream);
CUDA_CALL(cudaEventRecord(d2h, stream));
// allocate output space
cudaEventSynchronize(d2h);
cudaEventDestroy(d2h);
CUDA_CALL(cudaEventSynchronize(d2h));
CUDA_CALL(cudaEventDestroy(d2h));
IdArray recv_idx = aten::NewIdArray(
recv_prefix_host.back(), ctx, sizeof(IdType)*8);
......@@ -369,7 +369,7 @@ NDArray SparsePull(
}
cudaEvent_t d2h;
cudaEventCreate(&d2h);
CUDA_CALL(cudaEventCreate(&d2h));
std::vector<int64_t> request_prefix_host(comm_size+1);
device->CopyDataFromTo(
......@@ -420,11 +420,11 @@ NDArray SparsePull(
response_prefix.free();
// use an event to track when copying is done
cudaEventRecord(d2h, stream);
CUDA_CALL(cudaEventRecord(d2h, stream));
// allocate output space
cudaEventSynchronize(d2h);
cudaEventDestroy(d2h);
CUDA_CALL(cudaEventSynchronize(d2h));
CUDA_CALL(cudaEventDestroy(d2h));
// gather requested indexes
IdArray recv_idx = aten::NewIdArray(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment