Unverified Commit 9a00cf19 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Feature] Import PyTorch's CUDA stream management (#4503)

* add set_stream

* add .record_stream for NDArray and HeteroGraph

* refactor dgl stream Python APIs

* test record_stream

* add unit test for record stream

* use pytorch's stream

* fix lint

* fix cpu build

* address comments

* address comments

* add record stream tests for dgl.graph

* record frames and update dataloder

* add docstring

* update frame

* add backend check for record_stream

* remove CUDAThreadEntry::stream

* record stream for newly created formats

* fix bug

* fix cpp test

* fix None c_void_p to c_handle
parent 099b173f
...@@ -153,7 +153,7 @@ std::pair<IdArray, NDArray> SparsePush( ...@@ -153,7 +153,7 @@ std::pair<IdArray, NDArray> SparsePush(
"device"; "device";
auto device = DeviceAPI::Get(ctx); auto device = DeviceAPI::Get(ctx);
cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream; cudaStream_t stream = runtime::getCurrentCUDAStream();
CHECK_LE(in_idx->ndim, 1) << "The tensor of sending indices must be of " CHECK_LE(in_idx->ndim, 1) << "The tensor of sending indices must be of "
"dimension one (or empty)."; "dimension one (or empty).";
...@@ -214,7 +214,7 @@ std::pair<IdArray, NDArray> SparsePush( ...@@ -214,7 +214,7 @@ std::pair<IdArray, NDArray> SparsePush(
} }
std::vector<int64_t> send_prefix_host(comm_size+1); std::vector<int64_t> send_prefix_host(comm_size+1);
// copy using the same stream (CUDAThreadEntry->ThreadLocal()->stream), no need to sync // copy using the same stream (local current stream), no need to sync
device->CopyDataFromTo( device->CopyDataFromTo(
send_prefix.get(), send_prefix.get(),
0, 0,
...@@ -252,7 +252,7 @@ std::pair<IdArray, NDArray> SparsePush( ...@@ -252,7 +252,7 @@ std::pair<IdArray, NDArray> SparsePush(
// finally copy the prefixsum sum down to the host // finally copy the prefixsum sum down to the host
std::vector<int64_t> recv_prefix_host(comm_size+1); std::vector<int64_t> recv_prefix_host(comm_size+1);
// copy using the same stream (CUDAThreadEntry->ThreadLocal()->stream), no need to sync // copy using the same stream (local current stream), no need to sync
device->CopyDataFromTo( device->CopyDataFromTo(
recv_prefix.get(), recv_prefix.get(),
0, 0,
...@@ -306,7 +306,7 @@ NDArray SparsePull( ...@@ -306,7 +306,7 @@ NDArray SparsePull(
"values must be on the same device"; "values must be on the same device";
auto device = DeviceAPI::Get(ctx); auto device = DeviceAPI::Get(ctx);
cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream; cudaStream_t stream = runtime::getCurrentCUDAStream();
CHECK_LE(req_idx->ndim, 1) << "The tensor of requested indices must be of " CHECK_LE(req_idx->ndim, 1) << "The tensor of requested indices must be of "
"dimension one (or empty)."; "dimension one (or empty).";
...@@ -368,7 +368,7 @@ NDArray SparsePull( ...@@ -368,7 +368,7 @@ NDArray SparsePull(
CUDA_CALL(cudaEventCreate(&d2h)); CUDA_CALL(cudaEventCreate(&d2h));
std::vector<int64_t> request_prefix_host(comm_size+1); std::vector<int64_t> request_prefix_host(comm_size+1);
// copy using the same stream (CUDAThreadEntry->ThreadLocal()->stream), no need to sync // copy using the same stream (local current stream), no need to sync
device->CopyDataFromTo( device->CopyDataFromTo(
request_prefix.get(), request_prefix.get(),
0, 0,
...@@ -403,7 +403,7 @@ NDArray SparsePull( ...@@ -403,7 +403,7 @@ NDArray SparsePull(
// finally copy the prefixsum sum down to the host // finally copy the prefixsum sum down to the host
std::vector<int64_t> response_prefix_host(comm_size+1); std::vector<int64_t> response_prefix_host(comm_size+1);
// copy using the same stream (CUDAThreadEntry->ThreadLocal()->stream), no need to sync // copy using the same stream (local current stream), no need to sync
device->CopyDataFromTo( device->CopyDataFromTo(
response_prefix.get(), response_prefix.get(),
0, 0,
...@@ -622,7 +622,7 @@ void NCCLCommunicator::AllToAllV( ...@@ -622,7 +622,7 @@ void NCCLCommunicator::AllToAllV(
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
auto dtype = DLDataTypeTraits<DType>::dtype; auto dtype = DLDataTypeTraits<DType>::dtype;
// copy using the same stream (CUDAThreadEntry->ThreadLocal()->stream), no need to sync // copy using the same stream (local current stream), no need to sync
device->CopyDataFromTo(send, send_prefix[0], device->CopyDataFromTo(send, send_prefix[0],
recv, recv_prefix[0], recv, recv_prefix[0],
sizeof(DType)*send_prefix[1]-send_prefix[0], sizeof(DType)*send_prefix[1]-send_prefix[0],
...@@ -684,7 +684,7 @@ void NCCLCommunicator::AllToAll( ...@@ -684,7 +684,7 @@ void NCCLCommunicator::AllToAll(
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
auto dtype = DLDataTypeTraits<IdType>::dtype; auto dtype = DLDataTypeTraits<IdType>::dtype;
// copy using the same stream (CUDAThreadEntry->ThreadLocal()->stream), no need to sync // copy using the same stream (local current stream), no need to sync
device->CopyDataFromTo(send, 0, recv, 0, count, ctx, ctx, dtype); device->CopyDataFromTo(send, 0, recv, 0, count, ctx, ctx, dtype);
#endif #endif
} }
......
...@@ -250,7 +250,7 @@ void NDArray::CopyFromTo(DLTensor* from, ...@@ -250,7 +250,7 @@ void NDArray::CopyFromTo(DLTensor* from,
// api manager. // api manager.
DGLContext ctx = from->ctx.device_type != kDLCPU ? from->ctx : to->ctx; DGLContext ctx = from->ctx.device_type != kDLCPU ? from->ctx : to->ctx;
// default: local cuda stream: CUDAThreadEntry->ThreadLocal()->stream // default: local current cuda stream
DeviceAPI::Get(ctx)->CopyDataFromTo( DeviceAPI::Get(ctx)->CopyDataFromTo(
from->data, static_cast<size_t>(from->byte_offset), from->data, static_cast<size_t>(from->byte_offset),
to->data, static_cast<size_t>(to->byte_offset), to->data, static_cast<size_t>(to->byte_offset),
...@@ -279,6 +279,15 @@ void NDArray::UnpinContainer(NDArray::Container* ptr) { ...@@ -279,6 +279,15 @@ void NDArray::UnpinContainer(NDArray::Container* ptr) {
ptr->pinned_by_dgl_ = false; ptr->pinned_by_dgl_ = false;
} }
void NDArray::RecordStream(DGLArray* tensor, DGLStreamHandle stream) {
TensorDispatcher* td = TensorDispatcher::Global();
CHECK(td->IsAvailable()) << "RecordStream only works when TensorAdaptor is available.";
CHECK_EQ(tensor->ctx.device_type, kDLGPU)
<< "RecordStream only works with GPU tensors.";
td->RecordStream(tensor->data, stream, tensor->ctx.device_id);
}
template<typename T> template<typename T>
NDArray NDArray::FromVector(const std::vector<T>& vec, DLContext ctx) { NDArray NDArray::FromVector(const std::vector<T>& vec, DLContext ctx) {
const DLDataType dtype = DLDataTypeTraits<T>::dtype; const DLDataType dtype = DLDataTypeTraits<T>::dtype;
...@@ -556,3 +565,9 @@ int DGLArrayUnpinData(DGLArrayHandle handle, ...@@ -556,3 +565,9 @@ int DGLArrayUnpinData(DGLArrayHandle handle,
NDArray::UnpinContainer(nd_container); NDArray::UnpinContainer(nd_container);
API_END(); API_END();
} }
int DGLArrayRecordStream(DGLArrayHandle handle, DGLStreamHandle stream) {
API_BEGIN();
NDArray::RecordStream(handle, stream);
API_END();
}
...@@ -56,6 +56,15 @@ void CUDARawDelete(void* ptr); ...@@ -56,6 +56,15 @@ void CUDARawDelete(void* ptr);
* \brief Get the current CUDA stream. * \brief Get the current CUDA stream.
*/ */
cudaStream_t CUDACurrentStream(); cudaStream_t CUDACurrentStream();
/*!
* \brief Let the caching allocator know which streams are using this tensor.
*
* \param ptr Pointer of the tensor to be recorded.
* \param stream The stream that is using this tensor.
* \param device_id Device of the tensor.
*/
void RecordStream(void* ptr, cudaStream_t stream, int device_id);
#endif // DGL_USE_CUDA #endif // DGL_USE_CUDA
} }
......
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
#include <tensoradapter_exports.h> #include <tensoradapter_exports.h>
#include <c10/core/CPUAllocator.h> #include <c10/core/CPUAllocator.h>
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/cuda/CUDACachingAllocator.h> #include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAStream.h> #include <c10/cuda/CUDAStream.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
...@@ -26,6 +28,7 @@ TA_EXPORTS void CPURawDelete(void* ptr) { ...@@ -26,6 +28,7 @@ TA_EXPORTS void CPURawDelete(void* ptr) {
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
TA_EXPORTS void* CUDARawAlloc(size_t nbytes, cudaStream_t stream) { TA_EXPORTS void* CUDARawAlloc(size_t nbytes, cudaStream_t stream) {
at::globalContext().lazyInitCUDA();
return c10::cuda::CUDACachingAllocator::raw_alloc_with_stream( return c10::cuda::CUDACachingAllocator::raw_alloc_with_stream(
nbytes, stream); nbytes, stream);
} }
...@@ -37,6 +40,23 @@ TA_EXPORTS void CUDARawDelete(void* ptr) { ...@@ -37,6 +40,23 @@ TA_EXPORTS void CUDARawDelete(void* ptr) {
TA_EXPORTS cudaStream_t CUDACurrentStream() { TA_EXPORTS cudaStream_t CUDACurrentStream() {
return at::cuda::getCurrentCUDAStream(); return at::cuda::getCurrentCUDAStream();
} }
TA_EXPORTS void RecordStream(void* ptr, cudaStream_t stream, int device_id) {
c10::DataPtr data_ptr{
ptr, ptr, &c10::cuda::CUDACachingAllocator::raw_delete,
c10::Device(c10::DeviceType::CUDA, device_id)};
c10::cuda::CUDACachingAllocator::recordStream(
data_ptr,
// getStreamFromExternal doesn't exist before PyTorch 1.10, just copy it here
c10::cuda::CUDAStream(
c10::cuda::CUDAStream::UNCHECKED,
c10::Stream(
c10::Stream::UNSAFE,
c10::Device(c10::DeviceType::CUDA, device_id),
reinterpret_cast<int64_t>(stream)))
);
data_ptr.release_context();
}
#endif // DGL_USE_CUDA #endif // DGL_USE_CUDA
}; };
......
...@@ -352,8 +352,8 @@ void _TestUnitGraph_CopyTo(const DLContext &src_ctx, ...@@ -352,8 +352,8 @@ void _TestUnitGraph_CopyTo(const DLContext &src_ctx,
const aten::COOMatrix &coo = COO1<IdType>(src_ctx); const aten::COOMatrix &coo = COO1<IdType>(src_ctx);
auto device = dgl::runtime::DeviceAPI::Get(dst_ctx); auto device = dgl::runtime::DeviceAPI::Get(dst_ctx);
auto stream = device->CreateStream(dst_ctx); // We don't allow SetStream in DGL for now.
device->SetStream(dst_ctx, stream); auto stream = nullptr;
auto g = dgl::UnitGraph::CreateFromCSC(2, csr); auto g = dgl::UnitGraph::CreateFromCSC(2, csr);
ASSERT_EQ(g->GetCreatedFormats(), 4); ASSERT_EQ(g->GetCreatedFormats(), 4);
......
from statistics import mean
import unittest
import numpy as np
import torch
import dgl
import dgl.ndarray as nd
from dgl import rand_graph from dgl import rand_graph
import dgl._ffi.streams as FS
import dgl.ops as OPS import dgl.ops as OPS
import unittest from dgl._ffi.streams import to_dgl_stream_handle, _dgl_get_stream
from dgl.utils import to_dgl_context
import backend as F import backend as F
import torch
# borrowed from PyTorch, torch/testing/_internal/common_utils.py
def _get_cycles_per_ms() -> float:
"""Measure and return approximate number of cycles per millisecond for torch.cuda._sleep
"""
def measure() -> float:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
torch.cuda._sleep(1000000)
end.record()
end.synchronize()
cycles_per_ms = 1000000 / start.elapsed_time(end)
return cycles_per_ms
# Get 10 values and remove the 2 max and 2 min and return the avg.
# This is to avoid system disturbance that skew the results, e.g.
# the very first cuda call likely does a bunch of init, which takes
# much longer than subsequent calls.
num = 10
vals = []
for _ in range(num):
vals.append(measure())
vals = sorted(vals)
return mean(vals[2 : num - 2])
@unittest.skipIf(F._default_context_str == 'cpu', reason="stream only runs on GPU.") @unittest.skipIf(F._default_context_str == 'cpu', reason="stream only runs on GPU.")
def test_basics(): def test_basics():
g = rand_graph(10, 20, device=F.cpu()) g = rand_graph(10, 20, device=F.cpu())
x = torch.ones(g.num_nodes(), 10) x = torch.ones(g.num_nodes(), 10)
result = OPS.copy_u_sum(g, x).to(F.ctx())
# launch on default stream fetched via torch.cuda # launch on default stream used in DGL
s = torch.cuda.default_stream(device=F.ctx()) xx = x.to(device=F.ctx())
with torch.cuda.stream(s):
xx = x.to(device=F.ctx(), non_blocking=True)
with FS.stream(s):
gg = g.to(device=F.ctx()) gg = g.to(device=F.ctx())
s.synchronize()
OPS.copy_u_sum(gg, xx) OPS.copy_u_sum(gg, xx)
assert torch.equal(OPS.copy_u_sum(gg, xx), result)
# launch on new stream created via torch.cuda # launch on new stream created via torch.cuda
s = torch.cuda.Stream(device=F.ctx()) s = torch.cuda.Stream(device=F.ctx())
with torch.cuda.stream(s): with torch.cuda.stream(s):
xx = x.to(device=F.ctx(), non_blocking=True) xx = x.to(device=F.ctx(), non_blocking=True)
with FS.stream(s):
gg = g.to(device=F.ctx()) gg = g.to(device=F.ctx())
s.synchronize()
OPS.copy_u_sum(gg, xx) OPS.copy_u_sum(gg, xx)
s.synchronize()
assert torch.equal(OPS.copy_u_sum(gg, xx), result)
# launch on default stream used in DGL @unittest.skipIf(F._default_context_str == 'cpu', reason="stream only runs on GPU.")
xx = x.to(device=F.ctx()) def test_set_get_stream():
gg = g.to(device=F.ctx()) current_stream = torch.cuda.current_stream()
OPS.copy_u_sum(gg, xx) # test setting another stream
s = torch.cuda.Stream(device=F.ctx())
torch.cuda.set_stream(s)
assert to_dgl_stream_handle(s).value == _dgl_get_stream(to_dgl_context(F.ctx())).value
# revert to default stream
torch.cuda.set_stream(current_stream)
@unittest.skipIf(F._default_context_str == 'cpu', reason="stream only runs on GPU.")
# borrowed from PyTorch, test/test_cuda.py: test_record_stream()
def test_record_stream_ndarray():
cycles_per_ms = _get_cycles_per_ms()
t = nd.array(np.array([1., 2., 3., 4.], dtype=np.float32), ctx=nd.cpu())
t.pin_memory_()
result = nd.empty([4], ctx=nd.gpu(0))
stream = torch.cuda.Stream()
ptr = [None]
# Performs the CPU->GPU copy in a background stream
def perform_copy():
with torch.cuda.stream(stream):
tmp = t.copyto(nd.gpu(0))
ptr[0] = F.from_dgl_nd(tmp).data_ptr()
torch.cuda.current_stream().wait_stream(stream)
tmp.record_stream(
to_dgl_stream_handle(torch.cuda.current_stream()))
torch.cuda._sleep(int(50 * cycles_per_ms)) # delay the copy
result.copyfrom(tmp)
perform_copy()
with torch.cuda.stream(stream):
tmp2 = nd.empty([4], ctx=nd.gpu(0))
assert F.from_dgl_nd(tmp2).data_ptr() != ptr[0], 'allocation re-used too soon'
assert torch.equal(F.from_dgl_nd(result).cpu(), torch.tensor([1., 2., 3., 4.]))
# Check that the block will be re-used after the main stream finishes
torch.cuda.current_stream().synchronize()
with torch.cuda.stream(stream):
tmp3 = nd.empty([4], ctx=nd.gpu(0))
assert F.from_dgl_nd(tmp3).data_ptr() == ptr[0], 'allocation not re-used'
@unittest.skipIf(F._default_context_str == 'cpu', reason="stream only runs on GPU.")
def test_record_stream_graph_positive():
cycles_per_ms = _get_cycles_per_ms()
g = rand_graph(10, 20, device=F.cpu())
x = torch.ones(g.num_nodes(), 10)
result = OPS.copy_u_sum(g, x).to(F.ctx())
stream = torch.cuda.Stream()
results2 = torch.zeros_like(result)
# Performs the computing in a background stream
def perform_computing():
with torch.cuda.stream(stream):
g2 = g.to(F.ctx())
torch.cuda.current_stream().wait_stream(stream)
g2.record_stream(torch.cuda.current_stream())
torch.cuda._sleep(int(50 * cycles_per_ms)) # delay the computing
results2.copy_(OPS.copy_u_sum(g2, x))
x = x.to(F.ctx())
perform_computing()
with torch.cuda.stream(stream):
# since we have called record stream for g2, g3 won't reuse its memory
g3 = rand_graph(10, 20, device=F.ctx())
torch.cuda.current_stream().synchronize()
assert torch.equal(result, results2)
@unittest.skipIf(F._default_context_str == 'cpu', reason="stream only runs on GPU.")
def test_record_stream_graph_negative():
cycles_per_ms = _get_cycles_per_ms()
g = rand_graph(10, 20, device=F.cpu())
x = torch.ones(g.num_nodes(), 10)
result = OPS.copy_u_sum(g, x).to(F.ctx())
stream = torch.cuda.Stream()
results2 = torch.zeros_like(result)
# Performs the computing in a background stream
def perform_computing():
with torch.cuda.stream(stream):
g2 = g.to(F.ctx())
torch.cuda.current_stream().wait_stream(stream)
# omit record_stream will produce a wrong result
# g2.record_stream(torch.cuda.current_stream())
torch.cuda._sleep(int(50 * cycles_per_ms)) # delay the computing
results2.copy_(OPS.copy_u_sum(g2, x))
x = x.to(F.ctx())
perform_computing()
with torch.cuda.stream(stream):
# g3 will reuse g2's memory block, resulting a wrong result
g3 = rand_graph(10, 20, device=F.ctx())
torch.cuda.current_stream().synchronize()
assert not torch.equal(result, results2)
if __name__ == '__main__': if __name__ == '__main__':
test_basics() test_basics()
test_set_get_stream()
test_record_stream_ndarray()
test_record_stream_graph_positive()
test_record_stream_graph_negative()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment