Unverified Commit a0d0b1ea authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Model] Fix + batched DGMG (#175)

* DGMG with batch size 1

* Fix

* Adjustment

* Fix

* Fix

* Fix

* Fix

* Fix has_node and __contains__

* Batched implementation for DGMG

* Remove redundant dependency

* Adjustment

* Fix

* Add comments
parent 5cda368d
......@@ -3,11 +3,16 @@
This is an implementation of [Learning Deep Generative Models of Graphs](https://arxiv.org/pdf/1803.03324.pdf) by
Yujia Li, Oriol Vinyals, Chris Dyer, Razvan Pascanu, Peter Battaglia.
# Dependency
## Dependency
- Python 3.5.2
- [Pytorch 0.4.1](https://pytorch.org/)
- [Matplotlib 2.2.2](https://matplotlib.org/)
# Usage
## Usage
- Train with batch size 1: `python main.py`
- Train with batch size larger than 1: `python main_batch.py`.
## Acknowledgement
We would like to thank Yujia Li for providing details on the implementation.
......@@ -89,10 +89,13 @@ class CycleDataset(Dataset):
def __getitem__(self, index):
return self.dataset[index]
def collate(self, batch):
def collate_single(self, batch):
assert len(batch) == 1, 'Currently we do not support batched training'
return batch[0]
def collate_batch(self, batch):
return batch
def dglGraph_to_adj_list(g):
adj_list = {}
......@@ -112,14 +115,6 @@ class CycleModelEvaluation(object):
self.dir = dir
def _initialize(self):
self.num_samples_examined = 0
self.average_size = 0
self.valid_size_ratio = 0
self.cycle_ratio = 0
self.valid_ratio = 0
def rollout_and_examine(self, model, num_samples):
assert not model.training, 'You need to call model.eval().'
......@@ -132,14 +127,22 @@ class CycleModelEvaluation(object):
for i in range(num_samples):
sampled_graph = model()
if isinstance(sampled_graph, list):
# When the model is a batched implementation, a list of
# DGLGraph objects is returned. Note that with model(),
# we generate a single graph as with the non-batched
# implementation. We actually support batched generation
# during the inference so feel free to modify the code.
sampled_graph = sampled_graph[0]
sampled_adj_list = dglGraph_to_adj_list(sampled_graph)
adj_lists_to_plot.append(sampled_adj_list)
generated_graph_size = sampled_graph.number_of_nodes()
valid_size = (self.v_min <= generated_graph_size <= self.v_max)
graph_size = sampled_graph.number_of_nodes()
valid_size = (self.v_min <= graph_size <= self.v_max)
cycle = is_cycle(sampled_graph)
num_total_size += generated_graph_size
num_total_size += graph_size
if valid_size:
num_valid_size += 1
......@@ -150,7 +153,7 @@ class CycleModelEvaluation(object):
if valid_size and cycle:
num_valid += 1
if len(adj_lists_to_plot) == 4:
if len(adj_lists_to_plot) >= 4:
plot_times += 1
fig, ((ax0, ax1), (ax2, ax3)) = plt.subplots(2, 2)
axes = {0: ax0, 1: ax1, 2: ax2, 3: ax3}
......@@ -197,7 +200,6 @@ class CycleModelEvaluation(object):
f.write(msg)
print('Saved model evaluation statistics to {}'.format(model_eval_path))
self._initialize()
class CyclePrinting(object):
......
......@@ -7,6 +7,7 @@ This implementation works with a minibatch of size 1 only for both training and
import argparse
import datetime
import time
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
......@@ -31,7 +32,7 @@ def main(opts):
raise ValueError('Unsupported dataset: {}'.format(opts['dataset']))
data_loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0,
collate_fn=dataset.collate)
collate_fn=dataset.collate_single)
# Initialize_model
model = DGMG(v_max=opts['max_size'],
......@@ -96,6 +97,9 @@ def main(opts):
print('On average, an epoch takes {}.'.format(datetime.timedelta(
seconds=(t3-t2) / opts['nepochs'])))
del model.g
torch.save(model, './model.pth')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='DGMG')
......
"""
Learning Deep Generative Models of Graphs
Paper: https://arxiv.org/pdf/1803.03324.pdf
This implementation works with a minibatch of size larger than 1 for training and 1 for inference.
"""
import argparse
import datetime
import time
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
from model_batch import DGMG
def main(opts):
t1 = time.time()
# Setup dataset and data loader
if opts['dataset'] == 'cycles':
from cycles import CycleDataset, CycleModelEvaluation, CyclePrinting
dataset = CycleDataset(fname=opts['path_to_dataset'])
evaluator = CycleModelEvaluation(v_min=opts['min_size'],
v_max=opts['max_size'],
dir = opts['log_dir'])
printer = CyclePrinting(num_epochs=opts['nepochs'],
num_batches=len(dataset) // opts['batch_size'])
else:
raise ValueError('Unsupported dataset: {}'.format(opts['dataset']))
data_loader = DataLoader(dataset, batch_size=opts['batch_size'], shuffle=True, num_workers=0,
collate_fn=dataset.collate_batch)
# Initialize_model
model = DGMG(v_max=opts['max_size'],
node_hidden_size=opts['node_hidden_size'],
num_prop_rounds=opts['num_propagation_rounds'])
# Initialize optimizer
if opts['optimizer'] == 'Adam':
optimizer = Adam(model.parameters(), lr=opts['lr'])
else:
raise ValueError('Unsupported argument for the optimizer')
t2 = time.time()
# Training
model.train()
for epoch in range(opts['nepochs']):
for batch, data in enumerate(data_loader):
log_prob = model(batch_size=opts['batch_size'], actions=data)
loss = - log_prob / opts['batch_size']
batch_avg_prob = (log_prob / opts['batch_size']).detach().exp()
batch_avg_loss = loss.item()
optimizer.zero_grad()
loss.backward()
if opts['clip_grad']:
clip_grad_norm_(model.parameters(), opts['clip_bound'])
optimizer.step()
printer.update(epoch + 1, {'averaged loss': batch_avg_loss,
'averaged prob': batch_avg_prob})
t3 = time.time()
model.eval()
evaluator.rollout_and_examine(model, opts['num_generated_samples'])
evaluator.write_summary()
t4 = time.time()
print('It took {} to setup.'.format(datetime.timedelta(seconds=t2-t1)))
print('It took {} to finish training.'.format(datetime.timedelta(seconds=t3-t2)))
print('It took {} to finish evaluation.'.format(datetime.timedelta(seconds=t4-t3)))
print('--------------------------------------------------------------------------')
print('On average, an epoch takes {}.'.format(datetime.timedelta(
seconds=(t3-t2) / opts['nepochs'])))
del model.g_list
torch.save(model, './model_batched.pth')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='batched DGMG')
# configure
parser.add_argument('--seed', type=int, default=9284, help='random seed')
# dataset
parser.add_argument('--dataset', choices=['cycles'], default='cycles',
help='dataset to use')
parser.add_argument('--path-to-dataset', type=str, default='cycles.p',
help='load the dataset if it exists, '
'generate it and save to the path otherwise')
# log
parser.add_argument('--log-dir', default='./results',
help='folder to save info like experiment configuration '
'or model evaluation results')
# optimization
parser.add_argument('--batch-size', type=int, default=10,
help='batch size to use for training')
parser.add_argument('--clip-grad', action='store_true', default=True,
help='gradient clipping is required to prevent gradient explosion')
parser.add_argument('--clip-bound', type=float, default=0.25,
help='constraint of gradient norm for gradient clipping')
args = parser.parse_args()
from utils import setup
opts = setup(args)
main(opts)
\ No newline at end of file
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from torch.distributions import Bernoulli, Categorical
class GraphEmbed(nn.Module):
def __init__(self, node_hidden_size):
super(GraphEmbed, self).__init__()
# Setting from the paper
self.graph_hidden_size = 2 * node_hidden_size
# Embed graphs
self.node_gating = nn.Sequential(
nn.Linear(node_hidden_size, 1),
nn.Sigmoid()
)
self.node_to_graph = nn.Linear(node_hidden_size,
self.graph_hidden_size)
def forward(self, g_list):
# With our current batched implementation of DGMG, new nodes
# are not added for any graph until all graphs are done with
# adding edges starting from the last node. Therefore all graphs
# in the graph_list should have the same number of nodes.
if g_list[0].number_of_nodes() == 0:
return torch.zeros(len(g_list), self.graph_hidden_size)
bg = dgl.batch(g_list)
bhv = bg.ndata['hv']
bg.ndata['hg'] = self.node_gating(bhv) * self.node_to_graph(bhv)
return dgl.sum_nodes(bg, 'hg')
class GraphProp(nn.Module):
def __init__(self, num_prop_rounds, node_hidden_size):
super(GraphProp, self).__init__()
self.num_prop_rounds = num_prop_rounds
# Setting from the paper
self.node_activation_hidden_size = 2 * node_hidden_size
message_funcs = []
node_update_funcs = []
self.reduce_funcs = []
for t in range(num_prop_rounds):
# input being [hv, hu, xuv]
message_funcs.append(nn.Linear(2 * node_hidden_size + 1,
self.node_activation_hidden_size))
self.reduce_funcs.append(partial(self.dgmg_reduce, round=t))
node_update_funcs.append(
nn.GRUCell(self.node_activation_hidden_size,
node_hidden_size))
self.message_funcs = nn.ModuleList(message_funcs)
self.node_update_funcs = nn.ModuleList(node_update_funcs)
def dgmg_msg(self, edges):
"""
For an edge u->v, return concat([h_u, x_uv])
"""
return {'m': torch.cat([edges.src['hv'],
edges.data['he']],
dim=1)}
def dgmg_reduce(self, nodes, round):
hv_old = nodes.data['hv']
m = nodes.mailbox['m']
message = torch.cat([
hv_old.unsqueeze(1).expand(-1, m.size(1), -1), m], dim=2)
node_activation = (self.message_funcs[round](message)).sum(1)
return {'a': node_activation}
def forward(self, g_list):
# Merge small graphs into a large graph.
bg = dgl.batch(g_list)
if bg.number_of_edges() == 0:
return
else:
for t in range(self.num_prop_rounds):
bg.update_all(message_func=self.dgmg_msg,
reduce_func=self.reduce_funcs[t])
bg.ndata['hv'] = self.node_update_funcs[t](
bg.ndata['a'], bg.ndata['hv'])
return dgl.unbatch(bg)
def bernoulli_action_log_prob(logit, action):
"""
Calculate the log p of an action with respect to a Bernoulli
distribution across a batch of actions. Use logit rather than
prob for numerical stability.
"""
log_probs = torch.cat([F.logsigmoid(-logit), F.logsigmoid(logit)], dim=1)
return log_probs.gather(1, torch.tensor(action).unsqueeze(1))
class AddNode(nn.Module):
def __init__(self, graph_embed_func, node_hidden_size):
super(AddNode, self).__init__()
self.graph_op = {'embed': graph_embed_func}
self.stop = 1
self.add_node = nn.Linear(graph_embed_func.graph_hidden_size, 1)
# If to add a node, initialize its hv
self.node_type_embed = nn.Embedding(1, node_hidden_size)
self.initialize_hv = nn.Linear(node_hidden_size + \
graph_embed_func.graph_hidden_size,
node_hidden_size)
self.init_node_activation = torch.zeros(1, 2 * node_hidden_size)
def _initialize_node_repr(self, g, node_type, graph_embed):
num_nodes = g.number_of_nodes()
hv_init = self.initialize_hv(
torch.cat([
self.node_type_embed(torch.LongTensor([node_type])),
graph_embed], dim=1))
g.nodes[num_nodes - 1].data['hv'] = hv_init
g.nodes[num_nodes - 1].data['a'] = self.init_node_activation
def prepare_training(self):
"""
This function will only be called during training.
It stores all log probabilities for AddNode actions.
Each element is a tensor of shape [batch_size, 1].
"""
self.log_prob = []
def forward(self, g_list, a=None):
"""
Decide if a new node should be added for each graph in
the `g_list`. If a new node is added, initialize its
node representations. Record graphs for which a new node
is added.
During training, the action is passed rather than made
and the log P of the action is recorded.
During inference, the action is sampled from a Bernoulli
distribution modeled.
Parameters
----------
g_list : list
A list of dgl.DGLGraph objects
a : None or list
- During training, a is a list of integers specifying
whether a new node should be added.
- During inference, a is None.
Returns
-------
g_non_stop : list
list of indices to specify which graphs in the
g_list have a new node added
"""
# Graphs for which a node is added
g_non_stop = []
batch_graph_embed = self.graph_op['embed'](g_list)
batch_logit = self.add_node(batch_graph_embed)
batch_prob = torch.sigmoid(batch_logit)
if not self.training:
a = Bernoulli(batch_prob).sample().squeeze(1).tolist()
for i, g in enumerate(g_list):
action = a[i]
stop = bool(action == self.stop)
if not stop:
g_non_stop.append(g.index)
g.add_nodes(1)
self._initialize_node_repr(g, action,
batch_graph_embed[i:i+1, :])
if self.training:
sample_log_prob = bernoulli_action_log_prob(batch_logit, a)
self.log_prob.append(sample_log_prob)
return g_non_stop
class AddEdge(nn.Module):
def __init__(self, graph_embed_func, node_hidden_size):
super(AddEdge, self).__init__()
self.graph_op = {'embed': graph_embed_func}
self.add_edge = nn.Linear(graph_embed_func.graph_hidden_size + \
node_hidden_size, 1)
def prepare_training(self):
"""
This function will only be called during training.
It stores all log probabilities for AddEdge actions.
Each element is a tensor of shape [batch_size, 1].
"""
self.log_prob = []
def forward(self, g_list, a=None):
"""
Decide if a new edge should be added for each graph in
the `g_list`. Record graphs for which a new edge is to
be added.
During training, the action is passed rather than made
and the log P of the action is recorded.
During inference, the action is sampled from a Bernoulli
distribution modeled.
Parameters
----------
g_list : list
A list of dgl.DGLGraph objects
a : None or list
- During training, a is a list of integers specifying
whether a new edge should be added.
- During inference, a is None.
Returns
-------
g_to_add_edge : list
list of indices to specify which graphs in the
g_list need a new edge to be added
"""
# Graphs for which an edge is to be added.
g_to_add_edge = []
batch_graph_embed = self.graph_op['embed'](g_list)
batch_src_embed = torch.cat([g.nodes[g.number_of_nodes() - 1].data['hv']
for g in g_list], dim=0)
batch_logit = self.add_edge(torch.cat([batch_graph_embed,
batch_src_embed], dim=1))
batch_prob = torch.sigmoid(batch_logit)
if not self.training:
a = Bernoulli(batch_prob).sample().squeeze(1).tolist()
for i, g in enumerate(g_list):
action = a[i]
if action == 0:
g_to_add_edge.append(g.index)
if self.training:
sample_log_prob = bernoulli_action_log_prob(batch_logit, a)
self.log_prob.append(sample_log_prob)
return g_to_add_edge
class ChooseDestAndUpdate(nn.Module):
def __init__(self, graph_prop_func, node_hidden_size):
super(ChooseDestAndUpdate, self).__init__()
self.choose_dest = nn.Linear(2 * node_hidden_size, 1)
def _initialize_edge_repr(self, g, src_list, dest_list):
# For untyped edges, we only add 1 to indicate its existence.
# For multiple edge types, we can use a one hot representation
# or an embedding module.
edge_repr = torch.ones(len(src_list), 1)
g.edges[src_list, dest_list].data['he'] = edge_repr
def prepare_training(self):
"""
This function will only be called during training.
It stores all log probabilities for ChooseDest actions.
Each element is a tensor of shape [1, 1].
"""
self.log_prob = []
def forward(self, g_list, d=None):
"""
For each g in g_list, add an edge (src, dest)
if (src, dst) does not exist. The src is just the latest
node in g. Initialize edge features if new edges are added.
During training, dst is passed rather than chosen and the
log P of the action is recorded.
During inference, dst is sampled from a Categorical
distribution modeled.
Parameters
----------
g_list : list
A list of dgl.DGLGraph objects
d : None or list
- During training, d is a list of integers specifying dst for
each graph in g_list.
- During inference, d is None.
"""
for i, g in enumerate(g_list):
src = g.number_of_nodes() - 1
possible_dests = range(src)
src_embed_expand = g.nodes[src].data['hv'].expand(src, -1)
possible_dests_embed = g.nodes[possible_dests].data['hv']
dests_scores = self.choose_dest(
torch.cat([possible_dests_embed,
src_embed_expand], dim=1)).view(1, -1)
dests_probs = F.softmax(dests_scores, dim=1)
if not self.training:
dest = Categorical(dests_probs).sample().item()
else:
dest = d[i]
# Note that we are not considering multigraph here.
if not g.has_edge_between(src, dest):
# For undirected graphs, we add edges for both
# directions so that we can perform graph propagation.
src_list = [src, dest]
dest_list = [dest, src]
g.add_edges(src_list, dest_list)
self._initialize_edge_repr(g, src_list, dest_list)
if self.training:
if dests_probs.nelement() > 1:
self.log_prob.append(
F.log_softmax(dests_scores, dim=1)[:, dest: dest + 1])
class DGMG(nn.Module):
def __init__(self, v_max, node_hidden_size,
num_prop_rounds):
super(DGMG, self).__init__()
# Graph configuration
self.v_max = v_max
# Graph embedding module
self.graph_embed = GraphEmbed(node_hidden_size)
# Graph propagation module
self.graph_prop = GraphProp(num_prop_rounds,
node_hidden_size)
# Actions
self.add_node_agent = AddNode(
self.graph_embed, node_hidden_size)
self.add_edge_agent = AddEdge(
self.graph_embed, node_hidden_size)
self.choose_dest_agent = ChooseDestAndUpdate(
self.graph_prop, node_hidden_size)
# Weight initialization
self.init_weights()
def init_weights(self):
from utils import weights_init, dgmg_message_weight_init
self.graph_embed.apply(weights_init)
self.graph_prop.apply(weights_init)
self.add_node_agent.apply(weights_init)
self.add_edge_agent.apply(weights_init)
self.choose_dest_agent.apply(weights_init)
self.graph_prop.message_funcs.apply(dgmg_message_weight_init)
def prepare(self, batch_size):
# Track how many actions have been taken for each graph.
self.step_count = [0] * batch_size
self.g_list = []
# indices for graphs being generated
self.g_active = list(range(batch_size))
for i in range(batch_size):
g = dgl.DGLGraph()
g.index = i
# If there are some features for nodes and edges,
# zero tensors will be set for those of new nodes and edges.
g.set_n_initializer(dgl.frame.zero_initializer)
g.set_e_initializer(dgl.frame.zero_initializer)
self.g_list.append(g)
if self.training:
self.add_node_agent.prepare_training()
self.add_edge_agent.prepare_training()
self.choose_dest_agent.prepare_training()
def _get_graphs(self, indices):
return [self.g_list[i] for i in indices]
def get_action_step(self, indices):
"""
This function should only be called during training.
Collect the number of actions taken for each graph
whose index is in the indices. After collecting
the number of actions, increment it by 1.
"""
old_step_count = []
for i in indices:
old_step_count.append(self.step_count[i])
self.step_count[i] += 1
return old_step_count
def get_actions(self, mode):
"""
This function should only be called during training.
Decide which graphs are related with the next batched
decision and extract the actions to take for each of
the graph.
"""
if mode == 'node':
# Graphs being generated
indices = self.g_active
elif mode == 'edge':
# Graphs having more edges to be added
# starting from the latest node.
indices = self.g_to_add_edge
else:
raise ValueError("Expected mode to be in ['node', 'edge'], "
"got {}".format(mode))
action_indices = self.get_action_step(indices)
# Actions for all graphs indexed by indices at timestep t
actions_t = []
for i, j in enumerate(indices):
actions_t.append(self.actions[j][action_indices[i]])
return actions_t
def add_node_and_update(self, a=None):
"""
Decide if to add a new node for each graph being generated.
If a new node should be added, update the graph.
The action(s) a are passed during training and
sampled (hence None) during inference.
"""
g_list = self._get_graphs(self.g_active)
g_non_stop = self.add_node_agent(g_list, a)
self.g_active = g_non_stop
# For all newly added nodes we need to decide
# if an edge is to be added for each of them.
self.g_to_add_edge = g_non_stop
return len(self.g_active) == 0
def add_edge_or_not(self, a=None):
"""
Decide if a new edge should be added for each
graph that may need one more edge.
The action(s) a are passed during training and
sampled (hence None) during inference.
"""
g_list = self._get_graphs(self.g_to_add_edge)
g_to_add_edge = self.add_edge_agent(g_list, a)
self.g_to_add_edge = g_to_add_edge
return len(self.g_to_add_edge) > 0
def choose_dest_and_update(self, a=None):
"""
For each graph that requires one more edge, choose
destination and connect it to the latest node.
Add edges for both directions and update the graph.
The action(s) a are passed during training and
sampled (hence None) during inference.
"""
g_list = self._get_graphs(self.g_to_add_edge)
self.choose_dest_agent(g_list, a)
# Graph propagation and update node features.
updated_g_list = self.graph_prop(g_list)
for i, g in enumerate(updated_g_list):
g.index = self.g_to_add_edge[i]
self.g_list[g.index] = g
def get_log_prob(self):
return torch.cat(self.add_node_agent.log_prob).sum()\
+ torch.cat(self.add_edge_agent.log_prob).sum()\
+ torch.cat(self.choose_dest_agent.log_prob).sum()
def forward_train(self, actions):
"""
Go through all decisions in actions and record their
log probabilities for calculating the loss.
Parameters
----------
actions : list
list of decisions extracted for generating a graph using DGMG
Returns
-------
tensor of shape torch.Size([])
log P(Generate a batch of graphs using DGMG)
"""
self.actions = actions
stop = self.add_node_and_update(a=self.get_actions('node'))
# Some graphs haven't been completely generated.
while not stop:
to_add_edge = self.add_edge_or_not(a=self.get_actions('edge'))
# Some graphs need more edges to be added for the latest node.
while to_add_edge:
self.choose_dest_and_update(a=self.get_actions('edge'))
to_add_edge = self.add_edge_or_not(a=self.get_actions('edge'))
stop = self.add_node_and_update(a=self.get_actions('node'))
return self.get_log_prob()
def forward_inference(self):
"""
Generate graph(s) on the fly.
Returns
-------
self.g_list : list
A list of dgl.DGLGraph objects.
"""
stop = self.add_node_and_update()
# Some graphs haven't been completely generated and their numbers of
# nodes do not exceed the limit of self.v_max.
while (not stop) and (self.g_list[self.g_active[0]].number_of_nodes()
< self.v_max + 1):
num_trials = 0
to_add_edge = self.add_edge_or_not()
# Some graphs need more edges to be added for the latest node and
# the number of trials does not exceed the number of maximum possible
# edges. Note that this limit on the number of edges eliminate the
# possibility of multi-graph and one may want to remove it.
while to_add_edge and (num_trials <
self.g_list[self.g_active[0]].number_of_nodes() - 1):
self.choose_dest_and_update()
num_trials += 1
to_add_edge = self.add_edge_or_not()
stop = self.add_node_and_update()
return self.g_list
def forward(self, batch_size=1, actions=None):
if self.training:
batch_size = len(actions)
self.prepare(batch_size)
if self.training:
return self.forward_train(actions)
else:
return self.forward_inference()
......@@ -285,11 +285,11 @@ class DGLGraph(object):
--------
has_nodes
"""
return self.has_node(vid)
return self._graph.has_node(vid)
def __contains__(self, vid):
"""Same as has_node."""
return self.has_node(vid)
return self._graph.has_node(vid)
def has_nodes(self, vids):
"""Return true if the nodes exist.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment