Unverified Commit 1c9d2a03 authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

[Feature] Unify the cuda stream used in core library (#4480)



* Use an internal cuda stream for CopyDataFromTo

* small fix white space

* Fix to compile

* Make stream optional in copydata for compile

* fix lint issue

* Update cub functions to use internal stream

* Lint check

* Update CopyTo/CopyFrom/CopyFromTo to use internal stream

* Address comments

* Fix backward CUDA stream

* Avoid overloading CopyFromTo()

* Minor comment update

* Overload copydatafromto in cuda device api
Co-authored-by: default avatarxiny <xiny@nvidia.com>
parent 62af41c2
...@@ -463,10 +463,12 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray ...@@ -463,10 +463,12 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray
auto ptr_cols = cols.Ptr<IdType>(); auto ptr_cols = cols.Ptr<IdType>();
size_t workspace_size = 0; size_t workspace_size = 0;
CUDA_CALL(cub::DeviceRadixSort::SortKeys( CUDA_CALL(cub::DeviceRadixSort::SortKeys(
nullptr, workspace_size, ptr_cols, ptr_sorted_cols, cols->shape[0])); nullptr, workspace_size, ptr_cols, ptr_sorted_cols, cols->shape[0],
0, sizeof(IdType)*8, thr_entry->stream));
void *workspace = device->AllocWorkspace(ctx, workspace_size); void *workspace = device->AllocWorkspace(ctx, workspace_size);
CUDA_CALL(cub::DeviceRadixSort::SortKeys( CUDA_CALL(cub::DeviceRadixSort::SortKeys(
workspace, workspace_size, ptr_cols, ptr_sorted_cols, cols->shape[0])); workspace, workspace_size, ptr_cols, ptr_sorted_cols, cols->shape[0],
0, sizeof(IdType)*8, thr_entry->stream));
device->FreeWorkspace(ctx, workspace); device->FreeWorkspace(ctx, workspace);
// Execute SegmentMaskColKernel // Execute SegmentMaskColKernel
......
...@@ -16,10 +16,11 @@ bool AllTrue(int8_t* flags, int64_t length, const DLContext& ctx) { ...@@ -16,10 +16,11 @@ bool AllTrue(int8_t* flags, int64_t length, const DLContext& ctx) {
int8_t* rst = static_cast<int8_t*>(device->AllocWorkspace(ctx, 1)); int8_t* rst = static_cast<int8_t*>(device->AllocWorkspace(ctx, 1));
// Call CUB's reduction // Call CUB's reduction
size_t workspace_size = 0; size_t workspace_size = 0;
CUDA_CALL(cub::DeviceReduce::Min(nullptr, workspace_size, flags, rst, length)); cudaStream_t stream = runtime::CUDAThreadEntry::ThreadLocal()->stream;
CUDA_CALL(cub::DeviceReduce::Min(nullptr, workspace_size, flags, rst, length, stream));
void* workspace = device->AllocWorkspace(ctx, workspace_size); void* workspace = device->AllocWorkspace(ctx, workspace_size);
CUDA_CALL(cub::DeviceReduce::Min(workspace, workspace_size, flags, rst, length)); CUDA_CALL(cub::DeviceReduce::Min(workspace, workspace_size, flags, rst, length, stream));
int8_t cpu_rst = GetCUDAScalar(device, ctx, rst, static_cast<cudaStream_t>(0)); int8_t cpu_rst = GetCUDAScalar(device, ctx, rst);
device->FreeWorkspace(ctx, workspace); device->FreeWorkspace(ctx, workspace);
device->FreeWorkspace(ctx, rst); device->FreeWorkspace(ctx, rst);
return cpu_rst == 1; return cpu_rst == 1;
......
...@@ -188,8 +188,7 @@ template <typename DType> ...@@ -188,8 +188,7 @@ template <typename DType>
inline DType GetCUDAScalar( inline DType GetCUDAScalar(
runtime::DeviceAPI* device_api, runtime::DeviceAPI* device_api,
DLContext ctx, DLContext ctx,
const DType* cuda_ptr, const DType* cuda_ptr) {
cudaStream_t stream) {
DType result; DType result;
device_api->CopyDataFromTo( device_api->CopyDataFromTo(
cuda_ptr, 0, cuda_ptr, 0,
...@@ -197,8 +196,7 @@ inline DType GetCUDAScalar( ...@@ -197,8 +196,7 @@ inline DType GetCUDAScalar(
sizeof(result), sizeof(result),
ctx, ctx,
DLContext{kDLCPU, 0}, DLContext{kDLCPU, 0},
DLDataTypeTraits<DType>::dtype, DLDataTypeTraits<DType>::dtype);
stream);
return result; return result;
} }
......
...@@ -252,8 +252,7 @@ HeteroGraphPtr HeteroGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) { ...@@ -252,8 +252,7 @@ HeteroGraphPtr HeteroGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
hgindex->num_verts_per_type_)); hgindex->num_verts_per_type_));
} }
HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DLContext &ctx, HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DLContext &ctx) {
const DGLStreamHandle &stream) {
if (ctx == g->Context()) { if (ctx == g->Context()) {
return g; return g;
} }
...@@ -261,7 +260,7 @@ HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DLContext &ctx, ...@@ -261,7 +260,7 @@ HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DLContext &ctx,
CHECK_NOTNULL(hgindex); CHECK_NOTNULL(hgindex);
std::vector<HeteroGraphPtr> rel_graphs; std::vector<HeteroGraphPtr> rel_graphs;
for (auto g : hgindex->relation_graphs_) { for (auto g : hgindex->relation_graphs_) {
rel_graphs.push_back(UnitGraph::CopyTo(g, ctx, stream)); rel_graphs.push_back(UnitGraph::CopyTo(g, ctx));
} }
return HeteroGraphPtr(new HeteroGraph(hgindex->meta_graph_, rel_graphs, return HeteroGraphPtr(new HeteroGraph(hgindex->meta_graph_, rel_graphs,
hgindex->num_verts_per_type_)); hgindex->num_verts_per_type_));
......
...@@ -229,8 +229,8 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -229,8 +229,8 @@ class HeteroGraph : public BaseHeteroGraph {
static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits); static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);
/*! \brief Copy the data to another context */ /*! \brief Copy the data to another context */
static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext &ctx, static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext &ctx);
const DGLStreamHandle &stream = nullptr);
/*! /*!
* \brief Pin all relation graphs of the current graph. * \brief Pin all relation graphs of the current graph.
......
...@@ -473,9 +473,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyTo") ...@@ -473,9 +473,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyTo")
DLContext ctx; DLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type); ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id; ctx.device_id = device_id;
DGLStreamHandle stream = nullptr; HeteroGraphPtr hg_new = HeteroGraph::CopyTo(hg.sptr(), ctx);
DGLGetStream(device_type, device_id, &stream);
HeteroGraphPtr hg_new = HeteroGraph::CopyTo(hg.sptr(), ctx, stream);
*rv = HeteroGraphRef(hg_new); *rv = HeteroGraphRef(hg_new);
}); });
......
...@@ -316,16 +316,16 @@ std::tuple<IdArray, IdArray, IdArray> FrequencyHashmap<IdxType>::Topk( ...@@ -316,16 +316,16 @@ std::tuple<IdArray, IdArray, IdArray> FrequencyHashmap<IdxType>::Topk(
void *d_temp_storage = nullptr; void *d_temp_storage = nullptr;
size_t temp_storage_bytes = 0; size_t temp_storage_bytes = 0;
CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes,
edge_blocks_prefix, edge_blocks_prefix_alternate, num_edge_blocks + 1)); edge_blocks_prefix, edge_blocks_prefix_alternate, num_edge_blocks + 1, _stream));
d_temp_storage = device->AllocWorkspace(_ctx, temp_storage_bytes); d_temp_storage = device->AllocWorkspace(_ctx, temp_storage_bytes);
CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes,
edge_blocks_prefix, edge_blocks_prefix_alternate, num_edge_blocks + 1)); edge_blocks_prefix, edge_blocks_prefix_alternate, num_edge_blocks + 1, _stream));
device->FreeWorkspace(_ctx, d_temp_storage); device->FreeWorkspace(_ctx, d_temp_storage);
std::swap(edge_blocks_prefix, edge_blocks_prefix_alternate); std::swap(edge_blocks_prefix, edge_blocks_prefix_alternate);
device->CopyDataFromTo(&edge_blocks_prefix[num_edge_blocks], 0, &num_unique_edges, 0, device->CopyDataFromTo(&edge_blocks_prefix[num_edge_blocks], 0, &num_unique_edges, 0,
sizeof(num_unique_edges), sizeof(num_unique_edges),
_ctx, DGLContext{kDLCPU, 0}, _ctx, DGLContext{kDLCPU, 0},
dtype, _stream); dtype);
device->StreamSync(_ctx, _stream); device->StreamSync(_ctx, _stream);
// 2.2 Allocate the data of unique edges and frequency // 2.2 Allocate the data of unique edges and frequency
// double space to use SegmentedRadixSort // double space to use SegmentedRadixSort
...@@ -350,10 +350,10 @@ std::tuple<IdArray, IdArray, IdArray> FrequencyHashmap<IdxType>::Topk( ...@@ -350,10 +350,10 @@ std::tuple<IdArray, IdArray, IdArray> FrequencyHashmap<IdxType>::Topk(
d_temp_storage = nullptr; d_temp_storage = nullptr;
temp_storage_bytes = 0; temp_storage_bytes = 0;
CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes,
num_unique_each_node, num_unique_each_node_alternate, num_dst_nodes + 1)); num_unique_each_node, num_unique_each_node_alternate, num_dst_nodes + 1, _stream));
d_temp_storage = device->AllocWorkspace(_ctx, temp_storage_bytes); d_temp_storage = device->AllocWorkspace(_ctx, temp_storage_bytes);
CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes,
num_unique_each_node, num_unique_each_node_alternate, num_dst_nodes + 1)); num_unique_each_node, num_unique_each_node_alternate, num_dst_nodes + 1, _stream));
device->FreeWorkspace(_ctx, d_temp_storage); device->FreeWorkspace(_ctx, d_temp_storage);
// 3.2 SegmentedRadixSort the unique_src_edges and unique_frequency // 3.2 SegmentedRadixSort the unique_src_edges and unique_frequency
// Create a set of DoubleBuffers to wrap pairs of device pointers // Create a set of DoubleBuffers to wrap pairs of device pointers
...@@ -366,20 +366,24 @@ std::tuple<IdArray, IdArray, IdArray> FrequencyHashmap<IdxType>::Topk( ...@@ -366,20 +366,24 @@ std::tuple<IdArray, IdArray, IdArray> FrequencyHashmap<IdxType>::Topk(
// especially when num_dst_nodes is large (about ~10000) // especially when num_dst_nodes is large (about ~10000)
if (dtype.bits == 32) { if (dtype.bits == 32) {
CUDA_CALL(cub::DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, CUDA_CALL(cub::DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes,
d_unique_frequency, d_unique_src_edges, num_unique_edges)); d_unique_frequency, d_unique_src_edges, num_unique_edges,
0, sizeof(Idx64Type)*8, _stream));
} else { } else {
CUDA_CALL(cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, CUDA_CALL(cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes,
d_unique_frequency, d_unique_src_edges, num_unique_edges, num_dst_nodes, d_unique_frequency, d_unique_src_edges, num_unique_edges, num_dst_nodes,
num_unique_each_node_alternate, num_unique_each_node_alternate + 1)); num_unique_each_node_alternate, num_unique_each_node_alternate + 1,
0, sizeof(Idx64Type)*8, _stream));
} }
d_temp_storage = device->AllocWorkspace(_ctx, temp_storage_bytes); d_temp_storage = device->AllocWorkspace(_ctx, temp_storage_bytes);
if (dtype.bits == 32) { if (dtype.bits == 32) {
CUDA_CALL(cub::DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, CUDA_CALL(cub::DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes,
d_unique_frequency, d_unique_src_edges, num_unique_edges)); d_unique_frequency, d_unique_src_edges, num_unique_edges,
0, sizeof(Idx64Type)*8, _stream));
} else { } else {
CUDA_CALL(cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, CUDA_CALL(cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes,
d_unique_frequency, d_unique_src_edges, num_unique_edges, num_dst_nodes, d_unique_frequency, d_unique_src_edges, num_unique_edges, num_dst_nodes,
num_unique_each_node_alternate, num_unique_each_node_alternate + 1)); num_unique_each_node_alternate, num_unique_each_node_alternate + 1,
0, sizeof(Idx64Type)*8, _stream));
} }
device->FreeWorkspace(_ctx, d_temp_storage); device->FreeWorkspace(_ctx, d_temp_storage);
...@@ -395,10 +399,10 @@ std::tuple<IdArray, IdArray, IdArray> FrequencyHashmap<IdxType>::Topk( ...@@ -395,10 +399,10 @@ std::tuple<IdArray, IdArray, IdArray> FrequencyHashmap<IdxType>::Topk(
d_temp_storage = nullptr; d_temp_storage = nullptr;
temp_storage_bytes = 0; temp_storage_bytes = 0;
CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes,
num_unique_each_node, unique_output_offsets, num_dst_nodes + 1)); num_unique_each_node, unique_output_offsets, num_dst_nodes + 1, _stream));
d_temp_storage = device->AllocWorkspace(_ctx, temp_storage_bytes); d_temp_storage = device->AllocWorkspace(_ctx, temp_storage_bytes);
CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes,
num_unique_each_node, unique_output_offsets, num_dst_nodes + 1)); num_unique_each_node, unique_output_offsets, num_dst_nodes + 1, _stream));
device->FreeWorkspace(_ctx, d_temp_storage); device->FreeWorkspace(_ctx, d_temp_storage);
// 5. Pick the data to result // 5. Pick the data to result
...@@ -406,7 +410,7 @@ std::tuple<IdArray, IdArray, IdArray> FrequencyHashmap<IdxType>::Topk( ...@@ -406,7 +410,7 @@ std::tuple<IdArray, IdArray, IdArray> FrequencyHashmap<IdxType>::Topk(
device->CopyDataFromTo(&unique_output_offsets[num_dst_nodes], 0, &num_output, 0, device->CopyDataFromTo(&unique_output_offsets[num_dst_nodes], 0, &num_output, 0,
sizeof(num_output), sizeof(num_output),
_ctx, DGLContext{kDLCPU, 0}, _ctx, DGLContext{kDLCPU, 0},
dtype, _stream); dtype);
device->StreamSync(_ctx, _stream); device->StreamSync(_ctx, _stream);
IdArray res_src = IdArray::Empty({static_cast<int64_t>(num_output)}, IdArray res_src = IdArray::Empty({static_cast<int64_t>(num_output)},
......
...@@ -29,14 +29,13 @@ TypeArray GetNodeTypesFromMetapath( ...@@ -29,14 +29,13 @@ TypeArray GetNodeTypesFromMetapath(
auto cpu_ctx = DGLContext{kDLCPU, 0}; auto cpu_ctx = DGLContext{kDLCPU, 0};
auto metapath_ctx = metapath->ctx; auto metapath_ctx = metapath->ctx;
// use default stream auto stream = DeviceAPI::Get(metapath_ctx)->GetStream();
cudaStream_t stream = 0;
TypeArray h_result = TypeArray::Empty( TypeArray h_result = TypeArray::Empty(
{metapath->shape[0] + 1}, metapath->dtype, cpu_ctx); {metapath->shape[0] + 1}, metapath->dtype, cpu_ctx);
auto h_result_data = h_result.Ptr<IdxType>(); auto h_result_data = h_result.Ptr<IdxType>();
auto h_metapath = metapath.CopyTo(cpu_ctx, stream); auto h_metapath = metapath.CopyTo(cpu_ctx);
DeviceAPI::Get(metapath_ctx)->StreamSync(metapath_ctx, stream); DeviceAPI::Get(metapath_ctx)->StreamSync(metapath_ctx, stream);
const IdxType *h_metapath_data = h_metapath.Ptr<IdxType>(); const IdxType *h_metapath_data = h_metapath.Ptr<IdxType>();
...@@ -56,7 +55,7 @@ TypeArray GetNodeTypesFromMetapath( ...@@ -56,7 +55,7 @@ TypeArray GetNodeTypesFromMetapath(
h_result_data[i + 1] = dsttype; h_result_data[i + 1] = dsttype;
} }
auto result = h_result.CopyTo(metapath->ctx, stream); auto result = h_result.CopyTo(metapath->ctx);
DeviceAPI::Get(metapath_ctx)->StreamSync(metapath_ctx, stream); DeviceAPI::Get(metapath_ctx)->StreamSync(metapath_ctx, stream);
return result; return result;
} }
......
...@@ -91,17 +91,17 @@ std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors( ...@@ -91,17 +91,17 @@ std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors(
res_src.Ptr<IdxType>(), 0, res_src.Ptr<IdxType>(), 0,
sizeof(IdxType) * res_src_vec.size(), sizeof(IdxType) * res_src_vec.size(),
DGLContext{kDLCPU, 0}, res_src->ctx, DGLContext{kDLCPU, 0}, res_src->ctx,
res_src->dtype, 0); res_src->dtype);
device->CopyDataFromTo(static_cast<IdxType*>(res_dst_vec.data()), 0, device->CopyDataFromTo(static_cast<IdxType*>(res_dst_vec.data()), 0,
res_dst.Ptr<IdxType>(), 0, res_dst.Ptr<IdxType>(), 0,
sizeof(IdxType) * res_dst_vec.size(), sizeof(IdxType) * res_dst_vec.size(),
DGLContext{kDLCPU, 0}, res_dst->ctx, DGLContext{kDLCPU, 0}, res_dst->ctx,
res_dst->dtype, 0); res_dst->dtype);
device->CopyDataFromTo(static_cast<IdxType*>(res_cnt_vec.data()), 0, device->CopyDataFromTo(static_cast<IdxType*>(res_cnt_vec.data()), 0,
res_cnt.Ptr<IdxType>(), 0, res_cnt.Ptr<IdxType>(), 0,
sizeof(IdxType) * res_cnt_vec.size(), sizeof(IdxType) * res_cnt_vec.size(),
DGLContext{kDLCPU, 0}, res_cnt->ctx, DGLContext{kDLCPU, 0}, res_cnt->ctx,
res_cnt->dtype, 0); res_cnt->dtype);
return std::make_tuple(res_src, res_dst, res_cnt); return std::make_tuple(res_src, res_dst, res_cnt);
} }
......
...@@ -197,8 +197,8 @@ std::pair<IdArray, IdArray> RandomWalkUniform( ...@@ -197,8 +197,8 @@ std::pair<IdArray, IdArray> RandomWalkUniform(
h_graphs[etype].in_cols = static_cast<const IdType*>(csr.indices->data); h_graphs[etype].in_cols = static_cast<const IdType*>(csr.indices->data);
h_graphs[etype].data = (CSRHasData(csr) ? static_cast<const IdType*>(csr.data->data) : nullptr); h_graphs[etype].data = (CSRHasData(csr) ? static_cast<const IdType*>(csr.data->data) : nullptr);
} }
// use default stream // use cuda stream from local thread
cudaStream_t stream = 0; cudaStream_t stream = runtime::CUDAThreadEntry::ThreadLocal()->stream;
auto device = DeviceAPI::Get(ctx); auto device = DeviceAPI::Get(ctx);
auto d_graphs = static_cast<GraphKernelData<IdType>*>( auto d_graphs = static_cast<GraphKernelData<IdType>*>(
device->AllocWorkspace(ctx, (num_etypes) * sizeof(GraphKernelData<IdType>))); device->AllocWorkspace(ctx, (num_etypes) * sizeof(GraphKernelData<IdType>)));
...@@ -207,8 +207,7 @@ std::pair<IdArray, IdArray> RandomWalkUniform( ...@@ -207,8 +207,7 @@ std::pair<IdArray, IdArray> RandomWalkUniform(
(num_etypes) * sizeof(GraphKernelData<IdType>), (num_etypes) * sizeof(GraphKernelData<IdType>),
DGLContext{kDLCPU, 0}, DGLContext{kDLCPU, 0},
ctx, ctx,
hg->GetCSRMatrix(0).indptr->dtype, hg->GetCSRMatrix(0).indptr->dtype);
stream);
// copy metapath to GPU // copy metapath to GPU
auto d_metapath = metapath.CopyTo(ctx); auto d_metapath = metapath.CopyTo(ctx);
const IdType *d_metapath_data = static_cast<IdType *>(d_metapath->data); const IdType *d_metapath_data = static_cast<IdType *>(d_metapath->data);
...@@ -270,7 +269,7 @@ std::pair<IdArray, IdArray> RandomWalkBiased( ...@@ -270,7 +269,7 @@ std::pair<IdArray, IdArray> RandomWalkBiased(
IdType *traces_data = traces.Ptr<IdType>(); IdType *traces_data = traces.Ptr<IdType>();
IdType *eids_data = eids.Ptr<IdType>(); IdType *eids_data = eids.Ptr<IdType>();
cudaStream_t stream = 0; cudaStream_t stream = runtime::CUDAThreadEntry::ThreadLocal()->stream;
auto device = DeviceAPI::Get(ctx); auto device = DeviceAPI::Get(ctx);
// new probs and prob sums pointers // new probs and prob sums pointers
assert(num_etypes == static_cast<int64_t>(prob.size())); assert(num_etypes == static_cast<int64_t>(prob.size()));
...@@ -306,14 +305,14 @@ std::pair<IdArray, IdArray> RandomWalkBiased( ...@@ -306,14 +305,14 @@ std::pair<IdArray, IdArray> RandomWalkBiased(
prob_sums[etype], prob_sums[etype],
num_segments, num_segments,
d_offsets, d_offsets,
d_offsets + 1)); d_offsets + 1, stream));
void *temp_storage = device->AllocWorkspace(ctx, temp_storage_size); void *temp_storage = device->AllocWorkspace(ctx, temp_storage_size);
CUDA_CALL(cub::DeviceSegmentedReduce::Sum(temp_storage, temp_storage_size, CUDA_CALL(cub::DeviceSegmentedReduce::Sum(temp_storage, temp_storage_size,
probs[etype], probs[etype],
prob_sums[etype], prob_sums[etype],
num_segments, num_segments,
d_offsets, d_offsets,
d_offsets + 1)); d_offsets + 1, stream));
device->FreeWorkspace(ctx, temp_storage); device->FreeWorkspace(ctx, temp_storage);
} }
...@@ -324,8 +323,7 @@ std::pair<IdArray, IdArray> RandomWalkBiased( ...@@ -324,8 +323,7 @@ std::pair<IdArray, IdArray> RandomWalkBiased(
(num_etypes) * sizeof(GraphKernelData<IdType>), (num_etypes) * sizeof(GraphKernelData<IdType>),
DGLContext{kDLCPU, 0}, DGLContext{kDLCPU, 0},
ctx, ctx,
hg->GetCSRMatrix(0).indptr->dtype, hg->GetCSRMatrix(0).indptr->dtype);
stream);
// copy probs pointers to GPU // copy probs pointers to GPU
const FloatType **probs_dev = static_cast<const FloatType **>( const FloatType **probs_dev = static_cast<const FloatType **>(
device->AllocWorkspace(ctx, num_etypes * sizeof(FloatType *))); device->AllocWorkspace(ctx, num_etypes * sizeof(FloatType *)));
...@@ -333,8 +331,7 @@ std::pair<IdArray, IdArray> RandomWalkBiased( ...@@ -333,8 +331,7 @@ std::pair<IdArray, IdArray> RandomWalkBiased(
(num_etypes) * sizeof(FloatType *), (num_etypes) * sizeof(FloatType *),
DGLContext{kDLCPU, 0}, DGLContext{kDLCPU, 0},
ctx, ctx,
prob[0]->dtype, prob[0]->dtype);
stream);
// copy probs_sum pointers to GPU // copy probs_sum pointers to GPU
const FloatType **prob_sums_dev = static_cast<const FloatType **>( const FloatType **prob_sums_dev = static_cast<const FloatType **>(
device->AllocWorkspace(ctx, num_etypes * sizeof(FloatType *))); device->AllocWorkspace(ctx, num_etypes * sizeof(FloatType *)));
...@@ -342,8 +339,7 @@ std::pair<IdArray, IdArray> RandomWalkBiased( ...@@ -342,8 +339,7 @@ std::pair<IdArray, IdArray> RandomWalkBiased(
(num_etypes) * sizeof(FloatType *), (num_etypes) * sizeof(FloatType *),
DGLContext{kDLCPU, 0}, DGLContext{kDLCPU, 0},
ctx, ctx,
prob[0]->dtype, prob[0]->dtype);
stream);
// copy metapath to GPU // copy metapath to GPU
auto d_metapath = metapath.CopyTo(ctx); auto d_metapath = metapath.CopyTo(ctx);
const IdType *d_metapath_data = static_cast<IdType *>(d_metapath->data); const IdType *d_metapath_data = static_cast<IdType *>(d_metapath->data);
...@@ -429,13 +425,13 @@ std::pair<IdArray, IdArray> RandomWalkWithRestart( ...@@ -429,13 +425,13 @@ std::pair<IdArray, IdArray> RandomWalkWithRestart(
{1}, DLDataType{kDLFloat, 64, 1}, device_ctx); {1}, DLDataType{kDLFloat, 64, 1}, device_ctx);
auto device = dgl::runtime::DeviceAPI::Get(device_ctx); auto device = dgl::runtime::DeviceAPI::Get(device_ctx);
// use default stream // use cuda stream from local thread
cudaStream_t stream = 0; cudaStream_t stream = runtime::CUDAThreadEntry::ThreadLocal()->stream;
device->CopyDataFromTo( device->CopyDataFromTo(
&restart_prob, 0, restart_prob_array.Ptr<double>(), 0, &restart_prob, 0, restart_prob_array.Ptr<double>(), 0,
sizeof(double), sizeof(double),
DGLContext{kDLCPU, 0}, device_ctx, DGLContext{kDLCPU, 0}, device_ctx,
restart_prob_array->dtype, stream); restart_prob_array->dtype);
device->StreamSync(device_ctx, stream); device->StreamSync(device_ctx, stream);
if (!isUniform) { if (!isUniform) {
...@@ -489,8 +485,8 @@ std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors( ...@@ -489,8 +485,8 @@ std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors(
const IdxType* dst_data = dst.Ptr<IdxType>(); const IdxType* dst_data = dst.Ptr<IdxType>();
const int64_t num_dst_nodes = (dst->shape[0] / num_samples_per_node); const int64_t num_dst_nodes = (dst->shape[0] / num_samples_per_node);
auto ctx = src->ctx; auto ctx = src->ctx;
// use default stream // use cuda stream from local thread
cudaStream_t stream = 0; cudaStream_t stream = runtime::CUDAThreadEntry::ThreadLocal()->stream;
auto frequency_hashmap = FrequencyHashmap<IdxType>(num_dst_nodes, auto frequency_hashmap = FrequencyHashmap<IdxType>(num_dst_nodes,
num_samples_per_node, ctx, stream); num_samples_per_node, ctx, stream);
auto ret = frequency_hashmap.Topk(src_data, dst_data, src->dtype, auto ret = frequency_hashmap.Topk(src_data, dst_data, src->dtype,
......
...@@ -88,9 +88,10 @@ CompactGraphsGPU( ...@@ -88,9 +88,10 @@ CompactGraphsGPU(
const std::vector<HeteroGraphPtr> &graphs, const std::vector<HeteroGraphPtr> &graphs,
const std::vector<IdArray> &always_preserve) { const std::vector<IdArray> &always_preserve) {
cudaStream_t stream = 0;
const auto& ctx = graphs[0]->Context(); const auto& ctx = graphs[0]->Context();
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
cudaStream_t stream = runtime::CUDAThreadEntry::ThreadLocal()->stream;
CHECK_EQ(ctx.device_type, kDLGPU); CHECK_EQ(ctx.device_type, kDLGPU);
...@@ -134,8 +135,7 @@ CompactGraphsGPU( ...@@ -134,8 +135,7 @@ CompactGraphsGPU(
sizeof(IdType)*always_preserve[ntype]->shape[0], sizeof(IdType)*always_preserve[ntype]->shape[0],
always_preserve[ntype]->ctx, always_preserve[ntype]->ctx,
all_nodes[ntype]->ctx, all_nodes[ntype]->ctx,
always_preserve[ntype]->dtype, always_preserve[ntype]->dtype);
stream);
node_offsets[ntype] += sizeof(IdType)*always_preserve[ntype]->shape[0]; node_offsets[ntype] += sizeof(IdType)*always_preserve[ntype]->shape[0];
} }
} }
...@@ -159,8 +159,7 @@ CompactGraphsGPU( ...@@ -159,8 +159,7 @@ CompactGraphsGPU(
sizeof(IdType)*edges.src->shape[0], sizeof(IdType)*edges.src->shape[0],
edges.src->ctx, edges.src->ctx,
all_nodes[srctype]->ctx, all_nodes[srctype]->ctx,
edges.src->dtype, edges.src->dtype);
stream);
node_offsets[srctype] += sizeof(IdType)*edges.src->shape[0]; node_offsets[srctype] += sizeof(IdType)*edges.src->shape[0];
} }
if (edges.dst.defined()) { if (edges.dst.defined()) {
...@@ -171,8 +170,7 @@ CompactGraphsGPU( ...@@ -171,8 +170,7 @@ CompactGraphsGPU(
sizeof(IdType)*edges.dst->shape[0], sizeof(IdType)*edges.dst->shape[0],
edges.dst->ctx, edges.dst->ctx,
all_nodes[dsttype]->ctx, all_nodes[dsttype]->ctx,
edges.dst->dtype, edges.dst->dtype);
stream);
node_offsets[dsttype] += sizeof(IdType)*edges.dst->shape[0]; node_offsets[dsttype] += sizeof(IdType)*edges.dst->shape[0];
} }
all_edges[i].push_back(edges); all_edges[i].push_back(edges);
...@@ -210,8 +208,7 @@ CompactGraphsGPU( ...@@ -210,8 +208,7 @@ CompactGraphsGPU(
sizeof(*num_induced_nodes.data())*num_ntypes, sizeof(*num_induced_nodes.data())*num_ntypes,
ctx, ctx,
DGLContext{kDLCPU, 0}, DGLContext{kDLCPU, 0},
DGLType{kDLInt, 64, 1}, DGLType{kDLInt, 64, 1});
stream);
device->StreamSync(ctx, stream); device->StreamSync(ctx, stream);
// wait for the node counts to finish transferring // wait for the node counts to finish transferring
......
...@@ -165,9 +165,10 @@ ToBlockGPU( ...@@ -165,9 +165,10 @@ ToBlockGPU(
std::vector<IdArray>& lhs_nodes = *lhs_nodes_ptr; std::vector<IdArray>& lhs_nodes = *lhs_nodes_ptr;
const bool generate_lhs_nodes = lhs_nodes.empty(); const bool generate_lhs_nodes = lhs_nodes.empty();
cudaStream_t stream = 0;
const auto& ctx = graph->Context(); const auto& ctx = graph->Context();
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
cudaStream_t stream = runtime::CUDAThreadEntry::ThreadLocal()->stream;
CHECK_EQ(ctx.device_type, kDLGPU); CHECK_EQ(ctx.device_type, kDLGPU);
for (const auto& nodes : rhs_nodes) { for (const auto& nodes : rhs_nodes) {
...@@ -233,8 +234,7 @@ ToBlockGPU( ...@@ -233,8 +234,7 @@ ToBlockGPU(
src_nodes[ntype].Ptr<IdType>(), src_node_offsets[ntype], src_nodes[ntype].Ptr<IdType>(), src_node_offsets[ntype],
sizeof(IdType)*rhs_nodes[ntype]->shape[0], sizeof(IdType)*rhs_nodes[ntype]->shape[0],
rhs_nodes[ntype]->ctx, src_nodes[ntype]->ctx, rhs_nodes[ntype]->ctx, src_nodes[ntype]->ctx,
rhs_nodes[ntype]->dtype, rhs_nodes[ntype]->dtype);
stream);
src_node_offsets[ntype] += sizeof(IdType)*rhs_nodes[ntype]->shape[0]; src_node_offsets[ntype] += sizeof(IdType)*rhs_nodes[ntype]->shape[0];
} }
} }
...@@ -249,8 +249,7 @@ ToBlockGPU( ...@@ -249,8 +249,7 @@ ToBlockGPU(
sizeof(IdType)*edge_arrays[etype].src->shape[0], sizeof(IdType)*edge_arrays[etype].src->shape[0],
rhs_nodes[srctype]->ctx, rhs_nodes[srctype]->ctx,
src_nodes[srctype]->ctx, src_nodes[srctype]->ctx,
rhs_nodes[srctype]->dtype, rhs_nodes[srctype]->dtype);
stream);
src_node_offsets[srctype] += sizeof(IdType)*edge_arrays[etype].src->shape[0]; src_node_offsets[srctype] += sizeof(IdType)*edge_arrays[etype].src->shape[0];
} }
...@@ -298,8 +297,7 @@ ToBlockGPU( ...@@ -298,8 +297,7 @@ ToBlockGPU(
sizeof(*num_nodes_per_type.data())*num_ntypes, sizeof(*num_nodes_per_type.data())*num_ntypes,
ctx, ctx,
DGLContext{kDLCPU, 0}, DGLContext{kDLCPU, 0},
DGLType{kDLInt, 64, 1}, DGLType{kDLInt, 64, 1});
stream);
device->StreamSync(ctx, stream); device->StreamSync(ctx, stream);
// wait for the node counts to finish transferring // wait for the node counts to finish transferring
......
...@@ -518,7 +518,7 @@ void BruteForceKNNSharedCuda(const NDArray& data_points, const IdArray& data_off ...@@ -518,7 +518,7 @@ void BruteForceKNNSharedCuda(const NDArray& data_points, const IdArray& data_off
size_t prefix_temp_size = 0; size_t prefix_temp_size = 0;
CUDA_CALL(cub::DeviceScan::ExclusiveSum( CUDA_CALL(cub::DeviceScan::ExclusiveSum(
nullptr, prefix_temp_size, num_block_per_segment, nullptr, prefix_temp_size, num_block_per_segment,
num_block_prefixsum, batch_size)); num_block_prefixsum, batch_size, thr_entry->stream));
void* prefix_temp = device->AllocWorkspace(ctx, prefix_temp_size); void* prefix_temp = device->AllocWorkspace(ctx, prefix_temp_size);
CUDA_CALL(cub::DeviceScan::ExclusiveSum( CUDA_CALL(cub::DeviceScan::ExclusiveSum(
prefix_temp, prefix_temp_size, num_block_per_segment, prefix_temp, prefix_temp_size, num_block_per_segment,
...@@ -529,11 +529,11 @@ void BruteForceKNNSharedCuda(const NDArray& data_points, const IdArray& data_off ...@@ -529,11 +529,11 @@ void BruteForceKNNSharedCuda(const NDArray& data_points, const IdArray& data_off
device->CopyDataFromTo( device->CopyDataFromTo(
num_block_prefixsum, copyoffset, &num_blocks, 0, num_block_prefixsum, copyoffset, &num_blocks, 0,
sizeof(IdType), ctx, DLContext{kDLCPU, 0}, sizeof(IdType), ctx, DLContext{kDLCPU, 0},
query_offsets->dtype, thr_entry->stream); query_offsets->dtype);
device->CopyDataFromTo( device->CopyDataFromTo(
num_block_per_segment, copyoffset, &final_elem, 0, num_block_per_segment, copyoffset, &final_elem, 0,
sizeof(IdType), ctx, DLContext{kDLCPU, 0}, sizeof(IdType), ctx, DLContext{kDLCPU, 0},
query_offsets->dtype, thr_entry->stream); query_offsets->dtype);
num_blocks += final_elem; num_blocks += final_elem;
device->FreeWorkspace(ctx, num_block_per_segment); device->FreeWorkspace(ctx, num_block_per_segment);
device->FreeWorkspace(ctx, num_block_prefixsum); device->FreeWorkspace(ctx, num_block_prefixsum);
...@@ -872,7 +872,7 @@ void NNDescent(const NDArray& points, const IdArray& offsets, ...@@ -872,7 +872,7 @@ void NNDescent(const NDArray& points, const IdArray& offsets,
device->AllocWorkspace(ctx, sizeof(IdType))); device->AllocWorkspace(ctx, sizeof(IdType)));
CUDA_CALL(cub::DeviceReduce::Sum( CUDA_CALL(cub::DeviceReduce::Sum(
nullptr, sum_temp_size, num_updates, total_num_updates_d, num_nodes)); nullptr, sum_temp_size, num_updates, total_num_updates_d, num_nodes, thr_entry->stream));
IdType* sum_temp_storage = static_cast<IdType*>( IdType* sum_temp_storage = static_cast<IdType*>(
device->AllocWorkspace(ctx, sum_temp_size)); device->AllocWorkspace(ctx, sum_temp_size));
...@@ -901,11 +901,12 @@ void NNDescent(const NDArray& points, const IdArray& offsets, ...@@ -901,11 +901,12 @@ void NNDescent(const NDArray& points, const IdArray& offsets,
total_num_updates = 0; total_num_updates = 0;
CUDA_CALL(cub::DeviceReduce::Sum( CUDA_CALL(cub::DeviceReduce::Sum(
sum_temp_storage, sum_temp_size, num_updates, total_num_updates_d, num_nodes)); sum_temp_storage, sum_temp_size, num_updates, total_num_updates_d, num_nodes,
thr_entry->stream));
device->CopyDataFromTo( device->CopyDataFromTo(
total_num_updates_d, 0, &total_num_updates, 0, total_num_updates_d, 0, &total_num_updates, 0,
sizeof(IdType), ctx, DLContext{kDLCPU, 0}, sizeof(IdType), ctx, DLContext{kDLCPU, 0},
offsets->dtype, thr_entry->stream); offsets->dtype);
if (total_num_updates <= static_cast<IdType>(delta * k * num_nodes)) { if (total_num_updates <= static_cast<IdType>(delta * k * num_nodes)) {
break; break;
......
...@@ -153,11 +153,10 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -153,11 +153,10 @@ class UnitGraph::COO : public BaseHeteroGraph {
return ret; return ret;
} }
COO CopyTo(const DLContext &ctx, COO CopyTo(const DLContext &ctx) const {
const DGLStreamHandle &stream = nullptr) const {
if (Context() == ctx) if (Context() == ctx)
return *this; return *this;
return COO(meta_graph_, adj_.CopyTo(ctx, stream)); return COO(meta_graph_, adj_.CopyTo(ctx));
} }
...@@ -558,12 +557,11 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -558,12 +557,11 @@ class UnitGraph::CSR : public BaseHeteroGraph {
} }
} }
CSR CopyTo(const DLContext &ctx, CSR CopyTo(const DLContext &ctx) const {
const DGLStreamHandle &stream = nullptr) const {
if (Context() == ctx) { if (Context() == ctx) {
return *this; return *this;
} else { } else {
return CSR(meta_graph_, adj_.CopyTo(ctx, stream)); return CSR(meta_graph_, adj_.CopyTo(ctx));
} }
} }
...@@ -1277,21 +1275,20 @@ HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) { ...@@ -1277,21 +1275,20 @@ HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
} }
} }
HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DLContext &ctx, HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DLContext &ctx) {
const DGLStreamHandle &stream) {
if (ctx == g->Context()) { if (ctx == g->Context()) {
return g; return g;
} else { } else {
auto bg = std::dynamic_pointer_cast<UnitGraph>(g); auto bg = std::dynamic_pointer_cast<UnitGraph>(g);
CHECK_NOTNULL(bg); CHECK_NOTNULL(bg);
CSRPtr new_incsr = (bg->in_csr_->defined()) CSRPtr new_incsr = (bg->in_csr_->defined())
? CSRPtr(new CSR(bg->in_csr_->CopyTo(ctx, stream))) ? CSRPtr(new CSR(bg->in_csr_->CopyTo(ctx)))
: nullptr; : nullptr;
CSRPtr new_outcsr = (bg->out_csr_->defined()) CSRPtr new_outcsr = (bg->out_csr_->defined())
? CSRPtr(new CSR(bg->out_csr_->CopyTo(ctx, stream))) ? CSRPtr(new CSR(bg->out_csr_->CopyTo(ctx)))
: nullptr; : nullptr;
COOPtr new_coo = (bg->coo_->defined()) COOPtr new_coo = (bg->coo_->defined())
? COOPtr(new COO(bg->coo_->CopyTo(ctx, stream))) ? COOPtr(new COO(bg->coo_->CopyTo(ctx)))
: nullptr; : nullptr;
return HeteroGraphPtr( return HeteroGraphPtr(
new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->formats_)); new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->formats_));
......
...@@ -207,8 +207,7 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -207,8 +207,7 @@ class UnitGraph : public BaseHeteroGraph {
static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits); static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);
/*! \brief Copy the data to another context */ /*! \brief Copy the data to another context */
static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext &ctx, static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext &ctx);
const DGLStreamHandle &stream = nullptr);
/*! /*!
* \brief Pin the in_csr_, out_scr_ and coo_ of the current graph. * \brief Pin the in_csr_, out_scr_ and coo_ of the current graph.
......
...@@ -62,8 +62,7 @@ class CPUDeviceAPI final : public DeviceAPI { ...@@ -62,8 +62,7 @@ class CPUDeviceAPI final : public DeviceAPI {
size_t size, size_t size,
DGLContext ctx_from, DGLContext ctx_from,
DGLContext ctx_to, DGLContext ctx_to,
DGLType type_hint, DGLType type_hint) final {
DGLStreamHandle stream) final {
memcpy(static_cast<char*>(to) + to_offset, memcpy(static_cast<char*>(to) + to_offset,
static_cast<const char*>(from) + from_offset, static_cast<const char*>(from) + from_offset,
size); size);
......
...@@ -137,7 +137,7 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -137,7 +137,7 @@ class CUDADeviceAPI final : public DeviceAPI {
DGLContext ctx_from, DGLContext ctx_from,
DGLContext ctx_to, DGLContext ctx_to,
DGLType type_hint, DGLType type_hint,
DGLStreamHandle stream) final { DGLStreamHandle stream) {
cudaStream_t cu_stream = static_cast<cudaStream_t>(stream); cudaStream_t cu_stream = static_cast<cudaStream_t>(stream);
from = static_cast<const char*>(from) + from_offset; from = static_cast<const char*>(from) + from_offset;
to = static_cast<char*>(to) + to_offset; to = static_cast<char*>(to) + to_offset;
...@@ -161,6 +161,18 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -161,6 +161,18 @@ class CUDADeviceAPI final : public DeviceAPI {
} }
} }
void CopyDataFromTo(const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t size,
DGLContext ctx_from,
DGLContext ctx_to,
DGLType type_hint) final {
auto stream = static_cast<DGLStreamHandle>(CUDAThreadEntry::ThreadLocal()->stream);
CopyDataFromTo(from, from_offset, to, to_offset, size, ctx_from, ctx_to, type_hint, stream);
}
DGLStreamHandle CreateStream(DGLContext ctx) { DGLStreamHandle CreateStream(DGLContext ctx) {
CUDA_CALL(cudaSetDevice(ctx.device_id)); CUDA_CALL(cudaSetDevice(ctx.device_id));
cudaStream_t retval; cudaStream_t retval;
...@@ -297,8 +309,13 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -297,8 +309,13 @@ class CUDADeviceAPI final : public DeviceAPI {
typedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore; typedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore;
// TODO(cliu): cuda streams should depend on the current device, therefore we should set device
// before setting stream.
CUDAThreadEntry::CUDAThreadEntry() CUDAThreadEntry::CUDAThreadEntry()
: pool(kDLGPU, CUDADeviceAPI::Global()) { : pool(kDLGPU, CUDADeviceAPI::Global()) {
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
stream = td->CUDAGetCurrentStream();
} }
CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() { CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() {
......
...@@ -439,7 +439,7 @@ void OrderedHashTable<IdType>::FillWithDuplicates( ...@@ -439,7 +439,7 @@ void OrderedHashTable<IdType>::FillWithDuplicates(
workspace_bytes, workspace_bytes,
static_cast<IdType*>(nullptr), static_cast<IdType*>(nullptr),
static_cast<IdType*>(nullptr), static_cast<IdType*>(nullptr),
grid.x+1)); grid.x+1, stream));
void * workspace = device->AllocWorkspace(ctx_, workspace_bytes); void * workspace = device->AllocWorkspace(ctx_, workspace_bytes);
CUDA_CALL(cub::DeviceScan::ExclusiveSum( CUDA_CALL(cub::DeviceScan::ExclusiveSum(
......
...@@ -153,8 +153,7 @@ std::pair<IdArray, NDArray> SparsePush( ...@@ -153,8 +153,7 @@ std::pair<IdArray, NDArray> SparsePush(
"device"; "device";
auto device = DeviceAPI::Get(ctx); auto device = DeviceAPI::Get(ctx);
// TODO(dlasalle): Get the stream from the device context. cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream;
cudaStream_t stream = 0;
CHECK_LE(in_idx->ndim, 1) << "The tensor of sending indices must be of " CHECK_LE(in_idx->ndim, 1) << "The tensor of sending indices must be of "
"dimension one (or empty)."; "dimension one (or empty).";
...@@ -215,6 +214,7 @@ std::pair<IdArray, NDArray> SparsePush( ...@@ -215,6 +214,7 @@ std::pair<IdArray, NDArray> SparsePush(
} }
std::vector<int64_t> send_prefix_host(comm_size+1); std::vector<int64_t> send_prefix_host(comm_size+1);
// copy using the same stream (CUDAThreadEntry->ThreadLocal()->stream), no need to sync
device->CopyDataFromTo( device->CopyDataFromTo(
send_prefix.get(), send_prefix.get(),
0, 0,
...@@ -223,8 +223,7 @@ std::pair<IdArray, NDArray> SparsePush( ...@@ -223,8 +223,7 @@ std::pair<IdArray, NDArray> SparsePush(
send_prefix_host.size()*sizeof(*send_prefix.get()), send_prefix_host.size()*sizeof(*send_prefix.get()),
ctx, ctx,
DGLContext{kDLCPU, 0}, DGLContext{kDLCPU, 0},
DGLType{kDLInt, sizeof(*send_prefix.get())*8, 1}, DGLType{kDLInt, sizeof(*send_prefix.get())*8, 1});
stream);
send_prefix.free(); send_prefix.free();
CHECK_EQ(send_prefix_host.back(), num_in) << "Internal Error: " CHECK_EQ(send_prefix_host.back(), num_in) << "Internal Error: "
...@@ -243,16 +242,17 @@ std::pair<IdArray, NDArray> SparsePush( ...@@ -243,16 +242,17 @@ std::pair<IdArray, NDArray> SparsePush(
{ {
size_t prefix_workspace_size; size_t prefix_workspace_size;
CUDA_CALL(cub::DeviceScan::ExclusiveSum(nullptr, prefix_workspace_size, CUDA_CALL(cub::DeviceScan::ExclusiveSum(nullptr, prefix_workspace_size,
recv_sum.get(), recv_prefix.get(), comm_size+1)); recv_sum.get(), recv_prefix.get(), comm_size+1, stream));
Workspace<void> prefix_workspace(device, ctx, prefix_workspace_size); Workspace<void> prefix_workspace(device, ctx, prefix_workspace_size);
CUDA_CALL(cub::DeviceScan::ExclusiveSum(prefix_workspace.get(), CUDA_CALL(cub::DeviceScan::ExclusiveSum(prefix_workspace.get(),
prefix_workspace_size, recv_sum.get(), recv_prefix.get(), comm_size+1)); prefix_workspace_size, recv_sum.get(), recv_prefix.get(), comm_size+1, stream));
} }
recv_sum.free(); recv_sum.free();
// finally copy the prefixsum sum down to the host // finally copy the prefixsum sum down to the host
std::vector<int64_t> recv_prefix_host(comm_size+1); std::vector<int64_t> recv_prefix_host(comm_size+1);
// copy using the same stream (CUDAThreadEntry->ThreadLocal()->stream), no need to sync
device->CopyDataFromTo( device->CopyDataFromTo(
recv_prefix.get(), recv_prefix.get(),
0, 0,
...@@ -261,8 +261,7 @@ std::pair<IdArray, NDArray> SparsePush( ...@@ -261,8 +261,7 @@ std::pair<IdArray, NDArray> SparsePush(
recv_prefix_host.size()*sizeof(*recv_prefix.get()), recv_prefix_host.size()*sizeof(*recv_prefix.get()),
ctx, ctx,
DGLContext{kDLCPU, 0}, DGLContext{kDLCPU, 0},
DGLType{kDLInt, sizeof(*recv_prefix.get())*8, 1}, DGLType{kDLInt, sizeof(*recv_prefix.get())*8, 1});
stream);
recv_prefix.free(); recv_prefix.free();
// use an event to track when copying is done // use an event to track when copying is done
...@@ -369,6 +368,7 @@ NDArray SparsePull( ...@@ -369,6 +368,7 @@ NDArray SparsePull(
CUDA_CALL(cudaEventCreate(&d2h)); CUDA_CALL(cudaEventCreate(&d2h));
std::vector<int64_t> request_prefix_host(comm_size+1); std::vector<int64_t> request_prefix_host(comm_size+1);
// copy using the same stream (CUDAThreadEntry->ThreadLocal()->stream), no need to sync
device->CopyDataFromTo( device->CopyDataFromTo(
request_prefix.get(), request_prefix.get(),
0, 0,
...@@ -377,8 +377,7 @@ NDArray SparsePull( ...@@ -377,8 +377,7 @@ NDArray SparsePull(
request_prefix_host.size()*sizeof(*request_prefix.get()), request_prefix_host.size()*sizeof(*request_prefix.get()),
ctx, ctx,
DGLContext{kDLCPU, 0}, DGLContext{kDLCPU, 0},
DGLType{kDLInt, sizeof(*request_prefix.get())*8, 1}, DGLType{kDLInt, sizeof(*request_prefix.get())*8, 1});
stream);
request_prefix.free(); request_prefix.free();
CHECK_EQ(request_prefix_host.back(), num_in) << "Internal Error: " CHECK_EQ(request_prefix_host.back(), num_in) << "Internal Error: "
"request_prefix_host.back() = " << request_prefix_host.back() << "request_prefix_host.back() = " << request_prefix_host.back() <<
...@@ -404,6 +403,7 @@ NDArray SparsePull( ...@@ -404,6 +403,7 @@ NDArray SparsePull(
// finally copy the prefixsum sum down to the host // finally copy the prefixsum sum down to the host
std::vector<int64_t> response_prefix_host(comm_size+1); std::vector<int64_t> response_prefix_host(comm_size+1);
// copy using the same stream (CUDAThreadEntry->ThreadLocal()->stream), no need to sync
device->CopyDataFromTo( device->CopyDataFromTo(
response_prefix.get(), response_prefix.get(),
0, 0,
...@@ -412,8 +412,7 @@ NDArray SparsePull( ...@@ -412,8 +412,7 @@ NDArray SparsePull(
response_prefix_host.size()*sizeof(*response_prefix.get()), response_prefix_host.size()*sizeof(*response_prefix.get()),
ctx, ctx,
DGLContext{kDLCPU, 0}, DGLContext{kDLCPU, 0},
DGLType{kDLInt, sizeof(*response_prefix.get())*8, 1}, DGLType{kDLInt, sizeof(*response_prefix.get())*8, 1});
stream);
response_prefix.free(); response_prefix.free();
// use an event to track when copying is done // use an event to track when copying is done
...@@ -623,12 +622,12 @@ void NCCLCommunicator::AllToAllV( ...@@ -623,12 +622,12 @@ void NCCLCommunicator::AllToAllV(
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
auto dtype = DLDataTypeTraits<DType>::dtype; auto dtype = DLDataTypeTraits<DType>::dtype;
// copy using the same stream (CUDAThreadEntry->ThreadLocal()->stream), no need to sync
device->CopyDataFromTo(send, send_prefix[0], device->CopyDataFromTo(send, send_prefix[0],
recv, recv_prefix[0], recv, recv_prefix[0],
sizeof(DType)*send_prefix[1]-send_prefix[0], sizeof(DType)*send_prefix[1]-send_prefix[0],
ctx, ctx, ctx, ctx,
dtype, dtype);
stream);
#endif #endif
} }
...@@ -685,7 +684,8 @@ void NCCLCommunicator::AllToAll( ...@@ -685,7 +684,8 @@ void NCCLCommunicator::AllToAll(
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
auto dtype = DLDataTypeTraits<IdType>::dtype; auto dtype = DLDataTypeTraits<IdType>::dtype;
device->CopyDataFromTo(send, 0, recv, 0, count, ctx, ctx, dtype, stream); // copy using the same stream (CUDAThreadEntry->ThreadLocal()->stream), no need to sync
device->CopyDataFromTo(send, 0, recv, 0, count, ctx, ctx, dtype);
#endif #endif
} }
......
...@@ -235,8 +235,7 @@ NDArray NDArray::FromDLPack(DLManagedTensor* tensor) { ...@@ -235,8 +235,7 @@ NDArray NDArray::FromDLPack(DLManagedTensor* tensor) {
} }
void NDArray::CopyFromTo(DLTensor* from, void NDArray::CopyFromTo(DLTensor* from,
DLTensor* to, DLTensor* to) {
DGLStreamHandle stream) {
size_t from_size = GetDataSize(*from); size_t from_size = GetDataSize(*from);
size_t to_size = GetDataSize(*to); size_t to_size = GetDataSize(*to);
CHECK_EQ(from_size, to_size) CHECK_EQ(from_size, to_size)
...@@ -251,10 +250,11 @@ void NDArray::CopyFromTo(DLTensor* from, ...@@ -251,10 +250,11 @@ void NDArray::CopyFromTo(DLTensor* from,
// api manager. // api manager.
DGLContext ctx = from->ctx.device_type != kDLCPU ? from->ctx : to->ctx; DGLContext ctx = from->ctx.device_type != kDLCPU ? from->ctx : to->ctx;
// default: local cuda stream: CUDAThreadEntry->ThreadLocal()->stream
DeviceAPI::Get(ctx)->CopyDataFromTo( DeviceAPI::Get(ctx)->CopyDataFromTo(
from->data, static_cast<size_t>(from->byte_offset), from->data, static_cast<size_t>(from->byte_offset),
to->data, static_cast<size_t>(to->byte_offset), to->data, static_cast<size_t>(to->byte_offset),
from_size, from->ctx, to->ctx, from->dtype, stream); from_size, from->ctx, to->ctx, from->dtype);
} }
void NDArray::PinContainer(NDArray::Container* ptr) { void NDArray::PinContainer(NDArray::Container* ptr) {
...@@ -292,8 +292,7 @@ NDArray NDArray::FromVector(const std::vector<T>& vec, DLContext ctx) { ...@@ -292,8 +292,7 @@ NDArray NDArray::FromVector(const std::vector<T>& vec, DLContext ctx) {
size * sizeof(T), size * sizeof(T),
DLContext{kDLCPU, 0}, DLContext{kDLCPU, 0},
ctx, ctx,
dtype, dtype);
nullptr);
return ret; return ret;
} }
...@@ -322,8 +321,7 @@ std::vector<T> NDArray::ToVector() const { ...@@ -322,8 +321,7 @@ std::vector<T> NDArray::ToVector() const {
size * sizeof(T), size * sizeof(T),
ctx, ctx,
DLContext{kDLCPU, 0}, DLContext{kDLCPU, 0},
dtype, dtype);
nullptr);
return vec; return vec;
} }
...@@ -471,10 +469,9 @@ int DGLArrayFree(DGLArrayHandle handle) { ...@@ -471,10 +469,9 @@ int DGLArrayFree(DGLArrayHandle handle) {
} }
int DGLArrayCopyFromTo(DGLArrayHandle from, int DGLArrayCopyFromTo(DGLArrayHandle from,
DGLArrayHandle to, DGLArrayHandle to) {
DGLStreamHandle stream) {
API_BEGIN(); API_BEGIN();
NDArray::CopyFromTo(from, to, stream); NDArray::CopyFromTo(from, to);
API_END(); API_END();
} }
...@@ -523,7 +520,7 @@ int DGLArrayCopyFromBytes(DGLArrayHandle handle, ...@@ -523,7 +520,7 @@ int DGLArrayCopyFromBytes(DGLArrayHandle handle,
DeviceAPI::Get(handle->ctx)->CopyDataFromTo( DeviceAPI::Get(handle->ctx)->CopyDataFromTo(
data, 0, data, 0,
handle->data, static_cast<size_t>(handle->byte_offset), handle->data, static_cast<size_t>(handle->byte_offset),
nbytes, cpu_ctx, handle->ctx, handle->dtype, nullptr); nbytes, cpu_ctx, handle->ctx, handle->dtype);
API_END(); API_END();
} }
...@@ -540,7 +537,7 @@ int DGLArrayCopyToBytes(DGLArrayHandle handle, ...@@ -540,7 +537,7 @@ int DGLArrayCopyToBytes(DGLArrayHandle handle,
DeviceAPI::Get(handle->ctx)->CopyDataFromTo( DeviceAPI::Get(handle->ctx)->CopyDataFromTo(
handle->data, static_cast<size_t>(handle->byte_offset), handle->data, static_cast<size_t>(handle->byte_offset),
data, 0, data, 0,
nbytes, handle->ctx, cpu_ctx, handle->dtype, nullptr); nbytes, handle->ctx, cpu_ctx, handle->dtype);
API_END(); API_END();
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment