Unverified Commit 870da747 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[CUDA][Kernel] More CUDA kernels; Standardize the behavior for sorted COO/CSR (#1704)

* add cub; array cumsum

* CSRSliceRows

* fix warning

* operator << for ndarray; CSRSliceRows

* add CSRIsSorted

* add csr_sort

* inplace coosort and outplace csrsort

* WIP: coo is sorted

* mv cuda_utils

* add AllTrue utility

* csr sort

* coo sort

* coo2csr for sorted coo arrays

* CSRToCOO from sorted

* pass tests for the new kernel changes

* cannot use inplace sort

* lint

* try fix msvc error

* Fix g.copy_to and g.asnumbits; ToBlock no longer uses CSC

* stash

* revert some hack

* revert some changes

* address comments

* fix

* fix to_block unittest

* add todo note
parent da8632ca
......@@ -20,6 +20,7 @@
#include <vector>
#include <algorithm>
#include <utility>
#include <memory>
#include "../../c_api_common.h"
using dgl::runtime::NDArray;
......
......@@ -51,9 +51,11 @@ ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, bool includ
const auto src_dst_types = graph->GetEndpointTypes(etype);
const dgl_type_t srctype = src_dst_types.first;
const dgl_type_t dsttype = src_dst_types.second;
const EdgeArray edges = graph->InEdges(etype, rhs_nodes[dsttype]);
lhs_node_mappings[srctype].Update(edges.src);
edge_arrays[etype] = edges;
if (!aten::IsNullArray(rhs_nodes[dsttype])) {
const EdgeArray& edges = graph->Edges(etype);
lhs_node_mappings[srctype].Update(edges.src);
edge_arrays[etype] = edges;
}
}
const auto meta_graph = graph->meta_graph();
......@@ -75,11 +77,26 @@ ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, bool includ
const dgl_type_t dsttype = src_dst_types.second;
const IdHashMap<IdType> &lhs_map = lhs_node_mappings[srctype];
const IdHashMap<IdType> &rhs_map = rhs_node_mappings[dsttype];
rel_graphs.push_back(CreateFromCOO(
2, lhs_map.Size(), rhs_map.Size(),
lhs_map.Map(edge_arrays[etype].src, -1),
rhs_map.Map(edge_arrays[etype].dst, -1)));
induced_edges.push_back(edge_arrays[etype].id);
if (rhs_map.Size() == 0) {
// No rhs nodes are given for this edge type. Create an empty graph.
rel_graphs.push_back(CreateFromCOO(
2, lhs_map.Size(), rhs_map.Size(),
aten::NullArray(), aten::NullArray()));
induced_edges.push_back(aten::NullArray());
} else {
IdArray new_src = lhs_map.Map(edge_arrays[etype].src, -1);
IdArray new_dst = rhs_map.Map(edge_arrays[etype].dst, -1);
// Check whether there are unmapped IDs and raise error.
for (int64_t i = 0; i < new_dst->shape[0]; ++i)
CHECK_NE(new_dst.Ptr<IdType>()[i], -1)
<< "Node " << edge_arrays[etype].dst.Ptr<IdType>()[i] << " does not exist"
<< " in `rhs_nodes`. Argument `rhs_nodes` must contain all the edge"
<< " destination nodes.";
rel_graphs.push_back(CreateFromCOO(
2, lhs_map.Size(), rhs_map.Size(),
new_src, new_dst));
induced_edges.push_back(edge_arrays[etype].id);
}
}
const HeteroGraphPtr new_graph = CreateHeteroGraph(
......
......@@ -138,13 +138,7 @@ class UnitGraph::COO : public BaseHeteroGraph {
COO CopyTo(const DLContext& ctx) const {
if (Context() == ctx)
return *this;
COO ret(
meta_graph_,
adj_.num_rows, adj_.num_cols,
adj_.row.CopyTo(ctx),
adj_.col.CopyTo(ctx));
return ret;
return COO(meta_graph_, adj_.CopyTo(ctx));
}
bool IsMultigraph() const override {
......@@ -516,13 +510,7 @@ class UnitGraph::CSR : public BaseHeteroGraph {
if (Context() == ctx) {
return *this;
} else {
CSR ret(
meta_graph_,
adj_.num_rows, adj_.num_cols,
adj_.indptr.CopyTo(ctx),
adj_.indices.CopyTo(ctx),
adj_.data.CopyTo(ctx));
return ret;
return CSR(meta_graph_, adj_.CopyTo(ctx));
}
}
......@@ -1181,35 +1169,28 @@ HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
if (g->NumBits() == bits) {
return g;
} else {
// TODO(minjie): since we don't have int32 operations,
// we make sure that this graph (on CPU) has materialized CSR,
// and then copy them to other context (usually GPU). This should
// be fixed later.
auto bg = std::dynamic_pointer_cast<UnitGraph>(g);
CHECK_NOTNULL(bg);
CSRPtr new_incsr = CSRPtr(new CSR(bg->GetInCSR()->AsNumBits(bits)));
CSRPtr new_outcsr = CSRPtr(new CSR(bg->GetOutCSR()->AsNumBits(bits)));
CSRPtr new_incsr = (bg->in_csr_)? CSRPtr(new CSR(bg->in_csr_->AsNumBits(bits))) : nullptr;
CSRPtr new_outcsr = (bg->out_csr_)? CSRPtr(new CSR(bg->out_csr_->AsNumBits(bits))) : nullptr;
COOPtr new_coo = (bg->coo_)? COOPtr(new COO(bg->coo_->AsNumBits(bits))) : nullptr;
return HeteroGraphPtr(
new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, nullptr, bg->restrict_format_));
new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->restrict_format_));
}
}
HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DLContext& ctx) {
if (ctx == g->Context()) {
return g;
} else {
auto bg = std::dynamic_pointer_cast<UnitGraph>(g);
CHECK_NOTNULL(bg);
CSRPtr new_incsr = (bg->in_csr_)? CSRPtr(new CSR(bg->in_csr_->CopyTo(ctx))) : nullptr;
CSRPtr new_outcsr = (bg->out_csr_)? CSRPtr(new CSR(bg->out_csr_->CopyTo(ctx))) : nullptr;
COOPtr new_coo = (bg->coo_)? COOPtr(new COO(bg->coo_->CopyTo(ctx))) : nullptr;
return HeteroGraphPtr(
new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->restrict_format_));
}
// TODO(minjie): since we don't have GPU implementation of COO<->CSR,
// we make sure that this graph (on CPU) has materialized CSR,
// and then copy them to other context (usually GPU). This should
// be fixed later.
auto bg = std::dynamic_pointer_cast<UnitGraph>(g);
CHECK_NOTNULL(bg);
CSRPtr new_incsr = CSRPtr(new CSR(bg->GetInCSR()->CopyTo(ctx)));
CSRPtr new_outcsr = CSRPtr(new CSR(bg->GetOutCSR()->CopyTo(ctx)));
return HeteroGraphPtr(
new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, nullptr, bg->restrict_format_));
}
UnitGraph::UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo,
......@@ -1278,9 +1259,8 @@ UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const {
const_cast<UnitGraph*>(this)->in_csr_ = ret;
} else {
CHECK(coo_) << "None of CSR, COO exist";
const auto& adj = coo_->adj();
const auto& newadj = aten::COOToCSR(
aten::COOMatrix{adj.num_cols, adj.num_rows, adj.col, adj.row});
const auto& newadj = aten::CSRSort(aten::COOToCSR(
aten::COOTranspose(coo_->adj())));
ret = std::make_shared<CSR>(meta_graph(), newadj);
if (inplace)
const_cast<UnitGraph*>(this)->in_csr_ = ret;
......@@ -1299,13 +1279,13 @@ UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const {
CSRPtr ret = out_csr_;
if (!out_csr_) {
if (in_csr_) {
const auto& newadj = aten::CSRTranspose(in_csr_->adj());
const auto& newadj = aten::CSRSort(aten::CSRTranspose(in_csr_->adj()));
ret = std::make_shared<CSR>(meta_graph(), newadj);
if (inplace)
const_cast<UnitGraph*>(this)->out_csr_ = ret;
} else {
CHECK(coo_) << "None of CSR, COO exist";
const auto& newadj = aten::COOToCSR(coo_->adj());
const auto& newadj = aten::CSRSort(aten::COOToCSR(coo_->adj()));
ret = std::make_shared<CSR>(meta_graph(), newadj);
if (inplace)
const_cast<UnitGraph*>(this)->out_csr_ = ret;
......
......@@ -8,6 +8,7 @@
#include <string.h>
#include <stdlib.h>
#include <time.h>
#include <memory>
#include "socket_communicator.h"
#include "../../c_api_common.h"
......
......@@ -10,6 +10,7 @@
#include <vector>
#include <string>
#include <unordered_map>
#include <memory>
#include "communicator.h"
#include "msg_queue.h"
......@@ -19,9 +20,9 @@
namespace dgl {
namespace network {
static int kMaxTryCount = 1024; // maximal connection: 1024
static int kTimeOut = 10; // 10 minutes for socket timeout
static int kMaxConnection = 1024; // maximal connection: 1024
static constexpr int kMaxTryCount = 1024; // maximal connection: 1024
static constexpr int kTimeOut = 10; // 10 minutes for socket timeout
static constexpr int kMaxConnection = 1024; // maximal connection: 1024
/*!
* \breif Networking address
......
......@@ -7,6 +7,7 @@
#include <dgl/runtime/serializer.h>
#include <fstream>
#include <vector>
#include <unordered_map>
#include "file_util.h"
......
......@@ -7,6 +7,7 @@
#define DGL_RUNTIME_FILE_UTIL_H_
#include <string>
#include <unordered_map>
#include "meta_data.h"
namespace dgl {
......
......@@ -9,6 +9,7 @@
#include <dgl/runtime/module.h>
#include <dgl/runtime/registry.h>
#include <string>
#include <memory>
#include "module_util.h"
namespace dgl {
......
......@@ -10,6 +10,7 @@
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/c_backend_api.h>
#include <vector>
#include <memory>
extern "C" {
// Function signature for generated packed function in shared library
......
......@@ -124,6 +124,8 @@ size_t NDArray::GetSize() const {
}
int64_t NDArray::NumElements() const {
if (data_->dl_tensor.ndim == 0)
return 0;
int64_t size = 1;
for (int i = 0; i < data_->dl_tensor.ndim; ++i) {
size *= data_->dl_tensor.shape[i];
......
......@@ -4,6 +4,7 @@
* \brief Workspace pool utility.
*/
#include "workspace_pool.h"
#include <memory>
namespace dgl {
namespace runtime {
......
......@@ -8,6 +8,7 @@
#include <dgl/runtime/device_api.h>
#include <vector>
#include <memory>
namespace dgl {
namespace runtime {
......
......@@ -1883,4 +1883,4 @@ if __name__ == '__main__':
# test_isolated_ntype()
# test_bipartite()
# test_dtype_cast()
test_format()
pass
......@@ -603,10 +603,6 @@ def test_to_block(index_dtype):
assert bg.number_of_src_nodes() == 4
assert bg.number_of_dst_nodes() == 4
dst_nodes = F.tensor([3, 4], dtype=getattr(F, index_dtype))
bg = dgl.to_block(g_a, dst_nodes)
check(g_a, bg, 'A', 'AA', dst_nodes)
dst_nodes = F.tensor([4, 3, 2, 1], dtype=getattr(F, index_dtype))
bg = dgl.to_block(g_a, dst_nodes)
check(g_a, bg, 'A', 'AA', dst_nodes)
......@@ -620,17 +616,13 @@ def test_to_block(index_dtype):
assert bg.number_of_nodes('DST/A') == 0
checkall(g_ab, bg, None)
dst_nodes = {'B': F.tensor([5, 6], dtype=getattr(F, index_dtype))}
dst_nodes = {'B': F.tensor([5, 6, 3, 1], dtype=getattr(F, index_dtype))}
bg = dgl.to_block(g, dst_nodes)
assert bg.number_of_nodes('SRC/B') == 2
assert bg.number_of_nodes('SRC/B') == 4
assert F.array_equal(bg.srcnodes['B'].data[dgl.NID], bg.dstnodes['B'].data[dgl.NID])
assert bg.number_of_nodes('DST/A') == 0
checkall(g, bg, dst_nodes)
dst_nodes = {'A': F.tensor([3, 4], dtype=getattr(F, index_dtype)), 'B': F.tensor([5, 6], dtype=getattr(F, index_dtype))}
bg = dgl.to_block(g, dst_nodes)
checkall(g, bg, dst_nodes)
dst_nodes = {'A': F.tensor([4, 3, 2, 1], dtype=getattr(F, index_dtype)), 'B': F.tensor([3, 5, 6, 1], dtype=getattr(F, index_dtype))}
bg = dgl.to_block(g, dst_nodes=dst_nodes)
checkall(g, bg, dst_nodes)
......
......@@ -29,6 +29,10 @@ inline int64_t Len(dgl::runtime::NDArray nd) {
template <typename T>
inline bool ArrayEQ(dgl::runtime::NDArray a1, dgl::runtime::NDArray a2) {
if (a1->ndim != a2->ndim) return false;
if (a1->dtype != a2->dtype) return false;
if (a1->ctx != a2->ctx) return false;
if (a1.NumElements() != a2.NumElements()) return false;
if (a1.NumElements() == 0) return true;
int64_t num = 1;
for (int i = 0; i < a1->ndim; ++i) {
if (a1->shape[i] != a2->shape[i])
......
......@@ -208,6 +208,8 @@ template <typename IDX>
void _TestIndexSelect(DLContext ctx) {
IdArray a = aten::Range(0, 100, sizeof(IDX)*8, ctx);
ASSERT_EQ(aten::IndexSelect<int>(a, 50), 50);
ASSERT_TRUE(ArrayEQ<IDX>(aten::IndexSelect(a, 10, 20),
aten::Range(10, 20, sizeof(IDX)*8, ctx)));
IdArray b = aten::VecToIdArray(std::vector<IDX>({0, 20, 10}), sizeof(IDX)*8, ctx);
IdArray c = aten::IndexSelect(a, b);
ASSERT_TRUE(ArrayEQ<IDX>(b, c));
......@@ -239,3 +241,41 @@ TEST(ArrayTest, TestRelabel_) {
_TestRelabel_<int32_t>();
_TestRelabel_<int64_t>();
}
template <typename IDX>
void _TestCumSum(DLContext ctx) {
IdArray a = aten::VecToIdArray(std::vector<IDX>({8, 6, 7, 5, 3, 0, 9}),
sizeof(IDX)*8, ctx);
{
IdArray tb = aten::VecToIdArray(std::vector<IDX>({8, 14, 21, 26, 29, 29, 38}),
sizeof(IDX)*8, ctx);
IdArray b = aten::CumSum(a);
ASSERT_TRUE(ArrayEQ<IDX>(b, tb));
}
{
IdArray tb = aten::VecToIdArray(std::vector<IDX>({0, 8, 14, 21, 26, 29, 29, 38}),
sizeof(IDX)*8, ctx);
IdArray b = aten::CumSum(a, true);
ASSERT_TRUE(ArrayEQ<IDX>(b, tb));
}
a = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX)*8, ctx);
{
IdArray tb = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX)*8, ctx);
IdArray b = aten::CumSum(a);
ASSERT_TRUE(ArrayEQ<IDX>(b, tb));
}
{
IdArray tb = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX)*8, ctx);
IdArray b = aten::CumSum(a);
ASSERT_TRUE(ArrayEQ<IDX>(b, tb));
}
}
TEST(ArrayTest, CumSum) {
_TestCumSum<int32_t>(CPU);
_TestCumSum<int64_t>(CPU);
#ifdef DGL_USE_CUDA
_TestCumSum<int32_t>(GPU);
_TestCumSum<int64_t>(GPU);
#endif
}
......@@ -17,8 +17,8 @@ aten::CSRMatrix CSR1(DLContext ctx = CTX) {
return aten::CSRMatrix(
4, 5,
aten::VecToIdArray(std::vector<IDX>({0, 2, 3, 5, 5}), sizeof(IDX)*8, ctx),
aten::VecToIdArray(std::vector<IDX>({1, 2, 0, 2, 3}), sizeof(IDX)*8, ctx),
aten::VecToIdArray(std::vector<IDX>({0, 2, 3, 1, 4}), sizeof(IDX)*8, ctx),
aten::VecToIdArray(std::vector<IDX>({1, 2, 0, 3, 2}), sizeof(IDX)*8, ctx),
aten::VecToIdArray(std::vector<IDX>({0, 2, 3, 4, 1}), sizeof(IDX)*8, ctx),
false);
}
......@@ -277,12 +277,23 @@ void _TestCSRToCOO(DLContext ctx) {
auto coo = CSRToCOO(csr, false);
ASSERT_EQ(coo.num_rows, 4);
ASSERT_EQ(coo.num_cols, 5);
ASSERT_TRUE(coo.row_sorted);
auto tr = aten::VecToIdArray(std::vector<IDX>({0, 0, 0, 1, 2, 2}), sizeof(IDX)*8, ctx);
auto tc = aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0, 2, 3}), sizeof(IDX)*8, ctx);
auto td = aten::VecToIdArray(std::vector<IDX>({0, 2, 5, 3, 1, 4}), sizeof(IDX)*8, ctx);
ASSERT_TRUE(ArrayEQ<IDX>(coo.row, tr));
ASSERT_TRUE(ArrayEQ<IDX>(coo.col, tc));
ASSERT_TRUE(ArrayEQ<IDX>(coo.data, td));
ASSERT_TRUE(ArrayEQ<IDX>(coo.col, csr.indices));
ASSERT_TRUE(ArrayEQ<IDX>(coo.data, csr.data));
// convert from sorted csr
auto s_csr = CSRSort(csr);
coo = CSRToCOO(s_csr, false);
ASSERT_EQ(coo.num_rows, 4);
ASSERT_EQ(coo.num_cols, 5);
ASSERT_TRUE(coo.row_sorted);
ASSERT_TRUE(coo.col_sorted);
tr = aten::VecToIdArray(std::vector<IDX>({0, 0, 0, 1, 2, 2}), sizeof(IDX)*8, ctx);
ASSERT_TRUE(ArrayEQ<IDX>(coo.row, tr));
ASSERT_TRUE(ArrayEQ<IDX>(coo.col, s_csr.indices));
ASSERT_TRUE(ArrayEQ<IDX>(coo.data, s_csr.data));
}
{
auto coo = CSRToCOO(csr, true);
......@@ -294,7 +305,7 @@ void _TestCSRToCOO(DLContext ctx) {
}
}
TEST(SpmatTest, TestCSRToCOO) {
TEST(SpmatTest, CSRToCOO) {
_TestCSRToCOO<int32_t>(CPU);
_TestCSRToCOO<int64_t>(CPU);
#if DGL_USE_CUDA
......@@ -303,8 +314,8 @@ TEST(SpmatTest, TestCSRToCOO) {
}
template <typename IDX>
void _TestCSRSliceRows() {
auto csr = CSR2<IDX>();
void _TestCSRSliceRows(DLContext ctx) {
auto csr = CSR2<IDX>(ctx);
auto x = aten::CSRSliceRows(csr, 1, 4);
// [1, 0, 0, 0, 0],
// [0, 0, 1, 1, 0],
......@@ -312,30 +323,34 @@ void _TestCSRSliceRows() {
// data: [3, 1, 4]
ASSERT_EQ(x.num_rows, 3);
ASSERT_EQ(x.num_cols, 5);
auto tp = aten::VecToIdArray(std::vector<IDX>({0, 1, 3, 3}), sizeof(IDX)*8, CTX);
auto ti = aten::VecToIdArray(std::vector<IDX>({0, 2, 3}), sizeof(IDX)*8, CTX);
auto td = aten::VecToIdArray(std::vector<IDX>({3, 1, 4}), sizeof(IDX)*8, CTX);
auto tp = aten::VecToIdArray(std::vector<IDX>({0, 1, 3, 3}), sizeof(IDX)*8, ctx);
auto ti = aten::VecToIdArray(std::vector<IDX>({0, 2, 3}), sizeof(IDX)*8, ctx);
auto td = aten::VecToIdArray(std::vector<IDX>({3, 1, 4}), sizeof(IDX)*8, ctx);
ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));
ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));
ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));
auto r = aten::VecToIdArray(std::vector<IDX>({0, 1, 3}), sizeof(IDX)*8, CTX);
auto r = aten::VecToIdArray(std::vector<IDX>({0, 1, 3}), sizeof(IDX)*8, ctx);
x = aten::CSRSliceRows(csr, r);
// [[0, 1, 2, 0, 0],
// [1, 0, 0, 0, 0],
// [0, 0, 0, 0, 0]]
// data: [0, 2, 5, 3]
tp = aten::VecToIdArray(std::vector<IDX>({0, 3, 4, 4}), sizeof(IDX)*8, CTX);
ti = aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0}), sizeof(IDX)*8, CTX);
td = aten::VecToIdArray(std::vector<IDX>({0, 2, 5, 3}), sizeof(IDX)*8, CTX);
tp = aten::VecToIdArray(std::vector<IDX>({0, 3, 4, 4}), sizeof(IDX)*8, ctx);
ti = aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0}), sizeof(IDX)*8, ctx);
td = aten::VecToIdArray(std::vector<IDX>({0, 2, 5, 3}), sizeof(IDX)*8, ctx);
ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));
ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));
ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));
}
TEST(SpmatTest, TestCSRSliceRows) {
_TestCSRSliceRows<int32_t>();
_TestCSRSliceRows<int64_t>();
_TestCSRSliceRows<int32_t>(CPU);
_TestCSRSliceRows<int64_t>(CPU);
#ifdef DGL_USE_CUDA
_TestCSRSliceRows<int32_t>(GPU);
_TestCSRSliceRows<int64_t>(GPU);
#endif
}
template <typename IDX>
......@@ -376,6 +391,29 @@ TEST(SpmatTest, TestCSRHasDuplicate) {
_TestCSRHasDuplicate<int64_t>();
}
template <typename IDX>
void _TestCSRSort(DLContext ctx) {
auto csr = CSR1<IDX>(ctx);
ASSERT_FALSE(aten::CSRIsSorted(csr));
auto csr1 = aten::CSRSort(csr);
ASSERT_FALSE(aten::CSRIsSorted(csr));
ASSERT_TRUE(aten::CSRIsSorted(csr1));
ASSERT_TRUE(csr1.sorted);
aten::CSRSort_(&csr);
ASSERT_TRUE(aten::CSRIsSorted(csr));
ASSERT_TRUE(csr.sorted);
csr = CSR2<IDX>(ctx);
ASSERT_TRUE(aten::CSRIsSorted(csr));
}
TEST(SpmatTest, CSRSort) {
_TestCSRSort<int32_t>(CPU);
_TestCSRSort<int64_t>(CPU);
#ifdef DGL_USE_CUDA
_TestCSRSort<int32_t>(GPU);
#endif
}
template <typename IDX>
void _TestCOOToCSR(DLContext ctx) {
auto coo = COO1<IDX>(ctx);
......@@ -392,6 +430,7 @@ void _TestCOOToCSR(DLContext ctx) {
ASSERT_EQ(coo.num_cols, csr.num_cols);
ASSERT_TRUE(ArrayEQ<IDX>(csr.indptr, tcsr.indptr));
// Convert from row sorted coo
coo = COO1<IDX>(ctx);
auto rs_coo = aten::COOSort(coo, false);
auto rs_csr = CSR1<IDX>(ctx);
......@@ -399,6 +438,8 @@ void _TestCOOToCSR(DLContext ctx) {
ASSERT_EQ(coo.num_rows, rs_tcsr.num_rows);
ASSERT_EQ(coo.num_cols, rs_tcsr.num_cols);
ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.indptr, rs_tcsr.indptr));
ASSERT_TRUE(ArrayEQ<IDX>(rs_tcsr.indices, rs_coo.col));
ASSERT_TRUE(ArrayEQ<IDX>(rs_tcsr.data, rs_coo.data));
coo = COO3<IDX>(ctx);
rs_coo = aten::COOSort(coo, false);
......@@ -407,16 +448,20 @@ void _TestCOOToCSR(DLContext ctx) {
ASSERT_EQ(coo.num_rows, rs_tcsr.num_rows);
ASSERT_EQ(coo.num_cols, rs_tcsr.num_cols);
ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.indptr, rs_tcsr.indptr));
ASSERT_TRUE(ArrayEQ<IDX>(rs_tcsr.indices, rs_coo.col));
ASSERT_TRUE(ArrayEQ<IDX>(rs_tcsr.data, rs_coo.data));
// Convert from col sorted coo
coo = COO1<IDX>(ctx);
auto src_coo = aten::COOSort(coo, true);
auto src_csr = CSR1<IDX>(ctx);
auto src_tcsr = aten::COOToCSR(src_coo);
ASSERT_EQ(coo.num_rows, src_tcsr.num_rows);
ASSERT_EQ(coo.num_cols, src_tcsr.num_cols);
ASSERT_TRUE(ArrayEQ<IDX>(src_csr.indptr, src_tcsr.indptr));
ASSERT_TRUE(ArrayEQ<IDX>(src_csr.indices, src_tcsr.indices));
ASSERT_TRUE(ArrayEQ<IDX>(src_csr.data, src_tcsr.data));
ASSERT_TRUE(src_tcsr.sorted);
ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.indptr, src_csr.indptr));
ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.indices, src_coo.col));
ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.data, src_coo.data));
coo = COO3<IDX>(ctx);
src_coo = aten::COOSort(coo, true);
......@@ -424,12 +469,13 @@ void _TestCOOToCSR(DLContext ctx) {
src_tcsr = aten::COOToCSR(src_coo);
ASSERT_EQ(coo.num_rows, src_tcsr.num_rows);
ASSERT_EQ(coo.num_cols, src_tcsr.num_cols);
ASSERT_TRUE(ArrayEQ<IDX>(src_csr.indptr, src_tcsr.indptr));
ASSERT_TRUE(ArrayEQ<IDX>(src_csr.indices, src_tcsr.indices));
ASSERT_TRUE(ArrayEQ<IDX>(src_csr.data, src_tcsr.data));
ASSERT_TRUE(src_tcsr.sorted);
ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.indptr, src_csr.indptr));
ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.indices, src_coo.col));
ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.data, src_coo.data));
}
TEST(SpmatTest, TestCOOToCSR) {
TEST(SpmatTest, COOToCSR) {
_TestCOOToCSR<int32_t>(CPU);
_TestCOOToCSR<int64_t>(CPU);
#ifdef DGL_USE_CUDA
......@@ -453,12 +499,37 @@ TEST(SpmatTest, TestCOOHasDuplicate) {
template <typename IDX>
void _TestCOOSort(DLContext ctx) {
auto coo = COO3<IDX>(ctx);
auto sr_coo = COOSort(coo, false);
ASSERT_EQ(coo.num_rows, sr_coo.num_rows);
ASSERT_EQ(coo.num_cols, sr_coo.num_cols);
ASSERT_TRUE(sr_coo.row_sorted);
auto flags = COOIsSorted(sr_coo);
ASSERT_TRUE(flags.first);
flags = COOIsSorted(coo); // original coo should stay the same
ASSERT_FALSE(flags.first);
ASSERT_FALSE(flags.second);
auto src_coo = COOSort(coo, true);
ASSERT_EQ(coo.num_rows, src_coo.num_rows);
ASSERT_EQ(coo.num_cols, src_coo.num_cols);
ASSERT_TRUE(src_coo.row_sorted);
ASSERT_TRUE(src_coo.col_sorted);
flags = COOIsSorted(src_coo);
ASSERT_TRUE(flags.first);
ASSERT_TRUE(flags.second);
// sort inplace
COOSort_(&coo);
ASSERT_TRUE(coo.row_sorted);
flags = COOIsSorted(coo);
ASSERT_TRUE(flags.first);
COOSort_(&coo, true);
ASSERT_TRUE(coo.row_sorted);
ASSERT_TRUE(coo.col_sorted);
flags = COOIsSorted(coo);
ASSERT_TRUE(flags.first);
ASSERT_TRUE(flags.second);
// COO3
// [[0, 1, 2, 0, 0],
......@@ -489,7 +560,7 @@ void _TestCOOSort(DLContext ctx) {
ASSERT_TRUE(ArrayEQ<IDX>(src_coo.data, sort_col_data));
}
TEST(SpmatTest, TestCOOSort) {
TEST(SpmatTest, COOSort) {
_TestCOOSort<int32_t>(CPU);
_TestCOOSort<int64_t>(CPU);
#ifdef DGL_USE_CUDA
......
Subproject commit c3cceac115c072fb63df1836ff46d8c60d9eb304
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment