Unverified Commit 870da747 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[CUDA][Kernel] More CUDA kernels; Standardize the behavior for sorted COO/CSR (#1704)

* add cub; array cumsum

* CSRSliceRows

* fix warning

* operator << for ndarray; CSRSliceRows

* add CSRIsSorted

* add csr_sort

* inplace coosort and outplace csrsort

* WIP: coo is sorted

* mv cuda_utils

* add AllTrue utility

* csr sort

* coo sort

* coo2csr for sorted coo arrays

* CSRToCOO from sorted

* pass tests for the new kernel changes

* cannot use inplace sort

* lint

* try fix msvc error

* Fix g.copy_to and g.asnumbits; ToBlock no longer uses CSC

* stash

* revert some hack

* revert some changes

* address comments

* fix

* fix to_block unittest

* add todo note
parent da8632ca
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <vector> #include <vector>
#include <algorithm> #include <algorithm>
#include <utility> #include <utility>
#include <memory>
#include "../../c_api_common.h" #include "../../c_api_common.h"
using dgl::runtime::NDArray; using dgl::runtime::NDArray;
......
...@@ -51,9 +51,11 @@ ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, bool includ ...@@ -51,9 +51,11 @@ ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, bool includ
const auto src_dst_types = graph->GetEndpointTypes(etype); const auto src_dst_types = graph->GetEndpointTypes(etype);
const dgl_type_t srctype = src_dst_types.first; const dgl_type_t srctype = src_dst_types.first;
const dgl_type_t dsttype = src_dst_types.second; const dgl_type_t dsttype = src_dst_types.second;
const EdgeArray edges = graph->InEdges(etype, rhs_nodes[dsttype]); if (!aten::IsNullArray(rhs_nodes[dsttype])) {
lhs_node_mappings[srctype].Update(edges.src); const EdgeArray& edges = graph->Edges(etype);
edge_arrays[etype] = edges; lhs_node_mappings[srctype].Update(edges.src);
edge_arrays[etype] = edges;
}
} }
const auto meta_graph = graph->meta_graph(); const auto meta_graph = graph->meta_graph();
...@@ -75,11 +77,26 @@ ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, bool includ ...@@ -75,11 +77,26 @@ ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, bool includ
const dgl_type_t dsttype = src_dst_types.second; const dgl_type_t dsttype = src_dst_types.second;
const IdHashMap<IdType> &lhs_map = lhs_node_mappings[srctype]; const IdHashMap<IdType> &lhs_map = lhs_node_mappings[srctype];
const IdHashMap<IdType> &rhs_map = rhs_node_mappings[dsttype]; const IdHashMap<IdType> &rhs_map = rhs_node_mappings[dsttype];
rel_graphs.push_back(CreateFromCOO( if (rhs_map.Size() == 0) {
2, lhs_map.Size(), rhs_map.Size(), // No rhs nodes are given for this edge type. Create an empty graph.
lhs_map.Map(edge_arrays[etype].src, -1), rel_graphs.push_back(CreateFromCOO(
rhs_map.Map(edge_arrays[etype].dst, -1))); 2, lhs_map.Size(), rhs_map.Size(),
induced_edges.push_back(edge_arrays[etype].id); aten::NullArray(), aten::NullArray()));
induced_edges.push_back(aten::NullArray());
} else {
IdArray new_src = lhs_map.Map(edge_arrays[etype].src, -1);
IdArray new_dst = rhs_map.Map(edge_arrays[etype].dst, -1);
// Check whether there are unmapped IDs and raise error.
for (int64_t i = 0; i < new_dst->shape[0]; ++i)
CHECK_NE(new_dst.Ptr<IdType>()[i], -1)
<< "Node " << edge_arrays[etype].dst.Ptr<IdType>()[i] << " does not exist"
<< " in `rhs_nodes`. Argument `rhs_nodes` must contain all the edge"
<< " destination nodes.";
rel_graphs.push_back(CreateFromCOO(
2, lhs_map.Size(), rhs_map.Size(),
new_src, new_dst));
induced_edges.push_back(edge_arrays[etype].id);
}
} }
const HeteroGraphPtr new_graph = CreateHeteroGraph( const HeteroGraphPtr new_graph = CreateHeteroGraph(
......
...@@ -138,13 +138,7 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -138,13 +138,7 @@ class UnitGraph::COO : public BaseHeteroGraph {
COO CopyTo(const DLContext& ctx) const { COO CopyTo(const DLContext& ctx) const {
if (Context() == ctx) if (Context() == ctx)
return *this; return *this;
return COO(meta_graph_, adj_.CopyTo(ctx));
COO ret(
meta_graph_,
adj_.num_rows, adj_.num_cols,
adj_.row.CopyTo(ctx),
adj_.col.CopyTo(ctx));
return ret;
} }
bool IsMultigraph() const override { bool IsMultigraph() const override {
...@@ -516,13 +510,7 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -516,13 +510,7 @@ class UnitGraph::CSR : public BaseHeteroGraph {
if (Context() == ctx) { if (Context() == ctx) {
return *this; return *this;
} else { } else {
CSR ret( return CSR(meta_graph_, adj_.CopyTo(ctx));
meta_graph_,
adj_.num_rows, adj_.num_cols,
adj_.indptr.CopyTo(ctx),
adj_.indices.CopyTo(ctx),
adj_.data.CopyTo(ctx));
return ret;
} }
} }
...@@ -1181,35 +1169,28 @@ HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) { ...@@ -1181,35 +1169,28 @@ HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
if (g->NumBits() == bits) { if (g->NumBits() == bits) {
return g; return g;
} else { } else {
// TODO(minjie): since we don't have int32 operations,
// we make sure that this graph (on CPU) has materialized CSR,
// and then copy them to other context (usually GPU). This should
// be fixed later.
auto bg = std::dynamic_pointer_cast<UnitGraph>(g); auto bg = std::dynamic_pointer_cast<UnitGraph>(g);
CHECK_NOTNULL(bg); CHECK_NOTNULL(bg);
CSRPtr new_incsr = (bg->in_csr_)? CSRPtr(new CSR(bg->in_csr_->AsNumBits(bits))) : nullptr;
CSRPtr new_incsr = CSRPtr(new CSR(bg->GetInCSR()->AsNumBits(bits))); CSRPtr new_outcsr = (bg->out_csr_)? CSRPtr(new CSR(bg->out_csr_->AsNumBits(bits))) : nullptr;
CSRPtr new_outcsr = CSRPtr(new CSR(bg->GetOutCSR()->AsNumBits(bits))); COOPtr new_coo = (bg->coo_)? COOPtr(new COO(bg->coo_->AsNumBits(bits))) : nullptr;
return HeteroGraphPtr( return HeteroGraphPtr(
new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, nullptr, bg->restrict_format_)); new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->restrict_format_));
} }
} }
HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DLContext& ctx) { HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DLContext& ctx) {
if (ctx == g->Context()) { if (ctx == g->Context()) {
return g; return g;
} else {
auto bg = std::dynamic_pointer_cast<UnitGraph>(g);
CHECK_NOTNULL(bg);
CSRPtr new_incsr = (bg->in_csr_)? CSRPtr(new CSR(bg->in_csr_->CopyTo(ctx))) : nullptr;
CSRPtr new_outcsr = (bg->out_csr_)? CSRPtr(new CSR(bg->out_csr_->CopyTo(ctx))) : nullptr;
COOPtr new_coo = (bg->coo_)? COOPtr(new COO(bg->coo_->CopyTo(ctx))) : nullptr;
return HeteroGraphPtr(
new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->restrict_format_));
} }
// TODO(minjie): since we don't have GPU implementation of COO<->CSR,
// we make sure that this graph (on CPU) has materialized CSR,
// and then copy them to other context (usually GPU). This should
// be fixed later.
auto bg = std::dynamic_pointer_cast<UnitGraph>(g);
CHECK_NOTNULL(bg);
CSRPtr new_incsr = CSRPtr(new CSR(bg->GetInCSR()->CopyTo(ctx)));
CSRPtr new_outcsr = CSRPtr(new CSR(bg->GetOutCSR()->CopyTo(ctx)));
return HeteroGraphPtr(
new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, nullptr, bg->restrict_format_));
} }
UnitGraph::UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo, UnitGraph::UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo,
...@@ -1278,9 +1259,8 @@ UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const { ...@@ -1278,9 +1259,8 @@ UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const {
const_cast<UnitGraph*>(this)->in_csr_ = ret; const_cast<UnitGraph*>(this)->in_csr_ = ret;
} else { } else {
CHECK(coo_) << "None of CSR, COO exist"; CHECK(coo_) << "None of CSR, COO exist";
const auto& adj = coo_->adj(); const auto& newadj = aten::CSRSort(aten::COOToCSR(
const auto& newadj = aten::COOToCSR( aten::COOTranspose(coo_->adj())));
aten::COOMatrix{adj.num_cols, adj.num_rows, adj.col, adj.row});
ret = std::make_shared<CSR>(meta_graph(), newadj); ret = std::make_shared<CSR>(meta_graph(), newadj);
if (inplace) if (inplace)
const_cast<UnitGraph*>(this)->in_csr_ = ret; const_cast<UnitGraph*>(this)->in_csr_ = ret;
...@@ -1299,13 +1279,13 @@ UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const { ...@@ -1299,13 +1279,13 @@ UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const {
CSRPtr ret = out_csr_; CSRPtr ret = out_csr_;
if (!out_csr_) { if (!out_csr_) {
if (in_csr_) { if (in_csr_) {
const auto& newadj = aten::CSRTranspose(in_csr_->adj()); const auto& newadj = aten::CSRSort(aten::CSRTranspose(in_csr_->adj()));
ret = std::make_shared<CSR>(meta_graph(), newadj); ret = std::make_shared<CSR>(meta_graph(), newadj);
if (inplace) if (inplace)
const_cast<UnitGraph*>(this)->out_csr_ = ret; const_cast<UnitGraph*>(this)->out_csr_ = ret;
} else { } else {
CHECK(coo_) << "None of CSR, COO exist"; CHECK(coo_) << "None of CSR, COO exist";
const auto& newadj = aten::COOToCSR(coo_->adj()); const auto& newadj = aten::CSRSort(aten::COOToCSR(coo_->adj()));
ret = std::make_shared<CSR>(meta_graph(), newadj); ret = std::make_shared<CSR>(meta_graph(), newadj);
if (inplace) if (inplace)
const_cast<UnitGraph*>(this)->out_csr_ = ret; const_cast<UnitGraph*>(this)->out_csr_ = ret;
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <string.h> #include <string.h>
#include <stdlib.h> #include <stdlib.h>
#include <time.h> #include <time.h>
#include <memory>
#include "socket_communicator.h" #include "socket_communicator.h"
#include "../../c_api_common.h" #include "../../c_api_common.h"
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <vector> #include <vector>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <memory>
#include "communicator.h" #include "communicator.h"
#include "msg_queue.h" #include "msg_queue.h"
...@@ -19,9 +20,9 @@ ...@@ -19,9 +20,9 @@
namespace dgl { namespace dgl {
namespace network { namespace network {
static int kMaxTryCount = 1024; // maximal connection: 1024 static constexpr int kMaxTryCount = 1024; // maximal connection: 1024
static int kTimeOut = 10; // 10 minutes for socket timeout static constexpr int kTimeOut = 10; // 10 minutes for socket timeout
static int kMaxConnection = 1024; // maximal connection: 1024 static constexpr int kMaxConnection = 1024; // maximal connection: 1024
/*! /*!
* \breif Networking address * \breif Networking address
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <dgl/runtime/serializer.h> #include <dgl/runtime/serializer.h>
#include <fstream> #include <fstream>
#include <vector> #include <vector>
#include <unordered_map>
#include "file_util.h" #include "file_util.h"
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#define DGL_RUNTIME_FILE_UTIL_H_ #define DGL_RUNTIME_FILE_UTIL_H_
#include <string> #include <string>
#include <unordered_map>
#include "meta_data.h" #include "meta_data.h"
namespace dgl { namespace dgl {
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <dgl/runtime/module.h> #include <dgl/runtime/module.h>
#include <dgl/runtime/registry.h> #include <dgl/runtime/registry.h>
#include <string> #include <string>
#include <memory>
#include "module_util.h" #include "module_util.h"
namespace dgl { namespace dgl {
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <dgl/runtime/c_runtime_api.h> #include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/c_backend_api.h> #include <dgl/runtime/c_backend_api.h>
#include <vector> #include <vector>
#include <memory>
extern "C" { extern "C" {
// Function signature for generated packed function in shared library // Function signature for generated packed function in shared library
......
...@@ -124,6 +124,8 @@ size_t NDArray::GetSize() const { ...@@ -124,6 +124,8 @@ size_t NDArray::GetSize() const {
} }
int64_t NDArray::NumElements() const { int64_t NDArray::NumElements() const {
if (data_->dl_tensor.ndim == 0)
return 0;
int64_t size = 1; int64_t size = 1;
for (int i = 0; i < data_->dl_tensor.ndim; ++i) { for (int i = 0; i < data_->dl_tensor.ndim; ++i) {
size *= data_->dl_tensor.shape[i]; size *= data_->dl_tensor.shape[i];
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* \brief Workspace pool utility. * \brief Workspace pool utility.
*/ */
#include "workspace_pool.h" #include "workspace_pool.h"
#include <memory>
namespace dgl { namespace dgl {
namespace runtime { namespace runtime {
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <dgl/runtime/device_api.h> #include <dgl/runtime/device_api.h>
#include <vector> #include <vector>
#include <memory>
namespace dgl { namespace dgl {
namespace runtime { namespace runtime {
......
...@@ -1883,4 +1883,4 @@ if __name__ == '__main__': ...@@ -1883,4 +1883,4 @@ if __name__ == '__main__':
# test_isolated_ntype() # test_isolated_ntype()
# test_bipartite() # test_bipartite()
# test_dtype_cast() # test_dtype_cast()
test_format() pass
...@@ -603,10 +603,6 @@ def test_to_block(index_dtype): ...@@ -603,10 +603,6 @@ def test_to_block(index_dtype):
assert bg.number_of_src_nodes() == 4 assert bg.number_of_src_nodes() == 4
assert bg.number_of_dst_nodes() == 4 assert bg.number_of_dst_nodes() == 4
dst_nodes = F.tensor([3, 4], dtype=getattr(F, index_dtype))
bg = dgl.to_block(g_a, dst_nodes)
check(g_a, bg, 'A', 'AA', dst_nodes)
dst_nodes = F.tensor([4, 3, 2, 1], dtype=getattr(F, index_dtype)) dst_nodes = F.tensor([4, 3, 2, 1], dtype=getattr(F, index_dtype))
bg = dgl.to_block(g_a, dst_nodes) bg = dgl.to_block(g_a, dst_nodes)
check(g_a, bg, 'A', 'AA', dst_nodes) check(g_a, bg, 'A', 'AA', dst_nodes)
...@@ -620,17 +616,13 @@ def test_to_block(index_dtype): ...@@ -620,17 +616,13 @@ def test_to_block(index_dtype):
assert bg.number_of_nodes('DST/A') == 0 assert bg.number_of_nodes('DST/A') == 0
checkall(g_ab, bg, None) checkall(g_ab, bg, None)
dst_nodes = {'B': F.tensor([5, 6], dtype=getattr(F, index_dtype))} dst_nodes = {'B': F.tensor([5, 6, 3, 1], dtype=getattr(F, index_dtype))}
bg = dgl.to_block(g, dst_nodes) bg = dgl.to_block(g, dst_nodes)
assert bg.number_of_nodes('SRC/B') == 2 assert bg.number_of_nodes('SRC/B') == 4
assert F.array_equal(bg.srcnodes['B'].data[dgl.NID], bg.dstnodes['B'].data[dgl.NID]) assert F.array_equal(bg.srcnodes['B'].data[dgl.NID], bg.dstnodes['B'].data[dgl.NID])
assert bg.number_of_nodes('DST/A') == 0 assert bg.number_of_nodes('DST/A') == 0
checkall(g, bg, dst_nodes) checkall(g, bg, dst_nodes)
dst_nodes = {'A': F.tensor([3, 4], dtype=getattr(F, index_dtype)), 'B': F.tensor([5, 6], dtype=getattr(F, index_dtype))}
bg = dgl.to_block(g, dst_nodes)
checkall(g, bg, dst_nodes)
dst_nodes = {'A': F.tensor([4, 3, 2, 1], dtype=getattr(F, index_dtype)), 'B': F.tensor([3, 5, 6, 1], dtype=getattr(F, index_dtype))} dst_nodes = {'A': F.tensor([4, 3, 2, 1], dtype=getattr(F, index_dtype)), 'B': F.tensor([3, 5, 6, 1], dtype=getattr(F, index_dtype))}
bg = dgl.to_block(g, dst_nodes=dst_nodes) bg = dgl.to_block(g, dst_nodes=dst_nodes)
checkall(g, bg, dst_nodes) checkall(g, bg, dst_nodes)
......
...@@ -29,6 +29,10 @@ inline int64_t Len(dgl::runtime::NDArray nd) { ...@@ -29,6 +29,10 @@ inline int64_t Len(dgl::runtime::NDArray nd) {
template <typename T> template <typename T>
inline bool ArrayEQ(dgl::runtime::NDArray a1, dgl::runtime::NDArray a2) { inline bool ArrayEQ(dgl::runtime::NDArray a1, dgl::runtime::NDArray a2) {
if (a1->ndim != a2->ndim) return false; if (a1->ndim != a2->ndim) return false;
if (a1->dtype != a2->dtype) return false;
if (a1->ctx != a2->ctx) return false;
if (a1.NumElements() != a2.NumElements()) return false;
if (a1.NumElements() == 0) return true;
int64_t num = 1; int64_t num = 1;
for (int i = 0; i < a1->ndim; ++i) { for (int i = 0; i < a1->ndim; ++i) {
if (a1->shape[i] != a2->shape[i]) if (a1->shape[i] != a2->shape[i])
......
...@@ -208,6 +208,8 @@ template <typename IDX> ...@@ -208,6 +208,8 @@ template <typename IDX>
void _TestIndexSelect(DLContext ctx) { void _TestIndexSelect(DLContext ctx) {
IdArray a = aten::Range(0, 100, sizeof(IDX)*8, ctx); IdArray a = aten::Range(0, 100, sizeof(IDX)*8, ctx);
ASSERT_EQ(aten::IndexSelect<int>(a, 50), 50); ASSERT_EQ(aten::IndexSelect<int>(a, 50), 50);
ASSERT_TRUE(ArrayEQ<IDX>(aten::IndexSelect(a, 10, 20),
aten::Range(10, 20, sizeof(IDX)*8, ctx)));
IdArray b = aten::VecToIdArray(std::vector<IDX>({0, 20, 10}), sizeof(IDX)*8, ctx); IdArray b = aten::VecToIdArray(std::vector<IDX>({0, 20, 10}), sizeof(IDX)*8, ctx);
IdArray c = aten::IndexSelect(a, b); IdArray c = aten::IndexSelect(a, b);
ASSERT_TRUE(ArrayEQ<IDX>(b, c)); ASSERT_TRUE(ArrayEQ<IDX>(b, c));
...@@ -239,3 +241,41 @@ TEST(ArrayTest, TestRelabel_) { ...@@ -239,3 +241,41 @@ TEST(ArrayTest, TestRelabel_) {
_TestRelabel_<int32_t>(); _TestRelabel_<int32_t>();
_TestRelabel_<int64_t>(); _TestRelabel_<int64_t>();
} }
template <typename IDX>
void _TestCumSum(DLContext ctx) {
IdArray a = aten::VecToIdArray(std::vector<IDX>({8, 6, 7, 5, 3, 0, 9}),
sizeof(IDX)*8, ctx);
{
IdArray tb = aten::VecToIdArray(std::vector<IDX>({8, 14, 21, 26, 29, 29, 38}),
sizeof(IDX)*8, ctx);
IdArray b = aten::CumSum(a);
ASSERT_TRUE(ArrayEQ<IDX>(b, tb));
}
{
IdArray tb = aten::VecToIdArray(std::vector<IDX>({0, 8, 14, 21, 26, 29, 29, 38}),
sizeof(IDX)*8, ctx);
IdArray b = aten::CumSum(a, true);
ASSERT_TRUE(ArrayEQ<IDX>(b, tb));
}
a = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX)*8, ctx);
{
IdArray tb = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX)*8, ctx);
IdArray b = aten::CumSum(a);
ASSERT_TRUE(ArrayEQ<IDX>(b, tb));
}
{
IdArray tb = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX)*8, ctx);
IdArray b = aten::CumSum(a);
ASSERT_TRUE(ArrayEQ<IDX>(b, tb));
}
}
TEST(ArrayTest, CumSum) {
_TestCumSum<int32_t>(CPU);
_TestCumSum<int64_t>(CPU);
#ifdef DGL_USE_CUDA
_TestCumSum<int32_t>(GPU);
_TestCumSum<int64_t>(GPU);
#endif
}
...@@ -17,8 +17,8 @@ aten::CSRMatrix CSR1(DLContext ctx = CTX) { ...@@ -17,8 +17,8 @@ aten::CSRMatrix CSR1(DLContext ctx = CTX) {
return aten::CSRMatrix( return aten::CSRMatrix(
4, 5, 4, 5,
aten::VecToIdArray(std::vector<IDX>({0, 2, 3, 5, 5}), sizeof(IDX)*8, ctx), aten::VecToIdArray(std::vector<IDX>({0, 2, 3, 5, 5}), sizeof(IDX)*8, ctx),
aten::VecToIdArray(std::vector<IDX>({1, 2, 0, 2, 3}), sizeof(IDX)*8, ctx), aten::VecToIdArray(std::vector<IDX>({1, 2, 0, 3, 2}), sizeof(IDX)*8, ctx),
aten::VecToIdArray(std::vector<IDX>({0, 2, 3, 1, 4}), sizeof(IDX)*8, ctx), aten::VecToIdArray(std::vector<IDX>({0, 2, 3, 4, 1}), sizeof(IDX)*8, ctx),
false); false);
} }
...@@ -277,12 +277,23 @@ void _TestCSRToCOO(DLContext ctx) { ...@@ -277,12 +277,23 @@ void _TestCSRToCOO(DLContext ctx) {
auto coo = CSRToCOO(csr, false); auto coo = CSRToCOO(csr, false);
ASSERT_EQ(coo.num_rows, 4); ASSERT_EQ(coo.num_rows, 4);
ASSERT_EQ(coo.num_cols, 5); ASSERT_EQ(coo.num_cols, 5);
ASSERT_TRUE(coo.row_sorted);
auto tr = aten::VecToIdArray(std::vector<IDX>({0, 0, 0, 1, 2, 2}), sizeof(IDX)*8, ctx); auto tr = aten::VecToIdArray(std::vector<IDX>({0, 0, 0, 1, 2, 2}), sizeof(IDX)*8, ctx);
auto tc = aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0, 2, 3}), sizeof(IDX)*8, ctx);
auto td = aten::VecToIdArray(std::vector<IDX>({0, 2, 5, 3, 1, 4}), sizeof(IDX)*8, ctx);
ASSERT_TRUE(ArrayEQ<IDX>(coo.row, tr)); ASSERT_TRUE(ArrayEQ<IDX>(coo.row, tr));
ASSERT_TRUE(ArrayEQ<IDX>(coo.col, tc)); ASSERT_TRUE(ArrayEQ<IDX>(coo.col, csr.indices));
ASSERT_TRUE(ArrayEQ<IDX>(coo.data, td)); ASSERT_TRUE(ArrayEQ<IDX>(coo.data, csr.data));
// convert from sorted csr
auto s_csr = CSRSort(csr);
coo = CSRToCOO(s_csr, false);
ASSERT_EQ(coo.num_rows, 4);
ASSERT_EQ(coo.num_cols, 5);
ASSERT_TRUE(coo.row_sorted);
ASSERT_TRUE(coo.col_sorted);
tr = aten::VecToIdArray(std::vector<IDX>({0, 0, 0, 1, 2, 2}), sizeof(IDX)*8, ctx);
ASSERT_TRUE(ArrayEQ<IDX>(coo.row, tr));
ASSERT_TRUE(ArrayEQ<IDX>(coo.col, s_csr.indices));
ASSERT_TRUE(ArrayEQ<IDX>(coo.data, s_csr.data));
} }
{ {
auto coo = CSRToCOO(csr, true); auto coo = CSRToCOO(csr, true);
...@@ -294,7 +305,7 @@ void _TestCSRToCOO(DLContext ctx) { ...@@ -294,7 +305,7 @@ void _TestCSRToCOO(DLContext ctx) {
} }
} }
TEST(SpmatTest, TestCSRToCOO) { TEST(SpmatTest, CSRToCOO) {
_TestCSRToCOO<int32_t>(CPU); _TestCSRToCOO<int32_t>(CPU);
_TestCSRToCOO<int64_t>(CPU); _TestCSRToCOO<int64_t>(CPU);
#if DGL_USE_CUDA #if DGL_USE_CUDA
...@@ -303,8 +314,8 @@ TEST(SpmatTest, TestCSRToCOO) { ...@@ -303,8 +314,8 @@ TEST(SpmatTest, TestCSRToCOO) {
} }
template <typename IDX> template <typename IDX>
void _TestCSRSliceRows() { void _TestCSRSliceRows(DLContext ctx) {
auto csr = CSR2<IDX>(); auto csr = CSR2<IDX>(ctx);
auto x = aten::CSRSliceRows(csr, 1, 4); auto x = aten::CSRSliceRows(csr, 1, 4);
// [1, 0, 0, 0, 0], // [1, 0, 0, 0, 0],
// [0, 0, 1, 1, 0], // [0, 0, 1, 1, 0],
...@@ -312,30 +323,34 @@ void _TestCSRSliceRows() { ...@@ -312,30 +323,34 @@ void _TestCSRSliceRows() {
// data: [3, 1, 4] // data: [3, 1, 4]
ASSERT_EQ(x.num_rows, 3); ASSERT_EQ(x.num_rows, 3);
ASSERT_EQ(x.num_cols, 5); ASSERT_EQ(x.num_cols, 5);
auto tp = aten::VecToIdArray(std::vector<IDX>({0, 1, 3, 3}), sizeof(IDX)*8, CTX); auto tp = aten::VecToIdArray(std::vector<IDX>({0, 1, 3, 3}), sizeof(IDX)*8, ctx);
auto ti = aten::VecToIdArray(std::vector<IDX>({0, 2, 3}), sizeof(IDX)*8, CTX); auto ti = aten::VecToIdArray(std::vector<IDX>({0, 2, 3}), sizeof(IDX)*8, ctx);
auto td = aten::VecToIdArray(std::vector<IDX>({3, 1, 4}), sizeof(IDX)*8, CTX); auto td = aten::VecToIdArray(std::vector<IDX>({3, 1, 4}), sizeof(IDX)*8, ctx);
ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp)); ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));
ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti)); ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));
ASSERT_TRUE(ArrayEQ<IDX>(x.data, td)); ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));
auto r = aten::VecToIdArray(std::vector<IDX>({0, 1, 3}), sizeof(IDX)*8, CTX); auto r = aten::VecToIdArray(std::vector<IDX>({0, 1, 3}), sizeof(IDX)*8, ctx);
x = aten::CSRSliceRows(csr, r); x = aten::CSRSliceRows(csr, r);
// [[0, 1, 2, 0, 0], // [[0, 1, 2, 0, 0],
// [1, 0, 0, 0, 0], // [1, 0, 0, 0, 0],
// [0, 0, 0, 0, 0]] // [0, 0, 0, 0, 0]]
// data: [0, 2, 5, 3] // data: [0, 2, 5, 3]
tp = aten::VecToIdArray(std::vector<IDX>({0, 3, 4, 4}), sizeof(IDX)*8, CTX); tp = aten::VecToIdArray(std::vector<IDX>({0, 3, 4, 4}), sizeof(IDX)*8, ctx);
ti = aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0}), sizeof(IDX)*8, CTX); ti = aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0}), sizeof(IDX)*8, ctx);
td = aten::VecToIdArray(std::vector<IDX>({0, 2, 5, 3}), sizeof(IDX)*8, CTX); td = aten::VecToIdArray(std::vector<IDX>({0, 2, 5, 3}), sizeof(IDX)*8, ctx);
ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp)); ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));
ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti)); ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));
ASSERT_TRUE(ArrayEQ<IDX>(x.data, td)); ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));
} }
TEST(SpmatTest, TestCSRSliceRows) { TEST(SpmatTest, TestCSRSliceRows) {
_TestCSRSliceRows<int32_t>(); _TestCSRSliceRows<int32_t>(CPU);
_TestCSRSliceRows<int64_t>(); _TestCSRSliceRows<int64_t>(CPU);
#ifdef DGL_USE_CUDA
_TestCSRSliceRows<int32_t>(GPU);
_TestCSRSliceRows<int64_t>(GPU);
#endif
} }
template <typename IDX> template <typename IDX>
...@@ -376,6 +391,29 @@ TEST(SpmatTest, TestCSRHasDuplicate) { ...@@ -376,6 +391,29 @@ TEST(SpmatTest, TestCSRHasDuplicate) {
_TestCSRHasDuplicate<int64_t>(); _TestCSRHasDuplicate<int64_t>();
} }
template <typename IDX>
void _TestCSRSort(DLContext ctx) {
auto csr = CSR1<IDX>(ctx);
ASSERT_FALSE(aten::CSRIsSorted(csr));
auto csr1 = aten::CSRSort(csr);
ASSERT_FALSE(aten::CSRIsSorted(csr));
ASSERT_TRUE(aten::CSRIsSorted(csr1));
ASSERT_TRUE(csr1.sorted);
aten::CSRSort_(&csr);
ASSERT_TRUE(aten::CSRIsSorted(csr));
ASSERT_TRUE(csr.sorted);
csr = CSR2<IDX>(ctx);
ASSERT_TRUE(aten::CSRIsSorted(csr));
}
TEST(SpmatTest, CSRSort) {
_TestCSRSort<int32_t>(CPU);
_TestCSRSort<int64_t>(CPU);
#ifdef DGL_USE_CUDA
_TestCSRSort<int32_t>(GPU);
#endif
}
template <typename IDX> template <typename IDX>
void _TestCOOToCSR(DLContext ctx) { void _TestCOOToCSR(DLContext ctx) {
auto coo = COO1<IDX>(ctx); auto coo = COO1<IDX>(ctx);
...@@ -392,6 +430,7 @@ void _TestCOOToCSR(DLContext ctx) { ...@@ -392,6 +430,7 @@ void _TestCOOToCSR(DLContext ctx) {
ASSERT_EQ(coo.num_cols, csr.num_cols); ASSERT_EQ(coo.num_cols, csr.num_cols);
ASSERT_TRUE(ArrayEQ<IDX>(csr.indptr, tcsr.indptr)); ASSERT_TRUE(ArrayEQ<IDX>(csr.indptr, tcsr.indptr));
// Convert from row sorted coo
coo = COO1<IDX>(ctx); coo = COO1<IDX>(ctx);
auto rs_coo = aten::COOSort(coo, false); auto rs_coo = aten::COOSort(coo, false);
auto rs_csr = CSR1<IDX>(ctx); auto rs_csr = CSR1<IDX>(ctx);
...@@ -399,6 +438,8 @@ void _TestCOOToCSR(DLContext ctx) { ...@@ -399,6 +438,8 @@ void _TestCOOToCSR(DLContext ctx) {
ASSERT_EQ(coo.num_rows, rs_tcsr.num_rows); ASSERT_EQ(coo.num_rows, rs_tcsr.num_rows);
ASSERT_EQ(coo.num_cols, rs_tcsr.num_cols); ASSERT_EQ(coo.num_cols, rs_tcsr.num_cols);
ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.indptr, rs_tcsr.indptr)); ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.indptr, rs_tcsr.indptr));
ASSERT_TRUE(ArrayEQ<IDX>(rs_tcsr.indices, rs_coo.col));
ASSERT_TRUE(ArrayEQ<IDX>(rs_tcsr.data, rs_coo.data));
coo = COO3<IDX>(ctx); coo = COO3<IDX>(ctx);
rs_coo = aten::COOSort(coo, false); rs_coo = aten::COOSort(coo, false);
...@@ -407,16 +448,20 @@ void _TestCOOToCSR(DLContext ctx) { ...@@ -407,16 +448,20 @@ void _TestCOOToCSR(DLContext ctx) {
ASSERT_EQ(coo.num_rows, rs_tcsr.num_rows); ASSERT_EQ(coo.num_rows, rs_tcsr.num_rows);
ASSERT_EQ(coo.num_cols, rs_tcsr.num_cols); ASSERT_EQ(coo.num_cols, rs_tcsr.num_cols);
ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.indptr, rs_tcsr.indptr)); ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.indptr, rs_tcsr.indptr));
ASSERT_TRUE(ArrayEQ<IDX>(rs_tcsr.indices, rs_coo.col));
ASSERT_TRUE(ArrayEQ<IDX>(rs_tcsr.data, rs_coo.data));
// Convert from col sorted coo
coo = COO1<IDX>(ctx); coo = COO1<IDX>(ctx);
auto src_coo = aten::COOSort(coo, true); auto src_coo = aten::COOSort(coo, true);
auto src_csr = CSR1<IDX>(ctx); auto src_csr = CSR1<IDX>(ctx);
auto src_tcsr = aten::COOToCSR(src_coo); auto src_tcsr = aten::COOToCSR(src_coo);
ASSERT_EQ(coo.num_rows, src_tcsr.num_rows); ASSERT_EQ(coo.num_rows, src_tcsr.num_rows);
ASSERT_EQ(coo.num_cols, src_tcsr.num_cols); ASSERT_EQ(coo.num_cols, src_tcsr.num_cols);
ASSERT_TRUE(ArrayEQ<IDX>(src_csr.indptr, src_tcsr.indptr)); ASSERT_TRUE(src_tcsr.sorted);
ASSERT_TRUE(ArrayEQ<IDX>(src_csr.indices, src_tcsr.indices)); ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.indptr, src_csr.indptr));
ASSERT_TRUE(ArrayEQ<IDX>(src_csr.data, src_tcsr.data)); ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.indices, src_coo.col));
ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.data, src_coo.data));
coo = COO3<IDX>(ctx); coo = COO3<IDX>(ctx);
src_coo = aten::COOSort(coo, true); src_coo = aten::COOSort(coo, true);
...@@ -424,12 +469,13 @@ void _TestCOOToCSR(DLContext ctx) { ...@@ -424,12 +469,13 @@ void _TestCOOToCSR(DLContext ctx) {
src_tcsr = aten::COOToCSR(src_coo); src_tcsr = aten::COOToCSR(src_coo);
ASSERT_EQ(coo.num_rows, src_tcsr.num_rows); ASSERT_EQ(coo.num_rows, src_tcsr.num_rows);
ASSERT_EQ(coo.num_cols, src_tcsr.num_cols); ASSERT_EQ(coo.num_cols, src_tcsr.num_cols);
ASSERT_TRUE(ArrayEQ<IDX>(src_csr.indptr, src_tcsr.indptr)); ASSERT_TRUE(src_tcsr.sorted);
ASSERT_TRUE(ArrayEQ<IDX>(src_csr.indices, src_tcsr.indices)); ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.indptr, src_csr.indptr));
ASSERT_TRUE(ArrayEQ<IDX>(src_csr.data, src_tcsr.data)); ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.indices, src_coo.col));
ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.data, src_coo.data));
} }
TEST(SpmatTest, TestCOOToCSR) { TEST(SpmatTest, COOToCSR) {
_TestCOOToCSR<int32_t>(CPU); _TestCOOToCSR<int32_t>(CPU);
_TestCOOToCSR<int64_t>(CPU); _TestCOOToCSR<int64_t>(CPU);
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
...@@ -453,12 +499,37 @@ TEST(SpmatTest, TestCOOHasDuplicate) { ...@@ -453,12 +499,37 @@ TEST(SpmatTest, TestCOOHasDuplicate) {
template <typename IDX> template <typename IDX>
void _TestCOOSort(DLContext ctx) { void _TestCOOSort(DLContext ctx) {
auto coo = COO3<IDX>(ctx); auto coo = COO3<IDX>(ctx);
auto sr_coo = COOSort(coo, false); auto sr_coo = COOSort(coo, false);
ASSERT_EQ(coo.num_rows, sr_coo.num_rows); ASSERT_EQ(coo.num_rows, sr_coo.num_rows);
ASSERT_EQ(coo.num_cols, sr_coo.num_cols); ASSERT_EQ(coo.num_cols, sr_coo.num_cols);
ASSERT_TRUE(sr_coo.row_sorted);
auto flags = COOIsSorted(sr_coo);
ASSERT_TRUE(flags.first);
flags = COOIsSorted(coo); // original coo should stay the same
ASSERT_FALSE(flags.first);
ASSERT_FALSE(flags.second);
auto src_coo = COOSort(coo, true); auto src_coo = COOSort(coo, true);
ASSERT_EQ(coo.num_rows, src_coo.num_rows); ASSERT_EQ(coo.num_rows, src_coo.num_rows);
ASSERT_EQ(coo.num_cols, src_coo.num_cols); ASSERT_EQ(coo.num_cols, src_coo.num_cols);
ASSERT_TRUE(src_coo.row_sorted);
ASSERT_TRUE(src_coo.col_sorted);
flags = COOIsSorted(src_coo);
ASSERT_TRUE(flags.first);
ASSERT_TRUE(flags.second);
// sort inplace
COOSort_(&coo);
ASSERT_TRUE(coo.row_sorted);
flags = COOIsSorted(coo);
ASSERT_TRUE(flags.first);
COOSort_(&coo, true);
ASSERT_TRUE(coo.row_sorted);
ASSERT_TRUE(coo.col_sorted);
flags = COOIsSorted(coo);
ASSERT_TRUE(flags.first);
ASSERT_TRUE(flags.second);
// COO3 // COO3
// [[0, 1, 2, 0, 0], // [[0, 1, 2, 0, 0],
...@@ -489,7 +560,7 @@ void _TestCOOSort(DLContext ctx) { ...@@ -489,7 +560,7 @@ void _TestCOOSort(DLContext ctx) {
ASSERT_TRUE(ArrayEQ<IDX>(src_coo.data, sort_col_data)); ASSERT_TRUE(ArrayEQ<IDX>(src_coo.data, sort_col_data));
} }
TEST(SpmatTest, TestCOOSort) { TEST(SpmatTest, COOSort) {
_TestCOOSort<int32_t>(CPU); _TestCOOSort<int32_t>(CPU);
_TestCOOSort<int64_t>(CPU); _TestCOOSort<int64_t>(CPU);
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
......
Subproject commit c3cceac115c072fb63df1836ff46d8c60d9eb304
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment