Unverified Commit 0e92dad6 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Transfom] Add support for to_bidirected of heterogeneous graphs (#1793)



* code

* joint_union

* joint union pass all test

* test case for test_transform

* upd

* lint

* lint

* Fix

* list

* Fix

* Fix

* update to_bidirected impl

* Fix

* Fix

* upd

* remove joint_union
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
parent 92321347
......@@ -370,21 +370,21 @@ class BaseHeteroGraph : public runtime::Object {
/*!
* \brief Get restrict sparse format of the graph.
*
*
* \return a string representing the sparse format: 'coo'/'csr'/'csc'/'any'
*/
virtual std::string GetRestrictFormat() const = 0;
/*!
* \brief Return the sparse format in use for the graph.
*
* \return a number of type dgl_format_code_t.
*
* \return a number of type dgl_format_code_t.
*/
virtual dgl_format_code_t GetFormatInUse() const = 0;
/*!
* \brief Return the graph in specified restrict format.
*
*
* \return The new graph.
*/
virtual HeteroGraphPtr GetGraphInFormat(SparseFormat restrict_format) const = 0;
......@@ -421,7 +421,7 @@ class BaseHeteroGraph : public runtime::Object {
/*!
* \brief Extract the induced subgraph by the given vertices.
*
*
* The length of the given vector should be equal to the number of vertex types.
* Empty arrays can be provided if no vertex is needed for the type. The result
* subgraph has the same meta graph with the parent, but some types can have no
......@@ -434,7 +434,7 @@ class BaseHeteroGraph : public runtime::Object {
/*!
* \brief Extract the induced subgraph by the given edges.
*
*
* The length of the given vector should be equal to the number of edge types.
* Empty arrays can be provided if no edge is needed for the type. The result
* subgraph has the same meta graph with the parent, but some types can have no
......@@ -479,7 +479,7 @@ class BaseHeteroGraph : public runtime::Object {
// Define HeteroGraphRef
DGL_DEFINE_OBJECT_REF(HeteroGraphRef, BaseHeteroGraph);
/*!
/*!
* \brief Hetero-subgraph data structure.
*
* This class can be used as arguments and return values of a C API.
......@@ -682,13 +682,27 @@ HeteroSubgraph InEdgeGraph(const HeteroGraphPtr graph, const std::vector<IdArray
*/
HeteroSubgraph OutEdgeGraph(const HeteroGraphPtr graph, const std::vector<IdArray>& nodes);
/*!
* \brief Joint union multiple graphs into one graph.
*
* All input graphs should have the same metagraph.
*
* TODO(xiangsx): remove the meta_graph argument
*
* \param meta_graph Metagraph of the inputs and result.
* \param component_graphs Input graphs
* \return One graph that unions all the components
*/
HeteroGraphPtr JointUnionHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs);
/*!
* \brief Union multiple graphs into one with each input graph as one disjoint component.
*
* All input graphs should have the same metagraph.
*
* TODO(minjie): remove the meta_graph argument
*
*
* \tparam IdType Graph's index data type, can be int32_t or int64_t
* \param meta_graph Metagraph of the inputs and result.
* \param component_graphs Input graphs
......@@ -735,7 +749,7 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(
IdArray vertex_sizes,
IdArray edge_sizes);
/*!
/*!
* \brief Structure for pickle/unpickle.
*
* The design principle is to leverage the NDArray class as much as possible so
......@@ -751,7 +765,7 @@ struct HeteroPickleStates : public runtime::Object {
/*! \brief version number */
int64_t version = 0;
/*! \brief Metainformation
/*! \brief Metainformation
*
* metagraph, number of nodes per type, format, flags
*/
......@@ -760,7 +774,7 @@ struct HeteroPickleStates : public runtime::Object {
/*! \brief Arrays representing graph structure (coo or csr) */
std::vector<IdArray> arrays;
/* To support backward compatibility, we have to retain fields in the old
/* To support backward compatibility, we have to retain fields in the old
* version of HeteroPickleStates
*/
......
......@@ -1127,6 +1127,23 @@ def create_heterograph_from_relations(metagraph, rel_graphs, num_nodes_per_type)
return _CAPI_DGLHeteroCreateHeteroGraphWithNumNodes(
metagraph, rel_graphs, num_nodes_per_type.todgltensor())
def joint_union(metagraph, gidx_list):
"""Return a joint union of the input heterographs.
Parameters
----------
metagraph : GraphIndex
Meta-graph.
gidx_list : list of HeteroGraphIndex
Heterographs to be joint_unioned.
Returns
-------
HeteroGraphIndex
joint_unioned Heterograph.
"""
return _CAPI_DGLHeteroJointUnion(metagraph, gidx_list)
def disjoint_union(metagraph, graphs):
"""Return a disjoint union of the input heterographs.
......
......@@ -12,7 +12,7 @@ from . import backend as F
from .graph_index import from_coo
from .graph_index import _get_halo_subgraph_inner_node
from .graph import unbatch
from .convert import graph, bipartite
from .convert import graph, bipartite, heterograph
from . import utils
from .base import EID, NID
from . import ndarray as nd
......@@ -27,6 +27,7 @@ __all__ = [
'reverse_heterograph',
'to_simple_graph',
'to_bidirected',
'to_bidirected_stale',
'laplacian_lambda_max',
'knn_graph',
'segmented_knn_graph',
......@@ -143,6 +144,154 @@ def segmented_knn_graph(x, k, segs):
g = DGLGraph(adj, readonly=True)
return g
def to_bidirected(g, readonly=None, copy_ndata=True,
copy_edata=False, ignore_bipartite=False):
r"""Convert the graph to a bidirected one.
For a graph with edges :math:`(i_1, j_1), \cdots, (i_n, j_n)`, this
function creates a new graph with edges
:math:`(i_1, j_1), \cdots, (i_n, j_n), (j_1, i_1), \cdots, (j_n, i_n)`.
For a heterograph with multiple edge types, we can treat edges corresponding
to each type as a separate graph and convert the graph to a bidirected one
for each of them.
Since **to_bidirected is not well defined for unidirectional bipartite graphs**,
an error will be raised if an edge type of the input heterograph is for a
unidirectional bipartite graph. We can simply skip the edge types corresponding
to unidirectional bipartite graphs by specifying ``ignore_bipartite=True``.
Parameters
----------
g : DGLGraph
The input graph.
readonly : bool, default to be True
Deprecated. There will be no difference between readonly and non-readonly
copy_ndata: bool, optional
If True, the node features of the bidirected graph are copied from
the original graph. If False, the bidirected
graph will not have any node features.
(Default: True)
copy_edata: bool, optional
If True, the features of the reversed edges will be identical to
the original ones."
If False, the bidirected graph will not have any edge
features.
(Default: False)
ignore_bipartite: bool, optional
If True, unidirectional bipartite graphs are ignored and
no error is raised. If False, an error will be raised if
an edge type of the input heterograph is for a unidirectional
bipartite graph.
Returns
-------
DGLGraph
The bidirected graph
Notes
-----
If ``copy_ndata`` is ``True``, same tensors are used as
the node features of the original graph and the new graph.
As a result, users should avoid performing in-place operations
on the node features of the new graph to avoid feature corruption.
On the contrary, edge features are concatenated,
and they are not shared due to concatenation.
For concrete examples, refer to the ``Examples`` section below.
Examples
--------
**Homographs**
>>> g = dgl.graph(th.tensor([0, 0]), th.tensor([0, 1]))
>>> bg1 = dgl.to_bidirected(g)
>>> bg1.edges()
(tensor([0, 0, 0, 1]), tensor([0, 1, 0, 0]))
To remove duplicate edges, see :func:to_simple
**Heterographs with Multiple Edge Types**
g = dgl.heterograph({
>>> ('user', 'wins', 'user'): (th.tensor([0, 2, 0, 2, 2]), th.tensor([1, 1, 2, 1, 0])),
>>> ('user', 'plays', 'game'): (th.tensor([1, 2, 1]), th.tensor([2, 1, 1])),
>>> ('user', 'follows', 'user'): (th.tensor([1, 2, 1), th.tensor([0, 0, 0]))
>>> })
>>> g.nodes['game'].data['hv'] = th.ones(3, 1)
>>> g.edges['wins'].data['h'] = th.tensor([0, 1, 2, 3, 4])
The to_bidirected operation is applied to the subgraph
corresponding to ('user', 'wins', 'user') and the
subgraph corresponding to ('user', 'follows', 'user).
The unidirectional bipartite subgraph ('user', 'plays', 'game')
is ignored. Both the node features and edge features
are shared.
>>> bg = dgl.to_bidirected(g, copy_ndata=True,
copy_edata=True, ignore_bipartite=True)
>>> bg.edges(('user', 'wins', 'user'))
(tensor([0, 2, 0, 2, 2, 1, 1, 2, 1, 0]), tensor([1, 1, 2, 1, 0, 0, 2, 0, 2, 2]))
>>> bg.edges(('user', 'follows', 'user'))
(tensor([1, 2, 1, 0, 0, 0]), tensor([0, 0, 0, 1, 2, 1]))
>>> bg.edges(('user', 'plays', 'game'))
(th.tensor([1, 2, 1]), th.tensor([2, 1, 1]))
>>> bg.nodes['game'].data['hv']
tensor([0, 0, 0])
>>> bg.edges[('user', 'wins', 'user')].data['h']
th.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3, 4])
"""
if readonly is not None:
dgl_warning("Parameter readonly is deprecated" \
"There will be no difference between readonly and non-readonly DGLGraph")
canonical_etypes = g.canonical_etypes
# fast path
if ignore_bipartite is False:
subgs = {}
for c_etype in canonical_etypes:
if c_etype[0] != c_etype[2]:
assert False, "to_bidirected is not well defined for " \
"unidirectional bipartite graphs" \
", but {} is unidirectional bipartite".format(c_etype)
u, v = g.edges(form='uv', order='eid', etype=c_etype)
subgs[c_etype] = (F.cat([u, v], dim=0), F.cat([v, u], dim=0))
new_g = heterograph(subgs)
else:
subgs = {}
for c_etype in canonical_etypes:
if c_etype[0] != c_etype[2]:
u, v = g.edges(form='uv', order='eid', etype=c_etype)
subgs[c_etype] = (u, v)
else:
u, v = g.edges(form='uv', order='eid', etype=c_etype)
subgs[c_etype] = (F.cat([u, v], dim=0), F.cat([v, u], dim=0))
new_g = heterograph(subgs)
# handle features
if copy_ndata:
# for each ntype
for ntype in g.ntypes:
# for each data field
for k in g.nodes[ntype].data:
new_g.nodes[ntype].data[k] = g.nodes[ntype].data[k]
if copy_edata:
# for each etype
for c_etype in canonical_etypes:
if c_etype[0] != c_etype[2]:
# for each data field
for k in g.edges[c_etype].data:
new_g.edges[c_etype].data[k] = g.edges[c_etype].data[k]
else:
for k in g.edges[c_etype].data:
new_g.edges[c_etype].data[k] = \
F.cat([g.edges[c_etype].data[k], g.edges[c_etype].data[k]], dim=0)
return new_g
def line_graph(g, backtracking=True, shared=False):
"""Return the line graph of this graph.
......@@ -539,7 +688,7 @@ def to_simple_graph(g):
gidx = _CAPI_DGLToSimpleGraph(g._graph)
return DGLGraph(gidx, readonly=True)
def to_bidirected(g, readonly=True):
def to_bidirected_stale(g, readonly=True):
"""Convert the graph to a bidirected graph.
The function generates a new graph with no node/edge feature.
......@@ -872,7 +1021,7 @@ def metis_partition_assignment(g, k, balance_ntypes=None, balance_edges=False):
'''
# METIS works only on symmetric graphs.
# The METIS runs on the symmetric graph to generate the node assignment to partitions.
sym_g = to_bidirected(g, readonly=True)
sym_g = to_bidirected_stale(g, readonly=True)
vwgt = []
# To balance the node types in each partition, we can take advantage of the vertex weights
# in Metis. When vertex weights are provided, Metis will tries to generate partitions with
......
......@@ -434,6 +434,30 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyTo")
*rv = HeteroGraphRef(hg_new);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroJointUnion")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef meta_graph = args[0];
List<HeteroGraphRef> component_graphs = args[1];
CHECK(component_graphs.size() > 1)
<< "Expect graph list to have at least two graphs";
std::vector<HeteroGraphPtr> component_ptrs;
component_ptrs.reserve(component_graphs.size());
const int64_t bits = component_graphs[0]->NumBits();
const DLContext ctx = component_graphs[0]->Context();
for (const auto& component : component_graphs) {
component_ptrs.push_back(component.sptr());
CHECK_EQ(component->NumBits(), bits)
<< "Expect graphs to joint union have the same index dtype(int" << bits
<< "), but got int" << component->NumBits();
CHECK_EQ(component->Context(), ctx)
<< "Expect graphs to joint union have the same context" << ctx
<< "), but got " << component->Context();
}
auto hgptr = JointUnionHeteroGraph(meta_graph.sptr(), component_ptrs);
*rv = HeteroGraphRef(hgptr);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointUnion_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef meta_graph = args[0];
......
......@@ -8,6 +8,92 @@ using namespace dgl::runtime;
namespace dgl {
HeteroGraphPtr JointUnionHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs) {
CHECK_GT(component_graphs.size(), 0) << "Input graph list has at least two graphs";
std::vector<HeteroGraphPtr> rel_graphs(meta_graph->NumEdges());
std::vector<int64_t> num_nodes_per_type(meta_graph->NumVertices(), 0);
// Loop over all canonical etypes
for (dgl_type_t etype = 0; etype < meta_graph->NumEdges(); ++etype) {
auto pair = meta_graph->FindEdge(etype);
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
uint64_t num_src_v = component_graphs[0]->NumVertices(src_vtype);
uint64_t num_dst_v = component_graphs[0]->NumVertices(dst_vtype);
HeteroGraphPtr rgptr = nullptr;
// ALL = CSC | CSR | COO
dgl_format_code_t format = (1 << (SparseFormat2Code(SparseFormat::kCOO)-1)) |
(1 << (SparseFormat2Code(SparseFormat::kCSR)-1)) |
(1 << (SparseFormat2Code(SparseFormat::kCSC)-1));
// get common format
for (size_t i = 0; i < component_graphs.size(); ++i) {
const auto& cg = component_graphs[i];
CHECK_EQ(num_src_v, component_graphs[i]->NumVertices(src_vtype)) << "Input graph[" << i <<
"] should have same number of src vertices as input graph[0]";
CHECK_EQ(num_dst_v, component_graphs[i]->NumVertices(dst_vtype)) << "Input graph[" << i <<
"] should have same number of dst vertices as input graph[0]";
const std::string restrict_format = cg->GetRelationGraph(etype)->GetRestrictFormat();
const SparseFormat curr_format = ParseSparseFormat(restrict_format);
if (curr_format == SparseFormat::kCOO ||
curr_format == SparseFormat::kCSR ||
curr_format == SparseFormat::kCSC)
format &=(1 << (SparseFormat2Code(curr_format)-1));
}
CHECK_GT(format, 0) << "The conjunction of restrict_format of the relation graphs under " <<
etype << "should not be None.";
// prefer COO
if (FORMAT_HAS_COO(format)) {
std::vector<aten::COOMatrix> coos;
for (size_t i = 0; i < component_graphs.size(); ++i) {
const auto& cg = component_graphs[i];
aten::COOMatrix coo = cg->GetCOOMatrix(etype);
coos.push_back(coo);
}
aten::COOMatrix res = aten::UnionCoo(coos);
rgptr = UnitGraph::CreateFromCOO(
(src_vtype == dst_vtype) ? 1 : 2, res,
SparseFormat::kAny);
} else if (FORMAT_HAS_CSR(format)) {
std::vector<aten::CSRMatrix> csrs;
for (size_t i = 0; i < component_graphs.size(); ++i) {
const auto& cg = component_graphs[i];
aten::CSRMatrix csr = cg->GetCSRMatrix(etype);
csrs.push_back(csr);
}
aten::CSRMatrix res = aten::UnionCsr(csrs);
rgptr = UnitGraph::CreateFromCSR(
(src_vtype == dst_vtype) ? 1 : 2, res,
SparseFormat::kAny);
} else if (FORMAT_HAS_CSC(format)) {
// CSR and CSC have the same storage format, i.e. CSRMatrix
std::vector<aten::CSRMatrix> cscs;
for (size_t i = 0; i < component_graphs.size(); ++i) {
const auto& cg = component_graphs[i];
aten::CSRMatrix csc = cg->GetCSCMatrix(etype);
cscs.push_back(csc);
}
aten::CSRMatrix res = aten::UnionCsr(cscs);
rgptr = UnitGraph::CreateFromCSC(
(src_vtype == dst_vtype) ? 1 : 2, res,
SparseFormat::kAny);
}
rel_graphs[etype] = rgptr;
num_nodes_per_type[src_vtype] = num_src_v;
num_nodes_per_type[dst_vtype] = num_dst_v;
}
return CreateHeteroGraph(meta_graph, rel_graphs, std::move(num_nodes_per_type));
}
HeteroGraphPtr DisjointUnionHeteroGraph2(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs) {
CHECK_GT(component_graphs.size(), 0) << "Input graph list is empty";
......@@ -24,8 +110,8 @@ HeteroGraphPtr DisjointUnionHeteroGraph2(
// ALL = CSC | CSR | COO
dgl_format_code_t format = (1 << (SparseFormat2Code(SparseFormat::kCOO)-1)) |
(1 << (SparseFormat2Code(SparseFormat::kCSR)-1)) |
(1 << (SparseFormat2Code(SparseFormat::kCSC)-1));
(1 << (SparseFormat2Code(SparseFormat::kCSR)-1)) |
(1 << (SparseFormat2Code(SparseFormat::kCSC)-1));
// do some preprocess
for (size_t i = 0; i < component_graphs.size(); ++i) {
const auto& cg = component_graphs[i];
......@@ -72,6 +158,7 @@ HeteroGraphPtr DisjointUnionHeteroGraph2(
(src_vtype == dst_vtype) ? 1 : 2, res,
SparseFormat::kAny);
} else if (FORMAT_HAS_CSC(format)) {
// CSR and CSC have the same storage format, i.e. CSRMatrix
std::vector<aten::CSRMatrix> cscs;
for (size_t i = 0; i < component_graphs.size(); ++i) {
const auto& cg = component_graphs[i];
......@@ -185,6 +272,7 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(
auto pair = meta_graph->FindEdge(etype);
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
// CSR and CSC have the same storage format, i.e. CSRMatrix
aten::CSRMatrix csc = batched_graph->GetCSCMatrix(etype);
auto res = aten::DisjointPartitionCsrBySizes(csc,
batch_size,
......
......@@ -8,6 +8,7 @@ import backend as F
import networkx as nx
import unittest, pytest
from dgl import DGLError
from dgl.heterograph_index import joint_union
from utils import parametrize_dtype
def create_test_heterograph(index_dtype):
......@@ -717,7 +718,7 @@ def test_view1(index_dtype):
for i in range(g.number_of_nodes(utype)):
assert out_degrees[i] == src_count[i]
for i in range(g.number_of_nodes(vtype)):
assert in_degrees[i] == dst_count[i]
assert in_degrees[i] == dst_count[i]
edges = {
'follows': ([0, 1], [1, 2]),
......@@ -774,7 +775,7 @@ def test_view1(index_dtype):
ndata = HG.ndata['h']
assert isinstance(ndata, dict)
assert F.array_equal(ndata['user'], f2)
edata = HG.edata['h']
assert isinstance(edata, dict)
assert F.array_equal(edata[('user', 'follows', 'user')], f4)
......@@ -1381,7 +1382,7 @@ def test_level2(index_dtype):
g['plays'].send_and_recv([2, 3], mfunc, rfunc)
y = g.nodes['game'].data['y']
assert F.array_equal(y, F.tensor([[0., 0.], [2., 2.]]))
# test fail case
# fail due to multiple types
fail = False
......@@ -2073,7 +2074,6 @@ def test_reverse(index_dtype):
assert F.array_equal(g_s.tousertensor(), rg_d.tousertensor())
assert F.array_equal(g_d.tousertensor(), rg_s.tousertensor())
if __name__ == '__main__':
# test_create()
# test_query()
......@@ -2100,6 +2100,7 @@ if __name__ == '__main__':
# test_isolated_ntype()
# test_bipartite()
# test_dtype_cast()
# test_reverse("int32")
test_reverse("int32")
test_format()
pass
......@@ -250,6 +250,84 @@ def test_reverse_shared_frames():
rg.update_all(src_msg, sum_reduce)
assert F.allclose(g.ndata['h'], rg.ndata['h'])
def test_to_bidirected():
# homogeneous graph
g = dgl.graph((F.tensor([0, 1, 3, 1]), F.tensor([1, 2, 0, 2])))
g.ndata['h'] = F.tensor([[0.], [1.], [2.], [1.]])
g.edata['h'] = F.tensor([[3.], [4.], [5.], [6.]])
bg = dgl.to_bidirected(g, copy_ndata=True, copy_edata=True)
u, v = g.edges()
ub, vb = bg.edges()
assert F.array_equal(F.cat([u, v], dim=0), ub)
assert F.array_equal(F.cat([v, u], dim=0), vb)
assert F.array_equal(g.ndata['h'], bg.ndata['h'])
assert F.array_equal(F.cat([g.edata['h'], g.edata['h']], dim=0), bg.edata['h'])
bg.ndata['hh'] = F.tensor([[0.], [1.], [2.], [1.]])
assert ('hh' in g.ndata) is False
bg.edata['hh'] = F.tensor([[0.], [1.], [2.], [1.], [0.], [1.], [2.], [1.]])
assert ('hh' in g.edata) is False
# donot share ndata and edata
bg = dgl.to_bidirected(g, copy_ndata=False, copy_edata=False)
ub, vb = bg.edges()
assert F.array_equal(F.cat([u, v], dim=0), ub)
assert F.array_equal(F.cat([v, u], dim=0), vb)
assert ('h' in bg.ndata) is False
assert ('h' in bg.edata) is False
# zero edge graph
g = dgl.graph([])
bg = dgl.to_bidirected(g, copy_ndata=True, copy_edata=True)
# heterogeneous graph
g = dgl.heterograph({
('user', 'wins', 'user'): (F.tensor([0, 2, 0, 2, 2]), F.tensor([1, 1, 2, 1, 0])),
('user', 'plays', 'game'): (F.tensor([1, 2, 1]), F.tensor([2, 1, 1])),
('user', 'follows', 'user'): (F.tensor([1, 2, 1]), F.tensor([0, 0, 0]))
})
g.nodes['game'].data['hv'] = F.ones((3, 1))
g.nodes['user'].data['hv'] = F.ones((3, 1))
g.edges['wins'].data['h'] = F.tensor([0, 1, 2, 3, 4])
bg = dgl.to_bidirected(g, copy_ndata=True, copy_edata=True, ignore_bipartite=True)
assert F.array_equal(g.nodes['game'].data['hv'], bg.nodes['game'].data['hv'])
assert F.array_equal(g.nodes['user'].data['hv'], bg.nodes['user'].data['hv'])
u, v = g.all_edges(order='eid', etype=('user', 'wins', 'user'))
ub, vb = bg.all_edges(order='eid', etype=('user', 'wins', 'user'))
assert F.array_equal(F.cat([u, v], dim=0), ub)
assert F.array_equal(F.cat([v, u], dim=0), vb)
assert F.array_equal(F.cat([g.edges['wins'].data['h'], g.edges['wins'].data['h']], dim=0),
bg.edges['wins'].data['h'])
u, v = g.all_edges(order='eid', etype=('user', 'follows', 'user'))
ub, vb = bg.all_edges(order='eid', etype=('user', 'follows', 'user'))
assert F.array_equal(F.cat([u, v], dim=0), ub)
assert F.array_equal(F.cat([v, u], dim=0), vb)
u, v = g.all_edges(order='eid', etype=('user', 'plays', 'game'))
ub, vb = bg.all_edges(order='eid', etype=('user', 'plays', 'game'))
assert F.array_equal(u, ub)
assert F.array_equal(v, vb)
assert len(bg.edges['plays'].data) == 0
assert len(bg.edges['follows'].data) == 0
# donot share ndata and edata
bg = dgl.to_bidirected(g, copy_ndata=False, copy_edata=False, ignore_bipartite=True)
assert len(bg.edges['wins'].data) == 0
assert len(bg.edges['plays'].data) == 0
assert len(bg.edges['follows'].data) == 0
assert len(bg.nodes['game'].data) == 0
assert len(bg.nodes['user'].data) == 0
u, v = g.all_edges(order='eid', etype=('user', 'wins', 'user'))
ub, vb = bg.all_edges(order='eid', etype=('user', 'wins', 'user'))
assert F.array_equal(F.cat([u, v], dim=0), ub)
assert F.array_equal(F.cat([v, u], dim=0), vb)
u, v = g.all_edges(order='eid', etype=('user', 'follows', 'user'))
ub, vb = bg.all_edges(order='eid', etype=('user', 'follows', 'user'))
assert F.array_equal(F.cat([u, v], dim=0), ub)
assert F.array_equal(F.cat([v, u], dim=0), vb)
u, v = g.all_edges(order='eid', etype=('user', 'plays', 'game'))
ub, vb = bg.all_edges(order='eid', etype=('user', 'plays', 'game'))
assert F.array_equal(u, ub)
assert F.array_equal(v, vb)
def test_simple_graph():
elist = [(0, 1), (0, 2), (1, 2), (0, 1)]
......@@ -271,7 +349,7 @@ def test_bidirected_graph():
g = dgl.DGLGraph(elist, readonly=in_readonly)
elist.append((1, 2))
elist = set(elist)
big = dgl.to_bidirected(g, out_readonly)
big = dgl.to_bidirected_stale(g, out_readonly)
assert big.number_of_edges() == num_edges
src, dst = big.edges()
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
......@@ -926,6 +1004,7 @@ if __name__ == '__main__':
# test_no_backtracking()
test_reverse()
# test_reverse_shared_frames()
test_to_bidirected()
# test_simple_graph()
# test_bidirected_graph()
# test_khop_adj()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment