Unverified Commit 1c9d2a03 authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

[Feature] Unify the cuda stream used in core library (#4480)



* Use an internal cuda stream for CopyDataFromTo

* small fix white space

* Fix to compile

* Make stream optional in copydata for compile

* fix lint issue

* Update cub functions to use internal stream

* Lint check

* Update CopyTo/CopyFrom/CopyFromTo to use internal stream

* Address comments

* Fix backward CUDA stream

* Avoid overloading CopyFromTo()

* Minor comment update

* Overload copydatafromto in cuda device api
Co-authored-by: default avatarxiny <xiny@nvidia.com>
parent 62af41c2
......@@ -122,13 +122,12 @@ struct COOMatrix {
}
/*! \brief Return a copy of this matrix on the give device context. */
inline COOMatrix CopyTo(const DLContext &ctx,
const DGLStreamHandle &stream = nullptr) const {
inline COOMatrix CopyTo(const DLContext &ctx) const {
if (ctx == row->ctx)
return *this;
return COOMatrix(num_rows, num_cols, row.CopyTo(ctx, stream),
col.CopyTo(ctx, stream),
aten::IsNullArray(data) ? data : data.CopyTo(ctx, stream),
return COOMatrix(num_rows, num_cols, row.CopyTo(ctx),
col.CopyTo(ctx),
aten::IsNullArray(data) ? data : data.CopyTo(ctx),
row_sorted, col_sorted);
}
......
......@@ -115,13 +115,12 @@ struct CSRMatrix {
}
/*! \brief Return a copy of this matrix on the give device context. */
inline CSRMatrix CopyTo(const DLContext &ctx,
const DGLStreamHandle &stream = nullptr) const {
inline CSRMatrix CopyTo(const DLContext &ctx) const {
if (ctx == indptr->ctx)
return *this;
return CSRMatrix(num_rows, num_cols, indptr.CopyTo(ctx, stream),
indices.CopyTo(ctx, stream),
aten::IsNullArray(data) ? data : data.CopyTo(ctx, stream),
return CSRMatrix(num_rows, num_cols, indptr.CopyTo(ctx),
indices.CopyTo(ctx),
aten::IsNullArray(data) ? data : data.CopyTo(ctx),
sorted);
}
......
......@@ -450,12 +450,10 @@ DGL_DLL int DGLArrayCopyToBytes(DGLArrayHandle handle,
* \brief Copy the array, both from and to must be valid during the copy.
* \param from The array to be copied from.
* \param to The target space.
* \param stream The stream where the copy happens, can be NULL.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLArrayCopyFromTo(DGLArrayHandle from,
DGLArrayHandle to,
DGLStreamHandle stream);
DGLArrayHandle to);
/*!
* \brief Produce an array from the DLManagedTensor that shares data memory
......
......@@ -93,7 +93,6 @@ class DeviceAPI {
* \param ctx_to The target context
* \param type_hint The type of elements, only neded by certain backends.
* can be useful for cross device endian converison.
* \param stream Optional stream object.
*/
virtual void CopyDataFromTo(const void* from,
size_t from_offset,
......@@ -102,9 +101,8 @@ class DeviceAPI {
size_t num_bytes,
DGLContext ctx_from,
DGLContext ctx_to,
DGLType type_hint,
DGLStreamHandle stream) = 0;
/*!
DGLType type_hint) = 0;
/*!
* \brief Create a new stream of execution.
*
* \param ctx The context of allocation.
......
......@@ -147,39 +147,26 @@ class NDArray {
else
return static_cast<T*>(operator->()->data);
}
/*!
* \brief Copy data content from another array.
* \param other The source array to be copied from.
* \param stream The stream to perform the copy on if it involves a GPU
* context, otherwise this parameter is ignored.
* \note The copy may happen asynchrously if it involves a GPU context.
* DGLSynchronize is necessary.
*/
inline void CopyFrom(DLTensor* other,
DGLStreamHandle stream = nullptr);
inline void CopyFrom(const NDArray& other,
DGLStreamHandle stream = nullptr);
/*!
* \brief Copy data content into another array.
/*!
* \brief Copy data content from/into another array.
* \param other The source array to be copied from.
* \note The copy may happen asynchrously if it involves a GPU context.
* DGLSynchronize is necessary.
* \note The copy runs on the dgl internal stream if it involves a GPU context.
*/
inline void CopyTo(DLTensor *other,
const DGLStreamHandle &stream = nullptr) const;
inline void CopyTo(const NDArray &other,
const DGLStreamHandle &stream = nullptr) const;
inline void CopyFrom(DLTensor* other);
inline void CopyFrom(const NDArray& other);
inline void CopyTo(DLTensor *other) const;
inline void CopyTo(const NDArray &other) const;
/*!
* \brief Copy the data to another context.
* \param ctx The target context.
* \return The array under another context.
*/
inline NDArray CopyTo(const DLContext &ctx,
const DGLStreamHandle &stream = nullptr) const;
inline NDArray CopyTo(const DLContext &ctx) const;
/*!
* \brief Return a new array with a copy of the content.
*/
inline NDArray Clone(const DGLStreamHandle &stream = nullptr) const;
inline NDArray Clone() const;
/*!
* \brief In-place method to pin the current array by calling PinContainer
* on the underlying NDArray:Container.
......@@ -297,10 +284,12 @@ class NDArray {
* \brief Function to copy data from one array to another.
* \param from The source array.
* \param to The target array.
* \param stream The stream used in copy.
* \param (optional) stream The stream used in copy.
*/
DGL_DLL static void CopyFromTo(
DLTensor* from, DLTensor* to, DGLStreamHandle stream = nullptr);
DLTensor* from, DLTensor* to);
DGL_DLL static void CopyFromTo(
DLTensor* from, DLTensor* to, DGLStreamHandle stream);
/*!
* \brief Function to pin the DLTensor of a Container.
......@@ -449,46 +438,39 @@ inline void NDArray::reset() {
}
}
inline void NDArray::CopyFrom(DLTensor* other,
DGLStreamHandle stream) {
inline void NDArray::CopyFrom(DLTensor* other) {
CHECK(data_ != nullptr);
CopyFromTo(other, &(data_->dl_tensor), stream);
CopyFromTo(other, &(data_->dl_tensor));
}
inline void NDArray::CopyFrom(const NDArray& other,
DGLStreamHandle stream) {
CHECK(data_ != nullptr);
inline void NDArray::CopyFrom(const NDArray& other) {
CHECK(other.data_ != nullptr);
CopyFromTo(&(other.data_->dl_tensor), &(data_->dl_tensor), stream);
CopyFrom(&(other.data_->dl_tensor));
}
inline void NDArray::CopyTo(DLTensor *other,
const DGLStreamHandle &stream) const {
inline void NDArray::CopyTo(DLTensor *other) const {
CHECK(data_ != nullptr);
CopyFromTo(&(data_->dl_tensor), other, stream);
CopyFromTo(&(data_->dl_tensor), other);
}
inline void NDArray::CopyTo(const NDArray &other,
const DGLStreamHandle &stream) const {
CHECK(data_ != nullptr);
inline void NDArray::CopyTo(const NDArray &other) const {
CHECK(other.data_ != nullptr);
CopyFromTo(&(data_->dl_tensor), &(other.data_->dl_tensor), stream);
CopyTo(&(other.data_->dl_tensor));
}
inline NDArray NDArray::CopyTo(const DLContext &ctx,
const DGLStreamHandle &stream) const {
inline NDArray NDArray::CopyTo(const DLContext &ctx) const {
CHECK(data_ != nullptr);
const DLTensor* dptr = operator->();
NDArray ret = Empty(std::vector<int64_t>(dptr->shape, dptr->shape + dptr->ndim),
dptr->dtype, ctx);
this->CopyTo(ret, stream);
this->CopyTo(ret);
return ret;
}
inline NDArray NDArray::Clone(const DGLStreamHandle &stream) const {
inline NDArray NDArray::Clone() const {
CHECK(data_ != nullptr);
const DLTensor* dptr = operator->();
return this->CopyTo(dptr->ctx, stream);
return this->CopyTo(dptr->ctx);
}
inline void NDArray::PinMemory_() {
......
......@@ -118,6 +118,21 @@ class TensorDispatcher {
auto entry = entrypoints_[Op::kCUDARawDelete];
FUNCCAST(tensoradapter::CUDARawDelete, entry)(ptr);
}
/*!
* \brief Find the current PyTorch CUDA stream
* Used in CUDAThreadEntry::ThreadLocal->stream.
*
* \note PyTorch pre-allocates/sets the current CUDA stream
* on current device via cudaGetDevice(). Make sure to call cudaSetDevice()
* before invoking this function.
*
* \return cudaStream_t stream handle
*/
inline cudaStream_t CUDAGetCurrentStream() {
auto entry = entrypoints_[Op::kCUDACurrentStream];
return FUNCCAST(tensoradapter::CUDACurrentStream, entry)();
}
#endif // DGL_USE_CUDA
private:
......@@ -137,6 +152,7 @@ class TensorDispatcher {
#ifdef DGL_USE_CUDA
"CUDARawAlloc",
"CUDARawDelete",
"CUDACurrentStream",
#endif // DGL_USE_CUDA
};
......@@ -148,6 +164,7 @@ class TensorDispatcher {
#ifdef DGL_USE_CUDA
static constexpr int kCUDARawAlloc = 2;
static constexpr int kCUDARawDelete = 3;
static constexpr int kCUDACurrentStream = 4;
#endif // DGL_USE_CUDA
};
......@@ -161,6 +178,7 @@ class TensorDispatcher {
#ifdef DGL_USE_CUDA
nullptr,
nullptr,
nullptr,
#endif // DGL_USE_CUDA
};
......
......@@ -107,8 +107,7 @@ cdef extern from "dgl/runtime/c_runtime_api.h":
CDGLArrayHandle* out)
int DGLArrayFree(DLTensorHandle handle)
int DGLArrayCopyFromTo(DLTensorHandle src,
DLTensorHandle to,
DGLStreamHandle stream)
DLTensorHandle to)
int DGLArrayFromDLPack(DLManagedTensor* arr_from,
DLTensorHandle* out)
int DGLArrayToDLPack(DLTensorHandle arr_from,
......
......@@ -311,7 +311,7 @@ class NDArrayBase(_NDArrayBase):
target = empty(self.shape, self.dtype, target)
if isinstance(target, NDArrayBase):
check_call(_LIB.DGLArrayCopyFromTo(
self.handle, target.handle, None))
self.handle, target.handle))
else:
raise ValueError("Unsupported target type %s" % str(type(target)))
return target
......
......@@ -101,11 +101,11 @@ IdArray HStack(IdArray lhs, IdArray rhs) {
device->CopyDataFromTo(lhs.Ptr<IdType>(), 0,
ret.Ptr<IdType>(), 0,
len * sizeof(IdType),
ctx, ctx, lhs->dtype, nullptr);
ctx, ctx, lhs->dtype);
device->CopyDataFromTo(rhs.Ptr<IdType>(), 0,
ret.Ptr<IdType>(), len * sizeof(IdType),
len * sizeof(IdType),
ctx, ctx, lhs->dtype, nullptr);
ctx, ctx, lhs->dtype);
});
return ret;
}
......@@ -160,7 +160,7 @@ NDArray IndexSelect(NDArray array, int64_t start, int64_t end) {
ATEN_DTYPE_SWITCH(array->dtype, DType, "values", {
device->CopyDataFromTo(array->data, start * sizeof(DType),
ret->data, 0, len * sizeof(DType),
array->ctx, ret->ctx, array->dtype, nullptr);
array->ctx, ret->ctx, array->dtype);
});
return ret;
}
......@@ -240,8 +240,7 @@ NDArray Concat(const std::vector<IdArray>& arrays) {
arrays[i]->shape[0] * sizeof(DType),
arrays[i]->ctx,
ret_arr->ctx,
arrays[i]->dtype,
nullptr);
arrays[i]->dtype);
offset += arrays[i]->shape[0] * sizeof(DType);
});
......
......@@ -80,7 +80,7 @@ DType IndexSelect(NDArray array, int64_t index) {
device->CopyDataFromTo(
static_cast<DType*>(array->data) + index, 0, reinterpret_cast<DType*>(&ret), 0,
sizeof(DType), array->ctx, DLContext{kDLCPU, 0},
array->dtype, nullptr);
array->dtype);
return reinterpret_cast<DType&>(ret);
}
......
......@@ -34,7 +34,7 @@ IdArray NonZero(IdArray array) {
const int64_t len = array->shape[0];
IdArray ret = NewIdArray(len, ctx, 64);
cudaStream_t stream = 0;
cudaStream_t stream = runtime::CUDAThreadEntry::ThreadLocal()->stream;
const IdType * const in_data = static_cast<const IdType*>(array->data);
int64_t * const out_data = static_cast<int64_t*>(ret->data);
......@@ -55,7 +55,7 @@ IdArray NonZero(IdArray array) {
device->FreeWorkspace(ctx, temp);
// copy number of selected elements from GPU to CPU
int64_t num_nonzeros = cuda::GetCUDAScalar(device, ctx, d_num_nonzeros, stream);
int64_t num_nonzeros = cuda::GetCUDAScalar(device, ctx, d_num_nonzeros);
device->FreeWorkspace(ctx, d_num_nonzeros);
device->StreamSync(ctx, stream);
......
......@@ -310,15 +310,15 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
induced_nodes.Ptr<IdType>(),
num_induced_device,
thr_entry->stream);
// copy using the internal stream: thr_entry->stream
device->CopyDataFromTo(
num_induced_device, 0,
&num_induced, 0,
sizeof(num_induced),
ctx,
DGLContext{kDLCPU, 0},
DGLType{kDLInt, 64, 1},
thr_entry->stream);
DGLType{kDLInt, 64, 1});
device->StreamSync(ctx, thr_entry->stream);
device->FreeWorkspace(ctx, num_induced_device);
......
......@@ -27,6 +27,7 @@ std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits) {
IdType* keys_out = sorted_array.Ptr<IdType>();
int64_t* values_out = sorted_idx.Ptr<int64_t>();
auto stream = runtime::CUDAThreadEntry::ThreadLocal()->stream;
if (num_bits == 0) {
num_bits = sizeof(IdType)*8;
}
......@@ -34,12 +35,12 @@ std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits) {
// Allocate workspace
size_t workspace_size = 0;
CUDA_CALL(cub::DeviceRadixSort::SortPairs(nullptr, workspace_size,
keys_in, keys_out, values_in, values_out, nitems, 0, num_bits));
keys_in, keys_out, values_in, values_out, nitems, 0, num_bits, stream));
void* workspace = device->AllocWorkspace(ctx, workspace_size);
// Compute
CUDA_CALL(cub::DeviceRadixSort::SortPairs(workspace, workspace_size,
keys_in, keys_out, values_in, values_out, nitems, 0, num_bits));
keys_in, keys_out, values_in, values_out, nitems, 0, num_bits, stream));
device->FreeWorkspace(ctx, workspace);
......
......@@ -130,13 +130,13 @@ void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr) {
size_t workspace_size = 0;
CUDA_CALL(cub::DeviceSegmentedRadixSort::SortPairs(nullptr, workspace_size,
key_in, key_out, value_in, value_out,
nnz, csr->num_rows, offsets, offsets + 1));
nnz, csr->num_rows, offsets, offsets + 1, 0, sizeof(int64_t)*8, thr_entry->stream));
void* workspace = device->AllocWorkspace(ctx, workspace_size);
// Compute
CUDA_CALL(cub::DeviceSegmentedRadixSort::SortPairs(workspace, workspace_size,
key_in, key_out, value_in, value_out,
nnz, csr->num_rows, offsets, offsets + 1));
nnz, csr->num_rows, offsets, offsets + 1, 0, sizeof(int64_t)*8, thr_entry->stream));
csr->sorted = true;
csr->indices = new_indices;
......
......@@ -18,9 +18,7 @@ namespace array {
namespace {
// TODO(nv-dlasalle): Replace with getting the stream from the context
// when it's implemented.
constexpr cudaStream_t cudaDefaultStream = 0;
cudaStream_t cudaStream = runtime::CUDAThreadEntry::ThreadLocal()->stream;
template<typename IdType, bool include>
__global__ void _IsInKernel(
......@@ -59,8 +57,6 @@ IdArray _PerformFilter(
return test;
}
cudaStream_t stream = cudaDefaultStream;
// we need two arrays: 1) to act as a prefixsum
// for the number of entries that will be inserted, and
// 2) to collect the included items.
......@@ -76,7 +72,7 @@ IdArray _PerformFilter(
const dim3 grid((size+block.x-1)/block.x);
CUDA_KERNEL_CALL((_IsInKernel<IdType, include>),
grid, block, 0, stream,
grid, block, 0, cudaStream,
table.DeviceHandle(),
static_cast<const IdType*>(test->data),
size,
......@@ -91,7 +87,7 @@ IdArray _PerformFilter(
workspace_bytes,
static_cast<IdType*>(nullptr),
static_cast<IdType*>(nullptr),
size+1));
size+1, cudaStream));
void * workspace = device->AllocWorkspace(ctx, workspace_bytes);
CUDA_CALL(cub::DeviceScan::ExclusiveSum(
......@@ -99,19 +95,18 @@ IdArray _PerformFilter(
workspace_bytes,
prefix,
prefix,
size+1, stream));
size+1, cudaStream));
device->FreeWorkspace(ctx, workspace);
}
// copy number
// copy number using the internal stream CUDAThreadEntry::ThreadLocal()->stream;
IdType num_unique;
device->CopyDataFromTo(prefix+size, 0,
&num_unique, 0,
sizeof(num_unique),
ctx,
DGLContext{kDLCPU, 0},
test->dtype,
stream);
test->dtype);
// insert items into set
{
......@@ -119,7 +114,7 @@ IdArray _PerformFilter(
const dim3 grid((size+block.x-1)/block.x);
CUDA_KERNEL_CALL(_InsertKernel,
grid, block, 0, stream,
grid, block, 0, cudaStream,
prefix,
size,
static_cast<IdType*>(result->data));
......@@ -134,11 +129,11 @@ template<typename IdType>
class CudaFilterSet : public Filter {
public:
explicit CudaFilterSet(IdArray array) :
table_(array->shape[0], array->ctx, cudaDefaultStream) {
table_(array->shape[0], array->ctx, cudaStream) {
table_.FillWithUnique(
static_cast<const IdType*>(array->data),
array->shape[0],
cudaDefaultStream);
cudaStream);
}
IdArray find_included_indices(IdArray test) override {
......
......@@ -84,7 +84,7 @@ void _Merge(IdType** arrs, IdType* prefix, IdType* offset, IdType* out,
device->CopyDataFromTo(
arrs, 0, arrs_dev, 0, sizeof(IdType*)*n_arrs,
DGLContext{kDLCPU, 0}, ctx, dtype, 0);
DGLContext{kDLCPU, 0}, ctx, dtype);
CUDA_KERNEL_CALL(_DisjointUnionKernel,
nb, nt, 0, stream,
......@@ -135,17 +135,17 @@ COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
device->CopyDataFromTo(
&prefix_elm[coos.size()], 0, &n_elements, 0,
sizeof(IdType), coos[0].row->ctx, DGLContext{kDLCPU, 0},
coos[0].row->dtype, 0);
coos[0].row->dtype);
device->CopyDataFromTo(
&prefix_src[coos.size()], 0, &src_offset, 0,
sizeof(IdType), coos[0].row->ctx, DGLContext{kDLCPU, 0},
coos[0].row->dtype, 0);
coos[0].row->dtype);
device->CopyDataFromTo(
&prefix_dst[coos.size()], 0, &dst_offset, 0,
sizeof(IdType), coos[0].row->ctx, DGLContext{kDLCPU, 0},
coos[0].row->dtype, 0);
coos[0].row->dtype);
// Union src array
IdArray result_src = NewIdArray(
......
......@@ -46,6 +46,6 @@
device->FreeWorkspace(ctx, (LHS_OFF)); \
device->FreeWorkspace(ctx, (RHS_OFF)); \
} \
} while (0)
} while (0)
#endif
......@@ -174,7 +174,7 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
CUDA_CALL(cub::DeviceSelect::If(
tmp, tmp_size, begin, out_begin, num_out_cuda, num_actual_samples, op,
thr_entry->stream));
num_out = cuda::GetCUDAScalar(device, ctx, num_out_cuda, static_cast<cudaStream_t>(0));
num_out = cuda::GetCUDAScalar(device, ctx, num_out_cuda);
if (!replace) {
IdArray unique_row = IdArray::Empty({num_out}, dtype, ctx);
......@@ -198,7 +198,7 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
CUDA_CALL(cub::DeviceSelect::Unique(
tmp_unique, tmp_size_unique, out_begin, unique_begin, num_out_cuda, num_out,
thr_entry->stream));
num_out = cuda::GetCUDAScalar(device, ctx, num_out_cuda, static_cast<cudaStream_t>(0));
num_out = cuda::GetCUDAScalar(device, ctx, num_out_cuda);
num_out = std::min(num_samples, num_out);
result = {unique_row.CreateView({num_out}, dtype), unique_col.CreateView({num_out}, dtype)};
......
......@@ -248,9 +248,7 @@ COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat,
const auto& ctx = rows->ctx;
auto device = runtime::DeviceAPI::Get(ctx);
// TODO(dlasalle): Once the device api supports getting the stream from the
// context, that should be used instead of the default stream here.
cudaStream_t stream = 0;
cudaStream_t stream = runtime::CUDAThreadEntry::ThreadLocal()->stream;
const int64_t num_rows = rows->shape[0];
const IdType * const slice_rows = static_cast<const IdType*>(rows->data);
......@@ -310,12 +308,12 @@ COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat,
// TODO(dlasalle): use pinned memory to overlap with the actual sampling, and wait on
// a cudaevent
IdType new_len;
// copy using the internal stream: CUDAThreadEntry::ThreadLocal->stream
device->CopyDataFromTo(out_ptr, num_rows * sizeof(new_len), &new_len, 0,
sizeof(new_len),
ctx,
DGLContext{kDLCPU, 0},
mat.indptr->dtype,
stream);
mat.indptr->dtype);
CUDA_CALL(cudaEventRecord(copyEvent, stream));
const uint64_t random_seed = RandomEngine::ThreadLocal()->RandInt(1000000000);
......
......@@ -425,9 +425,7 @@ COOMatrix CSRRowWiseSampling(CSRMatrix mat,
const auto& ctx = rows->ctx;
auto device = runtime::DeviceAPI::Get(ctx);
// TODO(dlasalle): Once the device api supports getting the stream from the
// context, that should be used instead of the default stream here.
cudaStream_t stream = 0;
cudaStream_t stream = runtime::CUDAThreadEntry::ThreadLocal()->stream;
const int64_t num_rows = rows->shape[0];
const IdType * const slice_rows = static_cast<const IdType*>(rows->data);
......@@ -491,12 +489,12 @@ COOMatrix CSRRowWiseSampling(CSRMatrix mat,
// TODO(Xin): The copy here is too small, and the overhead of creating
// cuda events cannot be ignored. Just use synchronized copy.
IdType temp_len;
// copy using the internal dgl stream: CUDAThreadEntry::ThreadLocal()->stream
device->CopyDataFromTo(temp_ptr, num_rows * sizeof(temp_len), &temp_len, 0,
sizeof(temp_len),
ctx,
DGLContext{kDLCPU, 0},
mat.indptr->dtype,
stream);
mat.indptr->dtype);
device->StreamSync(ctx, stream);
// fill out_ptr
......@@ -522,12 +520,12 @@ COOMatrix CSRRowWiseSampling(CSRMatrix mat,
// TODO(dlasalle): use pinned memory to overlap with the actual sampling, and wait on
// a cudaevent
IdType new_len;
// copy using the internal dgl stream: CUDAThreadEntry::ThreadLocal()->stream
device->CopyDataFromTo(out_ptr, num_rows * sizeof(new_len), &new_len, 0,
sizeof(new_len),
ctx,
DGLContext{kDLCPU, 0},
mat.indptr->dtype,
stream);
mat.indptr->dtype);
CUDA_CALL(cudaEventRecord(copyEvent, stream));
// allocate workspace
......@@ -604,7 +602,7 @@ COOMatrix CSRRowWiseSampling(CSRMatrix mat,
temp_len,
num_rows,
temp_ptr,
temp_ptr + 1));
temp_ptr + 1, stream));
d_temp_storage = device->AllocWorkspace(ctx, temp_storage_bytes);
CUDA_CALL(cub::DeviceSegmentedSort::SortPairsDescending(
d_temp_storage,
......@@ -614,7 +612,7 @@ COOMatrix CSRRowWiseSampling(CSRMatrix mat,
temp_len,
num_rows,
temp_ptr,
temp_ptr + 1));
temp_ptr + 1, stream));
device->FreeWorkspace(ctx, d_temp_storage);
device->FreeWorkspace(ctx, temp);
device->FreeWorkspace(ctx, temp_idxs);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment