"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "c052791b5fe29ce8a308bf63dda97aa205b729be"
Commit a8a4fcba authored by Minjie Wang's avatar Minjie Wang
Browse files

quickly integrating with tree-lstm example

parent 882e2a7b
...@@ -8,23 +8,17 @@ from torch.utils.data import DataLoader ...@@ -8,23 +8,17 @@ from torch.utils.data import DataLoader
import dgl import dgl
import dgl.data as data import dgl.data as data
import dgl.ndarray as nd
from tree_lstm import TreeLSTM from tree_lstm import TreeLSTM
def _batch_to_cuda(batch):
return data.SSTBatch(graph=batch.graph,
nid_with_word = batch.nid_with_word.cuda(),
wordid = batch.wordid.cuda(),
label = batch.label.cuda())
import dgl.context as ctx
def tensor_topo_traverse(g, cuda, args): def tensor_topo_traverse(g, cuda, args):
n = g.number_of_nodes() n = g.number_of_nodes()
if cuda: if cuda:
adjmat = g.cached_graph.adjmat().get(ctx.gpu(args.gpu)) adjmat = g._graph.adjacency_matrix().get(nd.gpu(args.gpu))
mask = th.ones((n, 1)).cuda() mask = th.ones((n, 1)).cuda()
else: else:
adjmat = g.cached_graph.adjmat().get(ctx.cpu()) adjmat = g._graph.adjacency_matrix().get(nd.cpu())
mask = th.ones((n, 1)) mask = th.ones((n, 1))
degree = th.spmm(adjmat, mask) degree = th.spmm(adjmat, mask)
while th.sum(mask) != 0.: while th.sum(mask) != 0.:
...@@ -39,10 +33,17 @@ def main(args): ...@@ -39,10 +33,17 @@ def main(args):
cuda = args.gpu >= 0 cuda = args.gpu >= 0
if cuda: if cuda:
th.cuda.set_device(args.gpu) th.cuda.set_device(args.gpu)
def _batcher(trees):
bg = dgl.batch(trees)
if cuda:
reprs = bg.get_n_repr()
reprs = {key : reprs[key].cuda()}
bg.set_n_repr(reprs)
return bg
trainset = data.SST() trainset = data.SST()
train_loader = DataLoader(dataset=trainset, train_loader = DataLoader(dataset=trainset,
batch_size=args.batch_size, batch_size=args.batch_size,
collate_fn=data.SST.batcher, collate_fn=_batcher,
shuffle=False, shuffle=False,
num_workers=0) num_workers=0)
#testset = data.SST(mode='test') #testset = data.SST(mode='test')
...@@ -69,18 +70,15 @@ def main(args): ...@@ -69,18 +70,15 @@ def main(args):
dur = [] dur = []
for epoch in range(args.epochs): for epoch in range(args.epochs):
t_epoch = time.time() t_epoch = time.time()
for step, batch in enumerate(train_loader): for step, graph in enumerate(train_loader):
g = batch.graph
if cuda:
batch = _batch_to_cuda(batch)
if step >= 3: if step >= 3:
t0 = time.time() t0 = time.time()
label = graph.pop_n_repr('y')
# traverse graph # traverse graph
giter = list(tensor_topo_traverse(g, False, args)) giter = list(tensor_topo_traverse(graph, False, args))
logits = model(batch, zero_initializer, iterator=giter, train=True) logits = model(graph, zero_initializer, iterator=giter, train=True)
logp = F.log_softmax(logits, 1) logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, batch.label) loss = F.nll_loss(logp, label)
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
...@@ -89,11 +87,11 @@ def main(args): ...@@ -89,11 +87,11 @@ def main(args):
if step > 0 and step % args.log_every == 0: if step > 0 and step % args.log_every == 0:
pred = th.argmax(logits, 1) pred = th.argmax(logits, 1)
acc = th.sum(th.eq(batch.label, pred)) acc = th.sum(th.eq(label, pred))
mean_dur = np.mean(dur) mean_dur = np.mean(dur)
print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | " print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | "
"Acc {:.4f} | Time(s) {:.4f} | Trees/s {:.4f}".format( "Acc {:.4f} | Time(s) {:.4f} | Trees/s {:.4f}".format(
epoch, step, loss.item(), acc.item()/len(batch.label), epoch, step, loss.item(), acc.item() / args.batch_size,
mean_dur, args.batch_size / mean_dur)) mean_dur, args.batch_size / mean_dur))
print("Epoch time(s):", time.time() - t_epoch) print("Epoch time(s):", time.time() - t_epoch)
......
...@@ -10,23 +10,7 @@ import torch as th ...@@ -10,23 +10,7 @@ import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
def topological_traverse(G): import dgl
indegree_map = {v: d for v, d in G.in_degree() if d > 0}
# These nodes have zero indegree and ready to be returned.
zero_indegree = [v for v, d in G.in_degree() if d == 0]
while True:
yield zero_indegree
next_zero_indegree = []
while zero_indegree:
node = zero_indegree.pop()
for _, child in G.edges(node):
indegree_map[child] -= 1
if indegree_map[child] == 0:
next_zero_indegree.append(child)
del indegree_map[child]
if len(next_zero_indegree) == 0:
break
zero_indegree = next_zero_indegree
class ChildSumTreeLSTMCell(nn.Module): class ChildSumTreeLSTMCell(nn.Module):
def __init__(self, x_size, h_size): def __init__(self, x_size, h_size):
...@@ -83,13 +67,13 @@ class TreeLSTM(nn.Module): ...@@ -83,13 +67,13 @@ class TreeLSTM(nn.Module):
else: else:
raise RuntimeError('Unknown cell type:', cell_type) raise RuntimeError('Unknown cell type:', cell_type)
def forward(self, batch, zero_initializer, h=None, c=None, iterator=None, train=True): def forward(self, graph, zero_initializer, h=None, c=None, iterator=None, train=True):
"""Compute tree-lstm prediction given a batch. """Compute tree-lstm prediction given a batch.
Parameters Parameters
---------- ----------
batch : dgl.data.SSTBatch graph : dgl.DGLGraph
The data batch. The batched trees.
zero_initializer : callable zero_initializer : callable
Function to return zero value tensor. Function to return zero value tensor.
h : Tensor, optional h : Tensor, optional
...@@ -104,15 +88,17 @@ class TreeLSTM(nn.Module): ...@@ -104,15 +88,17 @@ class TreeLSTM(nn.Module):
logits : Tensor logits : Tensor
The prediction of each node. The prediction of each node.
""" """
g = batch.graph g = graph
n = g.number_of_nodes() n = g.number_of_nodes()
g.register_message_func(self.cell.message_func, batchable=True) g.register_message_func(self.cell.message_func, batchable=True)
g.register_reduce_func(self.cell.reduce_func, batchable=True) g.register_reduce_func(self.cell.reduce_func, batchable=True)
g.register_apply_node_func(self.cell.apply_func, batchable=True) g.register_apply_node_func(self.cell.apply_func, batchable=True)
# feed embedding # feed embedding
embeds = self.embedding(batch.wordid) wordid = g.pop_n_repr('x')
x = zero_initializer((n, self.x_size)) mask = (wordid != dgl.data.SST.PAD_WORD)
x = x.index_copy(0, batch.nid_with_word, embeds) wordid = wordid * mask.long()
embeds = self.embedding(wordid)
x = embeds * th.unsqueeze(mask, 1).float()
if h is None: if h is None:
h = zero_initializer((n, self.h_size)) h = zero_initializer((n, self.h_size))
h_tild = zero_initializer((n, self.h_size)) h_tild = zero_initializer((n, self.h_size))
......
...@@ -8,6 +8,12 @@ namespace dgl { ...@@ -8,6 +8,12 @@ namespace dgl {
class GraphOp { class GraphOp {
public: public:
/*!
* \brief Return the line graph.
* \param graph The input graph.
* \return the line graph
*/
static Graph LineGraph(const Graph* graph);
/*! /*!
* \brief Return a disjoint union of the input graphs. * \brief Return a disjoint union of the input graphs.
* *
......
"""High-performance graph structure query component.
TODO: Currently implemented by igraph. Should replace with more efficient
solution later.
"""
from __future__ import absolute_import
import igraph
from . import backend as F
from .backend import Tensor
from . import utils
class CachedGraph:
def __init__(self):
self._graph = igraph.Graph(directed=True)
self._freeze = False
def add_nodes(self, num_nodes):
if self._freeze:
raise RuntimeError('Freezed cached graph cannot be mutated.')
self._graph.add_vertices(num_nodes)
def add_edge(self, u, v):
if self._freeze:
raise RuntimeError('Freezed cached graph cannot be mutated.')
self._graph.add_edge(u, v)
def add_edges(self, u, v):
if self._freeze:
raise RuntimeError('Freezed cached graph cannot be mutated.')
# The edge will be assigned ids equal to the order.
uvs = list(utils.edge_iter(u, v))
self._graph.add_edges(uvs)
def get_edge_id(self, u, v):
uvs = list(utils.edge_iter(u, v))
eids = self._graph.get_eids(uvs)
return utils.toindex(eids)
def in_edges(self, v):
"""Get in-edges of the vertices.
Parameters
----------
v : utils.Index
The vertex ids.
Returns
-------
src : utils.Index
The src vertex ids.
dst : utils.Index
The dst vertex ids.
orphan : utils.Index
The vertice that have no in-edges.
"""
src = []
dst = []
orphan = []
for vv in utils.node_iter(v):
uu = self._graph.predecessors(vv)
if len(uu) == 0:
orphan.append(vv)
else:
src += uu
dst += [vv] * len(uu)
src = utils.toindex(src)
dst = utils.toindex(dst)
orphan = utils.toindex(orphan)
return src, dst, orphan
def out_edges(self, u):
"""Get out-edges of the vertices.
Parameters
----------
v : utils.Index
The vertex ids.
Returns
-------
src : utils.Index
The src vertex ids.
dst : utils.Index
The dst vertex ids.
orphan : utils.Index
The vertice that have no out-edges.
"""
src = []
dst = []
orphan = []
for uu in utils.node_iter(u):
vv = self._graph.successors(uu)
if len(vv) == 0:
orphan.append(uu)
else:
src += [uu] * len(vv)
dst += vv
src = utils.toindex(src)
dst = utils.toindex(dst)
orphan = utils.toindex(orphan)
return src, dst, orphan
def in_degrees(self, v):
degs = self._graph.indegree(list(v))
return utils.toindex(degs)
def num_edges(self):
return self._graph.ecount()
@utils.cached_member
def edges(self):
elist = self._graph.get_edgelist()
src = [u for u, _ in elist]
dst = [v for _, v in elist]
src = utils.toindex(src)
dst = utils.toindex(dst)
return src, dst
@utils.cached_member
def adjmat(self):
"""Return a sparse adjacency matrix.
The row dimension represents the dst nodes; the column dimension
represents the src nodes.
"""
elist = self._graph.get_edgelist()
src = F.tensor([u for u, _ in elist], dtype=F.int64)
dst = F.tensor([v for _, v in elist], dtype=F.int64)
src = F.unsqueeze(src, 0)
dst = F.unsqueeze(dst, 0)
idx = F.pack([dst, src])
n = self._graph.vcount()
dat = F.ones((len(elist),))
mat = F.sparse_tensor(idx, dat, [n, n])
return utils.CtxCachedObject(lambda ctx: F.to_context(mat, ctx))
def freeze(self):
self._freeze = True
def create_cached_graph(dglgraph):
cg = CachedGraph()
cg.add_nodes(dglgraph.number_of_nodes())
cg._graph.add_edges(dglgraph.edge_list)
cg.freeze()
return cg
...@@ -48,6 +48,7 @@ class DGLGraph(object): ...@@ -48,6 +48,7 @@ class DGLGraph(object):
# msg graph & frame # msg graph & frame
self._msg_graph = create_graph_index() self._msg_graph = create_graph_index()
self._msg_frame = FrameRef() self._msg_frame = FrameRef()
self.reset_messages()
# registered functions # registered functions
self._message_func = (None, None) self._message_func = (None, None)
self._reduce_func = (None, None) self._reduce_func = (None, None)
...@@ -112,7 +113,7 @@ class DGLGraph(object): ...@@ -112,7 +113,7 @@ class DGLGraph(object):
self._msg_graph.clear() self._msg_graph.clear()
self._msg_frame.clear() self._msg_frame.clear()
def clear_messages(self): def reset_messages(self):
"""Clear all messages.""" """Clear all messages."""
self._msg_graph.clear() self._msg_graph.clear()
self._msg_frame.clear() self._msg_frame.clear()
...@@ -447,6 +448,7 @@ class DGLGraph(object): ...@@ -447,6 +448,7 @@ class DGLGraph(object):
self.clear() self.clear()
self._graph.from_networkx(nx_graph) self._graph.from_networkx(nx_graph)
self._msg_graph.add_nodes(self._graph.number_of_nodes()) self._msg_graph.add_nodes(self._graph.number_of_nodes())
# copy attributes
def _batcher(lst): def _batcher(lst):
if isinstance(lst[0], Tensor): if isinstance(lst[0], Tensor):
return F.pack([F.unsqueeze(x, 0) for x in lst]) return F.pack([F.unsqueeze(x, 0) for x in lst])
...@@ -1078,7 +1080,7 @@ class DGLGraph(object): ...@@ -1078,7 +1080,7 @@ class DGLGraph(object):
new_reprs.append(reduce_func(dst_reprs, reshaped_in_msgs)) new_reprs.append(reduce_func(dst_reprs, reshaped_in_msgs))
# TODO: clear partial messages # TODO: clear partial messages
self.clear_messages() self.reset_messages()
# Pack all reducer results together # Pack all reducer results together
reordered_v = F.pack(reordered_v) reordered_v = F.pack(reordered_v)
......
...@@ -20,7 +20,7 @@ void Graph::AddVertices(uint64_t num_vertices) { ...@@ -20,7 +20,7 @@ void Graph::AddVertices(uint64_t num_vertices) {
void Graph::AddEdge(dgl_id_t src, dgl_id_t dst) { void Graph::AddEdge(dgl_id_t src, dgl_id_t dst) {
CHECK(!read_only_) << "Graph is read-only. Mutations are not allowed."; CHECK(!read_only_) << "Graph is read-only. Mutations are not allowed.";
CHECK(HasVertex(src) && HasVertex(dst)) CHECK(HasVertex(src) && HasVertex(dst))
<< "In valid vertices: " << src << " " << dst; << "Invalid vertices: src=" << src << " dst=" << dst;
dgl_id_t eid = num_edges_++; dgl_id_t eid = num_edges_++;
adjlist_[src].succ.push_back(dst); adjlist_[src].succ.push_back(dst);
adjlist_[src].edge_id.push_back(eid); adjlist_[src].edge_id.push_back(eid);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment