"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "cd0268cd408d19d91f870e36fdffd031085abe13"
Commit b2e4bdc0 authored by Minjie Wang's avatar Minjie Wang
Browse files

passed basic test batching

parent 44db98c4
...@@ -123,7 +123,7 @@ class FrameRef(MutableMapping): ...@@ -123,7 +123,7 @@ class FrameRef(MutableMapping):
def select_rows(self, query): def select_rows(self, query):
rowids = self._getrowid(query) rowids = self._getrowid(query)
def _lazy_select(key): def _lazy_select(key):
idx = rowids.totensor(F.get_context(self._frame[key])) idx = rowids.tousertensor(F.get_context(self._frame[key]))
return F.gather_row(self._frame[key], idx) return F.gather_row(self._frame[key], idx)
return utils.LazyDict(_lazy_select, keys=self.schemes) return utils.LazyDict(_lazy_select, keys=self.schemes)
...@@ -132,7 +132,7 @@ class FrameRef(MutableMapping): ...@@ -132,7 +132,7 @@ class FrameRef(MutableMapping):
if self.is_span_whole_column(): if self.is_span_whole_column():
return col return col
else: else:
idx = self.index().totensor(F.get_context(col)) idx = self.index().tousertensor(F.get_context(col))
return F.gather_row(col, idx) return F.gather_row(col, idx)
def __setitem__(self, key, val): def __setitem__(self, key, val):
...@@ -156,7 +156,7 @@ class FrameRef(MutableMapping): ...@@ -156,7 +156,7 @@ class FrameRef(MutableMapping):
else: else:
fcol = F.zeros((self._frame.num_rows,) + shp[1:]) fcol = F.zeros((self._frame.num_rows,) + shp[1:])
fcol = F.to_context(fcol, colctx) fcol = F.to_context(fcol, colctx)
idx = self.index().totensor(colctx) idx = self.index().tousertensor(colctx)
newfcol = F.scatter_row(fcol, idx, col) newfcol = F.scatter_row(fcol, idx, col)
self._frame[name] = newfcol self._frame[name] = newfcol
...@@ -167,7 +167,7 @@ class FrameRef(MutableMapping): ...@@ -167,7 +167,7 @@ class FrameRef(MutableMapping):
# add new column # add new column
tmpref = FrameRef(self._frame, rowids) tmpref = FrameRef(self._frame, rowids)
tmpref.add_column(key, col) tmpref.add_column(key, col)
idx = rowids.totensor(F.get_context(self._frame[key])) idx = rowids.tousertensor(F.get_context(self._frame[key]))
self._frame[key] = F.scatter_row(self._frame[key], idx, col) self._frame[key] = F.scatter_row(self._frame[key], idx, col)
def __delitem__(self, key): def __delitem__(self, key):
...@@ -223,8 +223,8 @@ class FrameRef(MutableMapping): ...@@ -223,8 +223,8 @@ class FrameRef(MutableMapping):
# shortcut for identical mapping # shortcut for identical mapping
return query return query
else: else:
idxtensor = self.index().totensor() idxtensor = self.index().tousertensor()
return utils.toindex(F.gather_row(idxtensor, query.totensor())) return utils.toindex(F.gather_row(idxtensor, query.tousertensor()))
def index(self): def index(self):
if self._index is None: if self._index is None:
......
...@@ -8,12 +8,14 @@ import dgl ...@@ -8,12 +8,14 @@ import dgl
from .base import ALL, is_all, __MSG__, __REPR__ from .base import ALL, is_all, __MSG__, __REPR__
from . import backend as F from . import backend as F
from .backend import Tensor from .backend import Tensor
from .graph_index import GraphIndex
from .frame import FrameRef, merge_frames from .frame import FrameRef, merge_frames
from . import scheduler
from . import utils
from .function.message import BundledMessageFunction from .function.message import BundledMessageFunction
from .function.reducer import BundledReduceFunction from .function.reducer import BundledReduceFunction
from .graph_index import GraphIndex
from . import scheduler
from . import utils
__all__ = ['DLGraph']
class DGLGraph(object): class DGLGraph(object):
"""Base graph class specialized for neural networks on graphs. """Base graph class specialized for neural networks on graphs.
...@@ -63,9 +65,11 @@ class DGLGraph(object): ...@@ -63,9 +65,11 @@ class DGLGraph(object):
Optional node representations. Optional node representations.
""" """
self._graph.add_nodes(num) self._graph.add_nodes(num)
self._msg_graph.add_nodes(num)
#TODO(minjie): change frames #TODO(minjie): change frames
assert reprs is None
def add_edge(self, u, v, repr=None): def add_edge(self, u, v, reprs=None):
"""Add one edge. """Add one edge.
Parameters Parameters
...@@ -74,11 +78,12 @@ class DGLGraph(object): ...@@ -74,11 +78,12 @@ class DGLGraph(object):
The src node. The src node.
v : int v : int
The dst node. The dst node.
repr : dict reprs : dict
Optional edge representation. Optional edge representation.
""" """
self._graph.add_edge(u, v) self._graph.add_edge(u, v)
#TODO(minjie): change frames #TODO(minjie): change frames
assert reprs is None
def add_edges(self, u, v, reprs=None): def add_edges(self, u, v, reprs=None):
"""Add many edges. """Add many edges.
...@@ -96,6 +101,7 @@ class DGLGraph(object): ...@@ -96,6 +101,7 @@ class DGLGraph(object):
v = utils.toindex(v) v = utils.toindex(v)
self._graph.add_edges(u, v) self._graph.add_edges(u, v)
#TODO(minjie): change frames #TODO(minjie): change frames
assert reprs is None
def clear(self): def clear(self):
"""Clear the graph and its storage.""" """Clear the graph and its storage."""
...@@ -483,6 +489,8 @@ class DGLGraph(object): ...@@ -483,6 +489,8 @@ class DGLGraph(object):
dict dict
Representation dict Representation dict
""" """
if len(self.node_attr_schemes()) == 0:
return dict()
if is_all(u): if is_all(u):
if len(self._node_frame) == 1 and __REPR__ in self._node_frame: if len(self._node_frame) == 1 and __REPR__ in self._node_frame:
return self._node_frame[__REPR__] return self._node_frame[__REPR__]
...@@ -535,7 +543,7 @@ class DGLGraph(object): ...@@ -535,7 +543,7 @@ class DGLGraph(object):
v_is_all = is_all(v) v_is_all = is_all(v)
assert u_is_all == v_is_all assert u_is_all == v_is_all
if u_is_all: if u_is_all:
num_edges = self.cached_graph.num_edges() num_edges = self.number_of_edges()
else: else:
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
...@@ -553,7 +561,7 @@ class DGLGraph(object): ...@@ -553,7 +561,7 @@ class DGLGraph(object):
else: else:
self._edge_frame[__REPR__] = h_uv self._edge_frame[__REPR__] = h_uv
else: else:
eid = self.cached_graph.get_edge_id(u, v) eid = self._graph.edge_ids(u, v)
if utils.is_dict_like(h_uv): if utils.is_dict_like(h_uv):
self._edge_frame[eid] = h_uv self._edge_frame[eid] = h_uv
else: else:
...@@ -571,7 +579,7 @@ class DGLGraph(object): ...@@ -571,7 +579,7 @@ class DGLGraph(object):
""" """
# sanity check # sanity check
if is_all(eid): if is_all(eid):
num_edges = self.cached_graph.num_edges() num_edges = self.number_of_edges()
else: else:
eid = utils.toindex(eid) eid = utils.toindex(eid)
num_edges = len(eid) num_edges = len(eid)
...@@ -611,6 +619,8 @@ class DGLGraph(object): ...@@ -611,6 +619,8 @@ class DGLGraph(object):
u_is_all = is_all(u) u_is_all = is_all(u)
v_is_all = is_all(v) v_is_all = is_all(v)
assert u_is_all == v_is_all assert u_is_all == v_is_all
if len(self.edge_attr_schemes()) == 0:
return dict()
if u_is_all: if u_is_all:
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame: if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame:
return self._edge_frame[__REPR__] return self._edge_frame[__REPR__]
...@@ -619,7 +629,7 @@ class DGLGraph(object): ...@@ -619,7 +629,7 @@ class DGLGraph(object):
else: else:
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
eid = self.cached_graph.get_edge_id(u, v) eid = self._graph.edge_ids(u, v)
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame: if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame:
return self._edge_frame.select_rows(eid)[__REPR__] return self._edge_frame.select_rows(eid)[__REPR__]
else: else:
...@@ -653,6 +663,8 @@ class DGLGraph(object): ...@@ -653,6 +663,8 @@ class DGLGraph(object):
dict dict
Representation dict Representation dict
""" """
if len(self.edge_attr_schemes()) == 0:
return dict()
if is_all(eid): if is_all(eid):
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame: if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame:
return self._edge_frame[__REPR__] return self._edge_frame[__REPR__]
...@@ -843,8 +855,8 @@ class DGLGraph(object): ...@@ -843,8 +855,8 @@ class DGLGraph(object):
def _batch_send(self, u, v, message_func): def _batch_send(self, u, v, message_func):
if is_all(u) and is_all(v): if is_all(u) and is_all(v):
u, v = self.cached_graph.edges() u, v, _ = self._graph.edges(sorted=True)
self.msg_graph.add_edges(u, v) self._msg_graph.add_edges(u, v)
# call UDF # call UDF
src_reprs = self.get_n_repr(u) src_reprs = self.get_n_repr(u)
edge_reprs = self.get_e_repr() edge_reprs = self.get_e_repr()
...@@ -853,11 +865,10 @@ class DGLGraph(object): ...@@ -853,11 +865,10 @@ class DGLGraph(object):
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
u, v = utils.edge_broadcasting(u, v) u, v = utils.edge_broadcasting(u, v)
eid = self.cached_graph.get_edge_id(u, v) self._msg_graph.add_edges(u, v)
self.msg_graph.add_edges(u, v)
# call UDF # call UDF
src_reprs = self.get_n_repr(u) src_reprs = self.get_n_repr(u)
edge_reprs = self.get_e_repr_by_id(eid) edge_reprs = self.get_e_repr(u, v)
msgs = message_func(src_reprs, edge_reprs) msgs = message_func(src_reprs, edge_reprs)
if utils.is_dict_like(msgs): if utils.is_dict_like(msgs):
self._msg_frame.append(msgs) self._msg_frame.append(msgs)
...@@ -909,7 +920,7 @@ class DGLGraph(object): ...@@ -909,7 +920,7 @@ class DGLGraph(object):
def _batch_update_edge(self, u, v, edge_func): def _batch_update_edge(self, u, v, edge_func):
if is_all(u) and is_all(v): if is_all(u) and is_all(v):
u, v = self.cached_graph.edges() u, v = self._graph.edges(sorted=True)
# call the UDF # call the UDF
src_reprs = self.get_n_repr(u) src_reprs = self.get_n_repr(u)
dst_reprs = self.get_n_repr(v) dst_reprs = self.get_n_repr(v)
...@@ -920,7 +931,7 @@ class DGLGraph(object): ...@@ -920,7 +931,7 @@ class DGLGraph(object):
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
u, v = utils.edge_broadcasting(u, v) u, v = utils.edge_broadcasting(u, v)
eid = self.cached_graph.get_edge_id(u, v) eid = self._graph.edge_ids(u, v)
# call the UDF # call the UDF
src_reprs = self.get_n_repr(u) src_reprs = self.get_n_repr(u)
dst_reprs = self.get_n_repr(v) dst_reprs = self.get_n_repr(v)
...@@ -1005,7 +1016,7 @@ class DGLGraph(object): ...@@ -1005,7 +1016,7 @@ class DGLGraph(object):
v = utils.toindex(v) v = utils.toindex(v)
# degree bucketing # degree bucketing
degrees, v_buckets = scheduler.degree_bucketing(self.msg_graph, v) degrees, v_buckets = scheduler.degree_bucketing(self._msg_graph, v)
if degrees == [0]: if degrees == [0]:
# no message has been sent to the specified node # no message has been sent to the specified node
return return
...@@ -1020,8 +1031,7 @@ class DGLGraph(object): ...@@ -1020,8 +1031,7 @@ class DGLGraph(object):
continue continue
bkt_len = len(v_bkt) bkt_len = len(v_bkt)
dst_reprs = self.get_n_repr(v_bkt) dst_reprs = self.get_n_repr(v_bkt)
uu, vv, _ = self.msg_graph.in_edges(v_bkt) uu, vv, in_msg_ids = self._msg_graph.in_edges(v_bkt)
in_msg_ids = self.msg_graph.get_edge_id(uu, vv)
in_msgs = self._msg_frame.select_rows(in_msg_ids) in_msgs = self._msg_frame.select_rows(in_msg_ids)
# Reshape the column tensor to (B, Deg, ...). # Reshape the column tensor to (B, Deg, ...).
def _reshape_fn(msg): def _reshape_fn(msg):
...@@ -1033,7 +1043,7 @@ class DGLGraph(object): ...@@ -1033,7 +1043,7 @@ class DGLGraph(object):
else: else:
reshaped_in_msgs = utils.LazyDict( reshaped_in_msgs = utils.LazyDict(
lambda key: _reshape_fn(in_msgs[key]), self._msg_frame.schemes) lambda key: _reshape_fn(in_msgs[key]), self._msg_frame.schemes)
reordered_v.append(v_bkt.totensor()) reordered_v.append(v_bkt.tousertensor())
new_reprs.append(reduce_func(dst_reprs, reshaped_in_msgs)) new_reprs.append(reduce_func(dst_reprs, reshaped_in_msgs))
# TODO: clear partial messages # TODO: clear partial messages
...@@ -1087,7 +1097,7 @@ class DGLGraph(object): ...@@ -1087,7 +1097,7 @@ class DGLGraph(object):
# no edges to be triggered # no edges to be triggered
assert len(v) == 0 assert len(v) == 0
return return
unique_v = utils.toindex(F.unique(v.totensor())) unique_v = utils.toindex(F.unique(v.tousertensor()))
# TODO(minjie): better way to figure out `batchable` flag # TODO(minjie): better way to figure out `batchable` flag
if message_func == "default": if message_func == "default":
...@@ -1135,10 +1145,10 @@ class DGLGraph(object): ...@@ -1135,10 +1145,10 @@ class DGLGraph(object):
v = utils.toindex(v) v = utils.toindex(v)
if len(v) == 0: if len(v) == 0:
return return
uu, vv, _ = self.cached_graph.in_edges(v) uu, vv, _ = self._graph.in_edges(v)
self.send_and_recv(uu, vv, message_func, reduce_func, self.send_and_recv(uu, vv, message_func, reduce_func,
apply_node_func=None, batchable=batchable) apply_node_func=None, batchable=batchable)
unique_v = F.unique(v.totensor()) unique_v = F.unique(v.tousertensor())
self.apply_nodes(unique_v, apply_node_func, batchable=batchable) self.apply_nodes(unique_v, apply_node_func, batchable=batchable)
def push(self, def push(self,
...@@ -1165,7 +1175,7 @@ class DGLGraph(object): ...@@ -1165,7 +1175,7 @@ class DGLGraph(object):
u = utils.toindex(u) u = utils.toindex(u)
if len(u) == 0: if len(u) == 0:
return return
uu, vv, _ = self.cached_graph.out_edges(u) uu, vv, _ = self._graph.out_edges(u)
self.send_and_recv(uu, vv, message_func, self.send_and_recv(uu, vv, message_func,
reduce_func, apply_node_func, batchable=batchable) reduce_func, apply_node_func, batchable=batchable)
...@@ -1309,8 +1319,10 @@ class DGLGraph(object): ...@@ -1309,8 +1319,10 @@ class DGLGraph(object):
reduce_func) reduce_func)
def clear_messages(self): def clear_messages(self):
"""Clear all messages."""
self._msg_graph.clear() self._msg_graph.clear()
self._msg_frame.clear() self._msg_frame.clear()
self._msg_graph.add_nodes(self.number_of_nodes())
def _get_repr(attr_dict): def _get_repr(attr_dict):
if len(attr_dict) == 1 and __REPR__ in attr_dict: if len(attr_dict) == 1 and __REPR__ in attr_dict:
......
...@@ -7,6 +7,7 @@ used with C++ library. ...@@ -7,6 +7,7 @@ used with C++ library.
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import ctypes import ctypes
import functools
import operator import operator
import numpy as _np import numpy as _np
...@@ -18,7 +19,7 @@ from . import backend as F ...@@ -18,7 +19,7 @@ from . import backend as F
class NDArray(NDArrayBase): class NDArray(NDArrayBase):
"""Lightweight NDArray class for DGL framework.""" """Lightweight NDArray class for DGL framework."""
def __len__(self): def __len__(self):
return reduce(operator.mul, self.shape, 1) return functools.reduce(operator.mul, self.shape, 1)
def cpu(dev_id=0): def cpu(dev_id=0):
"""Construct a CPU device """Construct a CPU device
......
...@@ -11,12 +11,12 @@ from . import utils ...@@ -11,12 +11,12 @@ from . import utils
__all__ = ["degree_bucketing", "get_executor"] __all__ = ["degree_bucketing", "get_executor"]
def degree_bucketing(cached_graph, v): def degree_bucketing(graph, v):
"""Create degree bucketing scheduling policy. """Create degree bucketing scheduling policy.
Parameters Parameters
---------- ----------
cached_graph : dgl.cached_graph.CachedGraph graph : dgl.graph_index.GraphIndex
the graph the graph
v : dgl.utils.Index v : dgl.utils.Index
the nodes to gather messages the nodes to gather messages
...@@ -29,7 +29,7 @@ def degree_bucketing(cached_graph, v): ...@@ -29,7 +29,7 @@ def degree_bucketing(cached_graph, v):
list of node id buckets; nodes belong to the same bucket have list of node id buckets; nodes belong to the same bucket have
the same degree the same degree
""" """
degrees = F.asnumpy(cached_graph.in_degrees(v).totensor()) degrees = np.array(graph.in_degrees(v).tolist())
unique_degrees = list(np.unique(degrees)) unique_degrees = list(np.unique(degrees))
v_np = np.array(v.tolist()) v_np = np.array(v.tolist())
v_bkt = [] v_bkt = []
......
...@@ -141,9 +141,9 @@ def edge_broadcasting(u, v): ...@@ -141,9 +141,9 @@ def edge_broadcasting(u, v):
The dst id(s) after broadcasting The dst id(s) after broadcasting
""" """
if len(u) != len(v) and len(u) == 1: if len(u) != len(v) and len(u) == 1:
u = toindex(F.broadcast_to(u.totensor(), v.totensor())) u = toindex(F.broadcast_to(u.tousertensor(), v.tousertensor()))
elif len(u) != len(v) and len(v) == 1: elif len(u) != len(v) and len(v) == 1:
v = toindex(F.broadcast_to(v.totensor(), u.totensor())) v = toindex(F.broadcast_to(v.tousertensor(), u.tousertensor()))
else: else:
assert len(u) == len(v) assert len(u) == len(v)
return u, v return u, v
...@@ -205,7 +205,7 @@ def build_relabel_map(x): ...@@ -205,7 +205,7 @@ def build_relabel_map(x):
One can use advanced indexing to convert an old id tensor to a One can use advanced indexing to convert an old id tensor to a
new id tensor: new_id = old_to_new[old_id] new id tensor: new_id = old_to_new[old_id]
""" """
x = x.totensor() x = x.tousertensor()
unique_x, _ = F.sort(F.unique(x)) unique_x, _ = F.sort(F.unique(x))
map_len = int(F.max(unique_x)) + 1 map_len = int(F.max(unique_x)) + 1
old_to_new = F.zeros(map_len, dtype=F.int64) old_to_new = F.zeros(map_len, dtype=F.int64)
...@@ -312,6 +312,6 @@ def reorder(dict_like, index): ...@@ -312,6 +312,6 @@ def reorder(dict_like, index):
""" """
new_dict = {} new_dict = {}
for key, val in dict_like.items(): for key, val in dict_like.items():
idx_ctx = index.totensor(F.get_context(val)) idx_ctx = index.tousertensor(F.get_context(val))
new_dict[key] = F.gather_row(val, idx_ctx) new_dict[key] = F.gather_row(val, idx_ctx)
return new_dict return new_dict
...@@ -31,7 +31,7 @@ void Graph::AddEdges(IdArray src_ids, IdArray dst_ids) { ...@@ -31,7 +31,7 @@ void Graph::AddEdges(IdArray src_ids, IdArray dst_ids) {
CHECK(IsValidIdArray(src_ids)) << "Invalid src id array."; CHECK(IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array."; CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array.";
const auto srclen = src_ids->shape[0]; const auto srclen = src_ids->shape[0];
const auto dstlen = src_ids->shape[0]; const auto dstlen = dst_ids->shape[0];
const int64_t* src_data = static_cast<int64_t*>(src_ids->data); const int64_t* src_data = static_cast<int64_t*>(src_ids->data);
const int64_t* dst_data = static_cast<int64_t*>(dst_ids->data); const int64_t* dst_data = static_cast<int64_t*>(dst_ids->data);
if (srclen == 1) { if (srclen == 1) {
...@@ -78,7 +78,7 @@ BoolArray Graph::HasEdges(IdArray src_ids, IdArray dst_ids) const { ...@@ -78,7 +78,7 @@ BoolArray Graph::HasEdges(IdArray src_ids, IdArray dst_ids) const {
CHECK(IsValidIdArray(src_ids)) << "Invalid src id array."; CHECK(IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array."; CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array.";
const auto srclen = src_ids->shape[0]; const auto srclen = src_ids->shape[0];
const auto dstlen = src_ids->shape[0]; const auto dstlen = dst_ids->shape[0];
const auto rstlen = std::max(srclen, dstlen); const auto rstlen = std::max(srclen, dstlen);
BoolArray rst = BoolArray::Empty({rstlen}, src_ids->dtype, src_ids->ctx); BoolArray rst = BoolArray::Empty({rstlen}, src_ids->dtype, src_ids->ctx);
int64_t* rst_data = static_cast<int64_t*>(rst->data); int64_t* rst_data = static_cast<int64_t*>(rst->data);
...@@ -150,7 +150,7 @@ IdArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const { ...@@ -150,7 +150,7 @@ IdArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
CHECK(IsValidIdArray(src_ids)) << "Invalid src id array."; CHECK(IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array."; CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array.";
const auto srclen = src_ids->shape[0]; const auto srclen = src_ids->shape[0];
const auto dstlen = src_ids->shape[0]; const auto dstlen = dst_ids->shape[0];
const auto rstlen = std::max(srclen, dstlen); const auto rstlen = std::max(srclen, dstlen);
IdArray rst = IdArray::Empty({rstlen}, src_ids->dtype, src_ids->ctx); IdArray rst = IdArray::Empty({rstlen}, src_ids->dtype, src_ids->ctx);
int64_t* rst_data = static_cast<int64_t*>(rst->data); int64_t* rst_data = static_cast<int64_t*>(rst->data);
......
...@@ -27,8 +27,7 @@ def apply_node_func(node): ...@@ -27,8 +27,7 @@ def apply_node_func(node):
def generate_graph(grad=False): def generate_graph(grad=False):
g = DGLGraph() g = DGLGraph()
for i in range(10): g.add_nodes(10) # 10 nodes.
g.add_node(i) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink # create a graph where 0 is the source and 9 is the sink
for i in range(1, 9): for i in range(1, 9):
g.add_edge(0, i) g.add_edge(0, i)
...@@ -198,7 +197,7 @@ def test_update_routines(): ...@@ -198,7 +197,7 @@ def test_update_routines():
def test_reduce_0deg(): def test_reduce_0deg():
g = DGLGraph() g = DGLGraph()
g.add_nodes_from([0, 1, 2, 3, 4]) g.add_nodes(5)
g.add_edge(1, 0) g.add_edge(1, 0)
g.add_edge(2, 0) g.add_edge(2, 0)
g.add_edge(3, 0) g.add_edge(3, 0)
...@@ -218,7 +217,7 @@ def test_reduce_0deg(): ...@@ -218,7 +217,7 @@ def test_reduce_0deg():
def test_pull_0deg(): def test_pull_0deg():
g = DGLGraph() g = DGLGraph()
g.add_nodes_from([0, 1]) g.add_nodes(2)
g.add_edge(0, 1) g.add_edge(0, 1)
def _message(src, edge): def _message(src, edge):
return src return src
...@@ -243,16 +242,6 @@ def test_pull_0deg(): ...@@ -243,16 +242,6 @@ def test_pull_0deg():
assert th.allclose(new_repr[0], old_repr[0]) assert th.allclose(new_repr[0], old_repr[0])
assert th.allclose(new_repr[1], old_repr[0]) assert th.allclose(new_repr[1], old_repr[0])
def _test_delete():
g = generate_graph()
ecol = Variable(th.randn(17, D), requires_grad=grad)
g.set_e_repr({'e' : ecol})
assert g.get_n_repr()['h'].shape[0] == 10
assert g.get_e_repr()['e'].shape[0] == 17
g.remove_node(0)
assert g.get_n_repr()['h'].shape[0] == 9
assert g.get_e_repr()['e'].shape[0] == 8
if __name__ == '__main__': if __name__ == '__main__':
test_batch_setter_getter() test_batch_setter_getter()
test_batch_setter_autograd() test_batch_setter_autograd()
...@@ -261,4 +250,3 @@ if __name__ == '__main__': ...@@ -261,4 +250,3 @@ if __name__ == '__main__':
test_update_routines() test_update_routines()
test_reduce_0deg() test_reduce_0deg()
test_pull_0deg() test_pull_0deg()
#test_delete()
...@@ -23,8 +23,7 @@ def reduce_func(hv, msgs): ...@@ -23,8 +23,7 @@ def reduce_func(hv, msgs):
def generate_graph(grad=False): def generate_graph(grad=False):
g = DGLGraph() g = DGLGraph()
for i in range(10): g.add_nodes(10)
g.add_node(i) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink # create a graph where 0 is the source and 9 is the sink
for i in range(1, 9): for i in range(1, 9):
g.add_edge(0, i) g.add_edge(0, i)
......
...@@ -113,7 +113,7 @@ def test_append2(): ...@@ -113,7 +113,7 @@ def test_append2():
assert not f.is_span_whole_column() assert not f.is_span_whole_column()
assert f.num_rows == 3 * N assert f.num_rows == 3 * N
new_idx = list(range(N)) + list(range(2*N, 4*N)) new_idx = list(range(N)) + list(range(2*N, 4*N))
assert check_eq(f.index().totensor(), th.tensor(new_idx)) assert check_eq(f.index().tousertensor(), th.tensor(new_idx))
assert data.num_rows == 4 * N assert data.num_rows == 4 * N
def test_row1(): def test_row1():
......
...@@ -5,8 +5,7 @@ from dgl.graph import __REPR__ ...@@ -5,8 +5,7 @@ from dgl.graph import __REPR__
def generate_graph(): def generate_graph():
g = dgl.DGLGraph() g = dgl.DGLGraph()
for i in range(10): g.add_nodes(10) # 10 nodes.
g.add_node(i) # 10 nodes.
h = th.arange(1, 11) h = th.arange(1, 11)
g.set_n_repr({'h': h}) g.set_n_repr({'h': h})
# create a graph where 0 is the source and 9 is the sink # create a graph where 0 is the source and 9 is the sink
...@@ -23,8 +22,7 @@ def generate_graph(): ...@@ -23,8 +22,7 @@ def generate_graph():
def generate_graph1(): def generate_graph1():
"""graph with anonymous repr""" """graph with anonymous repr"""
g = dgl.DGLGraph() g = dgl.DGLGraph()
for i in range(10): g.add_nodes(10) # 10 nodes.
g.add_node(i) # 10 nodes.
h = th.arange(1, 11) h = th.arange(1, 11)
g.set_n_repr(h) g.set_n_repr(h)
# create a graph where 0 is the source and 9 is the sink # create a graph where 0 is the source and 9 is the sink
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment