Unverified Commit 17b60e1a authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Bugfix] EdgeBatch.edges call attempts tuple item assignment (#747)

* upd

* fig edgebatch edges

* add test

* trigger
parent b2f7f0ee
...@@ -3258,7 +3258,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -3258,7 +3258,7 @@ class DGLGraph(DGLBaseGraph):
""" """
if is_all(edges): if is_all(edges):
eid = ALL eid = ALL
u, v, _ = self._graph.edges() u, v, _ = self._graph.edges('eid')
elif isinstance(edges, tuple): elif isinstance(edges, tuple):
u, v = edges u, v = edges
u = utils.toindex(u) u = utils.toindex(u)
......
...@@ -76,8 +76,8 @@ class EdgeBatch(object): ...@@ -76,8 +76,8 @@ class EdgeBatch(object):
in the batch. in the batch.
""" """
if is_all(self._edges[2]): if is_all(self._edges[2]):
self._edges[2] = utils.toindex(F.arange( self._edges = self._edges[:2] + (utils.toindex(F.arange(
0, self._g.number_of_edges())) 0, self._g.number_of_edges())),)
u, v, eid = self._edges u, v, eid = self._edges
return (u.tousertensor(), v.tousertensor(), eid.tousertensor()) return (u.tousertensor(), v.tousertensor(), eid.tousertensor())
......
import backend as F
import dgl
import networkx as nx
import dgl.utils as utils
from dgl import DGLGraph, ALL
from dgl.udf import NodeBatch, EdgeBatch
def test_node_batch():
g = dgl.DGLGraph(nx.path_graph(20))
feat = F.randn((g.number_of_nodes(), 10))
g.ndata['x'] = feat
# test all
v = ALL
n_repr = g.get_n_repr(v)
nbatch = NodeBatch(g, v, n_repr)
assert F.allclose(nbatch.data['x'], feat)
assert nbatch.mailbox is None
assert F.allclose(nbatch.nodes(), g.nodes())
assert nbatch.batch_size() == g.number_of_nodes()
assert len(nbatch) == g.number_of_nodes()
# test partial
v = utils.toindex(F.tensor([0, 3, 5, 7, 9]))
n_repr = g.get_n_repr(v)
nbatch = NodeBatch(g, v, n_repr)
assert F.allclose(nbatch.data['x'], F.gather_row(feat, F.tensor([0, 3, 5, 7, 9])))
assert nbatch.mailbox is None
assert F.allclose(nbatch.nodes(), F.tensor([0, 3, 5, 7, 9]))
assert nbatch.batch_size() == 5
assert len(nbatch) == 5
def test_edge_batch():
d = 10
g = dgl.DGLGraph(nx.path_graph(20))
nfeat = F.randn((g.number_of_nodes(), d))
efeat = F.randn((g.number_of_edges(), d))
g.ndata['x'] = nfeat
g.edata['x'] = efeat
# test all
eid = ALL
u, v, _ = g._graph.edges('eid')
src_data = g.get_n_repr(u)
edge_data = g.get_e_repr(eid)
dst_data = g.get_n_repr(v)
ebatch = EdgeBatch(g, (u, v, eid), src_data, edge_data, dst_data)
assert F.shape(ebatch.src['x'])[0] == g.number_of_edges() and\
F.shape(ebatch.src['x'])[1] == d
assert F.shape(ebatch.dst['x'])[0] == g.number_of_edges() and\
F.shape(ebatch.dst['x'])[1] == d
assert F.shape(ebatch.data['x'])[0] == g.number_of_edges() and\
F.shape(ebatch.data['x'])[1] == d
assert F.allclose(ebatch.edges()[0], u.tousertensor())
assert F.allclose(ebatch.edges()[1], v.tousertensor())
assert F.allclose(ebatch.edges()[2], F.arange(0, g.number_of_edges()))
assert ebatch.batch_size() == g.number_of_edges()
assert len(ebatch) == g.number_of_edges()
# test partial
eid = utils.toindex(F.tensor([0, 3, 5, 7, 11, 13, 15, 27]))
u, v, _ = g._graph.find_edges(eid)
src_data = g.get_n_repr(u)
edge_data = g.get_e_repr(eid)
dst_data = g.get_n_repr(v)
ebatch = EdgeBatch(g, (u, v, eid), src_data, edge_data, dst_data)
assert F.shape(ebatch.src['x'])[0] == 8 and\
F.shape(ebatch.src['x'])[1] == d
assert F.shape(ebatch.dst['x'])[0] == 8 and\
F.shape(ebatch.dst['x'])[1] == d
assert F.shape(ebatch.data['x'])[0] == 8 and\
F.shape(ebatch.data['x'])[1] == d
assert F.allclose(ebatch.edges()[0], u.tousertensor())
assert F.allclose(ebatch.edges()[1], v.tousertensor())
assert F.allclose(ebatch.edges()[2], eid.tousertensor())
assert ebatch.batch_size() == 8
assert len(ebatch) == 8
if __name__ == '__main__':
test_node_batch()
test_edge_batch()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment