Unverified Commit 96179b0c authored by Lingfan Yu's avatar Lingfan Yu Committed by GitHub
Browse files

Deep Generative Models of Graphs (#14)

* model code for generative graphs

* batched version for dynamic graph generation using padding

* renaming function train back to forward

* remove old util function for padding DGMG

* override networkx clear to reset state, add dgl.nn

* Dynamic graph without batching

* use relative import path

* load dataset, pad batch

* bug fix

* experimental batch and unbatch

* dgmg batched version

* minor tweak

* move preprocessing padding into data loading

* batch graph test code

* minor

* batched graph class and test cases

* make dgl.nn.gcn a simple layer plus minor fix

* update dgmg model

* test forward using attribute field

* use frame append, minor changes

* moving networkx operations out of forward

* revert some changes

* remove structural immutability check
parent e3bac70b
...@@ -131,6 +131,7 @@ examples/pytorch/data/ind.citeseer.ally ...@@ -131,6 +131,7 @@ examples/pytorch/data/ind.citeseer.ally
examples/pytorch/data/ind.citeseer.allx examples/pytorch/data/ind.citeseer.allx
examples/pytorch/.DS_Store examples/pytorch/.DS_Store
examples/.DS_Store examples/.DS_Store
examples/pytorch/generative_graph/*.p
.DS_Store .DS_Store
# data directory # data directory
......
import dgl
from dgl.graph import DGLGraph
from dgl.nn import GCN
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import argparse
from util import DataLoader, elapsed
import time
class MLP(nn.Module):
def __init__(self, num_hidden, num_classes, num_layers):
super(MLP, self).__init__()
layers = []
# hidden layers
for _ in range(num_layers):
layers.append(nn.Linear(num_hidden, num_hidden))
layers.append(nn.Sigmoid())
# output projection
layers.append(nn.Linear(num_hidden, num_classes))
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
def move2cuda(x):
# recursively move a object to cuda
if isinstance(x, torch.Tensor):
# if Tensor, move directly
return x.cuda()
else:
try:
# iterable, recursively move each element
x = [move2cuda(i) for i in x]
return x
except:
# don't do anything for other types like basic types
return x
class DGMG(nn.Module):
def __init__(self, node_num_hidden, graph_num_hidden, T, num_MLP_layers=1, loss_func=None, dropout=0.0, use_cuda=False):
super(DGMG, self).__init__()
# hidden size of node and graph
self.node_num_hidden = node_num_hidden
self.graph_num_hidden = graph_num_hidden
# use GCN as a simple propagation model
self.gcn = nn.ModuleList([GCN(node_num_hidden, node_num_hidden, F.relu, dropout) for _ in range(T)])
# project node repr to graph repr (higher dimension)
self.graph_project = nn.Linear(node_num_hidden, graph_num_hidden)
# add node
self.fan = MLP(graph_num_hidden, 2, num_MLP_layers)
# add edge
self.fae = MLP(graph_num_hidden + node_num_hidden, 1, num_MLP_layers)
# select node to add edge
self.fs = MLP(node_num_hidden * 2, 1, num_MLP_layers)
# init node state
self.finit = MLP(graph_num_hidden, node_num_hidden, num_MLP_layers)
# loss function
self.loss_func = loss_func
# use gpu
self.use_cuda = use_cuda
def decide_add_node(self, hGs):
h = self.fan(hGs)
p = F.softmax(h, dim=1)
# calc loss
self.loss += self.loss_func(p, self.labels[self.step], self.masks[self.step])
def decide_add_edge(self, batched_graph, hGs):
hvs = batched_graph.get_n_repr((self.sample_node_curr_idx - 1).tolist())['h']
h = self.fae(torch.cat((hGs, hvs), dim=1))
p = torch.sigmoid(h)
p = torch.cat([1 - p, p], dim=1)
self.loss += self.loss_func(p, self.labels[self.step], self.masks[self.step])
def select_node_to_add_edge(self, batched_graph, indices):
node_indices = self.sample_node_curr_idx[indices].tolist()
node_start = self.sample_node_start_idx[indices].tolist()
node_repr = batched_graph.get_n_repr()['h']
for i, j, idx in zip(node_start, node_indices, indices):
hu = node_repr.narrow(0, i, j-i)
hv = node_repr.narrow(0, j-1, 1)
huv = torch.cat((hu, hv.expand(j-i, -1)), dim=1)
s = F.softmax(self.fs(huv), dim=0).view(1, -1)
dst = self.node_select[self.step][idx].view(-1)
self.loss += self.loss_func(s, dst)
def update_graph_repr(self, batched_graph, hGs, indices, indices_tensor):
start = self.sample_node_start_idx[indices].tolist()
stop = self.sample_node_curr_idx[indices].tolist()
node_repr = batched_graph.get_n_repr()['h']
graph_repr = self.graph_project(node_repr)
new_hGs = []
for i, j in zip(start, stop):
h = graph_repr.narrow(0, i, j-i)
hG = torch.sum(h, 0, keepdim=True)
new_hGs.append(hG)
new_hGs = torch.cat(new_hGs, dim=0)
return hGs.index_copy(0, indices_tensor, new_hGs)
def propagate(self, batched_graph, indices):
edge_src = [self.sample_edge_src[idx][0: self.sample_edge_count[idx]] for idx in indices]
edge_dst = [self.sample_edge_dst[idx][0: self.sample_edge_count[idx]] for idx in indices]
u = np.concatenate(edge_src).tolist()
v = np.concatenate(edge_dst).tolist()
for gcn in self.gcn:
gcn.forward(batched_graph, u, v, attribute='h')
def forward(self, training=False, ground_truth=None):
if not training:
raise NotImplementedError("inference is not implemented yet")
assert(ground_truth is not None)
signals, (batched_graph, self.sample_edge_src, self.sample_edge_dst) = ground_truth
nsteps, self.labels, self.node_select, self.masks, active_step, label1_set, label1_set_tensor = signals
# init loss
self.loss = 0
batch_size = len(self.sample_edge_src)
# initial node repr for each sample
hVs = torch.zeros(len(batched_graph), self.node_num_hidden)
# FIXME: what's the initial grpah repr for empty graph?
hGs = torch.zeros(batch_size, self.graph_num_hidden)
if self.use_cuda:
hVs = hVs.cuda()
hGs = hGs.cuda()
batched_graph.set_n_repr({'h': hVs})
self.sample_node_start_idx = batched_graph.query_node_start_offset()
self.sample_node_curr_idx = self.sample_node_start_idx.copy()
self.sample_edge_count = np.zeros(batch_size, dtype=int)
self.step = 0
while self.step < nsteps:
if self.step % 2 == 0: # add node step
if active_step[self.step]:
# decide whether to add node
self.decide_add_node(hGs)
# calculate initial state for new node
hvs = self.finit(hGs)
# add node
update = label1_set[self.step]
if len(update) > 0:
hvs = torch.index_select(hvs, 0, label1_set_tensor[self.step])
scatter_indices = self.sample_node_curr_idx[update]
batched_graph.set_n_repr({'h': hvs}, scatter_indices.tolist())
self.sample_node_curr_idx[update] += 1
# get new graph repr
hGs = self.update_graph_repr(batched_graph, hGs, update, label1_set_tensor[self.step])
else:
# all samples are masked
pass
else: # add edge step
# decide whether to add edge, which edge to add
# and also add edge
self.decide_add_edge(batched_graph, hGs)
# propagate
to_add_edge = label1_set[self.step]
if len(to_add_edge) > 0:
# at least one graph needs update
self.select_node_to_add_edge(batched_graph, to_add_edge)
# update edge count for each sample
self.sample_edge_count[to_add_edge] += 2 # undirected graph
# perform gcn propagation
self.propagate(batched_graph, to_add_edge)
# get new graph repr
hGs = self.update_graph_repr(batched_graph, hGs, label1_set[self.step], label1_set_tensor[self.step])
self.step += 1
def main(args):
if torch.cuda.is_available() and args.gpu >= 0:
torch.cuda.set_device(args.gpu)
use_cuda = True
else:
use_cuda = False
def masked_cross_entropy(x, label, mask=None):
# x: propability tensor, i.e. after softmax
x = torch.log(x)
if mask is not None:
x = x[mask]
label = label[mask]
return F.nll_loss(x, label)
model = DGMG(args.n_hidden_node, args.n_hidden_graph, args.n_layers,
loss_func=masked_cross_entropy, dropout=args.dropout, use_cuda=use_cuda)
if use_cuda:
model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
# training loop
for ep in range(args.n_epochs):
print("epoch: {}".format(ep))
for idx, ground_truth in enumerate(DataLoader(args.dataset, args.batch_size)):
if use_cuda:
count, label, node_list, mask, active, label1, label1_tensor = ground_truth[0]
label, node_list, mask, label1_tensor = move2cuda((label, node_list, mask, label1_tensor))
ground_truth[0] = (count, label, node_list, mask, active, label1, label1_tensor)
ground_truth[1][0].set_device(dgl.gpu(args.gpu))
optimizer.zero_grad()
# create new empty graphs
start = time.time()
model.forward(True, ground_truth)
end = time.time()
elapsed("model forward", start, end)
start = time.time()
model.loss.backward()
optimizer.step()
end = time.time()
elapsed("model backward", start, end)
print("iter {}: loss {}".format(idx, model.loss.item()))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='DGMG')
parser.add_argument("--dropout", type=float, default=0,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=1e-2,
help="learning rate")
parser.add_argument("--n-epochs", type=int, default=20,
help="number of training epochs")
parser.add_argument("--n-hidden-node", type=int, default=16,
help="number of hidden DGMG node units")
parser.add_argument("--n-hidden-graph", type=int, default=32,
help="number of hidden DGMG graph units")
parser.add_argument("--n-layers", type=int, default=2,
help="number of hidden gcn layers")
parser.add_argument("--dataset", type=str, default='samples.p',
help="dataset pickle file")
parser.add_argument("--batch-size", type=int, default=32,
help="batch size")
args = parser.parse_args()
print(args)
main(args)
import networkx as nx
import pickle
import random
import dgl
import numpy as np
import torch
def convert_graph_to_ordering(g):
ordering = []
h = nx.DiGraph()
h.add_edges_from(g.edges)
for n in h.nodes():
ordering.append(n)
for m in h.predecessors(n):
ordering.append((m, n))
return ordering
def generate_dataset():
n = 15
m = 2
n_samples = 1024
samples = []
for _ in range(n_samples):
g = nx.barabasi_albert_graph(n, m)
samples.append(convert_graph_to_ordering(g))
with open('samples.p', 'wb') as f:
pickle.dump(samples, f)
class DataLoader(object):
def __init__(self, fname, batch_size, shuffle=True):
with open(fname, 'rb') as f:
datasets = pickle.load(f)
if shuffle:
random.shuffle(datasets)
num = len(datasets) // batch_size
# pre-process dataset
self.ground_truth = []
for i in range(num):
batch = datasets[i*batch_size: (i+1)*batch_size]
padded_signals = pad_ground_truth(batch)
merged_graph = generate_merged_graph(batch)
self.ground_truth.append([padded_signals, merged_graph])
def __iter__(self):
return iter(self.ground_truth)
def generate_merged_graph(batch):
n_graphs = len(batch)
graph_list = []
# build each sample graph
new_edges = []
for ordering in batch:
g = dgl.DGLGraph()
node_count = 0
edge_list = []
for step in ordering:
if isinstance(step, int):
node_count += 1
else:
assert isinstance(step, tuple)
edge_list.append(step)
edge_list.append(tuple(reversed(step)))
g.add_nodes_from(range(node_count))
g.add_edges_from(edge_list)
new_edges.append(zip(*edge_list))
graph_list.append(g)
# batch
bg = dgl.batch(graph_list)
# get new edges
new_edges = [bg.query_new_edge(g, *edges) for g, edges in zip(graph_list, new_edges)]
new_src, new_dst = zip(*new_edges)
return bg, new_src, new_dst
def expand_ground_truth(ordering):
node_list = []
action = []
label = []
first_step = True
for i in ordering:
if isinstance(i, int):
if not first_step:
# add not to add edge
action.append(1)
label.append(0)
node_list.append(-1)
else:
first_step = False
action.append(0) # add node
label.append(1)
node_list.append(i)
else:
assert(isinstance(i, tuple))
action.append(1)
label.append(1)
node_list.append(i[0]) # select src node to add
# add not to add node
action.append(0)
label.append(0)
node_list.append(-1)
return len(action), action, label, node_list
def pad_ground_truth(batch):
a = []
bz = len(batch)
for sample in batch:
a.append(expand_ground_truth(sample))
length, action, label, node_list = zip(*a)
step = [0] * bz
new_label = []
new_node_list = []
mask_for_batch = []
next_action = 0
count = 0
active_step = [] # steps at least some graphs are not masked
label1_set = [] # graphs who decide to add node or edge
label1_set_tensor = []
while any([step[i] < length[i] for i in range(bz)]):
node_select = []
label_select = []
mask = []
label1 = []
not_all_masked = False
for sample_idx in range(bz):
if step[sample_idx] < length[sample_idx] and \
action[sample_idx][step[sample_idx]] == next_action:
mask.append(1)
node_select.append(node_list[sample_idx][step[sample_idx]])
label_select.append(label[sample_idx][step[sample_idx]])
# if decide to add node or add edge, record sample_idx
if label_select[-1] == 1:
label1.append(sample_idx)
step[sample_idx] += 1
not_all_masked = True
else:
mask.append(0)
node_select.append(-1)
label_select.append(0)
next_action = 1 - next_action
new_node_list.append(torch.LongTensor(node_select))
mask_for_batch.append(torch.ByteTensor(mask))
new_label.append(torch.LongTensor(label_select))
active_step.append(not_all_masked)
label1_set.append(np.array(label1))
label1_set_tensor.append(torch.LongTensor(label1))
count += 1
return count, new_label, new_node_list, mask_for_batch, active_step, label1_set, label1_set_tensor
def elapsed(msg, start, end):
print("{}: {} ms".format(msg, int((end-start)*1000)))
if __name__ == '__main__':
generate_dataset()
...@@ -2,3 +2,4 @@ from .base import ALL ...@@ -2,3 +2,4 @@ from .base import ALL
from .graph import DGLGraph from .graph import DGLGraph
from .graph import __MSG__, __REPR__ from .graph import __MSG__, __REPR__
from .context import cpu, gpu from .context import cpu, gpu
from .batch import batch, unbatch
...@@ -12,8 +12,13 @@ def asnumpy(a): ...@@ -12,8 +12,13 @@ def asnumpy(a):
def pack(arrays): def pack(arrays):
return np.concatenate(arrays, axis=0) return np.concatenate(arrays, axis=0)
def unpack(a): def unpack(a, split_size_or_sections=None):
return np.split(a, a.shape[0], axis=0) if split_size_or_sections is None:
indices_or_sections = a.shape[0]
else:
# convert split size to split indices by cumsum
indices_or_sections = np.cumsum(split_size_or_sections)[:-1]
return np.split(a, indices_or_sections, axis=0)
def shape(a): def shape(a):
return a.shape return a.shape
...@@ -32,8 +32,8 @@ def asnumpy(a): ...@@ -32,8 +32,8 @@ def asnumpy(a):
def pack(tensors): def pack(tensors):
return th.cat(tensors) return th.cat(tensors)
def unpack(x): def unpack(x, indices_or_sections=1):
return th.split(x, 1) return th.split(x, indices_or_sections)
def shape(x): def shape(x):
return x.shape return x.shape
......
from dgl.graph import DGLGraph
import dgl.backend as F
import dgl
import numpy as np
class BatchedDGLGraph(DGLGraph):
def __init__(self, graph_list, node_attrs=None, edge_attrs=None, **attr):
super(BatchedDGLGraph, self).__init__(**attr)
self.graph_list = graph_list
self.graph_idx = {}
for idx, g in enumerate(self.graph_list):
self.graph_idx[g] = idx
self.num_nodes = [len(g) for g in self.graph_list]
self.num_edges = [g.size() for g in self.graph_list]
# calc index offset
self.node_offset = np.cumsum([0] + self.num_nodes)
self.edge_offset = np.cumsum([0] + self.num_edges)
# in-order add relabeled nodes
self.add_nodes_from(range(self.node_offset[-1]))
# in-order add relabeled edges
self.new_edge_list = [np.array(g.edges) + offset
for g, offset in zip(self.graph_list, self.node_offset[:-1])]
self.new_edges = np.concatenate(self.new_edge_list)
self.add_edges_from(self.new_edges)
assert self.size() == self.edge_offset[-1]
# set new node attr
if node_attrs:
attrs = {}
for key in node_attrs:
vals = [g.pop_n_repr(key) for g in self.graph_list]
attrs[key] = F.pack(vals)
self.set_n_repr(attrs)
else:
for g in self.graph_list:
self._node_frame.append(g._node_frame)
# set new edge attr
if edge_attrs:
attrs = {}
for key in edge_attrs:
vals = [g.pop_e_repr(key) for g in self.graph_list]
attrs[key] = F.pack(vals)
self.set_e_repr(attrs)
else:
for g in self.graph_list:
self._edge_frame.append(g._edge_frame)
def query_new_node(self, g, u):
idx = self.graph_idx[g]
offset = self.node_offset[idx]
if isinstance(u, (int, np.array, F.Tensor)):
return u + offset
else:
return np.array(u) + offset
def query_new_edge(self, g, src, dst):
idx = self.graph_idx[g]
offset = self.node_offset[idx]
if isinstance(src, (int, np.ndarray, F.Tensor)) and \
isinstance(dst, (int, np.ndarray, F.Tensor)):
return src + offset, dst + offset
else:
return np.array(src) + offset, np.array(dst) + offset
def query_node_start_offset(self):
return self.node_offset[:-1].copy()
def query_edge_start_offset(self):
return self.edge_offset[:-1].copy()
def unbatch(graph_batch):
"""Unbatch the graph and return a list of subgraphs.
Parameters
----------
graph_batch : DGLGraph
The batched graph.
"""
graph_list = graph_batch.graph_list
num_graphs = len(graph_list)
# split and set node attrs
attrs = [{} for _ in range(num_graphs)] # node attr dict for each graph
for key in graph_batch.get_n_attr_list():
vals = F.unpack(graph_batch.pop_n_repr(key), graph_batch.num_nodes)
for attr, val in zip(attrs, vals):
attr[key] = val
for attr, g in zip(attrs, graph_list):
g.set_n_repr(attr)
# split and set edge attrs
attrs = [{} for _ in range(num_graphs)] # edge attr dict for each graph
for key in graph_batch.get_e_attr_list():
vals = F.unpack(graph_batch.pop_e_repr(key), graph_batch.num_edges)
for attr, val in zip(attrs, vals):
attr[key] = val
for attr, g in zip(attrs, graph_list):
g.set_e_repr(attr)
return graph_list
# FIXME (lingfan): Do we really need the batch API?
# Can't we let user call BatchedDGLGraph(graph_list) directly
# and make unbatch a member function of BatchedDGLGraph
def batch(graph_list, node_attrs=None, edge_attrs=None):
"""Batch a list of DGLGraphs into one single graph.
Once batch is called, the structure of both merged graph and graphs in graph_list
must not bbe mutated, or unbatch's behavior will be undefined.
Parameters
----------
graph_list : iterable
A list of DGLGraphs to be batched.
node_attrs : str or iterable
A list of node attributes needed for merged graph
It's user's resposiblity to make sure node_attrs exists
edge_attrs : str or iterable
A list of edge attributes needed for merged graph
It's user's resposiblity to make sure edge_attrs exists
Return
------
newgrh: DGLGraph
one single merged graph
"""
return BatchedDGLGraph(graph_list, node_attrs, edge_attrs)
...@@ -33,6 +33,7 @@ class _NodeDict(MutableMapping): ...@@ -33,6 +33,7 @@ class _NodeDict(MutableMapping):
def __getitem__(self, key): def __getitem__(self, key):
return self._dict[key] return self._dict[key]
def __delitem__(self, key): def __delitem__(self, key):
# FIXME: add callback
del self._dict[key] del self._dict[key]
def __len__(self): def __len__(self):
return len(self._dict) return len(self._dict)
...@@ -51,6 +52,7 @@ class _AdjInnerDict(MutableMapping): ...@@ -51,6 +52,7 @@ class _AdjInnerDict(MutableMapping):
def __getitem__(self, key): def __getitem__(self, key):
return self._dict[key] return self._dict[key]
def __delitem__(self, key): def __delitem__(self, key):
# FIXME: add callback
del self._dict[key] del self._dict[key]
def __len__(self): def __len__(self):
return len(self._dict) return len(self._dict)
...@@ -78,6 +80,12 @@ class DGLGraph(DiGraph): ...@@ -78,6 +80,12 @@ class DGLGraph(DiGraph):
self.adjlist_outer_dict_factory = None self.adjlist_outer_dict_factory = None
self.adjlist_inner_dict_factory = lambda : _AdjInnerDict(self._add_edge_callback) self.adjlist_inner_dict_factory = lambda : _AdjInnerDict(self._add_edge_callback)
self.edge_attr_dict_factory = dict self.edge_attr_dict_factory = dict
self._context = context.cpu()
# call base class init
super(DGLGraph, self).__init__(graph_data, **attr)
self._init_state()
def _init_state(self):
# cached graph and storage # cached graph and storage
self._cached_graph = None self._cached_graph = None
self._node_frame = Frame() self._node_frame = Frame()
...@@ -91,9 +99,16 @@ class DGLGraph(DiGraph): ...@@ -91,9 +99,16 @@ class DGLGraph(DiGraph):
self._edge_func = None self._edge_func = None
self._edge_cb_state = True self._edge_cb_state = True
self._edge_list = [] self._edge_list = []
self._context = context.cpu()
# call base class init def clear(self):
super(DGLGraph, self).__init__(graph_data, **attr) super(DGLGraph, self).clear()
self._init_state()
def get_n_attr_list(self):
return self._node_frame.schemes
def get_e_attr_list(self):
return self._edge_frame.schemes
def set_n_repr(self, hu, u=ALL): def set_n_repr(self, hu, u=ALL):
"""Set node(s) representation. """Set node(s) representation.
...@@ -764,6 +779,8 @@ class DGLGraph(DiGraph): ...@@ -764,6 +779,8 @@ class DGLGraph(DiGraph):
new_node_repr = update_func(node_repr, reduced_msgs) new_node_repr = update_func(node_repr, reduced_msgs)
self.set_n_repr(new_node_repr, new2old) self.set_n_repr(new_node_repr, new2old)
else: else:
u = utils.convert_to_id_tensor(u, self.context)
v = utils.convert_to_id_tensor(v, self.context)
self._batch_sendto(u, v, message_func) self._batch_sendto(u, v, message_func)
unique_v = F.unique(v) unique_v = F.unique(v)
self._batch_recv(unique_v, reduce_func, update_func) self._batch_recv(unique_v, reduce_func, update_func)
...@@ -990,6 +1007,7 @@ class DGLGraph(DiGraph): ...@@ -990,6 +1007,7 @@ class DGLGraph(DiGraph):
"""Return edges in the addition order.""" """Return edges in the addition order."""
return self._edge_list return self._edge_list
def _get_repr(attr_dict): def _get_repr(attr_dict):
if len(attr_dict) == 1 and __REPR__ in attr_dict: if len(attr_dict) == 1 and __REPR__ in attr_dict:
return attr_dict[__REPR__] return attr_dict[__REPR__]
......
import os
__backend__ = os.environ.get('DGLBACKEND', 'pytorch').lower()
if __backend__ == 'numpy':
pass
elif __backend__ == 'pytorch':
from .pytorch import *
else:
raise Exception("Unsupported backend %s" % __backend__)
"""
Semi-Supervised Classification with Graph Convolutional Networks
Paper: https://arxiv.org/abs/1609.02907
Code: https://github.com/tkipf/gcn
GCN with SPMV specialization.
"""
import torch.nn as nn
from dgl.base import ALL, is_all
class NodeUpdateModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None):
super(NodeUpdateModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation
self.attribute = None
def set_attribute_to_update(self, attribute):
self.attribute = attribute
def forward(self, node, accum, attribute=None):
if self.attribute:
accum = accum[self.attribute]
h = self.linear(accum)
if self.activation:
h = self.activation(h)
if self.attribute:
return {self.attribute: h}
else:
return h
class GCN(nn.Module):
def __init__(self,
in_feats,
out_feats,
activation,
dropout=0):
super(GCN, self).__init__()
self.dropout = dropout
# input layer
self.update_func = NodeUpdateModule(in_feats, out_feats, activation)
def forward(self, g, u=ALL, v=ALL, attribute=None):
self.update_func.set_attribute_to_update(attribute)
if is_all(u) and is_all(v):
g.update_all('from_src', 'sum', self.update_func, batchable=True)
else:
g.update_by_edge(u, v, 'from_src', 'sum', self.update_func, batchable=True)
return g
import networkx as nx
import dgl
import torch
import numpy as np
def tree1():
"""Generate a tree
0
/ \
1 2
/ \
3 4
Edges are from leaves to root.
"""
g = dgl.DGLGraph()
g.add_node(0)
g.add_node(1)
g.add_node(2)
g.add_node(3)
g.add_node(4)
g.add_edge(3, 1)
g.add_edge(4, 1)
g.add_edge(1, 0)
g.add_edge(2, 0)
g.set_n_repr(torch.Tensor([0, 1, 2, 3, 4]))
return g
def tree2():
"""Generate a tree
1
/ \
4 3
/ \
2 0
Edges are from leaves to root.
"""
g = dgl.DGLGraph()
g.add_node(0)
g.add_node(1)
g.add_node(2)
g.add_node(3)
g.add_node(4)
g.add_edge(2, 4)
g.add_edge(0, 4)
g.add_edge(4, 1)
g.add_edge(3, 1)
g.set_n_repr(torch.Tensor([0, 1, 2, 3, 4]))
return g
def test_batch_unbatch():
t1 = tree1()
t2 = tree2()
f1 = t1.get_n_repr()
f2 = t2.get_n_repr()
bg = dgl.batch([t1, t2])
dgl.unbatch(bg)
assert(f1.equal(t1.get_n_repr()))
assert(f2.equal(t2.get_n_repr()))
def test_batch_sendrecv():
t1 = tree1()
t2 = tree2()
bg = dgl.batch([t1, t2])
bg.register_message_func(lambda src, edge: src, batchable=True)
bg.register_reduce_func(lambda node, msgs: torch.sum(msgs, 1), batchable=True)
bg.register_update_func(lambda node, accum: accum, batchable=True)
e1 = [(3, 1), (4, 1)]
e2 = [(2, 4), (0, 4)]
u1, v1 = bg.query_new_edge(t1, *zip(*e1))
u2, v2 = bg.query_new_edge(t2, *zip(*e2))
u = np.concatenate((u1, u2)).tolist()
v = np.concatenate((v1, v2)).tolist()
bg.sendto(u, v)
bg.recv(v)
dgl.unbatch(bg)
assert t1.get_n_repr()[1] == 7
assert t2.get_n_repr()[4] == 2
def test_batch_propagate():
t1 = tree1()
t2 = tree2()
bg = dgl.batch([t1, t2])
bg.register_message_func(lambda src, edge: src, batchable=True)
bg.register_reduce_func(lambda node, msgs: torch.sum(msgs, 1), batchable=True)
bg.register_update_func(lambda node, accum: accum, batchable=True)
# get leaves.
order = []
# step 1
e1 = [(3, 1), (4, 1)]
e2 = [(2, 4), (0, 4)]
u1, v1 = bg.query_new_edge(t1, *zip(*e1))
u2, v2 = bg.query_new_edge(t2, *zip(*e2))
u = np.concatenate((u1, u2)).tolist()
v = np.concatenate((v1, v2)).tolist()
order.append((u, v))
# step 2
e1 = [(1, 0), (2, 0)]
e2 = [(4, 1), (3, 1)]
u1, v1 = bg.query_new_edge(t1, *zip(*e1))
u2, v2 = bg.query_new_edge(t2, *zip(*e2))
u = np.concatenate((u1, u2)).tolist()
v = np.concatenate((v1, v2)).tolist()
order.append((u, v))
bg.propagate(iterator=order)
dgl.unbatch(bg)
assert t1.get_n_repr()[0] == 9
assert t2.get_n_repr()[1] == 5
if __name__ == '__main__':
test_batch_unbatch()
test_batch_sendrecv()
test_batch_propagate()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment