Unverified Commit 9a00cf19 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Feature] Import PyTorch's CUDA stream management (#4503)

* add set_stream

* add .record_stream for NDArray and HeteroGraph

* refactor dgl stream Python APIs

* test record_stream

* add unit test for record stream

* use pytorch's stream

* fix lint

* fix cpu build

* address comments

* address comments

* add record stream tests for dgl.graph

* record frames and update dataloder

* add docstring

* update frame

* add backend check for record_stream

* remove CUDAThreadEntry::stream

* record stream for newly created formats

* fix bug

* fix cpp test

* fix None c_void_p to c_handle
parent 099b173f
...@@ -27,7 +27,7 @@ std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits) { ...@@ -27,7 +27,7 @@ std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits) {
IdType* keys_out = sorted_array.Ptr<IdType>(); IdType* keys_out = sorted_array.Ptr<IdType>();
int64_t* values_out = sorted_idx.Ptr<int64_t>(); int64_t* values_out = sorted_idx.Ptr<int64_t>();
auto stream = runtime::CUDAThreadEntry::ThreadLocal()->stream; cudaStream_t stream = runtime::getCurrentCUDAStream();
if (num_bits == 0) { if (num_bits == 0) {
num_bits = sizeof(IdType)*8; num_bits = sizeof(IdType)*8;
} }
......
...@@ -23,11 +23,12 @@ CSRMatrix COOToCSR(COOMatrix coo) { ...@@ -23,11 +23,12 @@ CSRMatrix COOToCSR(COOMatrix coo) {
template <> template <>
CSRMatrix COOToCSR<kDLGPU, int32_t>(COOMatrix coo) { CSRMatrix COOToCSR<kDLGPU, int32_t>(COOMatrix coo) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
// allocate cusparse handle if needed // allocate cusparse handle if needed
if (!thr_entry->cusparse_handle) { if (!thr_entry->cusparse_handle) {
CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle))); CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
} }
CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, thr_entry->stream)); CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, stream));
bool row_sorted = coo.row_sorted; bool row_sorted = coo.row_sorted;
bool col_sorted = coo.col_sorted; bool col_sorted = coo.col_sorted;
...@@ -102,7 +103,7 @@ template <> ...@@ -102,7 +103,7 @@ template <>
CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo) { CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo) {
const auto& ctx = coo.row->ctx; const auto& ctx = coo.row->ctx;
const auto nbits = coo.row->dtype.bits; const auto nbits = coo.row->dtype.bits;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
bool row_sorted = coo.row_sorted; bool row_sorted = coo.row_sorted;
bool col_sorted = coo.col_sorted; bool col_sorted = coo.col_sorted;
if (!row_sorted) { if (!row_sorted) {
...@@ -123,7 +124,7 @@ CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo) { ...@@ -123,7 +124,7 @@ CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo) {
const int nb = (coo.num_rows + nt - 1) / nt; const int nb = (coo.num_rows + nt - 1) / nt;
IdArray indptr = Full(0, coo.num_rows + 1, nbits, ctx); IdArray indptr = Full(0, coo.num_rows + 1, nbits, ctx);
CUDA_KERNEL_CALL(_SortedSearchKernelUpperBound, CUDA_KERNEL_CALL(_SortedSearchKernelUpperBound,
nb, nt, 0, thr_entry->stream, nb, nt, 0, stream,
coo.row.Ptr<int64_t>(), nnz, coo.row.Ptr<int64_t>(), nnz,
rowids.Ptr<int64_t>(), coo.num_rows, rowids.Ptr<int64_t>(), coo.num_rows,
indptr.Ptr<int64_t>() + 1); indptr.Ptr<int64_t>() + 1);
......
...@@ -86,7 +86,7 @@ int _NumberOfBits(const T& range) { ...@@ -86,7 +86,7 @@ int _NumberOfBits(const T& range) {
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
void COOSort_(COOMatrix* coo, bool sort_column) { void COOSort_(COOMatrix* coo, bool sort_column) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const int row_bits = _NumberOfBits(coo->num_rows); const int row_bits = _NumberOfBits(coo->num_rows);
const int64_t nnz = coo->row->shape[0]; const int64_t nnz = coo->row->shape[0];
...@@ -100,13 +100,13 @@ void COOSort_(COOMatrix* coo, bool sort_column) { ...@@ -100,13 +100,13 @@ void COOSort_(COOMatrix* coo, bool sort_column) {
IdArray pos = aten::NewIdArray(nnz, coo->row->ctx, coo->row->dtype.bits); IdArray pos = aten::NewIdArray(nnz, coo->row->ctx, coo->row->dtype.bits);
CUDA_KERNEL_CALL(_COOEncodeEdgesKernel, nb, nt, 0, thr_entry->stream, CUDA_KERNEL_CALL(_COOEncodeEdgesKernel, nb, nt, 0, stream,
coo->row.Ptr<IdType>(), coo->col.Ptr<IdType>(), coo->row.Ptr<IdType>(), coo->col.Ptr<IdType>(),
nnz, col_bits, pos.Ptr<IdType>()); nnz, col_bits, pos.Ptr<IdType>());
auto sorted = Sort(pos, num_bits); auto sorted = Sort(pos, num_bits);
CUDA_KERNEL_CALL(_COODecodeEdgesKernel, nb, nt, 0, thr_entry->stream, CUDA_KERNEL_CALL(_COODecodeEdgesKernel, nb, nt, 0, stream,
sorted.first.Ptr<IdType>(), nnz, col_bits, sorted.first.Ptr<IdType>(), nnz, col_bits,
coo->row.Ptr<IdType>(), coo->col.Ptr<IdType>()); coo->row.Ptr<IdType>(), coo->col.Ptr<IdType>());
...@@ -159,7 +159,7 @@ template <DLDeviceType XPU, typename IdType> ...@@ -159,7 +159,7 @@ template <DLDeviceType XPU, typename IdType>
std::pair<bool, bool> COOIsSorted(COOMatrix coo) { std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
const int64_t nnz = coo.row->shape[0]; const int64_t nnz = coo.row->shape[0];
const auto& ctx = coo.row->ctx; const auto& ctx = coo.row->ctx;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
// We allocate a workspace of 2*nnz bytes. It wastes a little bit memory but should // We allocate a workspace of 2*nnz bytes. It wastes a little bit memory but should
// be fine. // be fine.
...@@ -167,7 +167,7 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo) { ...@@ -167,7 +167,7 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
int8_t* col_flags = static_cast<int8_t*>(device->AllocWorkspace(ctx, nnz)); int8_t* col_flags = static_cast<int8_t*>(device->AllocWorkspace(ctx, nnz));
const int nt = cuda::FindNumThreads(nnz); const int nt = cuda::FindNumThreads(nnz);
const int nb = (nnz + nt - 1) / nt; const int nb = (nnz + nt - 1) / nt;
CUDA_KERNEL_CALL(_COOIsSortedKernel, nb, nt, 0, thr_entry->stream, CUDA_KERNEL_CALL(_COOIsSortedKernel, nb, nt, 0, stream,
coo.row.Ptr<IdType>(), coo.col.Ptr<IdType>(), coo.row.Ptr<IdType>(), coo.col.Ptr<IdType>(),
nnz, row_flags, col_flags); nnz, row_flags, col_flags);
......
...@@ -23,11 +23,12 @@ COOMatrix CSRToCOO(CSRMatrix csr) { ...@@ -23,11 +23,12 @@ COOMatrix CSRToCOO(CSRMatrix csr) {
template <> template <>
COOMatrix CSRToCOO<kDLGPU, int32_t>(CSRMatrix csr) { COOMatrix CSRToCOO<kDLGPU, int32_t>(CSRMatrix csr) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
// allocate cusparse handle if needed // allocate cusparse handle if needed
if (!thr_entry->cusparse_handle) { if (!thr_entry->cusparse_handle) {
CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle))); CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
} }
CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, thr_entry->stream)); CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, stream));
NDArray indptr = csr.indptr, indices = csr.indices, data = csr.data; NDArray indptr = csr.indptr, indices = csr.indices, data = csr.data;
const int32_t* indptr_ptr = static_cast<int32_t*>(indptr->data); const int32_t* indptr_ptr = static_cast<int32_t*>(indptr->data);
...@@ -78,16 +79,17 @@ __global__ void _RepeatKernel( ...@@ -78,16 +79,17 @@ __global__ void _RepeatKernel(
template <> template <>
COOMatrix CSRToCOO<kDLGPU, int64_t>(CSRMatrix csr) { COOMatrix CSRToCOO<kDLGPU, int64_t>(CSRMatrix csr) {
const auto& ctx = csr.indptr->ctx; const auto& ctx = csr.indptr->ctx;
cudaStream_t stream = runtime::getCurrentCUDAStream();
const int64_t nnz = csr.indices->shape[0]; const int64_t nnz = csr.indices->shape[0];
const auto nbits = csr.indptr->dtype.bits; const auto nbits = csr.indptr->dtype.bits;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
IdArray rowids = Range(0, csr.num_rows, nbits, ctx); IdArray rowids = Range(0, csr.num_rows, nbits, ctx);
IdArray ret_row = NewIdArray(nnz, ctx, nbits); IdArray ret_row = NewIdArray(nnz, ctx, nbits);
const int nt = 256; const int nt = 256;
const int nb = (nnz + nt - 1) / nt; const int nb = (nnz + nt - 1) / nt;
CUDA_KERNEL_CALL(_RepeatKernel, CUDA_KERNEL_CALL(_RepeatKernel,
nb, nt, 0, thr_entry->stream, nb, nt, 0, stream,
rowids.Ptr<int64_t>(), rowids.Ptr<int64_t>(),
csr.indptr.Ptr<int64_t>(), ret_row.Ptr<int64_t>(), csr.indptr.Ptr<int64_t>(), ret_row.Ptr<int64_t>(),
csr.num_rows, nnz); csr.num_rows, nnz);
...@@ -114,11 +116,12 @@ COOMatrix CSRToCOODataAsOrder<kDLGPU, int32_t>(CSRMatrix csr) { ...@@ -114,11 +116,12 @@ COOMatrix CSRToCOODataAsOrder<kDLGPU, int32_t>(CSRMatrix csr) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
auto device = runtime::DeviceAPI::Get(coo.row->ctx); auto device = runtime::DeviceAPI::Get(coo.row->ctx);
cudaStream_t stream = runtime::getCurrentCUDAStream();
// allocate cusparse handle if needed // allocate cusparse handle if needed
if (!thr_entry->cusparse_handle) { if (!thr_entry->cusparse_handle) {
CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle))); CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
} }
CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, thr_entry->stream)); CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, stream));
NDArray row = coo.row, col = coo.col, data = coo.data; NDArray row = coo.row, col = coo.col, data = coo.data;
int32_t* row_ptr = static_cast<int32_t*>(row->data); int32_t* row_ptr = static_cast<int32_t*>(row->data);
......
...@@ -34,7 +34,7 @@ NDArray CSRGetData( ...@@ -34,7 +34,7 @@ NDArray CSRGetData(
if (rstlen == 0) if (rstlen == 0)
return rst; return rst;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const int nt = cuda::FindNumThreads(rstlen); const int nt = cuda::FindNumThreads(rstlen);
const int nb = (rstlen + nt - 1) / nt; const int nb = (rstlen + nt - 1) / nt;
if (return_eids) if (return_eids)
...@@ -43,7 +43,7 @@ NDArray CSRGetData( ...@@ -43,7 +43,7 @@ NDArray CSRGetData(
// TODO(minjie): use binary search for sorted csr // TODO(minjie): use binary search for sorted csr
CUDA_KERNEL_CALL(cuda::_LinearSearchKernel, CUDA_KERNEL_CALL(cuda::_LinearSearchKernel,
nb, nt, 0, thr_entry->stream, nb, nt, 0, stream,
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(),
CSRHasData(csr)? csr.data.Ptr<IdType>() : nullptr, CSRHasData(csr)? csr.data.Ptr<IdType>() : nullptr,
rows.Ptr<IdType>(), cols.Ptr<IdType>(), rows.Ptr<IdType>(), cols.Ptr<IdType>(),
......
...@@ -37,13 +37,14 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm( ...@@ -37,13 +37,14 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm(
auto ctx = A.indptr->ctx; auto ctx = A.indptr->ctx;
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
const DType* A_weights = A_weights_array.Ptr<DType>(); const DType* A_weights = A_weights_array.Ptr<DType>();
const DType* B_weights = B_weights_array.Ptr<DType>(); const DType* B_weights = B_weights_array.Ptr<DType>();
// allocate cusparse handle if needed // allocate cusparse handle if needed
if (!thr_entry->cusparse_handle) { if (!thr_entry->cusparse_handle) {
CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle))); CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
} }
CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, thr_entry->stream)); CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, stream));
// all one data array // all one data array
cusparseSpMatDescr_t matA, matB, matC; cusparseSpMatDescr_t matA, matB, matC;
IdArray dC_csrOffsets = IdArray::Empty({A.num_rows+1}, A.indptr->dtype, A.indptr->ctx); IdArray dC_csrOffsets = IdArray::Empty({A.num_rows+1}, A.indptr->dtype, A.indptr->ctx);
...@@ -145,6 +146,7 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm( ...@@ -145,6 +146,7 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm(
auto ctx = A.indptr->ctx; auto ctx = A.indptr->ctx;
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
auto idtype = A.indptr->dtype; auto idtype = A.indptr->dtype;
auto dtype = A_weights_array->dtype; auto dtype = A_weights_array->dtype;
const DType* A_weights = A_weights_array.Ptr<DType>(); const DType* A_weights = A_weights_array.Ptr<DType>();
...@@ -152,7 +154,7 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm( ...@@ -152,7 +154,7 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm(
if (!thr_entry->cusparse_handle) { if (!thr_entry->cusparse_handle) {
CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle))); CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
} }
CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, thr_entry->stream)); CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, stream));
CUSPARSE_CALL(cusparseSetPointerMode( CUSPARSE_CALL(cusparseSetPointerMode(
thr_entry->cusparse_handle, CUSPARSE_POINTER_MODE_HOST)); thr_entry->cusparse_handle, CUSPARSE_POINTER_MODE_HOST));
......
...@@ -37,7 +37,7 @@ __global__ void _SegmentIsSorted( ...@@ -37,7 +37,7 @@ __global__ void _SegmentIsSorted(
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
bool CSRIsSorted(CSRMatrix csr) { bool CSRIsSorted(CSRMatrix csr) {
const auto& ctx = csr.indptr->ctx; const auto& ctx = csr.indptr->ctx;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
// We allocate a workspace of num_rows bytes. It wastes a little bit memory but should // We allocate a workspace of num_rows bytes. It wastes a little bit memory but should
// be fine. // be fine.
...@@ -45,7 +45,7 @@ bool CSRIsSorted(CSRMatrix csr) { ...@@ -45,7 +45,7 @@ bool CSRIsSorted(CSRMatrix csr) {
const int nt = cuda::FindNumThreads(csr.num_rows); const int nt = cuda::FindNumThreads(csr.num_rows);
const int nb = (csr.num_rows + nt - 1) / nt; const int nb = (csr.num_rows + nt - 1) / nt;
CUDA_KERNEL_CALL(_SegmentIsSorted, CUDA_KERNEL_CALL(_SegmentIsSorted,
nb, nt, 0, thr_entry->stream, nb, nt, 0, stream,
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(),
csr.num_rows, flags); csr.num_rows, flags);
bool ret = cuda::AllTrue(flags, csr.num_rows, ctx); bool ret = cuda::AllTrue(flags, csr.num_rows, ctx);
...@@ -65,11 +65,12 @@ template <> ...@@ -65,11 +65,12 @@ template <>
void CSRSort_<kDLGPU, int32_t>(CSRMatrix* csr) { void CSRSort_<kDLGPU, int32_t>(CSRMatrix* csr) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
auto device = runtime::DeviceAPI::Get(csr->indptr->ctx); auto device = runtime::DeviceAPI::Get(csr->indptr->ctx);
cudaStream_t stream = runtime::getCurrentCUDAStream();
// allocate cusparse handle if needed // allocate cusparse handle if needed
if (!thr_entry->cusparse_handle) { if (!thr_entry->cusparse_handle) {
CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle))); CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
} }
CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, thr_entry->stream)); CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, stream));
NDArray indptr = csr->indptr; NDArray indptr = csr->indptr;
NDArray indices = csr->indices; NDArray indices = csr->indices;
...@@ -108,7 +109,7 @@ void CSRSort_<kDLGPU, int32_t>(CSRMatrix* csr) { ...@@ -108,7 +109,7 @@ void CSRSort_<kDLGPU, int32_t>(CSRMatrix* csr) {
template <> template <>
void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr) { void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
auto device = runtime::DeviceAPI::Get(csr->indptr->ctx); auto device = runtime::DeviceAPI::Get(csr->indptr->ctx);
const auto& ctx = csr->indptr->ctx; const auto& ctx = csr->indptr->ctx;
...@@ -130,13 +131,13 @@ void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr) { ...@@ -130,13 +131,13 @@ void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr) {
size_t workspace_size = 0; size_t workspace_size = 0;
CUDA_CALL(cub::DeviceSegmentedRadixSort::SortPairs(nullptr, workspace_size, CUDA_CALL(cub::DeviceSegmentedRadixSort::SortPairs(nullptr, workspace_size,
key_in, key_out, value_in, value_out, key_in, key_out, value_in, value_out,
nnz, csr->num_rows, offsets, offsets + 1, 0, sizeof(int64_t)*8, thr_entry->stream)); nnz, csr->num_rows, offsets, offsets + 1, 0, sizeof(int64_t)*8, stream));
void* workspace = device->AllocWorkspace(ctx, workspace_size); void* workspace = device->AllocWorkspace(ctx, workspace_size);
// Compute // Compute
CUDA_CALL(cub::DeviceSegmentedRadixSort::SortPairs(workspace, workspace_size, CUDA_CALL(cub::DeviceSegmentedRadixSort::SortPairs(workspace, workspace_size,
key_in, key_out, value_in, value_out, key_in, key_out, value_in, value_out,
nnz, csr->num_rows, offsets, offsets + 1, 0, sizeof(int64_t)*8, thr_entry->stream)); nnz, csr->num_rows, offsets, offsets + 1, 0, sizeof(int64_t)*8, stream));
csr->sorted = true; csr->sorted = true;
csr->indices = new_indices; csr->indices = new_indices;
......
...@@ -33,12 +33,13 @@ std::pair<CSRMatrix, NDArray> CusparseCsrgeam2( ...@@ -33,12 +33,13 @@ std::pair<CSRMatrix, NDArray> CusparseCsrgeam2(
auto ctx = A.indptr->ctx; auto ctx = A.indptr->ctx;
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
const DType* A_weights = A_weights_array.Ptr<DType>(); const DType* A_weights = A_weights_array.Ptr<DType>();
const DType* B_weights = B_weights_array.Ptr<DType>(); const DType* B_weights = B_weights_array.Ptr<DType>();
// allocate cusparse handle if needed // allocate cusparse handle if needed
if (!thr_entry->cusparse_handle) if (!thr_entry->cusparse_handle)
CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle))); CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, thr_entry->stream)); CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, stream));
cusparseMatDescr_t matA, matB, matC; cusparseMatDescr_t matA, matB, matC;
CUSPARSE_CALL(cusparseCreateMatDescr(&matA)); CUSPARSE_CALL(cusparseCreateMatDescr(&matA));
......
...@@ -22,11 +22,12 @@ CSRMatrix CSRTranspose(CSRMatrix csr) { ...@@ -22,11 +22,12 @@ CSRMatrix CSRTranspose(CSRMatrix csr) {
template <> template <>
CSRMatrix CSRTranspose<kDLGPU, int32_t>(CSRMatrix csr) { CSRMatrix CSRTranspose<kDLGPU, int32_t>(CSRMatrix csr) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
// allocate cusparse handle if needed // allocate cusparse handle if needed
if (!thr_entry->cusparse_handle) { if (!thr_entry->cusparse_handle) {
CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle))); CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
} }
CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, thr_entry->stream)); CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, stream));
NDArray indptr = csr.indptr, indices = csr.indices, data = csr.data; NDArray indptr = csr.indptr, indices = csr.indices, data = csr.data;
const int64_t nnz = indices->shape[0]; const int64_t nnz = indices->shape[0];
......
...@@ -18,7 +18,7 @@ namespace array { ...@@ -18,7 +18,7 @@ namespace array {
namespace { namespace {
cudaStream_t cudaStream = runtime::CUDAThreadEntry::ThreadLocal()->stream; cudaStream_t cudaStream = runtime::getCurrentCUDAStream();
template<typename IdType, bool include> template<typename IdType, bool include>
__global__ void _IsInKernel( __global__ void _IsInKernel(
...@@ -99,7 +99,7 @@ IdArray _PerformFilter( ...@@ -99,7 +99,7 @@ IdArray _PerformFilter(
device->FreeWorkspace(ctx, workspace); device->FreeWorkspace(ctx, workspace);
} }
// copy number using the internal stream CUDAThreadEntry::ThreadLocal()->stream; // copy number using the internal current stream;
IdType num_unique; IdType num_unique;
device->CopyDataFromTo(prefix+size, 0, device->CopyDataFromTo(prefix+size, 0,
&num_unique, 0, &num_unique, 0,
......
...@@ -96,7 +96,7 @@ void _Merge(IdType** arrs, IdType* prefix, IdType* offset, IdType* out, ...@@ -96,7 +96,7 @@ void _Merge(IdType** arrs, IdType* prefix, IdType* offset, IdType* out,
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) { COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
auto device = runtime::DeviceAPI::Get(coos[0].row->ctx); auto device = runtime::DeviceAPI::Get(coos[0].row->ctx);
uint64_t src_offset = 0, dst_offset = 0; uint64_t src_offset = 0, dst_offset = 0;
bool has_data = false; bool has_data = false;
...@@ -129,7 +129,6 @@ COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) { ...@@ -129,7 +129,6 @@ COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
auto ctx = coos[0].row->ctx; auto ctx = coos[0].row->ctx;
auto dtype = coos[0].row->dtype; auto dtype = coos[0].row->dtype;
auto stream = thr_entry->stream;
IdType n_elements = 0; IdType n_elements = 0;
device->CopyDataFromTo( device->CopyDataFromTo(
......
...@@ -200,6 +200,7 @@ void SegmentMM(const NDArray A, ...@@ -200,6 +200,7 @@ void SegmentMM(const NDArray A,
bool a_trans, bool b_trans) { bool a_trans, bool b_trans) {
SWITCH_BITS(bits, DType, { SWITCH_BITS(bits, DType, {
auto device = runtime::DeviceAPI::Get(A->ctx); auto device = runtime::DeviceAPI::Get(A->ctx);
cudaStream_t stream = runtime::getCurrentCUDAStream();
const DType *A_data = A.Ptr<DType>(); const DType *A_data = A.Ptr<DType>();
const DType *B_data = B.Ptr<DType>(); const DType *B_data = B.Ptr<DType>();
const IdType* seglen_A_data = seglen_A.Ptr<IdType>(); const IdType* seglen_A_data = seglen_A.Ptr<IdType>();
...@@ -212,8 +213,7 @@ void SegmentMM(const NDArray A, ...@@ -212,8 +213,7 @@ void SegmentMM(const NDArray A,
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
if (!thr_entry->cublas_handle) if (!thr_entry->cublas_handle)
CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle))); CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle)));
CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle, CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle, stream));
thr_entry->stream));
IdType m_offset = 0; IdType m_offset = 0;
for (IdType etype = 0; etype < num_rel; ++etype) { for (IdType etype = 0; etype < num_rel; ++etype) {
...@@ -254,6 +254,7 @@ void SegmentMMBackwardB(const NDArray A, ...@@ -254,6 +254,7 @@ void SegmentMMBackwardB(const NDArray A,
const NDArray seglen) { const NDArray seglen) {
SWITCH_BITS(bits, DType, { SWITCH_BITS(bits, DType, {
auto device = runtime::DeviceAPI::Get(A->ctx); auto device = runtime::DeviceAPI::Get(A->ctx);
cudaStream_t stream = runtime::getCurrentCUDAStream();
const DType *A_data = A.Ptr<DType>(); const DType *A_data = A.Ptr<DType>();
const DType *dC_data = dC.Ptr<DType>(); const DType *dC_data = dC.Ptr<DType>();
const IdType* seglen_data = seglen.Ptr<IdType>(); const IdType* seglen_data = seglen.Ptr<IdType>();
...@@ -266,8 +267,7 @@ void SegmentMMBackwardB(const NDArray A, ...@@ -266,8 +267,7 @@ void SegmentMMBackwardB(const NDArray A,
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
if (!thr_entry->cublas_handle) if (!thr_entry->cublas_handle)
CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle))); CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle)));
CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle, CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle, stream));
thr_entry->stream));
IdType k_offset = 0; IdType k_offset = 0;
for (IdType etype = 0; etype < num_rel; ++etype) { for (IdType etype = 0; etype < num_rel; ++etype) {
...@@ -314,7 +314,7 @@ void GatherMM(const NDArray A, ...@@ -314,7 +314,7 @@ void GatherMM(const NDArray A,
const NDArray idx_b) { const NDArray idx_b) {
SWITCH_BITS(bits, DType, { SWITCH_BITS(bits, DType, {
auto device = runtime::DeviceAPI::Get(A->ctx); auto device = runtime::DeviceAPI::Get(A->ctx);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int64_t out_len = B->shape[2]; // cols of B int64_t out_len = B->shape[2]; // cols of B
int64_t in_len = A->shape[1]; // cols of A int64_t in_len = A->shape[1]; // cols of A
const int64_t tot_num_rows = A->shape[0]; const int64_t tot_num_rows = A->shape[0];
...@@ -324,7 +324,7 @@ void GatherMM(const NDArray A, ...@@ -324,7 +324,7 @@ void GatherMM(const NDArray A,
const dim3 nblks(nbx); const dim3 nblks(nbx);
const dim3 nthrs(ntx); const dim3 nthrs(ntx);
CUDA_KERNEL_CALL((cuda::GatherMMScatterKernel<IdType, DType>), CUDA_KERNEL_CALL((cuda::GatherMMScatterKernel<IdType, DType>),
nblks, nthrs, 0, thr_entry->stream, nblks, nthrs, 0, stream,
A.Ptr<DType>(), A.Ptr<DType>(),
B.Ptr<DType>(), B.Ptr<DType>(),
C.Ptr<DType>(), C.Ptr<DType>(),
...@@ -357,7 +357,7 @@ void GatherMMScatter(const NDArray A, ...@@ -357,7 +357,7 @@ void GatherMMScatter(const NDArray A,
const NDArray idx_c) { const NDArray idx_c) {
SWITCH_BITS(bits, DType, { SWITCH_BITS(bits, DType, {
auto device = runtime::DeviceAPI::Get(A->ctx); auto device = runtime::DeviceAPI::Get(A->ctx);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const IdType *idx_c_data = idx_c.Ptr<IdType>(); const IdType *idx_c_data = idx_c.Ptr<IdType>();
int64_t out_len = (B->ndim == 2)? B->shape[1] : B->shape[2]; // cols of B int64_t out_len = (B->ndim == 2)? B->shape[1] : B->shape[2]; // cols of B
int64_t in_len = A->shape[1]; // cols of A int64_t in_len = A->shape[1]; // cols of A
...@@ -369,7 +369,7 @@ void GatherMMScatter(const NDArray A, ...@@ -369,7 +369,7 @@ void GatherMMScatter(const NDArray A,
const dim3 nthrs(ntx); const dim3 nthrs(ntx);
if (B->ndim == 3) { if (B->ndim == 3) {
CUDA_KERNEL_CALL((cuda::GatherMMScatterKernel<IdType, DType>), CUDA_KERNEL_CALL((cuda::GatherMMScatterKernel<IdType, DType>),
nblks, nthrs, 0, thr_entry->stream, nblks, nthrs, 0, stream,
A.Ptr<DType>(), A.Ptr<DType>(),
B.Ptr<DType>(), B.Ptr<DType>(),
C.Ptr<DType>(), C.Ptr<DType>(),
...@@ -381,7 +381,7 @@ void GatherMMScatter(const NDArray A, ...@@ -381,7 +381,7 @@ void GatherMMScatter(const NDArray A,
// Custom kernel for W_grad[idx_c[i]] = H^T[i] * C.grad[i] // Custom kernel for W_grad[idx_c[i]] = H^T[i] * C.grad[i]
// This kernel accesses rows of A in a transposed way w/o explicitly converting A // This kernel accesses rows of A in a transposed way w/o explicitly converting A
CUDA_KERNEL_CALL((cuda::GatherMMScatterKernel2<IdType, DType>), CUDA_KERNEL_CALL((cuda::GatherMMScatterKernel2<IdType, DType>),
nblks, nthrs, 0, thr_entry->stream, nblks, nthrs, 0, stream,
A.Ptr<DType>(), A.Ptr<DType>(),
B.Ptr<DType>(), B.Ptr<DType>(),
C.Ptr<DType>(), C.Ptr<DType>(),
......
...@@ -128,7 +128,7 @@ void GESpMMCsr( ...@@ -128,7 +128,7 @@ void GESpMMCsr(
const DType *efeat_data = efeat.Ptr<DType>(); const DType *efeat_data = efeat.Ptr<DType>();
DType *out_data = out.Ptr<DType>(); DType *out_data = out.Ptr<DType>();
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const int ntx = 32; const int ntx = 32;
const int nty = 32; const int nty = 32;
...@@ -139,7 +139,7 @@ void GESpMMCsr( ...@@ -139,7 +139,7 @@ void GESpMMCsr(
const int sh_mem_size = 0; const int sh_mem_size = 0;
CUDA_KERNEL_CALL((GESpMMKernel<Idx, DType, BinaryOp>), CUDA_KERNEL_CALL((GESpMMKernel<Idx, DType, BinaryOp>),
nblks, nthrs, sh_mem_size, thr_entry->stream, nblks, nthrs, sh_mem_size, stream,
ufeat_data, efeat_data, out_data, ufeat_data, efeat_data, out_data,
indptr, indices, indptr, indices,
csr.num_rows, csr.num_cols, csr.num_rows, csr.num_cols,
......
...@@ -150,14 +150,14 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling( ...@@ -150,14 +150,14 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
IdType* out_row_data = out_row.Ptr<IdType>(); IdType* out_row_data = out_row.Ptr<IdType>();
IdType* out_col_data = out_col.Ptr<IdType>(); IdType* out_col_data = out_col.Ptr<IdType>();
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const int nt = cuda::FindNumThreads(num_actual_samples); const int nt = cuda::FindNumThreads(num_actual_samples);
const int nb = (num_actual_samples + nt - 1) / nt; const int nb = (num_actual_samples + nt - 1) / nt;
std::pair<IdArray, IdArray> result; std::pair<IdArray, IdArray> result;
int64_t num_out; int64_t num_out;
CUDA_KERNEL_CALL(_GlobalUniformNegativeSamplingKernel, CUDA_KERNEL_CALL(_GlobalUniformNegativeSamplingKernel,
nb, nt, 0, thr_entry->stream, nb, nt, 0, stream,
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(),
row_data, col_data, num_row, num_col, num_actual_samples, num_trials, row_data, col_data, num_row, num_col, num_actual_samples, num_trials,
exclude_self_loops, RandomEngine::ThreadLocal()->RandInt32()); exclude_self_loops, RandomEngine::ThreadLocal()->RandInt32());
...@@ -168,12 +168,10 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling( ...@@ -168,12 +168,10 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
PairIterator<IdType> begin(row_data, col_data); PairIterator<IdType> begin(row_data, col_data);
PairIterator<IdType> out_begin(out_row_data, out_col_data); PairIterator<IdType> out_begin(out_row_data, out_col_data);
CUDA_CALL(cub::DeviceSelect::If( CUDA_CALL(cub::DeviceSelect::If(
nullptr, tmp_size, begin, out_begin, num_out_cuda, num_actual_samples, op, nullptr, tmp_size, begin, out_begin, num_out_cuda, num_actual_samples, op, stream));
thr_entry->stream));
void* tmp = device->AllocWorkspace(ctx, tmp_size); void* tmp = device->AllocWorkspace(ctx, tmp_size);
CUDA_CALL(cub::DeviceSelect::If( CUDA_CALL(cub::DeviceSelect::If(
tmp, tmp_size, begin, out_begin, num_out_cuda, num_actual_samples, op, tmp, tmp_size, begin, out_begin, num_out_cuda, num_actual_samples, op, stream));
thr_entry->stream));
num_out = cuda::GetCUDAScalar(device, ctx, num_out_cuda); num_out = cuda::GetCUDAScalar(device, ctx, num_out_cuda);
if (!replace) { if (!replace) {
...@@ -185,19 +183,17 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling( ...@@ -185,19 +183,17 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
SortOrderedPairs( SortOrderedPairs(
device, ctx, out_row_data, out_col_data, unique_row_data, unique_col_data, device, ctx, out_row_data, out_col_data, unique_row_data, unique_col_data,
num_out, thr_entry->stream); num_out, stream);
size_t tmp_size_unique = 0; size_t tmp_size_unique = 0;
void* tmp_unique = nullptr; void* tmp_unique = nullptr;
CUDA_CALL(cub::DeviceSelect::Unique( CUDA_CALL(cub::DeviceSelect::Unique(
nullptr, tmp_size_unique, out_begin, unique_begin, num_out_cuda, num_out, nullptr, tmp_size_unique, out_begin, unique_begin, num_out_cuda, num_out, stream));
thr_entry->stream));
tmp_unique = (tmp_size_unique > tmp_size) ? tmp_unique = (tmp_size_unique > tmp_size) ?
device->AllocWorkspace(ctx, tmp_size_unique) : device->AllocWorkspace(ctx, tmp_size_unique) :
tmp; // reuse buffer tmp; // reuse buffer
CUDA_CALL(cub::DeviceSelect::Unique( CUDA_CALL(cub::DeviceSelect::Unique(
tmp_unique, tmp_size_unique, out_begin, unique_begin, num_out_cuda, num_out, tmp_unique, tmp_size_unique, out_begin, unique_begin, num_out_cuda, num_out, stream));
thr_entry->stream));
num_out = cuda::GetCUDAScalar(device, ctx, num_out_cuda); num_out = cuda::GetCUDAScalar(device, ctx, num_out_cuda);
num_out = std::min(num_samples, num_out); num_out = std::min(num_samples, num_out);
......
...@@ -247,8 +247,7 @@ COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat, ...@@ -247,8 +247,7 @@ COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat,
const bool replace) { const bool replace) {
const auto& ctx = rows->ctx; const auto& ctx = rows->ctx;
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
cudaStream_t stream = runtime::getCurrentCUDAStream();
cudaStream_t stream = runtime::CUDAThreadEntry::ThreadLocal()->stream;
const int64_t num_rows = rows->shape[0]; const int64_t num_rows = rows->shape[0];
const IdType * const slice_rows = static_cast<const IdType*>(rows->data); const IdType * const slice_rows = static_cast<const IdType*>(rows->data);
...@@ -308,7 +307,7 @@ COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat, ...@@ -308,7 +307,7 @@ COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat,
// TODO(dlasalle): use pinned memory to overlap with the actual sampling, and wait on // TODO(dlasalle): use pinned memory to overlap with the actual sampling, and wait on
// a cudaevent // a cudaevent
IdType new_len; IdType new_len;
// copy using the internal stream: CUDAThreadEntry::ThreadLocal->stream // copy using the internal current stream
device->CopyDataFromTo(out_ptr, num_rows * sizeof(new_len), &new_len, 0, device->CopyDataFromTo(out_ptr, num_rows * sizeof(new_len), &new_len, 0,
sizeof(new_len), sizeof(new_len),
ctx, ctx,
......
...@@ -424,8 +424,7 @@ COOMatrix CSRRowWiseSampling(CSRMatrix mat, ...@@ -424,8 +424,7 @@ COOMatrix CSRRowWiseSampling(CSRMatrix mat,
bool replace) { bool replace) {
const auto& ctx = rows->ctx; const auto& ctx = rows->ctx;
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
cudaStream_t stream = runtime::getCurrentCUDAStream();
cudaStream_t stream = runtime::CUDAThreadEntry::ThreadLocal()->stream;
const int64_t num_rows = rows->shape[0]; const int64_t num_rows = rows->shape[0];
const IdType * const slice_rows = static_cast<const IdType*>(rows->data); const IdType * const slice_rows = static_cast<const IdType*>(rows->data);
...@@ -489,7 +488,7 @@ COOMatrix CSRRowWiseSampling(CSRMatrix mat, ...@@ -489,7 +488,7 @@ COOMatrix CSRRowWiseSampling(CSRMatrix mat,
// TODO(Xin): The copy here is too small, and the overhead of creating // TODO(Xin): The copy here is too small, and the overhead of creating
// cuda events cannot be ignored. Just use synchronized copy. // cuda events cannot be ignored. Just use synchronized copy.
IdType temp_len; IdType temp_len;
// copy using the internal dgl stream: CUDAThreadEntry::ThreadLocal()->stream // copy using the internal current stream.
device->CopyDataFromTo(temp_ptr, num_rows * sizeof(temp_len), &temp_len, 0, device->CopyDataFromTo(temp_ptr, num_rows * sizeof(temp_len), &temp_len, 0,
sizeof(temp_len), sizeof(temp_len),
ctx, ctx,
...@@ -520,7 +519,7 @@ COOMatrix CSRRowWiseSampling(CSRMatrix mat, ...@@ -520,7 +519,7 @@ COOMatrix CSRRowWiseSampling(CSRMatrix mat,
// TODO(dlasalle): use pinned memory to overlap with the actual sampling, and wait on // TODO(dlasalle): use pinned memory to overlap with the actual sampling, and wait on
// a cudaevent // a cudaevent
IdType new_len; IdType new_len;
// copy using the internal dgl stream: CUDAThreadEntry::ThreadLocal()->stream // copy using the internal current stream.
device->CopyDataFromTo(out_ptr, num_rows * sizeof(new_len), &new_len, 0, device->CopyDataFromTo(out_ptr, num_rows * sizeof(new_len), &new_len, 0,
sizeof(new_len), sizeof(new_len),
ctx, ctx,
......
...@@ -275,7 +275,7 @@ void SDDMMCoo( ...@@ -275,7 +275,7 @@ void SDDMMCoo(
const DType *lhs_data = lhs.Ptr<DType>(); const DType *lhs_data = lhs.Ptr<DType>();
const DType *rhs_data = rhs.Ptr<DType>(); const DType *rhs_data = rhs.Ptr<DType>();
DType *out_data = out.Ptr<DType>(); DType *out_data = out.Ptr<DType>();
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int64_t *lhs_off = nullptr, *rhs_off = nullptr; int64_t *lhs_off = nullptr, *rhs_off = nullptr;
int64_t len = bcast.out_len, int64_t len = bcast.out_len,
...@@ -295,7 +295,7 @@ void SDDMMCoo( ...@@ -295,7 +295,7 @@ void SDDMMCoo(
const dim3 nthrs(ntx, nty); const dim3 nthrs(ntx, nty);
BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, { BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, {
CUDA_KERNEL_CALL((SDDMMCooTreeReduceKernel<Idx, DType, UseBcast, UseIdx, LhsTarget, RhsTarget>), CUDA_KERNEL_CALL((SDDMMCooTreeReduceKernel<Idx, DType, UseBcast, UseIdx, LhsTarget, RhsTarget>),
nblks, nthrs, 0, thr_entry->stream, nblks, nthrs, 0, stream,
lhs_data, rhs_data, out_data, lhs_data, rhs_data, out_data,
row, col, edge_map, row, col, edge_map,
coo.num_rows, coo.num_cols, nnz, reduce_dim, coo.num_rows, coo.num_cols, nnz, reduce_dim,
...@@ -311,7 +311,7 @@ void SDDMMCoo( ...@@ -311,7 +311,7 @@ void SDDMMCoo(
const dim3 nthrs(ntx, nty); const dim3 nthrs(ntx, nty);
BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, { BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, {
CUDA_KERNEL_CALL((SDDMMCooKernel<Idx, DType, Op, UseBcast, UseIdx, LhsTarget, RhsTarget>), CUDA_KERNEL_CALL((SDDMMCooKernel<Idx, DType, Op, UseBcast, UseIdx, LhsTarget, RhsTarget>),
nblks, nthrs, 0, thr_entry->stream, nblks, nthrs, 0, stream,
lhs_data, rhs_data, out_data, lhs_data, rhs_data, out_data,
row, col, edge_map, row, col, edge_map,
coo.num_rows, coo.num_cols, nnz, reduce_dim, coo.num_rows, coo.num_cols, nnz, reduce_dim,
...@@ -343,7 +343,7 @@ void SDDMMCsr( ...@@ -343,7 +343,7 @@ void SDDMMCsr(
const DType *lhs_data = lhs.Ptr<DType>(); const DType *lhs_data = lhs.Ptr<DType>();
const DType *rhs_data = rhs.Ptr<DType>(); const DType *rhs_data = rhs.Ptr<DType>();
DType *out_data = out.Ptr<DType>(); DType *out_data = out.Ptr<DType>();
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int64_t N = csr.num_rows, M = csr.num_cols, E = csr.indices->shape[0]; int64_t N = csr.num_rows, M = csr.num_cols, E = csr.indices->shape[0];
int64_t *lhs_off = nullptr, *rhs_off = nullptr; int64_t *lhs_off = nullptr, *rhs_off = nullptr;
...@@ -362,7 +362,7 @@ void SDDMMCsr( ...@@ -362,7 +362,7 @@ void SDDMMCsr(
BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, { BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, {
CUDA_KERNEL_CALL((SDDMMCsrKernel<Idx, DType, Op, UseBcast, UseIdx, LhsTarget, RhsTarget>), CUDA_KERNEL_CALL((SDDMMCsrKernel<Idx, DType, Op, UseBcast, UseIdx, LhsTarget, RhsTarget>),
nblks, nthrs, 0, thr_entry->stream, nblks, nthrs, 0, stream,
lhs_data, rhs_data, out_data, lhs_data, rhs_data, out_data,
indptr, indices, edge_map, indptr, indices, edge_map,
N, M, E, reduce_dim, N, M, E, reduce_dim,
......
...@@ -24,7 +24,6 @@ void SDDMMCooHetero(const std::string& op, ...@@ -24,7 +24,6 @@ void SDDMMCooHetero(const std::string& op,
int rhs_target, int rhs_target,
const std::vector<dgl_type_t>& lhs_eid, const std::vector<dgl_type_t>& lhs_eid,
const std::vector<dgl_type_t>& rhs_eid) { const std::vector<dgl_type_t>& rhs_eid) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
SWITCH_BITS(bits, DType, { SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, { SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
......
...@@ -130,7 +130,7 @@ void SegmentReduce( ...@@ -130,7 +130,7 @@ void SegmentReduce(
DType* out_data = out.Ptr<DType>(); DType* out_data = out.Ptr<DType>();
IdType* arg_data = arg.Ptr<IdType>(); IdType* arg_data = arg.Ptr<IdType>();
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int64_t n = out->shape[0]; int64_t n = out->shape[0];
int64_t dim = 1; int64_t dim = 1;
for (int i = 1; i < out->ndim; ++i) for (int i = 1; i < out->ndim; ++i)
...@@ -144,7 +144,7 @@ void SegmentReduce( ...@@ -144,7 +144,7 @@ void SegmentReduce(
const dim3 nthrs(ntx, nty); const dim3 nthrs(ntx, nty);
// TODO(zihao): try cub's DeviceSegmentedReduce and compare the performance. // TODO(zihao): try cub's DeviceSegmentedReduce and compare the performance.
CUDA_KERNEL_CALL((SegmentReduceKernel<IdType, DType, ReduceOp>), CUDA_KERNEL_CALL((SegmentReduceKernel<IdType, DType, ReduceOp>),
nblks, nthrs, 0, thr_entry->stream, nblks, nthrs, 0, stream,
feat_data, offsets_data, out_data, arg_data, feat_data, offsets_data, out_data, arg_data,
n, dim); n, dim);
} }
...@@ -164,8 +164,8 @@ void ScatterAdd( ...@@ -164,8 +164,8 @@ void ScatterAdd(
const DType* feat_data = feat.Ptr<DType>(); const DType* feat_data = feat.Ptr<DType>();
const IdType* idx_data = idx.Ptr<IdType>(); const IdType* idx_data = idx.Ptr<IdType>();
DType *out_data = out.Ptr<DType>(); DType *out_data = out.Ptr<DType>();
auto *thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int64_t n = feat->shape[0]; int64_t n = feat->shape[0];
int64_t dim = 1; int64_t dim = 1;
for (int i = 1; i < out->ndim; ++i) for (int i = 1; i < out->ndim; ++i)
...@@ -178,7 +178,7 @@ void ScatterAdd( ...@@ -178,7 +178,7 @@ void ScatterAdd(
const dim3 nblks(nbx, nby); const dim3 nblks(nbx, nby);
const dim3 nthrs(ntx, nty); const dim3 nthrs(ntx, nty);
CUDA_KERNEL_CALL((ScatterAddKernel<IdType, DType>), CUDA_KERNEL_CALL((ScatterAddKernel<IdType, DType>),
nblks, nthrs, 0, thr_entry->stream, nblks, nthrs, 0, stream,
feat_data, idx_data, out_data, feat_data, idx_data, out_data,
n, dim); n, dim);
} }
...@@ -199,6 +199,7 @@ void UpdateGradMinMax_hetero(const HeteroGraphPtr& graph, ...@@ -199,6 +199,7 @@ void UpdateGradMinMax_hetero(const HeteroGraphPtr& graph,
const std::vector<NDArray>& list_idx, const std::vector<NDArray>& list_idx,
const std::vector<NDArray>& list_idx_types, const std::vector<NDArray>& list_idx_types,
std::vector<NDArray>* list_out) { std::vector<NDArray>* list_out) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
if (op == "copy_lhs" || op == "copy_rhs") { if (op == "copy_lhs" || op == "copy_rhs") {
std::vector<std::vector<dgl_id_t>> src_dst_ntypes(graph->NumVertexTypes(), std::vector<std::vector<dgl_id_t>> src_dst_ntypes(graph->NumVertexTypes(),
std::vector<dgl_id_t>()); std::vector<dgl_id_t>());
...@@ -221,14 +222,13 @@ void UpdateGradMinMax_hetero(const HeteroGraphPtr& graph, ...@@ -221,14 +222,13 @@ void UpdateGradMinMax_hetero(const HeteroGraphPtr& graph,
for (int i = 1; i < (*list_out)[type]->ndim; ++i) for (int i = 1; i < (*list_out)[type]->ndim; ++i)
dim *= (*list_out)[type]->shape[i]; dim *= (*list_out)[type]->shape[i];
int n = list_feat[dst_ntype]->shape[0]; int n = list_feat[dst_ntype]->shape[0];
auto *thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
const int th_per_row = 32; const int th_per_row = 32;
const int ntx = 128; const int ntx = 128;
const int nbx = FindNumBlocks<'x'>((n * th_per_row + ntx - 1) / ntx); const int nbx = FindNumBlocks<'x'>((n * th_per_row + ntx - 1) / ntx);
const dim3 nblks(nbx); const dim3 nblks(nbx);
const dim3 nthrs(ntx); const dim3 nthrs(ntx);
CUDA_KERNEL_CALL((UpdateGradMinMaxHeteroKernel<IdType, DType>), CUDA_KERNEL_CALL((UpdateGradMinMaxHeteroKernel<IdType, DType>),
nblks, nthrs, 0, thr_entry->stream, nblks, nthrs, 0, stream,
feat_data, idx_data, idx_type_data, feat_data, idx_data, idx_type_data,
out_data, n, dim, type); out_data, n, dim, type);
} }
...@@ -251,7 +251,7 @@ void BackwardSegmentCmp( ...@@ -251,7 +251,7 @@ void BackwardSegmentCmp(
const IdType* arg_data = arg.Ptr<IdType>(); const IdType* arg_data = arg.Ptr<IdType>();
DType *out_data = out.Ptr<DType>(); DType *out_data = out.Ptr<DType>();
auto *thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int64_t n = feat->shape[0]; int64_t n = feat->shape[0];
int64_t dim = 1; int64_t dim = 1;
for (int i = 1; i < out->ndim; ++i) for (int i = 1; i < out->ndim; ++i)
...@@ -264,7 +264,7 @@ void BackwardSegmentCmp( ...@@ -264,7 +264,7 @@ void BackwardSegmentCmp(
const dim3 nblks(nbx, nby); const dim3 nblks(nbx, nby);
const dim3 nthrs(ntx, nty); const dim3 nthrs(ntx, nty);
CUDA_KERNEL_CALL((BackwardSegmentCmpKernel<IdType, DType>), CUDA_KERNEL_CALL((BackwardSegmentCmpKernel<IdType, DType>),
nblks, nthrs, 0, thr_entry->stream, nblks, nthrs, 0, stream,
feat_data, arg_data, out_data, feat_data, arg_data, out_data,
n, dim); n, dim);
} }
......
...@@ -73,7 +73,7 @@ __global__ void _COOGetRowNNZKernel( ...@@ -73,7 +73,7 @@ __global__ void _COOGetRowNNZKernel(
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) { int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const auto& ctx = coo.row->ctx; const auto& ctx = coo.row->ctx;
IdType nnz = coo.row->shape[0]; IdType nnz = coo.row->shape[0];
IdType nt = 1024; IdType nt = 1024;
...@@ -81,7 +81,7 @@ int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) { ...@@ -81,7 +81,7 @@ int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) {
NDArray rst = NDArray::Empty({1}, coo.row->dtype, coo.row->ctx); NDArray rst = NDArray::Empty({1}, coo.row->dtype, coo.row->ctx);
_Fill(rst.Ptr<IdType>(), 1, IdType(0)); _Fill(rst.Ptr<IdType>(), 1, IdType(0));
CUDA_KERNEL_CALL(_COOGetRowNNZKernel, CUDA_KERNEL_CALL(_COOGetRowNNZKernel,
nb, nt, 0, thr_entry->stream, nb, nt, 0, stream,
coo.row.Ptr<IdType>(), rst.Ptr<IdType>(), coo.row.Ptr<IdType>(), rst.Ptr<IdType>(),
row, nnz); row, nnz);
rst = rst.CopyTo(DLContext{kDLCPU, 0}); rst = rst.CopyTo(DLContext{kDLCPU, 0});
...@@ -106,7 +106,7 @@ __global__ void _COOGetAllRowNNZKernel( ...@@ -106,7 +106,7 @@ __global__ void _COOGetAllRowNNZKernel(
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) { NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const auto& ctx = coo.row->ctx; const auto& ctx = coo.row->ctx;
IdType nnz = coo.row->shape[0]; IdType nnz = coo.row->shape[0];
IdType num_rows = coo.num_rows; IdType num_rows = coo.num_rows;
...@@ -119,7 +119,7 @@ NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) { ...@@ -119,7 +119,7 @@ NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) {
NDArray rst = NDArray::Empty({1}, coo.row->dtype, coo.row->ctx); NDArray rst = NDArray::Empty({1}, coo.row->dtype, coo.row->ctx);
_Fill(rst.Ptr<IdType>(), 1, IdType(0)); _Fill(rst.Ptr<IdType>(), 1, IdType(0));
CUDA_KERNEL_CALL(_COOGetRowNNZKernel, CUDA_KERNEL_CALL(_COOGetRowNNZKernel,
nb, nt, 0, thr_entry->stream, nb, nt, 0, stream,
coo.row.Ptr<IdType>(), rst.Ptr<IdType>(), coo.row.Ptr<IdType>(), rst.Ptr<IdType>(),
row, nnz); row, nnz);
return rst; return rst;
...@@ -129,7 +129,7 @@ NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) { ...@@ -129,7 +129,7 @@ NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) {
NDArray in_degrees = NDArray::Empty({num_rows}, rows->dtype, rows->ctx); NDArray in_degrees = NDArray::Empty({num_rows}, rows->dtype, rows->ctx);
_Fill(in_degrees.Ptr<IdType>(), num_rows, IdType(0)); _Fill(in_degrees.Ptr<IdType>(), num_rows, IdType(0));
CUDA_KERNEL_CALL(_COOGetAllRowNNZKernel, CUDA_KERNEL_CALL(_COOGetAllRowNNZKernel,
nb, nt, 0, thr_entry->stream, nb, nt, 0, stream,
coo.row.Ptr<IdType>(), in_degrees.Ptr<IdType>(), coo.row.Ptr<IdType>(), in_degrees.Ptr<IdType>(),
nnz); nnz);
return IndexSelect(in_degrees, rows); return IndexSelect(in_degrees, rows);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment