"vscode:/vscode.git/clone" did not exist on "155608d3cb3d04319bf2167c717dc9b8885cbefd"
Unverified Commit 1c9d2a03 authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

[Feature] Unify the cuda stream used in core library (#4480)



* Use an internal cuda stream for CopyDataFromTo

* small fix white space

* Fix to compile

* Make stream optional in copydata for compile

* fix lint issue

* Update cub functions to use internal stream

* Lint check

* Update CopyTo/CopyFrom/CopyFromTo to use internal stream

* Address comments

* Fix backward CUDA stream

* Avoid overloading CopyFromTo()

* Minor comment update

* Overload copydatafromto in cuda device api
Co-authored-by: default avatarxiny <xiny@nvidia.com>
parent 62af41c2
......@@ -51,6 +51,11 @@ void* CUDARawAlloc(size_t nbytes, cudaStream_t stream);
* \param ptr Pointer to the memory to be freed.
*/
void CUDARawDelete(void* ptr);
/*!
* \brief Get the current CUDA stream.
*/
cudaStream_t CUDACurrentStream();
#endif // DGL_USE_CUDA
}
......
......@@ -8,6 +8,7 @@
#include <c10/core/CPUAllocator.h>
#ifdef DGL_USE_CUDA
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAStream.h>
#include <cuda_runtime.h>
#endif // DGL_USE_CUDA
......@@ -32,6 +33,10 @@ TA_EXPORTS void* CUDARawAlloc(size_t nbytes, cudaStream_t stream) {
TA_EXPORTS void CUDARawDelete(void* ptr) {
c10::cuda::CUDACachingAllocator::raw_delete(ptr);
}
TA_EXPORTS cudaStream_t CUDACurrentStream() {
return at::cuda::getCurrentCUDAStream();
}
#endif // DGL_USE_CUDA
};
......
......@@ -353,22 +353,23 @@ void _TestUnitGraph_CopyTo(const DLContext &src_ctx,
auto device = dgl::runtime::DeviceAPI::Get(dst_ctx);
auto stream = device->CreateStream(dst_ctx);
device->SetStream(dst_ctx, stream);
auto g = dgl::UnitGraph::CreateFromCSC(2, csr);
ASSERT_EQ(g->GetCreatedFormats(), 4);
auto cg = dgl::UnitGraph::CopyTo(g, dst_ctx, stream);
auto cg = dgl::UnitGraph::CopyTo(g, dst_ctx);
device->StreamSync(dst_ctx, stream);
ASSERT_EQ(cg->GetCreatedFormats(), 4);
g = dgl::UnitGraph::CreateFromCSR(2, csr);
ASSERT_EQ(g->GetCreatedFormats(), 2);
cg = dgl::UnitGraph::CopyTo(g, dst_ctx, stream);
cg = dgl::UnitGraph::CopyTo(g, dst_ctx);
device->StreamSync(dst_ctx, stream);
ASSERT_EQ(cg->GetCreatedFormats(), 2);
g = dgl::UnitGraph::CreateFromCOO(2, coo);
ASSERT_EQ(g->GetCreatedFormats(), 1);
cg = dgl::UnitGraph::CopyTo(g, dst_ctx, stream);
cg = dgl::UnitGraph::CopyTo(g, dst_ctx);
device->StreamSync(dst_ctx, stream);
ASSERT_EQ(cg->GetCreatedFormats(), 1);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment