Unverified Commit 9a00cf19 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Feature] Import PyTorch's CUDA stream management (#4503)

* add set_stream

* add .record_stream for NDArray and HeteroGraph

* refactor dgl stream Python APIs

* test record_stream

* add unit test for record stream

* use pytorch's stream

* fix lint

* fix cpu build

* address comments

* address comments

* add record stream tests for dgl.graph

* record frames and update dataloder

* add docstring

* update frame

* add backend check for record_stream

* remove CUDAThreadEntry::stream

* record stream for newly created formats

* fix bug

* fix cpp test

* fix None c_void_p to c_handle
parent 099b173f
......@@ -167,6 +167,18 @@ struct COOMatrix {
}
is_pinned = false;
}
/*!
* \brief Record stream for the row, col and data (if not Null) of the matrix.
* \param stream The stream that is using the graph
*/
inline void RecordStream(DGLStreamHandle stream) const {
row.RecordStream(stream);
col.RecordStream(stream);
if (!aten::IsNullArray(data)) {
data.RecordStream(stream);
}
}
};
///////////////////////// COO routines //////////////////////////
......
......@@ -160,6 +160,18 @@ struct CSRMatrix {
}
is_pinned = false;
}
/*!
* \brief Record stream for the indptr, indices and data (if not Null) of the matrix.
* \param stream The stream that is using the graph
*/
inline void RecordStream(DGLStreamHandle stream) const {
indptr.RecordStream(stream);
indices.RecordStream(stream);
if (!aten::IsNullArray(data)) {
data.RecordStream(stream);
}
}
};
///////////////////////// CSR routines //////////////////////////
......
......@@ -121,6 +121,12 @@ class BaseHeteroGraph : public runtime::Object {
*/
virtual bool IsPinned() const = 0;
/*!
* \brief Record stream for this graph.
* \param stream The stream that is using the graph
*/
virtual void RecordStream(DGLStreamHandle stream) = 0;
/*!
* \brief Get the number of integer bits used to store node/edge ids (32 or 64).
*/
......
......@@ -564,6 +564,11 @@ int DGLArrayPinData(DGLArrayHandle handle, DLContext ctx);
*/
int DGLArrayUnpinData(DGLArrayHandle handle, DLContext ctx);
/*!
* \brief Record the stream that's using this tensor.
*/
int DGLArrayRecordStream(DGLArrayHandle handle, DGLStreamHandle stream);
/*!
* \brief Bug report macro.
*
......
......@@ -188,6 +188,11 @@ class NDArray {
* \brief Check if the array is pinned.
*/
inline bool IsPinned() const;
/*!
* \brief Record streams that are using the underlying tensor.
* \param stream The stream that is using the underlying tensor.
*/
inline void RecordStream(DGLStreamHandle stream) const;
/*!
* \brief Load NDArray from stream
* \param stream The input data stream
......@@ -319,6 +324,13 @@ class NDArray {
*/
DGL_DLL static bool IsContainerPinned(Container* ptr);
/*!
* \brief Record streams that are using this tensor.
* \param ptr Pointer of the tensor to be recorded.
* \param stream The stream that is using this tensor.
*/
DGL_DLL static void RecordStream(DGLArray* tensor, DGLStreamHandle stream);
// internal namespace
struct Internal;
private:
......@@ -488,6 +500,11 @@ inline bool NDArray::IsPinned() const {
return IsContainerPinned(data_);
}
inline void NDArray::RecordStream(DGLStreamHandle stream) const {
CHECK(data_ != nullptr);
RecordStream(&(data_->dl_tensor), stream);
}
inline int NDArray::use_count() const {
if (data_ == nullptr) return 0;
return data_->ref_counter_.load(std::memory_order_relaxed);
......
......@@ -121,7 +121,7 @@ class TensorDispatcher {
/*!
* \brief Find the current PyTorch CUDA stream
* Used in CUDAThreadEntry::ThreadLocal->stream.
* Used in runtime::getCurrentCUDAStream().
*
* \note PyTorch pre-allocates/sets the current CUDA stream
* on current device via cudaGetDevice(). Make sure to call cudaSetDevice()
......@@ -135,6 +135,22 @@ class TensorDispatcher {
}
#endif // DGL_USE_CUDA
/*!
* \brief Record streams that are using this tensor.
* Used in NDArray::RecordStream().
*
* \param ptr Pointer of the tensor to be recorded.
* \param stream The stream that is using this tensor.
* \param device_id Device of the tensor.
*/
inline void RecordStream(void* ptr, DGLStreamHandle stream, int device_id) {
#ifdef DGL_USE_CUDA
auto entry = entrypoints_[Op::kRecordStream];
FUNCCAST(tensoradapter::RecordStream, entry)(
ptr, static_cast<cudaStream_t>(stream), device_id);
#endif // DGL_USE_CUDA
}
private:
/*! \brief ctor */
TensorDispatcher() = default;
......@@ -153,6 +169,7 @@ class TensorDispatcher {
"CUDARawAlloc",
"CUDARawDelete",
"CUDACurrentStream",
"RecordStream",
#endif // DGL_USE_CUDA
};
......@@ -165,6 +182,7 @@ class TensorDispatcher {
static constexpr int kCUDARawAlloc = 2;
static constexpr int kCUDARawDelete = 3;
static constexpr int kCUDACurrentStream = 4;
static constexpr int kRecordStream = 5;
#endif // DGL_USE_CUDA
};
......@@ -179,6 +197,7 @@ class TensorDispatcher {
nullptr,
nullptr,
nullptr,
nullptr,
#endif // DGL_USE_CUDA
};
......
......@@ -14,7 +14,6 @@ from ...ndarray import NDArray as DGLNDArray
from ... import backend as F
from ...base import DGLError, dgl_warning
from ...utils import to_dgl_context, check_device
from ..._ffi import streams as FS
__all__ = ['NodeDataLoader', 'EdgeDataLoader', 'GraphDataLoader',
# Temporary exposure.
......@@ -353,9 +352,8 @@ def _next(dl_iter, graph, device, load_input, load_output, stream=None):
result_ = (input_nodes, output_nodes, blocks, input_data, output_data)
if stream is not None:
with th.cuda.stream(stream):
with FS.stream(stream):
result = [_to_device(data, device)
for data in result_], result_, stream.record_event()
result = [_to_device(data, device)
for data in result_], result_, stream.record_event()
else:
result = [_to_device(data, device) for data in result_]
return result
......
......@@ -162,5 +162,8 @@ cdef inline object ctypes_handle(void* chandle):
cdef inline void* c_handle(object handle):
"""Cast C types handle to c handle."""
cdef unsigned long long v_ptr
v_ptr = handle.value
return <void*>(v_ptr)
if handle.value is None:
return NULL
else:
v_ptr = handle.value
return <void*>(v_ptr)
......@@ -326,6 +326,19 @@ class NDArrayBase(_NDArrayBase):
"""
check_call(_LIB.DGLArrayUnpinData(self.handle))
def record_stream(self, stream):
"""Record the stream that is using this tensor.
Note
----
This API is more for testing. Users should call ``record_stream``
on torch.Tensor or dgl.graph directly.
Parameters
----------
stream : DGLStreamHandle
"""
check_call(_LIB.DGLArrayRecordStream(self.handle, stream))
def free_extension_handle(handle, type_code):
"""Free c++ extension type handle
......
# pylint: disable=invalid-name, unused-import
"""Runtime stream api which is maily for internal use only."""
"""Runtime stream APIs which are mainly for internal test use only.
For applications, please use PyTorch's stream management, of which DGL is aware.
"""
from __future__ import absolute_import
import ctypes
from .base import _LIB, check_call, _FFI_MODE
from .runtime_ctypes import DGLStreamHandle
from .ndarray import context
from ..utils import to_dgl_context
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
def to_dgl_stream_handle(cuda_stream):
""" Convert torch.cuda.Stream to DGL stream handle
class StreamContext(object):
""" Context-manager that selects a given stream.
All CUDA kernels queued within its context will be enqueued
on a selected stream.
Parameters
----------
cuda_stream : torch.cuda.Stream.
Returns
-------
DGLStreamHandle
DGLStreamHandle of the input ``cuda_stream``.
"""
return ctypes.c_void_p(cuda_stream.cuda_stream)
def __init__(self, cuda_stream):
""" create stream context instance
Parameters
----------
cuda_stream : torch.cuda.Stream. This manager is a no-op if it's ``None``.
target stream will be set.
"""
if cuda_stream is None:
self.curr_cuda_stream = None
else:
self.ctx = to_dgl_context(cuda_stream.device)
self.curr_cuda_stream = cuda_stream.cuda_stream
def __enter__(self):
""" get previous stream and set target stream as current.
"""
if self.curr_cuda_stream is None:
return
self.prev_cuda_stream = DGLStreamHandle()
check_call(_LIB.DGLGetStream(
self.ctx.device_type, self.ctx.device_id, ctypes.byref(self.prev_cuda_stream)))
check_call(_LIB.DGLSetStream(
self.ctx.device_type, self.ctx.device_id, ctypes.c_void_p(self.curr_cuda_stream)))
def __exit__(self, exc_type, exc_value, exc_traceback):
""" restore previous stream when exiting.
"""
if self.curr_cuda_stream is None:
return
check_call(_LIB.DGLSetStream(
self.ctx.device_type, self.ctx.device_id, self.prev_cuda_stream))
def stream(cuda_stream):
""" Wrapper of StreamContext
def _dgl_get_stream(ctx):
"""Get the current CUDA stream of the given DGL context.
Parameters
----------
stream : torch.cuda.Stream. This manager is a no-op if it's ``None``.
target stream will be set.
ctx : DGL context.
Returns
-------
DGLStreamHandle
DGLStreamHandle of the current CUDA stream.
"""
return StreamContext(cuda_stream)
current_cuda_stream = DGLStreamHandle()
check_call(_LIB.DGLGetStream(
ctx.device_type, ctx.device_id, ctypes.byref(current_cuda_stream)))
return current_cuda_stream
""" CUDA wrappers """
from . import nccl
from .._ffi.streams import stream
......@@ -303,7 +303,7 @@ def _await_or_return(x):
def _record_stream(x, stream):
if stream is None:
return x
if isinstance(x, torch.Tensor):
if hasattr(x, 'record_stream'):
x.record_stream(stream)
return x
elif isinstance(x, _PrefetchedGraphFeatures):
......@@ -333,9 +333,7 @@ def _prefetch(batch, dataloader, stream):
feats = recursive_apply(batch, _prefetch_for, dataloader)
feats = recursive_apply(feats, _await_or_return)
feats = recursive_apply(feats, _record_stream, current_stream)
# transfer input nodes/seed nodes
# TODO(Xin): sampled subgraph is transferred in the default stream
# because heterograph doesn't support .record_stream() for now
# transfer input nodes/seed nodes/subgraphs
batch = recursive_apply(batch, lambda x: x.to(dataloader.device, non_blocking=True))
batch = recursive_apply(batch, _record_stream, current_stream)
stream_event = stream.record_event() if stream is not None else None
......
......@@ -101,6 +101,9 @@ class LazyFeature(object):
def unpin_memory_(self):
"""No-op. For compatibility of :meth:`Frame.unpin_memory_` method."""
def record_stream(self, stream):
"""No-op. For compatibility of :meth:`Frame.record_stream` method."""
class Scheme(namedtuple('Scheme', ['shape', 'dtype'])):
"""The column scheme.
......@@ -492,6 +495,18 @@ class Column(TensorStorage):
self._data_nd = None
self.pinned_by_dgl = False
def record_stream(self, stream):
"""Record stream that is using the storage.
Does nothing if the backend is not PyTorch.
Parameters
----------
stream : torch.cuda.Stream.
"""
if F.get_preferred_backend() != 'pytorch':
raise DGLError("record_stream only supports the PyTorch backend.")
self.data.record_stream(stream)
class Frame(MutableMapping):
"""The columnar storage for node/edge features.
......@@ -877,6 +892,12 @@ class Frame(MutableMapping):
for column in self._columns.values():
column.unpin_memory_()
def record_stream(self, stream):
"""Record stream that is using the data of every column, materializing them
if necessary."""
for column in self._columns.values():
column.record_stream(stream)
def _astype_float(self, new_type):
assert new_type in [F.float64, F.float32, F.float16], \
"'new_type' must be floating-point type: %s" % str(new_type)
......
......@@ -5584,6 +5584,31 @@ class DGLHeteroGraph(object):
"""
return self._graph.is_pinned()
def record_stream(self, stream):
"""Record the stream that is using this graph.
This method only supports the PyTorch backend and requires graphs on the GPU.
Parameters
----------
stream : torch.cuda.Stream
The stream that is using this graph.
Returns
-------
DGLGraph
self.
"""
if F.get_preferred_backend() != 'pytorch':
raise DGLError("record_stream only support the PyTorch backend.")
if F.device_type(self.device) != 'cuda':
raise DGLError("The graph must be on GPU to be recorded.")
self._graph.record_stream(stream)
for frame in itertools.chain(self._node_frames, self._edge_frames):
for col in frame._columns.values():
col.record_stream(stream)
return self
def clone(self):
"""Return a heterograph object that is a clone of current graph.
......
......@@ -8,6 +8,7 @@ import scipy
from ._ffi.object import register_object, ObjectBase
from ._ffi.function import _init_api
from ._ffi.streams import to_dgl_stream_handle
from .base import DGLError, dgl_warning
from .graph_index import from_coo
from . import backend as F
......@@ -273,6 +274,21 @@ class HeteroGraphIndex(ObjectBase):
"""
return bool(_CAPI_DGLHeteroIsPinned(self))
def record_stream(self, stream):
"""Record the stream that is using this graph.
Parameters
----------
stream : torch.cuda.Stream
The stream that is using this graph.
Returns
-------
HeteroGraphIndex
self.
"""
return _CAPI_DGLHeteroRecordStream(self, to_dgl_stream_handle(stream))
def shared_memory(self, name, ntypes=None, etypes=None, formats=('coo', 'csr', 'csc')):
"""Return a copy of this graph in shared memory
......
......@@ -19,8 +19,8 @@ IdArray CumSum(IdArray array, bool prepend_zero) {
if (len == 0)
return !prepend_zero ? array : aten::Full(0, 1, array->dtype.bits, array->ctx);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
auto device = runtime::DeviceAPI::Get(array->ctx);
cudaStream_t stream = runtime::getCurrentCUDAStream();
const IdType* in_d = array.Ptr<IdType>();
IdArray ret;
IdType* out_d = nullptr;
......@@ -34,12 +34,12 @@ IdArray CumSum(IdArray array, bool prepend_zero) {
// Allocate workspace
size_t workspace_size = 0;
CUDA_CALL(cub::DeviceScan::InclusiveSum(
nullptr, workspace_size, in_d, out_d, len, thr_entry->stream));
nullptr, workspace_size, in_d, out_d, len, stream));
void* workspace = device->AllocWorkspace(array->ctx, workspace_size);
// Compute cumsum
CUDA_CALL(cub::DeviceScan::InclusiveSum(
workspace, workspace_size, in_d, out_d, len, thr_entry->stream));
workspace, workspace_size, in_d, out_d, len, stream));
device->FreeWorkspace(array->ctx, workspace);
......
......@@ -15,7 +15,7 @@ namespace impl {
template<DLDeviceType XPU, typename DType, typename IdType>
NDArray IndexSelect(NDArray array, IdArray index) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
const DType* array_data = static_cast<DType*>(array->data);
const IdType* idx_data = static_cast<IdType*>(index->data);
const int64_t arr_len = array->shape[0];
......@@ -36,7 +36,7 @@ NDArray IndexSelect(NDArray array, IdArray index) {
if (num_feat == 1) {
const int nt = cuda::FindNumThreads(len);
const int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL(IndexSelectSingleKernel, nb, nt, 0, thr_entry->stream,
CUDA_KERNEL_CALL(IndexSelectSingleKernel, nb, nt, 0, stream,
array_data, idx_data, len, arr_len, ret_data);
} else {
dim3 block(256, 1);
......@@ -45,7 +45,7 @@ NDArray IndexSelect(NDArray array, IdArray index) {
block.y *= 2;
}
const dim3 grid((len+block.y-1)/block.y);
CUDA_KERNEL_CALL(IndexSelectMultiKernel, grid, block, 0, thr_entry->stream,
CUDA_KERNEL_CALL(IndexSelectMultiKernel, grid, block, 0, stream,
array_data, num_feat, idx_data, len, arr_len, ret_data);
}
return ret;
......
......@@ -34,7 +34,7 @@ IdArray NonZero(IdArray array) {
const int64_t len = array->shape[0];
IdArray ret = NewIdArray(len, ctx, 64);
cudaStream_t stream = runtime::CUDAThreadEntry::ThreadLocal()->stream;
cudaStream_t stream = runtime::getCurrentCUDAStream();
const IdType * const in_data = static_cast<const IdType*>(array->data);
int64_t * const out_data = static_cast<int64_t*>(ret->data);
......
......@@ -35,11 +35,11 @@ IdArray BinaryElewise(IdArray lhs, IdArray rhs) {
const IdType* lhs_data = static_cast<IdType*>(lhs->data);
const IdType* rhs_data = static_cast<IdType*>(rhs->data);
IdType* ret_data = static_cast<IdType*>(ret->data);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = cuda::FindNumThreads(len);
int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL((_BinaryElewiseKernel<IdType, Op>),
nb, nt, 0, thr_entry->stream,
nb, nt, 0, stream,
lhs_data, rhs_data, ret_data, len);
return ret;
}
......@@ -85,11 +85,11 @@ IdArray BinaryElewise(IdArray lhs, IdType rhs) {
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
const IdType* lhs_data = static_cast<IdType*>(lhs->data);
IdType* ret_data = static_cast<IdType*>(ret->data);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = cuda::FindNumThreads(len);
int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL((_BinaryElewiseKernel<IdType, Op>),
nb, nt, 0, thr_entry->stream,
nb, nt, 0, stream,
lhs_data, rhs, ret_data, len);
return ret;
}
......@@ -136,11 +136,11 @@ IdArray BinaryElewise(IdType lhs, IdArray rhs) {
IdArray ret = NewIdArray(rhs->shape[0], rhs->ctx, rhs->dtype.bits);
const IdType* rhs_data = static_cast<IdType*>(rhs->data);
IdType* ret_data = static_cast<IdType*>(ret->data);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = cuda::FindNumThreads(len);
int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL((_BinaryElewiseKernel<IdType, Op>),
nb, nt, 0, thr_entry->stream,
nb, nt, 0, stream,
lhs, rhs_data, ret_data, len);
return ret;
}
......@@ -185,11 +185,11 @@ IdArray UnaryElewise(IdArray lhs) {
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
const IdType* lhs_data = static_cast<IdType*>(lhs->data);
IdType* ret_data = static_cast<IdType*>(ret->data);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = cuda::FindNumThreads(len);
int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL((_UnaryElewiseKernel<IdType, Op>),
nb, nt, 0, thr_entry->stream,
nb, nt, 0, stream,
lhs_data, ret_data, len);
return ret;
}
......@@ -214,10 +214,10 @@ template <DLDeviceType XPU, typename DType>
NDArray Full(DType val, int64_t length, DLContext ctx) {
NDArray ret = NDArray::Empty({length}, DLDataTypeTraits<DType>::dtype, ctx);
DType* ret_data = static_cast<DType*>(ret->data);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = cuda::FindNumThreads(length);
int nb = (length + nt - 1) / nt;
CUDA_KERNEL_CALL((_FullKernel<DType>), nb, nt, 0, thr_entry->stream,
CUDA_KERNEL_CALL((_FullKernel<DType>), nb, nt, 0, stream,
ret_data, length, val);
return ret;
}
......@@ -251,11 +251,11 @@ IdArray Range(IdType low, IdType high, DLContext ctx) {
if (length == 0)
return ret;
IdType* ret_data = static_cast<IdType*>(ret->data);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = cuda::FindNumThreads(length);
int nb = (length + nt - 1) / nt;
CUDA_KERNEL_CALL((_RangeKernel<IdType>),
nb, nt, 0, thr_entry->stream,
nb, nt, 0, stream,
ret_data, low, length);
return ret;
}
......@@ -289,10 +289,10 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
const auto& ctx = arrays[0]->ctx;
auto device = runtime::DeviceAPI::Get(ctx);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
// build node maps and get the induced nodes
OrderedHashTable<IdType> node_map(total_length, ctx, thr_entry->stream);
OrderedHashTable<IdType> node_map(total_length, ctx, stream);
int64_t num_induced = 0;
int64_t * num_induced_device = static_cast<int64_t*>(
device->AllocWorkspace(ctx, sizeof(int64_t)));
......@@ -302,15 +302,15 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
num_induced_device,
0,
sizeof(*num_induced_device),
thr_entry->stream));
stream));
node_map.FillWithDuplicates(
all_nodes.Ptr<IdType>(),
all_nodes->shape[0],
induced_nodes.Ptr<IdType>(),
num_induced_device,
thr_entry->stream);
// copy using the internal stream: thr_entry->stream
stream);
// copy using the internal current stream
device->CopyDataFromTo(
num_induced_device, 0,
&num_induced, 0,
......@@ -319,7 +319,7 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
DGLContext{kDLCPU, 0},
DGLType{kDLInt, 64, 1});
device->StreamSync(ctx, thr_entry->stream);
device->StreamSync(ctx, stream);
device->FreeWorkspace(ctx, num_induced_device);
// resize the induced nodes
......@@ -331,7 +331,7 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
const int64_t length = arr->shape[0];
int nb = (length + nt - 1) / nt;
CUDA_KERNEL_CALL((_RelabelKernel<IdType>),
nb, nt, 0, thr_entry->stream,
nb, nt, 0, stream,
arr.Ptr<IdType>(), length, node_map.DeviceHandle());
}
......@@ -358,16 +358,16 @@ IdArray AsNumBits(IdArray arr, uint8_t bits) {
const std::vector<int64_t> shape(arr->shape, arr->shape + arr->ndim);
IdArray ret = IdArray::Empty(shape, DLDataType{kDLInt, bits, 1}, arr->ctx);
const int64_t length = ret.NumElements();
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = cuda::FindNumThreads(length);
int nb = (length + nt - 1) / nt;
if (bits == 32) {
CUDA_KERNEL_CALL((_CastKernel<IdType, int32_t>),
nb, nt, 0, thr_entry->stream,
nb, nt, 0, stream,
static_cast<IdType*>(arr->data), static_cast<int32_t*>(ret->data), length);
} else {
CUDA_KERNEL_CALL((_CastKernel<IdType, int64_t>),
nb, nt, 0, thr_entry->stream,
nb, nt, 0, stream,
static_cast<IdType*>(arr->data), static_cast<int64_t*>(ret->data), length);
}
return ret;
......
......@@ -30,10 +30,10 @@ void Scatter_(IdArray index, NDArray value, NDArray out) {
const DType* val = value.Ptr<DType>();
DType* outd = out.Ptr<DType>();
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
const int nt = cuda::FindNumThreads(len);
const int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL(_ScatterKernel, nb, nt, 0, thr_entry->stream,
CUDA_KERNEL_CALL(_ScatterKernel, nb, nt, 0, stream,
idx, val, len, outd);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment