Unverified Commit 5a245104 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Feature] enable to specify stream in UnitGraph::CopyTo() which could lead to async copy (#3297)


Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent f4fe518f
......@@ -120,12 +120,13 @@ struct COOMatrix {
}
/*! \brief Return a copy of this matrix on the give device context. */
inline COOMatrix CopyTo(const DLContext& ctx) const {
inline COOMatrix CopyTo(const DLContext &ctx,
const DGLStreamHandle &stream = nullptr) const {
if (ctx == row->ctx)
return *this;
return COOMatrix(num_rows, num_cols,
row.CopyTo(ctx), col.CopyTo(ctx),
aten::IsNullArray(data)? data : data.CopyTo(ctx),
return COOMatrix(num_rows, num_cols, row.CopyTo(ctx, stream),
col.CopyTo(ctx, stream),
aten::IsNullArray(data) ? data : data.CopyTo(ctx, stream),
row_sorted, col_sorted);
}
};
......
......@@ -113,12 +113,13 @@ struct CSRMatrix {
}
/*! \brief Return a copy of this matrix on the give device context. */
inline CSRMatrix CopyTo(const DLContext& ctx) const {
inline CSRMatrix CopyTo(const DLContext &ctx,
const DGLStreamHandle &stream = nullptr) const {
if (ctx == indptr->ctx)
return *this;
return CSRMatrix(num_rows, num_cols,
indptr.CopyTo(ctx), indices.CopyTo(ctx),
aten::IsNullArray(data)? data : data.CopyTo(ctx),
return CSRMatrix(num_rows, num_cols, indptr.CopyTo(ctx, stream),
indices.CopyTo(ctx, stream),
aten::IsNullArray(data) ? data : data.CopyTo(ctx, stream),
sorted);
}
};
......
......@@ -154,18 +154,21 @@ class NDArray {
* \note The copy may happen asynchrously if it involves a GPU context.
* DGLSynchronize is necessary.
*/
inline void CopyTo(DLTensor* other) const;
inline void CopyTo(const NDArray& other) const;
inline void CopyTo(DLTensor *other,
const DGLStreamHandle &stream = nullptr) const;
inline void CopyTo(const NDArray &other,
const DGLStreamHandle &stream = nullptr) const;
/*!
* \brief Copy the data to another context.
* \param ctx The target context.
* \return The array under another context.
*/
inline NDArray CopyTo(const DLContext& ctx) const;
inline NDArray CopyTo(const DLContext &ctx,
const DGLStreamHandle &stream = nullptr) const;
/*!
* \brief Return a new array with a copy of the content.
*/
inline NDArray Clone() const;
inline NDArray Clone(const DGLStreamHandle &stream = nullptr) const;
/*!
* \brief Load NDArray from stream
* \param stream The input data stream
......@@ -401,30 +404,33 @@ inline void NDArray::CopyFrom(const NDArray& other,
CopyFromTo(&(other.data_->dl_tensor), &(data_->dl_tensor), stream);
}
inline void NDArray::CopyTo(DLTensor* other) const {
inline void NDArray::CopyTo(DLTensor *other,
const DGLStreamHandle &stream) const {
CHECK(data_ != nullptr);
CopyFromTo(&(data_->dl_tensor), other);
CopyFromTo(&(data_->dl_tensor), other, stream);
}
inline void NDArray::CopyTo(const NDArray& other) const {
inline void NDArray::CopyTo(const NDArray &other,
const DGLStreamHandle &stream) const {
CHECK(data_ != nullptr);
CHECK(other.data_ != nullptr);
CopyFromTo(&(data_->dl_tensor), &(other.data_->dl_tensor));
CopyFromTo(&(data_->dl_tensor), &(other.data_->dl_tensor), stream);
}
inline NDArray NDArray::CopyTo(const DLContext& ctx) const {
inline NDArray NDArray::CopyTo(const DLContext &ctx,
const DGLStreamHandle &stream) const {
CHECK(data_ != nullptr);
const DLTensor* dptr = operator->();
NDArray ret = Empty(std::vector<int64_t>(dptr->shape, dptr->shape + dptr->ndim),
dptr->dtype, ctx);
this->CopyTo(ret);
this->CopyTo(ret, stream);
return ret;
}
inline NDArray NDArray::Clone() const {
inline NDArray NDArray::Clone(const DGLStreamHandle &stream) const {
CHECK(data_ != nullptr);
const DLTensor* dptr = operator->();
return this->CopyTo(dptr->ctx);
return this->CopyTo(dptr->ctx, stream);
}
inline int NDArray::use_count() const {
......
......@@ -254,7 +254,8 @@ HeteroGraphPtr HeteroGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
hgindex->num_verts_per_type_));
}
HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DLContext& ctx) {
HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DLContext &ctx,
const DGLStreamHandle &stream) {
if (ctx == g->Context()) {
return g;
}
......@@ -262,7 +263,7 @@ HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DLContext& ctx) {
CHECK_NOTNULL(hgindex);
std::vector<HeteroGraphPtr> rel_graphs;
for (auto g : hgindex->relation_graphs_) {
rel_graphs.push_back(UnitGraph::CopyTo(g, ctx));
rel_graphs.push_back(UnitGraph::CopyTo(g, ctx, stream));
}
return HeteroGraphPtr(new HeteroGraph(hgindex->meta_graph_, rel_graphs,
hgindex->num_verts_per_type_));
......
......@@ -225,7 +225,8 @@ class HeteroGraph : public BaseHeteroGraph {
static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);
/*! \brief Copy the data to another context */
static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext& ctx);
static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext &ctx,
const DGLStreamHandle &stream = nullptr);
/*! \brief Copy the data to shared memory.
*
......
......@@ -149,10 +149,11 @@ class UnitGraph::COO : public BaseHeteroGraph {
return ret;
}
COO CopyTo(const DLContext& ctx) const {
COO CopyTo(const DLContext &ctx,
const DGLStreamHandle &stream = nullptr) const {
if (Context() == ctx)
return *this;
return COO(meta_graph_, adj_.CopyTo(ctx));
return COO(meta_graph_, adj_.CopyTo(ctx, stream));
}
bool IsMultigraph() const override {
......@@ -537,11 +538,12 @@ class UnitGraph::CSR : public BaseHeteroGraph {
}
}
CSR CopyTo(const DLContext& ctx) const {
CSR CopyTo(const DLContext &ctx,
const DGLStreamHandle &stream = nullptr) const {
if (Context() == ctx) {
return *this;
} else {
return CSR(meta_graph_, adj_.CopyTo(ctx));
return CSR(meta_graph_, adj_.CopyTo(ctx, stream));
}
}
......@@ -1232,18 +1234,22 @@ HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
}
}
HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DLContext& ctx) {
HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DLContext &ctx,
const DGLStreamHandle &stream) {
if (ctx == g->Context()) {
return g;
} else {
auto bg = std::dynamic_pointer_cast<UnitGraph>(g);
CHECK_NOTNULL(bg);
CSRPtr new_incsr =
(bg->in_csr_->defined())? CSRPtr(new CSR(bg->in_csr_->CopyTo(ctx))) : nullptr;
CSRPtr new_outcsr =
(bg->out_csr_->defined())? CSRPtr(new CSR(bg->out_csr_->CopyTo(ctx))) : nullptr;
COOPtr new_coo =
(bg->coo_->defined())? COOPtr(new COO(bg->coo_->CopyTo(ctx))) : nullptr;
CSRPtr new_incsr = (bg->in_csr_->defined())
? CSRPtr(new CSR(bg->in_csr_->CopyTo(ctx, stream)))
: nullptr;
CSRPtr new_outcsr = (bg->out_csr_->defined())
? CSRPtr(new CSR(bg->out_csr_->CopyTo(ctx, stream)))
: nullptr;
COOPtr new_coo = (bg->coo_->defined())
? COOPtr(new COO(bg->coo_->CopyTo(ctx, stream)))
: nullptr;
return HeteroGraphPtr(
new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->formats_));
}
......
......@@ -205,7 +205,8 @@ class UnitGraph : public BaseHeteroGraph {
static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);
/*! \brief Copy the data to another context */
static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext& ctx);
static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext &ctx,
const DGLStreamHandle &stream = nullptr);
/*!
* \brief Create in-edge CSR format of the unit graph.
......
......@@ -60,6 +60,8 @@ class CPUDeviceAPI final : public DeviceAPI {
size);
}
DGLStreamHandle CreateStream(DGLContext) final { return nullptr; }
void StreamSync(DGLContext ctx, DGLStreamHandle stream) final {
}
......
......@@ -3,14 +3,15 @@
* \file test_unit_graph.cc
* \brief Test UnitGraph
*/
#include <gtest/gtest.h>
#include "../../src/graph/unit_graph.h"
#include "./../src/graph/heterograph.h"
#include "./common.h"
#include <dgl/array.h>
#include <dgl/immutable_graph.h>
#include <dgl/runtime/device_api.h>
#include <gtest/gtest.h>
#include <memory>
#include <vector>
#include <dgl/immutable_graph.h>
#include "./common.h"
#include "./../src/graph/heterograph.h"
#include "../../src/graph/unit_graph.h"
using namespace dgl;
using namespace dgl::runtime;
......@@ -298,6 +299,47 @@ void _TestUnitGraph_Reserve(DLContext ctx) {
ASSERT_TRUE(g_out_csr.indices->data == r_g_in_csr.indices->data);
}
template <typename IdType>
void _TestUnitGraph_CopyTo(const DLContext &src_ctx,
const DGLContext &dst_ctx) {
const aten::CSRMatrix &csr = CSR1<IdType>(src_ctx);
const aten::COOMatrix &coo = COO1<IdType>(src_ctx);
auto device = dgl::runtime::DeviceAPI::Get(dst_ctx);
auto stream = device->CreateStream(dst_ctx);
auto g = dgl::UnitGraph::CreateFromCSC(2, csr);
ASSERT_EQ(g->GetCreatedFormats(), 4);
auto cg = dgl::UnitGraph::CopyTo(g, dst_ctx, stream);
device->StreamSync(dst_ctx, stream);
ASSERT_EQ(cg->GetCreatedFormats(), 4);
g = dgl::UnitGraph::CreateFromCSR(2, csr);
ASSERT_EQ(g->GetCreatedFormats(), 2);
cg = dgl::UnitGraph::CopyTo(g, dst_ctx, stream);
device->StreamSync(dst_ctx, stream);
ASSERT_EQ(cg->GetCreatedFormats(), 2);
g = dgl::UnitGraph::CreateFromCOO(2, coo);
ASSERT_EQ(g->GetCreatedFormats(), 1);
cg = dgl::UnitGraph::CopyTo(g, dst_ctx, stream);
device->StreamSync(dst_ctx, stream);
ASSERT_EQ(cg->GetCreatedFormats(), 1);
}
TEST(UniGraphTest, TestUnitGraph_CopyTo) {
_TestUnitGraph_CopyTo<int32_t>(CPU, CPU);
_TestUnitGraph_CopyTo<int64_t>(CPU, CPU);
#ifdef DGL_USE_CUDA
_TestUnitGraph_CopyTo<int32_t>(CPU, GPU);
_TestUnitGraph_CopyTo<int32_t>(GPU, GPU);
_TestUnitGraph_CopyTo<int32_t>(GPU, CPU);
_TestUnitGraph_CopyTo<int64_t>(CPU, GPU);
_TestUnitGraph_CopyTo<int64_t>(GPU, GPU);
_TestUnitGraph_CopyTo<int64_t>(GPU, CPU);
#endif
}
TEST(UniGraphTest, TestUnitGraph_Create) {
_TestUnitGraph<int32_t>(CPU);
_TestUnitGraph<int64_t>(CPU);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment