"examples/vscode:/vscode.git/clone" did not exist on "beb1c017adca5b090e656ebe3fdb7f64215aefa0"
Commit b7eb1659 authored by Gan Quan's avatar Gan Quan Committed by Minjie Wang
Browse files

Fix 0deg update_to and Tree-LSTM model (#51)

* WIP

* WIP

* treelstm dataloader

* Main training loop.

* trainable treelstm script

* fix dependency

* cuda training

* Add tensorized topological traversal

* allowing update_to() with no incoming messages

* fixing partial cases
parent 5e75f5db
......@@ -19,15 +19,17 @@ _urls = {
'pubmed' : 'https://www.dropbox.com/s/fj5q6pi66xhymcm/pubmed.zip?dl=1',
}
class GCNDataset(object):
class CitationGraphDataset(object):
def __init__(self, name):
self.name = name
self.mode = mode
self.dir = get_download_dir()
self.zip_file_path='{}/{}.zip'.format(self.dir, name)
download(_urls[name], path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/{}'.format(self.dir, name))
self._load()
def load(self):
def _load(self):
"""Loads input data from gcn/data directory
ind.name.x => the feature vectors of the training instances as scipy.sparse.csr.csr_matrix object;
......@@ -87,13 +89,6 @@ class GCNDataset(object):
val_mask = _sample_mask(idx_val, labels.shape[0])
test_mask = _sample_mask(idx_test, labels.shape[0])
#y_train = np.zeros(labels.shape)
#y_val = np.zeros(labels.shape)
#y_test = np.zeros(labels.shape)
#y_train[train_mask, :] = labels[train_mask, :]
#y_val[val_mask, :] = labels[val_mask, :]
#y_test[test_mask, :] = labels[test_mask, :]
self.graph = graph
self.features = _preprocess_features(features)
self.labels = labels
......@@ -112,6 +107,12 @@ class GCNDataset(object):
print(' NumValidationSamples: {}'.format(len(np.nonzero(self.val_mask)[0])))
print(' NumTestSamples: {}'.format(len(np.nonzero(self.test_mask)[0])))
def __getitem__(self, idx):
return self
def __len__(self):
return 1
def _preprocess_features(features):
"""Row-normalize feature matrix and convert to tuple representation"""
rowsum = np.array(features.sum(1))
......@@ -135,18 +136,15 @@ def _sample_mask(idx, l):
return mask
def load_cora():
data = GCNDataset('cora')
data.load()
data = CitationGraphDataset('cora')
return data
def load_citeseer():
data = GCNDataset('citeseer')
data.load()
data = CitationGraphDataset('citeseer')
return data
def load_pubmed():
data = GCNDataset('pubmed')
data.load()
data = CitationGraphDataset('pubmed')
return data
class GCNSyntheticDataset(object):
......@@ -196,6 +194,12 @@ class GCNSyntheticDataset(object):
print(' NumValidationSamples: {}'.format(len(np.nonzero(self.val_mask)[0])))
print(' NumTestSamples: {}'.format(len(np.nonzero(self.test_mask)[0])))
def __getitem__(self, idx):
return self
def __len__(self):
return 1
def get_gnp_generator(args):
n = args.syn_gnp_n
p = (2 * np.log(n) / n) if args.syn_gnp_p == 0. else args.syn_gnp_p
......
"""Tree-structured data.
Including:
- Stanford Sentiment Treebank
"""
from __future__ import absolute_import
from collections import namedtuple
from nltk.tree import Tree
from nltk.corpus.reader import BracketParseCorpusReader
import networkx as nx
import dgl
import dgl.backend as F
from dgl.data.utils import download, extract_archive, get_download_dir
_urls = {
'sst' : 'https://www.dropbox.com/s/dw8kr2vuq7k4dqi/sst.zip?dl=1',
}
SSTBatch = namedtuple('SSTBatch', ['graph', 'nid_with_word', 'wordid', 'label'])
class SST(object):
"""SST"""
PAD_WORD=-1
def __init__(self, mode='train', vocab_file=None):
self.mode = mode
self.dir = get_download_dir()
self.zip_file_path='{}/sst.zip'.format(self.dir)
self.vocab_file = '{}/sst/vocab.txt'.format(self.dir) if vocab_file is None else vocab_file
download(_urls['sst'], path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/sst'.format(self.dir))
self.trees = []
self.num_classes = 5
print('Preprocessing...')
self._load()
print('Dataset creation finished. #Trees:', len(self.trees))
def _load(self):
files = ['{}.txt'.format(self.mode)]
corpus = BracketParseCorpusReader('{}/sst'.format(self.dir), files)
sents = corpus.parsed_sents(files[0])
# load vocab file
self.vocab = {}
with open(self.vocab_file) as vf:
for line in vf.readlines():
line = line.strip()
self.vocab[line] = len(self.vocab)
# build trees
for sent in sents:
self.trees.append(self._build_tree(sent))
def _build_tree(self, root):
g = nx.DiGraph()
def _rec_build(nid, node):
for child in node:
cid = g.number_of_nodes()
if isinstance(child[0], str):
# leaf node
word = self.vocab[child[0].lower()]
g.add_node(cid, x=word, y=int(child.label()))
else:
g.add_node(cid, x=SST.PAD_WORD, y=int(child.label()))
_rec_build(cid, child)
g.add_edge(cid, nid)
# add root
g.add_node(0, x=SST.PAD_WORD, y=int(root.label()))
_rec_build(0, root)
return dgl.DGLGraph(g)
def __getitem__(self, idx):
return self.trees[idx]
def __len__(self):
return len(self.trees)
@property
def num_vocabs(self):
return len(self.vocab)
@staticmethod
def batcher(batch):
nid_with_word = []
wordid = []
label = []
gnid = 0
for tree in batch:
for nid in range(tree.number_of_nodes()):
if tree.nodes[nid]['x'] != SST.PAD_WORD:
nid_with_word.append(gnid)
wordid.append(tree.nodes[nid]['x'])
label.append(tree.nodes[nid]['y'])
gnid += 1
batch_trees = dgl.batch(batch)
return SSTBatch(graph=batch_trees,
nid_with_word=F.tensor(nid_with_word, dtype=F.int64),
wordid=F.tensor(wordid, dtype=F.int64),
label=F.tensor(label, dtype=F.int64))
......@@ -93,20 +93,20 @@ class DGLGraph(DiGraph):
else:
u = utils.toindex(u)
num_nodes = len(u)
if isinstance(hu, dict):
if utils.is_dict_like(hu):
for key, val in hu.items():
assert F.shape(val)[0] == num_nodes
else:
assert F.shape(hu)[0] == num_nodes
# set
if is_all(u):
if isinstance(hu, dict):
if utils.is_dict_like(hu):
for key, val in hu.items():
self._node_frame[key] = val
else:
self._node_frame[__REPR__] = hu
else:
if isinstance(hu, dict):
if utils.is_dict_like(hu):
self._node_frame[u] = hu
else:
self._node_frame[u] = {__REPR__ : hu}
......@@ -171,21 +171,21 @@ class DGLGraph(DiGraph):
u = utils.toindex(u)
v = utils.toindex(v)
num_edges = max(len(u), len(v))
if isinstance(h_uv, dict):
if utils.is_dict_like(h_uv):
for key, val in h_uv.items():
assert F.shape(val)[0] == num_edges
else:
assert F.shape(h_uv)[0] == num_edges
# set
if u_is_all:
if isinstance(h_uv, dict):
if utils.is_dict_like(h_uv):
for key, val in h_uv.items():
self._edge_frame[key] = val
else:
self._edge_frame[__REPR__] = h_uv
else:
eid = self.cached_graph.get_edge_id(u, v)
if isinstance(h_uv, dict):
if utils.is_dict_like(h_uv):
self._edge_frame[eid] = h_uv
else:
self._edge_frame[eid] = {__REPR__ : h_uv}
......@@ -206,20 +206,20 @@ class DGLGraph(DiGraph):
else:
eid = utils.toindex(eid)
num_edges = len(eid)
if isinstance(h_uv, dict):
if utils.is_dict_like(h_uv):
for key, val in h_uv.items():
assert F.shape(val)[0] == num_edges
else:
assert F.shape(h_uv)[0] == num_edges
# set
if is_all(eid):
if isinstance(h_uv, dict):
if utils.is_dict_like(h_uv):
for key, val in h_uv.items():
self._edge_frame[key] = val
else:
self._edge_frame[__REPR__] = h_uv
else:
if isinstance(h_uv, dict):
if utils.is_dict_like(h_uv):
self._edge_frame[eid] = h_uv
else:
self._edge_frame[eid] = {__REPR__ : h_uv}
......@@ -400,7 +400,7 @@ class DGLGraph(DiGraph):
src_reprs = self.get_n_repr(u)
edge_reprs = self.get_e_repr_by_id(eid)
msgs = message_func(src_reprs, edge_reprs)
if isinstance(msgs, dict):
if utils.is_dict_like(msgs):
self._msg_frame.append(msgs)
else:
self._msg_frame.append({__MSG__ : msgs})
......@@ -522,11 +522,11 @@ class DGLGraph(DiGraph):
def _nonbatch_recv(self, u, reduce_func, update_func):
f_reduce = _get_reduce_func(reduce_func)
f_update = update_func
if is_all(u):
u = list(range(0, self.number_of_nodes()))
else:
u = utils.toindex(u)
for i, uu in enumerate(utils.node_iter(u)):
# reduce phase
msgs_batch = [self.edges[vv, uu].pop(__MSG__)
......@@ -536,31 +536,31 @@ class DGLGraph(DiGraph):
else:
msgs_reduced = f_reduce(_get_repr(self.nodes[uu]), msgs_batch)
# update phase
ret = f_update(_get_repr(self.nodes[uu]), msgs_reduced)
ret = update_func(_get_repr(self.nodes[uu]), msgs_reduced)
_set_repr(self.nodes[uu], ret)
def _batch_recv(self, v, reduce_func, update_func):
f_update = update_func
if len(v) == 0:
# no vertex to be triggered.
return
null_v, reordered_v, all_reduced_msgs = self._batch_reduce(v, reduce_func)
if all_reduced_msgs is None:
# no message; only do recv.
if is_all(v):
self.set_n_repr(f_update(self.get_n_repr(), None))
self.set_n_repr(update_func(self.get_n_repr(), None))
else:
self.set_n_repr(f_update(self.get_n_repr(v), None), v)
self.set_n_repr(update_func(self.get_n_repr(v), None), v)
else:
# Read the node states in the degree-bucketing order.
# Compute new node repr for nodes with no in-coming messages.
if len(null_v) == 0:
null_ns = new_null_ns = None
new_null_ns = None
else:
null_ns = self.get_n_repr(null_v)
new_null_ns = f_update(null_ns, None)
new_null_ns = update_func(self.get_n_repr(null_v), None)
# Read the node states in the degree-bucketing order.
if len(reordered_v) == 0:
reordered_ns = new_reordered_ns = None
new_reordered_ns = None
else:
reordered_ns = self.get_n_repr(reordered_v)
new_reordered_ns = f_update(reordered_ns, all_reduced_msgs)
new_reordered_ns = update_func(self.get_n_repr(reordered_v), all_reduced_msgs)
v_tensor = utils.pack2(null_v.totensor(), reordered_v.totensor())
new_ns = utils.pack2(new_null_ns, new_reordered_ns)
......@@ -569,7 +569,7 @@ class DGLGraph(DiGraph):
_, indices = F.sort(v_tensor)
indices = utils.toindex(indices)
# TODO(minjie): following code should be included in Frame somehow.
if isinstance(new_ns, dict):
if utils.is_dict_like(new_ns):
for key, val in new_ns.items():
idx = indices.totensor(F.get_context(val))
self._node_frame[key] = F.gather_row(val, idx)
......@@ -581,16 +581,13 @@ class DGLGraph(DiGraph):
self.set_n_repr(new_ns, v_tensor)
def _batch_reduce(self, v, reduce_func):
if is_all(v) and len(self._msg_frame) == 0:
# no message has been sent
if self._msg_frame.num_rows == 0:
# no message has ever been sent
return None, None, None
if is_all(v):
v = list(range(self.number_of_nodes()))
# freeze message graph
self.msg_graph.freeze()
# sanity checks
v = utils.toindex(v)
f_reduce = _get_reduce_func(reduce_func)
......@@ -609,7 +606,7 @@ class DGLGraph(DiGraph):
null_v_bucket = v_bkt
continue
uu, vv = self.msg_graph.in_edges(v_bkt)
uu, vv, _ = self.msg_graph.in_edges(v_bkt)
in_msg_ids = self.msg_graph.get_edge_id(uu, vv)
in_msgs = self._msg_frame.select_rows(in_msg_ids)
# Reshape the column tensor to (B, Deg, ...).
......@@ -625,14 +622,13 @@ class DGLGraph(DiGraph):
non_null_v_buckets.append(v_bkt)
reduced_msgs.append(f_reduce(dst_reprs, reshaped_in_msgs))
# FIXME: this will only trigger if reduced_msgs is empty. Remove?
if len(reduced_msgs) == 0:
# no message has been sent to the specified node
return None, None, None
# TODO: clear partial messages
self.clear_messages()
# Read the node states in the degree-bucketing order.
null_v = utils.toindex(null_v_bucket or [])
reordered_v = utils.toindex(
......@@ -641,7 +637,7 @@ class DGLGraph(DiGraph):
)
# Pack all reduced msgs together
if isinstance(reduced_msgs[0], dict):
if utils.is_dict_like(reduced_msgs[0]):
keys = reduced_msgs[0].keys()
all_reduced_msgs = {
key : F.pack([msg[key] for msg in reduced_msgs])
......@@ -713,7 +709,10 @@ class DGLGraph(DiGraph):
message_func,
reduce_func,
update_func):
if is_all(u) and is_all(v):
if len(u) == 0:
# no message
assert len(v) == 0
elif is_all(u) and is_all(v):
self.update_all(message_func, reduce_func, update_func, True)
elif message_func == 'from_src' and reduce_func == 'sum':
# TODO(minjie): check the validity of edges u->v
......@@ -732,7 +731,6 @@ class DGLGraph(DiGraph):
dat = F.ones((len(u),))
n = self.number_of_nodes()
m = len(new2old)
# TODO(minjie): context
adjmat = F.sparse_tensor(idx, dat, [m, n])
ctx_adjmat = utils.CtxCachedObject(lambda ctx: F.to_context(adjmat, ctx))
# TODO(minjie): use lazy dict for reduced_msgs
......@@ -784,15 +782,18 @@ class DGLGraph(DiGraph):
assert update_func is not None
if batchable:
v = utils.toindex(v)
uu, vv = self.cached_graph.in_edges(v)
uu, vv, orphan = self.cached_graph.in_edges(v)
self._batch_update_by_edge(uu, vv, message_func,
reduce_func, update_func)
# trigger update function for nodes that have no incoming messages.
self._batch_recv(orphan, reduce_func, update_func)
else:
v = utils.toindex(v)
for vv in utils.node_iter(v):
assert vv in self.nodes
uu = list(self.pred[vv])
self._nonbatch_sendto(uu, vv, message_func)
if len(uu) > 0:
self._nonbatch_sendto(uu, vv, message_func)
self._nonbatch_recv(vv, reduce_func, update_func)
def update_from(self,
......@@ -827,7 +828,7 @@ class DGLGraph(DiGraph):
assert update_func is not None
if batchable:
u = utils.toindex(u)
uu, vv = self.cached_graph.out_edges(u)
uu, vv, _ = self.cached_graph.out_edges(u)
self._batch_update_by_edge(uu, vv, message_func,
reduce_func, update_func)
else:
......@@ -888,11 +889,11 @@ class DGLGraph(DiGraph):
self._nonbatch_recv(list(self.nodes()), reduce_func, update_func)
def propagate(self,
iterator='bfs',
message_func=None,
reduce_func=None,
update_func=None,
batchable=False,
iterator='bfs',
**kwargs):
"""Propagate messages and update nodes using iterator.
......@@ -1072,7 +1073,7 @@ def _get_repr(attr_dict):
return attr_dict
def _set_repr(attr_dict, attr):
if isinstance(attr, dict):
if utils.is_dict_like(attr):
attr_dict.update(attr)
else:
attr_dict[__REPR__] = attr
......
......@@ -64,6 +64,13 @@ class AdjInnerDict(MutableMapping):
def __iter__(self):
return iter(self._dict)
class AdjInnerDictFactory(object):
def __init__(self, cb1, cb2):
self._cb1 = cb1
self._cb2 = cb2
def __call__(self):
return AdjInnerDict(self._cb1, self._cb2)
def nx_init(obj,
add_node_cb,
add_edge_cb,
......@@ -88,7 +95,7 @@ def nx_init(obj,
"""
# The following codes work for networkx 2.1.
obj.adjlist_outer_dict_factory = None
obj.adjlist_inner_dict_factory = lambda : AdjInnerDict(add_edge_cb, del_edge_cb)
obj.adjlist_inner_dict_factory = AdjInnerDictFactory(add_edge_cb, del_edge_cb)
obj.edge_attr_dict_factory = dict
obj.root_graph = obj
......
......@@ -31,4 +31,5 @@ def degree_bucketing(cached_graph, v):
for deg in unique_degrees:
idx = np.where(degrees == deg)
v_bkt.append(utils.Index(v_np[idx]))
#print('degree-bucketing:', unique_degrees, [len(b) for b in v_bkt])
return unique_degrees, v_bkt
......@@ -290,6 +290,9 @@ def cached_member(func):
return func(self)
return wrapper
def is_dict_like(obj):
return isinstance(obj, Mapping)
def pack2(a, b):
if a is None:
return b
......
......@@ -250,6 +250,35 @@ def test_reduce_0deg():
assert th.allclose(new_repr[1:], old_repr[1:])
assert th.allclose(new_repr[0], old_repr.sum(0))
def test_update_to_0deg():
g = DGLGraph()
g.add_nodes_from([0, 1])
g.add_edge(0, 1)
def _message(src, edge):
return src
def _reduce(node, msgs):
assert msgs is not None
return msgs.sum(1)
def _update(node, accum):
return node * 2 if accum is None else accum
old_repr = th.randn(2, 5)
g.set_n_repr(old_repr)
g.update_to(0, _message, _reduce, _update, True)
new_repr = g.get_n_repr()
assert th.allclose(new_repr[0], old_repr[0] * 2)
assert th.allclose(new_repr[1], old_repr[1])
g.update_to(1, _message, _reduce, _update, True)
new_repr = g.get_n_repr()
assert th.allclose(new_repr[1], old_repr[0] * 2)
old_repr = th.randn(2, 5)
g.set_n_repr(old_repr)
g.update_to([0, 1], _message, _reduce, _update, True)
new_repr = g.get_n_repr()
assert th.allclose(new_repr[0], old_repr[0] * 2)
assert th.allclose(new_repr[1], old_repr[0])
def _test_delete():
g = generate_graph()
ecol = Variable(th.randn(17, D), requires_grad=grad)
......@@ -268,4 +297,5 @@ if __name__ == '__main__':
test_batch_recv2()
test_update_routines()
test_reduce_0deg()
test_update_to_0deg()
#test_delete()
......@@ -21,13 +21,15 @@ def test_basics():
u = Index(th.tensor([0, 0, 1, 1, 2, 2]))
v = Index(th.tensor([1, 2, 2, 3, 4, 5]))
check_eq(cg.get_edge_id(u, v).totensor(), th.tensor([0, 5, 1, 2, 3, 4]))
query = Index(th.tensor([1, 2]))
s, d = cg.in_edges(query)
check_eq(s.totensor(), th.tensor([0, 0, 1]))
check_eq(d.totensor(), th.tensor([1, 2, 2]))
s, d = cg.out_edges(query)
check_eq(s.totensor(), th.tensor([1, 1, 2, 2]))
check_eq(d.totensor(), th.tensor([2, 3, 4, 5]))
query = Index(th.tensor([0, 1, 2, 5]))
s, d, orphan = cg.in_edges(query)
check_eq(s.totensor(), th.tensor([0, 0, 1, 2]))
check_eq(d.totensor(), th.tensor([1, 2, 2, 5]))
assert orphan.tolist() == [0]
s, d, orphan = cg.out_edges(query)
check_eq(s.totensor(), th.tensor([0, 0, 1, 1, 2, 2]))
check_eq(d.totensor(), th.tensor([1, 2, 2, 3, 4, 5]))
assert orphan.tolist() == [5]
if __name__ == '__main__':
test_basics()
......@@ -75,6 +75,22 @@ def _test_update_routines(g):
g.update_all()
check(g, [56, 5, 5, 6, 7, 8, 9, 10, 11, 108])
def _test_update_to_0deg():
g = DGLGraph()
g.add_node(0, h=2)
g.add_node(1, h=1)
g.add_edge(0, 1)
def _message(src, edge):
return src
def _reduce(node, msgs):
assert msgs is not None
return msgs.sum(1)
def _update(node, accum):
assert accum is None
return {'h': node['h'] * 2}
g.update_to(0, _message, _reduce, _update)
assert g.nodes[0]['h'] == 4
def test_sendrecv():
g = generate_graph()
register1(g)
......@@ -99,6 +115,8 @@ def test_update_routines():
register2(g)
_test_update_routines(g)
_test_update_to_0deg()
if __name__ == '__main__':
test_sendrecv()
test_multi_sendrecv()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment