Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
3e8b63ec
Commit
3e8b63ec
authored
Nov 22, 2018
by
Mufei Li
Committed by
Minjie Wang
Nov 22, 2018
Browse files
[Model] DGMG Training with Batch Size 1 (#161)
* DGMG with batch size 1 * Fix * Adjustment * Fix * Fix * Fix * Fix
parent
bf6d0025
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
855 additions
and
409 deletions
+855
-409
examples/pytorch/dgmg/README.md
examples/pytorch/dgmg/README.md
+13
-0
examples/pytorch/dgmg/configure.py
examples/pytorch/dgmg/configure.py
+34
-0
examples/pytorch/dgmg/cycles.py
examples/pytorch/dgmg/cycles.py
+218
-0
examples/pytorch/dgmg/main.py
examples/pytorch/dgmg/main.py
+130
-0
examples/pytorch/dgmg/model.py
examples/pytorch/dgmg/model.py
+327
-252
examples/pytorch/dgmg/util.py
examples/pytorch/dgmg/util.py
+0
-157
examples/pytorch/dgmg/utils.py
examples/pytorch/dgmg/utils.py
+133
-0
No files found.
examples/pytorch/dgmg/README.md
0 → 100644
View file @
3e8b63ec
# Learning Deep Generative Models of Graphs
This is an implementation of
[
Learning Deep Generative Models of Graphs
](
https://arxiv.org/pdf/1803.03324.pdf
)
by
Yujia Li, Oriol Vinyals, Chris Dyer, Razvan Pascanu, Peter Battaglia.
# Dependency
-
Python 3.5.2
-
[
Pytorch 0.4.1
](
https://pytorch.org/
)
-
[
Matplotlib 2.2.2
](
https://matplotlib.org/
)
# Usage
-
Train with batch size 1:
`python main.py`
examples/pytorch/dgmg/configure.py
0 → 100644
View file @
3e8b63ec
"""We intend to make our reproduction as close as possible to the original paper.
The configuration in the file is mostly from the description in the original paper
and will be loaded when setting up."""
def
dataset_based_configure
(
opts
):
if
opts
[
'dataset'
]
==
'cycles'
:
ds_configure
=
cycles_configure
else
:
raise
ValueError
(
'Unsupported dataset: {}'
.
format
(
opts
[
'dataset'
]))
opts
=
{
**
opts
,
**
ds_configure
}
return
opts
synthetic_dataset_configure
=
{
'node_hidden_size'
:
16
,
'num_propagation_rounds'
:
2
,
'optimizer'
:
'Adam'
,
'nepochs'
:
25
,
'ds_size'
:
4000
,
'num_generated_samples'
:
10000
,
}
cycles_configure
=
{
**
synthetic_dataset_configure
,
**
{
'min_size'
:
10
,
'max_size'
:
20
,
'lr'
:
5e-4
,
}
}
examples/pytorch/dgmg/cycles.py
0 → 100644
View file @
3e8b63ec
import
matplotlib.pyplot
as
plt
import
networkx
as
nx
import
os
import
pickle
import
random
from
torch.utils.data
import
Dataset
def
get_previous
(
i
,
v_max
):
if
i
==
0
:
return
v_max
else
:
return
i
-
1
def
get_next
(
i
,
v_max
):
if
i
==
v_max
:
return
0
else
:
return
i
+
1
def
is_cycle
(
g
):
size
=
g
.
number_of_nodes
()
if
size
<
3
:
return
False
for
node
in
range
(
size
):
neighbors
=
g
.
successors
(
node
)
if
len
(
neighbors
)
!=
2
:
return
False
if
get_previous
(
node
,
size
-
1
)
not
in
neighbors
:
return
False
if
get_next
(
node
,
size
-
1
)
not
in
neighbors
:
return
False
return
True
def
get_decision_sequence
(
size
):
"""
Get the decision sequence for generating valid cycles with DGMG for teacher
forcing optimization.
"""
decision_sequence
=
[]
for
i
in
range
(
size
):
decision_sequence
.
append
(
0
)
# Add node
if
i
!=
0
:
decision_sequence
.
append
(
0
)
# Add edge
decision_sequence
.
append
(
i
-
1
)
# Set destination to be previous node.
if
i
==
size
-
1
:
decision_sequence
.
append
(
0
)
# Add edge
decision_sequence
.
append
(
0
)
# Set destination to be the root.
decision_sequence
.
append
(
1
)
# Stop adding edge
decision_sequence
.
append
(
1
)
# Stop adding node
return
decision_sequence
def
generate_dataset
(
v_min
,
v_max
,
n_samples
,
fname
):
samples
=
[]
for
_
in
range
(
n_samples
):
size
=
random
.
randint
(
v_min
,
v_max
)
samples
.
append
(
get_decision_sequence
(
size
))
with
open
(
fname
,
'wb'
)
as
f
:
pickle
.
dump
(
samples
,
f
)
class
CycleDataset
(
Dataset
):
def
__init__
(
self
,
fname
):
super
(
CycleDataset
,
self
).
__init__
()
with
open
(
fname
,
'rb'
)
as
f
:
self
.
dataset
=
pickle
.
load
(
f
)
def
__len__
(
self
):
return
len
(
self
.
dataset
)
def
__getitem__
(
self
,
index
):
return
self
.
dataset
[
index
]
def
collate
(
self
,
batch
):
assert
len
(
batch
)
==
1
,
'Currently we do not support batched training'
return
batch
[
0
]
def
dglGraph_to_adj_list
(
g
):
adj_list
=
{}
for
node
in
range
(
g
.
number_of_nodes
()):
# For undirected graph. successors and
# predecessors are equivalent.
adj_list
[
node
]
=
g
.
successors
(
node
).
tolist
()
return
adj_list
class
CycleModelEvaluation
(
object
):
def
__init__
(
self
,
v_min
,
v_max
,
dir
):
super
(
CycleModelEvaluation
,
self
).
__init__
()
self
.
v_min
=
v_min
self
.
v_max
=
v_max
self
.
dir
=
dir
def
_initialize
(
self
):
self
.
num_samples_examined
=
0
self
.
average_size
=
0
self
.
valid_size_ratio
=
0
self
.
cycle_ratio
=
0
self
.
valid_ratio
=
0
def
rollout_and_examine
(
self
,
model
,
num_samples
):
assert
not
model
.
training
,
'You need to call model.eval().'
num_total_size
=
0
num_valid_size
=
0
num_cycle
=
0
num_valid
=
0
plot_times
=
0
adj_lists_to_plot
=
[]
for
i
in
range
(
num_samples
):
sampled_graph
=
model
()
sampled_adj_list
=
dglGraph_to_adj_list
(
sampled_graph
)
adj_lists_to_plot
.
append
(
sampled_adj_list
)
generated_graph_size
=
sampled_graph
.
number_of_nodes
()
valid_size
=
(
self
.
v_min
<=
generated_graph_size
<=
self
.
v_max
)
cycle
=
is_cycle
(
sampled_graph
)
num_total_size
+=
generated_graph_size
if
valid_size
:
num_valid_size
+=
1
if
cycle
:
num_cycle
+=
1
if
valid_size
and
cycle
:
num_valid
+=
1
if
len
(
adj_lists_to_plot
)
==
4
:
plot_times
+=
1
fig
,
((
ax0
,
ax1
),
(
ax2
,
ax3
))
=
plt
.
subplots
(
2
,
2
)
axes
=
{
0
:
ax0
,
1
:
ax1
,
2
:
ax2
,
3
:
ax3
}
for
i
in
range
(
4
):
nx
.
draw_circular
(
nx
.
from_dict_of_lists
(
adj_lists_to_plot
[
i
]),
with_labels
=
True
,
ax
=
axes
[
i
])
plt
.
savefig
(
self
.
dir
+
'/samples/{:d}'
.
format
(
plot_times
))
plt
.
close
()
adj_lists_to_plot
=
[]
self
.
num_samples_examined
=
num_samples
self
.
average_size
=
num_total_size
/
num_samples
self
.
valid_size_ratio
=
num_valid_size
/
num_samples
self
.
cycle_ratio
=
num_cycle
/
num_samples
self
.
valid_ratio
=
num_valid
/
num_samples
def
write_summary
(
self
):
def
_format_value
(
v
):
if
isinstance
(
v
,
float
):
return
'{:.4f}'
.
format
(
v
)
elif
isinstance
(
v
,
int
):
return
'{:d}'
.
format
(
v
)
else
:
return
'{}'
.
format
(
v
)
statistics
=
{
'num_samples'
:
self
.
num_samples_examined
,
'v_min'
:
self
.
v_min
,
'v_max'
:
self
.
v_max
,
'average_size'
:
self
.
average_size
,
'valid_size_ratio'
:
self
.
valid_size_ratio
,
'cycle_ratio'
:
self
.
cycle_ratio
,
'valid_ratio'
:
self
.
valid_ratio
}
model_eval_path
=
os
.
path
.
join
(
self
.
dir
,
'model_eval.txt'
)
with
open
(
model_eval_path
,
'w'
)
as
f
:
for
key
,
value
in
statistics
.
items
():
msg
=
'{}
\t
{}
\n
'
.
format
(
key
,
_format_value
(
value
))
f
.
write
(
msg
)
print
(
'Saved model evaluation statistics to {}'
.
format
(
model_eval_path
))
self
.
_initialize
()
class
CyclePrinting
(
object
):
def
__init__
(
self
,
num_epochs
,
num_batches
):
super
(
CyclePrinting
,
self
).
__init__
()
self
.
num_epochs
=
num_epochs
self
.
num_batches
=
num_batches
self
.
batch_count
=
0
def
update
(
self
,
epoch
,
metrics
):
self
.
batch_count
=
(
self
.
batch_count
)
%
self
.
num_batches
+
1
msg
=
'epoch {:d}/{:d}, batch {:d}/{:d}'
.
format
(
epoch
,
self
.
num_epochs
,
self
.
batch_count
,
self
.
num_batches
)
for
key
,
value
in
metrics
.
items
():
msg
+=
', {}: {:4f}'
.
format
(
key
,
value
)
print
(
msg
)
examples/pytorch/dgmg/main.py
0 → 100644
View file @
3e8b63ec
"""
Learning Deep Generative Models of Graphs
Paper: https://arxiv.org/pdf/1803.03324.pdf
This implementation works with a minibatch of size 1 only for both training and inference.
"""
import
argparse
import
datetime
import
time
from
torch.optim
import
Adam
from
torch.utils.data
import
DataLoader
from
torch.nn.utils
import
clip_grad_norm_
from
model
import
DGMG
def
main
(
opts
):
t1
=
time
.
time
()
# Setup dataset and data loader
if
opts
[
'dataset'
]
==
'cycles'
:
from
cycles
import
CycleDataset
,
CycleModelEvaluation
,
CyclePrinting
dataset
=
CycleDataset
(
fname
=
opts
[
'path_to_dataset'
])
evaluator
=
CycleModelEvaluation
(
v_min
=
opts
[
'min_size'
],
v_max
=
opts
[
'max_size'
],
dir
=
opts
[
'log_dir'
])
printer
=
CyclePrinting
(
num_epochs
=
opts
[
'nepochs'
],
num_batches
=
opts
[
'ds_size'
]
//
opts
[
'batch_size'
])
else
:
raise
ValueError
(
'Unsupported dataset: {}'
.
format
(
opts
[
'dataset'
]))
data_loader
=
DataLoader
(
dataset
,
batch_size
=
1
,
shuffle
=
True
,
num_workers
=
0
,
collate_fn
=
dataset
.
collate
)
# Initialize_model
model
=
DGMG
(
v_max
=
opts
[
'max_size'
],
node_hidden_size
=
opts
[
'node_hidden_size'
],
num_prop_rounds
=
opts
[
'num_propagation_rounds'
])
# Initialize optimizer
if
opts
[
'optimizer'
]
==
'Adam'
:
optimizer
=
Adam
(
model
.
parameters
(),
lr
=
opts
[
'lr'
])
else
:
raise
ValueError
(
'Unsupported argument for the optimizer'
)
t2
=
time
.
time
()
# Training
model
.
train
()
for
epoch
in
range
(
opts
[
'nepochs'
]):
batch_count
=
0
batch_loss
=
0
batch_prob
=
0
optimizer
.
zero_grad
()
for
i
,
data
in
enumerate
(
data_loader
):
log_prob
=
model
(
actions
=
data
)
prob
=
log_prob
.
detach
().
exp
()
loss
=
-
log_prob
/
opts
[
'batch_size'
]
prob_averaged
=
prob
/
opts
[
'batch_size'
]
loss
.
backward
()
batch_loss
+=
loss
.
item
()
batch_prob
+=
prob_averaged
.
item
()
batch_count
+=
1
if
batch_count
%
opts
[
'batch_size'
]
==
0
:
printer
.
update
(
epoch
+
1
,
{
'averaged_loss'
:
batch_loss
,
'averaged_prob'
:
batch_prob
})
if
opts
[
'clip_grad'
]:
clip_grad_norm_
(
model
.
parameters
(),
opts
[
'clip_bound'
])
optimizer
.
step
()
batch_loss
=
0
batch_prob
=
0
optimizer
.
zero_grad
()
t3
=
time
.
time
()
model
.
eval
()
evaluator
.
rollout_and_examine
(
model
,
opts
[
'num_generated_samples'
])
evaluator
.
write_summary
()
t4
=
time
.
time
()
print
(
'It took {} to setup.'
.
format
(
datetime
.
timedelta
(
seconds
=
t2
-
t1
)))
print
(
'It took {} to finish training.'
.
format
(
datetime
.
timedelta
(
seconds
=
t3
-
t2
)))
print
(
'It took {} to finish evaluation.'
.
format
(
datetime
.
timedelta
(
seconds
=
t4
-
t3
)))
print
(
'--------------------------------------------------------------------------'
)
print
(
'On average, an epoch takes {}.'
.
format
(
datetime
.
timedelta
(
seconds
=
(
t3
-
t2
)
/
opts
[
'nepochs'
])))
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'DGMG'
)
# configure
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
9284
,
help
=
'random seed'
)
# dataset
parser
.
add_argument
(
'--dataset'
,
choices
=
[
'cycles'
],
default
=
'cycles'
,
help
=
'dataset to use'
)
parser
.
add_argument
(
'--path-to-dataset'
,
type
=
str
,
default
=
'cycles.p'
,
help
=
'load the dataset if it exists, '
'generate it and save to the path otherwise'
)
# log
parser
.
add_argument
(
'--log-dir'
,
default
=
'./results'
,
help
=
'folder to save info like experiment configuration '
'or model evaluation results'
)
# optimization
parser
.
add_argument
(
'--batch-size'
,
type
=
int
,
default
=
10
,
help
=
'batch size to use for training'
)
parser
.
add_argument
(
'--clip-grad'
,
action
=
'store_true'
,
default
=
True
,
help
=
'gradient clipping is required to prevent gradient explosion'
)
parser
.
add_argument
(
'--clip-bound'
,
type
=
float
,
default
=
0.25
,
help
=
'constraint of gradient norm for gradient clipping'
)
args
=
parser
.
parse_args
()
from
utils
import
setup
opts
=
setup
(
args
)
main
(
opts
)
examples/pytorch/dgmg/model.py
View file @
3e8b63ec
import
dgl
from
dgl.graph
import
DGLGraph
from
dgl.nn
import
GCN
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
numpy
as
np
import
argparse
from
util
import
DataLoader
,
elapsed
,
generate_dataset
import
time
class
MLP
(
nn
.
Module
):
def
__init__
(
self
,
num_hidden
,
num_classes
,
num_layers
):
super
(
MLP
,
self
).
__init__
()
layers
=
[]
# hidden layers
for
_
in
range
(
num_layers
):
layers
.
append
(
nn
.
Linear
(
num_hidden
,
num_hidden
))
layers
.
append
(
nn
.
Sigmoid
())
# output projection
layers
.
append
(
nn
.
Linear
(
num_hidden
,
num_classes
))
self
.
layers
=
nn
.
Sequential
(
*
layers
)
def
forward
(
self
,
x
):
return
self
.
layers
(
x
)
def
move2cuda
(
x
):
# recursively move a object to cuda
if
isinstance
(
x
,
torch
.
Tensor
):
# if Tensor, move directly
return
x
.
cuda
()
from
functools
import
partial
from
torch.distributions
import
Bernoulli
,
Categorical
class
GraphEmbed
(
nn
.
Module
):
def
__init__
(
self
,
node_hidden_size
):
super
(
GraphEmbed
,
self
).
__init__
()
# Setting from the paper
self
.
graph_hidden_size
=
2
*
node_hidden_size
# Embed graphs
self
.
node_gating
=
nn
.
Sequential
(
nn
.
Linear
(
node_hidden_size
,
1
),
nn
.
Sigmoid
()
)
self
.
node_to_graph
=
nn
.
Linear
(
node_hidden_size
,
self
.
graph_hidden_size
)
def
forward
(
self
,
g
):
if
g
.
number_of_nodes
()
==
0
:
return
torch
.
zeros
(
1
,
self
.
graph_hidden_size
)
else
:
# Node features are stored as hv in ndata.
hvs
=
g
.
ndata
[
'hv'
]
return
(
self
.
node_gating
(
hvs
)
*
self
.
node_to_graph
(
hvs
)).
sum
(
0
,
keepdim
=
True
)
class
GraphProp
(
nn
.
Module
):
def
__init__
(
self
,
num_prop_rounds
,
node_hidden_size
):
super
(
GraphProp
,
self
).
__init__
()
self
.
num_prop_rounds
=
num_prop_rounds
# Setting from the paper
self
.
node_activation_hidden_size
=
2
*
node_hidden_size
message_funcs
=
[]
self
.
reduce_funcs
=
[]
node_update_funcs
=
[]
for
t
in
range
(
num_prop_rounds
):
# input being [hv, hu, xuv]
message_funcs
.
append
(
nn
.
Linear
(
2
*
node_hidden_size
+
1
,
self
.
node_activation_hidden_size
))
self
.
reduce_funcs
.
append
(
partial
(
self
.
dgmg_reduce
,
round
=
t
))
node_update_funcs
.
append
(
nn
.
GRUCell
(
self
.
node_activation_hidden_size
,
node_hidden_size
))
self
.
message_funcs
=
nn
.
ModuleList
(
message_funcs
)
self
.
node_update_funcs
=
nn
.
ModuleList
(
node_update_funcs
)
def
dgmg_msg
(
self
,
edges
):
"""For an edge u->v, return concat([h_u, x_uv])"""
return
{
'm'
:
torch
.
cat
([
edges
.
src
[
'hv'
],
edges
.
data
[
'he'
]],
dim
=
1
)}
def
dgmg_reduce
(
self
,
nodes
,
round
):
hv_old
=
nodes
.
data
[
'hv'
]
m
=
nodes
.
mailbox
[
'm'
]
message
=
torch
.
cat
([
hv_old
.
unsqueeze
(
1
).
expand
(
-
1
,
m
.
size
(
1
),
-
1
),
m
],
dim
=
2
)
node_activation
=
(
self
.
message_funcs
[
round
](
message
)).
sum
(
1
)
return
{
'a'
:
node_activation
}
def
forward
(
self
,
g
):
if
g
.
number_of_edges
()
==
0
:
return
else
:
for
t
in
range
(
self
.
num_prop_rounds
):
g
.
update_all
(
message_func
=
self
.
dgmg_msg
,
reduce_func
=
self
.
reduce_funcs
[
t
])
g
.
ndata
[
'hv'
]
=
self
.
node_update_funcs
[
t
](
g
.
ndata
[
'a'
],
g
.
ndata
[
'hv'
])
def
bernoulli_action_log_prob
(
logit
,
action
):
"""Calculate the log p of an action with respect to a Bernoulli
distribution. Use logit rather than prob for numerical stability."""
if
action
==
0
:
return
F
.
logsigmoid
(
-
logit
)
else
:
try
:
# iterable, recursively move each element
x
=
[
move2cuda
(
i
)
for
i
in
x
]
return
x
except
:
# don't do anything for other types like basic types
return
x
return
F
.
logsigmoid
(
logit
)
class
AddNode
(
nn
.
Module
):
def
__init__
(
self
,
graph_embed_func
,
node_hidden_size
):
super
(
AddNode
,
self
).
__init__
()
self
.
graph_op
=
{
'embed'
:
graph_embed_func
}
self
.
stop
=
1
self
.
add_node
=
nn
.
Linear
(
graph_embed_func
.
graph_hidden_size
,
1
)
# If to add a node, initialize its hv
self
.
node_type_embed
=
nn
.
Embedding
(
1
,
node_hidden_size
)
self
.
initialize_hv
=
nn
.
Linear
(
node_hidden_size
+
\
graph_embed_func
.
graph_hidden_size
,
node_hidden_size
)
self
.
init_node_activation
=
torch
.
zeros
(
1
,
2
*
node_hidden_size
)
def
_initialize_node_repr
(
self
,
g
,
node_type
,
graph_embed
):
num_nodes
=
g
.
number_of_nodes
()
hv_init
=
self
.
initialize_hv
(
torch
.
cat
([
self
.
node_type_embed
(
torch
.
LongTensor
([
node_type
])),
graph_embed
],
dim
=
1
))
g
.
nodes
[
num_nodes
-
1
].
data
[
'hv'
]
=
hv_init
g
.
nodes
[
num_nodes
-
1
].
data
[
'a'
]
=
self
.
init_node_activation
def
prepare_training
(
self
):
self
.
log_prob
=
[]
def
forward
(
self
,
g
,
action
=
None
):
graph_embed
=
self
.
graph_op
[
'embed'
](
g
)
logit
=
self
.
add_node
(
graph_embed
)
prob
=
torch
.
sigmoid
(
logit
)
if
not
self
.
training
:
action
=
Bernoulli
(
prob
).
sample
().
item
()
stop
=
bool
(
action
==
self
.
stop
)
if
not
stop
:
g
.
add_nodes
(
1
)
self
.
_initialize_node_repr
(
g
,
action
,
graph_embed
)
if
self
.
training
:
sample_log_prob
=
bernoulli_action_log_prob
(
logit
,
action
)
self
.
log_prob
.
append
(
sample_log_prob
)
return
stop
class
AddEdge
(
nn
.
Module
):
def
__init__
(
self
,
graph_embed_func
,
node_hidden_size
):
super
(
AddEdge
,
self
).
__init__
()
self
.
graph_op
=
{
'embed'
:
graph_embed_func
}
self
.
add_edge
=
nn
.
Linear
(
graph_embed_func
.
graph_hidden_size
+
\
node_hidden_size
,
1
)
def
prepare_training
(
self
):
self
.
log_prob
=
[]
def
forward
(
self
,
g
,
action
=
None
):
graph_embed
=
self
.
graph_op
[
'embed'
](
g
)
src_embed
=
g
.
nodes
[
g
.
number_of_nodes
()
-
1
].
data
[
'hv'
]
logit
=
self
.
add_edge
(
torch
.
cat
(
[
graph_embed
,
src_embed
],
dim
=
1
))
prob
=
torch
.
sigmoid
(
logit
)
if
not
self
.
training
:
action
=
Bernoulli
(
prob
).
sample
().
item
()
to_add_edge
=
bool
(
action
==
0
)
if
self
.
training
:
sample_log_prob
=
bernoulli_action_log_prob
(
logit
,
action
)
self
.
log_prob
.
append
(
sample_log_prob
)
return
to_add_edge
class
ChooseDestAndUpdate
(
nn
.
Module
):
def
__init__
(
self
,
graph_prop_func
,
node_hidden_size
):
super
(
ChooseDestAndUpdate
,
self
).
__init__
()
self
.
graph_op
=
{
'prop'
:
graph_prop_func
}
self
.
choose_dest
=
nn
.
Linear
(
2
*
node_hidden_size
,
1
)
def
_initialize_edge_repr
(
self
,
g
,
src_list
,
dest_list
):
# For untyped edges, we only add 1 to indicate its existence.
# For multiple edge types, we can use a one hot representation
# or an embedding module.
edge_repr
=
torch
.
ones
(
len
(
src_list
),
1
)
g
.
edges
[
src_list
,
dest_list
].
data
[
'he'
]
=
edge_repr
def
prepare_training
(
self
):
self
.
log_prob
=
[]
def
forward
(
self
,
g
,
dest
):
src
=
g
.
number_of_nodes
()
-
1
possible_dests
=
range
(
src
)
src_embed_expand
=
g
.
nodes
[
src
].
data
[
'hv'
].
expand
(
src
,
-
1
)
possible_dests_embed
=
g
.
nodes
[
possible_dests
].
data
[
'hv'
]
dests_scores
=
self
.
choose_dest
(
torch
.
cat
([
possible_dests_embed
,
src_embed_expand
],
dim
=
1
)).
view
(
1
,
-
1
)
dests_probs
=
F
.
softmax
(
dests_scores
,
dim
=
1
)
if
not
self
.
training
:
dest
=
Categorical
(
dests_probs
).
sample
().
item
()
if
not
g
.
has_edge_between
(
src
,
dest
):
# For undirected graphs, we add edges for both directions
# so that we can perform graph propagation.
src_list
=
[
src
,
dest
]
dest_list
=
[
dest
,
src
]
g
.
add_edges
(
src_list
,
dest_list
)
self
.
_initialize_edge_repr
(
g
,
src_list
,
dest_list
)
self
.
graph_op
[
'prop'
](
g
)
if
self
.
training
:
if
dests_probs
.
nelement
()
>
1
:
self
.
log_prob
.
append
(
F
.
log_softmax
(
dests_scores
,
dim
=
1
)[:,
dest
:
dest
+
1
])
class
DGMG
(
nn
.
Module
):
def
__init__
(
self
,
node_num_hidden
,
graph_num_hidden
,
T
,
num_MLP_layers
=
1
,
loss_func
=
None
,
dropout
=
0.0
,
use_cuda
=
False
):
def
__init__
(
self
,
v_max
,
node_hidden_size
,
num_prop_rounds
):
super
(
DGMG
,
self
).
__init__
()
# hidden size of node and graph
self
.
node_num_hidden
=
node_num_hidden
self
.
graph_num_hidden
=
graph_num_hidden
# use GCN as a simple propagation model
self
.
gcn
=
nn
.
ModuleList
([
GCN
(
node_num_hidden
,
node_num_hidden
,
F
.
relu
,
dropout
)
for
_
in
range
(
T
)])
# project node repr to graph repr (higher dimension)
self
.
graph_project
=
nn
.
Linear
(
node_num_hidden
,
graph_num_hidden
)
# add node
self
.
fan
=
MLP
(
graph_num_hidden
,
2
,
num_MLP_layers
)
# add edge
self
.
fae
=
MLP
(
graph_num_hidden
+
node_num_hidden
,
1
,
num_MLP_layers
)
# select node to add edge
self
.
fs
=
MLP
(
node_num_hidden
*
2
,
1
,
num_MLP_layers
)
# init node state
self
.
finit
=
MLP
(
graph_num_hidden
,
node_num_hidden
,
num_MLP_layers
)
# loss function
self
.
loss_func
=
loss_func
# use gpu
self
.
use_cuda
=
use_cuda
def
decide_add_node
(
self
,
hGs
):
h
=
self
.
fan
(
hGs
)
p
=
F
.
softmax
(
h
,
dim
=
1
)
# calc loss
self
.
loss
+=
self
.
loss_func
(
p
,
self
.
labels
[
self
.
step
],
self
.
masks
[
self
.
step
])
def
decide_add_edge
(
self
,
batched_graph
,
hGs
):
hvs
=
batched_graph
.
get_n_repr
((
self
.
sample_node_curr_idx
-
1
).
tolist
())[
'h'
]
h
=
self
.
fae
(
torch
.
cat
((
hGs
,
hvs
),
dim
=
1
))
p
=
torch
.
sigmoid
(
h
)
p
=
torch
.
cat
([
1
-
p
,
p
],
dim
=
1
)
self
.
loss
+=
self
.
loss_func
(
p
,
self
.
labels
[
self
.
step
],
self
.
masks
[
self
.
step
])
def
select_node_to_add_edge
(
self
,
batched_graph
,
indices
):
node_indices
=
self
.
sample_node_curr_idx
[
indices
].
tolist
()
node_start
=
self
.
sample_node_start_idx
[
indices
].
tolist
()
node_repr
=
batched_graph
.
get_n_repr
()[
'h'
]
for
i
,
j
,
idx
in
zip
(
node_start
,
node_indices
,
indices
):
hu
=
node_repr
.
narrow
(
0
,
i
,
j
-
i
)
hv
=
node_repr
.
narrow
(
0
,
j
-
1
,
1
)
huv
=
torch
.
cat
((
hu
,
hv
.
expand
(
j
-
i
,
-
1
)),
dim
=
1
)
s
=
F
.
softmax
(
self
.
fs
(
huv
),
dim
=
0
).
view
(
1
,
-
1
)
dst
=
self
.
node_select
[
self
.
step
][
idx
].
view
(
-
1
)
self
.
loss
+=
self
.
loss_func
(
s
,
dst
)
def
update_graph_repr
(
self
,
batched_graph
,
hGs
,
indices
,
indices_tensor
):
start
=
self
.
sample_node_start_idx
[
indices
].
tolist
()
stop
=
self
.
sample_node_curr_idx
[
indices
].
tolist
()
node_repr
=
batched_graph
.
get_n_repr
()[
'h'
]
graph_repr
=
self
.
graph_project
(
node_repr
)
new_hGs
=
[]
for
i
,
j
in
zip
(
start
,
stop
):
h
=
graph_repr
.
narrow
(
0
,
i
,
j
-
i
)
hG
=
torch
.
sum
(
h
,
0
,
keepdim
=
True
)
new_hGs
.
append
(
hG
)
new_hGs
=
torch
.
cat
(
new_hGs
,
dim
=
0
)
return
hGs
.
index_copy
(
0
,
indices_tensor
,
new_hGs
)
def
propagate
(
self
,
batched_graph
,
indices
):
edge_src
=
[
self
.
sample_edge_src
[
idx
][
0
:
self
.
sample_edge_count
[
idx
]]
for
idx
in
indices
]
edge_dst
=
[
self
.
sample_edge_dst
[
idx
][
0
:
self
.
sample_edge_count
[
idx
]]
for
idx
in
indices
]
u
=
np
.
concatenate
(
edge_src
).
tolist
()
v
=
np
.
concatenate
(
edge_dst
).
tolist
()
for
gcn
in
self
.
gcn
:
gcn
.
forward
(
batched_graph
,
u
,
v
,
attribute
=
'h'
)
def
forward
(
self
,
training
=
False
,
ground_truth
=
None
):
if
not
training
:
raise
NotImplementedError
(
"inference is not implemented yet"
)
assert
(
ground_truth
is
not
None
)
signals
,
(
batched_graph
,
self
.
sample_edge_src
,
self
.
sample_edge_dst
)
=
ground_truth
nsteps
,
self
.
labels
,
self
.
node_select
,
self
.
masks
,
active_step
,
label1_set
,
label1_set_tensor
=
signals
# init loss
self
.
loss
=
0
batch_size
=
len
(
self
.
sample_edge_src
)
# initial node repr for each sample
hVs
=
torch
.
zeros
(
len
(
batched_graph
),
self
.
node_num_hidden
)
# FIXME: what's the initial grpah repr for empty graph?
hGs
=
torch
.
zeros
(
batch_size
,
self
.
graph_num_hidden
)
if
self
.
use_cuda
:
hVs
=
hVs
.
cuda
()
hGs
=
hGs
.
cuda
()
batched_graph
.
set_n_repr
({
'h'
:
hVs
})
self
.
sample_node_start_idx
=
batched_graph
.
query_node_start_offset
()
self
.
sample_node_curr_idx
=
self
.
sample_node_start_idx
.
copy
()
self
.
sample_edge_count
=
np
.
zeros
(
batch_size
,
dtype
=
int
)
self
.
step
=
0
while
self
.
step
<
nsteps
:
if
self
.
step
%
2
==
0
:
# add node step
if
active_step
[
self
.
step
]:
# decide whether to add node
self
.
decide_add_node
(
hGs
)
# calculate initial state for new node
hvs
=
self
.
finit
(
hGs
)
# add node
update
=
label1_set
[
self
.
step
]
if
len
(
update
)
>
0
:
hvs
=
torch
.
index_select
(
hvs
,
0
,
label1_set_tensor
[
self
.
step
])
scatter_indices
=
self
.
sample_node_curr_idx
[
update
]
batched_graph
.
set_n_repr
({
'h'
:
hvs
},
scatter_indices
.
tolist
())
self
.
sample_node_curr_idx
[
update
]
+=
1
# get new graph repr
hGs
=
self
.
update_graph_repr
(
batched_graph
,
hGs
,
update
,
label1_set_tensor
[
self
.
step
])
else
:
# all samples are masked
pass
else
:
# add edge step
# decide whether to add edge, which edge to add
# and also add edge
self
.
decide_add_edge
(
batched_graph
,
hGs
)
# propagate
to_add_edge
=
label1_set
[
self
.
step
]
if
len
(
to_add_edge
)
>
0
:
# at least one graph needs update
self
.
select_node_to_add_edge
(
batched_graph
,
to_add_edge
)
# update edge count for each sample
self
.
sample_edge_count
[
to_add_edge
]
+=
2
# undirected graph
# perform gcn propagation
self
.
propagate
(
batched_graph
,
to_add_edge
)
# get new graph repr
hGs
=
self
.
update_graph_repr
(
batched_graph
,
hGs
,
label1_set
[
self
.
step
],
label1_set_tensor
[
self
.
step
])
self
.
step
+=
1
def
main
(
args
):
if
torch
.
cuda
.
is_available
()
and
args
.
gpu
>=
0
:
torch
.
cuda
.
set_device
(
args
.
gpu
)
use_cuda
=
True
else
:
use_cuda
=
False
def
masked_cross_entropy
(
x
,
label
,
mask
=
None
):
# x: propability tensor, i.e. after softmax
x
=
torch
.
log
(
x
)
if
mask
is
not
None
:
x
=
x
[
mask
]
label
=
label
[
mask
]
return
F
.
nll_loss
(
x
,
label
)
model
=
DGMG
(
args
.
n_hidden_node
,
args
.
n_hidden_graph
,
args
.
n_layers
,
loss_func
=
masked_cross_entropy
,
dropout
=
args
.
dropout
,
use_cuda
=
use_cuda
)
if
use_cuda
:
model
.
cuda
()
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
)
# training loop
for
ep
in
range
(
args
.
n_epochs
):
print
(
"epoch: {}"
.
format
(
ep
))
for
idx
,
ground_truth
in
enumerate
(
DataLoader
(
args
.
dataset
,
args
.
batch_size
)):
if
use_cuda
:
count
,
label
,
node_list
,
mask
,
active
,
label1
,
label1_tensor
=
ground_truth
[
0
]
label
,
node_list
,
mask
,
label1_tensor
=
move2cuda
((
label
,
node_list
,
mask
,
label1_tensor
))
ground_truth
[
0
]
=
(
count
,
label
,
node_list
,
mask
,
active
,
label1
,
label1_tensor
)
optimizer
.
zero_grad
()
# create new empty graphs
start
=
time
.
time
()
model
.
forward
(
True
,
ground_truth
)
end
=
time
.
time
()
elapsed
(
"model forward"
,
start
,
end
)
start
=
time
.
time
()
model
.
loss
.
backward
()
optimizer
.
step
()
end
=
time
.
time
()
elapsed
(
"model backward"
,
start
,
end
)
print
(
"iter {}: loss {}"
.
format
(
idx
,
model
.
loss
.
item
()))
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'DGMG'
)
parser
.
add_argument
(
"--dropout"
,
type
=
float
,
default
=
0
,
help
=
"dropout probability"
)
parser
.
add_argument
(
"--gpu"
,
type
=
int
,
default
=-
1
,
help
=
"gpu"
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
1e-2
,
help
=
"learning rate"
)
parser
.
add_argument
(
"--n-epochs"
,
type
=
int
,
default
=
20
,
help
=
"number of training epochs"
)
parser
.
add_argument
(
"--n-hidden-node"
,
type
=
int
,
default
=
16
,
help
=
"number of hidden DGMG node units"
)
parser
.
add_argument
(
"--n-hidden-graph"
,
type
=
int
,
default
=
32
,
help
=
"number of hidden DGMG graph units"
)
parser
.
add_argument
(
"--n-layers"
,
type
=
int
,
default
=
2
,
help
=
"number of hidden gcn layers"
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
default
=
'samples.p'
,
help
=
"dataset pickle file"
)
parser
.
add_argument
(
"--gen-dataset"
,
type
=
str
,
default
=
None
,
help
=
"parameters to generate B-A graph datasets. Format: <#node>,<#edge>,<#sample>"
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
32
,
help
=
"batch size"
)
args
=
parser
.
parse_args
()
print
(
args
)
# generate dataset if needed
if
args
.
gen_dataset
is
not
None
:
n_node
,
n_edge
,
n_sample
=
map
(
int
,
args
.
gen_dataset
.
split
(
','
))
generate_dataset
(
n_node
,
n_edge
,
n_sample
,
args
.
dataset
)
main
(
args
)
# Graph configuration
self
.
v_max
=
v_max
# Graph embedding module
self
.
graph_embed
=
GraphEmbed
(
node_hidden_size
)
# Graph propagation module
self
.
graph_prop
=
GraphProp
(
num_prop_rounds
,
node_hidden_size
)
# Actions
self
.
add_node_agent
=
AddNode
(
self
.
graph_embed
,
node_hidden_size
)
self
.
add_edge_agent
=
AddEdge
(
self
.
graph_embed
,
node_hidden_size
)
self
.
choose_dest_agent
=
ChooseDestAndUpdate
(
self
.
graph_prop
,
node_hidden_size
)
# Weight initialization
self
.
init_weights
()
def
init_weights
(
self
):
from
utils
import
weights_init
,
dgmg_message_weight_init
self
.
graph_embed
.
apply
(
weights_init
)
self
.
graph_prop
.
apply
(
weights_init
)
self
.
add_node_agent
.
apply
(
weights_init
)
self
.
add_edge_agent
.
apply
(
weights_init
)
self
.
choose_dest_agent
.
apply
(
weights_init
)
self
.
graph_prop
.
message_funcs
.
apply
(
dgmg_message_weight_init
)
@
property
def
action_step
(
self
):
old_step_count
=
self
.
step_count
self
.
step_count
+=
1
return
old_step_count
def
prepare_for_train
(
self
):
self
.
step_count
=
0
self
.
add_node_agent
.
prepare_training
()
self
.
add_edge_agent
.
prepare_training
()
self
.
choose_dest_agent
.
prepare_training
()
def
add_node_and_update
(
self
,
a
=
None
):
"""Decide if to add a new node.
If a new node should be added, update the graph."""
return
self
.
add_node_agent
(
self
.
g
,
a
)
def
add_edge_or_not
(
self
,
a
=
None
):
"""Decide if a new edge should be added."""
return
self
.
add_edge_agent
(
self
.
g
,
a
)
def
choose_dest_and_update
(
self
,
a
=
None
):
"""Choose destination and connect it to the latest node.
Add edges for both directions and update the graph."""
self
.
choose_dest_agent
(
self
.
g
,
a
)
def
get_log_prob
(
self
):
return
torch
.
cat
(
self
.
add_node_agent
.
log_prob
).
sum
()
\
+
torch
.
cat
(
self
.
add_edge_agent
.
log_prob
).
sum
()
\
+
torch
.
cat
(
self
.
choose_dest_agent
.
log_prob
).
sum
()
def
forward_train
(
self
,
actions
):
self
.
prepare_for_train
()
stop
=
self
.
add_node_and_update
(
a
=
actions
[
self
.
action_step
])
while
not
stop
:
to_add_edge
=
self
.
add_edge_or_not
(
a
=
actions
[
self
.
action_step
])
while
to_add_edge
:
self
.
choose_dest_and_update
(
a
=
actions
[
self
.
action_step
])
to_add_edge
=
self
.
add_edge_or_not
(
a
=
actions
[
self
.
action_step
])
stop
=
self
.
add_node_and_update
(
a
=
actions
[
self
.
action_step
])
return
self
.
get_log_prob
()
def
forward_inference
(
self
):
stop
=
self
.
add_node_and_update
()
while
(
not
stop
)
and
(
self
.
g
.
number_of_nodes
()
<
self
.
v_max
+
1
):
num_trials
=
0
to_add_edge
=
self
.
add_edge_or_not
()
while
to_add_edge
and
(
num_trials
<
self
.
g
.
number_of_nodes
()
-
1
):
self
.
choose_dest_and_update
()
num_trials
+=
1
to_add_edge
=
self
.
add_edge_or_not
()
stop
=
self
.
add_node_and_update
()
return
self
.
g
def
forward
(
self
,
actions
=
None
):
# The graph we will work on
self
.
g
=
dgl
.
DGLGraph
()
# If there are some features for nodes and edges,
# zero tensors will be set for those of new nodes and edges.
self
.
g
.
set_n_initializer
(
dgl
.
frame
.
zero_initializer
)
self
.
g
.
set_e_initializer
(
dgl
.
frame
.
zero_initializer
)
if
self
.
training
:
return
self
.
forward_train
(
actions
)
else
:
return
self
.
forward_inference
()
examples/pytorch/dgmg/util.py
deleted
100644 → 0
View file @
bf6d0025
import
networkx
as
nx
import
pickle
import
random
import
dgl
import
numpy
as
np
import
torch
def
convert_graph_to_ordering
(
g
):
ordering
=
[]
h
=
nx
.
DiGraph
()
h
.
add_edges_from
(
g
.
edges
)
for
n
in
range
(
len
(
h
)):
ordering
.
append
(
n
)
for
m
in
h
.
predecessors
(
n
):
ordering
.
append
((
m
,
n
))
return
ordering
def
generate_dataset
(
n
,
m
,
n_samples
,
fname
):
samples
=
[]
for
_
in
range
(
n_samples
):
g
=
nx
.
barabasi_albert_graph
(
n
,
m
)
samples
.
append
(
convert_graph_to_ordering
(
g
))
with
open
(
fname
,
'wb'
)
as
f
:
pickle
.
dump
(
samples
,
f
)
class
DataLoader
(
object
):
def
__init__
(
self
,
fname
,
batch_size
,
shuffle
=
True
):
with
open
(
fname
,
'rb'
)
as
f
:
datasets
=
pickle
.
load
(
f
)
if
shuffle
:
random
.
shuffle
(
datasets
)
num
=
len
(
datasets
)
//
batch_size
# pre-process dataset
self
.
ground_truth
=
[]
for
i
in
range
(
num
):
batch
=
datasets
[
i
*
batch_size
:
(
i
+
1
)
*
batch_size
]
padded_signals
=
pad_ground_truth
(
batch
)
merged_graph
=
generate_merged_graph
(
batch
)
self
.
ground_truth
.
append
([
padded_signals
,
merged_graph
])
def
__iter__
(
self
):
return
iter
(
self
.
ground_truth
)
def
generate_merged_graph
(
batch
):
n_graphs
=
len
(
batch
)
graph_list
=
[]
# build each sample graph
new_edges
=
[]
for
ordering
in
batch
:
g
=
dgl
.
DGLGraph
()
node_count
=
0
edge_list
=
[]
for
step
in
ordering
:
if
isinstance
(
step
,
int
):
node_count
+=
1
else
:
assert
isinstance
(
step
,
tuple
)
edge_list
.
append
(
step
)
edge_list
.
append
(
tuple
(
reversed
(
step
)))
g
.
add_nodes_from
(
range
(
node_count
))
g
.
add_edges_from
(
edge_list
)
new_edges
.
append
(
zip
(
*
edge_list
))
graph_list
.
append
(
g
)
# batch
bg
=
dgl
.
batch
(
graph_list
)
# get new edges
new_edges
=
[
bg
.
query_new_edge
(
g
,
*
edges
)
for
g
,
edges
in
zip
(
graph_list
,
new_edges
)]
new_src
,
new_dst
=
zip
(
*
new_edges
)
return
bg
,
new_src
,
new_dst
def
expand_ground_truth
(
ordering
):
node_list
=
[]
action
=
[]
label
=
[]
first_step
=
True
for
i
in
ordering
:
if
isinstance
(
i
,
int
):
if
not
first_step
:
# add not to add edge
action
.
append
(
1
)
label
.
append
(
0
)
node_list
.
append
(
-
1
)
else
:
first_step
=
False
action
.
append
(
0
)
# add node
label
.
append
(
1
)
node_list
.
append
(
i
)
else
:
assert
(
isinstance
(
i
,
tuple
))
action
.
append
(
1
)
label
.
append
(
1
)
node_list
.
append
(
i
[
0
])
# select src node to add
# add not to add node
action
.
append
(
0
)
label
.
append
(
0
)
node_list
.
append
(
-
1
)
return
len
(
action
),
action
,
label
,
node_list
def
pad_ground_truth
(
batch
):
a
=
[]
bz
=
len
(
batch
)
for
sample
in
batch
:
a
.
append
(
expand_ground_truth
(
sample
))
length
,
action
,
label
,
node_list
=
zip
(
*
a
)
step
=
[
0
]
*
bz
new_label
=
[]
new_node_list
=
[]
mask_for_batch
=
[]
next_action
=
0
count
=
0
active_step
=
[]
# steps at least some graphs are not masked
label1_set
=
[]
# graphs who decide to add node or edge
label1_set_tensor
=
[]
while
any
([
step
[
i
]
<
length
[
i
]
for
i
in
range
(
bz
)]):
node_select
=
[]
label_select
=
[]
mask
=
[]
label1
=
[]
not_all_masked
=
False
for
sample_idx
in
range
(
bz
):
if
step
[
sample_idx
]
<
length
[
sample_idx
]
and
\
action
[
sample_idx
][
step
[
sample_idx
]]
==
next_action
:
mask
.
append
(
1
)
node_select
.
append
(
node_list
[
sample_idx
][
step
[
sample_idx
]])
label_select
.
append
(
label
[
sample_idx
][
step
[
sample_idx
]])
# if decide to add node or add edge, record sample_idx
if
label_select
[
-
1
]
==
1
:
label1
.
append
(
sample_idx
)
step
[
sample_idx
]
+=
1
not_all_masked
=
True
else
:
mask
.
append
(
0
)
node_select
.
append
(
-
1
)
label_select
.
append
(
0
)
next_action
=
1
-
next_action
new_node_list
.
append
(
torch
.
LongTensor
(
node_select
))
mask_for_batch
.
append
(
torch
.
ByteTensor
(
mask
))
new_label
.
append
(
torch
.
LongTensor
(
label_select
))
active_step
.
append
(
not_all_masked
)
label1_set
.
append
(
np
.
array
(
label1
))
label1_set_tensor
.
append
(
torch
.
LongTensor
(
label1
))
count
+=
1
return
count
,
new_label
,
new_node_list
,
mask_for_batch
,
active_step
,
label1_set
,
label1_set_tensor
def
elapsed
(
msg
,
start
,
end
):
print
(
"{}: {} ms"
.
format
(
msg
,
int
((
end
-
start
)
*
1000
)))
if
__name__
==
'__main__'
:
n
=
15
m
=
2
n_samples
=
1024
fname
=
'samples.p'
generate_dataset
(
n
,
m
,
n_samples
,
fname
)
examples/pytorch/dgmg/utils.py
0 → 100644
View file @
3e8b63ec
import
datetime
import
matplotlib.pyplot
as
plt
import
os
import
random
import
torch
import
torch.backends.cudnn
as
cudnn
import
torch.nn
as
nn
import
torch.nn.init
as
init
from
pprint
import
pprint
########################################################################################################################
# configuration #
########################################################################################################################
def
mkdir_p
(
path
):
import
errno
try
:
os
.
makedirs
(
path
)
print
(
'Created directory {}'
.
format
(
path
))
except
OSError
as
exc
:
if
exc
.
errno
==
errno
.
EEXIST
and
os
.
path
.
isdir
(
path
):
print
(
'Directory {} already exists.'
.
format
(
path
))
else
:
raise
def
date_filename
(
base_dir
=
'./'
):
dt
=
datetime
.
datetime
.
now
()
return
os
.
path
.
join
(
base_dir
,
'{}_{:02d}-{:02d}-{:02d}'
.
format
(
dt
.
date
(),
dt
.
hour
,
dt
.
minute
,
dt
.
second
))
def
setup_log_dir
(
opts
):
log_dir
=
'{}'
.
format
(
date_filename
(
opts
[
'log_dir'
]))
mkdir_p
(
log_dir
)
return
log_dir
def
save_arg_dict
(
opts
,
filename
=
'settings.txt'
):
def
_format_value
(
v
):
if
isinstance
(
v
,
float
):
return
'{:.4f}'
.
format
(
v
)
elif
isinstance
(
v
,
int
):
return
'{:d}'
.
format
(
v
)
else
:
return
'{}'
.
format
(
v
)
save_path
=
os
.
path
.
join
(
opts
[
'log_dir'
],
filename
)
with
open
(
save_path
,
'w'
)
as
f
:
for
key
,
value
in
opts
.
items
():
f
.
write
(
'{}
\t
{}
\n
'
.
format
(
key
,
_format_value
(
value
)))
print
(
'Saved settings to {}'
.
format
(
save_path
))
def
setup
(
args
):
opts
=
args
.
__dict__
.
copy
()
cudnn
.
benchmark
=
False
cudnn
.
deterministic
=
True
# Seed
if
opts
[
'seed'
]
is
None
:
opts
[
'seed'
]
=
random
.
randint
(
1
,
10000
)
random
.
seed
(
opts
[
'seed'
])
torch
.
manual_seed
(
opts
[
'seed'
])
# Dataset
from
configure
import
dataset_based_configure
opts
=
dataset_based_configure
(
opts
)
assert
opts
[
'path_to_dataset'
]
is
not
None
,
'Expect path to dataset to be set.'
if
not
os
.
path
.
exists
(
opts
[
'path_to_dataset'
]):
if
opts
[
'dataset'
]
==
'cycles'
:
from
cycles
import
generate_dataset
generate_dataset
(
opts
[
'min_size'
],
opts
[
'max_size'
],
opts
[
'ds_size'
],
opts
[
'path_to_dataset'
])
else
:
raise
ValueError
(
'Unsupported dataset: {}'
.
format
(
opts
[
'dataset'
]))
# Optimization
if
opts
[
'clip_grad'
]:
assert
opts
[
'clip_grad'
]
is
not
None
,
'Expect the gradient norm constraint to be set.'
# Log
print
(
'Prepare logging directory...'
)
log_dir
=
setup_log_dir
(
opts
)
opts
[
'log_dir'
]
=
log_dir
mkdir_p
(
log_dir
+
'/samples'
)
plt
.
switch_backend
(
'Agg'
)
save_arg_dict
(
opts
)
pprint
(
opts
)
return
opts
########################################################################################################################
# model #
########################################################################################################################
def
weights_init
(
m
):
'''
Code from https://gist.github.com/jeasinema/ed9236ce743c8efaf30fa2ff732749f5
Usage:
model = Model()
model.apply(weight_init)
'''
if
isinstance
(
m
,
nn
.
Linear
):
init
.
xavier_normal_
(
m
.
weight
.
data
)
init
.
normal_
(
m
.
bias
.
data
)
elif
isinstance
(
m
,
nn
.
GRUCell
):
for
param
in
m
.
parameters
():
if
len
(
param
.
shape
)
>=
2
:
init
.
orthogonal_
(
param
.
data
)
else
:
init
.
normal_
(
param
.
data
)
def
dgmg_message_weight_init
(
m
):
"""
This is similar as the function above where we initialize linear layers from a normal distribution with std
1./10 as suggested by the author. This should only be used for the message passing functions, i.e. fe's in the
paper.
"""
def
_weight_init
(
m
):
if
isinstance
(
m
,
nn
.
Linear
):
init
.
normal_
(
m
.
weight
.
data
,
std
=
1.
/
10
)
init
.
normal_
(
m
.
bias
.
data
,
std
=
1.
/
10
)
else
:
raise
ValueError
(
'Expected the input to be of type nn.Linear!'
)
if
isinstance
(
m
,
nn
.
ModuleList
):
for
layer
in
m
:
layer
.
apply
(
_weight_init
)
else
:
m
.
apply
(
_weight_init
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment