"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "3e2cff4da25642e964c48fa44d7c00d3314b1ce8"
Unverified Commit 9a00cf19 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Feature] Import PyTorch's CUDA stream management (#4503)

* add set_stream

* add .record_stream for NDArray and HeteroGraph

* refactor dgl stream Python APIs

* test record_stream

* add unit test for record stream

* use pytorch's stream

* fix lint

* fix cpu build

* address comments

* address comments

* add record stream tests for dgl.graph

* record frames and update dataloder

* add docstring

* update frame

* add backend check for record_stream

* remove CUDAThreadEntry::stream

* record stream for newly created formats

* fix bug

* fix cpp test

* fix None c_void_p to c_handle
parent 099b173f
...@@ -167,6 +167,18 @@ struct COOMatrix { ...@@ -167,6 +167,18 @@ struct COOMatrix {
} }
is_pinned = false; is_pinned = false;
} }
/*!
* \brief Record stream for the row, col and data (if not Null) of the matrix.
* \param stream The stream that is using the graph
*/
inline void RecordStream(DGLStreamHandle stream) const {
row.RecordStream(stream);
col.RecordStream(stream);
if (!aten::IsNullArray(data)) {
data.RecordStream(stream);
}
}
}; };
///////////////////////// COO routines ////////////////////////// ///////////////////////// COO routines //////////////////////////
......
...@@ -160,6 +160,18 @@ struct CSRMatrix { ...@@ -160,6 +160,18 @@ struct CSRMatrix {
} }
is_pinned = false; is_pinned = false;
} }
/*!
* \brief Record stream for the indptr, indices and data (if not Null) of the matrix.
* \param stream The stream that is using the graph
*/
inline void RecordStream(DGLStreamHandle stream) const {
indptr.RecordStream(stream);
indices.RecordStream(stream);
if (!aten::IsNullArray(data)) {
data.RecordStream(stream);
}
}
}; };
///////////////////////// CSR routines ////////////////////////// ///////////////////////// CSR routines //////////////////////////
......
...@@ -121,6 +121,12 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -121,6 +121,12 @@ class BaseHeteroGraph : public runtime::Object {
*/ */
virtual bool IsPinned() const = 0; virtual bool IsPinned() const = 0;
/*!
* \brief Record stream for this graph.
* \param stream The stream that is using the graph
*/
virtual void RecordStream(DGLStreamHandle stream) = 0;
/*! /*!
* \brief Get the number of integer bits used to store node/edge ids (32 or 64). * \brief Get the number of integer bits used to store node/edge ids (32 or 64).
*/ */
......
...@@ -564,6 +564,11 @@ int DGLArrayPinData(DGLArrayHandle handle, DLContext ctx); ...@@ -564,6 +564,11 @@ int DGLArrayPinData(DGLArrayHandle handle, DLContext ctx);
*/ */
int DGLArrayUnpinData(DGLArrayHandle handle, DLContext ctx); int DGLArrayUnpinData(DGLArrayHandle handle, DLContext ctx);
/*!
* \brief Record the stream that's using this tensor.
*/
int DGLArrayRecordStream(DGLArrayHandle handle, DGLStreamHandle stream);
/*! /*!
* \brief Bug report macro. * \brief Bug report macro.
* *
......
...@@ -188,6 +188,11 @@ class NDArray { ...@@ -188,6 +188,11 @@ class NDArray {
* \brief Check if the array is pinned. * \brief Check if the array is pinned.
*/ */
inline bool IsPinned() const; inline bool IsPinned() const;
/*!
* \brief Record streams that are using the underlying tensor.
* \param stream The stream that is using the underlying tensor.
*/
inline void RecordStream(DGLStreamHandle stream) const;
/*! /*!
* \brief Load NDArray from stream * \brief Load NDArray from stream
* \param stream The input data stream * \param stream The input data stream
...@@ -319,6 +324,13 @@ class NDArray { ...@@ -319,6 +324,13 @@ class NDArray {
*/ */
DGL_DLL static bool IsContainerPinned(Container* ptr); DGL_DLL static bool IsContainerPinned(Container* ptr);
/*!
* \brief Record streams that are using this tensor.
* \param ptr Pointer of the tensor to be recorded.
* \param stream The stream that is using this tensor.
*/
DGL_DLL static void RecordStream(DGLArray* tensor, DGLStreamHandle stream);
// internal namespace // internal namespace
struct Internal; struct Internal;
private: private:
...@@ -488,6 +500,11 @@ inline bool NDArray::IsPinned() const { ...@@ -488,6 +500,11 @@ inline bool NDArray::IsPinned() const {
return IsContainerPinned(data_); return IsContainerPinned(data_);
} }
inline void NDArray::RecordStream(DGLStreamHandle stream) const {
CHECK(data_ != nullptr);
RecordStream(&(data_->dl_tensor), stream);
}
inline int NDArray::use_count() const { inline int NDArray::use_count() const {
if (data_ == nullptr) return 0; if (data_ == nullptr) return 0;
return data_->ref_counter_.load(std::memory_order_relaxed); return data_->ref_counter_.load(std::memory_order_relaxed);
......
...@@ -121,7 +121,7 @@ class TensorDispatcher { ...@@ -121,7 +121,7 @@ class TensorDispatcher {
/*! /*!
* \brief Find the current PyTorch CUDA stream * \brief Find the current PyTorch CUDA stream
* Used in CUDAThreadEntry::ThreadLocal->stream. * Used in runtime::getCurrentCUDAStream().
* *
* \note PyTorch pre-allocates/sets the current CUDA stream * \note PyTorch pre-allocates/sets the current CUDA stream
* on current device via cudaGetDevice(). Make sure to call cudaSetDevice() * on current device via cudaGetDevice(). Make sure to call cudaSetDevice()
...@@ -135,6 +135,22 @@ class TensorDispatcher { ...@@ -135,6 +135,22 @@ class TensorDispatcher {
} }
#endif // DGL_USE_CUDA #endif // DGL_USE_CUDA
/*!
* \brief Record streams that are using this tensor.
* Used in NDArray::RecordStream().
*
* \param ptr Pointer of the tensor to be recorded.
* \param stream The stream that is using this tensor.
* \param device_id Device of the tensor.
*/
inline void RecordStream(void* ptr, DGLStreamHandle stream, int device_id) {
#ifdef DGL_USE_CUDA
auto entry = entrypoints_[Op::kRecordStream];
FUNCCAST(tensoradapter::RecordStream, entry)(
ptr, static_cast<cudaStream_t>(stream), device_id);
#endif // DGL_USE_CUDA
}
private: private:
/*! \brief ctor */ /*! \brief ctor */
TensorDispatcher() = default; TensorDispatcher() = default;
...@@ -153,6 +169,7 @@ class TensorDispatcher { ...@@ -153,6 +169,7 @@ class TensorDispatcher {
"CUDARawAlloc", "CUDARawAlloc",
"CUDARawDelete", "CUDARawDelete",
"CUDACurrentStream", "CUDACurrentStream",
"RecordStream",
#endif // DGL_USE_CUDA #endif // DGL_USE_CUDA
}; };
...@@ -165,6 +182,7 @@ class TensorDispatcher { ...@@ -165,6 +182,7 @@ class TensorDispatcher {
static constexpr int kCUDARawAlloc = 2; static constexpr int kCUDARawAlloc = 2;
static constexpr int kCUDARawDelete = 3; static constexpr int kCUDARawDelete = 3;
static constexpr int kCUDACurrentStream = 4; static constexpr int kCUDACurrentStream = 4;
static constexpr int kRecordStream = 5;
#endif // DGL_USE_CUDA #endif // DGL_USE_CUDA
}; };
...@@ -179,6 +197,7 @@ class TensorDispatcher { ...@@ -179,6 +197,7 @@ class TensorDispatcher {
nullptr, nullptr,
nullptr, nullptr,
nullptr, nullptr,
nullptr,
#endif // DGL_USE_CUDA #endif // DGL_USE_CUDA
}; };
......
...@@ -14,7 +14,6 @@ from ...ndarray import NDArray as DGLNDArray ...@@ -14,7 +14,6 @@ from ...ndarray import NDArray as DGLNDArray
from ... import backend as F from ... import backend as F
from ...base import DGLError, dgl_warning from ...base import DGLError, dgl_warning
from ...utils import to_dgl_context, check_device from ...utils import to_dgl_context, check_device
from ..._ffi import streams as FS
__all__ = ['NodeDataLoader', 'EdgeDataLoader', 'GraphDataLoader', __all__ = ['NodeDataLoader', 'EdgeDataLoader', 'GraphDataLoader',
# Temporary exposure. # Temporary exposure.
...@@ -353,9 +352,8 @@ def _next(dl_iter, graph, device, load_input, load_output, stream=None): ...@@ -353,9 +352,8 @@ def _next(dl_iter, graph, device, load_input, load_output, stream=None):
result_ = (input_nodes, output_nodes, blocks, input_data, output_data) result_ = (input_nodes, output_nodes, blocks, input_data, output_data)
if stream is not None: if stream is not None:
with th.cuda.stream(stream): with th.cuda.stream(stream):
with FS.stream(stream): result = [_to_device(data, device)
result = [_to_device(data, device) for data in result_], result_, stream.record_event()
for data in result_], result_, stream.record_event()
else: else:
result = [_to_device(data, device) for data in result_] result = [_to_device(data, device) for data in result_]
return result return result
......
...@@ -162,5 +162,8 @@ cdef inline object ctypes_handle(void* chandle): ...@@ -162,5 +162,8 @@ cdef inline object ctypes_handle(void* chandle):
cdef inline void* c_handle(object handle): cdef inline void* c_handle(object handle):
"""Cast C types handle to c handle.""" """Cast C types handle to c handle."""
cdef unsigned long long v_ptr cdef unsigned long long v_ptr
v_ptr = handle.value if handle.value is None:
return <void*>(v_ptr) return NULL
else:
v_ptr = handle.value
return <void*>(v_ptr)
...@@ -326,6 +326,19 @@ class NDArrayBase(_NDArrayBase): ...@@ -326,6 +326,19 @@ class NDArrayBase(_NDArrayBase):
""" """
check_call(_LIB.DGLArrayUnpinData(self.handle)) check_call(_LIB.DGLArrayUnpinData(self.handle))
def record_stream(self, stream):
"""Record the stream that is using this tensor.
Note
----
This API is more for testing. Users should call ``record_stream``
on torch.Tensor or dgl.graph directly.
Parameters
----------
stream : DGLStreamHandle
"""
check_call(_LIB.DGLArrayRecordStream(self.handle, stream))
def free_extension_handle(handle, type_code): def free_extension_handle(handle, type_code):
"""Free c++ extension type handle """Free c++ extension type handle
......
# pylint: disable=invalid-name, unused-import # pylint: disable=invalid-name, unused-import
"""Runtime stream api which is maily for internal use only.""" """Runtime stream APIs which are mainly for internal test use only.
For applications, please use PyTorch's stream management, of which DGL is aware.
"""
from __future__ import absolute_import from __future__ import absolute_import
import ctypes import ctypes
from .base import _LIB, check_call, _FFI_MODE from .base import _LIB, check_call, _FFI_MODE
from .runtime_ctypes import DGLStreamHandle from .runtime_ctypes import DGLStreamHandle
from .ndarray import context
from ..utils import to_dgl_context
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
def to_dgl_stream_handle(cuda_stream):
""" Convert torch.cuda.Stream to DGL stream handle
class StreamContext(object): Parameters
""" Context-manager that selects a given stream. ----------
cuda_stream : torch.cuda.Stream.
All CUDA kernels queued within its context will be enqueued
on a selected stream.
Returns
-------
DGLStreamHandle
DGLStreamHandle of the input ``cuda_stream``.
""" """
return ctypes.c_void_p(cuda_stream.cuda_stream)
def __init__(self, cuda_stream): def _dgl_get_stream(ctx):
""" create stream context instance """Get the current CUDA stream of the given DGL context.
Parameters
----------
cuda_stream : torch.cuda.Stream. This manager is a no-op if it's ``None``.
target stream will be set.
"""
if cuda_stream is None:
self.curr_cuda_stream = None
else:
self.ctx = to_dgl_context(cuda_stream.device)
self.curr_cuda_stream = cuda_stream.cuda_stream
def __enter__(self):
""" get previous stream and set target stream as current.
"""
if self.curr_cuda_stream is None:
return
self.prev_cuda_stream = DGLStreamHandle()
check_call(_LIB.DGLGetStream(
self.ctx.device_type, self.ctx.device_id, ctypes.byref(self.prev_cuda_stream)))
check_call(_LIB.DGLSetStream(
self.ctx.device_type, self.ctx.device_id, ctypes.c_void_p(self.curr_cuda_stream)))
def __exit__(self, exc_type, exc_value, exc_traceback):
""" restore previous stream when exiting.
"""
if self.curr_cuda_stream is None:
return
check_call(_LIB.DGLSetStream(
self.ctx.device_type, self.ctx.device_id, self.prev_cuda_stream))
def stream(cuda_stream):
""" Wrapper of StreamContext
Parameters Parameters
---------- ----------
stream : torch.cuda.Stream. This manager is a no-op if it's ``None``. ctx : DGL context.
target stream will be set.
Returns
-------
DGLStreamHandle
DGLStreamHandle of the current CUDA stream.
""" """
return StreamContext(cuda_stream) current_cuda_stream = DGLStreamHandle()
check_call(_LIB.DGLGetStream(
ctx.device_type, ctx.device_id, ctypes.byref(current_cuda_stream)))
return current_cuda_stream
""" CUDA wrappers """ """ CUDA wrappers """
from . import nccl from . import nccl
from .._ffi.streams import stream
...@@ -303,7 +303,7 @@ def _await_or_return(x): ...@@ -303,7 +303,7 @@ def _await_or_return(x):
def _record_stream(x, stream): def _record_stream(x, stream):
if stream is None: if stream is None:
return x return x
if isinstance(x, torch.Tensor): if hasattr(x, 'record_stream'):
x.record_stream(stream) x.record_stream(stream)
return x return x
elif isinstance(x, _PrefetchedGraphFeatures): elif isinstance(x, _PrefetchedGraphFeatures):
...@@ -333,9 +333,7 @@ def _prefetch(batch, dataloader, stream): ...@@ -333,9 +333,7 @@ def _prefetch(batch, dataloader, stream):
feats = recursive_apply(batch, _prefetch_for, dataloader) feats = recursive_apply(batch, _prefetch_for, dataloader)
feats = recursive_apply(feats, _await_or_return) feats = recursive_apply(feats, _await_or_return)
feats = recursive_apply(feats, _record_stream, current_stream) feats = recursive_apply(feats, _record_stream, current_stream)
# transfer input nodes/seed nodes # transfer input nodes/seed nodes/subgraphs
# TODO(Xin): sampled subgraph is transferred in the default stream
# because heterograph doesn't support .record_stream() for now
batch = recursive_apply(batch, lambda x: x.to(dataloader.device, non_blocking=True)) batch = recursive_apply(batch, lambda x: x.to(dataloader.device, non_blocking=True))
batch = recursive_apply(batch, _record_stream, current_stream) batch = recursive_apply(batch, _record_stream, current_stream)
stream_event = stream.record_event() if stream is not None else None stream_event = stream.record_event() if stream is not None else None
......
...@@ -101,6 +101,9 @@ class LazyFeature(object): ...@@ -101,6 +101,9 @@ class LazyFeature(object):
def unpin_memory_(self): def unpin_memory_(self):
"""No-op. For compatibility of :meth:`Frame.unpin_memory_` method.""" """No-op. For compatibility of :meth:`Frame.unpin_memory_` method."""
def record_stream(self, stream):
"""No-op. For compatibility of :meth:`Frame.record_stream` method."""
class Scheme(namedtuple('Scheme', ['shape', 'dtype'])): class Scheme(namedtuple('Scheme', ['shape', 'dtype'])):
"""The column scheme. """The column scheme.
...@@ -492,6 +495,18 @@ class Column(TensorStorage): ...@@ -492,6 +495,18 @@ class Column(TensorStorage):
self._data_nd = None self._data_nd = None
self.pinned_by_dgl = False self.pinned_by_dgl = False
def record_stream(self, stream):
"""Record stream that is using the storage.
Does nothing if the backend is not PyTorch.
Parameters
----------
stream : torch.cuda.Stream.
"""
if F.get_preferred_backend() != 'pytorch':
raise DGLError("record_stream only supports the PyTorch backend.")
self.data.record_stream(stream)
class Frame(MutableMapping): class Frame(MutableMapping):
"""The columnar storage for node/edge features. """The columnar storage for node/edge features.
...@@ -877,6 +892,12 @@ class Frame(MutableMapping): ...@@ -877,6 +892,12 @@ class Frame(MutableMapping):
for column in self._columns.values(): for column in self._columns.values():
column.unpin_memory_() column.unpin_memory_()
def record_stream(self, stream):
"""Record stream that is using the data of every column, materializing them
if necessary."""
for column in self._columns.values():
column.record_stream(stream)
def _astype_float(self, new_type): def _astype_float(self, new_type):
assert new_type in [F.float64, F.float32, F.float16], \ assert new_type in [F.float64, F.float32, F.float16], \
"'new_type' must be floating-point type: %s" % str(new_type) "'new_type' must be floating-point type: %s" % str(new_type)
......
...@@ -5584,6 +5584,31 @@ class DGLHeteroGraph(object): ...@@ -5584,6 +5584,31 @@ class DGLHeteroGraph(object):
""" """
return self._graph.is_pinned() return self._graph.is_pinned()
def record_stream(self, stream):
"""Record the stream that is using this graph.
This method only supports the PyTorch backend and requires graphs on the GPU.
Parameters
----------
stream : torch.cuda.Stream
The stream that is using this graph.
Returns
-------
DGLGraph
self.
"""
if F.get_preferred_backend() != 'pytorch':
raise DGLError("record_stream only support the PyTorch backend.")
if F.device_type(self.device) != 'cuda':
raise DGLError("The graph must be on GPU to be recorded.")
self._graph.record_stream(stream)
for frame in itertools.chain(self._node_frames, self._edge_frames):
for col in frame._columns.values():
col.record_stream(stream)
return self
def clone(self): def clone(self):
"""Return a heterograph object that is a clone of current graph. """Return a heterograph object that is a clone of current graph.
......
...@@ -8,6 +8,7 @@ import scipy ...@@ -8,6 +8,7 @@ import scipy
from ._ffi.object import register_object, ObjectBase from ._ffi.object import register_object, ObjectBase
from ._ffi.function import _init_api from ._ffi.function import _init_api
from ._ffi.streams import to_dgl_stream_handle
from .base import DGLError, dgl_warning from .base import DGLError, dgl_warning
from .graph_index import from_coo from .graph_index import from_coo
from . import backend as F from . import backend as F
...@@ -273,6 +274,21 @@ class HeteroGraphIndex(ObjectBase): ...@@ -273,6 +274,21 @@ class HeteroGraphIndex(ObjectBase):
""" """
return bool(_CAPI_DGLHeteroIsPinned(self)) return bool(_CAPI_DGLHeteroIsPinned(self))
def record_stream(self, stream):
"""Record the stream that is using this graph.
Parameters
----------
stream : torch.cuda.Stream
The stream that is using this graph.
Returns
-------
HeteroGraphIndex
self.
"""
return _CAPI_DGLHeteroRecordStream(self, to_dgl_stream_handle(stream))
def shared_memory(self, name, ntypes=None, etypes=None, formats=('coo', 'csr', 'csc')): def shared_memory(self, name, ntypes=None, etypes=None, formats=('coo', 'csr', 'csc')):
"""Return a copy of this graph in shared memory """Return a copy of this graph in shared memory
......
...@@ -19,8 +19,8 @@ IdArray CumSum(IdArray array, bool prepend_zero) { ...@@ -19,8 +19,8 @@ IdArray CumSum(IdArray array, bool prepend_zero) {
if (len == 0) if (len == 0)
return !prepend_zero ? array : aten::Full(0, 1, array->dtype.bits, array->ctx); return !prepend_zero ? array : aten::Full(0, 1, array->dtype.bits, array->ctx);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
auto device = runtime::DeviceAPI::Get(array->ctx); auto device = runtime::DeviceAPI::Get(array->ctx);
cudaStream_t stream = runtime::getCurrentCUDAStream();
const IdType* in_d = array.Ptr<IdType>(); const IdType* in_d = array.Ptr<IdType>();
IdArray ret; IdArray ret;
IdType* out_d = nullptr; IdType* out_d = nullptr;
...@@ -34,12 +34,12 @@ IdArray CumSum(IdArray array, bool prepend_zero) { ...@@ -34,12 +34,12 @@ IdArray CumSum(IdArray array, bool prepend_zero) {
// Allocate workspace // Allocate workspace
size_t workspace_size = 0; size_t workspace_size = 0;
CUDA_CALL(cub::DeviceScan::InclusiveSum( CUDA_CALL(cub::DeviceScan::InclusiveSum(
nullptr, workspace_size, in_d, out_d, len, thr_entry->stream)); nullptr, workspace_size, in_d, out_d, len, stream));
void* workspace = device->AllocWorkspace(array->ctx, workspace_size); void* workspace = device->AllocWorkspace(array->ctx, workspace_size);
// Compute cumsum // Compute cumsum
CUDA_CALL(cub::DeviceScan::InclusiveSum( CUDA_CALL(cub::DeviceScan::InclusiveSum(
workspace, workspace_size, in_d, out_d, len, thr_entry->stream)); workspace, workspace_size, in_d, out_d, len, stream));
device->FreeWorkspace(array->ctx, workspace); device->FreeWorkspace(array->ctx, workspace);
......
...@@ -15,7 +15,7 @@ namespace impl { ...@@ -15,7 +15,7 @@ namespace impl {
template<DLDeviceType XPU, typename DType, typename IdType> template<DLDeviceType XPU, typename DType, typename IdType>
NDArray IndexSelect(NDArray array, IdArray index) { NDArray IndexSelect(NDArray array, IdArray index) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const DType* array_data = static_cast<DType*>(array->data); const DType* array_data = static_cast<DType*>(array->data);
const IdType* idx_data = static_cast<IdType*>(index->data); const IdType* idx_data = static_cast<IdType*>(index->data);
const int64_t arr_len = array->shape[0]; const int64_t arr_len = array->shape[0];
...@@ -36,7 +36,7 @@ NDArray IndexSelect(NDArray array, IdArray index) { ...@@ -36,7 +36,7 @@ NDArray IndexSelect(NDArray array, IdArray index) {
if (num_feat == 1) { if (num_feat == 1) {
const int nt = cuda::FindNumThreads(len); const int nt = cuda::FindNumThreads(len);
const int nb = (len + nt - 1) / nt; const int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL(IndexSelectSingleKernel, nb, nt, 0, thr_entry->stream, CUDA_KERNEL_CALL(IndexSelectSingleKernel, nb, nt, 0, stream,
array_data, idx_data, len, arr_len, ret_data); array_data, idx_data, len, arr_len, ret_data);
} else { } else {
dim3 block(256, 1); dim3 block(256, 1);
...@@ -45,7 +45,7 @@ NDArray IndexSelect(NDArray array, IdArray index) { ...@@ -45,7 +45,7 @@ NDArray IndexSelect(NDArray array, IdArray index) {
block.y *= 2; block.y *= 2;
} }
const dim3 grid((len+block.y-1)/block.y); const dim3 grid((len+block.y-1)/block.y);
CUDA_KERNEL_CALL(IndexSelectMultiKernel, grid, block, 0, thr_entry->stream, CUDA_KERNEL_CALL(IndexSelectMultiKernel, grid, block, 0, stream,
array_data, num_feat, idx_data, len, arr_len, ret_data); array_data, num_feat, idx_data, len, arr_len, ret_data);
} }
return ret; return ret;
......
...@@ -34,7 +34,7 @@ IdArray NonZero(IdArray array) { ...@@ -34,7 +34,7 @@ IdArray NonZero(IdArray array) {
const int64_t len = array->shape[0]; const int64_t len = array->shape[0];
IdArray ret = NewIdArray(len, ctx, 64); IdArray ret = NewIdArray(len, ctx, 64);
cudaStream_t stream = runtime::CUDAThreadEntry::ThreadLocal()->stream; cudaStream_t stream = runtime::getCurrentCUDAStream();
const IdType * const in_data = static_cast<const IdType*>(array->data); const IdType * const in_data = static_cast<const IdType*>(array->data);
int64_t * const out_data = static_cast<int64_t*>(ret->data); int64_t * const out_data = static_cast<int64_t*>(ret->data);
......
...@@ -35,11 +35,11 @@ IdArray BinaryElewise(IdArray lhs, IdArray rhs) { ...@@ -35,11 +35,11 @@ IdArray BinaryElewise(IdArray lhs, IdArray rhs) {
const IdType* lhs_data = static_cast<IdType*>(lhs->data); const IdType* lhs_data = static_cast<IdType*>(lhs->data);
const IdType* rhs_data = static_cast<IdType*>(rhs->data); const IdType* rhs_data = static_cast<IdType*>(rhs->data);
IdType* ret_data = static_cast<IdType*>(ret->data); IdType* ret_data = static_cast<IdType*>(ret->data);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = cuda::FindNumThreads(len); int nt = cuda::FindNumThreads(len);
int nb = (len + nt - 1) / nt; int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL((_BinaryElewiseKernel<IdType, Op>), CUDA_KERNEL_CALL((_BinaryElewiseKernel<IdType, Op>),
nb, nt, 0, thr_entry->stream, nb, nt, 0, stream,
lhs_data, rhs_data, ret_data, len); lhs_data, rhs_data, ret_data, len);
return ret; return ret;
} }
...@@ -85,11 +85,11 @@ IdArray BinaryElewise(IdArray lhs, IdType rhs) { ...@@ -85,11 +85,11 @@ IdArray BinaryElewise(IdArray lhs, IdType rhs) {
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits); IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
const IdType* lhs_data = static_cast<IdType*>(lhs->data); const IdType* lhs_data = static_cast<IdType*>(lhs->data);
IdType* ret_data = static_cast<IdType*>(ret->data); IdType* ret_data = static_cast<IdType*>(ret->data);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = cuda::FindNumThreads(len); int nt = cuda::FindNumThreads(len);
int nb = (len + nt - 1) / nt; int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL((_BinaryElewiseKernel<IdType, Op>), CUDA_KERNEL_CALL((_BinaryElewiseKernel<IdType, Op>),
nb, nt, 0, thr_entry->stream, nb, nt, 0, stream,
lhs_data, rhs, ret_data, len); lhs_data, rhs, ret_data, len);
return ret; return ret;
} }
...@@ -136,11 +136,11 @@ IdArray BinaryElewise(IdType lhs, IdArray rhs) { ...@@ -136,11 +136,11 @@ IdArray BinaryElewise(IdType lhs, IdArray rhs) {
IdArray ret = NewIdArray(rhs->shape[0], rhs->ctx, rhs->dtype.bits); IdArray ret = NewIdArray(rhs->shape[0], rhs->ctx, rhs->dtype.bits);
const IdType* rhs_data = static_cast<IdType*>(rhs->data); const IdType* rhs_data = static_cast<IdType*>(rhs->data);
IdType* ret_data = static_cast<IdType*>(ret->data); IdType* ret_data = static_cast<IdType*>(ret->data);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = cuda::FindNumThreads(len); int nt = cuda::FindNumThreads(len);
int nb = (len + nt - 1) / nt; int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL((_BinaryElewiseKernel<IdType, Op>), CUDA_KERNEL_CALL((_BinaryElewiseKernel<IdType, Op>),
nb, nt, 0, thr_entry->stream, nb, nt, 0, stream,
lhs, rhs_data, ret_data, len); lhs, rhs_data, ret_data, len);
return ret; return ret;
} }
...@@ -185,11 +185,11 @@ IdArray UnaryElewise(IdArray lhs) { ...@@ -185,11 +185,11 @@ IdArray UnaryElewise(IdArray lhs) {
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits); IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
const IdType* lhs_data = static_cast<IdType*>(lhs->data); const IdType* lhs_data = static_cast<IdType*>(lhs->data);
IdType* ret_data = static_cast<IdType*>(ret->data); IdType* ret_data = static_cast<IdType*>(ret->data);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = cuda::FindNumThreads(len); int nt = cuda::FindNumThreads(len);
int nb = (len + nt - 1) / nt; int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL((_UnaryElewiseKernel<IdType, Op>), CUDA_KERNEL_CALL((_UnaryElewiseKernel<IdType, Op>),
nb, nt, 0, thr_entry->stream, nb, nt, 0, stream,
lhs_data, ret_data, len); lhs_data, ret_data, len);
return ret; return ret;
} }
...@@ -214,10 +214,10 @@ template <DLDeviceType XPU, typename DType> ...@@ -214,10 +214,10 @@ template <DLDeviceType XPU, typename DType>
NDArray Full(DType val, int64_t length, DLContext ctx) { NDArray Full(DType val, int64_t length, DLContext ctx) {
NDArray ret = NDArray::Empty({length}, DLDataTypeTraits<DType>::dtype, ctx); NDArray ret = NDArray::Empty({length}, DLDataTypeTraits<DType>::dtype, ctx);
DType* ret_data = static_cast<DType*>(ret->data); DType* ret_data = static_cast<DType*>(ret->data);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = cuda::FindNumThreads(length); int nt = cuda::FindNumThreads(length);
int nb = (length + nt - 1) / nt; int nb = (length + nt - 1) / nt;
CUDA_KERNEL_CALL((_FullKernel<DType>), nb, nt, 0, thr_entry->stream, CUDA_KERNEL_CALL((_FullKernel<DType>), nb, nt, 0, stream,
ret_data, length, val); ret_data, length, val);
return ret; return ret;
} }
...@@ -251,11 +251,11 @@ IdArray Range(IdType low, IdType high, DLContext ctx) { ...@@ -251,11 +251,11 @@ IdArray Range(IdType low, IdType high, DLContext ctx) {
if (length == 0) if (length == 0)
return ret; return ret;
IdType* ret_data = static_cast<IdType*>(ret->data); IdType* ret_data = static_cast<IdType*>(ret->data);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = cuda::FindNumThreads(length); int nt = cuda::FindNumThreads(length);
int nb = (length + nt - 1) / nt; int nb = (length + nt - 1) / nt;
CUDA_KERNEL_CALL((_RangeKernel<IdType>), CUDA_KERNEL_CALL((_RangeKernel<IdType>),
nb, nt, 0, thr_entry->stream, nb, nt, 0, stream,
ret_data, low, length); ret_data, low, length);
return ret; return ret;
} }
...@@ -289,10 +289,10 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) { ...@@ -289,10 +289,10 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
const auto& ctx = arrays[0]->ctx; const auto& ctx = arrays[0]->ctx;
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
// build node maps and get the induced nodes // build node maps and get the induced nodes
OrderedHashTable<IdType> node_map(total_length, ctx, thr_entry->stream); OrderedHashTable<IdType> node_map(total_length, ctx, stream);
int64_t num_induced = 0; int64_t num_induced = 0;
int64_t * num_induced_device = static_cast<int64_t*>( int64_t * num_induced_device = static_cast<int64_t*>(
device->AllocWorkspace(ctx, sizeof(int64_t))); device->AllocWorkspace(ctx, sizeof(int64_t)));
...@@ -302,15 +302,15 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) { ...@@ -302,15 +302,15 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
num_induced_device, num_induced_device,
0, 0,
sizeof(*num_induced_device), sizeof(*num_induced_device),
thr_entry->stream)); stream));
node_map.FillWithDuplicates( node_map.FillWithDuplicates(
all_nodes.Ptr<IdType>(), all_nodes.Ptr<IdType>(),
all_nodes->shape[0], all_nodes->shape[0],
induced_nodes.Ptr<IdType>(), induced_nodes.Ptr<IdType>(),
num_induced_device, num_induced_device,
thr_entry->stream); stream);
// copy using the internal stream: thr_entry->stream // copy using the internal current stream
device->CopyDataFromTo( device->CopyDataFromTo(
num_induced_device, 0, num_induced_device, 0,
&num_induced, 0, &num_induced, 0,
...@@ -319,7 +319,7 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) { ...@@ -319,7 +319,7 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
DGLContext{kDLCPU, 0}, DGLContext{kDLCPU, 0},
DGLType{kDLInt, 64, 1}); DGLType{kDLInt, 64, 1});
device->StreamSync(ctx, thr_entry->stream); device->StreamSync(ctx, stream);
device->FreeWorkspace(ctx, num_induced_device); device->FreeWorkspace(ctx, num_induced_device);
// resize the induced nodes // resize the induced nodes
...@@ -331,7 +331,7 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) { ...@@ -331,7 +331,7 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
const int64_t length = arr->shape[0]; const int64_t length = arr->shape[0];
int nb = (length + nt - 1) / nt; int nb = (length + nt - 1) / nt;
CUDA_KERNEL_CALL((_RelabelKernel<IdType>), CUDA_KERNEL_CALL((_RelabelKernel<IdType>),
nb, nt, 0, thr_entry->stream, nb, nt, 0, stream,
arr.Ptr<IdType>(), length, node_map.DeviceHandle()); arr.Ptr<IdType>(), length, node_map.DeviceHandle());
} }
...@@ -358,16 +358,16 @@ IdArray AsNumBits(IdArray arr, uint8_t bits) { ...@@ -358,16 +358,16 @@ IdArray AsNumBits(IdArray arr, uint8_t bits) {
const std::vector<int64_t> shape(arr->shape, arr->shape + arr->ndim); const std::vector<int64_t> shape(arr->shape, arr->shape + arr->ndim);
IdArray ret = IdArray::Empty(shape, DLDataType{kDLInt, bits, 1}, arr->ctx); IdArray ret = IdArray::Empty(shape, DLDataType{kDLInt, bits, 1}, arr->ctx);
const int64_t length = ret.NumElements(); const int64_t length = ret.NumElements();
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = cuda::FindNumThreads(length); int nt = cuda::FindNumThreads(length);
int nb = (length + nt - 1) / nt; int nb = (length + nt - 1) / nt;
if (bits == 32) { if (bits == 32) {
CUDA_KERNEL_CALL((_CastKernel<IdType, int32_t>), CUDA_KERNEL_CALL((_CastKernel<IdType, int32_t>),
nb, nt, 0, thr_entry->stream, nb, nt, 0, stream,
static_cast<IdType*>(arr->data), static_cast<int32_t*>(ret->data), length); static_cast<IdType*>(arr->data), static_cast<int32_t*>(ret->data), length);
} else { } else {
CUDA_KERNEL_CALL((_CastKernel<IdType, int64_t>), CUDA_KERNEL_CALL((_CastKernel<IdType, int64_t>),
nb, nt, 0, thr_entry->stream, nb, nt, 0, stream,
static_cast<IdType*>(arr->data), static_cast<int64_t*>(ret->data), length); static_cast<IdType*>(arr->data), static_cast<int64_t*>(ret->data), length);
} }
return ret; return ret;
......
...@@ -30,10 +30,10 @@ void Scatter_(IdArray index, NDArray value, NDArray out) { ...@@ -30,10 +30,10 @@ void Scatter_(IdArray index, NDArray value, NDArray out) {
const DType* val = value.Ptr<DType>(); const DType* val = value.Ptr<DType>();
DType* outd = out.Ptr<DType>(); DType* outd = out.Ptr<DType>();
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const int nt = cuda::FindNumThreads(len); const int nt = cuda::FindNumThreads(len);
const int nb = (len + nt - 1) / nt; const int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL(_ScatterKernel, nb, nt, 0, thr_entry->stream, CUDA_KERNEL_CALL(_ScatterKernel, nb, nt, 0, stream,
idx, val, len, outd); idx, val, len, outd);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment