Unverified Commit 100d9328 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

Update (#2062)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-1-5.us-west-2.compute.internal>
parent a260a6e6
...@@ -4,7 +4,7 @@ This is an implementation of [Learning Deep Generative Models of Graphs](https:/ ...@@ -4,7 +4,7 @@ This is an implementation of [Learning Deep Generative Models of Graphs](https:/
Yujia Li, Oriol Vinyals, Chris Dyer, Razvan Pascanu, Peter Battaglia. Yujia Li, Oriol Vinyals, Chris Dyer, Razvan Pascanu, Peter Battaglia.
For molecule generation, see For molecule generation, see
[our model zoo for Chemistry](https://github.com/dmlc/dgl/tree/master/examples/pytorch/model_zoo/chem/generative_models/dgmg). [DGL-LifeSci](https://github.com/awslabs/dgl-lifesci/tree/master/examples/generative_models/dgmg).
## Dependencies ## Dependencies
- Python 3.5.2 - Python 3.5.2
...@@ -13,8 +13,7 @@ For molecule generation, see ...@@ -13,8 +13,7 @@ For molecule generation, see
## Usage ## Usage
- Train with batch size 1: `python3 main.py` `python3 main.py`
- Train with batch size larger than 1: `python3 main_batch.py`.
## Performance ## Performance
...@@ -22,8 +21,7 @@ For molecule generation, see ...@@ -22,8 +21,7 @@ For molecule generation, see
## Speed ## Speed
On AWS p3.2x instance (w/ V100), one epoch takes ~526s for batch size 1 and takes On AWS p3.2x instance (w/ V100), one epoch takes ~526s.
~238s for batch size 10.
## Acknowledgement ## Acknowledgement
......
"""
Learning Deep Generative Models of Graphs
Paper: https://arxiv.org/pdf/1803.03324.pdf
This implementation works with a minibatch of size larger than 1 for training and 1 for inference.
"""
import argparse
import datetime
import time
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
from model_batch import DGMG
def main(opts):
t1 = time.time()
# Setup dataset and data loader
if opts['dataset'] == 'cycles':
from cycles import CycleDataset, CycleModelEvaluation, CyclePrinting
dataset = CycleDataset(fname=opts['path_to_dataset'])
evaluator = CycleModelEvaluation(v_min=opts['min_size'],
v_max=opts['max_size'],
dir = opts['log_dir'])
printer = CyclePrinting(num_epochs=opts['nepochs'],
num_batches=len(dataset) // opts['batch_size'])
else:
raise ValueError('Unsupported dataset: {}'.format(opts['dataset']))
data_loader = DataLoader(dataset, batch_size=opts['batch_size'], shuffle=True, num_workers=0,
collate_fn=dataset.collate_batch)
# Initialize_model
model = DGMG(v_max=opts['max_size'],
node_hidden_size=opts['node_hidden_size'],
num_prop_rounds=opts['num_propagation_rounds'])
# Initialize optimizer
if opts['optimizer'] == 'Adam':
optimizer = Adam(model.parameters(), lr=opts['lr'])
else:
raise ValueError('Unsupported argument for the optimizer')
t2 = time.time()
# Training
model.train()
for epoch in range(opts['nepochs']):
for batch, data in enumerate(data_loader):
log_prob = model(batch_size=opts['batch_size'], actions=data)
loss = - log_prob / opts['batch_size']
batch_avg_prob = (log_prob / opts['batch_size']).detach().exp()
batch_avg_loss = loss.item()
optimizer.zero_grad()
loss.backward()
if opts['clip_grad']:
clip_grad_norm_(model.parameters(), opts['clip_bound'])
optimizer.step()
printer.update(epoch + 1, {'averaged loss': batch_avg_loss,
'averaged prob': batch_avg_prob})
t3 = time.time()
model.eval()
evaluator.rollout_and_examine(model, opts['num_generated_samples'])
evaluator.write_summary()
t4 = time.time()
print('It took {} to setup.'.format(datetime.timedelta(seconds=t2-t1)))
print('It took {} to finish training.'.format(datetime.timedelta(seconds=t3-t2)))
print('It took {} to finish evaluation.'.format(datetime.timedelta(seconds=t4-t3)))
print('--------------------------------------------------------------------------')
print('On average, an epoch takes {}.'.format(datetime.timedelta(
seconds=(t3-t2) / opts['nepochs'])))
del model.g_list
torch.save(model, './model_batched.pth')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='batched DGMG')
# configure
parser.add_argument('--seed', type=int, default=9284, help='random seed')
# dataset
parser.add_argument('--dataset', choices=['cycles'], default='cycles',
help='dataset to use')
parser.add_argument('--path-to-dataset', type=str, default='cycles.p',
help='load the dataset if it exists, '
'generate it and save to the path otherwise')
# log
parser.add_argument('--log-dir', default='./results',
help='folder to save info like experiment configuration '
'or model evaluation results')
# optimization
parser.add_argument('--batch-size', type=int, default=10,
help='batch size to use for training')
parser.add_argument('--clip-grad', action='store_true', default=True,
help='gradient clipping is required to prevent gradient explosion')
parser.add_argument('--clip-bound', type=float, default=0.25,
help='constraint of gradient norm for gradient clipping')
args = parser.parse_args()
from utils import setup
opts = setup(args)
main(opts)
\ No newline at end of file
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from torch.distributions import Bernoulli, Categorical
class GraphEmbed(nn.Module):
def __init__(self, node_hidden_size):
super(GraphEmbed, self).__init__()
# Setting from the paper
self.graph_hidden_size = 2 * node_hidden_size
# Embed graphs
self.node_gating = nn.Sequential(
nn.Linear(node_hidden_size, 1),
nn.Sigmoid()
)
self.node_to_graph = nn.Linear(node_hidden_size,
self.graph_hidden_size)
def forward(self, g_list):
# With our current batched implementation of DGMG, new nodes
# are not added for any graph until all graphs are done with
# adding edges starting from the last node. Therefore all graphs
# in the graph_list should have the same number of nodes.
if g_list[0].number_of_nodes() == 0:
return torch.zeros(len(g_list), self.graph_hidden_size)
bg = dgl.batch(g_list)
bhv = bg.ndata['hv']
bg.ndata['hg'] = self.node_gating(bhv) * self.node_to_graph(bhv)
return dgl.sum_nodes(bg, 'hg')
class GraphProp(nn.Module):
def __init__(self, num_prop_rounds, node_hidden_size):
super(GraphProp, self).__init__()
self.num_prop_rounds = num_prop_rounds
# Setting from the paper
self.node_activation_hidden_size = 2 * node_hidden_size
message_funcs = []
node_update_funcs = []
self.reduce_funcs = []
for t in range(num_prop_rounds):
# input being [hv, hu, xuv]
message_funcs.append(nn.Linear(2 * node_hidden_size + 1,
self.node_activation_hidden_size))
self.reduce_funcs.append(partial(self.dgmg_reduce, round=t))
node_update_funcs.append(
nn.GRUCell(self.node_activation_hidden_size,
node_hidden_size))
self.message_funcs = nn.ModuleList(message_funcs)
self.node_update_funcs = nn.ModuleList(node_update_funcs)
def dgmg_msg(self, edges):
"""
For an edge u->v, return concat([h_u, x_uv])
"""
return {'m': torch.cat([edges.src['hv'],
edges.data['he']],
dim=1)}
def dgmg_reduce(self, nodes, round):
hv_old = nodes.data['hv']
m = nodes.mailbox['m']
message = torch.cat([
hv_old.unsqueeze(1).expand(-1, m.size(1), -1), m], dim=2)
node_activation = (self.message_funcs[round](message)).sum(1)
return {'a': node_activation}
def forward(self, g_list):
# Merge small graphs into a large graph.
bg = dgl.batch(g_list)
if bg.number_of_edges() == 0:
return
else:
for t in range(self.num_prop_rounds):
bg.update_all(message_func=self.dgmg_msg,
reduce_func=self.reduce_funcs[t])
bg.ndata['hv'] = self.node_update_funcs[t](
bg.ndata['a'], bg.ndata['hv'])
return dgl.unbatch(bg)
def bernoulli_action_log_prob(logit, action):
"""
Calculate the log p of an action with respect to a Bernoulli
distribution across a batch of actions. Use logit rather than
prob for numerical stability.
"""
log_probs = torch.cat([F.logsigmoid(-logit), F.logsigmoid(logit)], dim=1)
return log_probs.gather(1, torch.tensor(action).unsqueeze(1))
class AddNode(nn.Module):
def __init__(self, graph_embed_func, node_hidden_size):
super(AddNode, self).__init__()
self.graph_op = {'embed': graph_embed_func}
self.stop = 1
self.add_node = nn.Linear(graph_embed_func.graph_hidden_size, 1)
# If to add a node, initialize its hv
self.node_type_embed = nn.Embedding(1, node_hidden_size)
self.initialize_hv = nn.Linear(node_hidden_size + \
graph_embed_func.graph_hidden_size,
node_hidden_size)
self.init_node_activation = torch.zeros(1, 2 * node_hidden_size)
def _initialize_node_repr(self, g, node_type, graph_embed):
num_nodes = g.number_of_nodes()
hv_init = self.initialize_hv(
torch.cat([
self.node_type_embed(torch.LongTensor([node_type])),
graph_embed], dim=1))
g.nodes[num_nodes - 1].data['hv'] = hv_init
g.nodes[num_nodes - 1].data['a'] = self.init_node_activation
def prepare_training(self):
"""
This function will only be called during training.
It stores all log probabilities for AddNode actions.
Each element is a tensor of shape [batch_size, 1].
"""
self.log_prob = []
def forward(self, g_list, a=None):
"""
Decide if a new node should be added for each graph in
the `g_list`. If a new node is added, initialize its
node representations. Record graphs for which a new node
is added.
During training, the action is passed rather than made
and the log P of the action is recorded.
During inference, the action is sampled from a Bernoulli
distribution modeled.
Parameters
----------
g_list : list
A list of dgl.DGLGraph objects
a : None or list
- During training, a is a list of integers specifying
whether a new node should be added.
- During inference, a is None.
Returns
-------
g_non_stop : list
list of indices to specify which graphs in the
g_list have a new node added
"""
# Graphs for which a node is added
g_non_stop = []
batch_graph_embed = self.graph_op['embed'](g_list)
batch_logit = self.add_node(batch_graph_embed)
batch_prob = torch.sigmoid(batch_logit)
if not self.training:
a = Bernoulli(batch_prob).sample().squeeze(1).tolist()
for i, g in enumerate(g_list):
action = a[i]
stop = bool(action == self.stop)
if not stop:
g_non_stop.append(g.index)
g.add_nodes(1)
self._initialize_node_repr(g, action,
batch_graph_embed[i:i+1, :])
if self.training:
sample_log_prob = bernoulli_action_log_prob(batch_logit, a)
self.log_prob.append(sample_log_prob)
return g_non_stop
class AddEdge(nn.Module):
def __init__(self, graph_embed_func, node_hidden_size):
super(AddEdge, self).__init__()
self.graph_op = {'embed': graph_embed_func}
self.add_edge = nn.Linear(graph_embed_func.graph_hidden_size + \
node_hidden_size, 1)
def prepare_training(self):
"""
This function will only be called during training.
It stores all log probabilities for AddEdge actions.
Each element is a tensor of shape [batch_size, 1].
"""
self.log_prob = []
def forward(self, g_list, a=None):
"""
Decide if a new edge should be added for each graph in
the `g_list`. Record graphs for which a new edge is to
be added.
During training, the action is passed rather than made
and the log P of the action is recorded.
During inference, the action is sampled from a Bernoulli
distribution modeled.
Parameters
----------
g_list : list
A list of dgl.DGLGraph objects
a : None or list
- During training, a is a list of integers specifying
whether a new edge should be added.
- During inference, a is None.
Returns
-------
g_to_add_edge : list
list of indices to specify which graphs in the
g_list need a new edge to be added
"""
# Graphs for which an edge is to be added.
g_to_add_edge = []
batch_graph_embed = self.graph_op['embed'](g_list)
batch_src_embed = torch.cat([g.nodes[g.number_of_nodes() - 1].data['hv']
for g in g_list], dim=0)
batch_logit = self.add_edge(torch.cat([batch_graph_embed,
batch_src_embed], dim=1))
batch_prob = torch.sigmoid(batch_logit)
if not self.training:
a = Bernoulli(batch_prob).sample().squeeze(1).tolist()
for i, g in enumerate(g_list):
action = a[i]
if action == 0:
g_to_add_edge.append(g.index)
if self.training:
sample_log_prob = bernoulli_action_log_prob(batch_logit, a)
self.log_prob.append(sample_log_prob)
return g_to_add_edge
class ChooseDestAndUpdate(nn.Module):
def __init__(self, graph_prop_func, node_hidden_size):
super(ChooseDestAndUpdate, self).__init__()
self.choose_dest = nn.Linear(2 * node_hidden_size, 1)
def _initialize_edge_repr(self, g, src_list, dest_list):
# For untyped edges, we only add 1 to indicate its existence.
# For multiple edge types, we can use a one hot representation
# or an embedding module.
edge_repr = torch.ones(len(src_list), 1)
g.edges[src_list, dest_list].data['he'] = edge_repr
def prepare_training(self):
"""
This function will only be called during training.
It stores all log probabilities for ChooseDest actions.
Each element is a tensor of shape [1, 1].
"""
self.log_prob = []
def forward(self, g_list, d=None):
"""
For each g in g_list, add an edge (src, dest)
if (src, dst) does not exist. The src is just the latest
node in g. Initialize edge features if new edges are added.
During training, dst is passed rather than chosen and the
log P of the action is recorded.
During inference, dst is sampled from a Categorical
distribution modeled.
Parameters
----------
g_list : list
A list of dgl.DGLGraph objects
d : None or list
- During training, d is a list of integers specifying dst for
each graph in g_list.
- During inference, d is None.
"""
for i, g in enumerate(g_list):
src = g.number_of_nodes() - 1
possible_dests = range(src)
src_embed_expand = g.nodes[src].data['hv'].expand(src, -1)
possible_dests_embed = g.nodes[possible_dests].data['hv']
dests_scores = self.choose_dest(
torch.cat([possible_dests_embed,
src_embed_expand], dim=1)).view(1, -1)
dests_probs = F.softmax(dests_scores, dim=1)
if not self.training:
dest = Categorical(dests_probs).sample().item()
else:
dest = d[i]
# Note that we are not considering multigraph here.
if not g.has_edge_between(src, dest):
# For undirected graphs, we add edges for both
# directions so that we can perform graph propagation.
src_list = [src, dest]
dest_list = [dest, src]
g.add_edges(src_list, dest_list)
self._initialize_edge_repr(g, src_list, dest_list)
if self.training:
if dests_probs.nelement() > 1:
self.log_prob.append(
F.log_softmax(dests_scores, dim=1)[:, dest: dest + 1])
class DGMG(nn.Module):
def __init__(self, v_max, node_hidden_size,
num_prop_rounds):
super(DGMG, self).__init__()
# Graph configuration
self.v_max = v_max
# Graph embedding module
self.graph_embed = GraphEmbed(node_hidden_size)
# Graph propagation module
self.graph_prop = GraphProp(num_prop_rounds,
node_hidden_size)
# Actions
self.add_node_agent = AddNode(
self.graph_embed, node_hidden_size)
self.add_edge_agent = AddEdge(
self.graph_embed, node_hidden_size)
self.choose_dest_agent = ChooseDestAndUpdate(
self.graph_prop, node_hidden_size)
# Weight initialization
self.init_weights()
def init_weights(self):
from utils import weights_init, dgmg_message_weight_init
self.graph_embed.apply(weights_init)
self.graph_prop.apply(weights_init)
self.add_node_agent.apply(weights_init)
self.add_edge_agent.apply(weights_init)
self.choose_dest_agent.apply(weights_init)
self.graph_prop.message_funcs.apply(dgmg_message_weight_init)
def prepare(self, batch_size):
# Track how many actions have been taken for each graph.
self.step_count = [0] * batch_size
self.g_list = []
# indices for graphs being generated
self.g_active = list(range(batch_size))
for i in range(batch_size):
g = dgl.DGLGraph()
g.index = i
# If there are some features for nodes and edges,
# zero tensors will be set for those of new nodes and edges.
g.set_n_initializer(dgl.frame.zero_initializer)
g.set_e_initializer(dgl.frame.zero_initializer)
self.g_list.append(g)
if self.training:
self.add_node_agent.prepare_training()
self.add_edge_agent.prepare_training()
self.choose_dest_agent.prepare_training()
def _get_graphs(self, indices):
return [self.g_list[i] for i in indices]
def get_action_step(self, indices):
"""
This function should only be called during training.
Collect the number of actions taken for each graph
whose index is in the indices. After collecting
the number of actions, increment it by 1.
"""
old_step_count = []
for i in indices:
old_step_count.append(self.step_count[i])
self.step_count[i] += 1
return old_step_count
def get_actions(self, mode):
"""
This function should only be called during training.
Decide which graphs are related with the next batched
decision and extract the actions to take for each of
the graph.
"""
if mode == 'node':
# Graphs being generated
indices = self.g_active
elif mode == 'edge':
# Graphs having more edges to be added
# starting from the latest node.
indices = self.g_to_add_edge
else:
raise ValueError("Expected mode to be in ['node', 'edge'], "
"got {}".format(mode))
action_indices = self.get_action_step(indices)
# Actions for all graphs indexed by indices at timestep t
actions_t = []
for i, j in enumerate(indices):
actions_t.append(self.actions[j][action_indices[i]])
return actions_t
def add_node_and_update(self, a=None):
"""
Decide if to add a new node for each graph being generated.
If a new node should be added, update the graph.
The action(s) a are passed during training and
sampled (hence None) during inference.
"""
g_list = self._get_graphs(self.g_active)
g_non_stop = self.add_node_agent(g_list, a)
self.g_active = g_non_stop
# For all newly added nodes we need to decide
# if an edge is to be added for each of them.
self.g_to_add_edge = g_non_stop
return len(self.g_active) == 0
def add_edge_or_not(self, a=None):
"""
Decide if a new edge should be added for each
graph that may need one more edge.
The action(s) a are passed during training and
sampled (hence None) during inference.
"""
g_list = self._get_graphs(self.g_to_add_edge)
g_to_add_edge = self.add_edge_agent(g_list, a)
self.g_to_add_edge = g_to_add_edge
return len(self.g_to_add_edge) > 0
def choose_dest_and_update(self, a=None):
"""
For each graph that requires one more edge, choose
destination and connect it to the latest node.
Add edges for both directions and update the graph.
The action(s) a are passed during training and
sampled (hence None) during inference.
"""
g_list = self._get_graphs(self.g_to_add_edge)
self.choose_dest_agent(g_list, a)
# Graph propagation and update node features.
updated_g_list = self.graph_prop(g_list)
for i, g in enumerate(updated_g_list):
g.index = self.g_to_add_edge[i]
self.g_list[g.index] = g
def get_log_prob(self):
return torch.cat(self.add_node_agent.log_prob).sum()\
+ torch.cat(self.add_edge_agent.log_prob).sum()\
+ torch.cat(self.choose_dest_agent.log_prob).sum()
def forward_train(self, actions):
"""
Go through all decisions in actions and record their
log probabilities for calculating the loss.
Parameters
----------
actions : list
list of decisions extracted for generating a graph using DGMG
Returns
-------
tensor of shape torch.Size([])
log P(Generate a batch of graphs using DGMG)
"""
self.actions = actions
stop = self.add_node_and_update(a=self.get_actions('node'))
# Some graphs haven't been completely generated.
while not stop:
to_add_edge = self.add_edge_or_not(a=self.get_actions('edge'))
# Some graphs need more edges to be added for the latest node.
while to_add_edge:
self.choose_dest_and_update(a=self.get_actions('edge'))
to_add_edge = self.add_edge_or_not(a=self.get_actions('edge'))
stop = self.add_node_and_update(a=self.get_actions('node'))
return self.get_log_prob()
def forward_inference(self):
"""
Generate graph(s) on the fly.
Returns
-------
self.g_list : list
A list of dgl.DGLGraph objects.
"""
stop = self.add_node_and_update()
# Some graphs haven't been completely generated and their numbers of
# nodes do not exceed the limit of self.v_max.
while (not stop) and (self.g_list[self.g_active[0]].number_of_nodes()
< self.v_max + 1):
num_trials = 0
to_add_edge = self.add_edge_or_not()
# Some graphs need more edges to be added for the latest node and
# the number of trials does not exceed the number of maximum possible
# edges. Note that this limit on the number of edges eliminate the
# possibility of multi-graph and one may want to remove it.
while to_add_edge and (num_trials <
self.g_list[self.g_active[0]].number_of_nodes() - 1):
self.choose_dest_and_update()
num_trials += 1
to_add_edge = self.add_edge_or_not()
stop = self.add_node_and_update()
return self.g_list
def forward(self, batch_size=1, actions=None):
if self.training:
batch_size = len(actions)
self.prepare(batch_size)
if self.training:
return self.forward_train(actions)
else:
return self.forward_inference()
...@@ -765,77 +765,3 @@ print('Among 100 graphs generated, {}% are valid.'.format(num_valid)) ...@@ -765,77 +765,3 @@ print('Among 100 graphs generated, {}% are valid.'.format(num_valid))
# For the complete implementation, see the `DGL DGMG example # For the complete implementation, see the `DGL DGMG example
# <https://github.com/dmlc/dgl/tree/master/examples/pytorch/dgmg>`__. # <https://github.com/dmlc/dgl/tree/master/examples/pytorch/dgmg>`__.
# #
# Batched graph generation
# ---------------------------
#
# Speeding up DGMG is hard because each graph can be generated with a
# unique sequence of actions. One way to explore parallelism is to adopt
# asynchronous gradient descent with multiple processes. Each of them
# works on one graph at a time and the processes are loosely coordinated
# by a parameter server.
#
# DGL explores parallelism in the message-passing framework, on top of
# the framework-provided tensor operation. The earlier tutorial already
# does that in the message propagation and graph embedding phases, but
# only within one graph. For a batch of graphs, a for loop is then needed:
#
# ::
#
# for g in g_list:
# self.graph_prop(g)
#
# Modify the code to work on a batch of graphs at once by replacing
# these lines with the following. On CPU with a macOS, you instantly
# enjoy a six to seven-time reduction for the graph propagation part.
# ::
#
# bg = dgl.batch(g_list)
# self.graph_prop(bg)
# g_list = dgl.unbatch(bg)
#
# You have already used this trick of calling ``dgl.batch`` in the
# `Tree-LSTM tutorial
# <http://docs.dgl.ai/tutorials/models/3_tree-lstm.html#sphx-glr-tutorials-models-3-tree-lstm-py>`__
# , and it is worth explaining one more time why this is so.
#
# By batching many small graphs, DGL parallels message passing on each individual
# graphs of a batch.
#
# With ``dgl.batch``, you merge ``g_{1}, ..., g_{N}`` into one single giant
# graph consisting of :math:`N` isolated small graphs. For example, if we
# have two graphs with adjacency matrices
#
# ::
#
# [0, 1]
# [1, 0]
#
# [0, 1, 0]
# [1, 0, 0]
# [0, 1, 0]
#
# ``dgl.batch`` simply gives a graph whose adjacency matrix is
#
# ::
#
# [0, 1, 0, 0, 0]
# [1, 0, 0, 0, 0]
# [0, 1, 0, 0, 0]
# [1, 0, 0, 0, 0]
# [0, 1, 0, 0, 0]
#
# In DGL, the message function is defined on the edges, thus batching scales
# the processing of edge user-defined functions (UDFs) linearly.
#
# The reduce UDFs or ``dgmg_reduce``, work on nodes. Each of them may
# have different numbers of incoming edges. Using ``degree bucketing``, DGL
# internally groups nodes with the same in-degrees and calls reduce UDF once
# for each group. Thus, batching also reduces number of calls to these UDFs.
#
# The modification of the node/edge features of the batched graph object
# does not take effect on the features of the original small graphs, so we
# need to replace the old graph list with the new graph list
# ``g_list = dgl.unbatch(bg)``.
#
# The complete code to the batched version can also be found in the example.
# On a testbed, you get roughly double the speed when compared to the previous implementation.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment