Unverified Commit 9a00cf19 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Feature] Import PyTorch's CUDA stream management (#4503)

* add set_stream

* add .record_stream for NDArray and HeteroGraph

* refactor dgl stream Python APIs

* test record_stream

* add unit test for record stream

* use pytorch's stream

* fix lint

* fix cpu build

* address comments

* address comments

* add record stream tests for dgl.graph

* record frames and update dataloder

* add docstring

* update frame

* add backend check for record_stream

* remove CUDAThreadEntry::stream

* record stream for newly created formats

* fix bug

* fix cpp test

* fix None c_void_p to c_handle
parent 099b173f
...@@ -23,7 +23,7 @@ namespace impl { ...@@ -23,7 +23,7 @@ namespace impl {
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) { bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const auto& ctx = csr.indptr->ctx; const auto& ctx = csr.indptr->ctx;
IdArray rows = aten::VecToIdArray<int64_t>({row}, sizeof(IdType) * 8, ctx); IdArray rows = aten::VecToIdArray<int64_t>({row}, sizeof(IdType) * 8, ctx);
IdArray cols = aten::VecToIdArray<int64_t>({col}, sizeof(IdType) * 8, ctx); IdArray cols = aten::VecToIdArray<int64_t>({col}, sizeof(IdType) * 8, ctx);
...@@ -33,7 +33,7 @@ bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) { ...@@ -33,7 +33,7 @@ bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
const IdType* data = nullptr; const IdType* data = nullptr;
// TODO(minjie): use binary search for sorted csr // TODO(minjie): use binary search for sorted csr
CUDA_KERNEL_CALL(dgl::cuda::_LinearSearchKernel, CUDA_KERNEL_CALL(dgl::cuda::_LinearSearchKernel,
1, 1, 0, thr_entry->stream, 1, 1, 0, stream,
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), data, csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), data,
rows.Ptr<IdType>(), cols.Ptr<IdType>(), rows.Ptr<IdType>(), cols.Ptr<IdType>(),
1, 1, 1, 1, 1, 1,
...@@ -55,13 +55,13 @@ NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) { ...@@ -55,13 +55,13 @@ NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
return rst; return rst;
const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1; const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1; const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const int nt = dgl::cuda::FindNumThreads(rstlen); const int nt = dgl::cuda::FindNumThreads(rstlen);
const int nb = (rstlen + nt - 1) / nt; const int nb = (rstlen + nt - 1) / nt;
const IdType* data = nullptr; const IdType* data = nullptr;
// TODO(minjie): use binary search for sorted csr // TODO(minjie): use binary search for sorted csr
CUDA_KERNEL_CALL(dgl::cuda::_LinearSearchKernel, CUDA_KERNEL_CALL(dgl::cuda::_LinearSearchKernel,
nb, nt, 0, thr_entry->stream, nb, nt, 0, stream,
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), data, csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), data,
row.Ptr<IdType>(), col.Ptr<IdType>(), row.Ptr<IdType>(), col.Ptr<IdType>(),
row_stride, col_stride, rstlen, row_stride, col_stride, rstlen,
...@@ -100,7 +100,7 @@ bool CSRHasDuplicate(CSRMatrix csr) { ...@@ -100,7 +100,7 @@ bool CSRHasDuplicate(CSRMatrix csr) {
if (!csr.sorted) if (!csr.sorted)
csr = CSRSort(csr); csr = CSRSort(csr);
const auto& ctx = csr.indptr->ctx; const auto& ctx = csr.indptr->ctx;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
// We allocate a workspace of num_rows bytes. It wastes a little bit memory but should // We allocate a workspace of num_rows bytes. It wastes a little bit memory but should
// be fine. // be fine.
...@@ -108,7 +108,7 @@ bool CSRHasDuplicate(CSRMatrix csr) { ...@@ -108,7 +108,7 @@ bool CSRHasDuplicate(CSRMatrix csr) {
const int nt = dgl::cuda::FindNumThreads(csr.num_rows); const int nt = dgl::cuda::FindNumThreads(csr.num_rows);
const int nb = (csr.num_rows + nt - 1) / nt; const int nb = (csr.num_rows + nt - 1) / nt;
CUDA_KERNEL_CALL(_SegmentHasNoDuplicate, CUDA_KERNEL_CALL(_SegmentHasNoDuplicate,
nb, nt, 0, thr_entry->stream, nb, nt, 0, stream,
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(),
csr.num_rows, flags); csr.num_rows, flags);
bool ret = dgl::cuda::AllTrue(flags, csr.num_rows, ctx); bool ret = dgl::cuda::AllTrue(flags, csr.num_rows, ctx);
...@@ -148,7 +148,7 @@ __global__ void _CSRGetRowNNZKernel( ...@@ -148,7 +148,7 @@ __global__ void _CSRGetRowNNZKernel(
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) { NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const auto len = rows->shape[0]; const auto len = rows->shape[0];
const IdType* vid_data = static_cast<IdType*>(rows->data); const IdType* vid_data = static_cast<IdType*>(rows->data);
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data); const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
...@@ -157,7 +157,7 @@ NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) { ...@@ -157,7 +157,7 @@ NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {
const int nt = dgl::cuda::FindNumThreads(len); const int nt = dgl::cuda::FindNumThreads(len);
const int nb = (len + nt - 1) / nt; const int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL(_CSRGetRowNNZKernel, CUDA_KERNEL_CALL(_CSRGetRowNNZKernel,
nb, nt, 0, thr_entry->stream, nb, nt, 0, stream,
vid_data, indptr_data, rst_data, len); vid_data, indptr_data, rst_data, len);
return rst; return rst;
} }
...@@ -245,7 +245,7 @@ __global__ void _SegmentCopyKernel( ...@@ -245,7 +245,7 @@ __global__ void _SegmentCopyKernel(
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) { CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const int64_t len = rows->shape[0]; const int64_t len = rows->shape[0];
IdArray ret_indptr = aten::CumSum(aten::CSRGetRowNNZ(csr, rows), true); IdArray ret_indptr = aten::CumSum(aten::CSRGetRowNNZ(csr, rows), true);
const int64_t nnz = aten::IndexSelect<IdType>(ret_indptr, len); const int64_t nnz = aten::IndexSelect<IdType>(ret_indptr, len);
...@@ -256,14 +256,14 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) { ...@@ -256,14 +256,14 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
// Copy indices. // Copy indices.
IdArray ret_indices = NDArray::Empty({nnz}, csr.indptr->dtype, rows->ctx); IdArray ret_indices = NDArray::Empty({nnz}, csr.indptr->dtype, rows->ctx);
CUDA_KERNEL_CALL(_SegmentCopyKernel, CUDA_KERNEL_CALL(_SegmentCopyKernel,
nb, nt, 0, thr_entry->stream, nb, nt, 0, stream,
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(),
rows.Ptr<IdType>(), nnz, len, rows.Ptr<IdType>(), nnz, len,
ret_indptr.Ptr<IdType>(), ret_indices.Ptr<IdType>()); ret_indptr.Ptr<IdType>(), ret_indices.Ptr<IdType>());
// Copy data. // Copy data.
IdArray ret_data = NDArray::Empty({nnz}, csr.indptr->dtype, rows->ctx); IdArray ret_data = NDArray::Empty({nnz}, csr.indptr->dtype, rows->ctx);
CUDA_KERNEL_CALL(_SegmentCopyKernel, CUDA_KERNEL_CALL(_SegmentCopyKernel,
nb, nt, 0, thr_entry->stream, nb, nt, 0, stream,
csr.indptr.Ptr<IdType>(), CSRHasData(csr)? csr.data.Ptr<IdType>() : nullptr, csr.indptr.Ptr<IdType>(), CSRHasData(csr)? csr.data.Ptr<IdType>() : nullptr,
rows.Ptr<IdType>(), nnz, len, rows.Ptr<IdType>(), nnz, len,
ret_indptr.Ptr<IdType>(), ret_data.Ptr<IdType>()); ret_indptr.Ptr<IdType>(), ret_data.Ptr<IdType>());
...@@ -358,14 +358,14 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray row, NDArray co ...@@ -358,14 +358,14 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray row, NDArray co
const int64_t nnz = csr.indices->shape[0]; const int64_t nnz = csr.indices->shape[0];
const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1; const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1; const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
// Generate a 0-1 mask for matched (row, col) positions. // Generate a 0-1 mask for matched (row, col) positions.
IdArray mask = Full(0, nnz, nbits, ctx); IdArray mask = Full(0, nnz, nbits, ctx);
const int nt = dgl::cuda::FindNumThreads(len); const int nt = dgl::cuda::FindNumThreads(len);
const int nb = (len + nt - 1) / nt; const int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL(_SegmentMaskKernel, CUDA_KERNEL_CALL(_SegmentMaskKernel,
nb, nt, 0, thr_entry->stream, nb, nt, 0, stream,
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(),
row.Ptr<IdType>(), col.Ptr<IdType>(), row.Ptr<IdType>(), col.Ptr<IdType>(),
row_stride, col_stride, len, row_stride, col_stride, len,
...@@ -381,7 +381,7 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray row, NDArray co ...@@ -381,7 +381,7 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray row, NDArray co
const int nt2 = dgl::cuda::FindNumThreads(idx->shape[0]); const int nt2 = dgl::cuda::FindNumThreads(idx->shape[0]);
const int nb2 = (idx->shape[0] + nt - 1) / nt; const int nb2 = (idx->shape[0] + nt - 1) / nt;
CUDA_KERNEL_CALL(_SortedSearchKernel, CUDA_KERNEL_CALL(_SortedSearchKernel,
nb2, nt2, 0, thr_entry->stream, nb2, nt2, 0, stream,
csr.indptr.Ptr<IdType>(), csr.num_rows, csr.indptr.Ptr<IdType>(), csr.num_rows,
idx.Ptr<IdType>(), idx->shape[0], idx.Ptr<IdType>(), idx->shape[0],
ret_row.Ptr<IdType>()); ret_row.Ptr<IdType>());
...@@ -424,7 +424,7 @@ __global__ void _SegmentMaskColKernel( ...@@ -424,7 +424,7 @@ __global__ void _SegmentMaskColKernel(
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols) { CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const auto& ctx = rows->ctx; const auto& ctx = rows->ctx;
const auto& dtype = rows->dtype; const auto& dtype = rows->dtype;
const auto nbits = dtype.bits; const auto nbits = dtype.bits;
...@@ -464,17 +464,17 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray ...@@ -464,17 +464,17 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray
size_t workspace_size = 0; size_t workspace_size = 0;
CUDA_CALL(cub::DeviceRadixSort::SortKeys( CUDA_CALL(cub::DeviceRadixSort::SortKeys(
nullptr, workspace_size, ptr_cols, ptr_sorted_cols, cols->shape[0], nullptr, workspace_size, ptr_cols, ptr_sorted_cols, cols->shape[0],
0, sizeof(IdType)*8, thr_entry->stream)); 0, sizeof(IdType)*8, stream));
void *workspace = device->AllocWorkspace(ctx, workspace_size); void *workspace = device->AllocWorkspace(ctx, workspace_size);
CUDA_CALL(cub::DeviceRadixSort::SortKeys( CUDA_CALL(cub::DeviceRadixSort::SortKeys(
workspace, workspace_size, ptr_cols, ptr_sorted_cols, cols->shape[0], workspace, workspace_size, ptr_cols, ptr_sorted_cols, cols->shape[0],
0, sizeof(IdType)*8, thr_entry->stream)); 0, sizeof(IdType)*8, stream));
device->FreeWorkspace(ctx, workspace); device->FreeWorkspace(ctx, workspace);
// Execute SegmentMaskColKernel // Execute SegmentMaskColKernel
int nb = (nnz_csr + nt - 1) / nt; int nb = (nnz_csr + nt - 1) / nt;
CUDA_KERNEL_CALL(_SegmentMaskColKernel, CUDA_KERNEL_CALL(_SegmentMaskColKernel,
nb, nt, 0, thr_entry->stream, nb, nt, 0, stream,
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), csr.num_rows, nnz_csr, csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), csr.num_rows, nnz_csr,
ptr_sorted_cols, cols_size, ptr_sorted_cols, cols_size,
mask.Ptr<IdType>(), count.Ptr<IdType>()); mask.Ptr<IdType>(), count.Ptr<IdType>());
......
...@@ -103,9 +103,10 @@ void _Transpose(const DType* in, DType* out, ...@@ -103,9 +103,10 @@ void _Transpose(const DType* in, DType* out,
int row, int col) { int row, int col) {
DType alpha = 1., beta = 0.; DType alpha = 1., beta = 0.;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
if (!thr_entry->cublas_handle) if (!thr_entry->cublas_handle)
CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle))); CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle)));
CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle, thr_entry->stream)); CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle, stream));
CUBLAS_CALL(Xgeam<DType>( CUBLAS_CALL(Xgeam<DType>(
thr_entry->cublas_handle, thr_entry->cublas_handle,
CUBLAS_OP_T, CUBLAS_OP_T,
...@@ -123,10 +124,10 @@ void _Transpose(const DType* in, DType* out, ...@@ -123,10 +124,10 @@ void _Transpose(const DType* in, DType* out,
template <> template <>
void _Transpose<half>(const half* in, half* out, void _Transpose<half>(const half* in, half* out,
int row, int col) { int row, int col) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = FindNumThreads(row); int nt = FindNumThreads(row);
int nb = col; int nb = col;
CUDA_KERNEL_CALL(_TransposeKernel, nb, nt, 0, thr_entry->stream, in, out, col, row); CUDA_KERNEL_CALL(_TransposeKernel, nb, nt, 0, stream, in, out, col, row);
} }
/* /*
...@@ -149,7 +150,7 @@ __global__ void _IndexSelectKernel(const DType* array, const IdType* index, ...@@ -149,7 +150,7 @@ __global__ void _IndexSelectKernel(const DType* array, const IdType* index,
*/ */
template<typename DType, typename IdType> template<typename DType, typename IdType>
NDArray _IndexSelect(NDArray array, NDArray index) { NDArray _IndexSelect(NDArray array, NDArray index) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const DType* array_data = static_cast<DType*>(array->data); const DType* array_data = static_cast<DType*>(array->data);
const IdType* idx_data = static_cast<IdType*>(index->data); const IdType* idx_data = static_cast<IdType*>(index->data);
const int64_t arr_len = array->shape[0]; const int64_t arr_len = array->shape[0];
...@@ -160,7 +161,7 @@ NDArray _IndexSelect(NDArray array, NDArray index) { ...@@ -160,7 +161,7 @@ NDArray _IndexSelect(NDArray array, NDArray index) {
DType* ret_data = static_cast<DType*>(ret->data); DType* ret_data = static_cast<DType*>(ret->data);
const int nt = FindNumThreads(len); const int nt = FindNumThreads(len);
const int nb = (len + nt - 1) / nt; const int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL(_IndexSelectKernel, nb, nt, 0, thr_entry->stream, CUDA_KERNEL_CALL(_IndexSelectKernel, nb, nt, 0, stream,
array_data, idx_data, len, ret_data); array_data, idx_data, len, ret_data);
return ret; return ret;
} }
...@@ -223,11 +224,12 @@ void CusparseCsrmm2( ...@@ -223,11 +224,12 @@ void CusparseCsrmm2(
// device // device
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
// allocate cusparse handle if needed // allocate cusparse handle if needed
if (!thr_entry->cusparse_handle) { if (!thr_entry->cusparse_handle) {
CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle))); CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
} }
CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, thr_entry->stream)); CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, stream));
// all one data array // all one data array
DType* valptr = nullptr; DType* valptr = nullptr;
if (!A_data) { if (!A_data) {
...@@ -678,7 +680,7 @@ void SpMMCoo( ...@@ -678,7 +680,7 @@ void SpMMCoo(
DType *out_data = out.Ptr<DType>(); DType *out_data = out.Ptr<DType>();
Idx *argu_data = argu.Ptr<Idx>(), Idx *argu_data = argu.Ptr<Idx>(),
*arge_data = arge.Ptr<Idx>(); *arge_data = arge.Ptr<Idx>();
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const int64_t N = coo.num_rows, M = coo.num_cols, E = coo.row->shape[0]; const int64_t N = coo.num_rows, M = coo.num_cols, E = coo.row->shape[0];
int64_t *ubcast_off = nullptr, *ebcast_off = nullptr; int64_t *ubcast_off = nullptr, *ebcast_off = nullptr;
...@@ -689,7 +691,7 @@ void SpMMCoo( ...@@ -689,7 +691,7 @@ void SpMMCoo(
int64_t out_size = out.NumElements(); int64_t out_size = out.NumElements();
const int nt = FindNumThreads(out_size); const int nt = FindNumThreads(out_size);
const int nb = (out_size + nt - 1) / nt; const int nb = (out_size + nt - 1) / nt;
CUDA_KERNEL_CALL(_FillKernel, nb, nt, 0, thr_entry->stream, CUDA_KERNEL_CALL(_FillKernel, nb, nt, 0, stream,
out_data, out_size, ReduceOp::zero()); out_data, out_size, ReduceOp::zero());
const int ntx = FindNumThreads(len); const int ntx = FindNumThreads(len);
...@@ -703,7 +705,7 @@ void SpMMCoo( ...@@ -703,7 +705,7 @@ void SpMMCoo(
BCAST_IDX_CTX_SWITCH(bcast, use_idx, ufeat->ctx, ubcast_off, ebcast_off, { BCAST_IDX_CTX_SWITCH(bcast, use_idx, ufeat->ctx, ubcast_off, ebcast_off, {
CUDA_KERNEL_CALL((SpMMCooKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>), CUDA_KERNEL_CALL((SpMMCooKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>),
nblks, nthrs, 0, thr_entry->stream, nblks, nthrs, 0, stream,
ufeat_data, efeat_data, out_data, argu_data, arge_data, ufeat_data, efeat_data, out_data, argu_data, arge_data,
row, col, edge_map, row, col, edge_map,
N, M, E, N, M, E,
...@@ -711,7 +713,7 @@ void SpMMCoo( ...@@ -711,7 +713,7 @@ void SpMMCoo(
lhs_len, rhs_len, len); lhs_len, rhs_len, len);
if (ReduceOp::require_arg) { if (ReduceOp::require_arg) {
CUDA_KERNEL_CALL((ArgSpMMCooKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>), CUDA_KERNEL_CALL((ArgSpMMCooKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>),
nblks, nthrs, 0, thr_entry->stream, nblks, nthrs, 0, stream,
ufeat_data, efeat_data, out_data, argu_data, arge_data, ufeat_data, efeat_data, out_data, argu_data, arge_data,
row, col, edge_map, row, col, edge_map,
N, M, E, N, M, E,
...@@ -751,7 +753,7 @@ void SpMMCsr( ...@@ -751,7 +753,7 @@ void SpMMCsr(
Idx* argu_data = argu.Ptr<Idx>(); Idx* argu_data = argu.Ptr<Idx>();
Idx* arge_data = arge.Ptr<Idx>(); Idx* arge_data = arge.Ptr<Idx>();
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int64_t *ubcast_off = nullptr, *ebcast_off = nullptr; int64_t *ubcast_off = nullptr, *ebcast_off = nullptr;
int64_t len = bcast.out_len, int64_t len = bcast.out_len,
...@@ -768,7 +770,7 @@ void SpMMCsr( ...@@ -768,7 +770,7 @@ void SpMMCsr(
BCAST_IDX_CTX_SWITCH(bcast, use_idx, ufeat->ctx, ubcast_off, ebcast_off, { BCAST_IDX_CTX_SWITCH(bcast, use_idx, ufeat->ctx, ubcast_off, ebcast_off, {
CUDA_KERNEL_CALL((SpMMCsrKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>), CUDA_KERNEL_CALL((SpMMCsrKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>),
nblks, nthrs, 0, thr_entry->stream, nblks, nthrs, 0, stream,
ufeat_data, efeat_data, out_data, argu_data, arge_data, ufeat_data, efeat_data, out_data, argu_data, arge_data,
indptr, indices, edge_map, indptr, indices, edge_map,
csr.num_rows, csr.num_cols, csr.num_rows, csr.num_cols,
...@@ -817,7 +819,7 @@ void SpMMCmpCsrHetero( ...@@ -817,7 +819,7 @@ void SpMMCmpCsrHetero(
Idx* argu_data = argu.Ptr<Idx>(); Idx* argu_data = argu.Ptr<Idx>();
Idx* arge_data = arge.Ptr<Idx>(); Idx* arge_data = arge.Ptr<Idx>();
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int64_t *ubcast_off = nullptr, *ebcast_off = nullptr; int64_t *ubcast_off = nullptr, *ebcast_off = nullptr;
int64_t len = bcast.out_len, int64_t len = bcast.out_len,
...@@ -833,7 +835,7 @@ void SpMMCmpCsrHetero( ...@@ -833,7 +835,7 @@ void SpMMCmpCsrHetero(
BCAST_IDX_CTX_SWITCH(bcast, use_idx, ufeat->ctx, ubcast_off, ebcast_off, { BCAST_IDX_CTX_SWITCH(bcast, use_idx, ufeat->ctx, ubcast_off, ebcast_off, {
CUDA_KERNEL_CALL((SpMMCmpCsrHeteroKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>), CUDA_KERNEL_CALL((SpMMCmpCsrHeteroKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>),
nblks, nthrs, 0, thr_entry->stream, nblks, nthrs, 0, stream,
ufeat_data, efeat_data, out_data, argu_data, arge_data, ufeat_data, efeat_data, out_data, argu_data, arge_data,
static_cast<Idx*>(argu_ntype->data), static_cast<Idx*>(argu_ntype->data),
static_cast<Idx*>(arge_etype->data), static_cast<Idx*>(arge_etype->data),
......
...@@ -118,7 +118,7 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -118,7 +118,7 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
} }
} }
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) { for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) {
const dgl_type_t src_id = ufeat_ntids[etype]; const dgl_type_t src_id = ufeat_ntids[etype];
const dgl_type_t dst_id = out_ntids[etype]; const dgl_type_t dst_id = out_ntids[etype];
...@@ -135,7 +135,7 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -135,7 +135,7 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
static_cast<DType*>(vec_ufeat[src_id]->data), static_cast<DType*>(vec_ufeat[src_id]->data),
nullptr, nullptr,
out, out,
x_length, thr_entry->stream); x_length, stream);
} else if (op == "mul" && is_scalar_efeat && } else if (op == "mul" && is_scalar_efeat &&
cusparse_available<bits, IdType>(more_nnz)) { // cusparse cusparse_available<bits, IdType>(more_nnz)) { // cusparse
NDArray efeat = vec_efeat[etype]; NDArray efeat = vec_efeat[etype];
...@@ -147,7 +147,7 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -147,7 +147,7 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
static_cast<DType*>(efeat->data), static_cast<DType*>(efeat->data),
// TODO(Israt): Change (*vec_out) to trans_out to support CUDA version < 11 // TODO(Israt): Change (*vec_out) to trans_out to support CUDA version < 11
static_cast<DType*>((*vec_out)[dst_id]->data), static_cast<DType*>((*vec_out)[dst_id]->data),
x_length, thr_entry->stream); x_length, stream);
} else { // general kernel } else { // general kernel
NDArray ufeat = (vec_ufeat.size() == 0) ? NDArray ufeat = (vec_ufeat.size() == 0) ?
NullArray() : vec_ufeat[src_id]; NullArray() : vec_ufeat[src_id];
......
...@@ -16,7 +16,7 @@ bool AllTrue(int8_t* flags, int64_t length, const DLContext& ctx) { ...@@ -16,7 +16,7 @@ bool AllTrue(int8_t* flags, int64_t length, const DLContext& ctx) {
int8_t* rst = static_cast<int8_t*>(device->AllocWorkspace(ctx, 1)); int8_t* rst = static_cast<int8_t*>(device->AllocWorkspace(ctx, 1));
// Call CUB's reduction // Call CUB's reduction
size_t workspace_size = 0; size_t workspace_size = 0;
cudaStream_t stream = runtime::CUDAThreadEntry::ThreadLocal()->stream; cudaStream_t stream = runtime::getCurrentCUDAStream();
CUDA_CALL(cub::DeviceReduce::Min(nullptr, workspace_size, flags, rst, length, stream)); CUDA_CALL(cub::DeviceReduce::Min(nullptr, workspace_size, flags, rst, length, stream));
void* workspace = device->AllocWorkspace(ctx, workspace_size); void* workspace = device->AllocWorkspace(ctx, workspace_size);
CUDA_CALL(cub::DeviceReduce::Min(workspace, workspace_size, flags, rst, length, stream)); CUDA_CALL(cub::DeviceReduce::Min(workspace, workspace_size, flags, rst, length, stream));
......
...@@ -135,10 +135,10 @@ __global__ void _FillKernel(DType* ptr, size_t length, DType val) { ...@@ -135,10 +135,10 @@ __global__ void _FillKernel(DType* ptr, size_t length, DType val) {
/*! \brief Fill the vector started from ptr of size length with val */ /*! \brief Fill the vector started from ptr of size length with val */
template <typename DType> template <typename DType>
void _Fill(DType* ptr, size_t length, DType val) { void _Fill(DType* ptr, size_t length, DType val) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = FindNumThreads(length); int nt = FindNumThreads(length);
int nb = (length + nt - 1) / nt; // on x-axis, no need to worry about upperbound. int nb = (length + nt - 1) / nt; // on x-axis, no need to worry about upperbound.
CUDA_KERNEL_CALL(cuda::_FillKernel, nb, nt, 0, thr_entry->stream, ptr, length, val); CUDA_KERNEL_CALL(cuda::_FillKernel, nb, nt, 0, stream, ptr, length, val);
} }
/*! /*!
......
...@@ -16,7 +16,7 @@ namespace impl { ...@@ -16,7 +16,7 @@ namespace impl {
template<typename DType, typename IdType> template<typename DType, typename IdType>
NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) { NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const DType* array_data = static_cast<DType*>(array->data); const DType* array_data = static_cast<DType*>(array->data);
const IdType* idx_data = static_cast<IdType*>(index->data); const IdType* idx_data = static_cast<IdType*>(index->data);
const int64_t arr_len = array->shape[0]; const int64_t arr_len = array->shape[0];
...@@ -41,7 +41,7 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) { ...@@ -41,7 +41,7 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
const int nt = cuda::FindNumThreads(len); const int nt = cuda::FindNumThreads(len);
const int nb = (len + nt - 1) / nt; const int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL(IndexSelectSingleKernel, nb, nt, 0, CUDA_KERNEL_CALL(IndexSelectSingleKernel, nb, nt, 0,
thr_entry->stream, array_data, idx_data, len, arr_len, ret_data); stream, array_data, idx_data, len, arr_len, ret_data);
} else { } else {
dim3 block(256, 1); dim3 block(256, 1);
while (static_cast<int64_t>(block.x) >= 2*num_feat) { while (static_cast<int64_t>(block.x) >= 2*num_feat) {
...@@ -51,11 +51,11 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) { ...@@ -51,11 +51,11 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
const dim3 grid((len+block.y-1)/block.y); const dim3 grid((len+block.y-1)/block.y);
if (num_feat * sizeof(DType) < 2 * CACHE_LINE_SIZE) { if (num_feat * sizeof(DType) < 2 * CACHE_LINE_SIZE) {
CUDA_KERNEL_CALL(IndexSelectMultiKernel, grid, block, 0, CUDA_KERNEL_CALL(IndexSelectMultiKernel, grid, block, 0,
thr_entry->stream, array_data, num_feat, idx_data, stream, array_data, num_feat, idx_data,
len, arr_len, ret_data); len, arr_len, ret_data);
} else { } else {
CUDA_KERNEL_CALL(IndexSelectMultiKernelAligned, grid, block, 0, CUDA_KERNEL_CALL(IndexSelectMultiKernelAligned, grid, block, 0,
thr_entry->stream, array_data, num_feat, idx_data, stream, array_data, num_feat, idx_data,
len, arr_len, ret_data); len, arr_len, ret_data);
} }
} }
...@@ -75,7 +75,7 @@ template NDArray IndexSelectCPUFromGPU<int64_t, int64_t>(NDArray, IdArray); ...@@ -75,7 +75,7 @@ template NDArray IndexSelectCPUFromGPU<int64_t, int64_t>(NDArray, IdArray);
template<typename DType, typename IdType> template<typename DType, typename IdType>
void IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source) { void IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
DType* dest_data = static_cast<DType*>(dest->data); DType* dest_data = static_cast<DType*>(dest->data);
const DType* source_data = static_cast<DType*>(source->data); const DType* source_data = static_cast<DType*>(source->data);
const IdType* idx_data = static_cast<IdType*>(index->data); const IdType* idx_data = static_cast<IdType*>(index->data);
...@@ -99,7 +99,7 @@ void IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source) { ...@@ -99,7 +99,7 @@ void IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source) {
const int nt = cuda::FindNumThreads(len); const int nt = cuda::FindNumThreads(len);
const int nb = (len + nt - 1) / nt; const int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL(IndexScatterSingleKernel, nb, nt, 0, CUDA_KERNEL_CALL(IndexScatterSingleKernel, nb, nt, 0,
thr_entry->stream, source_data, idx_data, len, arr_len, dest_data); stream, source_data, idx_data, len, arr_len, dest_data);
} else { } else {
dim3 block(256, 1); dim3 block(256, 1);
while (static_cast<int64_t>(block.x) >= 2*num_feat) { while (static_cast<int64_t>(block.x) >= 2*num_feat) {
...@@ -108,7 +108,7 @@ void IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source) { ...@@ -108,7 +108,7 @@ void IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source) {
} }
const dim3 grid((len+block.y-1)/block.y); const dim3 grid((len+block.y-1)/block.y);
CUDA_KERNEL_CALL(IndexScatterMultiKernel, grid, block, 0, CUDA_KERNEL_CALL(IndexScatterMultiKernel, grid, block, 0,
thr_entry->stream, source_data, num_feat, idx_data, stream, source_data, num_feat, idx_data,
len, arr_len, dest_data); len, arr_len, dest_data);
} }
} }
......
...@@ -117,18 +117,18 @@ __global__ void weighted_respond_kernel(const IdType *indptr, const IdType *indi ...@@ -117,18 +117,18 @@ __global__ void weighted_respond_kernel(const IdType *indptr, const IdType *indi
template<typename IdType> template<typename IdType>
bool Colorize(IdType * result_data, int64_t num_nodes, float * const prop) { bool Colorize(IdType * result_data, int64_t num_nodes, float * const prop) {
// initial done signal // initial done signal
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
CUDA_KERNEL_CALL(init_done_kernel, 1, 1, 0, thr_entry->stream); CUDA_KERNEL_CALL(init_done_kernel, 1, 1, 0, stream);
// generate color prop for each node // generate color prop for each node
uint64_t seed = dgl::RandomEngine::ThreadLocal()->RandInt(UINT64_MAX); uint64_t seed = dgl::RandomEngine::ThreadLocal()->RandInt(UINT64_MAX);
auto num_threads = cuda::FindNumThreads(num_nodes); auto num_threads = cuda::FindNumThreads(num_nodes);
auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_nodes, num_threads)); auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_nodes, num_threads));
CUDA_KERNEL_CALL(generate_uniform_kernel, num_blocks, num_threads, 0, thr_entry->stream, CUDA_KERNEL_CALL(generate_uniform_kernel, num_blocks, num_threads, 0, stream,
prop, num_nodes, seed); prop, num_nodes, seed);
// call kernel // call kernel
CUDA_KERNEL_CALL(colorize_kernel, num_blocks, num_threads, 0, thr_entry->stream, CUDA_KERNEL_CALL(colorize_kernel, num_blocks, num_threads, 0, stream,
prop, num_nodes, result_data); prop, num_nodes, result_data);
bool done_h = false; bool done_h = false;
CUDA_CALL(cudaMemcpyFromSymbol(&done_h, done_d, sizeof(done_h), 0, cudaMemcpyDeviceToHost)); CUDA_CALL(cudaMemcpyFromSymbol(&done_h, done_d, sizeof(done_h), 0, cudaMemcpyDeviceToHost));
...@@ -152,7 +152,7 @@ bool Colorize(IdType * result_data, int64_t num_nodes, float * const prop) { ...@@ -152,7 +152,7 @@ bool Colorize(IdType * result_data, int64_t num_nodes, float * const prop) {
*/ */
template <DLDeviceType XPU, typename FloatType, typename IdType> template <DLDeviceType XPU, typename FloatType, typename IdType>
void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight, IdArray result) { void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight, IdArray result) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const auto& ctx = result->ctx; const auto& ctx = result->ctx;
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
device->SetDevice(ctx); device->SetDevice(ctx);
...@@ -175,9 +175,9 @@ void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight, ...@@ -175,9 +175,9 @@ void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight,
auto num_threads = cuda::FindNumThreads(num_nodes); auto num_threads = cuda::FindNumThreads(num_nodes);
auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_nodes, num_threads)); auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_nodes, num_threads));
while (!Colorize<IdType>(result_data, num_nodes, prop)) { while (!Colorize<IdType>(result_data, num_nodes, prop)) {
CUDA_KERNEL_CALL(weighted_propose_kernel, num_blocks, num_threads, 0, thr_entry->stream, CUDA_KERNEL_CALL(weighted_propose_kernel, num_blocks, num_threads, 0, stream,
indptr_data, indices_data, weight_data, num_nodes, proposal_data, result_data); indptr_data, indices_data, weight_data, num_nodes, proposal_data, result_data);
CUDA_KERNEL_CALL(weighted_respond_kernel, num_blocks, num_threads, 0, thr_entry->stream, CUDA_KERNEL_CALL(weighted_respond_kernel, num_blocks, num_threads, 0, stream,
indptr_data, indices_data, weight_data, num_nodes, proposal_data, result_data); indptr_data, indices_data, weight_data, num_nodes, proposal_data, result_data);
} }
device->FreeWorkspace(ctx, prop); device->FreeWorkspace(ctx, prop);
...@@ -209,14 +209,14 @@ void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) { ...@@ -209,14 +209,14 @@ void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) {
device->SetDevice(ctx); device->SetDevice(ctx);
// generate random weights // generate random weights
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
NDArray weight = NDArray::Empty( NDArray weight = NDArray::Empty(
{num_edges}, DLDataType{kDLFloat, sizeof(float) * 8, 1}, ctx); {num_edges}, DLDataType{kDLFloat, sizeof(float) * 8, 1}, ctx);
float *weight_data = static_cast<float*>(weight->data); float *weight_data = static_cast<float*>(weight->data);
uint64_t seed = dgl::RandomEngine::ThreadLocal()->RandInt(UINT64_MAX); uint64_t seed = dgl::RandomEngine::ThreadLocal()->RandInt(UINT64_MAX);
auto num_threads = cuda::FindNumThreads(num_edges); auto num_threads = cuda::FindNumThreads(num_edges);
auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_edges, num_threads)); auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_edges, num_threads));
CUDA_KERNEL_CALL(generate_uniform_kernel, num_blocks, num_threads, 0, thr_entry->stream, CUDA_KERNEL_CALL(generate_uniform_kernel, num_blocks, num_threads, 0, stream,
weight_data, num_edges, seed); weight_data, num_edges, seed);
WeightedNeighborMatching<XPU, float, IdType>(csr, weight, result); WeightedNeighborMatching<XPU, float, IdType>(csr, weight, result);
......
...@@ -92,7 +92,7 @@ __global__ void fps_kernel(const FloatType *array_data, const int64_t batch_size ...@@ -92,7 +92,7 @@ __global__ void fps_kernel(const FloatType *array_data, const int64_t batch_size
template <DLDeviceType XPU, typename FloatType, typename IdType> template <DLDeviceType XPU, typename FloatType, typename IdType>
void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_points, void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result) { NDArray dist, IdArray start_idx, IdArray result) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const FloatType* array_data = static_cast<FloatType*>(array->data); const FloatType* array_data = static_cast<FloatType*>(array->data);
...@@ -110,7 +110,7 @@ void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_poin ...@@ -110,7 +110,7 @@ void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_poin
CUDA_CALL(cudaSetDevice(array->ctx.device_id)); CUDA_CALL(cudaSetDevice(array->ctx.device_id));
CUDA_KERNEL_CALL(fps_kernel, CUDA_KERNEL_CALL(fps_kernel,
batch_size, THREADS, 0, thr_entry->stream, batch_size, THREADS, 0, stream,
array_data, batch_size, sample_points, array_data, batch_size, sample_points,
point_in_batch, dim, start_idx_data, dist_data, ret_data); point_in_batch, dim, start_idx_data, dist_data, ret_data);
} }
......
...@@ -276,6 +276,11 @@ void HeteroGraph::UnpinMemory_() { ...@@ -276,6 +276,11 @@ void HeteroGraph::UnpinMemory_() {
g->UnpinMemory_(); g->UnpinMemory_();
} }
void HeteroGraph::RecordStream(DGLStreamHandle stream) {
for (auto g : relation_graphs_)
g->RecordStream(stream);
}
std::string HeteroGraph::SharedMemName() const { std::string HeteroGraph::SharedMemName() const {
return shared_mem_ ? shared_mem_->GetName() : ""; return shared_mem_ ? shared_mem_->GetName() : "";
} }
......
...@@ -251,6 +251,12 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -251,6 +251,12 @@ class HeteroGraph : public BaseHeteroGraph {
*/ */
void UnpinMemory_(); void UnpinMemory_();
/*!
* \brief Record stream for this graph.
* \param stream The stream that is using the graph
*/
void RecordStream(DGLStreamHandle stream) override;
/*! \brief Copy the data to shared memory. /*! \brief Copy the data to shared memory.
* *
* Also save names of node types and edge types of the HeteroGraph object to shared memory * Also save names of node types and edge types of the HeteroGraph object to shared memory
......
...@@ -493,6 +493,15 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpinMemory_") ...@@ -493,6 +493,15 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpinMemory_")
*rv = hg; *rv = hg;
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroRecordStream")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
DGLStreamHandle stream = args[1];
auto hgindex = std::dynamic_pointer_cast<HeteroGraph>(hg.sptr());
hgindex->RecordStream(stream);
*rv = hg;
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyToSharedMem") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyToSharedMem")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
......
...@@ -198,7 +198,7 @@ std::pair<IdArray, IdArray> RandomWalkUniform( ...@@ -198,7 +198,7 @@ std::pair<IdArray, IdArray> RandomWalkUniform(
h_graphs[etype].data = (CSRHasData(csr) ? static_cast<const IdType*>(csr.data->data) : nullptr); h_graphs[etype].data = (CSRHasData(csr) ? static_cast<const IdType*>(csr.data->data) : nullptr);
} }
// use cuda stream from local thread // use cuda stream from local thread
cudaStream_t stream = runtime::CUDAThreadEntry::ThreadLocal()->stream; cudaStream_t stream = runtime::getCurrentCUDAStream();
auto device = DeviceAPI::Get(ctx); auto device = DeviceAPI::Get(ctx);
auto d_graphs = static_cast<GraphKernelData<IdType>*>( auto d_graphs = static_cast<GraphKernelData<IdType>*>(
device->AllocWorkspace(ctx, (num_etypes) * sizeof(GraphKernelData<IdType>))); device->AllocWorkspace(ctx, (num_etypes) * sizeof(GraphKernelData<IdType>)));
...@@ -269,7 +269,7 @@ std::pair<IdArray, IdArray> RandomWalkBiased( ...@@ -269,7 +269,7 @@ std::pair<IdArray, IdArray> RandomWalkBiased(
IdType *traces_data = traces.Ptr<IdType>(); IdType *traces_data = traces.Ptr<IdType>();
IdType *eids_data = eids.Ptr<IdType>(); IdType *eids_data = eids.Ptr<IdType>();
cudaStream_t stream = runtime::CUDAThreadEntry::ThreadLocal()->stream; cudaStream_t stream = runtime::getCurrentCUDAStream();
auto device = DeviceAPI::Get(ctx); auto device = DeviceAPI::Get(ctx);
// new probs and prob sums pointers // new probs and prob sums pointers
assert(num_etypes == static_cast<int64_t>(prob.size())); assert(num_etypes == static_cast<int64_t>(prob.size()));
...@@ -426,7 +426,7 @@ std::pair<IdArray, IdArray> RandomWalkWithRestart( ...@@ -426,7 +426,7 @@ std::pair<IdArray, IdArray> RandomWalkWithRestart(
auto device = dgl::runtime::DeviceAPI::Get(device_ctx); auto device = dgl::runtime::DeviceAPI::Get(device_ctx);
// use cuda stream from local thread // use cuda stream from local thread
cudaStream_t stream = runtime::CUDAThreadEntry::ThreadLocal()->stream; cudaStream_t stream = runtime::getCurrentCUDAStream();
device->CopyDataFromTo( device->CopyDataFromTo(
&restart_prob, 0, restart_prob_array.Ptr<double>(), 0, &restart_prob, 0, restart_prob_array.Ptr<double>(), 0,
sizeof(double), sizeof(double),
...@@ -486,7 +486,7 @@ std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors( ...@@ -486,7 +486,7 @@ std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors(
const int64_t num_dst_nodes = (dst->shape[0] / num_samples_per_node); const int64_t num_dst_nodes = (dst->shape[0] / num_samples_per_node);
auto ctx = src->ctx; auto ctx = src->ctx;
// use cuda stream from local thread // use cuda stream from local thread
cudaStream_t stream = runtime::CUDAThreadEntry::ThreadLocal()->stream; cudaStream_t stream = runtime::getCurrentCUDAStream();
auto frequency_hashmap = FrequencyHashmap<IdxType>(num_dst_nodes, auto frequency_hashmap = FrequencyHashmap<IdxType>(num_dst_nodes,
num_samples_per_node, ctx, stream); num_samples_per_node, ctx, stream);
auto ret = frequency_hashmap.Topk(src_data, dst_data, src->dtype, auto ret = frequency_hashmap.Topk(src_data, dst_data, src->dtype,
......
...@@ -88,10 +88,9 @@ CompactGraphsGPU( ...@@ -88,10 +88,9 @@ CompactGraphsGPU(
const std::vector<HeteroGraphPtr> &graphs, const std::vector<HeteroGraphPtr> &graphs,
const std::vector<IdArray> &always_preserve) { const std::vector<IdArray> &always_preserve) {
const auto& ctx = graphs[0]->Context(); const auto& ctx = graphs[0]->Context();
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
cudaStream_t stream = runtime::CUDAThreadEntry::ThreadLocal()->stream; cudaStream_t stream = runtime::getCurrentCUDAStream();
CHECK_EQ(ctx.device_type, kDLGPU); CHECK_EQ(ctx.device_type, kDLGPU);
......
...@@ -168,7 +168,7 @@ ToBlockGPU( ...@@ -168,7 +168,7 @@ ToBlockGPU(
const auto& ctx = graph->Context(); const auto& ctx = graph->Context();
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
cudaStream_t stream = runtime::CUDAThreadEntry::ThreadLocal()->stream; cudaStream_t stream = runtime::getCurrentCUDAStream();
CHECK_EQ(ctx.device_type, kDLGPU); CHECK_EQ(ctx.device_type, kDLGPU);
for (const auto& nodes : rhs_nodes) { for (const auto& nodes : rhs_nodes) {
......
...@@ -437,7 +437,7 @@ template <typename FloatType, typename IdType> ...@@ -437,7 +437,7 @@ template <typename FloatType, typename IdType>
void BruteForceKNNCuda(const NDArray& data_points, const IdArray& data_offsets, void BruteForceKNNCuda(const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets, const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result) { const int k, IdArray result) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const auto& ctx = data_points->ctx; const auto& ctx = data_points->ctx;
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
const int64_t batch_size = data_offsets->shape[0] - 1; const int64_t batch_size = data_offsets->shape[0] - 1;
...@@ -454,7 +454,7 @@ void BruteForceKNNCuda(const NDArray& data_points, const IdArray& data_offsets, ...@@ -454,7 +454,7 @@ void BruteForceKNNCuda(const NDArray& data_points, const IdArray& data_offsets,
const int64_t block_size = cuda::FindNumThreads(query_points->shape[0]); const int64_t block_size = cuda::FindNumThreads(query_points->shape[0]);
const int64_t num_blocks = (query_points->shape[0] - 1) / block_size + 1; const int64_t num_blocks = (query_points->shape[0] - 1) / block_size + 1;
CUDA_KERNEL_CALL(BruteforceKnnKernel, num_blocks, block_size, 0, thr_entry->stream, CUDA_KERNEL_CALL(BruteforceKnnKernel, num_blocks, block_size, 0, stream,
data_points_data, data_offsets_data, query_points_data, query_offsets_data, data_points_data, data_offsets_data, query_points_data, query_offsets_data,
k, dists, query_out, data_out, batch_size, feature_size); k, dists, query_out, data_out, batch_size, feature_size);
...@@ -480,7 +480,7 @@ template <typename FloatType, typename IdType> ...@@ -480,7 +480,7 @@ template <typename FloatType, typename IdType>
void BruteForceKNNSharedCuda(const NDArray& data_points, const IdArray& data_offsets, void BruteForceKNNSharedCuda(const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets, const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result) { const int k, IdArray result) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const auto& ctx = data_points->ctx; const auto& ctx = data_points->ctx;
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
const int64_t batch_size = data_offsets->shape[0] - 1; const int64_t batch_size = data_offsets->shape[0] - 1;
...@@ -512,17 +512,17 @@ void BruteForceKNNSharedCuda(const NDArray& data_points, const IdArray& data_off ...@@ -512,17 +512,17 @@ void BruteForceKNNSharedCuda(const NDArray& data_points, const IdArray& data_off
int64_t temp_block_size = cuda::FindNumThreads(batch_size); int64_t temp_block_size = cuda::FindNumThreads(batch_size);
int64_t temp_num_blocks = (batch_size - 1) / temp_block_size + 1; int64_t temp_num_blocks = (batch_size - 1) / temp_block_size + 1;
CUDA_KERNEL_CALL(GetNumBlockPerSegment, temp_num_blocks, CUDA_KERNEL_CALL(GetNumBlockPerSegment, temp_num_blocks,
temp_block_size, 0, thr_entry->stream, temp_block_size, 0, stream,
query_offsets_data, num_block_per_segment, query_offsets_data, num_block_per_segment,
batch_size, block_size); batch_size, block_size);
size_t prefix_temp_size = 0; size_t prefix_temp_size = 0;
CUDA_CALL(cub::DeviceScan::ExclusiveSum( CUDA_CALL(cub::DeviceScan::ExclusiveSum(
nullptr, prefix_temp_size, num_block_per_segment, nullptr, prefix_temp_size, num_block_per_segment,
num_block_prefixsum, batch_size, thr_entry->stream)); num_block_prefixsum, batch_size, stream));
void* prefix_temp = device->AllocWorkspace(ctx, prefix_temp_size); void* prefix_temp = device->AllocWorkspace(ctx, prefix_temp_size);
CUDA_CALL(cub::DeviceScan::ExclusiveSum( CUDA_CALL(cub::DeviceScan::ExclusiveSum(
prefix_temp, prefix_temp_size, num_block_per_segment, prefix_temp, prefix_temp_size, num_block_per_segment,
num_block_prefixsum, batch_size, thr_entry->stream)); num_block_prefixsum, batch_size, stream));
device->FreeWorkspace(ctx, prefix_temp); device->FreeWorkspace(ctx, prefix_temp);
int64_t num_blocks = 0, final_elem = 0, copyoffset = (batch_size - 1) * sizeof(IdType); int64_t num_blocks = 0, final_elem = 0, copyoffset = (batch_size - 1) * sizeof(IdType);
...@@ -547,13 +547,13 @@ void BruteForceKNNSharedCuda(const NDArray& data_points, const IdArray& data_off ...@@ -547,13 +547,13 @@ void BruteForceKNNSharedCuda(const NDArray& data_points, const IdArray& data_off
ctx, num_blocks * sizeof(IdType))); ctx, num_blocks * sizeof(IdType)));
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
GetBlockInfo, temp_num_blocks, temp_block_size, 0, GetBlockInfo, temp_num_blocks, temp_block_size, 0,
thr_entry->stream, num_block_prefixsum, block_batch_id, stream, num_block_prefixsum, block_batch_id,
local_block_id, batch_size, num_blocks); local_block_id, batch_size, num_blocks);
FloatType* dists = static_cast<FloatType*>(device->AllocWorkspace( FloatType* dists = static_cast<FloatType*>(device->AllocWorkspace(
ctx, k * query_points->shape[0] * sizeof(FloatType))); ctx, k * query_points->shape[0] * sizeof(FloatType)));
CUDA_KERNEL_CALL(BruteforceKnnShareKernel, num_blocks, block_size, CUDA_KERNEL_CALL(BruteforceKnnShareKernel, num_blocks, block_size,
single_shared_mem * block_size, thr_entry->stream, data_points_data, single_shared_mem * block_size, stream, data_points_data,
data_offsets_data, query_points_data, query_offsets_data, data_offsets_data, query_points_data, query_offsets_data,
block_batch_id, local_block_id, k, dists, query_out, block_batch_id, local_block_id, k, dists, query_out,
data_out, batch_size, feature_size); data_out, batch_size, feature_size);
...@@ -834,7 +834,7 @@ template <DLDeviceType XPU, typename FloatType, typename IdType> ...@@ -834,7 +834,7 @@ template <DLDeviceType XPU, typename FloatType, typename IdType>
void NNDescent(const NDArray& points, const IdArray& offsets, void NNDescent(const NDArray& points, const IdArray& offsets,
IdArray result, const int k, const int num_iters, IdArray result, const int k, const int num_iters,
const int num_candidates, const double delta) { const int num_candidates, const double delta) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const auto& ctx = points->ctx; const auto& ctx = points->ctx;
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
const int64_t num_nodes = points->shape[0]; const int64_t num_nodes = points->shape[0];
...@@ -872,7 +872,7 @@ void NNDescent(const NDArray& points, const IdArray& offsets, ...@@ -872,7 +872,7 @@ void NNDescent(const NDArray& points, const IdArray& offsets,
device->AllocWorkspace(ctx, sizeof(IdType))); device->AllocWorkspace(ctx, sizeof(IdType)));
CUDA_CALL(cub::DeviceReduce::Sum( CUDA_CALL(cub::DeviceReduce::Sum(
nullptr, sum_temp_size, num_updates, total_num_updates_d, num_nodes, thr_entry->stream)); nullptr, sum_temp_size, num_updates, total_num_updates_d, num_nodes, stream));
IdType* sum_temp_storage = static_cast<IdType*>( IdType* sum_temp_storage = static_cast<IdType*>(
device->AllocWorkspace(ctx, sum_temp_size)); device->AllocWorkspace(ctx, sum_temp_size));
...@@ -880,7 +880,7 @@ void NNDescent(const NDArray& points, const IdArray& offsets, ...@@ -880,7 +880,7 @@ void NNDescent(const NDArray& points, const IdArray& offsets,
seed = RandomEngine::ThreadLocal()->RandInt<uint64_t>( seed = RandomEngine::ThreadLocal()->RandInt<uint64_t>(
std::numeric_limits<uint64_t>::max()); std::numeric_limits<uint64_t>::max());
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
impl::RandomInitNeighborsKernel, num_blocks, block_size, 0, thr_entry->stream, impl::RandomInitNeighborsKernel, num_blocks, block_size, 0, stream,
points_data, offsets_data, central_nodes, neighbors, distances, flags, k, points_data, offsets_data, central_nodes, neighbors, distances, flags, k,
feature_size, batch_size, seed); feature_size, batch_size, seed);
...@@ -890,19 +890,19 @@ void NNDescent(const NDArray& points, const IdArray& offsets, ...@@ -890,19 +890,19 @@ void NNDescent(const NDArray& points, const IdArray& offsets,
std::numeric_limits<uint64_t>::max()); std::numeric_limits<uint64_t>::max());
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
impl::FindCandidatesKernel, num_blocks, block_size, 0, impl::FindCandidatesKernel, num_blocks, block_size, 0,
thr_entry->stream, offsets_data, new_candidates, old_candidates, neighbors, stream, offsets_data, new_candidates, old_candidates, neighbors,
flags, seed, batch_size, num_candidates, k); flags, seed, batch_size, num_candidates, k);
// update // update
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
impl::UpdateNeighborsKernel, num_blocks, block_size, 0, thr_entry->stream, impl::UpdateNeighborsKernel, num_blocks, block_size, 0, stream,
points_data, offsets_data, neighbors, new_candidates, old_candidates, distances, points_data, offsets_data, neighbors, new_candidates, old_candidates, distances,
flags, num_updates, batch_size, num_candidates, k, feature_size); flags, num_updates, batch_size, num_candidates, k, feature_size);
total_num_updates = 0; total_num_updates = 0;
CUDA_CALL(cub::DeviceReduce::Sum( CUDA_CALL(cub::DeviceReduce::Sum(
sum_temp_storage, sum_temp_size, num_updates, total_num_updates_d, num_nodes, sum_temp_storage, sum_temp_size, num_updates, total_num_updates_d, num_nodes,
thr_entry->stream)); stream));
device->CopyDataFromTo( device->CopyDataFromTo(
total_num_updates_d, 0, &total_num_updates, 0, total_num_updates_d, 0, &total_num_updates, 0,
sizeof(IdType), ctx, DLContext{kDLCPU, 0}, sizeof(IdType), ctx, DLContext{kDLCPU, 0},
......
...@@ -170,6 +170,11 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -170,6 +170,11 @@ class UnitGraph::COO : public BaseHeteroGraph {
adj_.UnpinMemory_(); adj_.UnpinMemory_();
} }
/*! \brief Record stream for the adj_: COOMatrix of the COO graph. */
void RecordStream(DGLStreamHandle stream) override {
adj_.RecordStream(stream);
}
bool IsMultigraph() const override { bool IsMultigraph() const override {
return aten::COOHasDuplicate(adj_); return aten::COOHasDuplicate(adj_);
} }
...@@ -575,6 +580,11 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -575,6 +580,11 @@ class UnitGraph::CSR : public BaseHeteroGraph {
adj_.UnpinMemory_(); adj_.UnpinMemory_();
} }
/*! \brief Record stream for the adj_: CSRMatrix of the CSR graph. */
void RecordStream(DGLStreamHandle stream) override {
adj_.RecordStream(stream);
}
bool IsMultigraph() const override { bool IsMultigraph() const override {
return aten::CSRHasDuplicate(adj_); return aten::CSRHasDuplicate(adj_);
} }
...@@ -1313,6 +1323,16 @@ void UnitGraph::UnpinMemory_() { ...@@ -1313,6 +1323,16 @@ void UnitGraph::UnpinMemory_() {
this->coo_->UnpinMemory_(); this->coo_->UnpinMemory_();
} }
void UnitGraph::RecordStream(DGLStreamHandle stream) {
if (this->in_csr_->defined())
this->in_csr_->RecordStream(stream);
if (this->out_csr_->defined())
this->out_csr_->RecordStream(stream);
if (this->coo_->defined())
this->coo_->RecordStream(stream);
this->recorded_streams.push_back(stream);
}
void UnitGraph::InvalidateCSR() { void UnitGraph::InvalidateCSR() {
this->out_csr_ = CSRPtr(new CSR()); this->out_csr_ = CSRPtr(new CSR());
} }
...@@ -1402,8 +1422,12 @@ UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const { ...@@ -1402,8 +1422,12 @@ UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const {
else else
ret = std::make_shared<CSR>(meta_graph(), newadj); ret = std::make_shared<CSR>(meta_graph(), newadj);
} }
if (inplace && IsPinned()) if (inplace) {
if (IsPinned())
in_csr_->PinMemory_(); in_csr_->PinMemory_();
for (auto stream : recorded_streams)
in_csr_->RecordStream(stream);
}
} }
return ret; return ret;
} }
...@@ -1434,8 +1458,12 @@ UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const { ...@@ -1434,8 +1458,12 @@ UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const {
else else
ret = std::make_shared<CSR>(meta_graph(), newadj); ret = std::make_shared<CSR>(meta_graph(), newadj);
} }
if (inplace && IsPinned()) if (inplace) {
if (IsPinned())
out_csr_->PinMemory_(); out_csr_->PinMemory_();
for (auto stream : recorded_streams)
out_csr_->RecordStream(stream);
}
} }
return ret; return ret;
} }
...@@ -1464,8 +1492,12 @@ UnitGraph::COOPtr UnitGraph::GetCOO(bool inplace) const { ...@@ -1464,8 +1492,12 @@ UnitGraph::COOPtr UnitGraph::GetCOO(bool inplace) const {
else else
ret = std::make_shared<COO>(meta_graph(), newadj); ret = std::make_shared<COO>(meta_graph(), newadj);
} }
if (inplace && IsPinned()) if (inplace) {
if (IsPinned())
coo_->PinMemory_(); coo_->PinMemory_();
for (auto stream : recorded_streams)
coo_->RecordStream(stream);
}
} }
return ret; return ret;
} }
......
...@@ -228,6 +228,12 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -228,6 +228,12 @@ class UnitGraph : public BaseHeteroGraph {
*/ */
void UnpinMemory_(); void UnpinMemory_();
/*!
* \brief Record stream for this graph.
* \param stream The stream that is using the graph
*/
void RecordStream(DGLStreamHandle stream) override;
/*! /*!
* \brief Create in-edge CSR format of the unit graph. * \brief Create in-edge CSR format of the unit graph.
* \param inplace if true and the in-edge CSR format does not exist, the created * \param inplace if true and the in-edge CSR format does not exist, the created
...@@ -375,6 +381,8 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -375,6 +381,8 @@ class UnitGraph : public BaseHeteroGraph {
* \brief Storage format restriction. * \brief Storage format restriction.
*/ */
dgl_format_code_t formats_; dgl_format_code_t formats_;
/*! \brief which streams have recorded the graph */
std::vector<DGLStreamHandle> recorded_streams;
}; };
}; // namespace dgl }; // namespace dgl
......
...@@ -261,7 +261,7 @@ GeneratePermutationFromRemainder( ...@@ -261,7 +261,7 @@ GeneratePermutationFromRemainder(
const auto& ctx = in_idx->ctx; const auto& ctx = in_idx->ctx;
auto device = DeviceAPI::Get(ctx); auto device = DeviceAPI::Get(ctx);
cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream; cudaStream_t stream = runtime::getCurrentCUDAStream();
const int64_t num_in = in_idx->shape[0]; const int64_t num_in = in_idx->shape[0];
...@@ -392,7 +392,7 @@ IdArray MapToLocalFromRemainder( ...@@ -392,7 +392,7 @@ IdArray MapToLocalFromRemainder(
const int num_parts, const int num_parts,
IdArray global_idx) { IdArray global_idx) {
const auto& ctx = global_idx->ctx; const auto& ctx = global_idx->ctx;
cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream; cudaStream_t stream = runtime::getCurrentCUDAStream();
if (num_parts > 1) { if (num_parts > 1) {
IdArray local_idx = aten::NewIdArray(global_idx->shape[0], ctx, IdArray local_idx = aten::NewIdArray(global_idx->shape[0], ctx,
...@@ -439,7 +439,7 @@ IdArray MapToGlobalFromRemainder( ...@@ -439,7 +439,7 @@ IdArray MapToGlobalFromRemainder(
"/" << num_parts; "/" << num_parts;
const auto& ctx = local_idx->ctx; const auto& ctx = local_idx->ctx;
cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream; cudaStream_t stream = runtime::getCurrentCUDAStream();
if (num_parts > 1) { if (num_parts > 1) {
IdArray global_idx = aten::NewIdArray(local_idx->shape[0], ctx, IdArray global_idx = aten::NewIdArray(local_idx->shape[0], ctx,
...@@ -492,7 +492,7 @@ GeneratePermutationFromRange( ...@@ -492,7 +492,7 @@ GeneratePermutationFromRange(
const auto& ctx = in_idx->ctx; const auto& ctx = in_idx->ctx;
auto device = DeviceAPI::Get(ctx); auto device = DeviceAPI::Get(ctx);
cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream; cudaStream_t stream = runtime::getCurrentCUDAStream();
const int64_t num_in = in_idx->shape[0]; const int64_t num_in = in_idx->shape[0];
...@@ -628,7 +628,7 @@ IdArray MapToLocalFromRange( ...@@ -628,7 +628,7 @@ IdArray MapToLocalFromRange(
IdArray range, IdArray range,
IdArray global_idx) { IdArray global_idx) {
const auto& ctx = global_idx->ctx; const auto& ctx = global_idx->ctx;
cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream; cudaStream_t stream = runtime::getCurrentCUDAStream();
if (num_parts > 1 && global_idx->shape[0] > 0) { if (num_parts > 1 && global_idx->shape[0] > 0) {
IdArray local_idx = aten::NewIdArray(global_idx->shape[0], ctx, IdArray local_idx = aten::NewIdArray(global_idx->shape[0], ctx,
...@@ -690,7 +690,7 @@ IdArray MapToGlobalFromRange( ...@@ -690,7 +690,7 @@ IdArray MapToGlobalFromRange(
"/" << num_parts; "/" << num_parts;
const auto& ctx = local_idx->ctx; const auto& ctx = local_idx->ctx;
cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream; cudaStream_t stream = runtime::getCurrentCUDAStream();
if (num_parts > 1 && local_idx->shape[0] > 0) { if (num_parts > 1 && local_idx->shape[0] > 0) {
IdArray global_idx = aten::NewIdArray(local_idx->shape[0], ctx, IdArray global_idx = aten::NewIdArray(local_idx->shape[0], ctx,
......
...@@ -158,8 +158,6 @@ struct cusparse_idtype<int64_t> { ...@@ -158,8 +158,6 @@ struct cusparse_idtype<int64_t> {
/*! \brief Thread local workspace */ /*! \brief Thread local workspace */
class CUDAThreadEntry { class CUDAThreadEntry {
public: public:
/*! \brief The cuda stream */
cudaStream_t stream{nullptr};
/*! \brief The cusparse handler */ /*! \brief The cusparse handler */
cusparseHandle_t cusparse_handle{nullptr}; cusparseHandle_t cusparse_handle{nullptr};
/*! \brief The cublas handler */ /*! \brief The cublas handler */
...@@ -173,6 +171,9 @@ class CUDAThreadEntry { ...@@ -173,6 +171,9 @@ class CUDAThreadEntry {
// get the threadlocal workspace // get the threadlocal workspace
static CUDAThreadEntry* ThreadLocal(); static CUDAThreadEntry* ThreadLocal();
}; };
/*! \brief Get the current CUDA stream */
cudaStream_t getCurrentCUDAStream();
} // namespace runtime } // namespace runtime
} // namespace dgl } // namespace dgl
#endif // DGL_RUNTIME_CUDA_CUDA_COMMON_H_ #endif // DGL_RUNTIME_CUDA_CUDA_COMMON_H_
...@@ -111,7 +111,7 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -111,7 +111,7 @@ class CUDADeviceAPI final : public DeviceAPI {
// Redirect to PyTorch's allocator when available. // Redirect to PyTorch's allocator when available.
TensorDispatcher* td = TensorDispatcher::Global(); TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable()) if (td->IsAvailable())
return td->CUDAAllocWorkspace(nbytes, CUDAThreadEntry::ThreadLocal()->stream); return td->CUDAAllocWorkspace(nbytes, getCurrentCUDAStream());
CHECK_EQ(256 % alignment, 0U) CHECK_EQ(256 % alignment, 0U)
<< "CUDA space is aligned at 256 bytes"; << "CUDA space is aligned at 256 bytes";
...@@ -169,7 +169,7 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -169,7 +169,7 @@ class CUDADeviceAPI final : public DeviceAPI {
DGLContext ctx_from, DGLContext ctx_from,
DGLContext ctx_to, DGLContext ctx_to,
DGLType type_hint) final { DGLType type_hint) final {
auto stream = static_cast<DGLStreamHandle>(CUDAThreadEntry::ThreadLocal()->stream); auto stream = GetStream();
CopyDataFromTo(from, from_offset, to, to_offset, size, ctx_from, ctx_to, type_hint, stream); CopyDataFromTo(from, from_offset, to, to_offset, size, ctx_from, ctx_to, type_hint, stream);
} }
...@@ -203,13 +203,16 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -203,13 +203,16 @@ class CUDADeviceAPI final : public DeviceAPI {
CUDA_CALL(cudaStreamSynchronize(static_cast<cudaStream_t>(stream))); CUDA_CALL(cudaStreamSynchronize(static_cast<cudaStream_t>(stream)));
} }
void SetStream(DGLContext ctx, DGLStreamHandle stream) final { /*! NOTE: If the backend is PyTorch, we will use PyTorch's stream management,
CUDAThreadEntry::ThreadLocal() * so just avoid calling our SetStream/CreateStream unless
->stream = static_cast<cudaStream_t>(stream); * you really need advanced stream control.
} * TODO(Xin): Redirect this to PyTorch or remove it.
* PyTorch allows external CUDA streams to be set as current since v1.11.
*/
void SetStream(DGLContext ctx, DGLStreamHandle stream) final {}
DGLStreamHandle GetStream() const final { DGLStreamHandle GetStream() const final {
return static_cast<DGLStreamHandle>(CUDAThreadEntry::ThreadLocal()->stream); return static_cast<DGLStreamHandle>(getCurrentCUDAStream());
} }
/*! NOTE: cudaHostRegister can be called from an arbitrary GPU device, /*! NOTE: cudaHostRegister can be called from an arbitrary GPU device,
...@@ -271,7 +274,7 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -271,7 +274,7 @@ class CUDADeviceAPI final : public DeviceAPI {
// Redirect to PyTorch's allocator when available. // Redirect to PyTorch's allocator when available.
TensorDispatcher* td = TensorDispatcher::Global(); TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable()) if (td->IsAvailable())
return td->CUDAAllocWorkspace(size, CUDAThreadEntry::ThreadLocal()->stream); return td->CUDAAllocWorkspace(size, getCurrentCUDAStream());
return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size); return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
} }
...@@ -309,19 +312,22 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -309,19 +312,22 @@ class CUDADeviceAPI final : public DeviceAPI {
typedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore; typedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore;
// TODO(cliu): cuda streams should depend on the current device, therefore we should set device
// before setting stream.
CUDAThreadEntry::CUDAThreadEntry() CUDAThreadEntry::CUDAThreadEntry()
: pool(kDLGPU, CUDADeviceAPI::Global()) { : pool(kDLGPU, CUDADeviceAPI::Global()) {
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
stream = td->CUDAGetCurrentStream();
} }
CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() { CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() {
return CUDAThreadStore::Get(); return CUDAThreadStore::Get();
} }
cudaStream_t getCurrentCUDAStream() {
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
return td->CUDAGetCurrentStream();
else // return the default stream when TA is not available
return nullptr;
}
DGL_REGISTER_GLOBAL("device_api.gpu") DGL_REGISTER_GLOBAL("device_api.gpu")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
DeviceAPI* ptr = CUDADeviceAPI::Global().get(); DeviceAPI* ptr = CUDADeviceAPI::Global().get();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment