Commit 3e8b63ec authored by Mufei Li's avatar Mufei Li Committed by Minjie Wang
Browse files

[Model] DGMG Training with Batch Size 1 (#161)

* DGMG with batch size 1

* Fix

* Adjustment

* Fix

* Fix

* Fix

* Fix
parent bf6d0025
# Learning Deep Generative Models of Graphs
This is an implementation of [Learning Deep Generative Models of Graphs](https://arxiv.org/pdf/1803.03324.pdf) by
Yujia Li, Oriol Vinyals, Chris Dyer, Razvan Pascanu, Peter Battaglia.
# Dependency
- Python 3.5.2
- [Pytorch 0.4.1](https://pytorch.org/)
- [Matplotlib 2.2.2](https://matplotlib.org/)
# Usage
- Train with batch size 1: `python main.py`
"""We intend to make our reproduction as close as possible to the original paper.
The configuration in the file is mostly from the description in the original paper
and will be loaded when setting up."""
def dataset_based_configure(opts):
if opts['dataset'] == 'cycles':
ds_configure = cycles_configure
else:
raise ValueError('Unsupported dataset: {}'.format(opts['dataset']))
opts = {**opts, **ds_configure}
return opts
synthetic_dataset_configure = {
'node_hidden_size': 16,
'num_propagation_rounds': 2,
'optimizer': 'Adam',
'nepochs': 25,
'ds_size': 4000,
'num_generated_samples': 10000,
}
cycles_configure = {
**synthetic_dataset_configure,
**{
'min_size': 10,
'max_size': 20,
'lr': 5e-4,
}
}
import matplotlib.pyplot as plt
import networkx as nx
import os
import pickle
import random
from torch.utils.data import Dataset
def get_previous(i, v_max):
if i == 0:
return v_max
else:
return i - 1
def get_next(i, v_max):
if i == v_max:
return 0
else:
return i + 1
def is_cycle(g):
size = g.number_of_nodes()
if size < 3:
return False
for node in range(size):
neighbors = g.successors(node)
if len(neighbors) != 2:
return False
if get_previous(node, size - 1) not in neighbors:
return False
if get_next(node, size - 1) not in neighbors:
return False
return True
def get_decision_sequence(size):
"""
Get the decision sequence for generating valid cycles with DGMG for teacher
forcing optimization.
"""
decision_sequence = []
for i in range(size):
decision_sequence.append(0) # Add node
if i != 0:
decision_sequence.append(0) # Add edge
decision_sequence.append(i - 1) # Set destination to be previous node.
if i == size - 1:
decision_sequence.append(0) # Add edge
decision_sequence.append(0) # Set destination to be the root.
decision_sequence.append(1) # Stop adding edge
decision_sequence.append(1) # Stop adding node
return decision_sequence
def generate_dataset(v_min, v_max, n_samples, fname):
samples = []
for _ in range(n_samples):
size = random.randint(v_min, v_max)
samples.append(get_decision_sequence(size))
with open(fname, 'wb') as f:
pickle.dump(samples, f)
class CycleDataset(Dataset):
def __init__(self, fname):
super(CycleDataset, self).__init__()
with open(fname, 'rb') as f:
self.dataset = pickle.load(f)
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
return self.dataset[index]
def collate(self, batch):
assert len(batch) == 1, 'Currently we do not support batched training'
return batch[0]
def dglGraph_to_adj_list(g):
adj_list = {}
for node in range(g.number_of_nodes()):
# For undirected graph. successors and
# predecessors are equivalent.
adj_list[node] = g.successors(node).tolist()
return adj_list
class CycleModelEvaluation(object):
def __init__(self, v_min, v_max, dir):
super(CycleModelEvaluation, self).__init__()
self.v_min = v_min
self.v_max = v_max
self.dir = dir
def _initialize(self):
self.num_samples_examined = 0
self.average_size = 0
self.valid_size_ratio = 0
self.cycle_ratio = 0
self.valid_ratio = 0
def rollout_and_examine(self, model, num_samples):
assert not model.training, 'You need to call model.eval().'
num_total_size = 0
num_valid_size = 0
num_cycle = 0
num_valid = 0
plot_times = 0
adj_lists_to_plot = []
for i in range(num_samples):
sampled_graph = model()
sampled_adj_list = dglGraph_to_adj_list(sampled_graph)
adj_lists_to_plot.append(sampled_adj_list)
generated_graph_size = sampled_graph.number_of_nodes()
valid_size = (self.v_min <= generated_graph_size <= self.v_max)
cycle = is_cycle(sampled_graph)
num_total_size += generated_graph_size
if valid_size:
num_valid_size += 1
if cycle:
num_cycle += 1
if valid_size and cycle:
num_valid += 1
if len(adj_lists_to_plot) == 4:
plot_times += 1
fig, ((ax0, ax1), (ax2, ax3)) = plt.subplots(2, 2)
axes = {0: ax0, 1: ax1, 2: ax2, 3: ax3}
for i in range(4):
nx.draw_circular(nx.from_dict_of_lists(adj_lists_to_plot[i]),
with_labels=True, ax=axes[i])
plt.savefig(self.dir + '/samples/{:d}'.format(plot_times))
plt.close()
adj_lists_to_plot = []
self.num_samples_examined = num_samples
self.average_size = num_total_size / num_samples
self.valid_size_ratio = num_valid_size / num_samples
self.cycle_ratio = num_cycle / num_samples
self.valid_ratio = num_valid / num_samples
def write_summary(self):
def _format_value(v):
if isinstance(v, float):
return '{:.4f}'.format(v)
elif isinstance(v, int):
return '{:d}'.format(v)
else:
return '{}'.format(v)
statistics = {
'num_samples': self.num_samples_examined,
'v_min': self.v_min,
'v_max': self.v_max,
'average_size': self.average_size,
'valid_size_ratio': self.valid_size_ratio,
'cycle_ratio': self.cycle_ratio,
'valid_ratio': self.valid_ratio
}
model_eval_path = os.path.join(self.dir, 'model_eval.txt')
with open(model_eval_path, 'w') as f:
for key, value in statistics.items():
msg = '{}\t{}\n'.format(key, _format_value(value))
f.write(msg)
print('Saved model evaluation statistics to {}'.format(model_eval_path))
self._initialize()
class CyclePrinting(object):
def __init__(self, num_epochs, num_batches):
super(CyclePrinting, self).__init__()
self.num_epochs = num_epochs
self.num_batches = num_batches
self.batch_count = 0
def update(self, epoch, metrics):
self.batch_count = (self.batch_count) % self.num_batches + 1
msg = 'epoch {:d}/{:d}, batch {:d}/{:d}'.format(epoch, self.num_epochs,
self.batch_count, self.num_batches)
for key, value in metrics.items():
msg += ', {}: {:4f}'.format(key, value)
print(msg)
"""
Learning Deep Generative Models of Graphs
Paper: https://arxiv.org/pdf/1803.03324.pdf
This implementation works with a minibatch of size 1 only for both training and inference.
"""
import argparse
import datetime
import time
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
from model import DGMG
def main(opts):
t1 = time.time()
# Setup dataset and data loader
if opts['dataset'] == 'cycles':
from cycles import CycleDataset, CycleModelEvaluation, CyclePrinting
dataset = CycleDataset(fname=opts['path_to_dataset'])
evaluator = CycleModelEvaluation(v_min=opts['min_size'],
v_max=opts['max_size'],
dir=opts['log_dir'])
printer = CyclePrinting(num_epochs=opts['nepochs'],
num_batches=opts['ds_size'] // opts['batch_size'])
else:
raise ValueError('Unsupported dataset: {}'.format(opts['dataset']))
data_loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0,
collate_fn=dataset.collate)
# Initialize_model
model = DGMG(v_max=opts['max_size'],
node_hidden_size=opts['node_hidden_size'],
num_prop_rounds=opts['num_propagation_rounds'])
# Initialize optimizer
if opts['optimizer'] == 'Adam':
optimizer = Adam(model.parameters(), lr=opts['lr'])
else:
raise ValueError('Unsupported argument for the optimizer')
t2 = time.time()
# Training
model.train()
for epoch in range(opts['nepochs']):
batch_count = 0
batch_loss = 0
batch_prob = 0
optimizer.zero_grad()
for i, data in enumerate(data_loader):
log_prob = model(actions=data)
prob = log_prob.detach().exp()
loss = - log_prob / opts['batch_size']
prob_averaged = prob / opts['batch_size']
loss.backward()
batch_loss += loss.item()
batch_prob += prob_averaged.item()
batch_count += 1
if batch_count % opts['batch_size'] == 0:
printer.update(epoch + 1, {'averaged_loss': batch_loss,
'averaged_prob': batch_prob})
if opts['clip_grad']:
clip_grad_norm_(model.parameters(), opts['clip_bound'])
optimizer.step()
batch_loss = 0
batch_prob = 0
optimizer.zero_grad()
t3 = time.time()
model.eval()
evaluator.rollout_and_examine(model, opts['num_generated_samples'])
evaluator.write_summary()
t4 = time.time()
print('It took {} to setup.'.format(datetime.timedelta(seconds=t2-t1)))
print('It took {} to finish training.'.format(datetime.timedelta(seconds=t3-t2)))
print('It took {} to finish evaluation.'.format(datetime.timedelta(seconds=t4-t3)))
print('--------------------------------------------------------------------------')
print('On average, an epoch takes {}.'.format(datetime.timedelta(
seconds=(t3-t2) / opts['nepochs'])))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='DGMG')
# configure
parser.add_argument('--seed', type=int, default=9284, help='random seed')
# dataset
parser.add_argument('--dataset', choices=['cycles'], default='cycles',
help='dataset to use')
parser.add_argument('--path-to-dataset', type=str, default='cycles.p',
help='load the dataset if it exists, '
'generate it and save to the path otherwise')
# log
parser.add_argument('--log-dir', default='./results',
help='folder to save info like experiment configuration '
'or model evaluation results')
# optimization
parser.add_argument('--batch-size', type=int, default=10,
help='batch size to use for training')
parser.add_argument('--clip-grad', action='store_true', default=True,
help='gradient clipping is required to prevent gradient explosion')
parser.add_argument('--clip-bound', type=float, default=0.25,
help='constraint of gradient norm for gradient clipping')
args = parser.parse_args()
from utils import setup
opts = setup(args)
main(opts)
import dgl
from dgl.graph import DGLGraph
from dgl.nn import GCN
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import argparse
from util import DataLoader, elapsed, generate_dataset
import time
class MLP(nn.Module):
def __init__(self, num_hidden, num_classes, num_layers):
super(MLP, self).__init__()
layers = []
# hidden layers
for _ in range(num_layers):
layers.append(nn.Linear(num_hidden, num_hidden))
layers.append(nn.Sigmoid())
# output projection
layers.append(nn.Linear(num_hidden, num_classes))
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
def move2cuda(x):
# recursively move a object to cuda
if isinstance(x, torch.Tensor):
# if Tensor, move directly
return x.cuda()
from functools import partial
from torch.distributions import Bernoulli, Categorical
class GraphEmbed(nn.Module):
def __init__(self, node_hidden_size):
super(GraphEmbed, self).__init__()
# Setting from the paper
self.graph_hidden_size = 2 * node_hidden_size
# Embed graphs
self.node_gating = nn.Sequential(
nn.Linear(node_hidden_size, 1),
nn.Sigmoid()
)
self.node_to_graph = nn.Linear(node_hidden_size,
self.graph_hidden_size)
def forward(self, g):
if g.number_of_nodes() == 0:
return torch.zeros(1, self.graph_hidden_size)
else:
# Node features are stored as hv in ndata.
hvs = g.ndata['hv']
return (self.node_gating(hvs) *
self.node_to_graph(hvs)).sum(0, keepdim=True)
class GraphProp(nn.Module):
def __init__(self, num_prop_rounds, node_hidden_size):
super(GraphProp, self).__init__()
self.num_prop_rounds = num_prop_rounds
# Setting from the paper
self.node_activation_hidden_size = 2 * node_hidden_size
message_funcs = []
self.reduce_funcs = []
node_update_funcs = []
for t in range(num_prop_rounds):
# input being [hv, hu, xuv]
message_funcs.append(nn.Linear(2 * node_hidden_size + 1,
self.node_activation_hidden_size))
self.reduce_funcs.append(partial(self.dgmg_reduce, round=t))
node_update_funcs.append(
nn.GRUCell(self.node_activation_hidden_size,
node_hidden_size))
self.message_funcs = nn.ModuleList(message_funcs)
self.node_update_funcs = nn.ModuleList(node_update_funcs)
def dgmg_msg(self, edges):
"""For an edge u->v, return concat([h_u, x_uv])"""
return {'m': torch.cat([edges.src['hv'],
edges.data['he']],
dim=1)}
def dgmg_reduce(self, nodes, round):
hv_old = nodes.data['hv']
m = nodes.mailbox['m']
message = torch.cat([
hv_old.unsqueeze(1).expand(-1, m.size(1), -1), m], dim=2)
node_activation = (self.message_funcs[round](message)).sum(1)
return {'a': node_activation}
def forward(self, g):
if g.number_of_edges() == 0:
return
else:
for t in range(self.num_prop_rounds):
g.update_all(message_func=self.dgmg_msg,
reduce_func=self.reduce_funcs[t])
g.ndata['hv'] = self.node_update_funcs[t](
g.ndata['a'], g.ndata['hv'])
def bernoulli_action_log_prob(logit, action):
"""Calculate the log p of an action with respect to a Bernoulli
distribution. Use logit rather than prob for numerical stability."""
if action == 0:
return F.logsigmoid(-logit)
else:
try:
# iterable, recursively move each element
x = [move2cuda(i) for i in x]
return x
except:
# don't do anything for other types like basic types
return x
return F.logsigmoid(logit)
class AddNode(nn.Module):
def __init__(self, graph_embed_func, node_hidden_size):
super(AddNode, self).__init__()
self.graph_op = {'embed': graph_embed_func}
self.stop = 1
self.add_node = nn.Linear(graph_embed_func.graph_hidden_size, 1)
# If to add a node, initialize its hv
self.node_type_embed = nn.Embedding(1, node_hidden_size)
self.initialize_hv = nn.Linear(node_hidden_size + \
graph_embed_func.graph_hidden_size,
node_hidden_size)
self.init_node_activation = torch.zeros(1, 2 * node_hidden_size)
def _initialize_node_repr(self, g, node_type, graph_embed):
num_nodes = g.number_of_nodes()
hv_init = self.initialize_hv(
torch.cat([
self.node_type_embed(torch.LongTensor([node_type])),
graph_embed], dim=1))
g.nodes[num_nodes - 1].data['hv'] = hv_init
g.nodes[num_nodes - 1].data['a'] = self.init_node_activation
def prepare_training(self):
self.log_prob = []
def forward(self, g, action=None):
graph_embed = self.graph_op['embed'](g)
logit = self.add_node(graph_embed)
prob = torch.sigmoid(logit)
if not self.training:
action = Bernoulli(prob).sample().item()
stop = bool(action == self.stop)
if not stop:
g.add_nodes(1)
self._initialize_node_repr(g, action, graph_embed)
if self.training:
sample_log_prob = bernoulli_action_log_prob(logit, action)
self.log_prob.append(sample_log_prob)
return stop
class AddEdge(nn.Module):
def __init__(self, graph_embed_func, node_hidden_size):
super(AddEdge, self).__init__()
self.graph_op = {'embed': graph_embed_func}
self.add_edge = nn.Linear(graph_embed_func.graph_hidden_size + \
node_hidden_size, 1)
def prepare_training(self):
self.log_prob = []
def forward(self, g, action=None):
graph_embed = self.graph_op['embed'](g)
src_embed = g.nodes[g.number_of_nodes() - 1].data['hv']
logit = self.add_edge(torch.cat(
[graph_embed, src_embed], dim=1))
prob = torch.sigmoid(logit)
if not self.training:
action = Bernoulli(prob).sample().item()
to_add_edge = bool(action == 0)
if self.training:
sample_log_prob = bernoulli_action_log_prob(logit, action)
self.log_prob.append(sample_log_prob)
return to_add_edge
class ChooseDestAndUpdate(nn.Module):
def __init__(self, graph_prop_func, node_hidden_size):
super(ChooseDestAndUpdate, self).__init__()
self.graph_op = {'prop': graph_prop_func}
self.choose_dest = nn.Linear(2 * node_hidden_size, 1)
def _initialize_edge_repr(self, g, src_list, dest_list):
# For untyped edges, we only add 1 to indicate its existence.
# For multiple edge types, we can use a one hot representation
# or an embedding module.
edge_repr = torch.ones(len(src_list), 1)
g.edges[src_list, dest_list].data['he'] = edge_repr
def prepare_training(self):
self.log_prob = []
def forward(self, g, dest):
src = g.number_of_nodes() - 1
possible_dests = range(src)
src_embed_expand = g.nodes[src].data['hv'].expand(src, -1)
possible_dests_embed = g.nodes[possible_dests].data['hv']
dests_scores = self.choose_dest(
torch.cat([possible_dests_embed,
src_embed_expand], dim=1)).view(1, -1)
dests_probs = F.softmax(dests_scores, dim=1)
if not self.training:
dest = Categorical(dests_probs).sample().item()
if not g.has_edge_between(src, dest):
# For undirected graphs, we add edges for both directions
# so that we can perform graph propagation.
src_list = [src, dest]
dest_list = [dest, src]
g.add_edges(src_list, dest_list)
self._initialize_edge_repr(g, src_list, dest_list)
self.graph_op['prop'](g)
if self.training:
if dests_probs.nelement() > 1:
self.log_prob.append(
F.log_softmax(dests_scores, dim=1)[:, dest: dest + 1])
class DGMG(nn.Module):
def __init__(self, node_num_hidden, graph_num_hidden, T, num_MLP_layers=1, loss_func=None, dropout=0.0, use_cuda=False):
def __init__(self, v_max, node_hidden_size,
num_prop_rounds):
super(DGMG, self).__init__()
# hidden size of node and graph
self.node_num_hidden = node_num_hidden
self.graph_num_hidden = graph_num_hidden
# use GCN as a simple propagation model
self.gcn = nn.ModuleList([GCN(node_num_hidden, node_num_hidden, F.relu, dropout) for _ in range(T)])
# project node repr to graph repr (higher dimension)
self.graph_project = nn.Linear(node_num_hidden, graph_num_hidden)
# add node
self.fan = MLP(graph_num_hidden, 2, num_MLP_layers)
# add edge
self.fae = MLP(graph_num_hidden + node_num_hidden, 1, num_MLP_layers)
# select node to add edge
self.fs = MLP(node_num_hidden * 2, 1, num_MLP_layers)
# init node state
self.finit = MLP(graph_num_hidden, node_num_hidden, num_MLP_layers)
# loss function
self.loss_func = loss_func
# use gpu
self.use_cuda = use_cuda
def decide_add_node(self, hGs):
h = self.fan(hGs)
p = F.softmax(h, dim=1)
# calc loss
self.loss += self.loss_func(p, self.labels[self.step], self.masks[self.step])
def decide_add_edge(self, batched_graph, hGs):
hvs = batched_graph.get_n_repr((self.sample_node_curr_idx - 1).tolist())['h']
h = self.fae(torch.cat((hGs, hvs), dim=1))
p = torch.sigmoid(h)
p = torch.cat([1 - p, p], dim=1)
self.loss += self.loss_func(p, self.labels[self.step], self.masks[self.step])
def select_node_to_add_edge(self, batched_graph, indices):
node_indices = self.sample_node_curr_idx[indices].tolist()
node_start = self.sample_node_start_idx[indices].tolist()
node_repr = batched_graph.get_n_repr()['h']
for i, j, idx in zip(node_start, node_indices, indices):
hu = node_repr.narrow(0, i, j-i)
hv = node_repr.narrow(0, j-1, 1)
huv = torch.cat((hu, hv.expand(j-i, -1)), dim=1)
s = F.softmax(self.fs(huv), dim=0).view(1, -1)
dst = self.node_select[self.step][idx].view(-1)
self.loss += self.loss_func(s, dst)
def update_graph_repr(self, batched_graph, hGs, indices, indices_tensor):
start = self.sample_node_start_idx[indices].tolist()
stop = self.sample_node_curr_idx[indices].tolist()
node_repr = batched_graph.get_n_repr()['h']
graph_repr = self.graph_project(node_repr)
new_hGs = []
for i, j in zip(start, stop):
h = graph_repr.narrow(0, i, j-i)
hG = torch.sum(h, 0, keepdim=True)
new_hGs.append(hG)
new_hGs = torch.cat(new_hGs, dim=0)
return hGs.index_copy(0, indices_tensor, new_hGs)
def propagate(self, batched_graph, indices):
edge_src = [self.sample_edge_src[idx][0: self.sample_edge_count[idx]] for idx in indices]
edge_dst = [self.sample_edge_dst[idx][0: self.sample_edge_count[idx]] for idx in indices]
u = np.concatenate(edge_src).tolist()
v = np.concatenate(edge_dst).tolist()
for gcn in self.gcn:
gcn.forward(batched_graph, u, v, attribute='h')
def forward(self, training=False, ground_truth=None):
if not training:
raise NotImplementedError("inference is not implemented yet")
assert(ground_truth is not None)
signals, (batched_graph, self.sample_edge_src, self.sample_edge_dst) = ground_truth
nsteps, self.labels, self.node_select, self.masks, active_step, label1_set, label1_set_tensor = signals
# init loss
self.loss = 0
batch_size = len(self.sample_edge_src)
# initial node repr for each sample
hVs = torch.zeros(len(batched_graph), self.node_num_hidden)
# FIXME: what's the initial grpah repr for empty graph?
hGs = torch.zeros(batch_size, self.graph_num_hidden)
if self.use_cuda:
hVs = hVs.cuda()
hGs = hGs.cuda()
batched_graph.set_n_repr({'h': hVs})
self.sample_node_start_idx = batched_graph.query_node_start_offset()
self.sample_node_curr_idx = self.sample_node_start_idx.copy()
self.sample_edge_count = np.zeros(batch_size, dtype=int)
self.step = 0
while self.step < nsteps:
if self.step % 2 == 0: # add node step
if active_step[self.step]:
# decide whether to add node
self.decide_add_node(hGs)
# calculate initial state for new node
hvs = self.finit(hGs)
# add node
update = label1_set[self.step]
if len(update) > 0:
hvs = torch.index_select(hvs, 0, label1_set_tensor[self.step])
scatter_indices = self.sample_node_curr_idx[update]
batched_graph.set_n_repr({'h': hvs}, scatter_indices.tolist())
self.sample_node_curr_idx[update] += 1
# get new graph repr
hGs = self.update_graph_repr(batched_graph, hGs, update, label1_set_tensor[self.step])
else:
# all samples are masked
pass
else: # add edge step
# decide whether to add edge, which edge to add
# and also add edge
self.decide_add_edge(batched_graph, hGs)
# propagate
to_add_edge = label1_set[self.step]
if len(to_add_edge) > 0:
# at least one graph needs update
self.select_node_to_add_edge(batched_graph, to_add_edge)
# update edge count for each sample
self.sample_edge_count[to_add_edge] += 2 # undirected graph
# perform gcn propagation
self.propagate(batched_graph, to_add_edge)
# get new graph repr
hGs = self.update_graph_repr(batched_graph, hGs, label1_set[self.step], label1_set_tensor[self.step])
self.step += 1
def main(args):
if torch.cuda.is_available() and args.gpu >= 0:
torch.cuda.set_device(args.gpu)
use_cuda = True
else:
use_cuda = False
def masked_cross_entropy(x, label, mask=None):
# x: propability tensor, i.e. after softmax
x = torch.log(x)
if mask is not None:
x = x[mask]
label = label[mask]
return F.nll_loss(x, label)
model = DGMG(args.n_hidden_node, args.n_hidden_graph, args.n_layers,
loss_func=masked_cross_entropy, dropout=args.dropout, use_cuda=use_cuda)
if use_cuda:
model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
# training loop
for ep in range(args.n_epochs):
print("epoch: {}".format(ep))
for idx, ground_truth in enumerate(DataLoader(args.dataset, args.batch_size)):
if use_cuda:
count, label, node_list, mask, active, label1, label1_tensor = ground_truth[0]
label, node_list, mask, label1_tensor = move2cuda((label, node_list, mask, label1_tensor))
ground_truth[0] = (count, label, node_list, mask, active, label1, label1_tensor)
optimizer.zero_grad()
# create new empty graphs
start = time.time()
model.forward(True, ground_truth)
end = time.time()
elapsed("model forward", start, end)
start = time.time()
model.loss.backward()
optimizer.step()
end = time.time()
elapsed("model backward", start, end)
print("iter {}: loss {}".format(idx, model.loss.item()))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='DGMG')
parser.add_argument("--dropout", type=float, default=0,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=1e-2,
help="learning rate")
parser.add_argument("--n-epochs", type=int, default=20,
help="number of training epochs")
parser.add_argument("--n-hidden-node", type=int, default=16,
help="number of hidden DGMG node units")
parser.add_argument("--n-hidden-graph", type=int, default=32,
help="number of hidden DGMG graph units")
parser.add_argument("--n-layers", type=int, default=2,
help="number of hidden gcn layers")
parser.add_argument("--dataset", type=str, default='samples.p',
help="dataset pickle file")
parser.add_argument("--gen-dataset", type=str, default=None,
help="parameters to generate B-A graph datasets. Format: <#node>,<#edge>,<#sample>")
parser.add_argument("--batch-size", type=int, default=32,
help="batch size")
args = parser.parse_args()
print(args)
# generate dataset if needed
if args.gen_dataset is not None:
n_node, n_edge, n_sample = map(int, args.gen_dataset.split(','))
generate_dataset(n_node, n_edge, n_sample, args.dataset)
main(args)
# Graph configuration
self.v_max = v_max
# Graph embedding module
self.graph_embed = GraphEmbed(node_hidden_size)
# Graph propagation module
self.graph_prop = GraphProp(num_prop_rounds,
node_hidden_size)
# Actions
self.add_node_agent = AddNode(
self.graph_embed, node_hidden_size)
self.add_edge_agent = AddEdge(
self.graph_embed, node_hidden_size)
self.choose_dest_agent = ChooseDestAndUpdate(
self.graph_prop, node_hidden_size)
# Weight initialization
self.init_weights()
def init_weights(self):
from utils import weights_init, dgmg_message_weight_init
self.graph_embed.apply(weights_init)
self.graph_prop.apply(weights_init)
self.add_node_agent.apply(weights_init)
self.add_edge_agent.apply(weights_init)
self.choose_dest_agent.apply(weights_init)
self.graph_prop.message_funcs.apply(dgmg_message_weight_init)
@property
def action_step(self):
old_step_count = self.step_count
self.step_count += 1
return old_step_count
def prepare_for_train(self):
self.step_count = 0
self.add_node_agent.prepare_training()
self.add_edge_agent.prepare_training()
self.choose_dest_agent.prepare_training()
def add_node_and_update(self, a=None):
"""Decide if to add a new node.
If a new node should be added, update the graph."""
return self.add_node_agent(self.g, a)
def add_edge_or_not(self, a=None):
"""Decide if a new edge should be added."""
return self.add_edge_agent(self.g, a)
def choose_dest_and_update(self, a=None):
"""Choose destination and connect it to the latest node.
Add edges for both directions and update the graph."""
self.choose_dest_agent(self.g, a)
def get_log_prob(self):
return torch.cat(self.add_node_agent.log_prob).sum()\
+ torch.cat(self.add_edge_agent.log_prob).sum()\
+ torch.cat(self.choose_dest_agent.log_prob).sum()
def forward_train(self, actions):
self.prepare_for_train()
stop = self.add_node_and_update(a=actions[self.action_step])
while not stop:
to_add_edge = self.add_edge_or_not(a=actions[self.action_step])
while to_add_edge:
self.choose_dest_and_update(a=actions[self.action_step])
to_add_edge = self.add_edge_or_not(a=actions[self.action_step])
stop = self.add_node_and_update(a=actions[self.action_step])
return self.get_log_prob()
def forward_inference(self):
stop = self.add_node_and_update()
while (not stop) and (self.g.number_of_nodes() < self.v_max + 1):
num_trials = 0
to_add_edge = self.add_edge_or_not()
while to_add_edge and (num_trials < self.g.number_of_nodes() - 1):
self.choose_dest_and_update()
num_trials += 1
to_add_edge = self.add_edge_or_not()
stop = self.add_node_and_update()
return self.g
def forward(self, actions=None):
# The graph we will work on
self.g = dgl.DGLGraph()
# If there are some features for nodes and edges,
# zero tensors will be set for those of new nodes and edges.
self.g.set_n_initializer(dgl.frame.zero_initializer)
self.g.set_e_initializer(dgl.frame.zero_initializer)
if self.training:
return self.forward_train(actions)
else:
return self.forward_inference()
import networkx as nx
import pickle
import random
import dgl
import numpy as np
import torch
def convert_graph_to_ordering(g):
ordering = []
h = nx.DiGraph()
h.add_edges_from(g.edges)
for n in range(len(h)):
ordering.append(n)
for m in h.predecessors(n):
ordering.append((m, n))
return ordering
def generate_dataset(n, m, n_samples, fname):
samples = []
for _ in range(n_samples):
g = nx.barabasi_albert_graph(n, m)
samples.append(convert_graph_to_ordering(g))
with open(fname, 'wb') as f:
pickle.dump(samples, f)
class DataLoader(object):
def __init__(self, fname, batch_size, shuffle=True):
with open(fname, 'rb') as f:
datasets = pickle.load(f)
if shuffle:
random.shuffle(datasets)
num = len(datasets) // batch_size
# pre-process dataset
self.ground_truth = []
for i in range(num):
batch = datasets[i*batch_size: (i+1)*batch_size]
padded_signals = pad_ground_truth(batch)
merged_graph = generate_merged_graph(batch)
self.ground_truth.append([padded_signals, merged_graph])
def __iter__(self):
return iter(self.ground_truth)
def generate_merged_graph(batch):
n_graphs = len(batch)
graph_list = []
# build each sample graph
new_edges = []
for ordering in batch:
g = dgl.DGLGraph()
node_count = 0
edge_list = []
for step in ordering:
if isinstance(step, int):
node_count += 1
else:
assert isinstance(step, tuple)
edge_list.append(step)
edge_list.append(tuple(reversed(step)))
g.add_nodes_from(range(node_count))
g.add_edges_from(edge_list)
new_edges.append(zip(*edge_list))
graph_list.append(g)
# batch
bg = dgl.batch(graph_list)
# get new edges
new_edges = [bg.query_new_edge(g, *edges) for g, edges in zip(graph_list, new_edges)]
new_src, new_dst = zip(*new_edges)
return bg, new_src, new_dst
def expand_ground_truth(ordering):
node_list = []
action = []
label = []
first_step = True
for i in ordering:
if isinstance(i, int):
if not first_step:
# add not to add edge
action.append(1)
label.append(0)
node_list.append(-1)
else:
first_step = False
action.append(0) # add node
label.append(1)
node_list.append(i)
else:
assert(isinstance(i, tuple))
action.append(1)
label.append(1)
node_list.append(i[0]) # select src node to add
# add not to add node
action.append(0)
label.append(0)
node_list.append(-1)
return len(action), action, label, node_list
def pad_ground_truth(batch):
a = []
bz = len(batch)
for sample in batch:
a.append(expand_ground_truth(sample))
length, action, label, node_list = zip(*a)
step = [0] * bz
new_label = []
new_node_list = []
mask_for_batch = []
next_action = 0
count = 0
active_step = [] # steps at least some graphs are not masked
label1_set = [] # graphs who decide to add node or edge
label1_set_tensor = []
while any([step[i] < length[i] for i in range(bz)]):
node_select = []
label_select = []
mask = []
label1 = []
not_all_masked = False
for sample_idx in range(bz):
if step[sample_idx] < length[sample_idx] and \
action[sample_idx][step[sample_idx]] == next_action:
mask.append(1)
node_select.append(node_list[sample_idx][step[sample_idx]])
label_select.append(label[sample_idx][step[sample_idx]])
# if decide to add node or add edge, record sample_idx
if label_select[-1] == 1:
label1.append(sample_idx)
step[sample_idx] += 1
not_all_masked = True
else:
mask.append(0)
node_select.append(-1)
label_select.append(0)
next_action = 1 - next_action
new_node_list.append(torch.LongTensor(node_select))
mask_for_batch.append(torch.ByteTensor(mask))
new_label.append(torch.LongTensor(label_select))
active_step.append(not_all_masked)
label1_set.append(np.array(label1))
label1_set_tensor.append(torch.LongTensor(label1))
count += 1
return count, new_label, new_node_list, mask_for_batch, active_step, label1_set, label1_set_tensor
def elapsed(msg, start, end):
print("{}: {} ms".format(msg, int((end-start)*1000)))
if __name__ == '__main__':
n = 15
m = 2
n_samples = 1024
fname ='samples.p'
generate_dataset(n, m, n_samples, fname)
import datetime
import matplotlib.pyplot as plt
import os
import random
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.init as init
from pprint import pprint
########################################################################################################################
# configuration #
########################################################################################################################
def mkdir_p(path):
import errno
try:
os.makedirs(path)
print('Created directory {}'.format(path))
except OSError as exc:
if exc.errno == errno.EEXIST and os.path.isdir(path):
print('Directory {} already exists.'.format(path))
else:
raise
def date_filename(base_dir='./'):
dt = datetime.datetime.now()
return os.path.join(base_dir, '{}_{:02d}-{:02d}-{:02d}'.format(
dt.date(), dt.hour, dt.minute, dt.second
))
def setup_log_dir(opts):
log_dir = '{}'.format(date_filename(opts['log_dir']))
mkdir_p(log_dir)
return log_dir
def save_arg_dict(opts, filename='settings.txt'):
def _format_value(v):
if isinstance(v, float):
return '{:.4f}'.format(v)
elif isinstance(v, int):
return '{:d}'.format(v)
else:
return '{}'.format(v)
save_path = os.path.join(opts['log_dir'], filename)
with open(save_path, 'w') as f:
for key, value in opts.items():
f.write('{}\t{}\n'.format(key, _format_value(value)))
print('Saved settings to {}'.format(save_path))
def setup(args):
opts = args.__dict__.copy()
cudnn.benchmark = False
cudnn.deterministic = True
# Seed
if opts['seed'] is None:
opts['seed'] = random.randint(1, 10000)
random.seed(opts['seed'])
torch.manual_seed(opts['seed'])
# Dataset
from configure import dataset_based_configure
opts = dataset_based_configure(opts)
assert opts['path_to_dataset'] is not None, 'Expect path to dataset to be set.'
if not os.path.exists(opts['path_to_dataset']):
if opts['dataset'] == 'cycles':
from cycles import generate_dataset
generate_dataset(opts['min_size'], opts['max_size'],
opts['ds_size'], opts['path_to_dataset'])
else:
raise ValueError('Unsupported dataset: {}'.format(opts['dataset']))
# Optimization
if opts['clip_grad']:
assert opts['clip_grad'] is not None, 'Expect the gradient norm constraint to be set.'
# Log
print('Prepare logging directory...')
log_dir = setup_log_dir(opts)
opts['log_dir'] = log_dir
mkdir_p(log_dir + '/samples')
plt.switch_backend('Agg')
save_arg_dict(opts)
pprint(opts)
return opts
########################################################################################################################
# model #
########################################################################################################################
def weights_init(m):
'''
Code from https://gist.github.com/jeasinema/ed9236ce743c8efaf30fa2ff732749f5
Usage:
model = Model()
model.apply(weight_init)
'''
if isinstance(m, nn.Linear):
init.xavier_normal_(m.weight.data)
init.normal_(m.bias.data)
elif isinstance(m, nn.GRUCell):
for param in m.parameters():
if len(param.shape) >= 2:
init.orthogonal_(param.data)
else:
init.normal_(param.data)
def dgmg_message_weight_init(m):
"""
This is similar as the function above where we initialize linear layers from a normal distribution with std
1./10 as suggested by the author. This should only be used for the message passing functions, i.e. fe's in the
paper.
"""
def _weight_init(m):
if isinstance(m, nn.Linear):
init.normal_(m.weight.data, std=1./10)
init.normal_(m.bias.data, std=1./10)
else:
raise ValueError('Expected the input to be of type nn.Linear!')
if isinstance(m, nn.ModuleList):
for layer in m:
layer.apply(_weight_init)
else:
m.apply(_weight_init)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment