Unverified Commit a0d0b1ea authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Model] Fix + batched DGMG (#175)

* DGMG with batch size 1

* Fix

* Adjustment

* Fix

* Fix

* Fix

* Fix

* Fix has_node and __contains__

* Batched implementation for DGMG

* Remove redundant dependency

* Adjustment

* Fix

* Add comments
parent 5cda368d
......@@ -3,11 +3,16 @@
This is an implementation of [Learning Deep Generative Models of Graphs](https://arxiv.org/pdf/1803.03324.pdf) by
Yujia Li, Oriol Vinyals, Chris Dyer, Razvan Pascanu, Peter Battaglia.
# Dependency
## Dependency
- Python 3.5.2
- [Pytorch 0.4.1](https://pytorch.org/)
- [Matplotlib 2.2.2](https://matplotlib.org/)
# Usage
## Usage
- Train with batch size 1: `python main.py`
- Train with batch size larger than 1: `python main_batch.py`.
## Acknowledgement
We would like to thank Yujia Li for providing details on the implementation.
......@@ -89,10 +89,13 @@ class CycleDataset(Dataset):
def __getitem__(self, index):
return self.dataset[index]
def collate(self, batch):
def collate_single(self, batch):
assert len(batch) == 1, 'Currently we do not support batched training'
return batch[0]
def collate_batch(self, batch):
return batch
def dglGraph_to_adj_list(g):
adj_list = {}
......@@ -112,14 +115,6 @@ class CycleModelEvaluation(object):
self.dir = dir
def _initialize(self):
self.num_samples_examined = 0
self.average_size = 0
self.valid_size_ratio = 0
self.cycle_ratio = 0
self.valid_ratio = 0
def rollout_and_examine(self, model, num_samples):
assert not model.training, 'You need to call model.eval().'
......@@ -132,14 +127,22 @@ class CycleModelEvaluation(object):
for i in range(num_samples):
sampled_graph = model()
if isinstance(sampled_graph, list):
# When the model is a batched implementation, a list of
# DGLGraph objects is returned. Note that with model(),
# we generate a single graph as with the non-batched
# implementation. We actually support batched generation
# during the inference so feel free to modify the code.
sampled_graph = sampled_graph[0]
sampled_adj_list = dglGraph_to_adj_list(sampled_graph)
adj_lists_to_plot.append(sampled_adj_list)
generated_graph_size = sampled_graph.number_of_nodes()
valid_size = (self.v_min <= generated_graph_size <= self.v_max)
graph_size = sampled_graph.number_of_nodes()
valid_size = (self.v_min <= graph_size <= self.v_max)
cycle = is_cycle(sampled_graph)
num_total_size += generated_graph_size
num_total_size += graph_size
if valid_size:
num_valid_size += 1
......@@ -150,7 +153,7 @@ class CycleModelEvaluation(object):
if valid_size and cycle:
num_valid += 1
if len(adj_lists_to_plot) == 4:
if len(adj_lists_to_plot) >= 4:
plot_times += 1
fig, ((ax0, ax1), (ax2, ax3)) = plt.subplots(2, 2)
axes = {0: ax0, 1: ax1, 2: ax2, 3: ax3}
......@@ -197,7 +200,6 @@ class CycleModelEvaluation(object):
f.write(msg)
print('Saved model evaluation statistics to {}'.format(model_eval_path))
self._initialize()
class CyclePrinting(object):
......
......@@ -7,6 +7,7 @@ This implementation works with a minibatch of size 1 only for both training and
import argparse
import datetime
import time
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
......@@ -31,7 +32,7 @@ def main(opts):
raise ValueError('Unsupported dataset: {}'.format(opts['dataset']))
data_loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0,
collate_fn=dataset.collate)
collate_fn=dataset.collate_single)
# Initialize_model
model = DGMG(v_max=opts['max_size'],
......@@ -96,6 +97,9 @@ def main(opts):
print('On average, an epoch takes {}.'.format(datetime.timedelta(
seconds=(t3-t2) / opts['nepochs'])))
del model.g
torch.save(model, './model.pth')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='DGMG')
......
"""
Learning Deep Generative Models of Graphs
Paper: https://arxiv.org/pdf/1803.03324.pdf
This implementation works with a minibatch of size larger than 1 for training and 1 for inference.
"""
import argparse
import datetime
import time
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
from model_batch import DGMG
def main(opts):
t1 = time.time()
# Setup dataset and data loader
if opts['dataset'] == 'cycles':
from cycles import CycleDataset, CycleModelEvaluation, CyclePrinting
dataset = CycleDataset(fname=opts['path_to_dataset'])
evaluator = CycleModelEvaluation(v_min=opts['min_size'],
v_max=opts['max_size'],
dir = opts['log_dir'])
printer = CyclePrinting(num_epochs=opts['nepochs'],
num_batches=len(dataset) // opts['batch_size'])
else:
raise ValueError('Unsupported dataset: {}'.format(opts['dataset']))
data_loader = DataLoader(dataset, batch_size=opts['batch_size'], shuffle=True, num_workers=0,
collate_fn=dataset.collate_batch)
# Initialize_model
model = DGMG(v_max=opts['max_size'],
node_hidden_size=opts['node_hidden_size'],
num_prop_rounds=opts['num_propagation_rounds'])
# Initialize optimizer
if opts['optimizer'] == 'Adam':
optimizer = Adam(model.parameters(), lr=opts['lr'])
else:
raise ValueError('Unsupported argument for the optimizer')
t2 = time.time()
# Training
model.train()
for epoch in range(opts['nepochs']):
for batch, data in enumerate(data_loader):
log_prob = model(batch_size=opts['batch_size'], actions=data)
loss = - log_prob / opts['batch_size']
batch_avg_prob = (log_prob / opts['batch_size']).detach().exp()
batch_avg_loss = loss.item()
optimizer.zero_grad()
loss.backward()
if opts['clip_grad']:
clip_grad_norm_(model.parameters(), opts['clip_bound'])
optimizer.step()
printer.update(epoch + 1, {'averaged loss': batch_avg_loss,
'averaged prob': batch_avg_prob})
t3 = time.time()
model.eval()
evaluator.rollout_and_examine(model, opts['num_generated_samples'])
evaluator.write_summary()
t4 = time.time()
print('It took {} to setup.'.format(datetime.timedelta(seconds=t2-t1)))
print('It took {} to finish training.'.format(datetime.timedelta(seconds=t3-t2)))
print('It took {} to finish evaluation.'.format(datetime.timedelta(seconds=t4-t3)))
print('--------------------------------------------------------------------------')
print('On average, an epoch takes {}.'.format(datetime.timedelta(
seconds=(t3-t2) / opts['nepochs'])))
del model.g_list
torch.save(model, './model_batched.pth')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='batched DGMG')
# configure
parser.add_argument('--seed', type=int, default=9284, help='random seed')
# dataset
parser.add_argument('--dataset', choices=['cycles'], default='cycles',
help='dataset to use')
parser.add_argument('--path-to-dataset', type=str, default='cycles.p',
help='load the dataset if it exists, '
'generate it and save to the path otherwise')
# log
parser.add_argument('--log-dir', default='./results',
help='folder to save info like experiment configuration '
'or model evaluation results')
# optimization
parser.add_argument('--batch-size', type=int, default=10,
help='batch size to use for training')
parser.add_argument('--clip-grad', action='store_true', default=True,
help='gradient clipping is required to prevent gradient explosion')
parser.add_argument('--clip-bound', type=float, default=0.25,
help='constraint of gradient norm for gradient clipping')
args = parser.parse_args()
from utils import setup
opts = setup(args)
main(opts)
\ No newline at end of file
This diff is collapsed.
......@@ -285,11 +285,11 @@ class DGLGraph(object):
--------
has_nodes
"""
return self.has_node(vid)
return self._graph.has_node(vid)
def __contains__(self, vid):
"""Same as has_node."""
return self.has_node(vid)
return self._graph.has_node(vid)
def has_nodes(self, vids):
"""Return true if the nodes exist.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment