"examples/vscode:/vscode.git/clone" did not exist on "d1bea9e87ff9d473ce2de9f99fafe2d37af5dfa8"
Commit b7eb1659 authored by Gan Quan's avatar Gan Quan Committed by Minjie Wang
Browse files

Fix 0deg update_to and Tree-LSTM model (#51)

* WIP

* WIP

* treelstm dataloader

* Main training loop.

* trainable treelstm script

* fix dependency

* cuda training

* Add tensorized topological traversal

* allowing update_to() with no incoming messages

* fixing partial cases
parent 5e75f5db
...@@ -19,15 +19,17 @@ _urls = { ...@@ -19,15 +19,17 @@ _urls = {
'pubmed' : 'https://www.dropbox.com/s/fj5q6pi66xhymcm/pubmed.zip?dl=1', 'pubmed' : 'https://www.dropbox.com/s/fj5q6pi66xhymcm/pubmed.zip?dl=1',
} }
class GCNDataset(object): class CitationGraphDataset(object):
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
self.mode = mode
self.dir = get_download_dir() self.dir = get_download_dir()
self.zip_file_path='{}/{}.zip'.format(self.dir, name) self.zip_file_path='{}/{}.zip'.format(self.dir, name)
download(_urls[name], path=self.zip_file_path) download(_urls[name], path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/{}'.format(self.dir, name)) extract_archive(self.zip_file_path, '{}/{}'.format(self.dir, name))
self._load()
def load(self): def _load(self):
"""Loads input data from gcn/data directory """Loads input data from gcn/data directory
ind.name.x => the feature vectors of the training instances as scipy.sparse.csr.csr_matrix object; ind.name.x => the feature vectors of the training instances as scipy.sparse.csr.csr_matrix object;
...@@ -87,13 +89,6 @@ class GCNDataset(object): ...@@ -87,13 +89,6 @@ class GCNDataset(object):
val_mask = _sample_mask(idx_val, labels.shape[0]) val_mask = _sample_mask(idx_val, labels.shape[0])
test_mask = _sample_mask(idx_test, labels.shape[0]) test_mask = _sample_mask(idx_test, labels.shape[0])
#y_train = np.zeros(labels.shape)
#y_val = np.zeros(labels.shape)
#y_test = np.zeros(labels.shape)
#y_train[train_mask, :] = labels[train_mask, :]
#y_val[val_mask, :] = labels[val_mask, :]
#y_test[test_mask, :] = labels[test_mask, :]
self.graph = graph self.graph = graph
self.features = _preprocess_features(features) self.features = _preprocess_features(features)
self.labels = labels self.labels = labels
...@@ -112,6 +107,12 @@ class GCNDataset(object): ...@@ -112,6 +107,12 @@ class GCNDataset(object):
print(' NumValidationSamples: {}'.format(len(np.nonzero(self.val_mask)[0]))) print(' NumValidationSamples: {}'.format(len(np.nonzero(self.val_mask)[0])))
print(' NumTestSamples: {}'.format(len(np.nonzero(self.test_mask)[0]))) print(' NumTestSamples: {}'.format(len(np.nonzero(self.test_mask)[0])))
def __getitem__(self, idx):
return self
def __len__(self):
return 1
def _preprocess_features(features): def _preprocess_features(features):
"""Row-normalize feature matrix and convert to tuple representation""" """Row-normalize feature matrix and convert to tuple representation"""
rowsum = np.array(features.sum(1)) rowsum = np.array(features.sum(1))
...@@ -135,18 +136,15 @@ def _sample_mask(idx, l): ...@@ -135,18 +136,15 @@ def _sample_mask(idx, l):
return mask return mask
def load_cora(): def load_cora():
data = GCNDataset('cora') data = CitationGraphDataset('cora')
data.load()
return data return data
def load_citeseer(): def load_citeseer():
data = GCNDataset('citeseer') data = CitationGraphDataset('citeseer')
data.load()
return data return data
def load_pubmed(): def load_pubmed():
data = GCNDataset('pubmed') data = CitationGraphDataset('pubmed')
data.load()
return data return data
class GCNSyntheticDataset(object): class GCNSyntheticDataset(object):
...@@ -196,6 +194,12 @@ class GCNSyntheticDataset(object): ...@@ -196,6 +194,12 @@ class GCNSyntheticDataset(object):
print(' NumValidationSamples: {}'.format(len(np.nonzero(self.val_mask)[0]))) print(' NumValidationSamples: {}'.format(len(np.nonzero(self.val_mask)[0])))
print(' NumTestSamples: {}'.format(len(np.nonzero(self.test_mask)[0]))) print(' NumTestSamples: {}'.format(len(np.nonzero(self.test_mask)[0])))
def __getitem__(self, idx):
return self
def __len__(self):
return 1
def get_gnp_generator(args): def get_gnp_generator(args):
n = args.syn_gnp_n n = args.syn_gnp_n
p = (2 * np.log(n) / n) if args.syn_gnp_p == 0. else args.syn_gnp_p p = (2 * np.log(n) / n) if args.syn_gnp_p == 0. else args.syn_gnp_p
......
"""Tree-structured data.
Including:
- Stanford Sentiment Treebank
"""
from __future__ import absolute_import
from collections import namedtuple
from nltk.tree import Tree
from nltk.corpus.reader import BracketParseCorpusReader
import networkx as nx
import dgl
import dgl.backend as F
from dgl.data.utils import download, extract_archive, get_download_dir
_urls = {
'sst' : 'https://www.dropbox.com/s/dw8kr2vuq7k4dqi/sst.zip?dl=1',
}
SSTBatch = namedtuple('SSTBatch', ['graph', 'nid_with_word', 'wordid', 'label'])
class SST(object):
"""SST"""
PAD_WORD=-1
def __init__(self, mode='train', vocab_file=None):
self.mode = mode
self.dir = get_download_dir()
self.zip_file_path='{}/sst.zip'.format(self.dir)
self.vocab_file = '{}/sst/vocab.txt'.format(self.dir) if vocab_file is None else vocab_file
download(_urls['sst'], path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/sst'.format(self.dir))
self.trees = []
self.num_classes = 5
print('Preprocessing...')
self._load()
print('Dataset creation finished. #Trees:', len(self.trees))
def _load(self):
files = ['{}.txt'.format(self.mode)]
corpus = BracketParseCorpusReader('{}/sst'.format(self.dir), files)
sents = corpus.parsed_sents(files[0])
# load vocab file
self.vocab = {}
with open(self.vocab_file) as vf:
for line in vf.readlines():
line = line.strip()
self.vocab[line] = len(self.vocab)
# build trees
for sent in sents:
self.trees.append(self._build_tree(sent))
def _build_tree(self, root):
g = nx.DiGraph()
def _rec_build(nid, node):
for child in node:
cid = g.number_of_nodes()
if isinstance(child[0], str):
# leaf node
word = self.vocab[child[0].lower()]
g.add_node(cid, x=word, y=int(child.label()))
else:
g.add_node(cid, x=SST.PAD_WORD, y=int(child.label()))
_rec_build(cid, child)
g.add_edge(cid, nid)
# add root
g.add_node(0, x=SST.PAD_WORD, y=int(root.label()))
_rec_build(0, root)
return dgl.DGLGraph(g)
def __getitem__(self, idx):
return self.trees[idx]
def __len__(self):
return len(self.trees)
@property
def num_vocabs(self):
return len(self.vocab)
@staticmethod
def batcher(batch):
nid_with_word = []
wordid = []
label = []
gnid = 0
for tree in batch:
for nid in range(tree.number_of_nodes()):
if tree.nodes[nid]['x'] != SST.PAD_WORD:
nid_with_word.append(gnid)
wordid.append(tree.nodes[nid]['x'])
label.append(tree.nodes[nid]['y'])
gnid += 1
batch_trees = dgl.batch(batch)
return SSTBatch(graph=batch_trees,
nid_with_word=F.tensor(nid_with_word, dtype=F.int64),
wordid=F.tensor(wordid, dtype=F.int64),
label=F.tensor(label, dtype=F.int64))
...@@ -93,20 +93,20 @@ class DGLGraph(DiGraph): ...@@ -93,20 +93,20 @@ class DGLGraph(DiGraph):
else: else:
u = utils.toindex(u) u = utils.toindex(u)
num_nodes = len(u) num_nodes = len(u)
if isinstance(hu, dict): if utils.is_dict_like(hu):
for key, val in hu.items(): for key, val in hu.items():
assert F.shape(val)[0] == num_nodes assert F.shape(val)[0] == num_nodes
else: else:
assert F.shape(hu)[0] == num_nodes assert F.shape(hu)[0] == num_nodes
# set # set
if is_all(u): if is_all(u):
if isinstance(hu, dict): if utils.is_dict_like(hu):
for key, val in hu.items(): for key, val in hu.items():
self._node_frame[key] = val self._node_frame[key] = val
else: else:
self._node_frame[__REPR__] = hu self._node_frame[__REPR__] = hu
else: else:
if isinstance(hu, dict): if utils.is_dict_like(hu):
self._node_frame[u] = hu self._node_frame[u] = hu
else: else:
self._node_frame[u] = {__REPR__ : hu} self._node_frame[u] = {__REPR__ : hu}
...@@ -171,21 +171,21 @@ class DGLGraph(DiGraph): ...@@ -171,21 +171,21 @@ class DGLGraph(DiGraph):
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
num_edges = max(len(u), len(v)) num_edges = max(len(u), len(v))
if isinstance(h_uv, dict): if utils.is_dict_like(h_uv):
for key, val in h_uv.items(): for key, val in h_uv.items():
assert F.shape(val)[0] == num_edges assert F.shape(val)[0] == num_edges
else: else:
assert F.shape(h_uv)[0] == num_edges assert F.shape(h_uv)[0] == num_edges
# set # set
if u_is_all: if u_is_all:
if isinstance(h_uv, dict): if utils.is_dict_like(h_uv):
for key, val in h_uv.items(): for key, val in h_uv.items():
self._edge_frame[key] = val self._edge_frame[key] = val
else: else:
self._edge_frame[__REPR__] = h_uv self._edge_frame[__REPR__] = h_uv
else: else:
eid = self.cached_graph.get_edge_id(u, v) eid = self.cached_graph.get_edge_id(u, v)
if isinstance(h_uv, dict): if utils.is_dict_like(h_uv):
self._edge_frame[eid] = h_uv self._edge_frame[eid] = h_uv
else: else:
self._edge_frame[eid] = {__REPR__ : h_uv} self._edge_frame[eid] = {__REPR__ : h_uv}
...@@ -206,20 +206,20 @@ class DGLGraph(DiGraph): ...@@ -206,20 +206,20 @@ class DGLGraph(DiGraph):
else: else:
eid = utils.toindex(eid) eid = utils.toindex(eid)
num_edges = len(eid) num_edges = len(eid)
if isinstance(h_uv, dict): if utils.is_dict_like(h_uv):
for key, val in h_uv.items(): for key, val in h_uv.items():
assert F.shape(val)[0] == num_edges assert F.shape(val)[0] == num_edges
else: else:
assert F.shape(h_uv)[0] == num_edges assert F.shape(h_uv)[0] == num_edges
# set # set
if is_all(eid): if is_all(eid):
if isinstance(h_uv, dict): if utils.is_dict_like(h_uv):
for key, val in h_uv.items(): for key, val in h_uv.items():
self._edge_frame[key] = val self._edge_frame[key] = val
else: else:
self._edge_frame[__REPR__] = h_uv self._edge_frame[__REPR__] = h_uv
else: else:
if isinstance(h_uv, dict): if utils.is_dict_like(h_uv):
self._edge_frame[eid] = h_uv self._edge_frame[eid] = h_uv
else: else:
self._edge_frame[eid] = {__REPR__ : h_uv} self._edge_frame[eid] = {__REPR__ : h_uv}
...@@ -400,7 +400,7 @@ class DGLGraph(DiGraph): ...@@ -400,7 +400,7 @@ class DGLGraph(DiGraph):
src_reprs = self.get_n_repr(u) src_reprs = self.get_n_repr(u)
edge_reprs = self.get_e_repr_by_id(eid) edge_reprs = self.get_e_repr_by_id(eid)
msgs = message_func(src_reprs, edge_reprs) msgs = message_func(src_reprs, edge_reprs)
if isinstance(msgs, dict): if utils.is_dict_like(msgs):
self._msg_frame.append(msgs) self._msg_frame.append(msgs)
else: else:
self._msg_frame.append({__MSG__ : msgs}) self._msg_frame.append({__MSG__ : msgs})
...@@ -522,11 +522,11 @@ class DGLGraph(DiGraph): ...@@ -522,11 +522,11 @@ class DGLGraph(DiGraph):
def _nonbatch_recv(self, u, reduce_func, update_func): def _nonbatch_recv(self, u, reduce_func, update_func):
f_reduce = _get_reduce_func(reduce_func) f_reduce = _get_reduce_func(reduce_func)
f_update = update_func
if is_all(u): if is_all(u):
u = list(range(0, self.number_of_nodes())) u = list(range(0, self.number_of_nodes()))
else: else:
u = utils.toindex(u) u = utils.toindex(u)
for i, uu in enumerate(utils.node_iter(u)): for i, uu in enumerate(utils.node_iter(u)):
# reduce phase # reduce phase
msgs_batch = [self.edges[vv, uu].pop(__MSG__) msgs_batch = [self.edges[vv, uu].pop(__MSG__)
...@@ -536,31 +536,31 @@ class DGLGraph(DiGraph): ...@@ -536,31 +536,31 @@ class DGLGraph(DiGraph):
else: else:
msgs_reduced = f_reduce(_get_repr(self.nodes[uu]), msgs_batch) msgs_reduced = f_reduce(_get_repr(self.nodes[uu]), msgs_batch)
# update phase # update phase
ret = f_update(_get_repr(self.nodes[uu]), msgs_reduced) ret = update_func(_get_repr(self.nodes[uu]), msgs_reduced)
_set_repr(self.nodes[uu], ret) _set_repr(self.nodes[uu], ret)
def _batch_recv(self, v, reduce_func, update_func): def _batch_recv(self, v, reduce_func, update_func):
f_update = update_func if len(v) == 0:
# no vertex to be triggered.
return
null_v, reordered_v, all_reduced_msgs = self._batch_reduce(v, reduce_func) null_v, reordered_v, all_reduced_msgs = self._batch_reduce(v, reduce_func)
if all_reduced_msgs is None: if all_reduced_msgs is None:
# no message; only do recv. # no message; only do recv.
if is_all(v): if is_all(v):
self.set_n_repr(f_update(self.get_n_repr(), None)) self.set_n_repr(update_func(self.get_n_repr(), None))
else: else:
self.set_n_repr(f_update(self.get_n_repr(v), None), v) self.set_n_repr(update_func(self.get_n_repr(v), None), v)
else: else:
# Read the node states in the degree-bucketing order. # Compute new node repr for nodes with no in-coming messages.
if len(null_v) == 0: if len(null_v) == 0:
null_ns = new_null_ns = None new_null_ns = None
else: else:
null_ns = self.get_n_repr(null_v) new_null_ns = update_func(self.get_n_repr(null_v), None)
new_null_ns = f_update(null_ns, None) # Read the node states in the degree-bucketing order.
if len(reordered_v) == 0: if len(reordered_v) == 0:
reordered_ns = new_reordered_ns = None new_reordered_ns = None
else: else:
reordered_ns = self.get_n_repr(reordered_v) new_reordered_ns = update_func(self.get_n_repr(reordered_v), all_reduced_msgs)
new_reordered_ns = f_update(reordered_ns, all_reduced_msgs)
v_tensor = utils.pack2(null_v.totensor(), reordered_v.totensor()) v_tensor = utils.pack2(null_v.totensor(), reordered_v.totensor())
new_ns = utils.pack2(new_null_ns, new_reordered_ns) new_ns = utils.pack2(new_null_ns, new_reordered_ns)
...@@ -569,7 +569,7 @@ class DGLGraph(DiGraph): ...@@ -569,7 +569,7 @@ class DGLGraph(DiGraph):
_, indices = F.sort(v_tensor) _, indices = F.sort(v_tensor)
indices = utils.toindex(indices) indices = utils.toindex(indices)
# TODO(minjie): following code should be included in Frame somehow. # TODO(minjie): following code should be included in Frame somehow.
if isinstance(new_ns, dict): if utils.is_dict_like(new_ns):
for key, val in new_ns.items(): for key, val in new_ns.items():
idx = indices.totensor(F.get_context(val)) idx = indices.totensor(F.get_context(val))
self._node_frame[key] = F.gather_row(val, idx) self._node_frame[key] = F.gather_row(val, idx)
...@@ -581,16 +581,13 @@ class DGLGraph(DiGraph): ...@@ -581,16 +581,13 @@ class DGLGraph(DiGraph):
self.set_n_repr(new_ns, v_tensor) self.set_n_repr(new_ns, v_tensor)
def _batch_reduce(self, v, reduce_func): def _batch_reduce(self, v, reduce_func):
if is_all(v) and len(self._msg_frame) == 0: if self._msg_frame.num_rows == 0:
# no message has been sent # no message has ever been sent
return None, None, None return None, None, None
if is_all(v): if is_all(v):
v = list(range(self.number_of_nodes())) v = list(range(self.number_of_nodes()))
# freeze message graph
self.msg_graph.freeze()
# sanity checks # sanity checks
v = utils.toindex(v) v = utils.toindex(v)
f_reduce = _get_reduce_func(reduce_func) f_reduce = _get_reduce_func(reduce_func)
...@@ -609,7 +606,7 @@ class DGLGraph(DiGraph): ...@@ -609,7 +606,7 @@ class DGLGraph(DiGraph):
null_v_bucket = v_bkt null_v_bucket = v_bkt
continue continue
uu, vv = self.msg_graph.in_edges(v_bkt) uu, vv, _ = self.msg_graph.in_edges(v_bkt)
in_msg_ids = self.msg_graph.get_edge_id(uu, vv) in_msg_ids = self.msg_graph.get_edge_id(uu, vv)
in_msgs = self._msg_frame.select_rows(in_msg_ids) in_msgs = self._msg_frame.select_rows(in_msg_ids)
# Reshape the column tensor to (B, Deg, ...). # Reshape the column tensor to (B, Deg, ...).
...@@ -625,14 +622,13 @@ class DGLGraph(DiGraph): ...@@ -625,14 +622,13 @@ class DGLGraph(DiGraph):
non_null_v_buckets.append(v_bkt) non_null_v_buckets.append(v_bkt)
reduced_msgs.append(f_reduce(dst_reprs, reshaped_in_msgs)) reduced_msgs.append(f_reduce(dst_reprs, reshaped_in_msgs))
# FIXME: this will only trigger if reduced_msgs is empty. Remove?
if len(reduced_msgs) == 0: if len(reduced_msgs) == 0:
# no message has been sent to the specified node # no message has been sent to the specified node
return None, None, None return None, None, None
# TODO: clear partial messages # TODO: clear partial messages
self.clear_messages() self.clear_messages()
# Read the node states in the degree-bucketing order. # Read the node states in the degree-bucketing order.
null_v = utils.toindex(null_v_bucket or []) null_v = utils.toindex(null_v_bucket or [])
reordered_v = utils.toindex( reordered_v = utils.toindex(
...@@ -641,7 +637,7 @@ class DGLGraph(DiGraph): ...@@ -641,7 +637,7 @@ class DGLGraph(DiGraph):
) )
# Pack all reduced msgs together # Pack all reduced msgs together
if isinstance(reduced_msgs[0], dict): if utils.is_dict_like(reduced_msgs[0]):
keys = reduced_msgs[0].keys() keys = reduced_msgs[0].keys()
all_reduced_msgs = { all_reduced_msgs = {
key : F.pack([msg[key] for msg in reduced_msgs]) key : F.pack([msg[key] for msg in reduced_msgs])
...@@ -713,7 +709,10 @@ class DGLGraph(DiGraph): ...@@ -713,7 +709,10 @@ class DGLGraph(DiGraph):
message_func, message_func,
reduce_func, reduce_func,
update_func): update_func):
if is_all(u) and is_all(v): if len(u) == 0:
# no message
assert len(v) == 0
elif is_all(u) and is_all(v):
self.update_all(message_func, reduce_func, update_func, True) self.update_all(message_func, reduce_func, update_func, True)
elif message_func == 'from_src' and reduce_func == 'sum': elif message_func == 'from_src' and reduce_func == 'sum':
# TODO(minjie): check the validity of edges u->v # TODO(minjie): check the validity of edges u->v
...@@ -732,7 +731,6 @@ class DGLGraph(DiGraph): ...@@ -732,7 +731,6 @@ class DGLGraph(DiGraph):
dat = F.ones((len(u),)) dat = F.ones((len(u),))
n = self.number_of_nodes() n = self.number_of_nodes()
m = len(new2old) m = len(new2old)
# TODO(minjie): context
adjmat = F.sparse_tensor(idx, dat, [m, n]) adjmat = F.sparse_tensor(idx, dat, [m, n])
ctx_adjmat = utils.CtxCachedObject(lambda ctx: F.to_context(adjmat, ctx)) ctx_adjmat = utils.CtxCachedObject(lambda ctx: F.to_context(adjmat, ctx))
# TODO(minjie): use lazy dict for reduced_msgs # TODO(minjie): use lazy dict for reduced_msgs
...@@ -784,15 +782,18 @@ class DGLGraph(DiGraph): ...@@ -784,15 +782,18 @@ class DGLGraph(DiGraph):
assert update_func is not None assert update_func is not None
if batchable: if batchable:
v = utils.toindex(v) v = utils.toindex(v)
uu, vv = self.cached_graph.in_edges(v) uu, vv, orphan = self.cached_graph.in_edges(v)
self._batch_update_by_edge(uu, vv, message_func, self._batch_update_by_edge(uu, vv, message_func,
reduce_func, update_func) reduce_func, update_func)
# trigger update function for nodes that have no incoming messages.
self._batch_recv(orphan, reduce_func, update_func)
else: else:
v = utils.toindex(v) v = utils.toindex(v)
for vv in utils.node_iter(v): for vv in utils.node_iter(v):
assert vv in self.nodes assert vv in self.nodes
uu = list(self.pred[vv]) uu = list(self.pred[vv])
self._nonbatch_sendto(uu, vv, message_func) if len(uu) > 0:
self._nonbatch_sendto(uu, vv, message_func)
self._nonbatch_recv(vv, reduce_func, update_func) self._nonbatch_recv(vv, reduce_func, update_func)
def update_from(self, def update_from(self,
...@@ -827,7 +828,7 @@ class DGLGraph(DiGraph): ...@@ -827,7 +828,7 @@ class DGLGraph(DiGraph):
assert update_func is not None assert update_func is not None
if batchable: if batchable:
u = utils.toindex(u) u = utils.toindex(u)
uu, vv = self.cached_graph.out_edges(u) uu, vv, _ = self.cached_graph.out_edges(u)
self._batch_update_by_edge(uu, vv, message_func, self._batch_update_by_edge(uu, vv, message_func,
reduce_func, update_func) reduce_func, update_func)
else: else:
...@@ -888,11 +889,11 @@ class DGLGraph(DiGraph): ...@@ -888,11 +889,11 @@ class DGLGraph(DiGraph):
self._nonbatch_recv(list(self.nodes()), reduce_func, update_func) self._nonbatch_recv(list(self.nodes()), reduce_func, update_func)
def propagate(self, def propagate(self,
iterator='bfs',
message_func=None, message_func=None,
reduce_func=None, reduce_func=None,
update_func=None, update_func=None,
batchable=False, batchable=False,
iterator='bfs',
**kwargs): **kwargs):
"""Propagate messages and update nodes using iterator. """Propagate messages and update nodes using iterator.
...@@ -1072,7 +1073,7 @@ def _get_repr(attr_dict): ...@@ -1072,7 +1073,7 @@ def _get_repr(attr_dict):
return attr_dict return attr_dict
def _set_repr(attr_dict, attr): def _set_repr(attr_dict, attr):
if isinstance(attr, dict): if utils.is_dict_like(attr):
attr_dict.update(attr) attr_dict.update(attr)
else: else:
attr_dict[__REPR__] = attr attr_dict[__REPR__] = attr
......
...@@ -64,6 +64,13 @@ class AdjInnerDict(MutableMapping): ...@@ -64,6 +64,13 @@ class AdjInnerDict(MutableMapping):
def __iter__(self): def __iter__(self):
return iter(self._dict) return iter(self._dict)
class AdjInnerDictFactory(object):
def __init__(self, cb1, cb2):
self._cb1 = cb1
self._cb2 = cb2
def __call__(self):
return AdjInnerDict(self._cb1, self._cb2)
def nx_init(obj, def nx_init(obj,
add_node_cb, add_node_cb,
add_edge_cb, add_edge_cb,
...@@ -88,7 +95,7 @@ def nx_init(obj, ...@@ -88,7 +95,7 @@ def nx_init(obj,
""" """
# The following codes work for networkx 2.1. # The following codes work for networkx 2.1.
obj.adjlist_outer_dict_factory = None obj.adjlist_outer_dict_factory = None
obj.adjlist_inner_dict_factory = lambda : AdjInnerDict(add_edge_cb, del_edge_cb) obj.adjlist_inner_dict_factory = AdjInnerDictFactory(add_edge_cb, del_edge_cb)
obj.edge_attr_dict_factory = dict obj.edge_attr_dict_factory = dict
obj.root_graph = obj obj.root_graph = obj
......
...@@ -31,4 +31,5 @@ def degree_bucketing(cached_graph, v): ...@@ -31,4 +31,5 @@ def degree_bucketing(cached_graph, v):
for deg in unique_degrees: for deg in unique_degrees:
idx = np.where(degrees == deg) idx = np.where(degrees == deg)
v_bkt.append(utils.Index(v_np[idx])) v_bkt.append(utils.Index(v_np[idx]))
#print('degree-bucketing:', unique_degrees, [len(b) for b in v_bkt])
return unique_degrees, v_bkt return unique_degrees, v_bkt
...@@ -290,6 +290,9 @@ def cached_member(func): ...@@ -290,6 +290,9 @@ def cached_member(func):
return func(self) return func(self)
return wrapper return wrapper
def is_dict_like(obj):
return isinstance(obj, Mapping)
def pack2(a, b): def pack2(a, b):
if a is None: if a is None:
return b return b
......
...@@ -250,6 +250,35 @@ def test_reduce_0deg(): ...@@ -250,6 +250,35 @@ def test_reduce_0deg():
assert th.allclose(new_repr[1:], old_repr[1:]) assert th.allclose(new_repr[1:], old_repr[1:])
assert th.allclose(new_repr[0], old_repr.sum(0)) assert th.allclose(new_repr[0], old_repr.sum(0))
def test_update_to_0deg():
g = DGLGraph()
g.add_nodes_from([0, 1])
g.add_edge(0, 1)
def _message(src, edge):
return src
def _reduce(node, msgs):
assert msgs is not None
return msgs.sum(1)
def _update(node, accum):
return node * 2 if accum is None else accum
old_repr = th.randn(2, 5)
g.set_n_repr(old_repr)
g.update_to(0, _message, _reduce, _update, True)
new_repr = g.get_n_repr()
assert th.allclose(new_repr[0], old_repr[0] * 2)
assert th.allclose(new_repr[1], old_repr[1])
g.update_to(1, _message, _reduce, _update, True)
new_repr = g.get_n_repr()
assert th.allclose(new_repr[1], old_repr[0] * 2)
old_repr = th.randn(2, 5)
g.set_n_repr(old_repr)
g.update_to([0, 1], _message, _reduce, _update, True)
new_repr = g.get_n_repr()
assert th.allclose(new_repr[0], old_repr[0] * 2)
assert th.allclose(new_repr[1], old_repr[0])
def _test_delete(): def _test_delete():
g = generate_graph() g = generate_graph()
ecol = Variable(th.randn(17, D), requires_grad=grad) ecol = Variable(th.randn(17, D), requires_grad=grad)
...@@ -268,4 +297,5 @@ if __name__ == '__main__': ...@@ -268,4 +297,5 @@ if __name__ == '__main__':
test_batch_recv2() test_batch_recv2()
test_update_routines() test_update_routines()
test_reduce_0deg() test_reduce_0deg()
test_update_to_0deg()
#test_delete() #test_delete()
...@@ -21,13 +21,15 @@ def test_basics(): ...@@ -21,13 +21,15 @@ def test_basics():
u = Index(th.tensor([0, 0, 1, 1, 2, 2])) u = Index(th.tensor([0, 0, 1, 1, 2, 2]))
v = Index(th.tensor([1, 2, 2, 3, 4, 5])) v = Index(th.tensor([1, 2, 2, 3, 4, 5]))
check_eq(cg.get_edge_id(u, v).totensor(), th.tensor([0, 5, 1, 2, 3, 4])) check_eq(cg.get_edge_id(u, v).totensor(), th.tensor([0, 5, 1, 2, 3, 4]))
query = Index(th.tensor([1, 2])) query = Index(th.tensor([0, 1, 2, 5]))
s, d = cg.in_edges(query) s, d, orphan = cg.in_edges(query)
check_eq(s.totensor(), th.tensor([0, 0, 1])) check_eq(s.totensor(), th.tensor([0, 0, 1, 2]))
check_eq(d.totensor(), th.tensor([1, 2, 2])) check_eq(d.totensor(), th.tensor([1, 2, 2, 5]))
s, d = cg.out_edges(query) assert orphan.tolist() == [0]
check_eq(s.totensor(), th.tensor([1, 1, 2, 2])) s, d, orphan = cg.out_edges(query)
check_eq(d.totensor(), th.tensor([2, 3, 4, 5])) check_eq(s.totensor(), th.tensor([0, 0, 1, 1, 2, 2]))
check_eq(d.totensor(), th.tensor([1, 2, 2, 3, 4, 5]))
assert orphan.tolist() == [5]
if __name__ == '__main__': if __name__ == '__main__':
test_basics() test_basics()
...@@ -75,6 +75,22 @@ def _test_update_routines(g): ...@@ -75,6 +75,22 @@ def _test_update_routines(g):
g.update_all() g.update_all()
check(g, [56, 5, 5, 6, 7, 8, 9, 10, 11, 108]) check(g, [56, 5, 5, 6, 7, 8, 9, 10, 11, 108])
def _test_update_to_0deg():
g = DGLGraph()
g.add_node(0, h=2)
g.add_node(1, h=1)
g.add_edge(0, 1)
def _message(src, edge):
return src
def _reduce(node, msgs):
assert msgs is not None
return msgs.sum(1)
def _update(node, accum):
assert accum is None
return {'h': node['h'] * 2}
g.update_to(0, _message, _reduce, _update)
assert g.nodes[0]['h'] == 4
def test_sendrecv(): def test_sendrecv():
g = generate_graph() g = generate_graph()
register1(g) register1(g)
...@@ -99,6 +115,8 @@ def test_update_routines(): ...@@ -99,6 +115,8 @@ def test_update_routines():
register2(g) register2(g)
_test_update_routines(g) _test_update_routines(g)
_test_update_to_0deg()
if __name__ == '__main__': if __name__ == '__main__':
test_sendrecv() test_sendrecv()
test_multi_sendrecv() test_multi_sendrecv()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment