Unverified Commit a0d0b1ea authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Model] Fix + batched DGMG (#175)

* DGMG with batch size 1

* Fix

* Adjustment

* Fix

* Fix

* Fix

* Fix

* Fix has_node and __contains__

* Batched implementation for DGMG

* Remove redundant dependency

* Adjustment

* Fix

* Add comments
parent 5cda368d
...@@ -3,11 +3,16 @@ ...@@ -3,11 +3,16 @@
This is an implementation of [Learning Deep Generative Models of Graphs](https://arxiv.org/pdf/1803.03324.pdf) by This is an implementation of [Learning Deep Generative Models of Graphs](https://arxiv.org/pdf/1803.03324.pdf) by
Yujia Li, Oriol Vinyals, Chris Dyer, Razvan Pascanu, Peter Battaglia. Yujia Li, Oriol Vinyals, Chris Dyer, Razvan Pascanu, Peter Battaglia.
# Dependency ## Dependency
- Python 3.5.2 - Python 3.5.2
- [Pytorch 0.4.1](https://pytorch.org/) - [Pytorch 0.4.1](https://pytorch.org/)
- [Matplotlib 2.2.2](https://matplotlib.org/) - [Matplotlib 2.2.2](https://matplotlib.org/)
# Usage ## Usage
- Train with batch size 1: `python main.py` - Train with batch size 1: `python main.py`
- Train with batch size larger than 1: `python main_batch.py`.
## Acknowledgement
We would like to thank Yujia Li for providing details on the implementation.
...@@ -89,10 +89,13 @@ class CycleDataset(Dataset): ...@@ -89,10 +89,13 @@ class CycleDataset(Dataset):
def __getitem__(self, index): def __getitem__(self, index):
return self.dataset[index] return self.dataset[index]
def collate(self, batch): def collate_single(self, batch):
assert len(batch) == 1, 'Currently we do not support batched training' assert len(batch) == 1, 'Currently we do not support batched training'
return batch[0] return batch[0]
def collate_batch(self, batch):
return batch
def dglGraph_to_adj_list(g): def dglGraph_to_adj_list(g):
adj_list = {} adj_list = {}
...@@ -112,14 +115,6 @@ class CycleModelEvaluation(object): ...@@ -112,14 +115,6 @@ class CycleModelEvaluation(object):
self.dir = dir self.dir = dir
def _initialize(self):
self.num_samples_examined = 0
self.average_size = 0
self.valid_size_ratio = 0
self.cycle_ratio = 0
self.valid_ratio = 0
def rollout_and_examine(self, model, num_samples): def rollout_and_examine(self, model, num_samples):
assert not model.training, 'You need to call model.eval().' assert not model.training, 'You need to call model.eval().'
...@@ -132,14 +127,22 @@ class CycleModelEvaluation(object): ...@@ -132,14 +127,22 @@ class CycleModelEvaluation(object):
for i in range(num_samples): for i in range(num_samples):
sampled_graph = model() sampled_graph = model()
if isinstance(sampled_graph, list):
# When the model is a batched implementation, a list of
# DGLGraph objects is returned. Note that with model(),
# we generate a single graph as with the non-batched
# implementation. We actually support batched generation
# during the inference so feel free to modify the code.
sampled_graph = sampled_graph[0]
sampled_adj_list = dglGraph_to_adj_list(sampled_graph) sampled_adj_list = dglGraph_to_adj_list(sampled_graph)
adj_lists_to_plot.append(sampled_adj_list) adj_lists_to_plot.append(sampled_adj_list)
generated_graph_size = sampled_graph.number_of_nodes() graph_size = sampled_graph.number_of_nodes()
valid_size = (self.v_min <= generated_graph_size <= self.v_max) valid_size = (self.v_min <= graph_size <= self.v_max)
cycle = is_cycle(sampled_graph) cycle = is_cycle(sampled_graph)
num_total_size += generated_graph_size num_total_size += graph_size
if valid_size: if valid_size:
num_valid_size += 1 num_valid_size += 1
...@@ -150,7 +153,7 @@ class CycleModelEvaluation(object): ...@@ -150,7 +153,7 @@ class CycleModelEvaluation(object):
if valid_size and cycle: if valid_size and cycle:
num_valid += 1 num_valid += 1
if len(adj_lists_to_plot) == 4: if len(adj_lists_to_plot) >= 4:
plot_times += 1 plot_times += 1
fig, ((ax0, ax1), (ax2, ax3)) = plt.subplots(2, 2) fig, ((ax0, ax1), (ax2, ax3)) = plt.subplots(2, 2)
axes = {0: ax0, 1: ax1, 2: ax2, 3: ax3} axes = {0: ax0, 1: ax1, 2: ax2, 3: ax3}
...@@ -197,7 +200,6 @@ class CycleModelEvaluation(object): ...@@ -197,7 +200,6 @@ class CycleModelEvaluation(object):
f.write(msg) f.write(msg)
print('Saved model evaluation statistics to {}'.format(model_eval_path)) print('Saved model evaluation statistics to {}'.format(model_eval_path))
self._initialize()
class CyclePrinting(object): class CyclePrinting(object):
......
...@@ -7,6 +7,7 @@ This implementation works with a minibatch of size 1 only for both training and ...@@ -7,6 +7,7 @@ This implementation works with a minibatch of size 1 only for both training and
import argparse import argparse
import datetime import datetime
import time import time
import torch
from torch.optim import Adam from torch.optim import Adam
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
...@@ -31,7 +32,7 @@ def main(opts): ...@@ -31,7 +32,7 @@ def main(opts):
raise ValueError('Unsupported dataset: {}'.format(opts['dataset'])) raise ValueError('Unsupported dataset: {}'.format(opts['dataset']))
data_loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0, data_loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0,
collate_fn=dataset.collate) collate_fn=dataset.collate_single)
# Initialize_model # Initialize_model
model = DGMG(v_max=opts['max_size'], model = DGMG(v_max=opts['max_size'],
...@@ -96,6 +97,9 @@ def main(opts): ...@@ -96,6 +97,9 @@ def main(opts):
print('On average, an epoch takes {}.'.format(datetime.timedelta( print('On average, an epoch takes {}.'.format(datetime.timedelta(
seconds=(t3-t2) / opts['nepochs']))) seconds=(t3-t2) / opts['nepochs'])))
del model.g
torch.save(model, './model.pth')
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='DGMG') parser = argparse.ArgumentParser(description='DGMG')
......
"""
Learning Deep Generative Models of Graphs
Paper: https://arxiv.org/pdf/1803.03324.pdf
This implementation works with a minibatch of size larger than 1 for training and 1 for inference.
"""
import argparse
import datetime
import time
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
from model_batch import DGMG
def main(opts):
t1 = time.time()
# Setup dataset and data loader
if opts['dataset'] == 'cycles':
from cycles import CycleDataset, CycleModelEvaluation, CyclePrinting
dataset = CycleDataset(fname=opts['path_to_dataset'])
evaluator = CycleModelEvaluation(v_min=opts['min_size'],
v_max=opts['max_size'],
dir = opts['log_dir'])
printer = CyclePrinting(num_epochs=opts['nepochs'],
num_batches=len(dataset) // opts['batch_size'])
else:
raise ValueError('Unsupported dataset: {}'.format(opts['dataset']))
data_loader = DataLoader(dataset, batch_size=opts['batch_size'], shuffle=True, num_workers=0,
collate_fn=dataset.collate_batch)
# Initialize_model
model = DGMG(v_max=opts['max_size'],
node_hidden_size=opts['node_hidden_size'],
num_prop_rounds=opts['num_propagation_rounds'])
# Initialize optimizer
if opts['optimizer'] == 'Adam':
optimizer = Adam(model.parameters(), lr=opts['lr'])
else:
raise ValueError('Unsupported argument for the optimizer')
t2 = time.time()
# Training
model.train()
for epoch in range(opts['nepochs']):
for batch, data in enumerate(data_loader):
log_prob = model(batch_size=opts['batch_size'], actions=data)
loss = - log_prob / opts['batch_size']
batch_avg_prob = (log_prob / opts['batch_size']).detach().exp()
batch_avg_loss = loss.item()
optimizer.zero_grad()
loss.backward()
if opts['clip_grad']:
clip_grad_norm_(model.parameters(), opts['clip_bound'])
optimizer.step()
printer.update(epoch + 1, {'averaged loss': batch_avg_loss,
'averaged prob': batch_avg_prob})
t3 = time.time()
model.eval()
evaluator.rollout_and_examine(model, opts['num_generated_samples'])
evaluator.write_summary()
t4 = time.time()
print('It took {} to setup.'.format(datetime.timedelta(seconds=t2-t1)))
print('It took {} to finish training.'.format(datetime.timedelta(seconds=t3-t2)))
print('It took {} to finish evaluation.'.format(datetime.timedelta(seconds=t4-t3)))
print('--------------------------------------------------------------------------')
print('On average, an epoch takes {}.'.format(datetime.timedelta(
seconds=(t3-t2) / opts['nepochs'])))
del model.g_list
torch.save(model, './model_batched.pth')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='batched DGMG')
# configure
parser.add_argument('--seed', type=int, default=9284, help='random seed')
# dataset
parser.add_argument('--dataset', choices=['cycles'], default='cycles',
help='dataset to use')
parser.add_argument('--path-to-dataset', type=str, default='cycles.p',
help='load the dataset if it exists, '
'generate it and save to the path otherwise')
# log
parser.add_argument('--log-dir', default='./results',
help='folder to save info like experiment configuration '
'or model evaluation results')
# optimization
parser.add_argument('--batch-size', type=int, default=10,
help='batch size to use for training')
parser.add_argument('--clip-grad', action='store_true', default=True,
help='gradient clipping is required to prevent gradient explosion')
parser.add_argument('--clip-bound', type=float, default=0.25,
help='constraint of gradient norm for gradient clipping')
args = parser.parse_args()
from utils import setup
opts = setup(args)
main(opts)
\ No newline at end of file
This diff is collapsed.
...@@ -285,11 +285,11 @@ class DGLGraph(object): ...@@ -285,11 +285,11 @@ class DGLGraph(object):
-------- --------
has_nodes has_nodes
""" """
return self.has_node(vid) return self._graph.has_node(vid)
def __contains__(self, vid): def __contains__(self, vid):
"""Same as has_node.""" """Same as has_node."""
return self.has_node(vid) return self._graph.has_node(vid)
def has_nodes(self, vids): def has_nodes(self, vids):
"""Return true if the nodes exist. """Return true if the nodes exist.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment