Unverified Commit 9a00cf19 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Feature] Import PyTorch's CUDA stream management (#4503)

* add set_stream

* add .record_stream for NDArray and HeteroGraph

* refactor dgl stream Python APIs

* test record_stream

* add unit test for record stream

* use pytorch's stream

* fix lint

* fix cpu build

* address comments

* address comments

* add record stream tests for dgl.graph

* record frames and update dataloder

* add docstring

* update frame

* add backend check for record_stream

* remove CUDAThreadEntry::stream

* record stream for newly created formats

* fix bug

* fix cpp test

* fix None c_void_p to c_handle
parent 099b173f
......@@ -23,7 +23,7 @@ namespace impl {
template <DLDeviceType XPU, typename IdType>
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
const auto& ctx = csr.indptr->ctx;
IdArray rows = aten::VecToIdArray<int64_t>({row}, sizeof(IdType) * 8, ctx);
IdArray cols = aten::VecToIdArray<int64_t>({col}, sizeof(IdType) * 8, ctx);
......@@ -33,7 +33,7 @@ bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
const IdType* data = nullptr;
// TODO(minjie): use binary search for sorted csr
CUDA_KERNEL_CALL(dgl::cuda::_LinearSearchKernel,
1, 1, 0, thr_entry->stream,
1, 1, 0, stream,
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), data,
rows.Ptr<IdType>(), cols.Ptr<IdType>(),
1, 1, 1,
......@@ -55,13 +55,13 @@ NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
return rst;
const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
const int nt = dgl::cuda::FindNumThreads(rstlen);
const int nb = (rstlen + nt - 1) / nt;
const IdType* data = nullptr;
// TODO(minjie): use binary search for sorted csr
CUDA_KERNEL_CALL(dgl::cuda::_LinearSearchKernel,
nb, nt, 0, thr_entry->stream,
nb, nt, 0, stream,
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), data,
row.Ptr<IdType>(), col.Ptr<IdType>(),
row_stride, col_stride, rstlen,
......@@ -100,7 +100,7 @@ bool CSRHasDuplicate(CSRMatrix csr) {
if (!csr.sorted)
csr = CSRSort(csr);
const auto& ctx = csr.indptr->ctx;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
auto device = runtime::DeviceAPI::Get(ctx);
// We allocate a workspace of num_rows bytes. It wastes a little bit memory but should
// be fine.
......@@ -108,7 +108,7 @@ bool CSRHasDuplicate(CSRMatrix csr) {
const int nt = dgl::cuda::FindNumThreads(csr.num_rows);
const int nb = (csr.num_rows + nt - 1) / nt;
CUDA_KERNEL_CALL(_SegmentHasNoDuplicate,
nb, nt, 0, thr_entry->stream,
nb, nt, 0, stream,
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(),
csr.num_rows, flags);
bool ret = dgl::cuda::AllTrue(flags, csr.num_rows, ctx);
......@@ -148,7 +148,7 @@ __global__ void _CSRGetRowNNZKernel(
template <DLDeviceType XPU, typename IdType>
NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
const auto len = rows->shape[0];
const IdType* vid_data = static_cast<IdType*>(rows->data);
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
......@@ -157,7 +157,7 @@ NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {
const int nt = dgl::cuda::FindNumThreads(len);
const int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL(_CSRGetRowNNZKernel,
nb, nt, 0, thr_entry->stream,
nb, nt, 0, stream,
vid_data, indptr_data, rst_data, len);
return rst;
}
......@@ -245,7 +245,7 @@ __global__ void _SegmentCopyKernel(
template <DLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
const int64_t len = rows->shape[0];
IdArray ret_indptr = aten::CumSum(aten::CSRGetRowNNZ(csr, rows), true);
const int64_t nnz = aten::IndexSelect<IdType>(ret_indptr, len);
......@@ -256,14 +256,14 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
// Copy indices.
IdArray ret_indices = NDArray::Empty({nnz}, csr.indptr->dtype, rows->ctx);
CUDA_KERNEL_CALL(_SegmentCopyKernel,
nb, nt, 0, thr_entry->stream,
nb, nt, 0, stream,
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(),
rows.Ptr<IdType>(), nnz, len,
ret_indptr.Ptr<IdType>(), ret_indices.Ptr<IdType>());
// Copy data.
IdArray ret_data = NDArray::Empty({nnz}, csr.indptr->dtype, rows->ctx);
CUDA_KERNEL_CALL(_SegmentCopyKernel,
nb, nt, 0, thr_entry->stream,
nb, nt, 0, stream,
csr.indptr.Ptr<IdType>(), CSRHasData(csr)? csr.data.Ptr<IdType>() : nullptr,
rows.Ptr<IdType>(), nnz, len,
ret_indptr.Ptr<IdType>(), ret_data.Ptr<IdType>());
......@@ -358,14 +358,14 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray row, NDArray co
const int64_t nnz = csr.indices->shape[0];
const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
// Generate a 0-1 mask for matched (row, col) positions.
IdArray mask = Full(0, nnz, nbits, ctx);
const int nt = dgl::cuda::FindNumThreads(len);
const int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL(_SegmentMaskKernel,
nb, nt, 0, thr_entry->stream,
nb, nt, 0, stream,
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(),
row.Ptr<IdType>(), col.Ptr<IdType>(),
row_stride, col_stride, len,
......@@ -381,7 +381,7 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray row, NDArray co
const int nt2 = dgl::cuda::FindNumThreads(idx->shape[0]);
const int nb2 = (idx->shape[0] + nt - 1) / nt;
CUDA_KERNEL_CALL(_SortedSearchKernel,
nb2, nt2, 0, thr_entry->stream,
nb2, nt2, 0, stream,
csr.indptr.Ptr<IdType>(), csr.num_rows,
idx.Ptr<IdType>(), idx->shape[0],
ret_row.Ptr<IdType>());
......@@ -424,7 +424,7 @@ __global__ void _SegmentMaskColKernel(
template <DLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
const auto& ctx = rows->ctx;
const auto& dtype = rows->dtype;
const auto nbits = dtype.bits;
......@@ -464,17 +464,17 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray
size_t workspace_size = 0;
CUDA_CALL(cub::DeviceRadixSort::SortKeys(
nullptr, workspace_size, ptr_cols, ptr_sorted_cols, cols->shape[0],
0, sizeof(IdType)*8, thr_entry->stream));
0, sizeof(IdType)*8, stream));
void *workspace = device->AllocWorkspace(ctx, workspace_size);
CUDA_CALL(cub::DeviceRadixSort::SortKeys(
workspace, workspace_size, ptr_cols, ptr_sorted_cols, cols->shape[0],
0, sizeof(IdType)*8, thr_entry->stream));
0, sizeof(IdType)*8, stream));
device->FreeWorkspace(ctx, workspace);
// Execute SegmentMaskColKernel
int nb = (nnz_csr + nt - 1) / nt;
CUDA_KERNEL_CALL(_SegmentMaskColKernel,
nb, nt, 0, thr_entry->stream,
nb, nt, 0, stream,
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), csr.num_rows, nnz_csr,
ptr_sorted_cols, cols_size,
mask.Ptr<IdType>(), count.Ptr<IdType>());
......
......@@ -103,9 +103,10 @@ void _Transpose(const DType* in, DType* out,
int row, int col) {
DType alpha = 1., beta = 0.;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
if (!thr_entry->cublas_handle)
CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle)));
CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle, thr_entry->stream));
CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle, stream));
CUBLAS_CALL(Xgeam<DType>(
thr_entry->cublas_handle,
CUBLAS_OP_T,
......@@ -123,10 +124,10 @@ void _Transpose(const DType* in, DType* out,
template <>
void _Transpose<half>(const half* in, half* out,
int row, int col) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = FindNumThreads(row);
int nb = col;
CUDA_KERNEL_CALL(_TransposeKernel, nb, nt, 0, thr_entry->stream, in, out, col, row);
CUDA_KERNEL_CALL(_TransposeKernel, nb, nt, 0, stream, in, out, col, row);
}
/*
......@@ -149,7 +150,7 @@ __global__ void _IndexSelectKernel(const DType* array, const IdType* index,
*/
template<typename DType, typename IdType>
NDArray _IndexSelect(NDArray array, NDArray index) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
const DType* array_data = static_cast<DType*>(array->data);
const IdType* idx_data = static_cast<IdType*>(index->data);
const int64_t arr_len = array->shape[0];
......@@ -160,7 +161,7 @@ NDArray _IndexSelect(NDArray array, NDArray index) {
DType* ret_data = static_cast<DType*>(ret->data);
const int nt = FindNumThreads(len);
const int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL(_IndexSelectKernel, nb, nt, 0, thr_entry->stream,
CUDA_KERNEL_CALL(_IndexSelectKernel, nb, nt, 0, stream,
array_data, idx_data, len, ret_data);
return ret;
}
......@@ -223,11 +224,12 @@ void CusparseCsrmm2(
// device
auto device = runtime::DeviceAPI::Get(ctx);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
// allocate cusparse handle if needed
if (!thr_entry->cusparse_handle) {
CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
}
CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, thr_entry->stream));
CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, stream));
// all one data array
DType* valptr = nullptr;
if (!A_data) {
......@@ -678,7 +680,7 @@ void SpMMCoo(
DType *out_data = out.Ptr<DType>();
Idx *argu_data = argu.Ptr<Idx>(),
*arge_data = arge.Ptr<Idx>();
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
const int64_t N = coo.num_rows, M = coo.num_cols, E = coo.row->shape[0];
int64_t *ubcast_off = nullptr, *ebcast_off = nullptr;
......@@ -689,7 +691,7 @@ void SpMMCoo(
int64_t out_size = out.NumElements();
const int nt = FindNumThreads(out_size);
const int nb = (out_size + nt - 1) / nt;
CUDA_KERNEL_CALL(_FillKernel, nb, nt, 0, thr_entry->stream,
CUDA_KERNEL_CALL(_FillKernel, nb, nt, 0, stream,
out_data, out_size, ReduceOp::zero());
const int ntx = FindNumThreads(len);
......@@ -703,7 +705,7 @@ void SpMMCoo(
BCAST_IDX_CTX_SWITCH(bcast, use_idx, ufeat->ctx, ubcast_off, ebcast_off, {
CUDA_KERNEL_CALL((SpMMCooKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>),
nblks, nthrs, 0, thr_entry->stream,
nblks, nthrs, 0, stream,
ufeat_data, efeat_data, out_data, argu_data, arge_data,
row, col, edge_map,
N, M, E,
......@@ -711,7 +713,7 @@ void SpMMCoo(
lhs_len, rhs_len, len);
if (ReduceOp::require_arg) {
CUDA_KERNEL_CALL((ArgSpMMCooKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>),
nblks, nthrs, 0, thr_entry->stream,
nblks, nthrs, 0, stream,
ufeat_data, efeat_data, out_data, argu_data, arge_data,
row, col, edge_map,
N, M, E,
......@@ -751,7 +753,7 @@ void SpMMCsr(
Idx* argu_data = argu.Ptr<Idx>();
Idx* arge_data = arge.Ptr<Idx>();
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
int64_t *ubcast_off = nullptr, *ebcast_off = nullptr;
int64_t len = bcast.out_len,
......@@ -768,7 +770,7 @@ void SpMMCsr(
BCAST_IDX_CTX_SWITCH(bcast, use_idx, ufeat->ctx, ubcast_off, ebcast_off, {
CUDA_KERNEL_CALL((SpMMCsrKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>),
nblks, nthrs, 0, thr_entry->stream,
nblks, nthrs, 0, stream,
ufeat_data, efeat_data, out_data, argu_data, arge_data,
indptr, indices, edge_map,
csr.num_rows, csr.num_cols,
......@@ -817,7 +819,7 @@ void SpMMCmpCsrHetero(
Idx* argu_data = argu.Ptr<Idx>();
Idx* arge_data = arge.Ptr<Idx>();
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
int64_t *ubcast_off = nullptr, *ebcast_off = nullptr;
int64_t len = bcast.out_len,
......@@ -833,7 +835,7 @@ void SpMMCmpCsrHetero(
BCAST_IDX_CTX_SWITCH(bcast, use_idx, ufeat->ctx, ubcast_off, ebcast_off, {
CUDA_KERNEL_CALL((SpMMCmpCsrHeteroKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>),
nblks, nthrs, 0, thr_entry->stream,
nblks, nthrs, 0, stream,
ufeat_data, efeat_data, out_data, argu_data, arge_data,
static_cast<Idx*>(argu_ntype->data),
static_cast<Idx*>(arge_etype->data),
......
......@@ -118,7 +118,7 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
}
}
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) {
const dgl_type_t src_id = ufeat_ntids[etype];
const dgl_type_t dst_id = out_ntids[etype];
......@@ -135,7 +135,7 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
static_cast<DType*>(vec_ufeat[src_id]->data),
nullptr,
out,
x_length, thr_entry->stream);
x_length, stream);
} else if (op == "mul" && is_scalar_efeat &&
cusparse_available<bits, IdType>(more_nnz)) { // cusparse
NDArray efeat = vec_efeat[etype];
......@@ -147,7 +147,7 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
static_cast<DType*>(efeat->data),
// TODO(Israt): Change (*vec_out) to trans_out to support CUDA version < 11
static_cast<DType*>((*vec_out)[dst_id]->data),
x_length, thr_entry->stream);
x_length, stream);
} else { // general kernel
NDArray ufeat = (vec_ufeat.size() == 0) ?
NullArray() : vec_ufeat[src_id];
......
......@@ -16,7 +16,7 @@ bool AllTrue(int8_t* flags, int64_t length, const DLContext& ctx) {
int8_t* rst = static_cast<int8_t*>(device->AllocWorkspace(ctx, 1));
// Call CUB's reduction
size_t workspace_size = 0;
cudaStream_t stream = runtime::CUDAThreadEntry::ThreadLocal()->stream;
cudaStream_t stream = runtime::getCurrentCUDAStream();
CUDA_CALL(cub::DeviceReduce::Min(nullptr, workspace_size, flags, rst, length, stream));
void* workspace = device->AllocWorkspace(ctx, workspace_size);
CUDA_CALL(cub::DeviceReduce::Min(workspace, workspace_size, flags, rst, length, stream));
......
......@@ -135,10 +135,10 @@ __global__ void _FillKernel(DType* ptr, size_t length, DType val) {
/*! \brief Fill the vector started from ptr of size length with val */
template <typename DType>
void _Fill(DType* ptr, size_t length, DType val) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = FindNumThreads(length);
int nb = (length + nt - 1) / nt; // on x-axis, no need to worry about upperbound.
CUDA_KERNEL_CALL(cuda::_FillKernel, nb, nt, 0, thr_entry->stream, ptr, length, val);
CUDA_KERNEL_CALL(cuda::_FillKernel, nb, nt, 0, stream, ptr, length, val);
}
/*!
......
......@@ -16,7 +16,7 @@ namespace impl {
template<typename DType, typename IdType>
NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
const DType* array_data = static_cast<DType*>(array->data);
const IdType* idx_data = static_cast<IdType*>(index->data);
const int64_t arr_len = array->shape[0];
......@@ -41,7 +41,7 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
const int nt = cuda::FindNumThreads(len);
const int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL(IndexSelectSingleKernel, nb, nt, 0,
thr_entry->stream, array_data, idx_data, len, arr_len, ret_data);
stream, array_data, idx_data, len, arr_len, ret_data);
} else {
dim3 block(256, 1);
while (static_cast<int64_t>(block.x) >= 2*num_feat) {
......@@ -51,11 +51,11 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
const dim3 grid((len+block.y-1)/block.y);
if (num_feat * sizeof(DType) < 2 * CACHE_LINE_SIZE) {
CUDA_KERNEL_CALL(IndexSelectMultiKernel, grid, block, 0,
thr_entry->stream, array_data, num_feat, idx_data,
stream, array_data, num_feat, idx_data,
len, arr_len, ret_data);
} else {
CUDA_KERNEL_CALL(IndexSelectMultiKernelAligned, grid, block, 0,
thr_entry->stream, array_data, num_feat, idx_data,
stream, array_data, num_feat, idx_data,
len, arr_len, ret_data);
}
}
......@@ -75,7 +75,7 @@ template NDArray IndexSelectCPUFromGPU<int64_t, int64_t>(NDArray, IdArray);
template<typename DType, typename IdType>
void IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
DType* dest_data = static_cast<DType*>(dest->data);
const DType* source_data = static_cast<DType*>(source->data);
const IdType* idx_data = static_cast<IdType*>(index->data);
......@@ -99,7 +99,7 @@ void IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source) {
const int nt = cuda::FindNumThreads(len);
const int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL(IndexScatterSingleKernel, nb, nt, 0,
thr_entry->stream, source_data, idx_data, len, arr_len, dest_data);
stream, source_data, idx_data, len, arr_len, dest_data);
} else {
dim3 block(256, 1);
while (static_cast<int64_t>(block.x) >= 2*num_feat) {
......@@ -108,7 +108,7 @@ void IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source) {
}
const dim3 grid((len+block.y-1)/block.y);
CUDA_KERNEL_CALL(IndexScatterMultiKernel, grid, block, 0,
thr_entry->stream, source_data, num_feat, idx_data,
stream, source_data, num_feat, idx_data,
len, arr_len, dest_data);
}
}
......
......@@ -117,18 +117,18 @@ __global__ void weighted_respond_kernel(const IdType *indptr, const IdType *indi
template<typename IdType>
bool Colorize(IdType * result_data, int64_t num_nodes, float * const prop) {
// initial done signal
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
CUDA_KERNEL_CALL(init_done_kernel, 1, 1, 0, thr_entry->stream);
cudaStream_t stream = runtime::getCurrentCUDAStream();
CUDA_KERNEL_CALL(init_done_kernel, 1, 1, 0, stream);
// generate color prop for each node
uint64_t seed = dgl::RandomEngine::ThreadLocal()->RandInt(UINT64_MAX);
auto num_threads = cuda::FindNumThreads(num_nodes);
auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_nodes, num_threads));
CUDA_KERNEL_CALL(generate_uniform_kernel, num_blocks, num_threads, 0, thr_entry->stream,
CUDA_KERNEL_CALL(generate_uniform_kernel, num_blocks, num_threads, 0, stream,
prop, num_nodes, seed);
// call kernel
CUDA_KERNEL_CALL(colorize_kernel, num_blocks, num_threads, 0, thr_entry->stream,
CUDA_KERNEL_CALL(colorize_kernel, num_blocks, num_threads, 0, stream,
prop, num_nodes, result_data);
bool done_h = false;
CUDA_CALL(cudaMemcpyFromSymbol(&done_h, done_d, sizeof(done_h), 0, cudaMemcpyDeviceToHost));
......@@ -152,7 +152,7 @@ bool Colorize(IdType * result_data, int64_t num_nodes, float * const prop) {
*/
template <DLDeviceType XPU, typename FloatType, typename IdType>
void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight, IdArray result) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
const auto& ctx = result->ctx;
auto device = runtime::DeviceAPI::Get(ctx);
device->SetDevice(ctx);
......@@ -175,9 +175,9 @@ void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight,
auto num_threads = cuda::FindNumThreads(num_nodes);
auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_nodes, num_threads));
while (!Colorize<IdType>(result_data, num_nodes, prop)) {
CUDA_KERNEL_CALL(weighted_propose_kernel, num_blocks, num_threads, 0, thr_entry->stream,
CUDA_KERNEL_CALL(weighted_propose_kernel, num_blocks, num_threads, 0, stream,
indptr_data, indices_data, weight_data, num_nodes, proposal_data, result_data);
CUDA_KERNEL_CALL(weighted_respond_kernel, num_blocks, num_threads, 0, thr_entry->stream,
CUDA_KERNEL_CALL(weighted_respond_kernel, num_blocks, num_threads, 0, stream,
indptr_data, indices_data, weight_data, num_nodes, proposal_data, result_data);
}
device->FreeWorkspace(ctx, prop);
......@@ -209,14 +209,14 @@ void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) {
device->SetDevice(ctx);
// generate random weights
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
NDArray weight = NDArray::Empty(
{num_edges}, DLDataType{kDLFloat, sizeof(float) * 8, 1}, ctx);
float *weight_data = static_cast<float*>(weight->data);
uint64_t seed = dgl::RandomEngine::ThreadLocal()->RandInt(UINT64_MAX);
auto num_threads = cuda::FindNumThreads(num_edges);
auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_edges, num_threads));
CUDA_KERNEL_CALL(generate_uniform_kernel, num_blocks, num_threads, 0, thr_entry->stream,
CUDA_KERNEL_CALL(generate_uniform_kernel, num_blocks, num_threads, 0, stream,
weight_data, num_edges, seed);
WeightedNeighborMatching<XPU, float, IdType>(csr, weight, result);
......
......@@ -92,7 +92,7 @@ __global__ void fps_kernel(const FloatType *array_data, const int64_t batch_size
template <DLDeviceType XPU, typename FloatType, typename IdType>
void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
const FloatType* array_data = static_cast<FloatType*>(array->data);
......@@ -110,7 +110,7 @@ void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_poin
CUDA_CALL(cudaSetDevice(array->ctx.device_id));
CUDA_KERNEL_CALL(fps_kernel,
batch_size, THREADS, 0, thr_entry->stream,
batch_size, THREADS, 0, stream,
array_data, batch_size, sample_points,
point_in_batch, dim, start_idx_data, dist_data, ret_data);
}
......
......@@ -276,6 +276,11 @@ void HeteroGraph::UnpinMemory_() {
g->UnpinMemory_();
}
void HeteroGraph::RecordStream(DGLStreamHandle stream) {
for (auto g : relation_graphs_)
g->RecordStream(stream);
}
std::string HeteroGraph::SharedMemName() const {
return shared_mem_ ? shared_mem_->GetName() : "";
}
......
......@@ -251,6 +251,12 @@ class HeteroGraph : public BaseHeteroGraph {
*/
void UnpinMemory_();
/*!
* \brief Record stream for this graph.
* \param stream The stream that is using the graph
*/
void RecordStream(DGLStreamHandle stream) override;
/*! \brief Copy the data to shared memory.
*
* Also save names of node types and edge types of the HeteroGraph object to shared memory
......
......@@ -493,6 +493,15 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpinMemory_")
*rv = hg;
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroRecordStream")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
DGLStreamHandle stream = args[1];
auto hgindex = std::dynamic_pointer_cast<HeteroGraph>(hg.sptr());
hgindex->RecordStream(stream);
*rv = hg;
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyToSharedMem")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
......
......@@ -198,7 +198,7 @@ std::pair<IdArray, IdArray> RandomWalkUniform(
h_graphs[etype].data = (CSRHasData(csr) ? static_cast<const IdType*>(csr.data->data) : nullptr);
}
// use cuda stream from local thread
cudaStream_t stream = runtime::CUDAThreadEntry::ThreadLocal()->stream;
cudaStream_t stream = runtime::getCurrentCUDAStream();
auto device = DeviceAPI::Get(ctx);
auto d_graphs = static_cast<GraphKernelData<IdType>*>(
device->AllocWorkspace(ctx, (num_etypes) * sizeof(GraphKernelData<IdType>)));
......@@ -269,7 +269,7 @@ std::pair<IdArray, IdArray> RandomWalkBiased(
IdType *traces_data = traces.Ptr<IdType>();
IdType *eids_data = eids.Ptr<IdType>();
cudaStream_t stream = runtime::CUDAThreadEntry::ThreadLocal()->stream;
cudaStream_t stream = runtime::getCurrentCUDAStream();
auto device = DeviceAPI::Get(ctx);
// new probs and prob sums pointers
assert(num_etypes == static_cast<int64_t>(prob.size()));
......@@ -426,7 +426,7 @@ std::pair<IdArray, IdArray> RandomWalkWithRestart(
auto device = dgl::runtime::DeviceAPI::Get(device_ctx);
// use cuda stream from local thread
cudaStream_t stream = runtime::CUDAThreadEntry::ThreadLocal()->stream;
cudaStream_t stream = runtime::getCurrentCUDAStream();
device->CopyDataFromTo(
&restart_prob, 0, restart_prob_array.Ptr<double>(), 0,
sizeof(double),
......@@ -486,7 +486,7 @@ std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors(
const int64_t num_dst_nodes = (dst->shape[0] / num_samples_per_node);
auto ctx = src->ctx;
// use cuda stream from local thread
cudaStream_t stream = runtime::CUDAThreadEntry::ThreadLocal()->stream;
cudaStream_t stream = runtime::getCurrentCUDAStream();
auto frequency_hashmap = FrequencyHashmap<IdxType>(num_dst_nodes,
num_samples_per_node, ctx, stream);
auto ret = frequency_hashmap.Topk(src_data, dst_data, src->dtype,
......
......@@ -88,10 +88,9 @@ CompactGraphsGPU(
const std::vector<HeteroGraphPtr> &graphs,
const std::vector<IdArray> &always_preserve) {
const auto& ctx = graphs[0]->Context();
auto device = runtime::DeviceAPI::Get(ctx);
cudaStream_t stream = runtime::CUDAThreadEntry::ThreadLocal()->stream;
cudaStream_t stream = runtime::getCurrentCUDAStream();
CHECK_EQ(ctx.device_type, kDLGPU);
......
......@@ -168,7 +168,7 @@ ToBlockGPU(
const auto& ctx = graph->Context();
auto device = runtime::DeviceAPI::Get(ctx);
cudaStream_t stream = runtime::CUDAThreadEntry::ThreadLocal()->stream;
cudaStream_t stream = runtime::getCurrentCUDAStream();
CHECK_EQ(ctx.device_type, kDLGPU);
for (const auto& nodes : rhs_nodes) {
......
......@@ -437,7 +437,7 @@ template <typename FloatType, typename IdType>
void BruteForceKNNCuda(const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
const auto& ctx = data_points->ctx;
auto device = runtime::DeviceAPI::Get(ctx);
const int64_t batch_size = data_offsets->shape[0] - 1;
......@@ -454,7 +454,7 @@ void BruteForceKNNCuda(const NDArray& data_points, const IdArray& data_offsets,
const int64_t block_size = cuda::FindNumThreads(query_points->shape[0]);
const int64_t num_blocks = (query_points->shape[0] - 1) / block_size + 1;
CUDA_KERNEL_CALL(BruteforceKnnKernel, num_blocks, block_size, 0, thr_entry->stream,
CUDA_KERNEL_CALL(BruteforceKnnKernel, num_blocks, block_size, 0, stream,
data_points_data, data_offsets_data, query_points_data, query_offsets_data,
k, dists, query_out, data_out, batch_size, feature_size);
......@@ -480,7 +480,7 @@ template <typename FloatType, typename IdType>
void BruteForceKNNSharedCuda(const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
const auto& ctx = data_points->ctx;
auto device = runtime::DeviceAPI::Get(ctx);
const int64_t batch_size = data_offsets->shape[0] - 1;
......@@ -512,17 +512,17 @@ void BruteForceKNNSharedCuda(const NDArray& data_points, const IdArray& data_off
int64_t temp_block_size = cuda::FindNumThreads(batch_size);
int64_t temp_num_blocks = (batch_size - 1) / temp_block_size + 1;
CUDA_KERNEL_CALL(GetNumBlockPerSegment, temp_num_blocks,
temp_block_size, 0, thr_entry->stream,
temp_block_size, 0, stream,
query_offsets_data, num_block_per_segment,
batch_size, block_size);
size_t prefix_temp_size = 0;
CUDA_CALL(cub::DeviceScan::ExclusiveSum(
nullptr, prefix_temp_size, num_block_per_segment,
num_block_prefixsum, batch_size, thr_entry->stream));
num_block_prefixsum, batch_size, stream));
void* prefix_temp = device->AllocWorkspace(ctx, prefix_temp_size);
CUDA_CALL(cub::DeviceScan::ExclusiveSum(
prefix_temp, prefix_temp_size, num_block_per_segment,
num_block_prefixsum, batch_size, thr_entry->stream));
num_block_prefixsum, batch_size, stream));
device->FreeWorkspace(ctx, prefix_temp);
int64_t num_blocks = 0, final_elem = 0, copyoffset = (batch_size - 1) * sizeof(IdType);
......@@ -547,13 +547,13 @@ void BruteForceKNNSharedCuda(const NDArray& data_points, const IdArray& data_off
ctx, num_blocks * sizeof(IdType)));
CUDA_KERNEL_CALL(
GetBlockInfo, temp_num_blocks, temp_block_size, 0,
thr_entry->stream, num_block_prefixsum, block_batch_id,
stream, num_block_prefixsum, block_batch_id,
local_block_id, batch_size, num_blocks);
FloatType* dists = static_cast<FloatType*>(device->AllocWorkspace(
ctx, k * query_points->shape[0] * sizeof(FloatType)));
CUDA_KERNEL_CALL(BruteforceKnnShareKernel, num_blocks, block_size,
single_shared_mem * block_size, thr_entry->stream, data_points_data,
single_shared_mem * block_size, stream, data_points_data,
data_offsets_data, query_points_data, query_offsets_data,
block_batch_id, local_block_id, k, dists, query_out,
data_out, batch_size, feature_size);
......@@ -834,7 +834,7 @@ template <DLDeviceType XPU, typename FloatType, typename IdType>
void NNDescent(const NDArray& points, const IdArray& offsets,
IdArray result, const int k, const int num_iters,
const int num_candidates, const double delta) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
const auto& ctx = points->ctx;
auto device = runtime::DeviceAPI::Get(ctx);
const int64_t num_nodes = points->shape[0];
......@@ -872,7 +872,7 @@ void NNDescent(const NDArray& points, const IdArray& offsets,
device->AllocWorkspace(ctx, sizeof(IdType)));
CUDA_CALL(cub::DeviceReduce::Sum(
nullptr, sum_temp_size, num_updates, total_num_updates_d, num_nodes, thr_entry->stream));
nullptr, sum_temp_size, num_updates, total_num_updates_d, num_nodes, stream));
IdType* sum_temp_storage = static_cast<IdType*>(
device->AllocWorkspace(ctx, sum_temp_size));
......@@ -880,7 +880,7 @@ void NNDescent(const NDArray& points, const IdArray& offsets,
seed = RandomEngine::ThreadLocal()->RandInt<uint64_t>(
std::numeric_limits<uint64_t>::max());
CUDA_KERNEL_CALL(
impl::RandomInitNeighborsKernel, num_blocks, block_size, 0, thr_entry->stream,
impl::RandomInitNeighborsKernel, num_blocks, block_size, 0, stream,
points_data, offsets_data, central_nodes, neighbors, distances, flags, k,
feature_size, batch_size, seed);
......@@ -890,19 +890,19 @@ void NNDescent(const NDArray& points, const IdArray& offsets,
std::numeric_limits<uint64_t>::max());
CUDA_KERNEL_CALL(
impl::FindCandidatesKernel, num_blocks, block_size, 0,
thr_entry->stream, offsets_data, new_candidates, old_candidates, neighbors,
stream, offsets_data, new_candidates, old_candidates, neighbors,
flags, seed, batch_size, num_candidates, k);
// update
CUDA_KERNEL_CALL(
impl::UpdateNeighborsKernel, num_blocks, block_size, 0, thr_entry->stream,
impl::UpdateNeighborsKernel, num_blocks, block_size, 0, stream,
points_data, offsets_data, neighbors, new_candidates, old_candidates, distances,
flags, num_updates, batch_size, num_candidates, k, feature_size);
total_num_updates = 0;
CUDA_CALL(cub::DeviceReduce::Sum(
sum_temp_storage, sum_temp_size, num_updates, total_num_updates_d, num_nodes,
thr_entry->stream));
stream));
device->CopyDataFromTo(
total_num_updates_d, 0, &total_num_updates, 0,
sizeof(IdType), ctx, DLContext{kDLCPU, 0},
......
......@@ -170,6 +170,11 @@ class UnitGraph::COO : public BaseHeteroGraph {
adj_.UnpinMemory_();
}
/*! \brief Record stream for the adj_: COOMatrix of the COO graph. */
void RecordStream(DGLStreamHandle stream) override {
adj_.RecordStream(stream);
}
bool IsMultigraph() const override {
return aten::COOHasDuplicate(adj_);
}
......@@ -575,6 +580,11 @@ class UnitGraph::CSR : public BaseHeteroGraph {
adj_.UnpinMemory_();
}
/*! \brief Record stream for the adj_: CSRMatrix of the CSR graph. */
void RecordStream(DGLStreamHandle stream) override {
adj_.RecordStream(stream);
}
bool IsMultigraph() const override {
return aten::CSRHasDuplicate(adj_);
}
......@@ -1313,6 +1323,16 @@ void UnitGraph::UnpinMemory_() {
this->coo_->UnpinMemory_();
}
void UnitGraph::RecordStream(DGLStreamHandle stream) {
if (this->in_csr_->defined())
this->in_csr_->RecordStream(stream);
if (this->out_csr_->defined())
this->out_csr_->RecordStream(stream);
if (this->coo_->defined())
this->coo_->RecordStream(stream);
this->recorded_streams.push_back(stream);
}
void UnitGraph::InvalidateCSR() {
this->out_csr_ = CSRPtr(new CSR());
}
......@@ -1402,8 +1422,12 @@ UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const {
else
ret = std::make_shared<CSR>(meta_graph(), newadj);
}
if (inplace && IsPinned())
in_csr_->PinMemory_();
if (inplace) {
if (IsPinned())
in_csr_->PinMemory_();
for (auto stream : recorded_streams)
in_csr_->RecordStream(stream);
}
}
return ret;
}
......@@ -1434,8 +1458,12 @@ UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const {
else
ret = std::make_shared<CSR>(meta_graph(), newadj);
}
if (inplace && IsPinned())
out_csr_->PinMemory_();
if (inplace) {
if (IsPinned())
out_csr_->PinMemory_();
for (auto stream : recorded_streams)
out_csr_->RecordStream(stream);
}
}
return ret;
}
......@@ -1464,8 +1492,12 @@ UnitGraph::COOPtr UnitGraph::GetCOO(bool inplace) const {
else
ret = std::make_shared<COO>(meta_graph(), newadj);
}
if (inplace && IsPinned())
coo_->PinMemory_();
if (inplace) {
if (IsPinned())
coo_->PinMemory_();
for (auto stream : recorded_streams)
coo_->RecordStream(stream);
}
}
return ret;
}
......
......@@ -228,6 +228,12 @@ class UnitGraph : public BaseHeteroGraph {
*/
void UnpinMemory_();
/*!
* \brief Record stream for this graph.
* \param stream The stream that is using the graph
*/
void RecordStream(DGLStreamHandle stream) override;
/*!
* \brief Create in-edge CSR format of the unit graph.
* \param inplace if true and the in-edge CSR format does not exist, the created
......@@ -375,6 +381,8 @@ class UnitGraph : public BaseHeteroGraph {
* \brief Storage format restriction.
*/
dgl_format_code_t formats_;
/*! \brief which streams have recorded the graph */
std::vector<DGLStreamHandle> recorded_streams;
};
}; // namespace dgl
......
......@@ -261,7 +261,7 @@ GeneratePermutationFromRemainder(
const auto& ctx = in_idx->ctx;
auto device = DeviceAPI::Get(ctx);
cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream;
cudaStream_t stream = runtime::getCurrentCUDAStream();
const int64_t num_in = in_idx->shape[0];
......@@ -392,7 +392,7 @@ IdArray MapToLocalFromRemainder(
const int num_parts,
IdArray global_idx) {
const auto& ctx = global_idx->ctx;
cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream;
cudaStream_t stream = runtime::getCurrentCUDAStream();
if (num_parts > 1) {
IdArray local_idx = aten::NewIdArray(global_idx->shape[0], ctx,
......@@ -439,7 +439,7 @@ IdArray MapToGlobalFromRemainder(
"/" << num_parts;
const auto& ctx = local_idx->ctx;
cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream;
cudaStream_t stream = runtime::getCurrentCUDAStream();
if (num_parts > 1) {
IdArray global_idx = aten::NewIdArray(local_idx->shape[0], ctx,
......@@ -492,7 +492,7 @@ GeneratePermutationFromRange(
const auto& ctx = in_idx->ctx;
auto device = DeviceAPI::Get(ctx);
cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream;
cudaStream_t stream = runtime::getCurrentCUDAStream();
const int64_t num_in = in_idx->shape[0];
......@@ -628,7 +628,7 @@ IdArray MapToLocalFromRange(
IdArray range,
IdArray global_idx) {
const auto& ctx = global_idx->ctx;
cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream;
cudaStream_t stream = runtime::getCurrentCUDAStream();
if (num_parts > 1 && global_idx->shape[0] > 0) {
IdArray local_idx = aten::NewIdArray(global_idx->shape[0], ctx,
......@@ -690,7 +690,7 @@ IdArray MapToGlobalFromRange(
"/" << num_parts;
const auto& ctx = local_idx->ctx;
cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream;
cudaStream_t stream = runtime::getCurrentCUDAStream();
if (num_parts > 1 && local_idx->shape[0] > 0) {
IdArray global_idx = aten::NewIdArray(local_idx->shape[0], ctx,
......
......@@ -158,8 +158,6 @@ struct cusparse_idtype<int64_t> {
/*! \brief Thread local workspace */
class CUDAThreadEntry {
public:
/*! \brief The cuda stream */
cudaStream_t stream{nullptr};
/*! \brief The cusparse handler */
cusparseHandle_t cusparse_handle{nullptr};
/*! \brief The cublas handler */
......@@ -173,6 +171,9 @@ class CUDAThreadEntry {
// get the threadlocal workspace
static CUDAThreadEntry* ThreadLocal();
};
/*! \brief Get the current CUDA stream */
cudaStream_t getCurrentCUDAStream();
} // namespace runtime
} // namespace dgl
#endif // DGL_RUNTIME_CUDA_CUDA_COMMON_H_
......@@ -111,7 +111,7 @@ class CUDADeviceAPI final : public DeviceAPI {
// Redirect to PyTorch's allocator when available.
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
return td->CUDAAllocWorkspace(nbytes, CUDAThreadEntry::ThreadLocal()->stream);
return td->CUDAAllocWorkspace(nbytes, getCurrentCUDAStream());
CHECK_EQ(256 % alignment, 0U)
<< "CUDA space is aligned at 256 bytes";
......@@ -169,7 +169,7 @@ class CUDADeviceAPI final : public DeviceAPI {
DGLContext ctx_from,
DGLContext ctx_to,
DGLType type_hint) final {
auto stream = static_cast<DGLStreamHandle>(CUDAThreadEntry::ThreadLocal()->stream);
auto stream = GetStream();
CopyDataFromTo(from, from_offset, to, to_offset, size, ctx_from, ctx_to, type_hint, stream);
}
......@@ -203,13 +203,16 @@ class CUDADeviceAPI final : public DeviceAPI {
CUDA_CALL(cudaStreamSynchronize(static_cast<cudaStream_t>(stream)));
}
void SetStream(DGLContext ctx, DGLStreamHandle stream) final {
CUDAThreadEntry::ThreadLocal()
->stream = static_cast<cudaStream_t>(stream);
}
/*! NOTE: If the backend is PyTorch, we will use PyTorch's stream management,
* so just avoid calling our SetStream/CreateStream unless
* you really need advanced stream control.
* TODO(Xin): Redirect this to PyTorch or remove it.
* PyTorch allows external CUDA streams to be set as current since v1.11.
*/
void SetStream(DGLContext ctx, DGLStreamHandle stream) final {}
DGLStreamHandle GetStream() const final {
return static_cast<DGLStreamHandle>(CUDAThreadEntry::ThreadLocal()->stream);
return static_cast<DGLStreamHandle>(getCurrentCUDAStream());
}
/*! NOTE: cudaHostRegister can be called from an arbitrary GPU device,
......@@ -271,7 +274,7 @@ class CUDADeviceAPI final : public DeviceAPI {
// Redirect to PyTorch's allocator when available.
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
return td->CUDAAllocWorkspace(size, CUDAThreadEntry::ThreadLocal()->stream);
return td->CUDAAllocWorkspace(size, getCurrentCUDAStream());
return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
}
......@@ -309,19 +312,22 @@ class CUDADeviceAPI final : public DeviceAPI {
typedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore;
// TODO(cliu): cuda streams should depend on the current device, therefore we should set device
// before setting stream.
CUDAThreadEntry::CUDAThreadEntry()
: pool(kDLGPU, CUDADeviceAPI::Global()) {
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
stream = td->CUDAGetCurrentStream();
}
CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() {
return CUDAThreadStore::Get();
}
cudaStream_t getCurrentCUDAStream() {
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
return td->CUDAGetCurrentStream();
else // return the default stream when TA is not available
return nullptr;
}
DGL_REGISTER_GLOBAL("device_api.gpu")
.set_body([](DGLArgs args, DGLRetValue* rv) {
DeviceAPI* ptr = CUDADeviceAPI::Global().get();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment