"tools/python/vscode:/vscode.git/clone" did not exist on "527e26df0806c4daa68c0692bc032a7fca43c2ef"
Commit 3e8b63ec authored by Mufei Li's avatar Mufei Li Committed by Minjie Wang
Browse files

[Model] DGMG Training with Batch Size 1 (#161)

* DGMG with batch size 1

* Fix

* Adjustment

* Fix

* Fix

* Fix

* Fix
parent bf6d0025
# Learning Deep Generative Models of Graphs
This is an implementation of [Learning Deep Generative Models of Graphs](https://arxiv.org/pdf/1803.03324.pdf) by
Yujia Li, Oriol Vinyals, Chris Dyer, Razvan Pascanu, Peter Battaglia.
# Dependency
- Python 3.5.2
- [Pytorch 0.4.1](https://pytorch.org/)
- [Matplotlib 2.2.2](https://matplotlib.org/)
# Usage
- Train with batch size 1: `python main.py`
"""We intend to make our reproduction as close as possible to the original paper.
The configuration in the file is mostly from the description in the original paper
and will be loaded when setting up."""
def dataset_based_configure(opts):
if opts['dataset'] == 'cycles':
ds_configure = cycles_configure
else:
raise ValueError('Unsupported dataset: {}'.format(opts['dataset']))
opts = {**opts, **ds_configure}
return opts
synthetic_dataset_configure = {
'node_hidden_size': 16,
'num_propagation_rounds': 2,
'optimizer': 'Adam',
'nepochs': 25,
'ds_size': 4000,
'num_generated_samples': 10000,
}
cycles_configure = {
**synthetic_dataset_configure,
**{
'min_size': 10,
'max_size': 20,
'lr': 5e-4,
}
}
import matplotlib.pyplot as plt
import networkx as nx
import os
import pickle
import random
from torch.utils.data import Dataset
def get_previous(i, v_max):
if i == 0:
return v_max
else:
return i - 1
def get_next(i, v_max):
if i == v_max:
return 0
else:
return i + 1
def is_cycle(g):
size = g.number_of_nodes()
if size < 3:
return False
for node in range(size):
neighbors = g.successors(node)
if len(neighbors) != 2:
return False
if get_previous(node, size - 1) not in neighbors:
return False
if get_next(node, size - 1) not in neighbors:
return False
return True
def get_decision_sequence(size):
"""
Get the decision sequence for generating valid cycles with DGMG for teacher
forcing optimization.
"""
decision_sequence = []
for i in range(size):
decision_sequence.append(0) # Add node
if i != 0:
decision_sequence.append(0) # Add edge
decision_sequence.append(i - 1) # Set destination to be previous node.
if i == size - 1:
decision_sequence.append(0) # Add edge
decision_sequence.append(0) # Set destination to be the root.
decision_sequence.append(1) # Stop adding edge
decision_sequence.append(1) # Stop adding node
return decision_sequence
def generate_dataset(v_min, v_max, n_samples, fname):
samples = []
for _ in range(n_samples):
size = random.randint(v_min, v_max)
samples.append(get_decision_sequence(size))
with open(fname, 'wb') as f:
pickle.dump(samples, f)
class CycleDataset(Dataset):
def __init__(self, fname):
super(CycleDataset, self).__init__()
with open(fname, 'rb') as f:
self.dataset = pickle.load(f)
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
return self.dataset[index]
def collate(self, batch):
assert len(batch) == 1, 'Currently we do not support batched training'
return batch[0]
def dglGraph_to_adj_list(g):
adj_list = {}
for node in range(g.number_of_nodes()):
# For undirected graph. successors and
# predecessors are equivalent.
adj_list[node] = g.successors(node).tolist()
return adj_list
class CycleModelEvaluation(object):
def __init__(self, v_min, v_max, dir):
super(CycleModelEvaluation, self).__init__()
self.v_min = v_min
self.v_max = v_max
self.dir = dir
def _initialize(self):
self.num_samples_examined = 0
self.average_size = 0
self.valid_size_ratio = 0
self.cycle_ratio = 0
self.valid_ratio = 0
def rollout_and_examine(self, model, num_samples):
assert not model.training, 'You need to call model.eval().'
num_total_size = 0
num_valid_size = 0
num_cycle = 0
num_valid = 0
plot_times = 0
adj_lists_to_plot = []
for i in range(num_samples):
sampled_graph = model()
sampled_adj_list = dglGraph_to_adj_list(sampled_graph)
adj_lists_to_plot.append(sampled_adj_list)
generated_graph_size = sampled_graph.number_of_nodes()
valid_size = (self.v_min <= generated_graph_size <= self.v_max)
cycle = is_cycle(sampled_graph)
num_total_size += generated_graph_size
if valid_size:
num_valid_size += 1
if cycle:
num_cycle += 1
if valid_size and cycle:
num_valid += 1
if len(adj_lists_to_plot) == 4:
plot_times += 1
fig, ((ax0, ax1), (ax2, ax3)) = plt.subplots(2, 2)
axes = {0: ax0, 1: ax1, 2: ax2, 3: ax3}
for i in range(4):
nx.draw_circular(nx.from_dict_of_lists(adj_lists_to_plot[i]),
with_labels=True, ax=axes[i])
plt.savefig(self.dir + '/samples/{:d}'.format(plot_times))
plt.close()
adj_lists_to_plot = []
self.num_samples_examined = num_samples
self.average_size = num_total_size / num_samples
self.valid_size_ratio = num_valid_size / num_samples
self.cycle_ratio = num_cycle / num_samples
self.valid_ratio = num_valid / num_samples
def write_summary(self):
def _format_value(v):
if isinstance(v, float):
return '{:.4f}'.format(v)
elif isinstance(v, int):
return '{:d}'.format(v)
else:
return '{}'.format(v)
statistics = {
'num_samples': self.num_samples_examined,
'v_min': self.v_min,
'v_max': self.v_max,
'average_size': self.average_size,
'valid_size_ratio': self.valid_size_ratio,
'cycle_ratio': self.cycle_ratio,
'valid_ratio': self.valid_ratio
}
model_eval_path = os.path.join(self.dir, 'model_eval.txt')
with open(model_eval_path, 'w') as f:
for key, value in statistics.items():
msg = '{}\t{}\n'.format(key, _format_value(value))
f.write(msg)
print('Saved model evaluation statistics to {}'.format(model_eval_path))
self._initialize()
class CyclePrinting(object):
def __init__(self, num_epochs, num_batches):
super(CyclePrinting, self).__init__()
self.num_epochs = num_epochs
self.num_batches = num_batches
self.batch_count = 0
def update(self, epoch, metrics):
self.batch_count = (self.batch_count) % self.num_batches + 1
msg = 'epoch {:d}/{:d}, batch {:d}/{:d}'.format(epoch, self.num_epochs,
self.batch_count, self.num_batches)
for key, value in metrics.items():
msg += ', {}: {:4f}'.format(key, value)
print(msg)
"""
Learning Deep Generative Models of Graphs
Paper: https://arxiv.org/pdf/1803.03324.pdf
This implementation works with a minibatch of size 1 only for both training and inference.
"""
import argparse
import datetime
import time
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
from model import DGMG
def main(opts):
t1 = time.time()
# Setup dataset and data loader
if opts['dataset'] == 'cycles':
from cycles import CycleDataset, CycleModelEvaluation, CyclePrinting
dataset = CycleDataset(fname=opts['path_to_dataset'])
evaluator = CycleModelEvaluation(v_min=opts['min_size'],
v_max=opts['max_size'],
dir=opts['log_dir'])
printer = CyclePrinting(num_epochs=opts['nepochs'],
num_batches=opts['ds_size'] // opts['batch_size'])
else:
raise ValueError('Unsupported dataset: {}'.format(opts['dataset']))
data_loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0,
collate_fn=dataset.collate)
# Initialize_model
model = DGMG(v_max=opts['max_size'],
node_hidden_size=opts['node_hidden_size'],
num_prop_rounds=opts['num_propagation_rounds'])
# Initialize optimizer
if opts['optimizer'] == 'Adam':
optimizer = Adam(model.parameters(), lr=opts['lr'])
else:
raise ValueError('Unsupported argument for the optimizer')
t2 = time.time()
# Training
model.train()
for epoch in range(opts['nepochs']):
batch_count = 0
batch_loss = 0
batch_prob = 0
optimizer.zero_grad()
for i, data in enumerate(data_loader):
log_prob = model(actions=data)
prob = log_prob.detach().exp()
loss = - log_prob / opts['batch_size']
prob_averaged = prob / opts['batch_size']
loss.backward()
batch_loss += loss.item()
batch_prob += prob_averaged.item()
batch_count += 1
if batch_count % opts['batch_size'] == 0:
printer.update(epoch + 1, {'averaged_loss': batch_loss,
'averaged_prob': batch_prob})
if opts['clip_grad']:
clip_grad_norm_(model.parameters(), opts['clip_bound'])
optimizer.step()
batch_loss = 0
batch_prob = 0
optimizer.zero_grad()
t3 = time.time()
model.eval()
evaluator.rollout_and_examine(model, opts['num_generated_samples'])
evaluator.write_summary()
t4 = time.time()
print('It took {} to setup.'.format(datetime.timedelta(seconds=t2-t1)))
print('It took {} to finish training.'.format(datetime.timedelta(seconds=t3-t2)))
print('It took {} to finish evaluation.'.format(datetime.timedelta(seconds=t4-t3)))
print('--------------------------------------------------------------------------')
print('On average, an epoch takes {}.'.format(datetime.timedelta(
seconds=(t3-t2) / opts['nepochs'])))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='DGMG')
# configure
parser.add_argument('--seed', type=int, default=9284, help='random seed')
# dataset
parser.add_argument('--dataset', choices=['cycles'], default='cycles',
help='dataset to use')
parser.add_argument('--path-to-dataset', type=str, default='cycles.p',
help='load the dataset if it exists, '
'generate it and save to the path otherwise')
# log
parser.add_argument('--log-dir', default='./results',
help='folder to save info like experiment configuration '
'or model evaluation results')
# optimization
parser.add_argument('--batch-size', type=int, default=10,
help='batch size to use for training')
parser.add_argument('--clip-grad', action='store_true', default=True,
help='gradient clipping is required to prevent gradient explosion')
parser.add_argument('--clip-bound', type=float, default=0.25,
help='constraint of gradient norm for gradient clipping')
args = parser.parse_args()
from utils import setup
opts = setup(args)
main(opts)
This diff is collapsed.
import networkx as nx
import pickle
import random
import dgl
import numpy as np
import torch
def convert_graph_to_ordering(g):
ordering = []
h = nx.DiGraph()
h.add_edges_from(g.edges)
for n in range(len(h)):
ordering.append(n)
for m in h.predecessors(n):
ordering.append((m, n))
return ordering
def generate_dataset(n, m, n_samples, fname):
samples = []
for _ in range(n_samples):
g = nx.barabasi_albert_graph(n, m)
samples.append(convert_graph_to_ordering(g))
with open(fname, 'wb') as f:
pickle.dump(samples, f)
class DataLoader(object):
def __init__(self, fname, batch_size, shuffle=True):
with open(fname, 'rb') as f:
datasets = pickle.load(f)
if shuffle:
random.shuffle(datasets)
num = len(datasets) // batch_size
# pre-process dataset
self.ground_truth = []
for i in range(num):
batch = datasets[i*batch_size: (i+1)*batch_size]
padded_signals = pad_ground_truth(batch)
merged_graph = generate_merged_graph(batch)
self.ground_truth.append([padded_signals, merged_graph])
def __iter__(self):
return iter(self.ground_truth)
def generate_merged_graph(batch):
n_graphs = len(batch)
graph_list = []
# build each sample graph
new_edges = []
for ordering in batch:
g = dgl.DGLGraph()
node_count = 0
edge_list = []
for step in ordering:
if isinstance(step, int):
node_count += 1
else:
assert isinstance(step, tuple)
edge_list.append(step)
edge_list.append(tuple(reversed(step)))
g.add_nodes_from(range(node_count))
g.add_edges_from(edge_list)
new_edges.append(zip(*edge_list))
graph_list.append(g)
# batch
bg = dgl.batch(graph_list)
# get new edges
new_edges = [bg.query_new_edge(g, *edges) for g, edges in zip(graph_list, new_edges)]
new_src, new_dst = zip(*new_edges)
return bg, new_src, new_dst
def expand_ground_truth(ordering):
node_list = []
action = []
label = []
first_step = True
for i in ordering:
if isinstance(i, int):
if not first_step:
# add not to add edge
action.append(1)
label.append(0)
node_list.append(-1)
else:
first_step = False
action.append(0) # add node
label.append(1)
node_list.append(i)
else:
assert(isinstance(i, tuple))
action.append(1)
label.append(1)
node_list.append(i[0]) # select src node to add
# add not to add node
action.append(0)
label.append(0)
node_list.append(-1)
return len(action), action, label, node_list
def pad_ground_truth(batch):
a = []
bz = len(batch)
for sample in batch:
a.append(expand_ground_truth(sample))
length, action, label, node_list = zip(*a)
step = [0] * bz
new_label = []
new_node_list = []
mask_for_batch = []
next_action = 0
count = 0
active_step = [] # steps at least some graphs are not masked
label1_set = [] # graphs who decide to add node or edge
label1_set_tensor = []
while any([step[i] < length[i] for i in range(bz)]):
node_select = []
label_select = []
mask = []
label1 = []
not_all_masked = False
for sample_idx in range(bz):
if step[sample_idx] < length[sample_idx] and \
action[sample_idx][step[sample_idx]] == next_action:
mask.append(1)
node_select.append(node_list[sample_idx][step[sample_idx]])
label_select.append(label[sample_idx][step[sample_idx]])
# if decide to add node or add edge, record sample_idx
if label_select[-1] == 1:
label1.append(sample_idx)
step[sample_idx] += 1
not_all_masked = True
else:
mask.append(0)
node_select.append(-1)
label_select.append(0)
next_action = 1 - next_action
new_node_list.append(torch.LongTensor(node_select))
mask_for_batch.append(torch.ByteTensor(mask))
new_label.append(torch.LongTensor(label_select))
active_step.append(not_all_masked)
label1_set.append(np.array(label1))
label1_set_tensor.append(torch.LongTensor(label1))
count += 1
return count, new_label, new_node_list, mask_for_batch, active_step, label1_set, label1_set_tensor
def elapsed(msg, start, end):
print("{}: {} ms".format(msg, int((end-start)*1000)))
if __name__ == '__main__':
n = 15
m = 2
n_samples = 1024
fname ='samples.p'
generate_dataset(n, m, n_samples, fname)
import datetime
import matplotlib.pyplot as plt
import os
import random
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.init as init
from pprint import pprint
########################################################################################################################
# configuration #
########################################################################################################################
def mkdir_p(path):
import errno
try:
os.makedirs(path)
print('Created directory {}'.format(path))
except OSError as exc:
if exc.errno == errno.EEXIST and os.path.isdir(path):
print('Directory {} already exists.'.format(path))
else:
raise
def date_filename(base_dir='./'):
dt = datetime.datetime.now()
return os.path.join(base_dir, '{}_{:02d}-{:02d}-{:02d}'.format(
dt.date(), dt.hour, dt.minute, dt.second
))
def setup_log_dir(opts):
log_dir = '{}'.format(date_filename(opts['log_dir']))
mkdir_p(log_dir)
return log_dir
def save_arg_dict(opts, filename='settings.txt'):
def _format_value(v):
if isinstance(v, float):
return '{:.4f}'.format(v)
elif isinstance(v, int):
return '{:d}'.format(v)
else:
return '{}'.format(v)
save_path = os.path.join(opts['log_dir'], filename)
with open(save_path, 'w') as f:
for key, value in opts.items():
f.write('{}\t{}\n'.format(key, _format_value(value)))
print('Saved settings to {}'.format(save_path))
def setup(args):
opts = args.__dict__.copy()
cudnn.benchmark = False
cudnn.deterministic = True
# Seed
if opts['seed'] is None:
opts['seed'] = random.randint(1, 10000)
random.seed(opts['seed'])
torch.manual_seed(opts['seed'])
# Dataset
from configure import dataset_based_configure
opts = dataset_based_configure(opts)
assert opts['path_to_dataset'] is not None, 'Expect path to dataset to be set.'
if not os.path.exists(opts['path_to_dataset']):
if opts['dataset'] == 'cycles':
from cycles import generate_dataset
generate_dataset(opts['min_size'], opts['max_size'],
opts['ds_size'], opts['path_to_dataset'])
else:
raise ValueError('Unsupported dataset: {}'.format(opts['dataset']))
# Optimization
if opts['clip_grad']:
assert opts['clip_grad'] is not None, 'Expect the gradient norm constraint to be set.'
# Log
print('Prepare logging directory...')
log_dir = setup_log_dir(opts)
opts['log_dir'] = log_dir
mkdir_p(log_dir + '/samples')
plt.switch_backend('Agg')
save_arg_dict(opts)
pprint(opts)
return opts
########################################################################################################################
# model #
########################################################################################################################
def weights_init(m):
'''
Code from https://gist.github.com/jeasinema/ed9236ce743c8efaf30fa2ff732749f5
Usage:
model = Model()
model.apply(weight_init)
'''
if isinstance(m, nn.Linear):
init.xavier_normal_(m.weight.data)
init.normal_(m.bias.data)
elif isinstance(m, nn.GRUCell):
for param in m.parameters():
if len(param.shape) >= 2:
init.orthogonal_(param.data)
else:
init.normal_(param.data)
def dgmg_message_weight_init(m):
"""
This is similar as the function above where we initialize linear layers from a normal distribution with std
1./10 as suggested by the author. This should only be used for the message passing functions, i.e. fe's in the
paper.
"""
def _weight_init(m):
if isinstance(m, nn.Linear):
init.normal_(m.weight.data, std=1./10)
init.normal_(m.bias.data, std=1./10)
else:
raise ValueError('Expected the input to be of type nn.Linear!')
if isinstance(m, nn.ModuleList):
for layer in m:
layer.apply(_weight_init)
else:
m.apply(_weight_init)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment