"...text-generation-inference.git" did not exist on "21267f3ca3f121302b86c1702cc2da6091164c55"
Unverified Commit 3721822e authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

Subgraph API (#39)

* subgraph

* more test cases

* WIP

* new FrameRef and test

* separate nx init code

* WIP

* subgraph code and test

* line graph code and test

* adding new test for adding new features on line graphs

* no backtracking line graph

* fix inplace relabel
parent 1f16f29b
from .base import ALL from .base import ALL
from .graph import DGLGraph
from .graph import __MSG__, __REPR__
from .context import cpu, gpu
from .batch import batch, unbatch from .batch import batch, unbatch
from .context import cpu, gpu
from .graph import DGLGraph, __MSG__, __REPR__
from .subgraph import DGLSubGraph
from .generator import *
...@@ -58,6 +58,7 @@ unsqueeze = th.unsqueeze ...@@ -58,6 +58,7 @@ unsqueeze = th.unsqueeze
reshape = th.reshape reshape = th.reshape
zeros = th.zeros zeros = th.zeros
ones = th.ones ones = th.ones
zeros = th.zeros
spmm = th.spmm spmm = th.spmm
sort = th.sort sort = th.sort
arange = th.arange arange = th.arange
......
"""Columnar storage for graph attributes.""" """Columnar storage for graph attributes."""
from __future__ import absolute_import from __future__ import absolute_import
from collections import MutableMapping
import numpy as np
import dgl.backend as F import dgl.backend as F
from dgl.backend import Tensor from dgl.backend import Tensor
from dgl.utils import LazyDict from dgl.utils import LazyDict
class Frame: class Frame(MutableMapping):
def __init__(self, data=None): def __init__(self, data=None):
if data is None: if data is None:
self._columns = dict() self._columns = dict()
self._num_rows = 0 self._num_rows = 0
else: else:
self._columns = data self._columns = dict(data)
self._num_rows = F.shape(list(data.values())[0])[0] self._num_rows = F.shape(list(data.values())[0])[0]
for k, v in data.items(): for k, v in data.items():
assert F.shape(v)[0] == self._num_rows assert F.shape(v)[0] == self._num_rows
...@@ -32,25 +35,18 @@ class Frame: ...@@ -32,25 +35,18 @@ class Frame:
return key in self._columns return key in self._columns
def __getitem__(self, key): def __getitem__(self, key):
if isinstance(key, str): # get column
return self._columns[key] return self._columns[key]
else:
return self.select_rows(key)
def __setitem__(self, key, val): def __setitem__(self, key, val):
if isinstance(key, str): # set column
self._columns[key] = val self.add_column(key, val)
else:
self.update_rows(key, val)
def __delitem__(self, key): def __delitem__(self, key):
# delete column # delete column
del self._columns[key] del self._columns[key]
if len(self._columns) == 0:
def pop(self, key): self._num_rows = 0
col = self._columns[key]
del self._columns[key]
return col
def add_column(self, name, col): def add_column(self, name, col):
if self.num_columns == 0: if self.num_columns == 0:
...@@ -60,35 +56,164 @@ class Frame: ...@@ -60,35 +56,164 @@ class Frame:
self._columns[name] = col self._columns[name] = col
def append(self, other): def append(self, other):
if not isinstance(other, Frame):
other = Frame(data=other)
if len(self._columns) == 0: if len(self._columns) == 0:
self._columns = other._columns for key, col in other.items():
self._num_rows = other._num_rows self._columns[key] = col
else: else:
assert self.schemes == other.schemes for key, col in other.items():
self._columns = {key : F.pack([self[key], other[key]]) for key in self._columns} self._columns[key] = F.pack([self[key], col])
self._num_rows += other._num_rows # TODO(minjie): sanity check for num_rows
if len(self._columns) != 0:
self._num_rows = F.shape(list(self._columns.values())[0])[0]
def clear(self): def clear(self):
self._columns = {} self._columns = {}
self._num_rows = 0 self._num_rows = 0
def select_rows(self, rowids): def __iter__(self):
return iter(self._columns)
def __len__(self):
return self.num_columns
class FrameRef(MutableMapping):
def __init__(self, frame=None, index=None):
self._frame = frame if frame is not None else Frame()
if index is None:
self._index = slice(0, self._frame.num_rows)
else:
# check no duplicate index
assert len(index) == len(np.unique(index))
self._index = index
self._index_tensor = None
@property
def schemes(self):
return self._frame.schemes
@property
def num_columns(self):
return self._frame.num_columns
@property
def num_rows(self):
if isinstance(self._index, slice):
return self._index.stop
else:
return len(self._index)
def __contains__(self, key):
return key in self._frame
def __getitem__(self, key):
if isinstance(key, str):
return self.get_column(key)
else:
return self.select_rows(key)
def select_rows(self, query):
rowids = self._getrowid(query)
def _lazy_select(key): def _lazy_select(key):
return F.gather_row(self._columns[key], rowids) return F.gather_row(self._frame[key], rowids)
return LazyDict(_lazy_select, keys=self._columns.keys()) return LazyDict(_lazy_select, keys=self.schemes)
def get_column(self, name):
col = self._frame[name]
if self.is_span_whole_column():
return col
else:
return F.gather_row(col, self.index_tensor())
def update_rows(self, rowids, other): def __setitem__(self, key, val):
if not isinstance(other, Frame): if isinstance(key, str):
other = Frame(data=other) self.add_column(key, val)
for key in other.schemes: else:
assert key in self._columns self.update_rows(key, val)
self._columns[key] = F.scatter_row(self[key], rowids, other[key])
def add_column(self, name, col):
shp = F.shape(col)
if self.is_span_whole_column():
if self.num_columns == 0:
self._index = slice(0, shp[0])
self._clear_cache()
assert shp[0] == self.num_rows
self._frame[name] = col
else:
if name in self._frame:
fcol = self._frame[name]
else:
fcol = F.zeros((self._frame.num_rows,) + shp[1:])
newfcol = F.scatter_row(fcol, self.index_tensor(), col)
self._frame[name] = newfcol
def update_rows(self, query, other):
rowids = self._getrowid(query)
for key, col in other.items():
self._frame[key] = F.scatter_row(self._frame[key], rowids, col)
def __delitem__(self, key):
if isinstance(key, str):
del self._frame[key]
if len(self._frame) == 0:
self.clear()
else:
self.delete_rows(key)
def delete_rows(self, query):
query = F.asnumpy(query)
if isinstance(self._index, slice):
self._index = list(range(self._index.start, self._index.stop))
arr = np.array(self._index, dtype=np.int32)
self._index = list(np.delete(arr, query))
self._clear_cache()
def append(self, other):
span_whole = self.is_span_whole_column()
contiguous = self.is_contiguous()
old_nrows = self._frame.num_rows
self._frame.append(other)
# update index
if span_whole:
self._index = slice(0, self._frame.num_rows)
else:
new_idx = list(range(self._index.start, self._index.stop))
new_idx += list(range(old_nrows, self._frame.num_rows))
self._index = new_idx
self._clear_cache()
def clear(self):
self._frame.clear()
self._index = slice(0, 0)
self._clear_cache()
def __iter__(self): def __iter__(self):
for key, col in self._columns.items(): return iter(self._frame)
yield key, col
def __len__(self): def __len__(self):
return self.num_columns return self.num_columns
def is_contiguous(self):
# NOTE: this check could have false negative
return isinstance(self._index, slice)
def is_span_whole_column(self):
return self.is_contiguous() and self.num_rows == self._frame.num_rows
def _getrowid(self, query):
if isinstance(self._index, slice):
# shortcut for identical mapping
return query
else:
return F.gather_row(self.index_tensor(), query)
def index_tensor(self):
# TODO(minjie): context
if self._index_tensor is None:
if self.is_contiguous():
self._index_tensor = F.arange(self._index.stop, dtype=F.int64)
else:
self._index_tensor = F.tensor(self._index, dtype=F.int64)
return self._index_tensor
def _clear_cache(self):
self._index_tensor = None
"""Line graph generator."""
from __future__ import absolute_import
import networkx as nx
import numpy as np
import dgl.backend as F
from dgl.graph import DGLGraph
from dgl.frame import FrameRef
def line_graph(G, no_backtracking=False):
"""Create the line graph that shares the underlying features.
The node features of the result line graph will share the edge features
of the given graph.
Parameters
----------
G : DGLGraph
The input graph.
no_backtracking : bool
Whether the backtracking edges are included in the line graph.
If i~j and j~i are two edges in original graph G, then
(i,j)~(j,i) and (j,i)~(i,j) are the "backtracking" edges on
the line graph.
"""
L = nx.DiGraph()
for eid, from_node in enumerate(G.edge_list):
L.add_node(from_node)
for to_node in G.edges(from_node[1]):
if no_backtracking and to_node[1] == from_node[0]:
continue
L.add_edge(from_node, to_node)
relabel_map = {}
for i, e in enumerate(G.edge_list):
relabel_map[e] = i
nx.relabel.relabel_nodes(L, relabel_map, copy=False)
return DGLGraph(L, node_frame=G._edge_frame)
...@@ -2,63 +2,24 @@ ...@@ -2,63 +2,24 @@
""" """
from __future__ import absolute_import from __future__ import absolute_import
from collections import MutableMapping
import networkx as nx import networkx as nx
from networkx.classes.digraph import DiGraph from networkx.classes.digraph import DiGraph
import dgl
from dgl.base import ALL, is_all from dgl.base import ALL, is_all
import dgl.backend as F import dgl.backend as F
from dgl.backend import Tensor from dgl.backend import Tensor
import dgl.builtin as builtin import dgl.builtin as builtin
from dgl.cached_graph import CachedGraph, create_cached_graph from dgl.cached_graph import CachedGraph, create_cached_graph
import dgl.context as context import dgl.context as context
from dgl.frame import Frame from dgl.frame import FrameRef
from dgl.nx_adapt import nx_init
import dgl.scheduler as scheduler import dgl.scheduler as scheduler
import dgl.utils as utils import dgl.utils as utils
__MSG__ = "__MSG__" __MSG__ = "__MSG__"
__REPR__ = "__REPR__" __REPR__ = "__REPR__"
class _NodeDict(MutableMapping):
def __init__(self, cb):
self._dict = {}
self._cb = cb
def __setitem__(self, key, val):
if isinstance(val, _AdjInnerDict):
# This node dict is used as adj_outer_list
val.src = key
elif key not in self._dict:
self._cb(key)
self._dict[key] = val
def __getitem__(self, key):
return self._dict[key]
def __delitem__(self, key):
# FIXME: add callback
del self._dict[key]
def __len__(self):
return len(self._dict)
def __iter__(self):
return iter(self._dict)
class _AdjInnerDict(MutableMapping):
def __init__(self, cb):
self._dict = {}
self.src = None
self._cb = cb
def __setitem__(self, key, val):
if key not in self._dict:
self._cb(self.src, key)
self._dict[key] = val
def __getitem__(self, key):
return self._dict[key]
def __delitem__(self, key):
# FIXME: add callback
del self._dict[key]
def __len__(self):
return len(self._dict)
def __iter__(self):
return iter(self._dict)
class DGLGraph(DiGraph): class DGLGraph(DiGraph):
"""Base graph class specialized for neural networks on graphs. """Base graph class specialized for neural networks on graphs.
...@@ -67,42 +28,41 @@ class DGLGraph(DiGraph): ...@@ -67,42 +28,41 @@ class DGLGraph(DiGraph):
Parameters Parameters
---------- ----------
data : graph data graph_data : graph data
Data to initialize graph. Same as networkx's semantics. Data to initialize graph. Same as networkx's semantics.
node_frame : dgl.frame.Frame
Node feature storage.
edge_frame : dgl.frame.Frame
Edge feature storage.
attr : keyword arguments, optional attr : keyword arguments, optional
Attributes to add to graph as key=value pairs. Attributes to add to graph as key=value pairs.
""" """
def __init__(self, graph_data=None, **attr): def __init__(self,
# setup dict overlay graph_data=None,
self.node_dict_factory = lambda : _NodeDict(self._add_node_callback) node_frame=None,
# In networkx 2.1, DiGraph is not using this factory. Instead, the outer edge_frame=None,
# dict uses the same data structure as the node dict. **attr):
self.adjlist_outer_dict_factory = None # TODO(minjie): maintaining node/edge list is costly when graph is large.
self.adjlist_inner_dict_factory = lambda : _AdjInnerDict(self._add_edge_callback) self._edge_list = []
self.edge_attr_dict_factory = dict nx_init(self,
self._context = context.cpu() self._add_node_callback,
# call base class init self._add_edge_callback,
super(DGLGraph, self).__init__(graph_data, **attr) self._del_node_callback,
self._init_state() self._del_edge_callback,
graph_data,
def _init_state(self): **attr)
# cached graph and storage # cached graph and storage
self._cached_graph = None self._cached_graph = None
self._node_frame = Frame() self._node_frame = node_frame if node_frame is not None else FrameRef()
self._edge_frame = Frame() self._edge_frame = edge_frame if edge_frame is not None else FrameRef()
# other class members # other class members
self._msg_graph = None self._msg_graph = None
self._msg_frame = Frame() self._msg_frame = FrameRef()
self._message_func = None self._message_func = None
self._reduce_func = None self._reduce_func = None
self._update_func = None self._update_func = None
self._edge_func = None self._edge_func = None
self._edge_cb_state = True self._context = context.cpu()
self._edge_list = []
def clear(self):
super(DGLGraph, self).clear()
self._init_state()
def get_n_attr_list(self): def get_n_attr_list(self):
return self._node_frame.schemes return self._node_frame.schemes
...@@ -129,7 +89,7 @@ class DGLGraph(DiGraph): ...@@ -129,7 +89,7 @@ class DGLGraph(DiGraph):
The node(s). The node(s).
""" """
# sanity check # sanity check
if isinstance(u, str) and u == ALL: if is_all(u):
num_nodes = self.number_of_nodes() num_nodes = self.number_of_nodes()
else: else:
u = utils.convert_to_id_tensor(u, self.context) u = utils.convert_to_id_tensor(u, self.context)
...@@ -140,7 +100,7 @@ class DGLGraph(DiGraph): ...@@ -140,7 +100,7 @@ class DGLGraph(DiGraph):
else: else:
assert F.shape(hu)[0] == num_nodes assert F.shape(hu)[0] == num_nodes
# set # set
if isinstance(u, str) and u == ALL: if is_all(u):
if isinstance(hu, dict): if isinstance(hu, dict):
for key, val in hu.items(): for key, val in hu.items():
self._node_frame[key] = val self._node_frame[key] = val
...@@ -161,7 +121,7 @@ class DGLGraph(DiGraph): ...@@ -161,7 +121,7 @@ class DGLGraph(DiGraph):
u : node, container or tensor u : node, container or tensor
The node(s). The node(s).
""" """
if isinstance(u, str) and u == ALL: if is_all(u):
if len(self._node_frame) == 1 and __REPR__ in self._node_frame: if len(self._node_frame) == 1 and __REPR__ in self._node_frame:
return self._node_frame[__REPR__] return self._node_frame[__REPR__]
else: else:
...@@ -204,8 +164,8 @@ class DGLGraph(DiGraph): ...@@ -204,8 +164,8 @@ class DGLGraph(DiGraph):
The destination node(s). The destination node(s).
""" """
# sanity check # sanity check
u_is_all = isinstance(u, str) and u == ALL u_is_all = is_all(u)
v_is_all = isinstance(v, str) and v == ALL v_is_all = is_all(v)
assert u_is_all == v_is_all assert u_is_all == v_is_all
if u_is_all: if u_is_all:
num_edges = self.number_of_edges() num_edges = self.number_of_edges()
...@@ -244,7 +204,7 @@ class DGLGraph(DiGraph): ...@@ -244,7 +204,7 @@ class DGLGraph(DiGraph):
The edge id(s). The edge id(s).
""" """
# sanity check # sanity check
if isinstance(eid, str) and eid == ALL: if is_all(eid):
num_edges = self.number_of_edges() num_edges = self.number_of_edges()
else: else:
eid = utils.convert_to_id_tensor(eid, self.context) eid = utils.convert_to_id_tensor(eid, self.context)
...@@ -255,7 +215,7 @@ class DGLGraph(DiGraph): ...@@ -255,7 +215,7 @@ class DGLGraph(DiGraph):
else: else:
assert F.shape(h_uv)[0] == num_edges assert F.shape(h_uv)[0] == num_edges
# set # set
if isinstance(eid, str) and eid == ALL: if is_all(eid):
if isinstance(h_uv, dict): if isinstance(h_uv, dict):
for key, val in h_uv.items(): for key, val in h_uv.items():
self._edge_frame[key] = val self._edge_frame[key] = val
...@@ -278,8 +238,8 @@ class DGLGraph(DiGraph): ...@@ -278,8 +238,8 @@ class DGLGraph(DiGraph):
v : node, container or tensor v : node, container or tensor
The destination node(s). The destination node(s).
""" """
u_is_all = isinstance(u, str) and u == ALL u_is_all = is_all(u)
v_is_all = isinstance(v, str) and v == ALL v_is_all = is_all(v)
assert u_is_all == v_is_all assert u_is_all == v_is_all
if u_is_all: if u_is_all:
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame: if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame:
...@@ -313,7 +273,7 @@ class DGLGraph(DiGraph): ...@@ -313,7 +273,7 @@ class DGLGraph(DiGraph):
eid : int, container or tensor eid : int, container or tensor
The edge id(s). The edge id(s).
""" """
if isinstance(eid, str) and eid == ALL: if is_all(eid):
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame: if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame:
return self._edge_frame[__REPR__] return self._edge_frame[__REPR__]
else: else:
...@@ -956,6 +916,38 @@ class DGLGraph(DiGraph): ...@@ -956,6 +916,38 @@ class DGLGraph(DiGraph):
self.update_by_edge(u, v, self.update_by_edge(u, v,
message_func, reduce_func, update_func, batchable) message_func, reduce_func, update_func, batchable)
def subgraph(self, nodes):
"""Generate the subgraph among the given nodes.
The generated graph contains only the graph structure. The node/edge
features are not shared implicitly. Use `copy_from` to get node/edge
features from parent graph.
Parameters
----------
nodes : list, or iterable
A container of the nodes to construct subgraph.
Returns
-------
G : DGLGraph
The subgraph.
"""
return dgl.DGLSubGraph(self, nodes)
def copy_from(self, graph):
"""Copy node/edge features from the given graph.
All old features will be removed.
Parameters
----------
graph : DGLGraph
The graph to copy from.
"""
# TODO
pass
def draw(self): def draw(self):
"""Plot the graph using dot.""" """Plot the graph using dot."""
from networkx.drawing.nx_agraph import graphviz_layout from networkx.drawing.nx_agraph import graphviz_layout
...@@ -984,6 +976,28 @@ class DGLGraph(DiGraph): ...@@ -984,6 +976,28 @@ class DGLGraph(DiGraph):
self._msg_graph.add_nodes(self.number_of_nodes()) self._msg_graph.add_nodes(self.number_of_nodes())
self._msg_frame.clear() self._msg_frame.clear()
@property
def edge_list(self):
"""Return edges in the addition order."""
return self._edge_list
def get_edge_id(self, u, v):
"""Return the continuous edge id(s) assigned.
Parameters
----------
u : node, container or tensor
The source node(s).
v : node, container or tensor
The destination node(s).
Returns
-------
eid : tensor
The tensor contains edge id(s).
"""
return self.cached_graph.get_edge_id(u, v)
def _nodes_or_all(self, nodes): def _nodes_or_all(self, nodes):
return self.nodes() if nodes == ALL else nodes return self.nodes() if nodes == ALL else nodes
...@@ -991,22 +1005,29 @@ class DGLGraph(DiGraph): ...@@ -991,22 +1005,29 @@ class DGLGraph(DiGraph):
return self.edges() if edges == ALL else edges return self.edges() if edges == ALL else edges
def _add_node_callback(self, node): def _add_node_callback(self, node):
#print('New node:', node)
self._cached_graph = None self._cached_graph = None
def _add_edge_callback(self, u, v): def _del_node_callback(self, node):
# In networkx 2.1, two adjlists are maintained. One for succ, one for pred. #print('Del node:', node)
# We only record once for the succ addition. raise RuntimeError('Node removal is not supported currently.')
if self._edge_cb_state: node = utils.convert_to_id_tensor(node)
#print('New edge:', u, v) self._node_frame.delete_rows(node)
self._edge_list.append((u, v))
self._edge_cb_state = not self._edge_cb_state
self._cached_graph = None self._cached_graph = None
@property def _add_edge_callback(self, u, v):
def edge_list(self): #print('New edge:', u, v)
"""Return edges in the addition order.""" self._edge_list.append((u, v))
return self._edge_list self._cached_graph = None
def _del_edge_callback(self, u, v):
#print('Del edge:', u, v)
raise RuntimeError('Edge removal is not supported currently.')
u = utils.convert_to_id_tensor(u)
v = utils.convert_to_id_tensor(v)
eid = self.get_edge_id(u, v)
self._edge_frame.delete_rows(eid)
self._cached_graph = None
def _get_repr(attr_dict): def _get_repr(attr_dict):
if len(attr_dict) == 1 and __REPR__ in attr_dict: if len(attr_dict) == 1 and __REPR__ in attr_dict:
......
"""Utility functions for networkx adapter."""
from __future__ import absolute_import
from collections import MutableMapping
import networkx as nx
import networkx.convert as convert
class NodeDict(MutableMapping):
def __init__(self, add_cb, del_cb):
self._dict = {}
self._add_cb = add_cb
self._del_cb = del_cb
def __setitem__(self, key, val):
self._add_cb(key)
self._dict[key] = val
def __getitem__(self, key):
return self._dict[key]
def __delitem__(self, key):
self._del_cb(key)
del self._dict[key]
def __len__(self):
return len(self._dict)
def __iter__(self):
return iter(self._dict)
class AdjOuterDict(MutableMapping):
def __init__(self, add_cb, del_cb):
self._dict = {}
self._add_cb = add_cb
self._del_cb = del_cb
def __setitem__(self, key, val):
val.src = key
self._dict[key] = val
def __getitem__(self, key):
return self._dict[key]
def __delitem__(self, key):
for val in self._dict[key]:
self._del_cb(key, val)
del self._dict[key]
def __len__(self):
return len(self._dict)
def __iter__(self):
return iter(self._dict)
class AdjInnerDict(MutableMapping):
def __init__(self, add_cb, del_cb):
self._dict = {}
self.src = None
self._add_cb = add_cb
self._del_cb = del_cb
def __setitem__(self, key, val):
if self.src is not None and key not in self._dict:
self._add_cb(self.src, key)
self._dict[key] = val
def __getitem__(self, key):
return self._dict[key]
def __delitem__(self, key):
if self.src is not None:
self._del_cb(self.src, key)
del self._dict[key]
def __len__(self):
return len(self._dict)
def __iter__(self):
return iter(self._dict)
def nx_init(obj,
add_node_cb,
add_edge_cb,
del_node_cb,
del_edge_cb,
graph_data,
**attr):
"""Init the object to be compatible with networkx's DiGraph.
Parameters
----------
obj : any
The object to be init.
add_node_cb : callable
The callback function when node is added.
add_edge_cb : callable
The callback function when edge is added.
graph_data : graph data
Data to initialize graph. Same as networkx's semantics.
attr : keyword arguments, optional
Attributes to add to graph as key=value pairs.
"""
# The following codes work for networkx 2.1.
obj.adjlist_outer_dict_factory = None
obj.adjlist_inner_dict_factory = lambda : AdjInnerDict(add_edge_cb, del_edge_cb)
obj.edge_attr_dict_factory = dict
obj.root_graph = obj
obj.graph = {}
obj._node = NodeDict(add_node_cb, del_node_cb)
obj._adj = AdjOuterDict(add_edge_cb, del_edge_cb)
obj._pred = dict()
obj._succ = obj._adj
if graph_data is not None:
convert.to_networkx_graph(graph_data, create_using=obj)
obj.graph.update(attr)
"""DGLSubGraph"""
from __future__ import absolute_import
import networkx as nx
import dgl.backend as F
from dgl.frame import Frame, FrameRef
from dgl.graph import DGLGraph
from dgl.nx_adapt import nx_init
import dgl.utils as utils
class DGLSubGraph(DGLGraph):
# TODO(gaiyu): ReadOnlyGraph
def __init__(self,
parent,
nodes):
# create subgraph and relabel
nx_sg = nx.DiGraph.subgraph(parent, nodes)
# node id
# TODO(minjie): context
nid = F.tensor(nodes, dtype=F.int64)
# edge id
# TODO(minjie): slow, context
u, v = zip(*nx_sg.edges)
u = list(u)
v = list(v)
eid = parent.cached_graph.get_edge_id(u, v)
# relabel
self._node_mapping = utils.build_relabel_dict(nodes)
nx_sg = nx.relabel.relabel_nodes(nx_sg, self._node_mapping)
# init
self._edge_list = []
nx_init(self,
self._add_node_callback,
self._add_edge_callback,
self._del_node_callback,
self._del_edge_callback,
nx_sg,
**parent.graph)
# cached graph and storage
self._cached_graph = None
if parent._node_frame.num_rows == 0:
self._node_frame = FrameRef()
else:
self._node_frame = FrameRef(Frame(parent._node_frame[nid]))
if parent._edge_frame.num_rows == 0:
self._edge_frame = FrameRef()
else:
self._edge_frame = FrameRef(Frame(parent._edge_frame[eid]))
# other class members
self._msg_graph = None
self._msg_frame = FrameRef()
self._message_func = parent._message_func
self._reduce_func = parent._reduce_func
self._update_func = parent._update_func
self._edge_func = parent._edge_func
self._context = parent._context
...@@ -71,23 +71,40 @@ class LazyDict(Mapping): ...@@ -71,23 +71,40 @@ class LazyDict(Mapping):
self._fn = fn self._fn = fn
self._keys = keys self._keys = keys
def keys(self):
return self._keys
def __getitem__(self, key): def __getitem__(self, key):
assert key in self._keys if not key in self._keys:
raise KeyError(key)
return self._fn(key) return self._fn(key)
def __contains__(self, key): def __contains__(self, key):
return key in self._keys return key in self._keys
def __iter__(self): def __iter__(self):
for key in self._keys: return iter(self._keys)
yield key, self._fn(key)
def __len__(self): def __len__(self):
return len(self._keys) return len(self._keys)
class ReadOnlyDict(Mapping):
"""A readonly dictionary wrapper."""
def __init__(self, dict_like):
self._dict_like = dict_like
def keys(self):
return self._dict_like.keys()
def __getitem__(self, key):
return self._dict_like[key]
def __contains__(self, key):
return key in self._dict_like
def __iter__(self):
return iter(self._dict_like)
def __len__(self):
return len(self._dict_like)
def build_relabel_map(x): def build_relabel_map(x):
"""Relabel the input ids to continuous ids that starts from zero. """Relabel the input ids to continuous ids that starts from zero.
...@@ -113,6 +130,26 @@ def build_relabel_map(x): ...@@ -113,6 +130,26 @@ def build_relabel_map(x):
old_to_new[unique_x] = F.astype(F.arange(len(unique_x)), F.int64) old_to_new[unique_x] = F.astype(F.arange(len(unique_x)), F.int64)
return unique_x, old_to_new return unique_x, old_to_new
def build_relabel_dict(x):
"""Relabel the input ids to continuous ids that starts from zero.
The new id follows the order of the given node id list.
Parameters
----------
x : list
The input ids.
Returns
-------
relabel_dict : dict
Dict from old id to new id.
"""
relabel_dict = {}
for i, v in enumerate(x):
relabel_dict[v] = i
return relabel_dict
def edge_broadcasting(u, v): def edge_broadcasting(u, v):
"""Convert one-many and many-one edges to many-many.""" """Convert one-many and many-one edges to many-many."""
if len(u) != len(v) and len(u) == 1: if len(u) != len(v) and len(u) == 1:
......
...@@ -36,8 +36,8 @@ def generate_graph(grad=False): ...@@ -36,8 +36,8 @@ def generate_graph(grad=False):
g.add_edge(i, 9) g.add_edge(i, 9)
# add a back flow from 9 to 0 # add a back flow from 9 to 0
g.add_edge(9, 0) g.add_edge(9, 0)
col = Variable(th.randn(10, D), requires_grad=grad) ncol = Variable(th.randn(10, D), requires_grad=grad)
g.set_n_repr({'h' : col}) g.set_n_repr({'h' : ncol})
return g return g
def test_batch_setter_getter(): def test_batch_setter_getter():
...@@ -196,9 +196,20 @@ def test_update_routines(): ...@@ -196,9 +196,20 @@ def test_update_routines():
assert(reduce_msg_shapes == {(1, 8, D), (9, 1, D)}) assert(reduce_msg_shapes == {(1, 8, D), (9, 1, D)})
reduce_msg_shapes.clear() reduce_msg_shapes.clear()
def _test_delete():
g = generate_graph()
ecol = Variable(th.randn(17, D), requires_grad=grad)
g.set_e_repr({'e' : ecol})
assert g.get_n_repr()['h'].shape[0] == 10
assert g.get_e_repr()['e'].shape[0] == 17
g.remove_node(0)
assert g.get_n_repr()['h'].shape[0] == 9
assert g.get_e_repr()['e'].shape[0] == 8
if __name__ == '__main__': if __name__ == '__main__':
test_batch_setter_getter() test_batch_setter_getter()
test_batch_setter_autograd() test_batch_setter_autograd()
test_batch_send() test_batch_send()
test_batch_recv() test_batch_recv()
test_update_routines() test_update_routines()
#test_delete()
import torch as th import torch as th
from torch.autograd import Variable from torch.autograd import Variable
import numpy as np import numpy as np
from dgl.frame import Frame from dgl.frame import Frame, FrameRef
N = 10 N = 10
D = 32 D = 5
def check_eq(a, b): def check_eq(a, b):
assert a.shape == b.shape return a.shape == b.shape and np.allclose(a.numpy(), b.numpy())
assert th.sum(a == b) == int(np.prod(list(a.shape)))
def check_fail(fn):
try:
fn()
return False
except:
return True
def create_test_data(grad=False): def create_test_data(grad=False):
c1 = Variable(th.randn(N, D), requires_grad=grad) c1 = Variable(th.randn(N, D), requires_grad=grad)
...@@ -32,30 +38,101 @@ def test_create(): ...@@ -32,30 +38,101 @@ def test_create():
assert len(f1.schemes) == 0 assert len(f1.schemes) == 0
assert f1.num_rows == 0 assert f1.num_rows == 0
def test_col_getter_setter(): def test_column1():
# Test frame column getter/setter
data = create_test_data() data = create_test_data()
f = Frame(data) f = Frame(data)
check_eq(f['a1'], data['a1']) assert f.num_rows == N
assert len(f) == 3
assert check_eq(f['a1'], data['a1'])
f['a1'] = data['a2'] f['a1'] = data['a2']
check_eq(f['a2'], data['a2']) assert check_eq(f['a2'], data['a2'])
# add a different length column should fail
def failed_add_col():
f['a4'] = th.zeros([N+1, D])
assert check_fail(failed_add_col)
# delete all the columns
del f['a1']
del f['a2']
assert len(f) == 1
del f['a3']
assert f.num_rows == 0
assert len(f) == 0
# add a different length column should succeed
f['a4'] = th.zeros([N+1, D])
assert f.num_rows == N+1
assert len(f) == 1
def test_column2():
# Test frameref column getter/setter
data = Frame(create_test_data())
f = FrameRef(data, [3, 4, 5, 6, 7])
assert f.num_rows == 5
assert len(f) == 3
assert check_eq(f['a1'], data['a1'][3:8])
# set column should reflect on the referenced data
f['a1'] = th.zeros([5, D])
assert check_eq(data['a1'][3:8], th.zeros([5, D]))
# add new column should be padded with zero
f['a4'] = th.ones([5, D])
assert len(data) == 4
assert check_eq(data['a4'][0:3], th.zeros([3, D]))
assert check_eq(data['a4'][3:8], th.ones([5, D]))
assert check_eq(data['a4'][8:10], th.zeros([2, D]))
def test_row_getter_setter(): def test_append1():
# test append API on Frame
data = create_test_data() data = create_test_data()
f = Frame(data) f1 = Frame()
f2 = Frame(data)
f1.append(data)
assert f1.num_rows == N
f1.append(f2)
assert f1.num_rows == 2 * N
c1 = f1['a1']
assert c1.shape == (2 * N, D)
truth = th.cat([data['a1'], data['a1']])
assert check_eq(truth, c1)
def test_append2():
# test append on FrameRef
data = Frame(create_test_data())
f = FrameRef(data)
assert f.is_contiguous()
assert f.is_span_whole_column()
assert f.num_rows == N
# append on the underlying frame should not reflect on the ref
data.append(data)
assert f.is_contiguous()
assert not f.is_span_whole_column()
assert f.num_rows == N
# append on the FrameRef should work
f.append(data)
assert not f.is_contiguous()
assert not f.is_span_whole_column()
assert f.num_rows == 3 * N
new_idx = list(range(N)) + list(range(2*N, 4*N))
assert check_eq(f.index_tensor(), th.tensor(new_idx))
assert data.num_rows == 4 * N
def test_row1():
# test row getter/setter
data = create_test_data()
f = FrameRef(Frame(data))
# getter # getter
# test non-duplicate keys # test non-duplicate keys
rowid = th.tensor([0, 2]) rowid = th.tensor([0, 2])
rows = f[rowid] rows = f[rowid]
for k, v in rows: for k, v in rows.items():
assert v.shape == (len(rowid), D) assert v.shape == (len(rowid), D)
check_eq(v, data[k][rowid]) assert check_eq(v, data[k][rowid])
# test duplicate keys # test duplicate keys
rowid = th.tensor([8, 2, 2, 1]) rowid = th.tensor([8, 2, 2, 1])
rows = f[rowid] rows = f[rowid]
for k, v in rows: for k, v in rows.items():
assert v.shape == (len(rowid), D) assert v.shape == (len(rowid), D)
check_eq(v, data[k][rowid]) assert check_eq(v, data[k][rowid])
# setter # setter
rowid = th.tensor([0, 2, 4]) rowid = th.tensor([0, 2, 4])
...@@ -64,12 +141,13 @@ def test_row_getter_setter(): ...@@ -64,12 +141,13 @@ def test_row_getter_setter():
'a3' : th.zeros((len(rowid), D)), 'a3' : th.zeros((len(rowid), D)),
} }
f[rowid] = vals f[rowid] = vals
for k, v in f[rowid]: for k, v in f[rowid].items():
check_eq(v, th.zeros((len(rowid), D))) assert check_eq(v, th.zeros((len(rowid), D)))
def test_row_getter_setter_grad(): def test_row2():
# test row getter/setter autograd compatibility
data = create_test_data(grad=True) data = create_test_data(grad=True)
f = Frame(data) f = FrameRef(Frame(data))
# getter # getter
c1 = f['a1'] c1 = f['a1']
...@@ -77,13 +155,13 @@ def test_row_getter_setter_grad(): ...@@ -77,13 +155,13 @@ def test_row_getter_setter_grad():
rowid = th.tensor([0, 2]) rowid = th.tensor([0, 2])
rows = f[rowid] rows = f[rowid]
rows['a1'].backward(th.ones((len(rowid), D))) rows['a1'].backward(th.ones((len(rowid), D)))
check_eq(c1.grad[:,0], th.tensor([1., 0., 1., 0., 0., 0., 0., 0., 0., 0.])) assert check_eq(c1.grad[:,0], th.tensor([1., 0., 1., 0., 0., 0., 0., 0., 0., 0.]))
c1.grad.data.zero_() c1.grad.data.zero_()
# test duplicate keys # test duplicate keys
rowid = th.tensor([8, 2, 2, 1]) rowid = th.tensor([8, 2, 2, 1])
rows = f[rowid] rows = f[rowid]
rows['a1'].backward(th.ones((len(rowid), D))) rows['a1'].backward(th.ones((len(rowid), D)))
check_eq(c1.grad[:,0], th.tensor([0., 1., 2., 0., 0., 0., 0., 0., 1., 0.])) assert check_eq(c1.grad[:,0], th.tensor([0., 1., 2., 0., 0., 0., 0., 0., 1., 0.]))
c1.grad.data.zero_() c1.grad.data.zero_()
# setter # setter
...@@ -96,26 +174,64 @@ def test_row_getter_setter_grad(): ...@@ -96,26 +174,64 @@ def test_row_getter_setter_grad():
f[rowid] = vals f[rowid] = vals
c11 = f['a1'] c11 = f['a1']
c11.backward(th.ones((N, D))) c11.backward(th.ones((N, D)))
check_eq(c1.grad[:,0], th.tensor([0., 1., 0., 1., 0., 1., 1., 1., 1., 1.])) assert check_eq(c1.grad[:,0], th.tensor([0., 1., 0., 1., 0., 1., 1., 1., 1., 1.]))
check_eq(vals['a1'].grad, th.ones((len(rowid), D))) assert check_eq(vals['a1'].grad, th.ones((len(rowid), D)))
assert vals['a2'].grad is None assert vals['a2'].grad is None
def test_append(): def test_row3():
data = create_test_data() # test row delete
f1 = Frame() data = Frame(create_test_data())
f2 = Frame(data) f = FrameRef(data)
f1.append(data) assert f.is_contiguous()
assert f1.num_rows == N assert f.is_span_whole_column()
f1.append(f2) assert f.num_rows == N
assert f1.num_rows == 2 * N del f[th.tensor([2, 3])]
c1 = f1['a1'] assert not f.is_contiguous()
assert c1.shape == (2 * N, D) assert not f.is_span_whole_column()
truth = th.cat([data['a1'], data['a1']]) # delete is lazy: only reflect on the ref while the
check_eq(truth, c1) # underlying storage should not be touched
assert f.num_rows == N - 2
assert data.num_rows == N
newidx = list(range(N))
newidx.pop(2)
newidx.pop(2)
for k, v in f.items():
assert check_eq(v, data[k][th.tensor(newidx)])
def test_sharing():
data = Frame(create_test_data())
f1 = FrameRef(data, index=[0, 1, 2, 3])
f2 = FrameRef(data, index=[2, 3, 4, 5, 6])
# test read
for k, v in f1.items():
assert check_eq(data[k][0:4], v)
for k, v in f2.items():
assert check_eq(data[k][2:7], v)
f2_a1 = f2['a1']
# test write
# update own ref should not been seen by the other.
f1[th.tensor([0, 1])] = {
'a1' : th.zeros([2, D]),
'a2' : th.zeros([2, D]),
'a3' : th.zeros([2, D]),
}
assert check_eq(f2['a1'], f2_a1)
# update shared space should been seen by the other.
f1[th.tensor([2, 3])] = {
'a1' : th.ones([2, D]),
'a2' : th.ones([2, D]),
'a3' : th.ones([2, D]),
}
f2_a1[0:2] = th.ones([2, D])
assert check_eq(f2['a1'], f2_a1)
if __name__ == '__main__': if __name__ == '__main__':
test_create() test_create()
test_col_getter_setter() test_column1()
test_append() test_column2()
test_row_getter_setter() test_append1()
test_row_getter_setter_grad() test_append2()
test_row1()
test_row2()
test_row3()
test_sharing()
import torch as th
import networkx as nx
import numpy as np
import dgl
D = 5
def check_eq(a, b):
return a.shape == b.shape and np.allclose(a.numpy(), b.numpy())
def test_line_graph():
N = 5
G = dgl.DGLGraph(nx.star_graph(N))
G.set_e_repr(th.randn((2*N, D)))
n_edges = len(G.edges)
L = dgl.line_graph(G)
assert L.number_of_nodes() == 2*N
# update node features on line graph should reflect to edge features on
# original graph.
u = [0, 0, 2, 3]
v = [1, 2, 0, 0]
eid = G.get_edge_id(u, v)
L.set_n_repr(th.zeros((4, D)), eid)
assert check_eq(G.get_e_repr(u, v), th.zeros((4, D)))
# adding a new node feature on line graph should also reflect to a new
# edge feature on original graph
data = th.randn(n_edges, D)
L.set_n_repr({'w': data})
assert check_eq(G.get_e_repr()['w'], data)
def test_no_backtracking():
N = 5
G = dgl.DGLGraph(nx.star_graph(N))
G.set_e_repr(th.randn((2*N, D)))
L = dgl.line_graph(G, no_backtracking=True)
assert L.number_of_nodes() == 2*N
for i in range(1, N):
e1 = G.get_edge_id(0, i)
e2 = G.get_edge_id(i, 0)
assert not L.has_edge(e1, e2)
assert not L.has_edge(e2, e1)
if __name__ == '__main__':
test_line_graph()
test_no_backtracking()
import torch as th
from torch.autograd import Variable
import numpy as np
from dgl.graph import DGLGraph
D = 5
def check_eq(a, b):
return a.shape == b.shape and np.allclose(a.numpy(), b.numpy())
def generate_graph(grad=False):
g = DGLGraph()
for i in range(10):
g.add_node(i) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
ncol = Variable(th.randn(10, D), requires_grad=grad)
ecol = Variable(th.randn(17, D), requires_grad=grad)
g.set_n_repr({'h' : ncol})
g.set_e_repr({'l' : ecol})
return g
def test_subgraph():
g = generate_graph()
h = g.get_n_repr()['h']
l = g.get_e_repr()['l']
sg = g.subgraph([0, 2, 3, 6, 7, 9])
sh = sg.get_n_repr()['h']
check_eq(h[th.tensor([0, 2, 3, 6, 7, 9])], sh)
'''
s, d, eid
0, 1, 0
1, 9, 1
0, 2, 2
2, 9, 3
0, 3, 4
3, 9, 5
0, 4, 6
4, 9, 7
0, 5, 8
5, 9, 9
0, 6, 10
6, 9, 11
0, 7, 12
7, 9, 13
0, 8, 14
8, 9, 15
9, 0, 16
'''
eid = th.tensor([2, 3, 4, 5, 10, 11, 12, 13, 16])
check_eq(l[eid], sg.get_e_repr()['l'])
if __name__ == '__main__':
test_subgraph()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment