"tests/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "91fe0c90690d9a7078b0b03dc059088a6f310777"
Commit b2e4bdc0 authored by Minjie Wang's avatar Minjie Wang
Browse files

passed basic test batching

parent 44db98c4
......@@ -123,7 +123,7 @@ class FrameRef(MutableMapping):
def select_rows(self, query):
rowids = self._getrowid(query)
def _lazy_select(key):
idx = rowids.totensor(F.get_context(self._frame[key]))
idx = rowids.tousertensor(F.get_context(self._frame[key]))
return F.gather_row(self._frame[key], idx)
return utils.LazyDict(_lazy_select, keys=self.schemes)
......@@ -132,7 +132,7 @@ class FrameRef(MutableMapping):
if self.is_span_whole_column():
return col
else:
idx = self.index().totensor(F.get_context(col))
idx = self.index().tousertensor(F.get_context(col))
return F.gather_row(col, idx)
def __setitem__(self, key, val):
......@@ -156,7 +156,7 @@ class FrameRef(MutableMapping):
else:
fcol = F.zeros((self._frame.num_rows,) + shp[1:])
fcol = F.to_context(fcol, colctx)
idx = self.index().totensor(colctx)
idx = self.index().tousertensor(colctx)
newfcol = F.scatter_row(fcol, idx, col)
self._frame[name] = newfcol
......@@ -167,7 +167,7 @@ class FrameRef(MutableMapping):
# add new column
tmpref = FrameRef(self._frame, rowids)
tmpref.add_column(key, col)
idx = rowids.totensor(F.get_context(self._frame[key]))
idx = rowids.tousertensor(F.get_context(self._frame[key]))
self._frame[key] = F.scatter_row(self._frame[key], idx, col)
def __delitem__(self, key):
......@@ -223,8 +223,8 @@ class FrameRef(MutableMapping):
# shortcut for identical mapping
return query
else:
idxtensor = self.index().totensor()
return utils.toindex(F.gather_row(idxtensor, query.totensor()))
idxtensor = self.index().tousertensor()
return utils.toindex(F.gather_row(idxtensor, query.tousertensor()))
def index(self):
if self._index is None:
......
......@@ -8,12 +8,14 @@ import dgl
from .base import ALL, is_all, __MSG__, __REPR__
from . import backend as F
from .backend import Tensor
from .graph_index import GraphIndex
from .frame import FrameRef, merge_frames
from . import scheduler
from . import utils
from .function.message import BundledMessageFunction
from .function.reducer import BundledReduceFunction
from .graph_index import GraphIndex
from . import scheduler
from . import utils
__all__ = ['DLGraph']
class DGLGraph(object):
"""Base graph class specialized for neural networks on graphs.
......@@ -63,9 +65,11 @@ class DGLGraph(object):
Optional node representations.
"""
self._graph.add_nodes(num)
self._msg_graph.add_nodes(num)
#TODO(minjie): change frames
assert reprs is None
def add_edge(self, u, v, repr=None):
def add_edge(self, u, v, reprs=None):
"""Add one edge.
Parameters
......@@ -74,11 +78,12 @@ class DGLGraph(object):
The src node.
v : int
The dst node.
repr : dict
reprs : dict
Optional edge representation.
"""
self._graph.add_edge(u, v)
#TODO(minjie): change frames
assert reprs is None
def add_edges(self, u, v, reprs=None):
"""Add many edges.
......@@ -96,6 +101,7 @@ class DGLGraph(object):
v = utils.toindex(v)
self._graph.add_edges(u, v)
#TODO(minjie): change frames
assert reprs is None
def clear(self):
"""Clear the graph and its storage."""
......@@ -483,6 +489,8 @@ class DGLGraph(object):
dict
Representation dict
"""
if len(self.node_attr_schemes()) == 0:
return dict()
if is_all(u):
if len(self._node_frame) == 1 and __REPR__ in self._node_frame:
return self._node_frame[__REPR__]
......@@ -535,7 +543,7 @@ class DGLGraph(object):
v_is_all = is_all(v)
assert u_is_all == v_is_all
if u_is_all:
num_edges = self.cached_graph.num_edges()
num_edges = self.number_of_edges()
else:
u = utils.toindex(u)
v = utils.toindex(v)
......@@ -553,7 +561,7 @@ class DGLGraph(object):
else:
self._edge_frame[__REPR__] = h_uv
else:
eid = self.cached_graph.get_edge_id(u, v)
eid = self._graph.edge_ids(u, v)
if utils.is_dict_like(h_uv):
self._edge_frame[eid] = h_uv
else:
......@@ -571,7 +579,7 @@ class DGLGraph(object):
"""
# sanity check
if is_all(eid):
num_edges = self.cached_graph.num_edges()
num_edges = self.number_of_edges()
else:
eid = utils.toindex(eid)
num_edges = len(eid)
......@@ -611,6 +619,8 @@ class DGLGraph(object):
u_is_all = is_all(u)
v_is_all = is_all(v)
assert u_is_all == v_is_all
if len(self.edge_attr_schemes()) == 0:
return dict()
if u_is_all:
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame:
return self._edge_frame[__REPR__]
......@@ -619,7 +629,7 @@ class DGLGraph(object):
else:
u = utils.toindex(u)
v = utils.toindex(v)
eid = self.cached_graph.get_edge_id(u, v)
eid = self._graph.edge_ids(u, v)
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame:
return self._edge_frame.select_rows(eid)[__REPR__]
else:
......@@ -653,6 +663,8 @@ class DGLGraph(object):
dict
Representation dict
"""
if len(self.edge_attr_schemes()) == 0:
return dict()
if is_all(eid):
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame:
return self._edge_frame[__REPR__]
......@@ -843,8 +855,8 @@ class DGLGraph(object):
def _batch_send(self, u, v, message_func):
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
self.msg_graph.add_edges(u, v)
u, v, _ = self._graph.edges(sorted=True)
self._msg_graph.add_edges(u, v)
# call UDF
src_reprs = self.get_n_repr(u)
edge_reprs = self.get_e_repr()
......@@ -853,11 +865,10 @@ class DGLGraph(object):
u = utils.toindex(u)
v = utils.toindex(v)
u, v = utils.edge_broadcasting(u, v)
eid = self.cached_graph.get_edge_id(u, v)
self.msg_graph.add_edges(u, v)
self._msg_graph.add_edges(u, v)
# call UDF
src_reprs = self.get_n_repr(u)
edge_reprs = self.get_e_repr_by_id(eid)
edge_reprs = self.get_e_repr(u, v)
msgs = message_func(src_reprs, edge_reprs)
if utils.is_dict_like(msgs):
self._msg_frame.append(msgs)
......@@ -909,7 +920,7 @@ class DGLGraph(object):
def _batch_update_edge(self, u, v, edge_func):
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
u, v = self._graph.edges(sorted=True)
# call the UDF
src_reprs = self.get_n_repr(u)
dst_reprs = self.get_n_repr(v)
......@@ -920,7 +931,7 @@ class DGLGraph(object):
u = utils.toindex(u)
v = utils.toindex(v)
u, v = utils.edge_broadcasting(u, v)
eid = self.cached_graph.get_edge_id(u, v)
eid = self._graph.edge_ids(u, v)
# call the UDF
src_reprs = self.get_n_repr(u)
dst_reprs = self.get_n_repr(v)
......@@ -1005,7 +1016,7 @@ class DGLGraph(object):
v = utils.toindex(v)
# degree bucketing
degrees, v_buckets = scheduler.degree_bucketing(self.msg_graph, v)
degrees, v_buckets = scheduler.degree_bucketing(self._msg_graph, v)
if degrees == [0]:
# no message has been sent to the specified node
return
......@@ -1020,8 +1031,7 @@ class DGLGraph(object):
continue
bkt_len = len(v_bkt)
dst_reprs = self.get_n_repr(v_bkt)
uu, vv, _ = self.msg_graph.in_edges(v_bkt)
in_msg_ids = self.msg_graph.get_edge_id(uu, vv)
uu, vv, in_msg_ids = self._msg_graph.in_edges(v_bkt)
in_msgs = self._msg_frame.select_rows(in_msg_ids)
# Reshape the column tensor to (B, Deg, ...).
def _reshape_fn(msg):
......@@ -1033,7 +1043,7 @@ class DGLGraph(object):
else:
reshaped_in_msgs = utils.LazyDict(
lambda key: _reshape_fn(in_msgs[key]), self._msg_frame.schemes)
reordered_v.append(v_bkt.totensor())
reordered_v.append(v_bkt.tousertensor())
new_reprs.append(reduce_func(dst_reprs, reshaped_in_msgs))
# TODO: clear partial messages
......@@ -1087,7 +1097,7 @@ class DGLGraph(object):
# no edges to be triggered
assert len(v) == 0
return
unique_v = utils.toindex(F.unique(v.totensor()))
unique_v = utils.toindex(F.unique(v.tousertensor()))
# TODO(minjie): better way to figure out `batchable` flag
if message_func == "default":
......@@ -1135,10 +1145,10 @@ class DGLGraph(object):
v = utils.toindex(v)
if len(v) == 0:
return
uu, vv, _ = self.cached_graph.in_edges(v)
uu, vv, _ = self._graph.in_edges(v)
self.send_and_recv(uu, vv, message_func, reduce_func,
apply_node_func=None, batchable=batchable)
unique_v = F.unique(v.totensor())
unique_v = F.unique(v.tousertensor())
self.apply_nodes(unique_v, apply_node_func, batchable=batchable)
def push(self,
......@@ -1165,7 +1175,7 @@ class DGLGraph(object):
u = utils.toindex(u)
if len(u) == 0:
return
uu, vv, _ = self.cached_graph.out_edges(u)
uu, vv, _ = self._graph.out_edges(u)
self.send_and_recv(uu, vv, message_func,
reduce_func, apply_node_func, batchable=batchable)
......@@ -1309,8 +1319,10 @@ class DGLGraph(object):
reduce_func)
def clear_messages(self):
"""Clear all messages."""
self._msg_graph.clear()
self._msg_frame.clear()
self._msg_graph.add_nodes(self.number_of_nodes())
def _get_repr(attr_dict):
if len(attr_dict) == 1 and __REPR__ in attr_dict:
......
......@@ -7,6 +7,7 @@ used with C++ library.
from __future__ import absolute_import as _abs
import ctypes
import functools
import operator
import numpy as _np
......@@ -18,7 +19,7 @@ from . import backend as F
class NDArray(NDArrayBase):
"""Lightweight NDArray class for DGL framework."""
def __len__(self):
return reduce(operator.mul, self.shape, 1)
return functools.reduce(operator.mul, self.shape, 1)
def cpu(dev_id=0):
"""Construct a CPU device
......
......@@ -11,12 +11,12 @@ from . import utils
__all__ = ["degree_bucketing", "get_executor"]
def degree_bucketing(cached_graph, v):
def degree_bucketing(graph, v):
"""Create degree bucketing scheduling policy.
Parameters
----------
cached_graph : dgl.cached_graph.CachedGraph
graph : dgl.graph_index.GraphIndex
the graph
v : dgl.utils.Index
the nodes to gather messages
......@@ -29,7 +29,7 @@ def degree_bucketing(cached_graph, v):
list of node id buckets; nodes belong to the same bucket have
the same degree
"""
degrees = F.asnumpy(cached_graph.in_degrees(v).totensor())
degrees = np.array(graph.in_degrees(v).tolist())
unique_degrees = list(np.unique(degrees))
v_np = np.array(v.tolist())
v_bkt = []
......
......@@ -141,9 +141,9 @@ def edge_broadcasting(u, v):
The dst id(s) after broadcasting
"""
if len(u) != len(v) and len(u) == 1:
u = toindex(F.broadcast_to(u.totensor(), v.totensor()))
u = toindex(F.broadcast_to(u.tousertensor(), v.tousertensor()))
elif len(u) != len(v) and len(v) == 1:
v = toindex(F.broadcast_to(v.totensor(), u.totensor()))
v = toindex(F.broadcast_to(v.tousertensor(), u.tousertensor()))
else:
assert len(u) == len(v)
return u, v
......@@ -205,7 +205,7 @@ def build_relabel_map(x):
One can use advanced indexing to convert an old id tensor to a
new id tensor: new_id = old_to_new[old_id]
"""
x = x.totensor()
x = x.tousertensor()
unique_x, _ = F.sort(F.unique(x))
map_len = int(F.max(unique_x)) + 1
old_to_new = F.zeros(map_len, dtype=F.int64)
......@@ -312,6 +312,6 @@ def reorder(dict_like, index):
"""
new_dict = {}
for key, val in dict_like.items():
idx_ctx = index.totensor(F.get_context(val))
idx_ctx = index.tousertensor(F.get_context(val))
new_dict[key] = F.gather_row(val, idx_ctx)
return new_dict
......@@ -31,7 +31,7 @@ void Graph::AddEdges(IdArray src_ids, IdArray dst_ids) {
CHECK(IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array.";
const auto srclen = src_ids->shape[0];
const auto dstlen = src_ids->shape[0];
const auto dstlen = dst_ids->shape[0];
const int64_t* src_data = static_cast<int64_t*>(src_ids->data);
const int64_t* dst_data = static_cast<int64_t*>(dst_ids->data);
if (srclen == 1) {
......@@ -78,7 +78,7 @@ BoolArray Graph::HasEdges(IdArray src_ids, IdArray dst_ids) const {
CHECK(IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array.";
const auto srclen = src_ids->shape[0];
const auto dstlen = src_ids->shape[0];
const auto dstlen = dst_ids->shape[0];
const auto rstlen = std::max(srclen, dstlen);
BoolArray rst = BoolArray::Empty({rstlen}, src_ids->dtype, src_ids->ctx);
int64_t* rst_data = static_cast<int64_t*>(rst->data);
......@@ -150,7 +150,7 @@ IdArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
CHECK(IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array.";
const auto srclen = src_ids->shape[0];
const auto dstlen = src_ids->shape[0];
const auto dstlen = dst_ids->shape[0];
const auto rstlen = std::max(srclen, dstlen);
IdArray rst = IdArray::Empty({rstlen}, src_ids->dtype, src_ids->ctx);
int64_t* rst_data = static_cast<int64_t*>(rst->data);
......
......@@ -27,8 +27,7 @@ def apply_node_func(node):
def generate_graph(grad=False):
g = DGLGraph()
for i in range(10):
g.add_node(i) # 10 nodes.
g.add_nodes(10) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
......@@ -198,7 +197,7 @@ def test_update_routines():
def test_reduce_0deg():
g = DGLGraph()
g.add_nodes_from([0, 1, 2, 3, 4])
g.add_nodes(5)
g.add_edge(1, 0)
g.add_edge(2, 0)
g.add_edge(3, 0)
......@@ -218,7 +217,7 @@ def test_reduce_0deg():
def test_pull_0deg():
g = DGLGraph()
g.add_nodes_from([0, 1])
g.add_nodes(2)
g.add_edge(0, 1)
def _message(src, edge):
return src
......@@ -243,16 +242,6 @@ def test_pull_0deg():
assert th.allclose(new_repr[0], old_repr[0])
assert th.allclose(new_repr[1], old_repr[0])
def _test_delete():
g = generate_graph()
ecol = Variable(th.randn(17, D), requires_grad=grad)
g.set_e_repr({'e' : ecol})
assert g.get_n_repr()['h'].shape[0] == 10
assert g.get_e_repr()['e'].shape[0] == 17
g.remove_node(0)
assert g.get_n_repr()['h'].shape[0] == 9
assert g.get_e_repr()['e'].shape[0] == 8
if __name__ == '__main__':
test_batch_setter_getter()
test_batch_setter_autograd()
......@@ -261,4 +250,3 @@ if __name__ == '__main__':
test_update_routines()
test_reduce_0deg()
test_pull_0deg()
#test_delete()
......@@ -23,8 +23,7 @@ def reduce_func(hv, msgs):
def generate_graph(grad=False):
g = DGLGraph()
for i in range(10):
g.add_node(i) # 10 nodes.
g.add_nodes(10)
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
......
......@@ -113,7 +113,7 @@ def test_append2():
assert not f.is_span_whole_column()
assert f.num_rows == 3 * N
new_idx = list(range(N)) + list(range(2*N, 4*N))
assert check_eq(f.index().totensor(), th.tensor(new_idx))
assert check_eq(f.index().tousertensor(), th.tensor(new_idx))
assert data.num_rows == 4 * N
def test_row1():
......
......@@ -5,8 +5,7 @@ from dgl.graph import __REPR__
def generate_graph():
g = dgl.DGLGraph()
for i in range(10):
g.add_node(i) # 10 nodes.
g.add_nodes(10) # 10 nodes.
h = th.arange(1, 11)
g.set_n_repr({'h': h})
# create a graph where 0 is the source and 9 is the sink
......@@ -23,8 +22,7 @@ def generate_graph():
def generate_graph1():
"""graph with anonymous repr"""
g = dgl.DGLGraph()
for i in range(10):
g.add_node(i) # 10 nodes.
g.add_nodes(10) # 10 nodes.
h = th.arange(1, 11)
g.set_n_repr(h)
# create a graph where 0 is the source and 9 is the sink
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment