Unverified Commit 5a245104 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Feature] enable to specify stream in UnitGraph::CopyTo() which could lead to async copy (#3297)


Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent f4fe518f
...@@ -120,12 +120,13 @@ struct COOMatrix { ...@@ -120,12 +120,13 @@ struct COOMatrix {
} }
/*! \brief Return a copy of this matrix on the give device context. */ /*! \brief Return a copy of this matrix on the give device context. */
inline COOMatrix CopyTo(const DLContext& ctx) const { inline COOMatrix CopyTo(const DLContext &ctx,
const DGLStreamHandle &stream = nullptr) const {
if (ctx == row->ctx) if (ctx == row->ctx)
return *this; return *this;
return COOMatrix(num_rows, num_cols, return COOMatrix(num_rows, num_cols, row.CopyTo(ctx, stream),
row.CopyTo(ctx), col.CopyTo(ctx), col.CopyTo(ctx, stream),
aten::IsNullArray(data)? data : data.CopyTo(ctx), aten::IsNullArray(data) ? data : data.CopyTo(ctx, stream),
row_sorted, col_sorted); row_sorted, col_sorted);
} }
}; };
......
...@@ -113,12 +113,13 @@ struct CSRMatrix { ...@@ -113,12 +113,13 @@ struct CSRMatrix {
} }
/*! \brief Return a copy of this matrix on the give device context. */ /*! \brief Return a copy of this matrix on the give device context. */
inline CSRMatrix CopyTo(const DLContext& ctx) const { inline CSRMatrix CopyTo(const DLContext &ctx,
const DGLStreamHandle &stream = nullptr) const {
if (ctx == indptr->ctx) if (ctx == indptr->ctx)
return *this; return *this;
return CSRMatrix(num_rows, num_cols, return CSRMatrix(num_rows, num_cols, indptr.CopyTo(ctx, stream),
indptr.CopyTo(ctx), indices.CopyTo(ctx), indices.CopyTo(ctx, stream),
aten::IsNullArray(data)? data : data.CopyTo(ctx), aten::IsNullArray(data) ? data : data.CopyTo(ctx, stream),
sorted); sorted);
} }
}; };
......
...@@ -154,18 +154,21 @@ class NDArray { ...@@ -154,18 +154,21 @@ class NDArray {
* \note The copy may happen asynchrously if it involves a GPU context. * \note The copy may happen asynchrously if it involves a GPU context.
* DGLSynchronize is necessary. * DGLSynchronize is necessary.
*/ */
inline void CopyTo(DLTensor* other) const; inline void CopyTo(DLTensor *other,
inline void CopyTo(const NDArray& other) const; const DGLStreamHandle &stream = nullptr) const;
inline void CopyTo(const NDArray &other,
const DGLStreamHandle &stream = nullptr) const;
/*! /*!
* \brief Copy the data to another context. * \brief Copy the data to another context.
* \param ctx The target context. * \param ctx The target context.
* \return The array under another context. * \return The array under another context.
*/ */
inline NDArray CopyTo(const DLContext& ctx) const; inline NDArray CopyTo(const DLContext &ctx,
const DGLStreamHandle &stream = nullptr) const;
/*! /*!
* \brief Return a new array with a copy of the content. * \brief Return a new array with a copy of the content.
*/ */
inline NDArray Clone() const; inline NDArray Clone(const DGLStreamHandle &stream = nullptr) const;
/*! /*!
* \brief Load NDArray from stream * \brief Load NDArray from stream
* \param stream The input data stream * \param stream The input data stream
...@@ -401,30 +404,33 @@ inline void NDArray::CopyFrom(const NDArray& other, ...@@ -401,30 +404,33 @@ inline void NDArray::CopyFrom(const NDArray& other,
CopyFromTo(&(other.data_->dl_tensor), &(data_->dl_tensor), stream); CopyFromTo(&(other.data_->dl_tensor), &(data_->dl_tensor), stream);
} }
inline void NDArray::CopyTo(DLTensor* other) const { inline void NDArray::CopyTo(DLTensor *other,
const DGLStreamHandle &stream) const {
CHECK(data_ != nullptr); CHECK(data_ != nullptr);
CopyFromTo(&(data_->dl_tensor), other); CopyFromTo(&(data_->dl_tensor), other, stream);
} }
inline void NDArray::CopyTo(const NDArray& other) const { inline void NDArray::CopyTo(const NDArray &other,
const DGLStreamHandle &stream) const {
CHECK(data_ != nullptr); CHECK(data_ != nullptr);
CHECK(other.data_ != nullptr); CHECK(other.data_ != nullptr);
CopyFromTo(&(data_->dl_tensor), &(other.data_->dl_tensor)); CopyFromTo(&(data_->dl_tensor), &(other.data_->dl_tensor), stream);
} }
inline NDArray NDArray::CopyTo(const DLContext& ctx) const { inline NDArray NDArray::CopyTo(const DLContext &ctx,
const DGLStreamHandle &stream) const {
CHECK(data_ != nullptr); CHECK(data_ != nullptr);
const DLTensor* dptr = operator->(); const DLTensor* dptr = operator->();
NDArray ret = Empty(std::vector<int64_t>(dptr->shape, dptr->shape + dptr->ndim), NDArray ret = Empty(std::vector<int64_t>(dptr->shape, dptr->shape + dptr->ndim),
dptr->dtype, ctx); dptr->dtype, ctx);
this->CopyTo(ret); this->CopyTo(ret, stream);
return ret; return ret;
} }
inline NDArray NDArray::Clone() const { inline NDArray NDArray::Clone(const DGLStreamHandle &stream) const {
CHECK(data_ != nullptr); CHECK(data_ != nullptr);
const DLTensor* dptr = operator->(); const DLTensor* dptr = operator->();
return this->CopyTo(dptr->ctx); return this->CopyTo(dptr->ctx, stream);
} }
inline int NDArray::use_count() const { inline int NDArray::use_count() const {
......
...@@ -254,7 +254,8 @@ HeteroGraphPtr HeteroGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) { ...@@ -254,7 +254,8 @@ HeteroGraphPtr HeteroGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
hgindex->num_verts_per_type_)); hgindex->num_verts_per_type_));
} }
HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DLContext& ctx) { HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DLContext &ctx,
const DGLStreamHandle &stream) {
if (ctx == g->Context()) { if (ctx == g->Context()) {
return g; return g;
} }
...@@ -262,7 +263,7 @@ HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DLContext& ctx) { ...@@ -262,7 +263,7 @@ HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DLContext& ctx) {
CHECK_NOTNULL(hgindex); CHECK_NOTNULL(hgindex);
std::vector<HeteroGraphPtr> rel_graphs; std::vector<HeteroGraphPtr> rel_graphs;
for (auto g : hgindex->relation_graphs_) { for (auto g : hgindex->relation_graphs_) {
rel_graphs.push_back(UnitGraph::CopyTo(g, ctx)); rel_graphs.push_back(UnitGraph::CopyTo(g, ctx, stream));
} }
return HeteroGraphPtr(new HeteroGraph(hgindex->meta_graph_, rel_graphs, return HeteroGraphPtr(new HeteroGraph(hgindex->meta_graph_, rel_graphs,
hgindex->num_verts_per_type_)); hgindex->num_verts_per_type_));
......
...@@ -225,7 +225,8 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -225,7 +225,8 @@ class HeteroGraph : public BaseHeteroGraph {
static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits); static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);
/*! \brief Copy the data to another context */ /*! \brief Copy the data to another context */
static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext& ctx); static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext &ctx,
const DGLStreamHandle &stream = nullptr);
/*! \brief Copy the data to shared memory. /*! \brief Copy the data to shared memory.
* *
......
...@@ -149,10 +149,11 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -149,10 +149,11 @@ class UnitGraph::COO : public BaseHeteroGraph {
return ret; return ret;
} }
COO CopyTo(const DLContext& ctx) const { COO CopyTo(const DLContext &ctx,
const DGLStreamHandle &stream = nullptr) const {
if (Context() == ctx) if (Context() == ctx)
return *this; return *this;
return COO(meta_graph_, adj_.CopyTo(ctx)); return COO(meta_graph_, adj_.CopyTo(ctx, stream));
} }
bool IsMultigraph() const override { bool IsMultigraph() const override {
...@@ -537,11 +538,12 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -537,11 +538,12 @@ class UnitGraph::CSR : public BaseHeteroGraph {
} }
} }
CSR CopyTo(const DLContext& ctx) const { CSR CopyTo(const DLContext &ctx,
const DGLStreamHandle &stream = nullptr) const {
if (Context() == ctx) { if (Context() == ctx) {
return *this; return *this;
} else { } else {
return CSR(meta_graph_, adj_.CopyTo(ctx)); return CSR(meta_graph_, adj_.CopyTo(ctx, stream));
} }
} }
...@@ -1232,18 +1234,22 @@ HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) { ...@@ -1232,18 +1234,22 @@ HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
} }
} }
HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DLContext& ctx) { HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DLContext &ctx,
const DGLStreamHandle &stream) {
if (ctx == g->Context()) { if (ctx == g->Context()) {
return g; return g;
} else { } else {
auto bg = std::dynamic_pointer_cast<UnitGraph>(g); auto bg = std::dynamic_pointer_cast<UnitGraph>(g);
CHECK_NOTNULL(bg); CHECK_NOTNULL(bg);
CSRPtr new_incsr = CSRPtr new_incsr = (bg->in_csr_->defined())
(bg->in_csr_->defined())? CSRPtr(new CSR(bg->in_csr_->CopyTo(ctx))) : nullptr; ? CSRPtr(new CSR(bg->in_csr_->CopyTo(ctx, stream)))
CSRPtr new_outcsr = : nullptr;
(bg->out_csr_->defined())? CSRPtr(new CSR(bg->out_csr_->CopyTo(ctx))) : nullptr; CSRPtr new_outcsr = (bg->out_csr_->defined())
COOPtr new_coo = ? CSRPtr(new CSR(bg->out_csr_->CopyTo(ctx, stream)))
(bg->coo_->defined())? COOPtr(new COO(bg->coo_->CopyTo(ctx))) : nullptr; : nullptr;
COOPtr new_coo = (bg->coo_->defined())
? COOPtr(new COO(bg->coo_->CopyTo(ctx, stream)))
: nullptr;
return HeteroGraphPtr( return HeteroGraphPtr(
new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->formats_)); new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->formats_));
} }
......
...@@ -205,7 +205,8 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -205,7 +205,8 @@ class UnitGraph : public BaseHeteroGraph {
static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits); static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);
/*! \brief Copy the data to another context */ /*! \brief Copy the data to another context */
static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext& ctx); static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext &ctx,
const DGLStreamHandle &stream = nullptr);
/*! /*!
* \brief Create in-edge CSR format of the unit graph. * \brief Create in-edge CSR format of the unit graph.
......
...@@ -60,6 +60,8 @@ class CPUDeviceAPI final : public DeviceAPI { ...@@ -60,6 +60,8 @@ class CPUDeviceAPI final : public DeviceAPI {
size); size);
} }
DGLStreamHandle CreateStream(DGLContext) final { return nullptr; }
void StreamSync(DGLContext ctx, DGLStreamHandle stream) final { void StreamSync(DGLContext ctx, DGLStreamHandle stream) final {
} }
......
...@@ -3,14 +3,15 @@ ...@@ -3,14 +3,15 @@
* \file test_unit_graph.cc * \file test_unit_graph.cc
* \brief Test UnitGraph * \brief Test UnitGraph
*/ */
#include <gtest/gtest.h> #include "../../src/graph/unit_graph.h"
#include "./../src/graph/heterograph.h"
#include "./common.h"
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/immutable_graph.h>
#include <dgl/runtime/device_api.h>
#include <gtest/gtest.h>
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <dgl/immutable_graph.h>
#include "./common.h"
#include "./../src/graph/heterograph.h"
#include "../../src/graph/unit_graph.h"
using namespace dgl; using namespace dgl;
using namespace dgl::runtime; using namespace dgl::runtime;
...@@ -298,6 +299,47 @@ void _TestUnitGraph_Reserve(DLContext ctx) { ...@@ -298,6 +299,47 @@ void _TestUnitGraph_Reserve(DLContext ctx) {
ASSERT_TRUE(g_out_csr.indices->data == r_g_in_csr.indices->data); ASSERT_TRUE(g_out_csr.indices->data == r_g_in_csr.indices->data);
} }
template <typename IdType>
void _TestUnitGraph_CopyTo(const DLContext &src_ctx,
const DGLContext &dst_ctx) {
const aten::CSRMatrix &csr = CSR1<IdType>(src_ctx);
const aten::COOMatrix &coo = COO1<IdType>(src_ctx);
auto device = dgl::runtime::DeviceAPI::Get(dst_ctx);
auto stream = device->CreateStream(dst_ctx);
auto g = dgl::UnitGraph::CreateFromCSC(2, csr);
ASSERT_EQ(g->GetCreatedFormats(), 4);
auto cg = dgl::UnitGraph::CopyTo(g, dst_ctx, stream);
device->StreamSync(dst_ctx, stream);
ASSERT_EQ(cg->GetCreatedFormats(), 4);
g = dgl::UnitGraph::CreateFromCSR(2, csr);
ASSERT_EQ(g->GetCreatedFormats(), 2);
cg = dgl::UnitGraph::CopyTo(g, dst_ctx, stream);
device->StreamSync(dst_ctx, stream);
ASSERT_EQ(cg->GetCreatedFormats(), 2);
g = dgl::UnitGraph::CreateFromCOO(2, coo);
ASSERT_EQ(g->GetCreatedFormats(), 1);
cg = dgl::UnitGraph::CopyTo(g, dst_ctx, stream);
device->StreamSync(dst_ctx, stream);
ASSERT_EQ(cg->GetCreatedFormats(), 1);
}
TEST(UniGraphTest, TestUnitGraph_CopyTo) {
_TestUnitGraph_CopyTo<int32_t>(CPU, CPU);
_TestUnitGraph_CopyTo<int64_t>(CPU, CPU);
#ifdef DGL_USE_CUDA
_TestUnitGraph_CopyTo<int32_t>(CPU, GPU);
_TestUnitGraph_CopyTo<int32_t>(GPU, GPU);
_TestUnitGraph_CopyTo<int32_t>(GPU, CPU);
_TestUnitGraph_CopyTo<int64_t>(CPU, GPU);
_TestUnitGraph_CopyTo<int64_t>(GPU, GPU);
_TestUnitGraph_CopyTo<int64_t>(GPU, CPU);
#endif
}
TEST(UniGraphTest, TestUnitGraph_Create) { TEST(UniGraphTest, TestUnitGraph_Create) {
_TestUnitGraph<int32_t>(CPU); _TestUnitGraph<int32_t>(CPU);
_TestUnitGraph<int64_t>(CPU); _TestUnitGraph<int64_t>(CPU);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment