"tools/python/vscode:/vscode.git/clone" did not exist on "859ccf5e8b9ce86f85682f72abb4ca42ae3a1e66"
Unverified Commit e234fcfa authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Feature] enable create/set/free cuda stream for internal use (#3334)

* [Feature] enable create/set/free cuda stream for internal use

* add unit test

* fix unit test failure on mxnet and tf

* refactor stream wrapper

* fix lint error

* fix lint error
parent 5cf48fc6
...@@ -516,6 +516,16 @@ DGL_DLL int DGLStreamFree(int device_type, int device_id, DGLStreamHandle stream ...@@ -516,6 +516,16 @@ DGL_DLL int DGLStreamFree(int device_type, int device_id, DGLStreamHandle stream
*/ */
DGL_DLL int DGLSetStream(int device_type, int device_id, DGLStreamHandle handle); DGL_DLL int DGLSetStream(int device_type, int device_id, DGLStreamHandle handle);
/*!
* \brief Get the runtime stream of current thread.
*
* \param device_type The device type of context
* \param device_id The device id of context.
* \param handle The stream handle.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLGetStream(int device_type, int device_id, DGLStreamHandle* handle);
/*! /*!
* \brief Wait until all computations on stream completes. * \brief Wait until all computations on stream completes.
* *
......
...@@ -125,6 +125,10 @@ class DeviceAPI { ...@@ -125,6 +125,10 @@ class DeviceAPI {
* \param stream The stream to be set. * \param stream The stream to be set.
*/ */
virtual void SetStream(DGLContext ctx, DGLStreamHandle stream) {} virtual void SetStream(DGLContext ctx, DGLStreamHandle stream) {}
/*!
* \brief Get the stream
*/
virtual DGLStreamHandle GetStream() const { return nullptr; }
/*! /*!
* \brief Synchronize 2 streams of execution. * \brief Synchronize 2 streams of execution.
* *
......
...@@ -256,3 +256,5 @@ class DGLArray(ctypes.Structure): ...@@ -256,3 +256,5 @@ class DGLArray(ctypes.Structure):
("byte_offset", ctypes.c_uint64)] ("byte_offset", ctypes.c_uint64)]
DGLArrayHandle = ctypes.POINTER(DGLArray) DGLArrayHandle = ctypes.POINTER(DGLArray)
DGLStreamHandle = ctypes.c_void_p
# pylint: disable=invalid-name, unused-import
"""Runtime stream api which is maily for internal use only."""
from __future__ import absolute_import
import ctypes
from .base import _LIB, check_call, _FFI_MODE
from .runtime_ctypes import DGLStreamHandle
from .ndarray import context
from ..utils import to_dgl_context
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
class StreamContext(object):
""" Context-manager that selects a given stream.
All CUDA kernels queued within its context will be enqueued
on a selected stream.
"""
def __init__(self, cuda_stream):
""" create stream context instance
Parameters
----------
cuda_stream : torch.cuda.Stream
target stream will be set.
"""
self.ctx = to_dgl_context(cuda_stream.device)
self.curr_cuda_stream = cuda_stream.cuda_stream
def __enter__(self):
""" get previous stream and set target stream as current.
"""
self.prev_cuda_stream = DGLStreamHandle()
check_call(_LIB.DGLGetStream(
self.ctx.device_type, self.ctx.device_id, ctypes.byref(self.prev_cuda_stream)))
check_call(_LIB.DGLSetStream(
self.ctx.device_type, self.ctx.device_id, ctypes.c_void_p(self.curr_cuda_stream)))
def __exit__(self, exc_type, exc_value, exc_traceback):
""" restore previous stream when exiting.
"""
check_call(_LIB.DGLSetStream(
self.ctx.device_type, self.ctx.device_id, self.prev_cuda_stream))
def stream(cuda_stream):
""" Wrapper of StreamContext
Parameters
----------
stream : torch.cuda.Stream
target stream will be set.
"""
return StreamContext(cuda_stream)
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <dgl/immutable_graph.h> #include <dgl/immutable_graph.h>
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include <dgl/runtime/parallel_for.h> #include <dgl/runtime/parallel_for.h>
#include <dgl/runtime/c_runtime_api.h>
#include <set> #include <set>
#include "../c_api_common.h" #include "../c_api_common.h"
...@@ -466,7 +467,9 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyTo") ...@@ -466,7 +467,9 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyTo")
DLContext ctx; DLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type); ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id; ctx.device_id = device_id;
HeteroGraphPtr hg_new = HeteroGraph::CopyTo(hg.sptr(), ctx); DGLStreamHandle stream = nullptr;
DGLGetStream(device_type, device_id, &stream);
HeteroGraphPtr hg_new = HeteroGraph::CopyTo(hg.sptr(), ctx, stream);
*rv = HeteroGraphRef(hg_new); *rv = HeteroGraphRef(hg_new);
}); });
......
...@@ -356,6 +356,15 @@ int DGLSetStream(int device_type, int device_id, DGLStreamHandle stream) { ...@@ -356,6 +356,15 @@ int DGLSetStream(int device_type, int device_id, DGLStreamHandle stream) {
API_END(); API_END();
} }
int DGLGetStream(int device_type, int device_id, DGLStreamHandle* stream) {
API_BEGIN();
DGLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
*stream = DeviceAPIManager::Get(ctx)->GetStream();
API_END();
}
int DGLSynchronize(int device_type, int device_id, DGLStreamHandle stream) { int DGLSynchronize(int device_type, int device_id, DGLStreamHandle stream) {
API_BEGIN(); API_BEGIN();
DGLContext ctx; DGLContext ctx;
......
...@@ -170,6 +170,10 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -170,6 +170,10 @@ class CUDADeviceAPI final : public DeviceAPI {
->stream = static_cast<cudaStream_t>(stream); ->stream = static_cast<cudaStream_t>(stream);
} }
DGLStreamHandle GetStream() const final {
return static_cast<DGLStreamHandle>(CUDAThreadEntry::ThreadLocal()->stream);
}
void PinData(DGLContext ctx, void* ptr, size_t nbytes) { void PinData(DGLContext ctx, void* ptr, size_t nbytes) {
CUDA_CALL(cudaSetDevice(ctx.device_id)); CUDA_CALL(cudaSetDevice(ctx.device_id));
CUDA_CALL(cudaHostRegister(ptr, nbytes, cudaHostRegisterDefault)); CUDA_CALL(cudaHostRegister(ptr, nbytes, cudaHostRegisterDefault));
......
from dgl import rand_graph
import dgl._ffi.streams as FS
import dgl.ops as OPS
import unittest
import backend as F
import torch
@unittest.skipIf(F._default_context_str == 'cpu', reason="stream only runs on GPU.")
def test_basics():
g = rand_graph(10, 20, device=F.cpu())
x = torch.ones(g.num_nodes(), 10)
# launch on default stream fetched via torch.cuda
s = torch.cuda.default_stream(device=F.ctx())
with torch.cuda.stream(s):
xx = x.to(device=F.ctx(), non_blocking=True)
with FS.stream(s):
gg = g.to(device=F.ctx())
s.synchronize()
OPS.copy_u_sum(gg, xx)
# launch on new stream created via torch.cuda
s = torch.cuda.Stream(device=F.ctx())
with torch.cuda.stream(s):
xx = x.to(device=F.ctx(), non_blocking=True)
with FS.stream(s):
gg = g.to(device=F.ctx())
s.synchronize()
OPS.copy_u_sum(gg, xx)
# launch on default stream used in DGL
xx = x.to(device=F.ctx())
gg = g.to(device=F.ctx())
OPS.copy_u_sum(gg, xx)
if __name__ == '__main__':
test_basics()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment