"docs/vscode:/vscode.git/clone" did not exist on "0974b4c6067165434fa715654b355b41beb5fceb"
Unverified Commit b0d9e7aa authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Refactor] Separating graph and sparse matrix operations (#699)

* WIP: array refactoring

* WIP: implementation

* wip

* most csr part

* WIP: on coo

* WIP: coo

* finish refactoring immutable graph

* compiled

* fix undefined ndarray copy bug; add COOToCSR when coo has no data array

* fix bug in COOToCSR

* fix bug in CSR constructor

* fix bug in in_edges(vid)

* fix OutEdges bug

* pass test_graph

* pass test_graph

* fix bug in CSR constructor

* fix bug in CSR constructor

* fix bug in CSR constructor

* fix stupid bug

* pass gpu test

* remove debug printout

* fix lint

* rm biparate grpah

* fix lint

* address comments

* fix bug in Clone

* cpp utests
parent f79188da
...@@ -65,9 +65,8 @@ def test_query(): ...@@ -65,9 +65,8 @@ def test_query():
assert g.has_node(i) assert g.has_node(i)
assert i in g assert i in g
assert not g.has_node(11) assert not g.has_node(11)
assert not g.has_node(-1) assert not 11 in g
assert not -1 in g assert F.allclose(g.has_nodes([0,2,10,11]), F.tensor([1,1,0,0]))
assert F.allclose(g.has_nodes([-1,0,2,10,11]), F.tensor([0,1,1,0,0]))
src, dst = edge_pair_input() src, dst = edge_pair_input()
for u, v in zip(src, dst): for u, v in zip(src, dst):
...@@ -137,9 +136,8 @@ def test_query(): ...@@ -137,9 +136,8 @@ def test_query():
assert g.has_node(i) assert g.has_node(i)
assert i in g assert i in g
assert not g.has_node(11) assert not g.has_node(11)
assert not g.has_node(-1) assert not 11 in g
assert not -1 in g assert F.allclose(g.has_nodes([0,2,10,11]), F.tensor([1,1,0,0]))
assert F.allclose(g.has_nodes([-1,0,2,10,11]), F.tensor([0,1,1,0,0]))
src, dst = edge_pair_input(sort=True) src, dst = edge_pair_input(sort=True)
for u, v in zip(src, dst): for u, v in zip(src, dst):
......
#ifndef TEST_COMMON_H_
#define TEST_COMMON_H_
#include <dgl/runtime/ndarray.h>
template <typename T>
inline T* Ptr(dgl::runtime::NDArray nd) {
return static_cast<T*>(nd->data);
}
inline int64_t* PI64(dgl::runtime::NDArray nd) {
return static_cast<int64_t*>(nd->data);
}
inline int32_t* PI32(dgl::runtime::NDArray nd) {
return static_cast<int32_t*>(nd->data);
}
inline int64_t Len(dgl::runtime::NDArray nd) {
return nd->shape[0];
}
template <typename T>
inline bool ArrayEQ(dgl::runtime::NDArray a1, dgl::runtime::NDArray a2) {
if (a1->ndim != a2->ndim) return false;
int64_t num = 1;
for (int i = 0; i < a1->ndim; ++i) {
if (a1->shape[i] != a2->shape[i])
return false;
num *= a1->shape[i];
}
for (int64_t i = 0; i < num; ++i)
if (static_cast<T*>(a1->data)[i] != static_cast<T*>(a2->data)[i])
return false;
return true;
}
static constexpr DLContext CTX = DLContext{kDLCPU, 0};
#endif // TEST_COMMON_H_
#include <gtest/gtest.h>
#include <dgl/array.h>
#include "./common.h"
using namespace dgl;
using namespace dgl::runtime;
TEST(ArrayTest, TestCreate) {
IdArray a = aten::NewIdArray(100, CTX, 32);
ASSERT_EQ(a->dtype.bits, 32);
ASSERT_EQ(a->shape[0], 100);
a = aten::NewIdArray(0);
ASSERT_EQ(a->shape[0], 0);
std::vector<int64_t> vec = {2, 94, 232, 30};
a = aten::VecToIdArray(vec, 32);
ASSERT_EQ(Len(a), vec.size());
ASSERT_EQ(a->dtype.bits, 32);
for (int i = 0; i < Len(a); ++i) {
ASSERT_EQ(Ptr<int32_t>(a)[i], vec[i]);
}
a = aten::VecToIdArray(std::vector<int32_t>());
ASSERT_EQ(Len(a), 0);
};
TEST(ArrayTest, TestRange) {
IdArray a = aten::Range(10, 10, 64, CTX);
ASSERT_EQ(Len(a), 0);
a = aten::Range(10, 20, 32, CTX);
ASSERT_EQ(Len(a), 10);
ASSERT_EQ(a->dtype.bits, 32);
for (int i = 0; i < 10; ++i)
ASSERT_EQ(Ptr<int32_t>(a)[i], i + 10);
};
TEST(ArrayTest, TestFull) {
IdArray a = aten::Full(-100, 0, 32, CTX);
ASSERT_EQ(Len(a), 0);
a = aten::Full(-100, 13, 64, CTX);
ASSERT_EQ(Len(a), 13);
ASSERT_EQ(a->dtype.bits, 64);
for (int i = 0; i < 13; ++i)
ASSERT_EQ(Ptr<int64_t>(a)[i], -100);
};
TEST(ArrayTest, TestClone) {
IdArray a = aten::NewIdArray(0);
IdArray b = aten::Clone(a);
ASSERT_EQ(Len(b), 0);
a = aten::Range(0, 10, 32, CTX);
b = aten::Clone(a);
for (int i = 0; i < 10; ++i) {
ASSERT_EQ(PI32(b)[i], i);
}
PI32(b)[0] = -1;
for (int i = 0; i < 10; ++i) {
ASSERT_EQ(PI32(a)[i], i);
}
};
TEST(ArrayTest, TestAsNumBits) {
IdArray a = aten::Range(0, 10, 32, CTX);
a = aten::AsNumBits(a, 64);
ASSERT_EQ(a->dtype.bits, 64);
for (int i = 0; i < 10; ++i)
ASSERT_EQ(PI64(a)[i], i);
};
template <typename IDX>
void _TestArith() {
const int N = 100;
IdArray a = aten::Full(-10, N, sizeof(IDX)*8, CTX);
IdArray b = aten::Full(7, N, sizeof(IDX)*8, CTX);
IdArray c = aten::Add(a, b);
for (int i = 0; i < N; ++i)
ASSERT_EQ(Ptr<IDX>(c)[i], -3);
c = aten::Sub(a, b);
for (int i = 0; i < N; ++i)
ASSERT_EQ(Ptr<IDX>(c)[i], -17);
c = aten::Mul(a, b);
for (int i = 0; i < N; ++i)
ASSERT_EQ(Ptr<IDX>(c)[i], -70);
c = aten::Div(a, b);
for (int i = 0; i < N; ++i)
ASSERT_EQ(Ptr<IDX>(c)[i], -1);
const int val = -3;
c = aten::Add(a, val);
for (int i = 0; i < N; ++i)
ASSERT_EQ(Ptr<IDX>(c)[i], -13);
c = aten::Sub(a, val);
for (int i = 0; i < N; ++i)
ASSERT_EQ(Ptr<IDX>(c)[i], -7);
c = aten::Mul(a, val);
for (int i = 0; i < N; ++i)
ASSERT_EQ(Ptr<IDX>(c)[i], 30);
c = aten::Div(a, val);
for (int i = 0; i < N; ++i)
ASSERT_EQ(Ptr<IDX>(c)[i], 3);
c = aten::Add(val, b);
for (int i = 0; i < N; ++i)
ASSERT_EQ(Ptr<IDX>(c)[i], 4);
c = aten::Sub(val, b);
for (int i = 0; i < N; ++i)
ASSERT_EQ(Ptr<IDX>(c)[i], -10);
c = aten::Mul(val, b);
for (int i = 0; i < N; ++i)
ASSERT_EQ(Ptr<IDX>(c)[i], -21);
c = aten::Div(val, b);
for (int i = 0; i < N; ++i)
ASSERT_EQ(Ptr<IDX>(c)[i], 0);
a = aten::Range(0, N, sizeof(IDX)*8, CTX);
c = aten::LT(a, 50);
for (int i = 0; i < N; ++i)
ASSERT_EQ(Ptr<IDX>(c)[i], (int)(i < 50));
}
TEST(ArrayTest, TestArith) {
_TestArith<int32_t>();
_TestArith<int64_t>();
};
template <typename IDX>
void _TestHStack() {
IdArray a = aten::Range(0, 100, sizeof(IDX)*8, CTX);
IdArray b = aten::Range(100, 200, sizeof(IDX)*8, CTX);
IdArray c = aten::HStack(a, b);
ASSERT_EQ(c->ndim, 1);
ASSERT_EQ(c->shape[0], 200);
for (int i = 0; i < 200; ++i)
ASSERT_EQ(Ptr<IDX>(c)[i], i);
}
TEST(ArrayTest, TestHStack) {
_TestHStack<int32_t>();
_TestHStack<int64_t>();
}
template <typename IDX>
void _TestIndexSelect() {
IdArray a = aten::Range(0, 100, sizeof(IDX)*8, CTX);
ASSERT_EQ(aten::IndexSelect(a, 50), 50);
IdArray b = aten::VecToIdArray(std::vector<IDX>({0, 20, 10}), sizeof(IDX)*8, CTX);
IdArray c = aten::IndexSelect(a, b);
ASSERT_TRUE(ArrayEQ<IDX>(b, c));
}
TEST(ArrayTest, TestIndexSelect) {
_TestIndexSelect<int32_t>();
_TestIndexSelect<int64_t>();
}
template <typename IDX>
void _TestRelabel_() {
IdArray a = aten::VecToIdArray(std::vector<IDX>({0, 20, 10}), sizeof(IDX)*8, CTX);
IdArray b = aten::VecToIdArray(std::vector<IDX>({20, 5, 6}), sizeof(IDX)*8, CTX);
IdArray c = aten::Relabel_({a, b});
IdArray ta = aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX)*8, CTX);
IdArray tb = aten::VecToIdArray(std::vector<IDX>({1, 3, 4}), sizeof(IDX)*8, CTX);
IdArray tc = aten::VecToIdArray(std::vector<IDX>({0, 20, 10, 5, 6}), sizeof(IDX)*8, CTX);
ASSERT_TRUE(ArrayEQ<IDX>(a, ta));
ASSERT_TRUE(ArrayEQ<IDX>(b, tb));
ASSERT_TRUE(ArrayEQ<IDX>(c, tc));
}
TEST(ArrayTest, TestRelabel_) {
_TestRelabel_<int32_t>();
_TestRelabel_<int64_t>();
}
#include <gtest/gtest.h>
#include <dgl/array.h>
#include "./common.h"
using namespace dgl;
using namespace dgl::runtime;
namespace {
template <typename IDX>
aten::CSRMatrix CSR1() {
// [[0, 1, 1, 0, 0],
// [1, 0, 0, 0, 0],
// [0, 0, 1, 1, 0],
// [0, 0, 0, 0, 0]]
// data: [0, 2, 3, 1, 4]
aten::CSRMatrix csr;
csr.num_rows = 4;
csr.num_cols = 5;
csr.indptr = aten::VecToIdArray(std::vector<IDX>({0, 2, 3, 5, 5}), sizeof(IDX)*8, CTX);
csr.indices = aten::VecToIdArray(std::vector<IDX>({1, 2, 0, 2, 3}), sizeof(IDX)*8, CTX);
csr.data = aten::VecToIdArray(std::vector<IDX>({0, 2, 3, 1, 4}), sizeof(IDX)*8, CTX);
return csr;
}
template <typename IDX>
aten::CSRMatrix CSR2() {
// has duplicate entries
// [[0, 1, 2, 0, 0],
// [1, 0, 0, 0, 0],
// [0, 0, 1, 1, 0],
// [0, 0, 0, 0, 0]]
// data: [0, 2, 5, 3, 1, 4]
aten::CSRMatrix csr;
csr.num_rows = 4;
csr.num_cols = 5;
csr.indptr = aten::VecToIdArray(std::vector<IDX>({0, 3, 4, 6, 6}), sizeof(IDX)*8, CTX);
csr.indices = aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0, 2, 3}), sizeof(IDX)*8, CTX);
csr.data = aten::VecToIdArray(std::vector<IDX>({0, 2, 5, 3, 1, 4}), sizeof(IDX)*8, CTX);
return csr;
}
template <typename IDX>
aten::COOMatrix COO1() {
// [[0, 1, 1, 0, 0],
// [1, 0, 0, 0, 0],
// [0, 0, 1, 1, 0],
// [0, 0, 0, 0, 0]]
// data: [0, 2, 3, 1, 4]
// row : [0, 2, 0, 1, 2]
// col : [1, 2, 2, 0, 3]
aten::COOMatrix coo;
coo.num_rows = 4;
coo.num_cols = 5;
coo.row = aten::VecToIdArray(std::vector<IDX>({0, 2, 0, 1, 2}), sizeof(IDX)*8, CTX);
coo.col = aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0, 3}), sizeof(IDX)*8, CTX);
return coo;
}
template <typename IDX>
aten::COOMatrix COO2() {
// has duplicate entries
// [[0, 1, 2, 0, 0],
// [1, 0, 0, 0, 0],
// [0, 0, 1, 1, 0],
// [0, 0, 0, 0, 0]]
// data: [0, 2, 5, 3, 1, 4]
// row : [0, 2, 0, 1, 2, 0]
// col : [1, 2, 2, 0, 3, 2]
aten::COOMatrix coo;
coo.num_rows = 4;
coo.num_cols = 5;
coo.row = aten::VecToIdArray(std::vector<IDX>({0, 2, 0, 1, 2, 0}), sizeof(IDX)*8, CTX);
coo.col = aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0, 3, 2}), sizeof(IDX)*8, CTX);
return coo;
}
}
template <typename IDX>
void _TestCSRIsNonZero() {
auto csr = CSR1<IDX>();
ASSERT_TRUE(aten::CSRIsNonZero(csr, 0, 1));
ASSERT_FALSE(aten::CSRIsNonZero(csr, 0, 0));
IdArray r = aten::VecToIdArray(std::vector<IDX>({2, 2, 0, 0}), sizeof(IDX)*8, CTX);
IdArray c = aten::VecToIdArray(std::vector<IDX>({1, 1, 1, 3}), sizeof(IDX)*8, CTX);
IdArray x = aten::CSRIsNonZero(csr, r, c);
IdArray tx = aten::VecToIdArray(std::vector<IDX>({0, 0, 1, 0}), sizeof(IDX)*8, CTX);
ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
}
TEST(SpmatTest, TestCSRIsNonZero) {
_TestCSRIsNonZero<int32_t>();
_TestCSRIsNonZero<int64_t>();
}
template <typename IDX>
void _TestCSRGetRowNNZ() {
auto csr = CSR2<IDX>();
ASSERT_EQ(aten::CSRGetRowNNZ(csr, 0), 3);
ASSERT_EQ(aten::CSRGetRowNNZ(csr, 3), 0);
IdArray r = aten::VecToIdArray(std::vector<IDX>({0, 3}), sizeof(IDX)*8, CTX);
IdArray x = aten::CSRGetRowNNZ(csr, r);
IdArray tx = aten::VecToIdArray(std::vector<IDX>({3, 0}), sizeof(IDX)*8, CTX);
ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
}
TEST(SpmatTest, TestCSRGetRowNNZ) {
_TestCSRGetRowNNZ<int32_t>();
_TestCSRGetRowNNZ<int64_t>();
}
template <typename IDX>
void _TestCSRGetRowColumnIndices() {
auto csr = CSR2<IDX>();
auto x = aten::CSRGetRowColumnIndices(csr, 0);
auto tx = aten::VecToIdArray(std::vector<IDX>({1, 2, 2}), sizeof(IDX)*8, CTX);
ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
x = aten::CSRGetRowColumnIndices(csr, 1);
tx = aten::VecToIdArray(std::vector<IDX>({0}), sizeof(IDX)*8, CTX);
ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
x = aten::CSRGetRowColumnIndices(csr, 3);
tx = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX)*8, CTX);
ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
}
TEST(SpmatTest, TestCSRGetRowColumnIndices) {
_TestCSRGetRowColumnIndices<int32_t>();
_TestCSRGetRowColumnIndices<int64_t>();
}
template <typename IDX>
void _TestCSRGetRowData() {
auto csr = CSR2<IDX>();
auto x = aten::CSRGetRowData(csr, 0);
auto tx = aten::VecToIdArray(std::vector<IDX>({0, 2, 5}), sizeof(IDX)*8, CTX);
ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
x = aten::CSRGetRowData(csr, 1);
tx = aten::VecToIdArray(std::vector<IDX>({3}), sizeof(IDX)*8, CTX);
ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
x = aten::CSRGetRowData(csr, 3);
tx = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX)*8, CTX);
ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
}
TEST(SpmatTest, TestCSRGetRowData) {
_TestCSRGetRowData<int32_t>();
_TestCSRGetRowData<int64_t>();
}
template <typename IDX>
void _TestCSRGetData() {
auto csr = CSR2<IDX>();
auto x = aten::CSRGetData(csr, 0, 0);
auto tx = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX)*8, CTX);
ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
x = aten::CSRGetData(csr, 0, 2);
tx = aten::VecToIdArray(std::vector<IDX>({2, 5}), sizeof(IDX)*8, CTX);
ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
auto r = aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX)*8, CTX);
auto c = aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX)*8, CTX);
x = aten::CSRGetData(csr, r, c);
tx = aten::VecToIdArray(std::vector<IDX>({0, 2, 5}), sizeof(IDX)*8, CTX);
ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
}
TEST(SpmatTest, TestCSRGetData) {
_TestCSRGetData<int32_t>();
_TestCSRGetData<int64_t>();
}
template <typename IDX>
void _TestCSRGetDataAndIndices() {
auto csr = CSR2<IDX>();
auto r = aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX)*8, CTX);
auto c = aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX)*8, CTX);
auto x = aten::CSRGetDataAndIndices(csr, r, c);
auto tr = aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX)*8, CTX);
auto tc = aten::VecToIdArray(std::vector<IDX>({1, 2, 2}), sizeof(IDX)*8, CTX);
auto td = aten::VecToIdArray(std::vector<IDX>({0, 2, 5}), sizeof(IDX)*8, CTX);
ASSERT_TRUE(ArrayEQ<IDX>(x[0], tr));
ASSERT_TRUE(ArrayEQ<IDX>(x[1], tc));
ASSERT_TRUE(ArrayEQ<IDX>(x[2], td));
}
TEST(SpmatTest, TestCSRGetDataAndIndices) {
_TestCSRGetDataAndIndices<int32_t>();
_TestCSRGetDataAndIndices<int64_t>();
}
template <typename IDX>
void _TestCSRTranspose() {
auto csr = CSR2<IDX>();
auto csr_t = aten::CSRTranspose(csr);
// [[0, 1, 0, 0],
// [1, 0, 0, 0],
// [2, 0, 1, 0],
// [0, 0, 1, 0],
// [0, 0, 0, 0]]
// data: [3, 0, 2, 5, 1, 4]
ASSERT_EQ(csr_t.num_rows, 5);
ASSERT_EQ(csr_t.num_cols, 4);
auto tp = aten::VecToIdArray(std::vector<IDX>({0, 1, 2, 5, 6, 6}), sizeof(IDX)*8, CTX);
auto ti = aten::VecToIdArray(std::vector<IDX>({1, 0, 0, 0, 2, 2}), sizeof(IDX)*8, CTX);
auto td = aten::VecToIdArray(std::vector<IDX>({3, 0, 2, 5, 1, 4}), sizeof(IDX)*8, CTX);
ASSERT_TRUE(ArrayEQ<IDX>(csr_t.indptr, tp));
ASSERT_TRUE(ArrayEQ<IDX>(csr_t.indices, ti));
ASSERT_TRUE(ArrayEQ<IDX>(csr_t.data, td));
}
TEST(SpmatTest, TestCSRTranspose) {
_TestCSRTranspose<int32_t>();
_TestCSRTranspose<int64_t>();
}
template <typename IDX>
void _TestCSRToCOO() {
auto csr = CSR2<IDX>();
{
auto coo = CSRToCOO(csr, false);
ASSERT_EQ(coo.num_rows, 4);
ASSERT_EQ(coo.num_cols, 5);
auto tr = aten::VecToIdArray(std::vector<IDX>({0, 0, 0, 1, 2, 2}), sizeof(IDX)*8, CTX);
auto tc = aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0, 2, 3}), sizeof(IDX)*8, CTX);
auto td = aten::VecToIdArray(std::vector<IDX>({0, 2, 5, 3, 1, 4}), sizeof(IDX)*8, CTX);
ASSERT_TRUE(ArrayEQ<IDX>(coo.row, tr));
ASSERT_TRUE(ArrayEQ<IDX>(coo.col, tc));
ASSERT_TRUE(ArrayEQ<IDX>(coo.data, td));
}
{
auto coo = CSRToCOO(csr, true);
ASSERT_EQ(coo.num_rows, 4);
ASSERT_EQ(coo.num_cols, 5);
auto tcoo = COO2<IDX>();
ASSERT_TRUE(ArrayEQ<IDX>(coo.row, tcoo.row));
ASSERT_TRUE(ArrayEQ<IDX>(coo.col, tcoo.col));
}
}
TEST(SpmatTest, TestCSRToCOO) {
_TestCSRToCOO<int32_t>();
_TestCSRToCOO<int64_t>();
}
template <typename IDX>
void _TestCSRSliceRows() {
auto csr = CSR2<IDX>();
auto x = aten::CSRSliceRows(csr, 1, 4);
// [1, 0, 0, 0, 0],
// [0, 0, 1, 1, 0],
// [0, 0, 0, 0, 0]]
// data: [3, 1, 4]
ASSERT_EQ(x.num_rows, 3);
ASSERT_EQ(x.num_cols, 5);
auto tp = aten::VecToIdArray(std::vector<IDX>({0, 1, 3, 3}), sizeof(IDX)*8, CTX);
auto ti = aten::VecToIdArray(std::vector<IDX>({0, 2, 3}), sizeof(IDX)*8, CTX);
auto td = aten::VecToIdArray(std::vector<IDX>({3, 1, 4}), sizeof(IDX)*8, CTX);
ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));
ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));
ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));
auto r = aten::VecToIdArray(std::vector<IDX>({0, 1, 3}), sizeof(IDX)*8, CTX);
x = aten::CSRSliceRows(csr, r);
// [[0, 1, 2, 0, 0],
// [1, 0, 0, 0, 0],
// [0, 0, 0, 0, 0]]
// data: [0, 2, 5, 3]
tp = aten::VecToIdArray(std::vector<IDX>({0, 3, 4, 4}), sizeof(IDX)*8, CTX);
ti = aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0}), sizeof(IDX)*8, CTX);
td = aten::VecToIdArray(std::vector<IDX>({0, 2, 5, 3}), sizeof(IDX)*8, CTX);
ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));
ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));
ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));
}
TEST(SpmatTest, TestCSRSliceRows) {
_TestCSRSliceRows<int32_t>();
_TestCSRSliceRows<int64_t>();
}
template <typename IDX>
void _TestCSRSliceMatrix() {
auto csr = CSR2<IDX>();
auto r = aten::VecToIdArray(std::vector<IDX>({0, 1, 3}), sizeof(IDX)*8, CTX);
auto c = aten::VecToIdArray(std::vector<IDX>({1, 2, 3}), sizeof(IDX)*8, CTX);
auto x = aten::CSRSliceMatrix(csr, r, c);
// [[1, 2, 0],
// [0, 0, 0],
// [0, 0, 0]]
// data: [0, 2, 5]
ASSERT_EQ(x.num_rows, 3);
ASSERT_EQ(x.num_cols, 3);
auto tp = aten::VecToIdArray(std::vector<IDX>({0, 3, 3, 3}), sizeof(IDX)*8, CTX);
auto ti = aten::VecToIdArray(std::vector<IDX>({0, 1, 1}), sizeof(IDX)*8, CTX);
auto td = aten::VecToIdArray(std::vector<IDX>({0, 2, 5}), sizeof(IDX)*8, CTX);
ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));
ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));
ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));
}
TEST(SpmatTest, TestCSRSliceMatrix) {
_TestCSRSliceMatrix<int32_t>();
_TestCSRSliceMatrix<int64_t>();
}
template <typename IDX>
void _TestCSRHasDuplicate() {
auto csr = CSR1<IDX>();
ASSERT_FALSE(aten::CSRHasDuplicate(csr));
csr = CSR2<IDX>();
ASSERT_TRUE(aten::CSRHasDuplicate(csr));
}
TEST(SpmatTest, TestCSRHasDuplicate) {
_TestCSRHasDuplicate<int32_t>();
_TestCSRHasDuplicate<int64_t>();
}
template <typename IDX>
void _TestCOOToCSR() {
auto coo = COO1<IDX>();
auto csr = CSR1<IDX>();
auto tcsr = aten::COOToCSR(coo);
ASSERT_EQ(coo.num_rows, csr.num_rows);
ASSERT_EQ(coo.num_cols, csr.num_cols);
ASSERT_TRUE(ArrayEQ<IDX>(csr.indptr, tcsr.indptr));
ASSERT_TRUE(ArrayEQ<IDX>(csr.indices, tcsr.indices));
ASSERT_TRUE(ArrayEQ<IDX>(csr.data, tcsr.data));
coo = COO2<IDX>();
csr = CSR2<IDX>();
tcsr = aten::COOToCSR(coo);
ASSERT_EQ(coo.num_rows, csr.num_rows);
ASSERT_EQ(coo.num_cols, csr.num_cols);
ASSERT_TRUE(ArrayEQ<IDX>(csr.indptr, tcsr.indptr));
ASSERT_TRUE(ArrayEQ<IDX>(csr.indices, tcsr.indices));
ASSERT_TRUE(ArrayEQ<IDX>(csr.data, tcsr.data));
}
TEST(SpmatTest, TestCOOToCSR) {
_TestCOOToCSR<int32_t>();
_TestCOOToCSR<int64_t>();
}
template <typename IDX>
void _TestCOOHasDuplicate() {
auto csr = COO1<IDX>();
ASSERT_FALSE(aten::COOHasDuplicate(csr));
csr = COO2<IDX>();
ASSERT_TRUE(aten::COOHasDuplicate(csr));
}
TEST(SpmatTest, TestCOOHasDuplicate) {
_TestCOOHasDuplicate<int32_t>();
_TestCOOHasDuplicate<int64_t>();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment