Unverified Commit 3721822e authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

Subgraph API (#39)

* subgraph

* more test cases

* WIP

* new FrameRef and test

* separate nx init code

* WIP

* subgraph code and test

* line graph code and test

* adding new test for adding new features on line graphs

* no backtracking line graph

* fix inplace relabel
parent 1f16f29b
from .base import ALL
from .graph import DGLGraph
from .graph import __MSG__, __REPR__
from .context import cpu, gpu
from .batch import batch, unbatch
from .context import cpu, gpu
from .graph import DGLGraph, __MSG__, __REPR__
from .subgraph import DGLSubGraph
from .generator import *
......@@ -58,6 +58,7 @@ unsqueeze = th.unsqueeze
reshape = th.reshape
zeros = th.zeros
ones = th.ones
zeros = th.zeros
spmm = th.spmm
sort = th.sort
arange = th.arange
......
"""Columnar storage for graph attributes."""
from __future__ import absolute_import
from collections import MutableMapping
import numpy as np
import dgl.backend as F
from dgl.backend import Tensor
from dgl.utils import LazyDict
class Frame:
class Frame(MutableMapping):
def __init__(self, data=None):
if data is None:
self._columns = dict()
self._num_rows = 0
else:
self._columns = data
self._columns = dict(data)
self._num_rows = F.shape(list(data.values())[0])[0]
for k, v in data.items():
assert F.shape(v)[0] == self._num_rows
......@@ -32,25 +35,18 @@ class Frame:
return key in self._columns
def __getitem__(self, key):
if isinstance(key, str):
return self._columns[key]
else:
return self.select_rows(key)
# get column
return self._columns[key]
def __setitem__(self, key, val):
if isinstance(key, str):
self._columns[key] = val
else:
self.update_rows(key, val)
# set column
self.add_column(key, val)
def __delitem__(self, key):
# delete column
del self._columns[key]
def pop(self, key):
col = self._columns[key]
del self._columns[key]
return col
if len(self._columns) == 0:
self._num_rows = 0
def add_column(self, name, col):
if self.num_columns == 0:
......@@ -60,35 +56,164 @@ class Frame:
self._columns[name] = col
def append(self, other):
if not isinstance(other, Frame):
other = Frame(data=other)
if len(self._columns) == 0:
self._columns = other._columns
self._num_rows = other._num_rows
for key, col in other.items():
self._columns[key] = col
else:
assert self.schemes == other.schemes
self._columns = {key : F.pack([self[key], other[key]]) for key in self._columns}
self._num_rows += other._num_rows
for key, col in other.items():
self._columns[key] = F.pack([self[key], col])
# TODO(minjie): sanity check for num_rows
if len(self._columns) != 0:
self._num_rows = F.shape(list(self._columns.values())[0])[0]
def clear(self):
self._columns = {}
self._num_rows = 0
def select_rows(self, rowids):
def __iter__(self):
return iter(self._columns)
def __len__(self):
return self.num_columns
class FrameRef(MutableMapping):
def __init__(self, frame=None, index=None):
self._frame = frame if frame is not None else Frame()
if index is None:
self._index = slice(0, self._frame.num_rows)
else:
# check no duplicate index
assert len(index) == len(np.unique(index))
self._index = index
self._index_tensor = None
@property
def schemes(self):
return self._frame.schemes
@property
def num_columns(self):
return self._frame.num_columns
@property
def num_rows(self):
if isinstance(self._index, slice):
return self._index.stop
else:
return len(self._index)
def __contains__(self, key):
return key in self._frame
def __getitem__(self, key):
if isinstance(key, str):
return self.get_column(key)
else:
return self.select_rows(key)
def select_rows(self, query):
rowids = self._getrowid(query)
def _lazy_select(key):
return F.gather_row(self._columns[key], rowids)
return LazyDict(_lazy_select, keys=self._columns.keys())
return F.gather_row(self._frame[key], rowids)
return LazyDict(_lazy_select, keys=self.schemes)
def get_column(self, name):
col = self._frame[name]
if self.is_span_whole_column():
return col
else:
return F.gather_row(col, self.index_tensor())
def update_rows(self, rowids, other):
if not isinstance(other, Frame):
other = Frame(data=other)
for key in other.schemes:
assert key in self._columns
self._columns[key] = F.scatter_row(self[key], rowids, other[key])
def __setitem__(self, key, val):
if isinstance(key, str):
self.add_column(key, val)
else:
self.update_rows(key, val)
def add_column(self, name, col):
shp = F.shape(col)
if self.is_span_whole_column():
if self.num_columns == 0:
self._index = slice(0, shp[0])
self._clear_cache()
assert shp[0] == self.num_rows
self._frame[name] = col
else:
if name in self._frame:
fcol = self._frame[name]
else:
fcol = F.zeros((self._frame.num_rows,) + shp[1:])
newfcol = F.scatter_row(fcol, self.index_tensor(), col)
self._frame[name] = newfcol
def update_rows(self, query, other):
rowids = self._getrowid(query)
for key, col in other.items():
self._frame[key] = F.scatter_row(self._frame[key], rowids, col)
def __delitem__(self, key):
if isinstance(key, str):
del self._frame[key]
if len(self._frame) == 0:
self.clear()
else:
self.delete_rows(key)
def delete_rows(self, query):
query = F.asnumpy(query)
if isinstance(self._index, slice):
self._index = list(range(self._index.start, self._index.stop))
arr = np.array(self._index, dtype=np.int32)
self._index = list(np.delete(arr, query))
self._clear_cache()
def append(self, other):
span_whole = self.is_span_whole_column()
contiguous = self.is_contiguous()
old_nrows = self._frame.num_rows
self._frame.append(other)
# update index
if span_whole:
self._index = slice(0, self._frame.num_rows)
else:
new_idx = list(range(self._index.start, self._index.stop))
new_idx += list(range(old_nrows, self._frame.num_rows))
self._index = new_idx
self._clear_cache()
def clear(self):
self._frame.clear()
self._index = slice(0, 0)
self._clear_cache()
def __iter__(self):
for key, col in self._columns.items():
yield key, col
return iter(self._frame)
def __len__(self):
return self.num_columns
def is_contiguous(self):
# NOTE: this check could have false negative
return isinstance(self._index, slice)
def is_span_whole_column(self):
return self.is_contiguous() and self.num_rows == self._frame.num_rows
def _getrowid(self, query):
if isinstance(self._index, slice):
# shortcut for identical mapping
return query
else:
return F.gather_row(self.index_tensor(), query)
def index_tensor(self):
# TODO(minjie): context
if self._index_tensor is None:
if self.is_contiguous():
self._index_tensor = F.arange(self._index.stop, dtype=F.int64)
else:
self._index_tensor = F.tensor(self._index, dtype=F.int64)
return self._index_tensor
def _clear_cache(self):
self._index_tensor = None
"""Line graph generator."""
from __future__ import absolute_import
import networkx as nx
import numpy as np
import dgl.backend as F
from dgl.graph import DGLGraph
from dgl.frame import FrameRef
def line_graph(G, no_backtracking=False):
"""Create the line graph that shares the underlying features.
The node features of the result line graph will share the edge features
of the given graph.
Parameters
----------
G : DGLGraph
The input graph.
no_backtracking : bool
Whether the backtracking edges are included in the line graph.
If i~j and j~i are two edges in original graph G, then
(i,j)~(j,i) and (j,i)~(i,j) are the "backtracking" edges on
the line graph.
"""
L = nx.DiGraph()
for eid, from_node in enumerate(G.edge_list):
L.add_node(from_node)
for to_node in G.edges(from_node[1]):
if no_backtracking and to_node[1] == from_node[0]:
continue
L.add_edge(from_node, to_node)
relabel_map = {}
for i, e in enumerate(G.edge_list):
relabel_map[e] = i
nx.relabel.relabel_nodes(L, relabel_map, copy=False)
return DGLGraph(L, node_frame=G._edge_frame)
......@@ -2,63 +2,24 @@
"""
from __future__ import absolute_import
from collections import MutableMapping
import networkx as nx
from networkx.classes.digraph import DiGraph
import dgl
from dgl.base import ALL, is_all
import dgl.backend as F
from dgl.backend import Tensor
import dgl.builtin as builtin
from dgl.cached_graph import CachedGraph, create_cached_graph
import dgl.context as context
from dgl.frame import Frame
from dgl.frame import FrameRef
from dgl.nx_adapt import nx_init
import dgl.scheduler as scheduler
import dgl.utils as utils
__MSG__ = "__MSG__"
__REPR__ = "__REPR__"
class _NodeDict(MutableMapping):
def __init__(self, cb):
self._dict = {}
self._cb = cb
def __setitem__(self, key, val):
if isinstance(val, _AdjInnerDict):
# This node dict is used as adj_outer_list
val.src = key
elif key not in self._dict:
self._cb(key)
self._dict[key] = val
def __getitem__(self, key):
return self._dict[key]
def __delitem__(self, key):
# FIXME: add callback
del self._dict[key]
def __len__(self):
return len(self._dict)
def __iter__(self):
return iter(self._dict)
class _AdjInnerDict(MutableMapping):
def __init__(self, cb):
self._dict = {}
self.src = None
self._cb = cb
def __setitem__(self, key, val):
if key not in self._dict:
self._cb(self.src, key)
self._dict[key] = val
def __getitem__(self, key):
return self._dict[key]
def __delitem__(self, key):
# FIXME: add callback
del self._dict[key]
def __len__(self):
return len(self._dict)
def __iter__(self):
return iter(self._dict)
class DGLGraph(DiGraph):
"""Base graph class specialized for neural networks on graphs.
......@@ -67,42 +28,41 @@ class DGLGraph(DiGraph):
Parameters
----------
data : graph data
graph_data : graph data
Data to initialize graph. Same as networkx's semantics.
node_frame : dgl.frame.Frame
Node feature storage.
edge_frame : dgl.frame.Frame
Edge feature storage.
attr : keyword arguments, optional
Attributes to add to graph as key=value pairs.
"""
def __init__(self, graph_data=None, **attr):
# setup dict overlay
self.node_dict_factory = lambda : _NodeDict(self._add_node_callback)
# In networkx 2.1, DiGraph is not using this factory. Instead, the outer
# dict uses the same data structure as the node dict.
self.adjlist_outer_dict_factory = None
self.adjlist_inner_dict_factory = lambda : _AdjInnerDict(self._add_edge_callback)
self.edge_attr_dict_factory = dict
self._context = context.cpu()
# call base class init
super(DGLGraph, self).__init__(graph_data, **attr)
self._init_state()
def _init_state(self):
def __init__(self,
graph_data=None,
node_frame=None,
edge_frame=None,
**attr):
# TODO(minjie): maintaining node/edge list is costly when graph is large.
self._edge_list = []
nx_init(self,
self._add_node_callback,
self._add_edge_callback,
self._del_node_callback,
self._del_edge_callback,
graph_data,
**attr)
# cached graph and storage
self._cached_graph = None
self._node_frame = Frame()
self._edge_frame = Frame()
self._node_frame = node_frame if node_frame is not None else FrameRef()
self._edge_frame = edge_frame if edge_frame is not None else FrameRef()
# other class members
self._msg_graph = None
self._msg_frame = Frame()
self._msg_frame = FrameRef()
self._message_func = None
self._reduce_func = None
self._update_func = None
self._edge_func = None
self._edge_cb_state = True
self._edge_list = []
def clear(self):
super(DGLGraph, self).clear()
self._init_state()
self._context = context.cpu()
def get_n_attr_list(self):
return self._node_frame.schemes
......@@ -129,7 +89,7 @@ class DGLGraph(DiGraph):
The node(s).
"""
# sanity check
if isinstance(u, str) and u == ALL:
if is_all(u):
num_nodes = self.number_of_nodes()
else:
u = utils.convert_to_id_tensor(u, self.context)
......@@ -140,7 +100,7 @@ class DGLGraph(DiGraph):
else:
assert F.shape(hu)[0] == num_nodes
# set
if isinstance(u, str) and u == ALL:
if is_all(u):
if isinstance(hu, dict):
for key, val in hu.items():
self._node_frame[key] = val
......@@ -161,7 +121,7 @@ class DGLGraph(DiGraph):
u : node, container or tensor
The node(s).
"""
if isinstance(u, str) and u == ALL:
if is_all(u):
if len(self._node_frame) == 1 and __REPR__ in self._node_frame:
return self._node_frame[__REPR__]
else:
......@@ -204,8 +164,8 @@ class DGLGraph(DiGraph):
The destination node(s).
"""
# sanity check
u_is_all = isinstance(u, str) and u == ALL
v_is_all = isinstance(v, str) and v == ALL
u_is_all = is_all(u)
v_is_all = is_all(v)
assert u_is_all == v_is_all
if u_is_all:
num_edges = self.number_of_edges()
......@@ -244,7 +204,7 @@ class DGLGraph(DiGraph):
The edge id(s).
"""
# sanity check
if isinstance(eid, str) and eid == ALL:
if is_all(eid):
num_edges = self.number_of_edges()
else:
eid = utils.convert_to_id_tensor(eid, self.context)
......@@ -255,7 +215,7 @@ class DGLGraph(DiGraph):
else:
assert F.shape(h_uv)[0] == num_edges
# set
if isinstance(eid, str) and eid == ALL:
if is_all(eid):
if isinstance(h_uv, dict):
for key, val in h_uv.items():
self._edge_frame[key] = val
......@@ -278,8 +238,8 @@ class DGLGraph(DiGraph):
v : node, container or tensor
The destination node(s).
"""
u_is_all = isinstance(u, str) and u == ALL
v_is_all = isinstance(v, str) and v == ALL
u_is_all = is_all(u)
v_is_all = is_all(v)
assert u_is_all == v_is_all
if u_is_all:
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame:
......@@ -313,7 +273,7 @@ class DGLGraph(DiGraph):
eid : int, container or tensor
The edge id(s).
"""
if isinstance(eid, str) and eid == ALL:
if is_all(eid):
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame:
return self._edge_frame[__REPR__]
else:
......@@ -956,6 +916,38 @@ class DGLGraph(DiGraph):
self.update_by_edge(u, v,
message_func, reduce_func, update_func, batchable)
def subgraph(self, nodes):
"""Generate the subgraph among the given nodes.
The generated graph contains only the graph structure. The node/edge
features are not shared implicitly. Use `copy_from` to get node/edge
features from parent graph.
Parameters
----------
nodes : list, or iterable
A container of the nodes to construct subgraph.
Returns
-------
G : DGLGraph
The subgraph.
"""
return dgl.DGLSubGraph(self, nodes)
def copy_from(self, graph):
"""Copy node/edge features from the given graph.
All old features will be removed.
Parameters
----------
graph : DGLGraph
The graph to copy from.
"""
# TODO
pass
def draw(self):
"""Plot the graph using dot."""
from networkx.drawing.nx_agraph import graphviz_layout
......@@ -984,6 +976,28 @@ class DGLGraph(DiGraph):
self._msg_graph.add_nodes(self.number_of_nodes())
self._msg_frame.clear()
@property
def edge_list(self):
"""Return edges in the addition order."""
return self._edge_list
def get_edge_id(self, u, v):
"""Return the continuous edge id(s) assigned.
Parameters
----------
u : node, container or tensor
The source node(s).
v : node, container or tensor
The destination node(s).
Returns
-------
eid : tensor
The tensor contains edge id(s).
"""
return self.cached_graph.get_edge_id(u, v)
def _nodes_or_all(self, nodes):
return self.nodes() if nodes == ALL else nodes
......@@ -991,22 +1005,29 @@ class DGLGraph(DiGraph):
return self.edges() if edges == ALL else edges
def _add_node_callback(self, node):
#print('New node:', node)
self._cached_graph = None
def _add_edge_callback(self, u, v):
# In networkx 2.1, two adjlists are maintained. One for succ, one for pred.
# We only record once for the succ addition.
if self._edge_cb_state:
#print('New edge:', u, v)
self._edge_list.append((u, v))
self._edge_cb_state = not self._edge_cb_state
def _del_node_callback(self, node):
#print('Del node:', node)
raise RuntimeError('Node removal is not supported currently.')
node = utils.convert_to_id_tensor(node)
self._node_frame.delete_rows(node)
self._cached_graph = None
@property
def edge_list(self):
"""Return edges in the addition order."""
return self._edge_list
def _add_edge_callback(self, u, v):
#print('New edge:', u, v)
self._edge_list.append((u, v))
self._cached_graph = None
def _del_edge_callback(self, u, v):
#print('Del edge:', u, v)
raise RuntimeError('Edge removal is not supported currently.')
u = utils.convert_to_id_tensor(u)
v = utils.convert_to_id_tensor(v)
eid = self.get_edge_id(u, v)
self._edge_frame.delete_rows(eid)
self._cached_graph = None
def _get_repr(attr_dict):
if len(attr_dict) == 1 and __REPR__ in attr_dict:
......
"""Utility functions for networkx adapter."""
from __future__ import absolute_import
from collections import MutableMapping
import networkx as nx
import networkx.convert as convert
class NodeDict(MutableMapping):
def __init__(self, add_cb, del_cb):
self._dict = {}
self._add_cb = add_cb
self._del_cb = del_cb
def __setitem__(self, key, val):
self._add_cb(key)
self._dict[key] = val
def __getitem__(self, key):
return self._dict[key]
def __delitem__(self, key):
self._del_cb(key)
del self._dict[key]
def __len__(self):
return len(self._dict)
def __iter__(self):
return iter(self._dict)
class AdjOuterDict(MutableMapping):
def __init__(self, add_cb, del_cb):
self._dict = {}
self._add_cb = add_cb
self._del_cb = del_cb
def __setitem__(self, key, val):
val.src = key
self._dict[key] = val
def __getitem__(self, key):
return self._dict[key]
def __delitem__(self, key):
for val in self._dict[key]:
self._del_cb(key, val)
del self._dict[key]
def __len__(self):
return len(self._dict)
def __iter__(self):
return iter(self._dict)
class AdjInnerDict(MutableMapping):
def __init__(self, add_cb, del_cb):
self._dict = {}
self.src = None
self._add_cb = add_cb
self._del_cb = del_cb
def __setitem__(self, key, val):
if self.src is not None and key not in self._dict:
self._add_cb(self.src, key)
self._dict[key] = val
def __getitem__(self, key):
return self._dict[key]
def __delitem__(self, key):
if self.src is not None:
self._del_cb(self.src, key)
del self._dict[key]
def __len__(self):
return len(self._dict)
def __iter__(self):
return iter(self._dict)
def nx_init(obj,
add_node_cb,
add_edge_cb,
del_node_cb,
del_edge_cb,
graph_data,
**attr):
"""Init the object to be compatible with networkx's DiGraph.
Parameters
----------
obj : any
The object to be init.
add_node_cb : callable
The callback function when node is added.
add_edge_cb : callable
The callback function when edge is added.
graph_data : graph data
Data to initialize graph. Same as networkx's semantics.
attr : keyword arguments, optional
Attributes to add to graph as key=value pairs.
"""
# The following codes work for networkx 2.1.
obj.adjlist_outer_dict_factory = None
obj.adjlist_inner_dict_factory = lambda : AdjInnerDict(add_edge_cb, del_edge_cb)
obj.edge_attr_dict_factory = dict
obj.root_graph = obj
obj.graph = {}
obj._node = NodeDict(add_node_cb, del_node_cb)
obj._adj = AdjOuterDict(add_edge_cb, del_edge_cb)
obj._pred = dict()
obj._succ = obj._adj
if graph_data is not None:
convert.to_networkx_graph(graph_data, create_using=obj)
obj.graph.update(attr)
"""DGLSubGraph"""
from __future__ import absolute_import
import networkx as nx
import dgl.backend as F
from dgl.frame import Frame, FrameRef
from dgl.graph import DGLGraph
from dgl.nx_adapt import nx_init
import dgl.utils as utils
class DGLSubGraph(DGLGraph):
# TODO(gaiyu): ReadOnlyGraph
def __init__(self,
parent,
nodes):
# create subgraph and relabel
nx_sg = nx.DiGraph.subgraph(parent, nodes)
# node id
# TODO(minjie): context
nid = F.tensor(nodes, dtype=F.int64)
# edge id
# TODO(minjie): slow, context
u, v = zip(*nx_sg.edges)
u = list(u)
v = list(v)
eid = parent.cached_graph.get_edge_id(u, v)
# relabel
self._node_mapping = utils.build_relabel_dict(nodes)
nx_sg = nx.relabel.relabel_nodes(nx_sg, self._node_mapping)
# init
self._edge_list = []
nx_init(self,
self._add_node_callback,
self._add_edge_callback,
self._del_node_callback,
self._del_edge_callback,
nx_sg,
**parent.graph)
# cached graph and storage
self._cached_graph = None
if parent._node_frame.num_rows == 0:
self._node_frame = FrameRef()
else:
self._node_frame = FrameRef(Frame(parent._node_frame[nid]))
if parent._edge_frame.num_rows == 0:
self._edge_frame = FrameRef()
else:
self._edge_frame = FrameRef(Frame(parent._edge_frame[eid]))
# other class members
self._msg_graph = None
self._msg_frame = FrameRef()
self._message_func = parent._message_func
self._reduce_func = parent._reduce_func
self._update_func = parent._update_func
self._edge_func = parent._edge_func
self._context = parent._context
......@@ -71,23 +71,40 @@ class LazyDict(Mapping):
self._fn = fn
self._keys = keys
def keys(self):
return self._keys
def __getitem__(self, key):
assert key in self._keys
if not key in self._keys:
raise KeyError(key)
return self._fn(key)
def __contains__(self, key):
return key in self._keys
def __iter__(self):
for key in self._keys:
yield key, self._fn(key)
return iter(self._keys)
def __len__(self):
return len(self._keys)
class ReadOnlyDict(Mapping):
"""A readonly dictionary wrapper."""
def __init__(self, dict_like):
self._dict_like = dict_like
def keys(self):
return self._dict_like.keys()
def __getitem__(self, key):
return self._dict_like[key]
def __contains__(self, key):
return key in self._dict_like
def __iter__(self):
return iter(self._dict_like)
def __len__(self):
return len(self._dict_like)
def build_relabel_map(x):
"""Relabel the input ids to continuous ids that starts from zero.
......@@ -113,6 +130,26 @@ def build_relabel_map(x):
old_to_new[unique_x] = F.astype(F.arange(len(unique_x)), F.int64)
return unique_x, old_to_new
def build_relabel_dict(x):
"""Relabel the input ids to continuous ids that starts from zero.
The new id follows the order of the given node id list.
Parameters
----------
x : list
The input ids.
Returns
-------
relabel_dict : dict
Dict from old id to new id.
"""
relabel_dict = {}
for i, v in enumerate(x):
relabel_dict[v] = i
return relabel_dict
def edge_broadcasting(u, v):
"""Convert one-many and many-one edges to many-many."""
if len(u) != len(v) and len(u) == 1:
......
......@@ -36,8 +36,8 @@ def generate_graph(grad=False):
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
col = Variable(th.randn(10, D), requires_grad=grad)
g.set_n_repr({'h' : col})
ncol = Variable(th.randn(10, D), requires_grad=grad)
g.set_n_repr({'h' : ncol})
return g
def test_batch_setter_getter():
......@@ -196,9 +196,20 @@ def test_update_routines():
assert(reduce_msg_shapes == {(1, 8, D), (9, 1, D)})
reduce_msg_shapes.clear()
def _test_delete():
g = generate_graph()
ecol = Variable(th.randn(17, D), requires_grad=grad)
g.set_e_repr({'e' : ecol})
assert g.get_n_repr()['h'].shape[0] == 10
assert g.get_e_repr()['e'].shape[0] == 17
g.remove_node(0)
assert g.get_n_repr()['h'].shape[0] == 9
assert g.get_e_repr()['e'].shape[0] == 8
if __name__ == '__main__':
test_batch_setter_getter()
test_batch_setter_autograd()
test_batch_send()
test_batch_recv()
test_update_routines()
#test_delete()
import torch as th
from torch.autograd import Variable
import numpy as np
from dgl.frame import Frame
from dgl.frame import Frame, FrameRef
N = 10
D = 32
D = 5
def check_eq(a, b):
assert a.shape == b.shape
assert th.sum(a == b) == int(np.prod(list(a.shape)))
return a.shape == b.shape and np.allclose(a.numpy(), b.numpy())
def check_fail(fn):
try:
fn()
return False
except:
return True
def create_test_data(grad=False):
c1 = Variable(th.randn(N, D), requires_grad=grad)
......@@ -32,30 +38,101 @@ def test_create():
assert len(f1.schemes) == 0
assert f1.num_rows == 0
def test_col_getter_setter():
def test_column1():
# Test frame column getter/setter
data = create_test_data()
f = Frame(data)
check_eq(f['a1'], data['a1'])
assert f.num_rows == N
assert len(f) == 3
assert check_eq(f['a1'], data['a1'])
f['a1'] = data['a2']
check_eq(f['a2'], data['a2'])
assert check_eq(f['a2'], data['a2'])
# add a different length column should fail
def failed_add_col():
f['a4'] = th.zeros([N+1, D])
assert check_fail(failed_add_col)
# delete all the columns
del f['a1']
del f['a2']
assert len(f) == 1
del f['a3']
assert f.num_rows == 0
assert len(f) == 0
# add a different length column should succeed
f['a4'] = th.zeros([N+1, D])
assert f.num_rows == N+1
assert len(f) == 1
def test_column2():
# Test frameref column getter/setter
data = Frame(create_test_data())
f = FrameRef(data, [3, 4, 5, 6, 7])
assert f.num_rows == 5
assert len(f) == 3
assert check_eq(f['a1'], data['a1'][3:8])
# set column should reflect on the referenced data
f['a1'] = th.zeros([5, D])
assert check_eq(data['a1'][3:8], th.zeros([5, D]))
# add new column should be padded with zero
f['a4'] = th.ones([5, D])
assert len(data) == 4
assert check_eq(data['a4'][0:3], th.zeros([3, D]))
assert check_eq(data['a4'][3:8], th.ones([5, D]))
assert check_eq(data['a4'][8:10], th.zeros([2, D]))
def test_row_getter_setter():
def test_append1():
# test append API on Frame
data = create_test_data()
f = Frame(data)
f1 = Frame()
f2 = Frame(data)
f1.append(data)
assert f1.num_rows == N
f1.append(f2)
assert f1.num_rows == 2 * N
c1 = f1['a1']
assert c1.shape == (2 * N, D)
truth = th.cat([data['a1'], data['a1']])
assert check_eq(truth, c1)
def test_append2():
# test append on FrameRef
data = Frame(create_test_data())
f = FrameRef(data)
assert f.is_contiguous()
assert f.is_span_whole_column()
assert f.num_rows == N
# append on the underlying frame should not reflect on the ref
data.append(data)
assert f.is_contiguous()
assert not f.is_span_whole_column()
assert f.num_rows == N
# append on the FrameRef should work
f.append(data)
assert not f.is_contiguous()
assert not f.is_span_whole_column()
assert f.num_rows == 3 * N
new_idx = list(range(N)) + list(range(2*N, 4*N))
assert check_eq(f.index_tensor(), th.tensor(new_idx))
assert data.num_rows == 4 * N
def test_row1():
# test row getter/setter
data = create_test_data()
f = FrameRef(Frame(data))
# getter
# test non-duplicate keys
rowid = th.tensor([0, 2])
rows = f[rowid]
for k, v in rows:
for k, v in rows.items():
assert v.shape == (len(rowid), D)
check_eq(v, data[k][rowid])
assert check_eq(v, data[k][rowid])
# test duplicate keys
rowid = th.tensor([8, 2, 2, 1])
rows = f[rowid]
for k, v in rows:
for k, v in rows.items():
assert v.shape == (len(rowid), D)
check_eq(v, data[k][rowid])
assert check_eq(v, data[k][rowid])
# setter
rowid = th.tensor([0, 2, 4])
......@@ -64,12 +141,13 @@ def test_row_getter_setter():
'a3' : th.zeros((len(rowid), D)),
}
f[rowid] = vals
for k, v in f[rowid]:
check_eq(v, th.zeros((len(rowid), D)))
for k, v in f[rowid].items():
assert check_eq(v, th.zeros((len(rowid), D)))
def test_row_getter_setter_grad():
def test_row2():
# test row getter/setter autograd compatibility
data = create_test_data(grad=True)
f = Frame(data)
f = FrameRef(Frame(data))
# getter
c1 = f['a1']
......@@ -77,13 +155,13 @@ def test_row_getter_setter_grad():
rowid = th.tensor([0, 2])
rows = f[rowid]
rows['a1'].backward(th.ones((len(rowid), D)))
check_eq(c1.grad[:,0], th.tensor([1., 0., 1., 0., 0., 0., 0., 0., 0., 0.]))
assert check_eq(c1.grad[:,0], th.tensor([1., 0., 1., 0., 0., 0., 0., 0., 0., 0.]))
c1.grad.data.zero_()
# test duplicate keys
rowid = th.tensor([8, 2, 2, 1])
rows = f[rowid]
rows['a1'].backward(th.ones((len(rowid), D)))
check_eq(c1.grad[:,0], th.tensor([0., 1., 2., 0., 0., 0., 0., 0., 1., 0.]))
assert check_eq(c1.grad[:,0], th.tensor([0., 1., 2., 0., 0., 0., 0., 0., 1., 0.]))
c1.grad.data.zero_()
# setter
......@@ -96,26 +174,64 @@ def test_row_getter_setter_grad():
f[rowid] = vals
c11 = f['a1']
c11.backward(th.ones((N, D)))
check_eq(c1.grad[:,0], th.tensor([0., 1., 0., 1., 0., 1., 1., 1., 1., 1.]))
check_eq(vals['a1'].grad, th.ones((len(rowid), D)))
assert check_eq(c1.grad[:,0], th.tensor([0., 1., 0., 1., 0., 1., 1., 1., 1., 1.]))
assert check_eq(vals['a1'].grad, th.ones((len(rowid), D)))
assert vals['a2'].grad is None
def test_append():
data = create_test_data()
f1 = Frame()
f2 = Frame(data)
f1.append(data)
assert f1.num_rows == N
f1.append(f2)
assert f1.num_rows == 2 * N
c1 = f1['a1']
assert c1.shape == (2 * N, D)
truth = th.cat([data['a1'], data['a1']])
check_eq(truth, c1)
def test_row3():
# test row delete
data = Frame(create_test_data())
f = FrameRef(data)
assert f.is_contiguous()
assert f.is_span_whole_column()
assert f.num_rows == N
del f[th.tensor([2, 3])]
assert not f.is_contiguous()
assert not f.is_span_whole_column()
# delete is lazy: only reflect on the ref while the
# underlying storage should not be touched
assert f.num_rows == N - 2
assert data.num_rows == N
newidx = list(range(N))
newidx.pop(2)
newidx.pop(2)
for k, v in f.items():
assert check_eq(v, data[k][th.tensor(newidx)])
def test_sharing():
data = Frame(create_test_data())
f1 = FrameRef(data, index=[0, 1, 2, 3])
f2 = FrameRef(data, index=[2, 3, 4, 5, 6])
# test read
for k, v in f1.items():
assert check_eq(data[k][0:4], v)
for k, v in f2.items():
assert check_eq(data[k][2:7], v)
f2_a1 = f2['a1']
# test write
# update own ref should not been seen by the other.
f1[th.tensor([0, 1])] = {
'a1' : th.zeros([2, D]),
'a2' : th.zeros([2, D]),
'a3' : th.zeros([2, D]),
}
assert check_eq(f2['a1'], f2_a1)
# update shared space should been seen by the other.
f1[th.tensor([2, 3])] = {
'a1' : th.ones([2, D]),
'a2' : th.ones([2, D]),
'a3' : th.ones([2, D]),
}
f2_a1[0:2] = th.ones([2, D])
assert check_eq(f2['a1'], f2_a1)
if __name__ == '__main__':
test_create()
test_col_getter_setter()
test_append()
test_row_getter_setter()
test_row_getter_setter_grad()
test_column1()
test_column2()
test_append1()
test_append2()
test_row1()
test_row2()
test_row3()
test_sharing()
import torch as th
import networkx as nx
import numpy as np
import dgl
D = 5
def check_eq(a, b):
return a.shape == b.shape and np.allclose(a.numpy(), b.numpy())
def test_line_graph():
N = 5
G = dgl.DGLGraph(nx.star_graph(N))
G.set_e_repr(th.randn((2*N, D)))
n_edges = len(G.edges)
L = dgl.line_graph(G)
assert L.number_of_nodes() == 2*N
# update node features on line graph should reflect to edge features on
# original graph.
u = [0, 0, 2, 3]
v = [1, 2, 0, 0]
eid = G.get_edge_id(u, v)
L.set_n_repr(th.zeros((4, D)), eid)
assert check_eq(G.get_e_repr(u, v), th.zeros((4, D)))
# adding a new node feature on line graph should also reflect to a new
# edge feature on original graph
data = th.randn(n_edges, D)
L.set_n_repr({'w': data})
assert check_eq(G.get_e_repr()['w'], data)
def test_no_backtracking():
N = 5
G = dgl.DGLGraph(nx.star_graph(N))
G.set_e_repr(th.randn((2*N, D)))
L = dgl.line_graph(G, no_backtracking=True)
assert L.number_of_nodes() == 2*N
for i in range(1, N):
e1 = G.get_edge_id(0, i)
e2 = G.get_edge_id(i, 0)
assert not L.has_edge(e1, e2)
assert not L.has_edge(e2, e1)
if __name__ == '__main__':
test_line_graph()
test_no_backtracking()
import torch as th
from torch.autograd import Variable
import numpy as np
from dgl.graph import DGLGraph
D = 5
def check_eq(a, b):
return a.shape == b.shape and np.allclose(a.numpy(), b.numpy())
def generate_graph(grad=False):
g = DGLGraph()
for i in range(10):
g.add_node(i) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
ncol = Variable(th.randn(10, D), requires_grad=grad)
ecol = Variable(th.randn(17, D), requires_grad=grad)
g.set_n_repr({'h' : ncol})
g.set_e_repr({'l' : ecol})
return g
def test_subgraph():
g = generate_graph()
h = g.get_n_repr()['h']
l = g.get_e_repr()['l']
sg = g.subgraph([0, 2, 3, 6, 7, 9])
sh = sg.get_n_repr()['h']
check_eq(h[th.tensor([0, 2, 3, 6, 7, 9])], sh)
'''
s, d, eid
0, 1, 0
1, 9, 1
0, 2, 2
2, 9, 3
0, 3, 4
3, 9, 5
0, 4, 6
4, 9, 7
0, 5, 8
5, 9, 9
0, 6, 10
6, 9, 11
0, 7, 12
7, 9, 13
0, 8, 14
8, 9, 15
9, 0, 16
'''
eid = th.tensor([2, 3, 4, 5, 10, 11, 12, 13, 16])
check_eq(l[eid], sg.get_e_repr()['l'])
if __name__ == '__main__':
test_subgraph()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment