Commit a8a4fcba authored by Minjie Wang's avatar Minjie Wang
Browse files

quickly integrating with tree-lstm example

parent 882e2a7b
......@@ -8,23 +8,17 @@ from torch.utils.data import DataLoader
import dgl
import dgl.data as data
import dgl.ndarray as nd
from tree_lstm import TreeLSTM
def _batch_to_cuda(batch):
return data.SSTBatch(graph=batch.graph,
nid_with_word = batch.nid_with_word.cuda(),
wordid = batch.wordid.cuda(),
label = batch.label.cuda())
import dgl.context as ctx
def tensor_topo_traverse(g, cuda, args):
n = g.number_of_nodes()
if cuda:
adjmat = g.cached_graph.adjmat().get(ctx.gpu(args.gpu))
adjmat = g._graph.adjacency_matrix().get(nd.gpu(args.gpu))
mask = th.ones((n, 1)).cuda()
else:
adjmat = g.cached_graph.adjmat().get(ctx.cpu())
adjmat = g._graph.adjacency_matrix().get(nd.cpu())
mask = th.ones((n, 1))
degree = th.spmm(adjmat, mask)
while th.sum(mask) != 0.:
......@@ -39,10 +33,17 @@ def main(args):
cuda = args.gpu >= 0
if cuda:
th.cuda.set_device(args.gpu)
def _batcher(trees):
bg = dgl.batch(trees)
if cuda:
reprs = bg.get_n_repr()
reprs = {key : reprs[key].cuda()}
bg.set_n_repr(reprs)
return bg
trainset = data.SST()
train_loader = DataLoader(dataset=trainset,
batch_size=args.batch_size,
collate_fn=data.SST.batcher,
collate_fn=_batcher,
shuffle=False,
num_workers=0)
#testset = data.SST(mode='test')
......@@ -69,18 +70,15 @@ def main(args):
dur = []
for epoch in range(args.epochs):
t_epoch = time.time()
for step, batch in enumerate(train_loader):
g = batch.graph
if cuda:
batch = _batch_to_cuda(batch)
for step, graph in enumerate(train_loader):
if step >= 3:
t0 = time.time()
label = graph.pop_n_repr('y')
# traverse graph
giter = list(tensor_topo_traverse(g, False, args))
logits = model(batch, zero_initializer, iterator=giter, train=True)
giter = list(tensor_topo_traverse(graph, False, args))
logits = model(graph, zero_initializer, iterator=giter, train=True)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, batch.label)
loss = F.nll_loss(logp, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
......@@ -89,11 +87,11 @@ def main(args):
if step > 0 and step % args.log_every == 0:
pred = th.argmax(logits, 1)
acc = th.sum(th.eq(batch.label, pred))
acc = th.sum(th.eq(label, pred))
mean_dur = np.mean(dur)
print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | "
"Acc {:.4f} | Time(s) {:.4f} | Trees/s {:.4f}".format(
epoch, step, loss.item(), acc.item()/len(batch.label),
epoch, step, loss.item(), acc.item() / args.batch_size,
mean_dur, args.batch_size / mean_dur))
print("Epoch time(s):", time.time() - t_epoch)
......
......@@ -10,23 +10,7 @@ import torch as th
import torch.nn as nn
import torch.nn.functional as F
def topological_traverse(G):
indegree_map = {v: d for v, d in G.in_degree() if d > 0}
# These nodes have zero indegree and ready to be returned.
zero_indegree = [v for v, d in G.in_degree() if d == 0]
while True:
yield zero_indegree
next_zero_indegree = []
while zero_indegree:
node = zero_indegree.pop()
for _, child in G.edges(node):
indegree_map[child] -= 1
if indegree_map[child] == 0:
next_zero_indegree.append(child)
del indegree_map[child]
if len(next_zero_indegree) == 0:
break
zero_indegree = next_zero_indegree
import dgl
class ChildSumTreeLSTMCell(nn.Module):
def __init__(self, x_size, h_size):
......@@ -83,13 +67,13 @@ class TreeLSTM(nn.Module):
else:
raise RuntimeError('Unknown cell type:', cell_type)
def forward(self, batch, zero_initializer, h=None, c=None, iterator=None, train=True):
def forward(self, graph, zero_initializer, h=None, c=None, iterator=None, train=True):
"""Compute tree-lstm prediction given a batch.
Parameters
----------
batch : dgl.data.SSTBatch
The data batch.
graph : dgl.DGLGraph
The batched trees.
zero_initializer : callable
Function to return zero value tensor.
h : Tensor, optional
......@@ -104,15 +88,17 @@ class TreeLSTM(nn.Module):
logits : Tensor
The prediction of each node.
"""
g = batch.graph
g = graph
n = g.number_of_nodes()
g.register_message_func(self.cell.message_func, batchable=True)
g.register_reduce_func(self.cell.reduce_func, batchable=True)
g.register_apply_node_func(self.cell.apply_func, batchable=True)
# feed embedding
embeds = self.embedding(batch.wordid)
x = zero_initializer((n, self.x_size))
x = x.index_copy(0, batch.nid_with_word, embeds)
wordid = g.pop_n_repr('x')
mask = (wordid != dgl.data.SST.PAD_WORD)
wordid = wordid * mask.long()
embeds = self.embedding(wordid)
x = embeds * th.unsqueeze(mask, 1).float()
if h is None:
h = zero_initializer((n, self.h_size))
h_tild = zero_initializer((n, self.h_size))
......
......@@ -8,6 +8,12 @@ namespace dgl {
class GraphOp {
public:
/*!
* \brief Return the line graph.
* \param graph The input graph.
* \return the line graph
*/
static Graph LineGraph(const Graph* graph);
/*!
* \brief Return a disjoint union of the input graphs.
*
......
"""High-performance graph structure query component.
TODO: Currently implemented by igraph. Should replace with more efficient
solution later.
"""
from __future__ import absolute_import
import igraph
from . import backend as F
from .backend import Tensor
from . import utils
class CachedGraph:
def __init__(self):
self._graph = igraph.Graph(directed=True)
self._freeze = False
def add_nodes(self, num_nodes):
if self._freeze:
raise RuntimeError('Freezed cached graph cannot be mutated.')
self._graph.add_vertices(num_nodes)
def add_edge(self, u, v):
if self._freeze:
raise RuntimeError('Freezed cached graph cannot be mutated.')
self._graph.add_edge(u, v)
def add_edges(self, u, v):
if self._freeze:
raise RuntimeError('Freezed cached graph cannot be mutated.')
# The edge will be assigned ids equal to the order.
uvs = list(utils.edge_iter(u, v))
self._graph.add_edges(uvs)
def get_edge_id(self, u, v):
uvs = list(utils.edge_iter(u, v))
eids = self._graph.get_eids(uvs)
return utils.toindex(eids)
def in_edges(self, v):
"""Get in-edges of the vertices.
Parameters
----------
v : utils.Index
The vertex ids.
Returns
-------
src : utils.Index
The src vertex ids.
dst : utils.Index
The dst vertex ids.
orphan : utils.Index
The vertice that have no in-edges.
"""
src = []
dst = []
orphan = []
for vv in utils.node_iter(v):
uu = self._graph.predecessors(vv)
if len(uu) == 0:
orphan.append(vv)
else:
src += uu
dst += [vv] * len(uu)
src = utils.toindex(src)
dst = utils.toindex(dst)
orphan = utils.toindex(orphan)
return src, dst, orphan
def out_edges(self, u):
"""Get out-edges of the vertices.
Parameters
----------
v : utils.Index
The vertex ids.
Returns
-------
src : utils.Index
The src vertex ids.
dst : utils.Index
The dst vertex ids.
orphan : utils.Index
The vertice that have no out-edges.
"""
src = []
dst = []
orphan = []
for uu in utils.node_iter(u):
vv = self._graph.successors(uu)
if len(vv) == 0:
orphan.append(uu)
else:
src += [uu] * len(vv)
dst += vv
src = utils.toindex(src)
dst = utils.toindex(dst)
orphan = utils.toindex(orphan)
return src, dst, orphan
def in_degrees(self, v):
degs = self._graph.indegree(list(v))
return utils.toindex(degs)
def num_edges(self):
return self._graph.ecount()
@utils.cached_member
def edges(self):
elist = self._graph.get_edgelist()
src = [u for u, _ in elist]
dst = [v for _, v in elist]
src = utils.toindex(src)
dst = utils.toindex(dst)
return src, dst
@utils.cached_member
def adjmat(self):
"""Return a sparse adjacency matrix.
The row dimension represents the dst nodes; the column dimension
represents the src nodes.
"""
elist = self._graph.get_edgelist()
src = F.tensor([u for u, _ in elist], dtype=F.int64)
dst = F.tensor([v for _, v in elist], dtype=F.int64)
src = F.unsqueeze(src, 0)
dst = F.unsqueeze(dst, 0)
idx = F.pack([dst, src])
n = self._graph.vcount()
dat = F.ones((len(elist),))
mat = F.sparse_tensor(idx, dat, [n, n])
return utils.CtxCachedObject(lambda ctx: F.to_context(mat, ctx))
def freeze(self):
self._freeze = True
def create_cached_graph(dglgraph):
cg = CachedGraph()
cg.add_nodes(dglgraph.number_of_nodes())
cg._graph.add_edges(dglgraph.edge_list)
cg.freeze()
return cg
......@@ -48,6 +48,7 @@ class DGLGraph(object):
# msg graph & frame
self._msg_graph = create_graph_index()
self._msg_frame = FrameRef()
self.reset_messages()
# registered functions
self._message_func = (None, None)
self._reduce_func = (None, None)
......@@ -112,7 +113,7 @@ class DGLGraph(object):
self._msg_graph.clear()
self._msg_frame.clear()
def clear_messages(self):
def reset_messages(self):
"""Clear all messages."""
self._msg_graph.clear()
self._msg_frame.clear()
......@@ -447,6 +448,7 @@ class DGLGraph(object):
self.clear()
self._graph.from_networkx(nx_graph)
self._msg_graph.add_nodes(self._graph.number_of_nodes())
# copy attributes
def _batcher(lst):
if isinstance(lst[0], Tensor):
return F.pack([F.unsqueeze(x, 0) for x in lst])
......@@ -1078,7 +1080,7 @@ class DGLGraph(object):
new_reprs.append(reduce_func(dst_reprs, reshaped_in_msgs))
# TODO: clear partial messages
self.clear_messages()
self.reset_messages()
# Pack all reducer results together
reordered_v = F.pack(reordered_v)
......
......@@ -20,7 +20,7 @@ void Graph::AddVertices(uint64_t num_vertices) {
void Graph::AddEdge(dgl_id_t src, dgl_id_t dst) {
CHECK(!read_only_) << "Graph is read-only. Mutations are not allowed.";
CHECK(HasVertex(src) && HasVertex(dst))
<< "In valid vertices: " << src << " " << dst;
<< "Invalid vertices: src=" << src << " dst=" << dst;
dgl_id_t eid = num_edges_++;
adjlist_[src].succ.push_back(dst);
adjlist_[src].edge_id.push_back(eid);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment