Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
100d9328
Unverified
Commit
100d9328
authored
Aug 19, 2020
by
Mufei Li
Committed by
GitHub
Aug 19, 2020
Browse files
Update (#2062)
Co-authored-by:
Ubuntu
<
ubuntu@ip-172-31-1-5.us-west-2.compute.internal
>
parent
a260a6e6
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
3 additions
and
775 deletions
+3
-775
examples/pytorch/dgmg/README.md
examples/pytorch/dgmg/README.md
+3
-5
examples/pytorch/dgmg/main_batch.py
examples/pytorch/dgmg/main_batch.py
+0
-118
examples/pytorch/dgmg/model_batch.py
examples/pytorch/dgmg/model_batch.py
+0
-578
tutorials/models/3_generative_model/5_dgmg.py
tutorials/models/3_generative_model/5_dgmg.py
+0
-74
No files found.
examples/pytorch/dgmg/README.md
View file @
100d9328
...
@@ -4,7 +4,7 @@ This is an implementation of [Learning Deep Generative Models of Graphs](https:/
...
@@ -4,7 +4,7 @@ This is an implementation of [Learning Deep Generative Models of Graphs](https:/
Yujia Li, Oriol Vinyals, Chris Dyer, Razvan Pascanu, Peter Battaglia.
Yujia Li, Oriol Vinyals, Chris Dyer, Razvan Pascanu, Peter Battaglia.
For molecule generation, see
For molecule generation, see
[
our model zoo for Chemistry
](
https://github.com/
dmlc/dgl
/tree/master/examples/
pytorch/model_zoo/chem/
generative_models/dgmg
)
.
[
DGL-LifeSci
](
https://github.com/
awslabs/dgl-lifesci
/tree/master/examples/generative_models/dgmg
)
.
## Dependencies
## Dependencies
-
Python 3.5.2
-
Python 3.5.2
...
@@ -13,8 +13,7 @@ For molecule generation, see
...
@@ -13,8 +13,7 @@ For molecule generation, see
## Usage
## Usage
-
Train with batch size 1:
`python3 main.py`
`python3 main.py`
-
Train with batch size larger than 1:
`python3 main_batch.py`
.
## Performance
## Performance
...
@@ -22,8 +21,7 @@ For molecule generation, see
...
@@ -22,8 +21,7 @@ For molecule generation, see
## Speed
## Speed
On AWS p3.2x instance (w/ V100), one epoch takes ~526s for batch size 1 and takes
On AWS p3.2x instance (w/ V100), one epoch takes ~526s.
~238s for batch size 10.
## Acknowledgement
## Acknowledgement
...
...
examples/pytorch/dgmg/main_batch.py
deleted
100644 → 0
View file @
a260a6e6
"""
Learning Deep Generative Models of Graphs
Paper: https://arxiv.org/pdf/1803.03324.pdf
This implementation works with a minibatch of size larger than 1 for training and 1 for inference.
"""
import
argparse
import
datetime
import
time
import
torch
from
torch.optim
import
Adam
from
torch.utils.data
import
DataLoader
from
torch.nn.utils
import
clip_grad_norm_
from
model_batch
import
DGMG
def
main
(
opts
):
t1
=
time
.
time
()
# Setup dataset and data loader
if
opts
[
'dataset'
]
==
'cycles'
:
from
cycles
import
CycleDataset
,
CycleModelEvaluation
,
CyclePrinting
dataset
=
CycleDataset
(
fname
=
opts
[
'path_to_dataset'
])
evaluator
=
CycleModelEvaluation
(
v_min
=
opts
[
'min_size'
],
v_max
=
opts
[
'max_size'
],
dir
=
opts
[
'log_dir'
])
printer
=
CyclePrinting
(
num_epochs
=
opts
[
'nepochs'
],
num_batches
=
len
(
dataset
)
//
opts
[
'batch_size'
])
else
:
raise
ValueError
(
'Unsupported dataset: {}'
.
format
(
opts
[
'dataset'
]))
data_loader
=
DataLoader
(
dataset
,
batch_size
=
opts
[
'batch_size'
],
shuffle
=
True
,
num_workers
=
0
,
collate_fn
=
dataset
.
collate_batch
)
# Initialize_model
model
=
DGMG
(
v_max
=
opts
[
'max_size'
],
node_hidden_size
=
opts
[
'node_hidden_size'
],
num_prop_rounds
=
opts
[
'num_propagation_rounds'
])
# Initialize optimizer
if
opts
[
'optimizer'
]
==
'Adam'
:
optimizer
=
Adam
(
model
.
parameters
(),
lr
=
opts
[
'lr'
])
else
:
raise
ValueError
(
'Unsupported argument for the optimizer'
)
t2
=
time
.
time
()
# Training
model
.
train
()
for
epoch
in
range
(
opts
[
'nepochs'
]):
for
batch
,
data
in
enumerate
(
data_loader
):
log_prob
=
model
(
batch_size
=
opts
[
'batch_size'
],
actions
=
data
)
loss
=
-
log_prob
/
opts
[
'batch_size'
]
batch_avg_prob
=
(
log_prob
/
opts
[
'batch_size'
]).
detach
().
exp
()
batch_avg_loss
=
loss
.
item
()
optimizer
.
zero_grad
()
loss
.
backward
()
if
opts
[
'clip_grad'
]:
clip_grad_norm_
(
model
.
parameters
(),
opts
[
'clip_bound'
])
optimizer
.
step
()
printer
.
update
(
epoch
+
1
,
{
'averaged loss'
:
batch_avg_loss
,
'averaged prob'
:
batch_avg_prob
})
t3
=
time
.
time
()
model
.
eval
()
evaluator
.
rollout_and_examine
(
model
,
opts
[
'num_generated_samples'
])
evaluator
.
write_summary
()
t4
=
time
.
time
()
print
(
'It took {} to setup.'
.
format
(
datetime
.
timedelta
(
seconds
=
t2
-
t1
)))
print
(
'It took {} to finish training.'
.
format
(
datetime
.
timedelta
(
seconds
=
t3
-
t2
)))
print
(
'It took {} to finish evaluation.'
.
format
(
datetime
.
timedelta
(
seconds
=
t4
-
t3
)))
print
(
'--------------------------------------------------------------------------'
)
print
(
'On average, an epoch takes {}.'
.
format
(
datetime
.
timedelta
(
seconds
=
(
t3
-
t2
)
/
opts
[
'nepochs'
])))
del
model
.
g_list
torch
.
save
(
model
,
'./model_batched.pth'
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'batched DGMG'
)
# configure
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
9284
,
help
=
'random seed'
)
# dataset
parser
.
add_argument
(
'--dataset'
,
choices
=
[
'cycles'
],
default
=
'cycles'
,
help
=
'dataset to use'
)
parser
.
add_argument
(
'--path-to-dataset'
,
type
=
str
,
default
=
'cycles.p'
,
help
=
'load the dataset if it exists, '
'generate it and save to the path otherwise'
)
# log
parser
.
add_argument
(
'--log-dir'
,
default
=
'./results'
,
help
=
'folder to save info like experiment configuration '
'or model evaluation results'
)
# optimization
parser
.
add_argument
(
'--batch-size'
,
type
=
int
,
default
=
10
,
help
=
'batch size to use for training'
)
parser
.
add_argument
(
'--clip-grad'
,
action
=
'store_true'
,
default
=
True
,
help
=
'gradient clipping is required to prevent gradient explosion'
)
parser
.
add_argument
(
'--clip-bound'
,
type
=
float
,
default
=
0.25
,
help
=
'constraint of gradient norm for gradient clipping'
)
args
=
parser
.
parse_args
()
from
utils
import
setup
opts
=
setup
(
args
)
main
(
opts
)
\ No newline at end of file
examples/pytorch/dgmg/model_batch.py
deleted
100644 → 0
View file @
a260a6e6
import
dgl
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
functools
import
partial
from
torch.distributions
import
Bernoulli
,
Categorical
class
GraphEmbed
(
nn
.
Module
):
def
__init__
(
self
,
node_hidden_size
):
super
(
GraphEmbed
,
self
).
__init__
()
# Setting from the paper
self
.
graph_hidden_size
=
2
*
node_hidden_size
# Embed graphs
self
.
node_gating
=
nn
.
Sequential
(
nn
.
Linear
(
node_hidden_size
,
1
),
nn
.
Sigmoid
()
)
self
.
node_to_graph
=
nn
.
Linear
(
node_hidden_size
,
self
.
graph_hidden_size
)
def
forward
(
self
,
g_list
):
# With our current batched implementation of DGMG, new nodes
# are not added for any graph until all graphs are done with
# adding edges starting from the last node. Therefore all graphs
# in the graph_list should have the same number of nodes.
if
g_list
[
0
].
number_of_nodes
()
==
0
:
return
torch
.
zeros
(
len
(
g_list
),
self
.
graph_hidden_size
)
bg
=
dgl
.
batch
(
g_list
)
bhv
=
bg
.
ndata
[
'hv'
]
bg
.
ndata
[
'hg'
]
=
self
.
node_gating
(
bhv
)
*
self
.
node_to_graph
(
bhv
)
return
dgl
.
sum_nodes
(
bg
,
'hg'
)
class
GraphProp
(
nn
.
Module
):
def
__init__
(
self
,
num_prop_rounds
,
node_hidden_size
):
super
(
GraphProp
,
self
).
__init__
()
self
.
num_prop_rounds
=
num_prop_rounds
# Setting from the paper
self
.
node_activation_hidden_size
=
2
*
node_hidden_size
message_funcs
=
[]
node_update_funcs
=
[]
self
.
reduce_funcs
=
[]
for
t
in
range
(
num_prop_rounds
):
# input being [hv, hu, xuv]
message_funcs
.
append
(
nn
.
Linear
(
2
*
node_hidden_size
+
1
,
self
.
node_activation_hidden_size
))
self
.
reduce_funcs
.
append
(
partial
(
self
.
dgmg_reduce
,
round
=
t
))
node_update_funcs
.
append
(
nn
.
GRUCell
(
self
.
node_activation_hidden_size
,
node_hidden_size
))
self
.
message_funcs
=
nn
.
ModuleList
(
message_funcs
)
self
.
node_update_funcs
=
nn
.
ModuleList
(
node_update_funcs
)
def
dgmg_msg
(
self
,
edges
):
"""
For an edge u->v, return concat([h_u, x_uv])
"""
return
{
'm'
:
torch
.
cat
([
edges
.
src
[
'hv'
],
edges
.
data
[
'he'
]],
dim
=
1
)}
def
dgmg_reduce
(
self
,
nodes
,
round
):
hv_old
=
nodes
.
data
[
'hv'
]
m
=
nodes
.
mailbox
[
'm'
]
message
=
torch
.
cat
([
hv_old
.
unsqueeze
(
1
).
expand
(
-
1
,
m
.
size
(
1
),
-
1
),
m
],
dim
=
2
)
node_activation
=
(
self
.
message_funcs
[
round
](
message
)).
sum
(
1
)
return
{
'a'
:
node_activation
}
def
forward
(
self
,
g_list
):
# Merge small graphs into a large graph.
bg
=
dgl
.
batch
(
g_list
)
if
bg
.
number_of_edges
()
==
0
:
return
else
:
for
t
in
range
(
self
.
num_prop_rounds
):
bg
.
update_all
(
message_func
=
self
.
dgmg_msg
,
reduce_func
=
self
.
reduce_funcs
[
t
])
bg
.
ndata
[
'hv'
]
=
self
.
node_update_funcs
[
t
](
bg
.
ndata
[
'a'
],
bg
.
ndata
[
'hv'
])
return
dgl
.
unbatch
(
bg
)
def
bernoulli_action_log_prob
(
logit
,
action
):
"""
Calculate the log p of an action with respect to a Bernoulli
distribution across a batch of actions. Use logit rather than
prob for numerical stability.
"""
log_probs
=
torch
.
cat
([
F
.
logsigmoid
(
-
logit
),
F
.
logsigmoid
(
logit
)],
dim
=
1
)
return
log_probs
.
gather
(
1
,
torch
.
tensor
(
action
).
unsqueeze
(
1
))
class
AddNode
(
nn
.
Module
):
def
__init__
(
self
,
graph_embed_func
,
node_hidden_size
):
super
(
AddNode
,
self
).
__init__
()
self
.
graph_op
=
{
'embed'
:
graph_embed_func
}
self
.
stop
=
1
self
.
add_node
=
nn
.
Linear
(
graph_embed_func
.
graph_hidden_size
,
1
)
# If to add a node, initialize its hv
self
.
node_type_embed
=
nn
.
Embedding
(
1
,
node_hidden_size
)
self
.
initialize_hv
=
nn
.
Linear
(
node_hidden_size
+
\
graph_embed_func
.
graph_hidden_size
,
node_hidden_size
)
self
.
init_node_activation
=
torch
.
zeros
(
1
,
2
*
node_hidden_size
)
def
_initialize_node_repr
(
self
,
g
,
node_type
,
graph_embed
):
num_nodes
=
g
.
number_of_nodes
()
hv_init
=
self
.
initialize_hv
(
torch
.
cat
([
self
.
node_type_embed
(
torch
.
LongTensor
([
node_type
])),
graph_embed
],
dim
=
1
))
g
.
nodes
[
num_nodes
-
1
].
data
[
'hv'
]
=
hv_init
g
.
nodes
[
num_nodes
-
1
].
data
[
'a'
]
=
self
.
init_node_activation
def
prepare_training
(
self
):
"""
This function will only be called during training.
It stores all log probabilities for AddNode actions.
Each element is a tensor of shape [batch_size, 1].
"""
self
.
log_prob
=
[]
def
forward
(
self
,
g_list
,
a
=
None
):
"""
Decide if a new node should be added for each graph in
the `g_list`. If a new node is added, initialize its
node representations. Record graphs for which a new node
is added.
During training, the action is passed rather than made
and the log P of the action is recorded.
During inference, the action is sampled from a Bernoulli
distribution modeled.
Parameters
----------
g_list : list
A list of dgl.DGLGraph objects
a : None or list
- During training, a is a list of integers specifying
whether a new node should be added.
- During inference, a is None.
Returns
-------
g_non_stop : list
list of indices to specify which graphs in the
g_list have a new node added
"""
# Graphs for which a node is added
g_non_stop
=
[]
batch_graph_embed
=
self
.
graph_op
[
'embed'
](
g_list
)
batch_logit
=
self
.
add_node
(
batch_graph_embed
)
batch_prob
=
torch
.
sigmoid
(
batch_logit
)
if
not
self
.
training
:
a
=
Bernoulli
(
batch_prob
).
sample
().
squeeze
(
1
).
tolist
()
for
i
,
g
in
enumerate
(
g_list
):
action
=
a
[
i
]
stop
=
bool
(
action
==
self
.
stop
)
if
not
stop
:
g_non_stop
.
append
(
g
.
index
)
g
.
add_nodes
(
1
)
self
.
_initialize_node_repr
(
g
,
action
,
batch_graph_embed
[
i
:
i
+
1
,
:])
if
self
.
training
:
sample_log_prob
=
bernoulli_action_log_prob
(
batch_logit
,
a
)
self
.
log_prob
.
append
(
sample_log_prob
)
return
g_non_stop
class
AddEdge
(
nn
.
Module
):
def
__init__
(
self
,
graph_embed_func
,
node_hidden_size
):
super
(
AddEdge
,
self
).
__init__
()
self
.
graph_op
=
{
'embed'
:
graph_embed_func
}
self
.
add_edge
=
nn
.
Linear
(
graph_embed_func
.
graph_hidden_size
+
\
node_hidden_size
,
1
)
def
prepare_training
(
self
):
"""
This function will only be called during training.
It stores all log probabilities for AddEdge actions.
Each element is a tensor of shape [batch_size, 1].
"""
self
.
log_prob
=
[]
def
forward
(
self
,
g_list
,
a
=
None
):
"""
Decide if a new edge should be added for each graph in
the `g_list`. Record graphs for which a new edge is to
be added.
During training, the action is passed rather than made
and the log P of the action is recorded.
During inference, the action is sampled from a Bernoulli
distribution modeled.
Parameters
----------
g_list : list
A list of dgl.DGLGraph objects
a : None or list
- During training, a is a list of integers specifying
whether a new edge should be added.
- During inference, a is None.
Returns
-------
g_to_add_edge : list
list of indices to specify which graphs in the
g_list need a new edge to be added
"""
# Graphs for which an edge is to be added.
g_to_add_edge
=
[]
batch_graph_embed
=
self
.
graph_op
[
'embed'
](
g_list
)
batch_src_embed
=
torch
.
cat
([
g
.
nodes
[
g
.
number_of_nodes
()
-
1
].
data
[
'hv'
]
for
g
in
g_list
],
dim
=
0
)
batch_logit
=
self
.
add_edge
(
torch
.
cat
([
batch_graph_embed
,
batch_src_embed
],
dim
=
1
))
batch_prob
=
torch
.
sigmoid
(
batch_logit
)
if
not
self
.
training
:
a
=
Bernoulli
(
batch_prob
).
sample
().
squeeze
(
1
).
tolist
()
for
i
,
g
in
enumerate
(
g_list
):
action
=
a
[
i
]
if
action
==
0
:
g_to_add_edge
.
append
(
g
.
index
)
if
self
.
training
:
sample_log_prob
=
bernoulli_action_log_prob
(
batch_logit
,
a
)
self
.
log_prob
.
append
(
sample_log_prob
)
return
g_to_add_edge
class
ChooseDestAndUpdate
(
nn
.
Module
):
def
__init__
(
self
,
graph_prop_func
,
node_hidden_size
):
super
(
ChooseDestAndUpdate
,
self
).
__init__
()
self
.
choose_dest
=
nn
.
Linear
(
2
*
node_hidden_size
,
1
)
def
_initialize_edge_repr
(
self
,
g
,
src_list
,
dest_list
):
# For untyped edges, we only add 1 to indicate its existence.
# For multiple edge types, we can use a one hot representation
# or an embedding module.
edge_repr
=
torch
.
ones
(
len
(
src_list
),
1
)
g
.
edges
[
src_list
,
dest_list
].
data
[
'he'
]
=
edge_repr
def
prepare_training
(
self
):
"""
This function will only be called during training.
It stores all log probabilities for ChooseDest actions.
Each element is a tensor of shape [1, 1].
"""
self
.
log_prob
=
[]
def
forward
(
self
,
g_list
,
d
=
None
):
"""
For each g in g_list, add an edge (src, dest)
if (src, dst) does not exist. The src is just the latest
node in g. Initialize edge features if new edges are added.
During training, dst is passed rather than chosen and the
log P of the action is recorded.
During inference, dst is sampled from a Categorical
distribution modeled.
Parameters
----------
g_list : list
A list of dgl.DGLGraph objects
d : None or list
- During training, d is a list of integers specifying dst for
each graph in g_list.
- During inference, d is None.
"""
for
i
,
g
in
enumerate
(
g_list
):
src
=
g
.
number_of_nodes
()
-
1
possible_dests
=
range
(
src
)
src_embed_expand
=
g
.
nodes
[
src
].
data
[
'hv'
].
expand
(
src
,
-
1
)
possible_dests_embed
=
g
.
nodes
[
possible_dests
].
data
[
'hv'
]
dests_scores
=
self
.
choose_dest
(
torch
.
cat
([
possible_dests_embed
,
src_embed_expand
],
dim
=
1
)).
view
(
1
,
-
1
)
dests_probs
=
F
.
softmax
(
dests_scores
,
dim
=
1
)
if
not
self
.
training
:
dest
=
Categorical
(
dests_probs
).
sample
().
item
()
else
:
dest
=
d
[
i
]
# Note that we are not considering multigraph here.
if
not
g
.
has_edge_between
(
src
,
dest
):
# For undirected graphs, we add edges for both
# directions so that we can perform graph propagation.
src_list
=
[
src
,
dest
]
dest_list
=
[
dest
,
src
]
g
.
add_edges
(
src_list
,
dest_list
)
self
.
_initialize_edge_repr
(
g
,
src_list
,
dest_list
)
if
self
.
training
:
if
dests_probs
.
nelement
()
>
1
:
self
.
log_prob
.
append
(
F
.
log_softmax
(
dests_scores
,
dim
=
1
)[:,
dest
:
dest
+
1
])
class
DGMG
(
nn
.
Module
):
def
__init__
(
self
,
v_max
,
node_hidden_size
,
num_prop_rounds
):
super
(
DGMG
,
self
).
__init__
()
# Graph configuration
self
.
v_max
=
v_max
# Graph embedding module
self
.
graph_embed
=
GraphEmbed
(
node_hidden_size
)
# Graph propagation module
self
.
graph_prop
=
GraphProp
(
num_prop_rounds
,
node_hidden_size
)
# Actions
self
.
add_node_agent
=
AddNode
(
self
.
graph_embed
,
node_hidden_size
)
self
.
add_edge_agent
=
AddEdge
(
self
.
graph_embed
,
node_hidden_size
)
self
.
choose_dest_agent
=
ChooseDestAndUpdate
(
self
.
graph_prop
,
node_hidden_size
)
# Weight initialization
self
.
init_weights
()
def
init_weights
(
self
):
from
utils
import
weights_init
,
dgmg_message_weight_init
self
.
graph_embed
.
apply
(
weights_init
)
self
.
graph_prop
.
apply
(
weights_init
)
self
.
add_node_agent
.
apply
(
weights_init
)
self
.
add_edge_agent
.
apply
(
weights_init
)
self
.
choose_dest_agent
.
apply
(
weights_init
)
self
.
graph_prop
.
message_funcs
.
apply
(
dgmg_message_weight_init
)
def
prepare
(
self
,
batch_size
):
# Track how many actions have been taken for each graph.
self
.
step_count
=
[
0
]
*
batch_size
self
.
g_list
=
[]
# indices for graphs being generated
self
.
g_active
=
list
(
range
(
batch_size
))
for
i
in
range
(
batch_size
):
g
=
dgl
.
DGLGraph
()
g
.
index
=
i
# If there are some features for nodes and edges,
# zero tensors will be set for those of new nodes and edges.
g
.
set_n_initializer
(
dgl
.
frame
.
zero_initializer
)
g
.
set_e_initializer
(
dgl
.
frame
.
zero_initializer
)
self
.
g_list
.
append
(
g
)
if
self
.
training
:
self
.
add_node_agent
.
prepare_training
()
self
.
add_edge_agent
.
prepare_training
()
self
.
choose_dest_agent
.
prepare_training
()
def
_get_graphs
(
self
,
indices
):
return
[
self
.
g_list
[
i
]
for
i
in
indices
]
def
get_action_step
(
self
,
indices
):
"""
This function should only be called during training.
Collect the number of actions taken for each graph
whose index is in the indices. After collecting
the number of actions, increment it by 1.
"""
old_step_count
=
[]
for
i
in
indices
:
old_step_count
.
append
(
self
.
step_count
[
i
])
self
.
step_count
[
i
]
+=
1
return
old_step_count
def
get_actions
(
self
,
mode
):
"""
This function should only be called during training.
Decide which graphs are related with the next batched
decision and extract the actions to take for each of
the graph.
"""
if
mode
==
'node'
:
# Graphs being generated
indices
=
self
.
g_active
elif
mode
==
'edge'
:
# Graphs having more edges to be added
# starting from the latest node.
indices
=
self
.
g_to_add_edge
else
:
raise
ValueError
(
"Expected mode to be in ['node', 'edge'], "
"got {}"
.
format
(
mode
))
action_indices
=
self
.
get_action_step
(
indices
)
# Actions for all graphs indexed by indices at timestep t
actions_t
=
[]
for
i
,
j
in
enumerate
(
indices
):
actions_t
.
append
(
self
.
actions
[
j
][
action_indices
[
i
]])
return
actions_t
def
add_node_and_update
(
self
,
a
=
None
):
"""
Decide if to add a new node for each graph being generated.
If a new node should be added, update the graph.
The action(s) a are passed during training and
sampled (hence None) during inference.
"""
g_list
=
self
.
_get_graphs
(
self
.
g_active
)
g_non_stop
=
self
.
add_node_agent
(
g_list
,
a
)
self
.
g_active
=
g_non_stop
# For all newly added nodes we need to decide
# if an edge is to be added for each of them.
self
.
g_to_add_edge
=
g_non_stop
return
len
(
self
.
g_active
)
==
0
def
add_edge_or_not
(
self
,
a
=
None
):
"""
Decide if a new edge should be added for each
graph that may need one more edge.
The action(s) a are passed during training and
sampled (hence None) during inference.
"""
g_list
=
self
.
_get_graphs
(
self
.
g_to_add_edge
)
g_to_add_edge
=
self
.
add_edge_agent
(
g_list
,
a
)
self
.
g_to_add_edge
=
g_to_add_edge
return
len
(
self
.
g_to_add_edge
)
>
0
def
choose_dest_and_update
(
self
,
a
=
None
):
"""
For each graph that requires one more edge, choose
destination and connect it to the latest node.
Add edges for both directions and update the graph.
The action(s) a are passed during training and
sampled (hence None) during inference.
"""
g_list
=
self
.
_get_graphs
(
self
.
g_to_add_edge
)
self
.
choose_dest_agent
(
g_list
,
a
)
# Graph propagation and update node features.
updated_g_list
=
self
.
graph_prop
(
g_list
)
for
i
,
g
in
enumerate
(
updated_g_list
):
g
.
index
=
self
.
g_to_add_edge
[
i
]
self
.
g_list
[
g
.
index
]
=
g
def
get_log_prob
(
self
):
return
torch
.
cat
(
self
.
add_node_agent
.
log_prob
).
sum
()
\
+
torch
.
cat
(
self
.
add_edge_agent
.
log_prob
).
sum
()
\
+
torch
.
cat
(
self
.
choose_dest_agent
.
log_prob
).
sum
()
def
forward_train
(
self
,
actions
):
"""
Go through all decisions in actions and record their
log probabilities for calculating the loss.
Parameters
----------
actions : list
list of decisions extracted for generating a graph using DGMG
Returns
-------
tensor of shape torch.Size([])
log P(Generate a batch of graphs using DGMG)
"""
self
.
actions
=
actions
stop
=
self
.
add_node_and_update
(
a
=
self
.
get_actions
(
'node'
))
# Some graphs haven't been completely generated.
while
not
stop
:
to_add_edge
=
self
.
add_edge_or_not
(
a
=
self
.
get_actions
(
'edge'
))
# Some graphs need more edges to be added for the latest node.
while
to_add_edge
:
self
.
choose_dest_and_update
(
a
=
self
.
get_actions
(
'edge'
))
to_add_edge
=
self
.
add_edge_or_not
(
a
=
self
.
get_actions
(
'edge'
))
stop
=
self
.
add_node_and_update
(
a
=
self
.
get_actions
(
'node'
))
return
self
.
get_log_prob
()
def
forward_inference
(
self
):
"""
Generate graph(s) on the fly.
Returns
-------
self.g_list : list
A list of dgl.DGLGraph objects.
"""
stop
=
self
.
add_node_and_update
()
# Some graphs haven't been completely generated and their numbers of
# nodes do not exceed the limit of self.v_max.
while
(
not
stop
)
and
(
self
.
g_list
[
self
.
g_active
[
0
]].
number_of_nodes
()
<
self
.
v_max
+
1
):
num_trials
=
0
to_add_edge
=
self
.
add_edge_or_not
()
# Some graphs need more edges to be added for the latest node and
# the number of trials does not exceed the number of maximum possible
# edges. Note that this limit on the number of edges eliminate the
# possibility of multi-graph and one may want to remove it.
while
to_add_edge
and
(
num_trials
<
self
.
g_list
[
self
.
g_active
[
0
]].
number_of_nodes
()
-
1
):
self
.
choose_dest_and_update
()
num_trials
+=
1
to_add_edge
=
self
.
add_edge_or_not
()
stop
=
self
.
add_node_and_update
()
return
self
.
g_list
def
forward
(
self
,
batch_size
=
1
,
actions
=
None
):
if
self
.
training
:
batch_size
=
len
(
actions
)
self
.
prepare
(
batch_size
)
if
self
.
training
:
return
self
.
forward_train
(
actions
)
else
:
return
self
.
forward_inference
()
tutorials/models/3_generative_model/5_dgmg.py
View file @
100d9328
...
@@ -765,77 +765,3 @@ print('Among 100 graphs generated, {}% are valid.'.format(num_valid))
...
@@ -765,77 +765,3 @@ print('Among 100 graphs generated, {}% are valid.'.format(num_valid))
# For the complete implementation, see the `DGL DGMG example
# For the complete implementation, see the `DGL DGMG example
# <https://github.com/dmlc/dgl/tree/master/examples/pytorch/dgmg>`__.
# <https://github.com/dmlc/dgl/tree/master/examples/pytorch/dgmg>`__.
#
#
# Batched graph generation
# ---------------------------
#
# Speeding up DGMG is hard because each graph can be generated with a
# unique sequence of actions. One way to explore parallelism is to adopt
# asynchronous gradient descent with multiple processes. Each of them
# works on one graph at a time and the processes are loosely coordinated
# by a parameter server.
#
# DGL explores parallelism in the message-passing framework, on top of
# the framework-provided tensor operation. The earlier tutorial already
# does that in the message propagation and graph embedding phases, but
# only within one graph. For a batch of graphs, a for loop is then needed:
#
# ::
#
# for g in g_list:
# self.graph_prop(g)
#
# Modify the code to work on a batch of graphs at once by replacing
# these lines with the following. On CPU with a macOS, you instantly
# enjoy a six to seven-time reduction for the graph propagation part.
# ::
#
# bg = dgl.batch(g_list)
# self.graph_prop(bg)
# g_list = dgl.unbatch(bg)
#
# You have already used this trick of calling ``dgl.batch`` in the
# `Tree-LSTM tutorial
# <http://docs.dgl.ai/tutorials/models/3_tree-lstm.html#sphx-glr-tutorials-models-3-tree-lstm-py>`__
# , and it is worth explaining one more time why this is so.
#
# By batching many small graphs, DGL parallels message passing on each individual
# graphs of a batch.
#
# With ``dgl.batch``, you merge ``g_{1}, ..., g_{N}`` into one single giant
# graph consisting of :math:`N` isolated small graphs. For example, if we
# have two graphs with adjacency matrices
#
# ::
#
# [0, 1]
# [1, 0]
#
# [0, 1, 0]
# [1, 0, 0]
# [0, 1, 0]
#
# ``dgl.batch`` simply gives a graph whose adjacency matrix is
#
# ::
#
# [0, 1, 0, 0, 0]
# [1, 0, 0, 0, 0]
# [0, 1, 0, 0, 0]
# [1, 0, 0, 0, 0]
# [0, 1, 0, 0, 0]
#
# In DGL, the message function is defined on the edges, thus batching scales
# the processing of edge user-defined functions (UDFs) linearly.
#
# The reduce UDFs or ``dgmg_reduce``, work on nodes. Each of them may
# have different numbers of incoming edges. Using ``degree bucketing``, DGL
# internally groups nodes with the same in-degrees and calls reduce UDF once
# for each group. Thus, batching also reduces number of calls to these UDFs.
#
# The modification of the node/edge features of the batched graph object
# does not take effect on the features of the original small graphs, so we
# need to replace the old graph list with the new graph list
# ``g_list = dgl.unbatch(bg)``.
#
# The complete code to the batched version can also be found in the example.
# On a testbed, you get roughly double the speed when compared to the previous implementation.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment