Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
36c7b771
Unverified
Commit
36c7b771
authored
Jun 05, 2020
by
Mufei Li
Committed by
GitHub
Jun 05, 2020
Browse files
[LifeSci] Move to Independent Repo (#1592)
* Move LifeSci * Remove doc
parent
94c67203
Changes
203
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
4233 deletions
+0
-4233
python/dgl/model_zoo/API.md
python/dgl/model_zoo/API.md
+0
-15
python/dgl/model_zoo/__init__.py
python/dgl/model_zoo/__init__.py
+0
-2
python/dgl/model_zoo/chem/__init__.py
python/dgl/model_zoo/chem/__init__.py
+0
-11
python/dgl/model_zoo/chem/acnn.py
python/dgl/model_zoo/chem/acnn.py
+0
-221
python/dgl/model_zoo/chem/attentive_fp.py
python/dgl/model_zoo/chem/attentive_fp.py
+0
-366
python/dgl/model_zoo/chem/classifiers.py
python/dgl/model_zoo/chem/classifiers.py
+0
-173
python/dgl/model_zoo/chem/dgmg.py
python/dgl/model_zoo/chem/dgmg.py
+0
-834
python/dgl/model_zoo/chem/gnn.py
python/dgl/model_zoo/chem/gnn.py
+0
-134
python/dgl/model_zoo/chem/jtnn/__init__.py
python/dgl/model_zoo/chem/jtnn/__init__.py
+0
-6
python/dgl/model_zoo/chem/jtnn/chemutils.py
python/dgl/model_zoo/chem/jtnn/chemutils.py
+0
-393
python/dgl/model_zoo/chem/jtnn/jtmpn.py
python/dgl/model_zoo/chem/jtnn/jtmpn.py
+0
-267
python/dgl/model_zoo/chem/jtnn/jtnn_dec.py
python/dgl/model_zoo/chem/jtnn/jtnn_dec.py
+0
-398
python/dgl/model_zoo/chem/jtnn/jtnn_enc.py
python/dgl/model_zoo/chem/jtnn/jtnn_enc.py
+0
-126
python/dgl/model_zoo/chem/jtnn/jtnn_vae.py
python/dgl/model_zoo/chem/jtnn/jtnn_vae.py
+0
-333
python/dgl/model_zoo/chem/jtnn/mol_tree.py
python/dgl/model_zoo/chem/jtnn/mol_tree.py
+0
-30
python/dgl/model_zoo/chem/jtnn/mol_tree_nx.py
python/dgl/model_zoo/chem/jtnn/mol_tree_nx.py
+0
-127
python/dgl/model_zoo/chem/jtnn/mpn.py
python/dgl/model_zoo/chem/jtnn/mpn.py
+0
-189
python/dgl/model_zoo/chem/jtnn/nnutils.py
python/dgl/model_zoo/chem/jtnn/nnutils.py
+0
-58
python/dgl/model_zoo/chem/layers.py
python/dgl/model_zoo/chem/layers.py
+0
-412
python/dgl/model_zoo/chem/mgcn.py
python/dgl/model_zoo/chem/mgcn.py
+0
-138
No files found.
python/dgl/model_zoo/API.md
deleted
100644 → 0
View file @
94c67203
Model Zoo API
==================
We provide two major APIs for the model zoo. For the time being, only PyTorch is supported.
-
`model_zoo.chem.[Model_Name]`
to load the model skeleton
-
`model_zoo.chem.load_pretrained([Pretrained_Model_Name])`
to load the model with pretrained weights
Models would be placed in
`python/dgl/model_zoo/chem`
.
Each Model should contain the following elements:
-
Papers related to the model
-
Model's input and output
-
Dataset compatible with the model
-
Documentation for all the customizable configs
-
Credits (Contributor infomation)
python/dgl/model_zoo/__init__.py
deleted
100644 → 0
View file @
94c67203
"""Package for model zoo."""
from
.
import
chem
python/dgl/model_zoo/chem/__init__.py
deleted
100644 → 0
View file @
94c67203
# pylint: disable=C0111
"""Model Zoo Package"""
from
.classifiers
import
GCNClassifier
,
GATClassifier
from
.schnet
import
SchNet
from
.mgcn
import
MGCNModel
from
.mpnn
import
MPNNModel
from
.dgmg
import
DGMG
from
.jtnn
import
DGLJTNNVAE
from
.pretrain
import
load_pretrained
from
.attentive_fp
import
AttentiveFP
from
.acnn
import
ACNN
python/dgl/model_zoo/chem/acnn.py
deleted
100644 → 0
View file @
94c67203
"""Atomic Convolutional Networks for Predicting Protein-Ligand Binding Affinity"""
# pylint: disable=C0103, C0123
import
itertools
import
torch
import
torch.nn
as
nn
from
...nn.pytorch
import
AtomicConv
from
...contrib.deprecation
import
deprecated
def
truncated_normal_
(
tensor
,
mean
=
0.
,
std
=
1.
):
"""Fills the given tensor in-place with elements sampled from the truncated normal
distribution parameterized by mean and std.
The generated values follow a normal distribution with specified mean and
standard deviation, except that values whose magnitude is more than 2 std
from the mean are dropped.
We credit to Ruotian Luo for this implementation:
https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15.
Parameters
----------
tensor : Float32 tensor of arbitrary shape
Tensor to be filled.
mean : float
Mean of the truncated normal distribution.
std : float
Standard deviation of the truncated normal distribution.
"""
shape
=
tensor
.
shape
tmp
=
tensor
.
new_empty
(
shape
+
(
4
,)).
normal_
()
valid
=
(
tmp
<
2
)
&
(
tmp
>
-
2
)
ind
=
valid
.
max
(
-
1
,
keepdim
=
True
)[
1
]
tensor
.
data
.
copy_
(
tmp
.
gather
(
-
1
,
ind
).
squeeze
(
-
1
))
tensor
.
data
.
mul_
(
std
).
add_
(
mean
)
class
ACNNPredictor
(
nn
.
Module
):
"""Predictor for ACNN.
Parameters
----------
in_size : int
Number of radial filters used.
hidden_sizes : list of int
Specifying the hidden sizes for all layers in the predictor.
weight_init_stddevs : list of float
Specifying the standard deviations to use for truncated normal
distributions in initialzing weights for the predictor.
dropouts : list of float
Specifying the dropouts to use for all layers in the predictor.
features_to_use : None or float tensor of shape (T)
In the original paper, these are atomic numbers to consider, representing the types
of atoms. T for the number of types of atomic numbers. Default to None.
num_tasks : int
Output size.
"""
def
__init__
(
self
,
in_size
,
hidden_sizes
,
weight_init_stddevs
,
dropouts
,
features_to_use
,
num_tasks
):
super
(
ACNNPredictor
,
self
).
__init__
()
if
type
(
features_to_use
)
!=
type
(
None
):
in_size
*=
len
(
features_to_use
)
modules
=
[]
for
i
,
h
in
enumerate
(
hidden_sizes
):
linear_layer
=
nn
.
Linear
(
in_size
,
h
)
truncated_normal_
(
linear_layer
.
weight
,
std
=
weight_init_stddevs
[
i
])
modules
.
append
(
linear_layer
)
modules
.
append
(
nn
.
ReLU
())
modules
.
append
(
nn
.
Dropout
(
dropouts
[
i
]))
in_size
=
h
linear_layer
=
nn
.
Linear
(
in_size
,
num_tasks
)
truncated_normal_
(
linear_layer
.
weight
,
std
=
weight_init_stddevs
[
-
1
])
modules
.
append
(
linear_layer
)
self
.
project
=
nn
.
Sequential
(
*
modules
)
def
forward
(
self
,
batch_size
,
frag1_node_indices_in_complex
,
frag2_node_indices_in_complex
,
ligand_conv_out
,
protein_conv_out
,
complex_conv_out
):
"""Perform the prediction.
Parameters
----------
batch_size : int
Number of datapoints in a batch.
frag1_node_indices_in_complex : Int64 tensor of shape (V1)
Indices for atoms in the first fragment (protein) in the batched complex.
frag2_node_indices_in_complex : list of int of length V2
Indices for atoms in the second fragment (ligand) in the batched complex.
ligand_conv_out : Float32 tensor of shape (V2, K * T)
Updated ligand node representations. V2 for the number of atoms in the
ligand, K for the number of radial filters, and T for the number of types
of atomic numbers.
protein_conv_out : Float32 tensor of shape (V1, K * T)
Updated protein node representations. V1 for the number of
atoms in the protein, K for the number of radial filters,
and T for the number of types of atomic numbers.
complex_conv_out : Float32 tensor of shape (V1 + V2, K * T)
Updated complex node representations. V1 and V2 separately
for the number of atoms in the ligand and protein, K for
the number of radial filters, and T for the number of
types of atomic numbers.
Returns
-------
Float32 tensor of shape (B, O)
Predicted protein-ligand binding affinity. B for the number
of protein-ligand pairs in the batch and O for the number of tasks.
"""
ligand_feats
=
self
.
project
(
ligand_conv_out
)
# (V1, O)
protein_feats
=
self
.
project
(
protein_conv_out
)
# (V2, O)
complex_feats
=
self
.
project
(
complex_conv_out
)
# (V1+V2, O)
ligand_energy
=
ligand_feats
.
reshape
(
batch_size
,
-
1
).
sum
(
-
1
,
keepdim
=
True
)
# (B, O)
protein_energy
=
protein_feats
.
reshape
(
batch_size
,
-
1
).
sum
(
-
1
,
keepdim
=
True
)
# (B, O)
complex_ligand_energy
=
complex_feats
[
frag1_node_indices_in_complex
].
reshape
(
batch_size
,
-
1
).
sum
(
-
1
,
keepdim
=
True
)
complex_protein_energy
=
complex_feats
[
frag2_node_indices_in_complex
].
reshape
(
batch_size
,
-
1
).
sum
(
-
1
,
keepdim
=
True
)
complex_energy
=
complex_ligand_energy
+
complex_protein_energy
return
complex_energy
-
(
ligand_energy
+
protein_energy
)
class
ACNN
(
nn
.
Module
):
"""Atomic Convolutional Networks.
The model was proposed in `Atomic Convolutional Networks for
Predicting Protein-Ligand Binding Affinity <https://arxiv.org/abs/1703.10603>`__.
Parameters
----------
hidden_sizes : list of int
Specifying the hidden sizes for all layers in the predictor.
weight_init_stddevs : list of float
Specifying the standard deviations to use for truncated normal
distributions in initialzing weights for the predictor.
dropouts : list of float
Specifying the dropouts to use for all layers in the predictor.
features_to_use : None or float tensor of shape (T)
In the original paper, these are atomic numbers to consider, representing the types
of atoms. T for the number of types of atomic numbers. Default to None.
radial : None or list
If not None, the list consists of 3 lists of floats, separately for the
options of interaction cutoff, the options of rbf kernel mean and the
options of rbf kernel scaling. If None, a default option of
``[[12.0], [0.0, 2.0, 4.0, 6.0, 8.0], [4.0]]`` will be used.
num_tasks : int
Number of output tasks.
"""
@
deprecated
(
'Import ACNN from dgllife.model instead.'
)
def
__init__
(
self
,
hidden_sizes
,
weight_init_stddevs
,
dropouts
,
features_to_use
=
None
,
radial
=
None
,
num_tasks
=
1
):
super
(
ACNN
,
self
).
__init__
()
if
radial
is
None
:
radial
=
[[
12.0
],
[
0.0
,
2.0
,
4.0
,
6.0
,
8.0
],
[
4.0
]]
# Take the product of sets of options and get a list of 3-tuples.
radial_params
=
[
x
for
x
in
itertools
.
product
(
*
radial
)]
radial_params
=
torch
.
stack
(
list
(
map
(
torch
.
tensor
,
zip
(
*
radial_params
))),
dim
=
1
)
interaction_cutoffs
=
radial_params
[:,
0
]
rbf_kernel_means
=
radial_params
[:,
1
]
rbf_kernel_scaling
=
radial_params
[:,
2
]
self
.
ligand_conv
=
AtomicConv
(
interaction_cutoffs
,
rbf_kernel_means
,
rbf_kernel_scaling
,
features_to_use
)
self
.
protein_conv
=
AtomicConv
(
interaction_cutoffs
,
rbf_kernel_means
,
rbf_kernel_scaling
,
features_to_use
)
self
.
complex_conv
=
AtomicConv
(
interaction_cutoffs
,
rbf_kernel_means
,
rbf_kernel_scaling
,
features_to_use
)
self
.
predictor
=
ACNNPredictor
(
radial_params
.
shape
[
0
],
hidden_sizes
,
weight_init_stddevs
,
dropouts
,
features_to_use
,
num_tasks
)
def
forward
(
self
,
graph
):
"""Apply the model for prediction.
Parameters
----------
graph : DGLHeteroGraph
DGLHeteroGraph consisting of the ligand graph, the protein graph
and the complex graph, along with preprocessed features.
Returns
-------
Float32 tensor of shape (B, O)
Predicted protein-ligand binding affinity. B for the number
of protein-ligand pairs in the batch and O for the number of tasks.
"""
ligand_graph
=
graph
[(
'ligand_atom'
,
'ligand'
,
'ligand_atom'
)]
ligand_graph_node_feats
=
ligand_graph
.
ndata
[
'atomic_number'
]
assert
ligand_graph_node_feats
.
shape
[
-
1
]
==
1
ligand_graph_distances
=
ligand_graph
.
edata
[
'distance'
]
ligand_conv_out
=
self
.
ligand_conv
(
ligand_graph
,
ligand_graph_node_feats
,
ligand_graph_distances
)
protein_graph
=
graph
[(
'protein_atom'
,
'protein'
,
'protein_atom'
)]
protein_graph_node_feats
=
protein_graph
.
ndata
[
'atomic_number'
]
assert
protein_graph_node_feats
.
shape
[
-
1
]
==
1
protein_graph_distances
=
protein_graph
.
edata
[
'distance'
]
protein_conv_out
=
self
.
protein_conv
(
protein_graph
,
protein_graph_node_feats
,
protein_graph_distances
)
complex_graph
=
graph
[:,
'complex'
,
:]
complex_graph_node_feats
=
complex_graph
.
ndata
[
'atomic_number'
]
assert
complex_graph_node_feats
.
shape
[
-
1
]
==
1
complex_graph_distances
=
complex_graph
.
edata
[
'distance'
]
complex_conv_out
=
self
.
complex_conv
(
complex_graph
,
complex_graph_node_feats
,
complex_graph_distances
)
frag1_node_indices_in_complex
=
torch
.
where
(
complex_graph
.
ndata
[
'_TYPE'
]
==
0
)[
0
]
frag2_node_indices_in_complex
=
list
(
set
(
range
(
complex_graph
.
number_of_nodes
()))
-
set
(
frag1_node_indices_in_complex
.
tolist
()))
return
self
.
predictor
(
graph
.
batch_size
,
frag1_node_indices_in_complex
,
frag2_node_indices_in_complex
,
ligand_conv_out
,
protein_conv_out
,
complex_conv_out
)
python/dgl/model_zoo/chem/attentive_fp.py
deleted
100644 → 0
View file @
94c67203
# pylint: disable=C0103, W0612, E1101
"""Pushing the Boundaries of Molecular Representation for Drug Discovery
with the Graph Attention Mechanism"""
import
dgl
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
...
import
function
as
fn
from
...contrib.deprecation
import
deprecated
from
...nn.pytorch.softmax
import
edge_softmax
class
AttentiveGRU1
(
nn
.
Module
):
"""Update node features with attention and GRU.
Parameters
----------
node_feat_size : int
Size for the input node (atom) features.
edge_feat_size : int
Size for the input edge (bond) features.
edge_hidden_size : int
Size for the intermediate edge (bond) representations.
dropout : float
The probability for performing dropout.
"""
def
__init__
(
self
,
node_feat_size
,
edge_feat_size
,
edge_hidden_size
,
dropout
):
super
(
AttentiveGRU1
,
self
).
__init__
()
self
.
edge_transform
=
nn
.
Sequential
(
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
edge_feat_size
,
edge_hidden_size
)
)
self
.
gru
=
nn
.
GRUCell
(
edge_hidden_size
,
node_feat_size
)
def
forward
(
self
,
g
,
edge_logits
,
edge_feats
,
node_feats
):
"""
Parameters
----------
g : DGLGraph
edge_logits : float32 tensor of shape (E, 1)
The edge logits based on which softmax will be performed for weighting
edges within 1-hop neighborhoods. E represents the number of edges.
edge_feats : float32 tensor of shape (E, M1)
Previous edge features.
node_feats : float32 tensor of shape (V, M2)
Previous node features.
Returns
-------
float32 tensor of shape (V, M2)
Updated node features.
"""
g
=
g
.
local_var
()
g
.
edata
[
'e'
]
=
edge_softmax
(
g
,
edge_logits
)
*
self
.
edge_transform
(
edge_feats
)
g
.
update_all
(
fn
.
copy_edge
(
'e'
,
'm'
),
fn
.
sum
(
'm'
,
'c'
))
context
=
F
.
elu
(
g
.
ndata
[
'c'
])
return
F
.
relu
(
self
.
gru
(
context
,
node_feats
))
class
AttentiveGRU2
(
nn
.
Module
):
"""Update node features with attention and GRU.
Parameters
----------
node_feat_size : int
Size for the input node (atom) features.
edge_hidden_size : int
Size for the intermediate edge (bond) representations.
dropout : float
The probability for performing dropout.
"""
def
__init__
(
self
,
node_feat_size
,
edge_hidden_size
,
dropout
):
super
(
AttentiveGRU2
,
self
).
__init__
()
self
.
project_node
=
nn
.
Sequential
(
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
node_feat_size
,
edge_hidden_size
)
)
self
.
gru
=
nn
.
GRUCell
(
edge_hidden_size
,
node_feat_size
)
def
forward
(
self
,
g
,
edge_logits
,
node_feats
):
"""
Parameters
----------
g : DGLGraph
edge_logits : float32 tensor of shape (E, 1)
The edge logits based on which softmax will be performed for weighting
edges within 1-hop neighborhoods. E represents the number of edges.
node_feats : float32 tensor of shape (V, M2)
Previous node features.
Returns
-------
float32 tensor of shape (V, M2)
Updated node features.
"""
g
=
g
.
local_var
()
g
.
edata
[
'a'
]
=
edge_softmax
(
g
,
edge_logits
)
g
.
ndata
[
'hv'
]
=
self
.
project_node
(
node_feats
)
g
.
update_all
(
fn
.
src_mul_edge
(
'hv'
,
'a'
,
'm'
),
fn
.
sum
(
'm'
,
'c'
))
context
=
F
.
elu
(
g
.
ndata
[
'c'
])
return
F
.
relu
(
self
.
gru
(
context
,
node_feats
))
class
GetContext
(
nn
.
Module
):
"""Generate context for each node (atom) by message passing at the beginning.
Parameters
----------
node_feat_size : int
Size for the input node (atom) features.
edge_feat_size : int
Size for the input edge (bond) features.
graph_feat_size : int
Size of the learned graph representation (molecular fingerprint).
dropout : float
The probability for performing dropout.
"""
def
__init__
(
self
,
node_feat_size
,
edge_feat_size
,
graph_feat_size
,
dropout
):
super
(
GetContext
,
self
).
__init__
()
self
.
project_node
=
nn
.
Sequential
(
nn
.
Linear
(
node_feat_size
,
graph_feat_size
),
nn
.
LeakyReLU
()
)
self
.
project_edge1
=
nn
.
Sequential
(
nn
.
Linear
(
node_feat_size
+
edge_feat_size
,
graph_feat_size
),
nn
.
LeakyReLU
()
)
self
.
project_edge2
=
nn
.
Sequential
(
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
2
*
graph_feat_size
,
1
),
nn
.
LeakyReLU
()
)
self
.
attentive_gru
=
AttentiveGRU1
(
graph_feat_size
,
graph_feat_size
,
graph_feat_size
,
dropout
)
def
apply_edges1
(
self
,
edges
):
"""Edge feature update."""
return
{
'he1'
:
torch
.
cat
([
edges
.
src
[
'hv'
],
edges
.
data
[
'he'
]],
dim
=
1
)}
def
apply_edges2
(
self
,
edges
):
"""Edge feature update."""
return
{
'he2'
:
torch
.
cat
([
edges
.
dst
[
'hv_new'
],
edges
.
data
[
'he1'
]],
dim
=
1
)}
def
forward
(
self
,
g
,
node_feats
,
edge_feats
):
"""
Parameters
----------
g : DGLGraph
Constructed DGLGraphs.
node_feats : float32 tensor of shape (V, N1)
Input node features. V for the number of nodes and N1 for the feature size.
edge_feats : float32 tensor of shape (E, N2)
Input edge features. E for the number of edges and N2 for the feature size.
Returns
-------
float32 tensor of shape (V, N3)
Updated node features.
"""
g
=
g
.
local_var
()
g
.
ndata
[
'hv'
]
=
node_feats
g
.
ndata
[
'hv_new'
]
=
self
.
project_node
(
node_feats
)
g
.
edata
[
'he'
]
=
edge_feats
g
.
apply_edges
(
self
.
apply_edges1
)
g
.
edata
[
'he1'
]
=
self
.
project_edge1
(
g
.
edata
[
'he1'
])
g
.
apply_edges
(
self
.
apply_edges2
)
logits
=
self
.
project_edge2
(
g
.
edata
[
'he2'
])
return
self
.
attentive_gru
(
g
,
logits
,
g
.
edata
[
'he1'
],
g
.
ndata
[
'hv_new'
])
class
GNNLayer
(
nn
.
Module
):
"""GNNLayer for updating node features.
Parameters
----------
node_feat_size : int
Size for the input node features.
graph_feat_size : int
Size for the input graph features.
dropout : float
The probability for performing dropout.
"""
def
__init__
(
self
,
node_feat_size
,
graph_feat_size
,
dropout
):
super
(
GNNLayer
,
self
).
__init__
()
self
.
project_edge
=
nn
.
Sequential
(
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
2
*
node_feat_size
,
1
),
nn
.
LeakyReLU
()
)
self
.
attentive_gru
=
AttentiveGRU2
(
node_feat_size
,
graph_feat_size
,
dropout
)
def
apply_edges
(
self
,
edges
):
"""Edge feature update by concatenating the features of the destination
and source nodes."""
return
{
'he'
:
torch
.
cat
([
edges
.
dst
[
'hv'
],
edges
.
src
[
'hv'
]],
dim
=
1
)}
def
forward
(
self
,
g
,
node_feats
):
"""
Parameters
----------
g : DGLGraph
Constructed DGLGraphs.
node_feats : float32 tensor of shape (V, N1)
Input node features. V for the number of nodes and N1 for the feature size.
Returns
-------
float32 tensor of shape (V, N1)
Updated node features.
"""
g
=
g
.
local_var
()
g
.
ndata
[
'hv'
]
=
node_feats
g
.
apply_edges
(
self
.
apply_edges
)
logits
=
self
.
project_edge
(
g
.
edata
[
'he'
])
return
self
.
attentive_gru
(
g
,
logits
,
node_feats
)
class
GlobalPool
(
nn
.
Module
):
"""Graph feature update.
Parameters
----------
node_feat_size : int
Size for the input node features.
graph_feat_size : int
Size for the input graph features.
dropout : float
The probability for performing dropout.
"""
def
__init__
(
self
,
node_feat_size
,
graph_feat_size
,
dropout
):
super
(
GlobalPool
,
self
).
__init__
()
self
.
compute_logits
=
nn
.
Sequential
(
nn
.
Linear
(
node_feat_size
+
graph_feat_size
,
1
),
nn
.
LeakyReLU
()
)
self
.
project_nodes
=
nn
.
Sequential
(
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
node_feat_size
,
graph_feat_size
)
)
self
.
gru
=
nn
.
GRUCell
(
graph_feat_size
,
graph_feat_size
)
def
forward
(
self
,
g
,
node_feats
,
g_feats
,
get_node_weight
=
False
):
"""
Parameters
----------
g : DGLGraph
Constructed DGLGraphs.
node_feats : float32 tensor of shape (V, N1)
Input node features. V for the number of nodes and N1 for the feature size.
g_feats : float32 tensor of shape (G, N2)
Input graph features. G for the number of graphs and N2 for the feature size.
get_node_weight : bool
Whether to get the weights of atoms during readout.
Returns
-------
float32 tensor of shape (G, N2)
Updated graph features.
float32 tensor of shape (V, 1)
The weights of nodes in readout.
"""
with
g
.
local_scope
():
g
.
ndata
[
'z'
]
=
self
.
compute_logits
(
torch
.
cat
([
dgl
.
broadcast_nodes
(
g
,
F
.
relu
(
g_feats
)),
node_feats
],
dim
=
1
))
g
.
ndata
[
'a'
]
=
dgl
.
softmax_nodes
(
g
,
'z'
)
g
.
ndata
[
'hv'
]
=
self
.
project_nodes
(
node_feats
)
context
=
F
.
elu
(
dgl
.
sum_nodes
(
g
,
'hv'
,
'a'
))
if
get_node_weight
:
return
self
.
gru
(
context
,
g_feats
),
g
.
ndata
[
'a'
]
else
:
return
self
.
gru
(
context
,
g_feats
)
class
AttentiveFP
(
nn
.
Module
):
"""`Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph
Attention Mechanism <https://www.ncbi.nlm.nih.gov/pubmed/31408336>`__
Parameters
----------
node_feat_size : int
Size for the input node (atom) features.
edge_feat_size : int
Size for the input edge (bond) features.
num_layers : int
Number of GNN layers.
num_timesteps : int
Number of timesteps for updating the molecular representation with GRU.
graph_feat_size : int
Size of the learned graph representation (molecular fingerprint).
output_size : int
Size of the prediction (target labels).
dropout : float
The probability for performing dropout.
"""
@
deprecated
(
'Import AttentiveFPPredictor from dgllife.model instead.'
,
'class'
)
def
__init__
(
self
,
node_feat_size
,
edge_feat_size
,
num_layers
,
num_timesteps
,
graph_feat_size
,
output_size
,
dropout
):
super
(
AttentiveFP
,
self
).
__init__
()
self
.
init_context
=
GetContext
(
node_feat_size
,
edge_feat_size
,
graph_feat_size
,
dropout
)
self
.
gnn_layers
=
nn
.
ModuleList
()
for
i
in
range
(
num_layers
-
1
):
self
.
gnn_layers
.
append
(
GNNLayer
(
graph_feat_size
,
graph_feat_size
,
dropout
))
self
.
readouts
=
nn
.
ModuleList
()
for
t
in
range
(
num_timesteps
):
self
.
readouts
.
append
(
GlobalPool
(
graph_feat_size
,
graph_feat_size
,
dropout
))
self
.
predict
=
nn
.
Sequential
(
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
graph_feat_size
,
output_size
)
)
def
forward
(
self
,
g
,
node_feats
,
edge_feats
,
get_node_weight
=
False
):
"""
Parameters
----------
g : DGLGraph
Constructed DGLGraphs.
node_feats : float32 tensor of shape (V, N1)
Input node features. V for the number of nodes and N1 for the feature size.
edge_feats : float32 tensor of shape (E, N2)
Input edge features. E for the number of edges and N2 for the feature size.
get_node_weight : bool
Whether to get the weights of atoms during readout.
Returns
-------
float32 tensor of shape (G, N3)
Prediction for the graphs. G for the number of graphs and N3 for the output size.
node_weights : list of float32 tensors of shape (V, 1)
Weights of nodes in all readout operations.
"""
node_feats
=
self
.
init_context
(
g
,
node_feats
,
edge_feats
)
for
gnn
in
self
.
gnn_layers
:
node_feats
=
gnn
(
g
,
node_feats
)
with
g
.
local_scope
():
g
.
ndata
[
'hv'
]
=
node_feats
g_feats
=
dgl
.
sum_nodes
(
g
,
'hv'
)
if
get_node_weight
:
node_weights
=
[]
for
readout
in
self
.
readouts
:
if
get_node_weight
:
g_feats
,
node_weights_t
=
readout
(
g
,
node_feats
,
g_feats
,
get_node_weight
)
node_weights
.
append
(
node_weights_t
)
else
:
g_feats
=
readout
(
g
,
node_feats
,
g_feats
)
if
get_node_weight
:
return
self
.
predict
(
g_feats
),
node_weights
else
:
return
self
.
predict
(
g_feats
)
python/dgl/model_zoo/chem/classifiers.py
deleted
100644 → 0
View file @
94c67203
# pylint: disable=C0111, C0103, C0200
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
.gnn
import
GCNLayer
,
GATLayer
from
...readout
import
max_nodes
from
...nn.pytorch
import
WeightAndSum
from
...contrib.deprecation
import
deprecated
class
MLPBinaryClassifier
(
nn
.
Module
):
"""MLP for soft binary classification over multiple tasks from molecule representations.
Parameters
----------
in_feats : int
Number of input molecular graph features
hidden_feats : int
Number of molecular graph features in hidden layers
n_tasks : int
Number of tasks, also output size
dropout : float
The probability for dropout. Default to be 0., i.e. no
dropout is performed.
"""
def
__init__
(
self
,
in_feats
,
hidden_feats
,
n_tasks
,
dropout
=
0.
):
super
(
MLPBinaryClassifier
,
self
).
__init__
()
self
.
predict
=
nn
.
Sequential
(
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
in_feats
,
hidden_feats
),
nn
.
ReLU
(),
nn
.
BatchNorm1d
(
hidden_feats
),
nn
.
Linear
(
hidden_feats
,
n_tasks
)
)
def
forward
(
self
,
h
):
"""Perform soft binary classification over multiple tasks
Parameters
----------
h : FloatTensor of shape (B, M3)
* B is the number of molecules in a batch
* M3 is the input molecule feature size, must match in_feats in initialization
Returns
-------
FloatTensor of shape (B, n_tasks)
"""
return
self
.
predict
(
h
)
class
BaseGNNClassifier
(
nn
.
Module
):
"""GCN based predictor for multitask prediction on molecular graphs
We assume each task requires to perform a binary classification.
Parameters
----------
gnn_out_feats : int
Number of atom representation features after using GNN
n_tasks : int
Number of prediction tasks
classifier_hidden_feats : int
Number of molecular graph features in hidden layers of the MLP Classifier
dropout : float
The probability for dropout. Default to be 0., i.e. no
dropout is performed.
"""
def
__init__
(
self
,
gnn_out_feats
,
n_tasks
,
classifier_hidden_feats
=
128
,
dropout
=
0.
):
super
(
BaseGNNClassifier
,
self
).
__init__
()
self
.
gnn_layers
=
nn
.
ModuleList
()
self
.
weighted_sum_readout
=
WeightAndSum
(
gnn_out_feats
)
self
.
g_feats
=
2
*
gnn_out_feats
self
.
soft_classifier
=
MLPBinaryClassifier
(
self
.
g_feats
,
classifier_hidden_feats
,
n_tasks
,
dropout
)
def
forward
(
self
,
g
,
feats
):
"""Multi-task prediction for a batch of molecules
Parameters
----------
g : DGLGraph
DGLGraph with batch size B for processing multiple molecules in parallel
feats : FloatTensor of shape (N, M0)
Initial features for all atoms in the batch of molecules
Returns
-------
FloatTensor of shape (B, n_tasks)
Soft prediction for all tasks on the batch of molecules
"""
# Update atom features with GNNs
for
gnn
in
self
.
gnn_layers
:
feats
=
gnn
(
g
,
feats
)
# Compute molecule features from atom features
h_g_sum
=
self
.
weighted_sum_readout
(
g
,
feats
)
with
g
.
local_scope
():
g
.
ndata
[
'h'
]
=
feats
h_g_max
=
max_nodes
(
g
,
'h'
)
h_g
=
torch
.
cat
([
h_g_sum
,
h_g_max
],
dim
=
1
)
# Multi-task prediction
return
self
.
soft_classifier
(
h_g
)
class
GCNClassifier
(
BaseGNNClassifier
):
"""GCN based predictor for multitask prediction on molecular graphs
We assume each task requires to perform a binary classification.
Parameters
----------
in_feats : int
Number of input atom features
gcn_hidden_feats : list of int
gcn_hidden_feats[i] gives the number of output atom features
in the i+1-th gcn layer
n_tasks : int
Number of prediction tasks
classifier_hidden_feats : int
Number of molecular graph features in hidden layers of the MLP Classifier
dropout : float
The probability for dropout. Default to be 0., i.e. no
dropout is performed.
"""
@
deprecated
(
'Import GCNPredictor from dgllife.model instead.'
,
'class'
)
def
__init__
(
self
,
in_feats
,
gcn_hidden_feats
,
n_tasks
,
classifier_hidden_feats
=
128
,
dropout
=
0.
):
super
(
GCNClassifier
,
self
).
__init__
(
gnn_out_feats
=
gcn_hidden_feats
[
-
1
],
n_tasks
=
n_tasks
,
classifier_hidden_feats
=
classifier_hidden_feats
,
dropout
=
dropout
)
for
i
in
range
(
len
(
gcn_hidden_feats
)):
out_feats
=
gcn_hidden_feats
[
i
]
self
.
gnn_layers
.
append
(
GCNLayer
(
in_feats
,
out_feats
))
in_feats
=
out_feats
class
GATClassifier
(
BaseGNNClassifier
):
"""GAT based predictor for multitask prediction on molecular graphs.
We assume each task requires to perform a binary classification.
Parameters
----------
in_feats : int
Number of input atom features
"""
@
deprecated
(
'Import GATPredictor from dgllife.model instead.'
,
'class'
)
def
__init__
(
self
,
in_feats
,
gat_hidden_feats
,
num_heads
,
n_tasks
,
classifier_hidden_feats
=
128
,
dropout
=
0
):
super
(
GATClassifier
,
self
).
__init__
(
gnn_out_feats
=
gat_hidden_feats
[
-
1
],
n_tasks
=
n_tasks
,
classifier_hidden_feats
=
classifier_hidden_feats
,
dropout
=
dropout
)
assert
len
(
gat_hidden_feats
)
==
len
(
num_heads
),
\
'Got gat_hidden_feats with length {:d} and num_heads with length {:d}, '
\
'expect them to be the same.'
.
format
(
len
(
gat_hidden_feats
),
len
(
num_heads
))
num_layers
=
len
(
num_heads
)
for
l
in
range
(
num_layers
):
if
l
>
0
:
in_feats
=
gat_hidden_feats
[
l
-
1
]
*
num_heads
[
l
-
1
]
if
l
==
num_layers
-
1
:
agg_mode
=
'mean'
agg_act
=
None
else
:
agg_mode
=
'flatten'
agg_act
=
F
.
elu
self
.
gnn_layers
.
append
(
GATLayer
(
in_feats
,
gat_hidden_feats
[
l
],
num_heads
[
l
],
feat_drop
=
dropout
,
attn_drop
=
dropout
,
agg_mode
=
agg_mode
,
activation
=
agg_act
))
python/dgl/model_zoo/chem/dgmg.py
deleted
100644 → 0
View file @
94c67203
# pylint: disable=C0103, W0622, R1710, W0104
"""
Learning Deep Generative Models of Graphs
https://arxiv.org/pdf/1803.03324.pdf
"""
from
functools
import
partial
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.init
as
init
from
torch.distributions
import
Categorical
import
dgl
from
dgl
import
DGLGraph
from
dgl.contrib.deprecation
import
deprecated
try
:
from
rdkit
import
Chem
except
ImportError
:
pass
class
MoleculeEnv
(
object
):
"""MDP environment for generating molecules.
Parameters
----------
atom_types : list
E.g. ['C', 'N']
bond_types : list
E.g. [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
"""
def
__init__
(
self
,
atom_types
,
bond_types
):
super
(
MoleculeEnv
,
self
).
__init__
()
self
.
atom_types
=
atom_types
self
.
bond_types
=
bond_types
self
.
atom_type_to_id
=
dict
()
self
.
bond_type_to_id
=
dict
()
for
id
,
a_type
in
enumerate
(
atom_types
):
self
.
atom_type_to_id
[
a_type
]
=
id
for
id
,
b_type
in
enumerate
(
bond_types
):
self
.
bond_type_to_id
[
b_type
]
=
id
def
get_decision_sequence
(
self
,
mol
,
atom_order
):
"""Extract a decision sequence with which DGMG can generate the
molecule with a specified atom order.
Parameters
----------
mol : Chem.rdchem.Mol
atom_order : list
Specifies a mapping between the original atom
indices and the new atom indices. In particular,
atom_order[i] is re-labeled as i.
Returns
-------
decisions : list
decisions[i] is a 2-tuple (i, j)
- If i = 0, j specifies either the type of the atom to add
self.atom_types[j] or termination with j = len(self.atom_types)
- If i = 1, j specifies either the type of the bond to add
self.bond_types[j] or termination with j = len(self.bond_types)
- If i = 2, j specifies the destination atom id for the bond to add.
With the formulation of DGMG, j must be created before the decision.
"""
decisions
=
[]
old2new
=
dict
()
for
new_id
,
old_id
in
enumerate
(
atom_order
):
atom
=
mol
.
GetAtomWithIdx
(
old_id
)
a_type
=
atom
.
GetSymbol
()
decisions
.
append
((
0
,
self
.
atom_type_to_id
[
a_type
]))
for
bond
in
atom
.
GetBonds
():
u
=
bond
.
GetBeginAtomIdx
()
v
=
bond
.
GetEndAtomIdx
()
if
v
==
old_id
:
u
,
v
=
v
,
u
if
v
in
old2new
:
decisions
.
append
((
1
,
self
.
bond_type_to_id
[
bond
.
GetBondType
()]))
decisions
.
append
((
2
,
old2new
[
v
]))
decisions
.
append
((
1
,
len
(
self
.
bond_types
)))
old2new
[
old_id
]
=
new_id
decisions
.
append
((
0
,
len
(
self
.
atom_types
)))
return
decisions
def
reset
(
self
,
rdkit_mol
=
False
):
"""Setup for generating a new molecule
Parameters
----------
rdkit_mol : bool
Whether to keep a Chem.rdchem.Mol object so
that we know what molecule is being generated
"""
self
.
dgl_graph
=
DGLGraph
()
# If there are some features for nodes and edges,
# zero tensors will be set for those of new nodes and edges.
self
.
dgl_graph
.
set_n_initializer
(
dgl
.
frame
.
zero_initializer
)
self
.
dgl_graph
.
set_e_initializer
(
dgl
.
frame
.
zero_initializer
)
self
.
mol
=
None
if
rdkit_mol
:
# RWMol is a molecule class that is intended to be edited.
self
.
mol
=
Chem
.
RWMol
(
Chem
.
MolFromSmiles
(
''
))
def
num_atoms
(
self
):
"""Get the number of atoms for the current molecule.
Returns
-------
int
"""
return
self
.
dgl_graph
.
number_of_nodes
()
def
add_atom
(
self
,
type
):
"""Add an atom of the specified type.
Parameters
----------
type : int
Should be in the range of [0, len(self.atom_types) - 1]
"""
self
.
dgl_graph
.
add_nodes
(
1
)
if
self
.
mol
is
not
None
:
self
.
mol
.
AddAtom
(
Chem
.
Atom
(
self
.
atom_types
[
type
]))
def
add_bond
(
self
,
u
,
v
,
type
,
bi_direction
=
True
):
"""Add a bond of the specified type between atom u and v.
Parameters
----------
u : int
Index for the first atom
v : int
Index for the second atom
type : int
Index for the bond type
bi_direction : bool
Whether to add edges for both directions in the DGLGraph.
If not, we will only add the edge (u, v).
"""
if
bi_direction
:
self
.
dgl_graph
.
add_edges
([
u
,
v
],
[
v
,
u
])
else
:
self
.
dgl_graph
.
add_edge
(
u
,
v
)
if
self
.
mol
is
not
None
:
self
.
mol
.
AddBond
(
u
,
v
,
self
.
bond_types
[
type
])
def
get_current_smiles
(
self
):
"""Get the generated molecule in SMILES
Returns
-------
s : str
SMILES
"""
assert
self
.
mol
is
not
None
,
'Expect a Chem.rdchem.Mol object initialized.'
s
=
Chem
.
MolToSmiles
(
self
.
mol
)
return
s
class
GraphEmbed
(
nn
.
Module
):
"""Compute a molecule representations out of atom representations.
Parameters
----------
node_hidden_size : int
Size of atom representation
"""
def
__init__
(
self
,
node_hidden_size
):
super
(
GraphEmbed
,
self
).
__init__
()
# Setting from the paper
self
.
graph_hidden_size
=
2
*
node_hidden_size
# Embed graphs
self
.
node_gating
=
nn
.
Sequential
(
nn
.
Linear
(
node_hidden_size
,
1
),
nn
.
Sigmoid
()
)
self
.
node_to_graph
=
nn
.
Linear
(
node_hidden_size
,
self
.
graph_hidden_size
)
def
forward
(
self
,
g
):
"""
Parameters
----------
g : DGLGraph
Current molecule graph
Returns
-------
tensor of dtype float32 and shape (1, self.graph_hidden_size)
Computed representation for the current molecule graph
"""
if
g
.
number_of_nodes
()
==
0
:
# Use a zero tensor for an empty molecule.
return
torch
.
zeros
(
1
,
self
.
graph_hidden_size
)
else
:
# Node features are stored as hv in ndata.
hvs
=
g
.
ndata
[
'hv'
]
return
(
self
.
node_gating
(
hvs
)
*
self
.
node_to_graph
(
hvs
)).
sum
(
0
,
keepdim
=
True
)
class
GraphProp
(
nn
.
Module
):
"""Perform message passing over a molecule graph and update its atom representations.
Parameters
----------
num_prop_rounds : int
Number of message passing rounds for each time
node_hidden_size : int
Size of atom representation
edge_hidden_size : int
Size of bond representation
"""
def
__init__
(
self
,
num_prop_rounds
,
node_hidden_size
,
edge_hidden_size
):
super
(
GraphProp
,
self
).
__init__
()
self
.
num_prop_rounds
=
num_prop_rounds
# Setting from the paper
self
.
node_activation_hidden_size
=
2
*
node_hidden_size
message_funcs
=
[]
self
.
reduce_funcs
=
[]
node_update_funcs
=
[]
for
t
in
range
(
num_prop_rounds
):
# input being [hv, hu, xuv]
message_funcs
.
append
(
nn
.
Linear
(
2
*
node_hidden_size
+
edge_hidden_size
,
self
.
node_activation_hidden_size
))
self
.
reduce_funcs
.
append
(
partial
(
self
.
dgmg_reduce
,
round
=
t
))
node_update_funcs
.
append
(
nn
.
GRUCell
(
self
.
node_activation_hidden_size
,
node_hidden_size
))
self
.
message_funcs
=
nn
.
ModuleList
(
message_funcs
)
self
.
node_update_funcs
=
nn
.
ModuleList
(
node_update_funcs
)
def
dgmg_msg
(
self
,
edges
):
"""For an edge u->v, send a message concat([h_u, x_uv])
Parameters
----------
edges : batch of edges
Returns
-------
dict
Dictionary containing messages for the edge batch,
with the messages being tensors of shape (B, F1),
B for the number of edges and F1 for the message size.
"""
return
{
'm'
:
torch
.
cat
([
edges
.
src
[
'hv'
],
edges
.
data
[
'he'
]],
dim
=
1
)}
def
dgmg_reduce
(
self
,
nodes
,
round
):
"""Aggregate messages.
Parameters
----------
nodes : batch of nodes
round : int
Update round
Returns
-------
dict
Dictionary containing aggregated messages for each node
in the batch, with the messages being tensors of shape
(B, F2), B for the number of nodes and F2 for the aggregated
message size
"""
hv_old
=
nodes
.
data
[
'hv'
]
m
=
nodes
.
mailbox
[
'm'
]
# Make copies of original atom representations to match the
# number of messages.
message
=
torch
.
cat
([
hv_old
.
unsqueeze
(
1
).
expand
(
-
1
,
m
.
size
(
1
),
-
1
),
m
],
dim
=
2
)
node_activation
=
(
self
.
message_funcs
[
round
](
message
)).
sum
(
1
)
return
{
'a'
:
node_activation
}
def
forward
(
self
,
g
):
"""
Parameters
----------
g : DGLGraph
"""
if
g
.
number_of_edges
()
==
0
:
return
else
:
for
t
in
range
(
self
.
num_prop_rounds
):
g
.
update_all
(
message_func
=
self
.
dgmg_msg
,
reduce_func
=
self
.
reduce_funcs
[
t
])
g
.
ndata
[
'hv'
]
=
self
.
node_update_funcs
[
t
](
g
.
ndata
[
'a'
],
g
.
ndata
[
'hv'
])
class
AddNode
(
nn
.
Module
):
"""Stop or add an atom of a particular type.
Parameters
----------
env : MoleculeEnv
Environment for generating molecules
graph_embed_func : callable taking g as input
Function for computing molecule representation
node_hidden_size : int
Size of atom representation
dropout : float
Probability for dropout
"""
def
__init__
(
self
,
env
,
graph_embed_func
,
node_hidden_size
,
dropout
):
super
(
AddNode
,
self
).
__init__
()
self
.
env
=
env
n_node_types
=
len
(
env
.
atom_types
)
self
.
graph_op
=
{
'embed'
:
graph_embed_func
}
self
.
stop
=
n_node_types
self
.
add_node
=
nn
.
Sequential
(
nn
.
Linear
(
graph_embed_func
.
graph_hidden_size
,
graph_embed_func
.
graph_hidden_size
),
nn
.
Dropout
(
p
=
dropout
),
nn
.
Linear
(
graph_embed_func
.
graph_hidden_size
,
n_node_types
+
1
)
)
# If to add a node, initialize its hv
self
.
node_type_embed
=
nn
.
Embedding
(
n_node_types
,
node_hidden_size
)
self
.
initialize_hv
=
nn
.
Linear
(
node_hidden_size
+
\
graph_embed_func
.
graph_hidden_size
,
node_hidden_size
)
self
.
init_node_activation
=
torch
.
zeros
(
1
,
2
*
node_hidden_size
)
self
.
dropout
=
nn
.
Dropout
(
p
=
dropout
)
def
_initialize_node_repr
(
self
,
g
,
node_type
,
graph_embed
):
"""Initialize atom representation
Parameters
----------
g : DGLGraph
node_type : int
Index for the type of the new atom
graph_embed : tensor of dtype float32
Molecule representation
"""
num_nodes
=
g
.
number_of_nodes
()
hv_init
=
torch
.
cat
([
self
.
node_type_embed
(
torch
.
LongTensor
([
node_type
])),
graph_embed
],
dim
=
1
)
hv_init
=
self
.
dropout
(
hv_init
)
hv_init
=
self
.
initialize_hv
(
hv_init
)
g
.
nodes
[
num_nodes
-
1
].
data
[
'hv'
]
=
hv_init
g
.
nodes
[
num_nodes
-
1
].
data
[
'a'
]
=
self
.
init_node_activation
def
prepare_log_prob
(
self
,
compute_log_prob
):
"""Setup for returning log likelihood
Parameters
----------
compute_log_prob : bool
Whether to compute log likelihood
"""
if
compute_log_prob
:
self
.
log_prob
=
[]
self
.
compute_log_prob
=
compute_log_prob
def
forward
(
self
,
action
=
None
):
"""
Parameters
----------
action : None or int
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
Returns
-------
stop : bool
Whether we stop adding new atoms
"""
g
=
self
.
env
.
dgl_graph
graph_embed
=
self
.
graph_op
[
'embed'
](
g
)
logits
=
self
.
add_node
(
graph_embed
).
view
(
1
,
-
1
)
probs
=
F
.
softmax
(
logits
,
dim
=
1
)
if
action
is
None
:
action
=
Categorical
(
probs
).
sample
().
item
()
stop
=
bool
(
action
==
self
.
stop
)
if
not
stop
:
self
.
env
.
add_atom
(
action
)
self
.
_initialize_node_repr
(
g
,
action
,
graph_embed
)
if
self
.
compute_log_prob
:
sample_log_prob
=
F
.
log_softmax
(
logits
,
dim
=
1
)[:,
action
:
action
+
1
]
self
.
log_prob
.
append
(
sample_log_prob
)
return
stop
class
AddEdge
(
nn
.
Module
):
"""Stop or add a bond of a particular type.
Parameters
----------
env : MoleculeEnv
Environment for generating molecules
graph_embed_func : callable taking g as input
Function for computing molecule representation
node_hidden_size : int
Size of atom representation
dropout : float
Probability for dropout
"""
def
__init__
(
self
,
env
,
graph_embed_func
,
node_hidden_size
,
dropout
):
super
(
AddEdge
,
self
).
__init__
()
self
.
env
=
env
n_bond_types
=
len
(
env
.
bond_types
)
self
.
stop
=
n_bond_types
self
.
graph_op
=
{
'embed'
:
graph_embed_func
}
self
.
add_edge
=
nn
.
Sequential
(
nn
.
Linear
(
graph_embed_func
.
graph_hidden_size
+
node_hidden_size
,
graph_embed_func
.
graph_hidden_size
+
node_hidden_size
),
nn
.
Dropout
(
p
=
dropout
),
nn
.
Linear
(
graph_embed_func
.
graph_hidden_size
+
node_hidden_size
,
n_bond_types
+
1
)
)
def
prepare_log_prob
(
self
,
compute_log_prob
):
"""Setup for returning log likelihood
Parameters
----------
compute_log_prob : bool
Whether to compute log likelihood
"""
if
compute_log_prob
:
self
.
log_prob
=
[]
self
.
compute_log_prob
=
compute_log_prob
def
forward
(
self
,
action
=
None
):
"""
Parameters
----------
action : None or int
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
Returns
-------
stop : bool
Whether we stop adding new bonds
action : int
The type for the new bond
"""
g
=
self
.
env
.
dgl_graph
graph_embed
=
self
.
graph_op
[
'embed'
](
g
)
src_embed
=
g
.
nodes
[
g
.
number_of_nodes
()
-
1
].
data
[
'hv'
]
logits
=
self
.
add_edge
(
torch
.
cat
([
graph_embed
,
src_embed
],
dim
=
1
))
probs
=
F
.
softmax
(
logits
,
dim
=
1
)
if
action
is
None
:
action
=
Categorical
(
probs
).
sample
().
item
()
stop
=
bool
(
action
==
self
.
stop
)
if
self
.
compute_log_prob
:
sample_log_prob
=
F
.
log_softmax
(
logits
,
dim
=
1
)[:,
action
:
action
+
1
]
self
.
log_prob
.
append
(
sample_log_prob
)
return
stop
,
action
class
ChooseDestAndUpdate
(
nn
.
Module
):
"""Choose the atom to connect for the new bond.
Parameters
----------
env : MoleculeEnv
Environment for generating molecules
graph_prop_func : callable taking g as input
Function for performing message passing
and updating atom representations
node_hidden_size : int
Size of atom representation
dropout : float
Probability for dropout
"""
def
__init__
(
self
,
env
,
graph_prop_func
,
node_hidden_size
,
dropout
):
super
(
ChooseDestAndUpdate
,
self
).
__init__
()
self
.
env
=
env
n_bond_types
=
len
(
self
.
env
.
bond_types
)
# To be used for one-hot encoding of bond type
self
.
bond_embedding
=
torch
.
eye
(
n_bond_types
)
self
.
graph_op
=
{
'prop'
:
graph_prop_func
}
self
.
choose_dest
=
nn
.
Sequential
(
nn
.
Linear
(
2
*
node_hidden_size
+
n_bond_types
,
2
*
node_hidden_size
+
n_bond_types
),
nn
.
Dropout
(
p
=
dropout
),
nn
.
Linear
(
2
*
node_hidden_size
+
n_bond_types
,
1
)
)
def
_initialize_edge_repr
(
self
,
g
,
src_list
,
dest_list
,
edge_embed
):
"""Initialize bond representation
Parameters
----------
g : DGLGraph
src_list : list of int
source atoms for new bonds
dest_list : list of int
destination atoms for new bonds
edge_embed : 2D tensor of dtype float32
Embeddings for the new bonds
"""
g
.
edges
[
src_list
,
dest_list
].
data
[
'he'
]
=
edge_embed
.
expand
(
len
(
src_list
),
-
1
)
def
prepare_log_prob
(
self
,
compute_log_prob
):
"""Setup for returning log likelihood
Parameters
----------
compute_log_prob : bool
Whether to compute log likelihood
"""
if
compute_log_prob
:
self
.
log_prob
=
[]
self
.
compute_log_prob
=
compute_log_prob
def
forward
(
self
,
bond_type
,
dest
):
"""
Parameters
----------
bond_type : int
The type for the new bond
dest : int or None
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
"""
g
=
self
.
env
.
dgl_graph
src
=
g
.
number_of_nodes
()
-
1
possible_dests
=
range
(
src
)
src_embed_expand
=
g
.
nodes
[
src
].
data
[
'hv'
].
expand
(
src
,
-
1
)
possible_dests_embed
=
g
.
nodes
[
possible_dests
].
data
[
'hv'
]
edge_embed
=
self
.
bond_embedding
[
bond_type
:
bond_type
+
1
]
dests_scores
=
self
.
choose_dest
(
torch
.
cat
([
possible_dests_embed
,
src_embed_expand
,
edge_embed
.
expand
(
src
,
-
1
)],
dim
=
1
)).
view
(
1
,
-
1
)
dests_probs
=
F
.
softmax
(
dests_scores
,
dim
=
1
)
if
dest
is
None
:
dest
=
Categorical
(
dests_probs
).
sample
().
item
()
if
not
g
.
has_edge_between
(
src
,
dest
):
# For undirected graphs, we add edges for both directions
# so that we can perform graph propagation.
src_list
=
[
src
,
dest
]
dest_list
=
[
dest
,
src
]
self
.
env
.
add_bond
(
src
,
dest
,
bond_type
)
self
.
_initialize_edge_repr
(
g
,
src_list
,
dest_list
,
edge_embed
)
# Perform message passing when new bonds are added.
self
.
graph_op
[
'prop'
](
g
)
if
self
.
compute_log_prob
:
if
dests_probs
.
nelement
()
>
1
:
self
.
log_prob
.
append
(
F
.
log_softmax
(
dests_scores
,
dim
=
1
)[:,
dest
:
dest
+
1
])
def
weights_init
(
m
):
'''Function to initialize weights for models
Code from https://gist.github.com/jeasinema/ed9236ce743c8efaf30fa2ff732749f5
Usage:
model = Model()
model.apply(weight_init)
'''
if
isinstance
(
m
,
nn
.
Linear
):
init
.
xavier_normal_
(
m
.
weight
.
data
)
init
.
normal_
(
m
.
bias
.
data
)
elif
isinstance
(
m
,
nn
.
GRUCell
):
for
param
in
m
.
parameters
():
if
len
(
param
.
shape
)
>=
2
:
init
.
orthogonal_
(
param
.
data
)
else
:
init
.
normal_
(
param
.
data
)
def
dgmg_message_weight_init
(
m
):
"""Weight initialization for graph propagation module
These are suggested by the author. This should only be used for
the message passing functions, i.e. fe's in the paper.
"""
def
_weight_init
(
m
):
if
isinstance
(
m
,
nn
.
Linear
):
init
.
normal_
(
m
.
weight
.
data
,
std
=
1.
/
10
)
init
.
normal_
(
m
.
bias
.
data
,
std
=
1.
/
10
)
else
:
raise
ValueError
(
'Expected the input to be of type nn.Linear!'
)
if
isinstance
(
m
,
nn
.
ModuleList
):
for
layer
in
m
:
layer
.
apply
(
_weight_init
)
else
:
m
.
apply
(
_weight_init
)
class
DGMG
(
nn
.
Module
):
"""DGMG model
`Learning Deep Generative Models of Graphs <https://arxiv.org/abs/1803.03324>`__
Users only need to initialize an instance of this class.
Parameters
----------
atom_types : list
E.g. ['C', 'N']
bond_types : list
E.g. [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
node_hidden_size : int
Size of atom representation
num_prop_rounds : int
Number of message passing rounds for each time
dropout : float
Probability for dropout
"""
@
deprecated
(
'Import DGMG from dgllife.model instead.'
,
'class'
)
def
__init__
(
self
,
atom_types
,
bond_types
,
node_hidden_size
,
num_prop_rounds
,
dropout
):
super
(
DGMG
,
self
).
__init__
()
self
.
env
=
MoleculeEnv
(
atom_types
,
bond_types
)
# Graph embedding module
self
.
graph_embed
=
GraphEmbed
(
node_hidden_size
)
# Graph propagation module
# For one-hot encoding, edge_hidden_size is just the number of bond types
self
.
graph_prop
=
GraphProp
(
num_prop_rounds
,
node_hidden_size
,
len
(
self
.
env
.
bond_types
))
# Actions
self
.
add_node_agent
=
AddNode
(
self
.
env
,
self
.
graph_embed
,
node_hidden_size
,
dropout
)
self
.
add_edge_agent
=
AddEdge
(
self
.
env
,
self
.
graph_embed
,
node_hidden_size
,
dropout
)
self
.
choose_dest_agent
=
ChooseDestAndUpdate
(
self
.
env
,
self
.
graph_prop
,
node_hidden_size
,
dropout
)
# Weight initialization
self
.
init_weights
()
def
init_weights
(
self
):
"""Initialize model weights"""
self
.
graph_embed
.
apply
(
weights_init
)
self
.
graph_prop
.
apply
(
weights_init
)
self
.
add_node_agent
.
apply
(
weights_init
)
self
.
add_edge_agent
.
apply
(
weights_init
)
self
.
choose_dest_agent
.
apply
(
weights_init
)
self
.
graph_prop
.
message_funcs
.
apply
(
dgmg_message_weight_init
)
def
count_step
(
self
):
"""Increment the step by 1."""
self
.
step_count
+=
1
def
prepare_log_prob
(
self
,
compute_log_prob
):
"""Setup for returning log likelihood
Parameters
----------
compute_log_prob : bool
Whether to compute log likelihood
"""
self
.
compute_log_prob
=
compute_log_prob
self
.
add_node_agent
.
prepare_log_prob
(
compute_log_prob
)
self
.
add_edge_agent
.
prepare_log_prob
(
compute_log_prob
)
self
.
choose_dest_agent
.
prepare_log_prob
(
compute_log_prob
)
def
add_node_and_update
(
self
,
a
=
None
):
"""Decide if to add a new atom.
If a new atom should be added, update the graph.
Parameters
----------
a : None or int
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
"""
self
.
count_step
()
return
self
.
add_node_agent
(
a
)
def
add_edge_or_not
(
self
,
a
=
None
):
"""Decide if to add a new bond.
Parameters
----------
a : None or int
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
"""
self
.
count_step
()
return
self
.
add_edge_agent
(
a
)
def
choose_dest_and_update
(
self
,
bond_type
,
a
=
None
):
"""Choose destination and connect it to the latest atom.
Add edges for both directions and update the graph.
Parameters
----------
bond_type : int
The type of the new bond to add
a : None or int
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
"""
self
.
count_step
()
self
.
choose_dest_agent
(
bond_type
,
a
)
def
get_log_prob
(
self
):
"""Compute the log likelihood for the decision sequence,
typically corresponding to the generation of a molecule.
Returns
-------
torch.tensor consisting of a float only
"""
return
torch
.
cat
(
self
.
add_node_agent
.
log_prob
).
sum
()
\
+
torch
.
cat
(
self
.
add_edge_agent
.
log_prob
).
sum
()
\
+
torch
.
cat
(
self
.
choose_dest_agent
.
log_prob
).
sum
()
def
teacher_forcing
(
self
,
actions
):
"""Generate a molecule according to a sequence of actions.
Parameters
----------
actions : list of 2-tuples of int
actions[t] gives (i, j), the action to execute by DGMG at timestep t.
- If i = 0, j specifies either the type of the atom to add or termination
- If i = 1, j specifies either the type of the bond to add or termination
- If i = 2, j specifies the destination atom id for the bond to add.
With the formulation of DGMG, j must be created before the decision.
"""
stop_node
=
self
.
add_node_and_update
(
a
=
actions
[
self
.
step_count
][
1
])
while
not
stop_node
:
# A new atom was just added.
stop_edge
,
bond_type
=
self
.
add_edge_or_not
(
a
=
actions
[
self
.
step_count
][
1
])
while
not
stop_edge
:
# A new bond is to be added.
self
.
choose_dest_and_update
(
bond_type
,
a
=
actions
[
self
.
step_count
][
1
])
stop_edge
,
bond_type
=
self
.
add_edge_or_not
(
a
=
actions
[
self
.
step_count
][
1
])
stop_node
=
self
.
add_node_and_update
(
a
=
actions
[
self
.
step_count
][
1
])
def
rollout
(
self
,
max_num_steps
):
"""Sample a molecule from the distribution learned by DGMG."""
stop_node
=
self
.
add_node_and_update
()
while
(
not
stop_node
)
and
(
self
.
step_count
<=
max_num_steps
):
stop_edge
,
bond_type
=
self
.
add_edge_or_not
()
if
self
.
env
.
num_atoms
()
==
1
:
stop_edge
=
True
while
(
not
stop_edge
)
and
(
self
.
step_count
<=
max_num_steps
):
self
.
choose_dest_and_update
(
bond_type
)
stop_edge
,
bond_type
=
self
.
add_edge_or_not
()
stop_node
=
self
.
add_node_and_update
()
def
forward
(
self
,
actions
=
None
,
rdkit_mol
=
False
,
compute_log_prob
=
False
,
max_num_steps
=
400
):
"""
Parameters
----------
actions : list of 2-tuples or None.
If actions are not None, generate a molecule according to actions.
Otherwise, a molecule will be generated based on sampled actions.
rdkit_mol : bool
Whether to maintain a Chem.rdchem.Mol object. This brings extra
computational cost, but is necessary if we are interested in
learning the generated molecule.
compute_log_prob : bool
Whether to compute log likelihood
max_num_steps : int
Maximum number of steps allowed. This only comes into effect
during inference and prevents the model from not stopping.
Returns
-------
torch.tensor consisting of a float only, optional
The log likelihood for the actions taken
str, optional
The generated molecule in the form of SMILES
"""
# Initialize an empty molecule
self
.
step_count
=
0
self
.
env
.
reset
(
rdkit_mol
=
rdkit_mol
)
self
.
prepare_log_prob
(
compute_log_prob
)
if
actions
is
not
None
:
# A sequence of decisions is given, use teacher forcing
self
.
teacher_forcing
(
actions
)
else
:
# Sample a molecule from the distribution learned by DGMG
self
.
rollout
(
max_num_steps
)
if
compute_log_prob
and
rdkit_mol
:
return
self
.
get_log_prob
(),
self
.
env
.
get_current_smiles
()
if
compute_log_prob
:
return
self
.
get_log_prob
()
if
rdkit_mol
:
return
self
.
env
.
get_current_smiles
()
python/dgl/model_zoo/chem/gnn.py
deleted
100644 → 0
View file @
94c67203
# pylint: disable=C0103, E1101
"""GNN layers for updating atom representations"""
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
...nn.pytorch
import
GraphConv
,
GATConv
class
GCNLayer
(
nn
.
Module
):
"""Single layer GCN for updating node features
Parameters
----------
in_feats : int
Number of input atom features
out_feats : int
Number of output atom features
activation : activation function
Default to be ReLU
residual : bool
Whether to use residual connection, default to be True
batchnorm : bool
Whether to use batch normalization on the output,
default to be True
dropout : float
The probability for dropout. Default to be 0., i.e. no
dropout is performed.
"""
def
__init__
(
self
,
in_feats
,
out_feats
,
activation
=
F
.
relu
,
residual
=
True
,
batchnorm
=
True
,
dropout
=
0.
):
super
(
GCNLayer
,
self
).
__init__
()
self
.
activation
=
activation
self
.
graph_conv
=
GraphConv
(
in_feats
=
in_feats
,
out_feats
=
out_feats
,
norm
=
False
,
activation
=
activation
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
residual
=
residual
if
residual
:
self
.
res_connection
=
nn
.
Linear
(
in_feats
,
out_feats
)
self
.
bn
=
batchnorm
if
batchnorm
:
self
.
bn_layer
=
nn
.
BatchNorm1d
(
out_feats
)
def
forward
(
self
,
g
,
feats
):
"""Update atom representations
Parameters
----------
g : DGLGraph
DGLGraph with batch size B for processing multiple molecules in parallel
feats : FloatTensor of shape (N, M1)
* N is the total number of atoms in the batched graph
* M1 is the input atom feature size, must match in_feats in initialization
Returns
-------
new_feats : FloatTensor of shape (N, M2)
* M2 is the output atom feature size, must match out_feats in initialization
"""
new_feats
=
self
.
graph_conv
(
g
,
feats
)
if
self
.
residual
:
res_feats
=
self
.
activation
(
self
.
res_connection
(
feats
))
new_feats
=
new_feats
+
res_feats
new_feats
=
self
.
dropout
(
new_feats
)
if
self
.
bn
:
new_feats
=
self
.
bn_layer
(
new_feats
)
return
new_feats
class
GATLayer
(
nn
.
Module
):
"""Single layer GAT for updating node features
Parameters
----------
in_feats : int
Number of input atom features
out_feats : int
Number of output atom features for each attention head
num_heads : int
Number of attention heads
feat_drop : float
Dropout applied to the input features
attn_drop : float
Dropout applied to attention values of edges
alpha : float
Hyperparameter in LeakyReLU, slope for negative values. Default to be 0.2
residual : bool
Whether to perform skip connection, default to be False
agg_mode : str
The way to aggregate multi-head attention results, can be either
'flatten' for concatenating all head results or 'mean' for averaging
all head results
activation : activation function or None
Activation function applied to aggregated multi-head results, default to be None.
"""
def
__init__
(
self
,
in_feats
,
out_feats
,
num_heads
,
feat_drop
,
attn_drop
,
alpha
=
0.2
,
residual
=
True
,
agg_mode
=
'flatten'
,
activation
=
None
):
super
(
GATLayer
,
self
).
__init__
()
self
.
gnn
=
GATConv
(
in_feats
=
in_feats
,
out_feats
=
out_feats
,
num_heads
=
num_heads
,
feat_drop
=
feat_drop
,
attn_drop
=
attn_drop
,
negative_slope
=
alpha
,
residual
=
residual
)
assert
agg_mode
in
[
'flatten'
,
'mean'
]
self
.
agg_mode
=
agg_mode
self
.
activation
=
activation
def
forward
(
self
,
bg
,
feats
):
"""Update atom representations
Parameters
----------
bg : DGLGraph
Batched DGLGraphs for processing multiple molecules in parallel
feats : FloatTensor of shape (N, M1)
* N is the total number of atoms in the batched graph
* M1 is the input atom feature size, must match in_feats in initialization
Returns
-------
new_feats : FloatTensor of shape (N, M2)
* M2 is the output atom feature size. If self.agg_mode == 'flatten', this would
be out_feats * num_heads, else it would be just out_feats.
"""
new_feats
=
self
.
gnn
(
bg
,
feats
)
if
self
.
agg_mode
==
'flatten'
:
new_feats
=
new_feats
.
flatten
(
1
)
else
:
new_feats
=
new_feats
.
mean
(
1
)
if
self
.
activation
is
not
None
:
new_feats
=
self
.
activation
(
new_feats
)
return
new_feats
python/dgl/model_zoo/chem/jtnn/__init__.py
deleted
100644 → 0
View file @
94c67203
"""JTNN Module"""
from
.chemutils
import
decode_stereo
from
.jtnn_vae
import
DGLJTNNVAE
from
.mol_tree
import
Vocab
from
.mpn
import
DGLMPN
from
.nnutils
import
create_var
,
cuda
python/dgl/model_zoo/chem/jtnn/chemutils.py
deleted
100644 → 0
View file @
94c67203
# pylint: disable=C0111, C0103, E1101, W0611, W0612, W0703, C0200, R1710
from
collections
import
defaultdict
import
rdkit.Chem
as
Chem
from
rdkit.Chem.EnumerateStereoisomers
import
EnumerateStereoisomers
from
scipy.sparse
import
csr_matrix
from
scipy.sparse.csgraph
import
minimum_spanning_tree
MST_MAX_WEIGHT
=
100
MAX_NCAND
=
2000
def
set_atommap
(
mol
,
num
=
0
):
for
atom
in
mol
.
GetAtoms
():
atom
.
SetAtomMapNum
(
num
)
def
get_mol
(
smiles
):
mol
=
Chem
.
MolFromSmiles
(
smiles
)
if
mol
is
None
:
return
None
Chem
.
Kekulize
(
mol
)
return
mol
def
get_smiles
(
mol
):
return
Chem
.
MolToSmiles
(
mol
,
kekuleSmiles
=
True
)
def
decode_stereo
(
smiles2D
):
mol
=
Chem
.
MolFromSmiles
(
smiles2D
)
dec_isomers
=
list
(
EnumerateStereoisomers
(
mol
))
dec_isomers
=
[
Chem
.
MolFromSmiles
(
Chem
.
MolToSmiles
(
mol
,
isomericSmiles
=
True
))
for
mol
in
dec_isomers
]
smiles3D
=
[
Chem
.
MolToSmiles
(
mol
,
isomericSmiles
=
True
)
for
mol
in
dec_isomers
]
chiralN
=
[
atom
.
GetIdx
()
for
atom
in
dec_isomers
[
0
].
GetAtoms
()
if
int
(
atom
.
GetChiralTag
())
>
0
and
atom
.
GetSymbol
()
==
"N"
]
if
len
(
chiralN
)
>
0
:
for
mol
in
dec_isomers
:
for
idx
in
chiralN
:
mol
.
GetAtomWithIdx
(
idx
).
SetChiralTag
(
Chem
.
rdchem
.
ChiralType
.
CHI_UNSPECIFIED
)
smiles3D
.
append
(
Chem
.
MolToSmiles
(
mol
,
isomericSmiles
=
True
))
return
smiles3D
def
sanitize
(
mol
):
try
:
smiles
=
get_smiles
(
mol
)
mol
=
get_mol
(
smiles
)
except
Exception
:
return
None
return
mol
def
copy_atom
(
atom
):
new_atom
=
Chem
.
Atom
(
atom
.
GetSymbol
())
new_atom
.
SetFormalCharge
(
atom
.
GetFormalCharge
())
new_atom
.
SetAtomMapNum
(
atom
.
GetAtomMapNum
())
return
new_atom
def
copy_edit_mol
(
mol
):
new_mol
=
Chem
.
RWMol
(
Chem
.
MolFromSmiles
(
''
))
for
atom
in
mol
.
GetAtoms
():
new_atom
=
copy_atom
(
atom
)
new_mol
.
AddAtom
(
new_atom
)
for
bond
in
mol
.
GetBonds
():
a1
=
bond
.
GetBeginAtom
().
GetIdx
()
a2
=
bond
.
GetEndAtom
().
GetIdx
()
bt
=
bond
.
GetBondType
()
new_mol
.
AddBond
(
a1
,
a2
,
bt
)
return
new_mol
def
get_clique_mol
(
mol
,
atoms
):
smiles
=
Chem
.
MolFragmentToSmiles
(
mol
,
atoms
,
kekuleSmiles
=
True
)
new_mol
=
Chem
.
MolFromSmiles
(
smiles
,
sanitize
=
False
)
new_mol
=
copy_edit_mol
(
new_mol
).
GetMol
()
new_mol
=
sanitize
(
new_mol
)
# We assume this is not None
return
new_mol
def
tree_decomp
(
mol
):
n_atoms
=
mol
.
GetNumAtoms
()
if
n_atoms
==
1
:
return
[[
0
]],
[]
cliques
=
[]
for
bond
in
mol
.
GetBonds
():
a1
=
bond
.
GetBeginAtom
().
GetIdx
()
a2
=
bond
.
GetEndAtom
().
GetIdx
()
if
not
bond
.
IsInRing
():
cliques
.
append
([
a1
,
a2
])
ssr
=
[
list
(
x
)
for
x
in
Chem
.
GetSymmSSSR
(
mol
)]
cliques
.
extend
(
ssr
)
nei_list
=
[[]
for
i
in
range
(
n_atoms
)]
for
i
in
range
(
len
(
cliques
)):
for
atom
in
cliques
[
i
]:
nei_list
[
atom
].
append
(
i
)
# Merge Rings with intersection > 2 atoms
for
i
in
range
(
len
(
cliques
)):
if
len
(
cliques
[
i
])
<=
2
:
continue
for
atom
in
cliques
[
i
]:
for
j
in
nei_list
[
atom
]:
if
i
>=
j
or
len
(
cliques
[
j
])
<=
2
:
continue
inter
=
set
(
cliques
[
i
])
&
set
(
cliques
[
j
])
if
len
(
inter
)
>
2
:
cliques
[
i
].
extend
(
cliques
[
j
])
cliques
[
i
]
=
list
(
set
(
cliques
[
i
]))
cliques
[
j
]
=
[]
cliques
=
[
c
for
c
in
cliques
if
len
(
c
)
>
0
]
nei_list
=
[[]
for
i
in
range
(
n_atoms
)]
for
i
in
range
(
len
(
cliques
)):
for
atom
in
cliques
[
i
]:
nei_list
[
atom
].
append
(
i
)
# Build edges and add singleton cliques
edges
=
defaultdict
(
int
)
for
atom
in
range
(
n_atoms
):
if
len
(
nei_list
[
atom
])
<=
1
:
continue
cnei
=
nei_list
[
atom
]
bonds
=
[
c
for
c
in
cnei
if
len
(
cliques
[
c
])
==
2
]
rings
=
[
c
for
c
in
cnei
if
len
(
cliques
[
c
])
>
4
]
# In general, if len(cnei) >= 3, a singleton should be added, but 1
# bond + 2 ring is currently not dealt with.
if
len
(
bonds
)
>
2
or
(
len
(
bonds
)
==
2
and
len
(
cnei
)
>
2
):
cliques
.
append
([
atom
])
c2
=
len
(
cliques
)
-
1
for
c1
in
cnei
:
edges
[(
c1
,
c2
)]
=
1
elif
len
(
rings
)
>
2
:
# Multiple (n>2) complex rings
cliques
.
append
([
atom
])
c2
=
len
(
cliques
)
-
1
for
c1
in
cnei
:
edges
[(
c1
,
c2
)]
=
MST_MAX_WEIGHT
-
1
else
:
for
i
in
range
(
len
(
cnei
)):
for
j
in
range
(
i
+
1
,
len
(
cnei
)):
c1
,
c2
=
cnei
[
i
],
cnei
[
j
]
inter
=
set
(
cliques
[
c1
])
&
set
(
cliques
[
c2
])
if
edges
[(
c1
,
c2
)]
<
len
(
inter
):
# cnei[i] < cnei[j] by construction
edges
[(
c1
,
c2
)]
=
len
(
inter
)
edges
=
[
u
+
(
MST_MAX_WEIGHT
-
v
,)
for
u
,
v
in
edges
.
items
()]
if
len
(
edges
)
==
0
:
return
cliques
,
edges
# Compute Maximum Spanning Tree
row
,
col
,
data
=
list
(
zip
(
*
edges
))
n_clique
=
len
(
cliques
)
clique_graph
=
csr_matrix
((
data
,
(
row
,
col
)),
shape
=
(
n_clique
,
n_clique
))
junc_tree
=
minimum_spanning_tree
(
clique_graph
)
row
,
col
=
junc_tree
.
nonzero
()
edges
=
[(
row
[
i
],
col
[
i
])
for
i
in
range
(
len
(
row
))]
return
(
cliques
,
edges
)
def
atom_equal
(
a1
,
a2
):
return
a1
.
GetSymbol
()
==
a2
.
GetSymbol
()
and
a1
.
GetFormalCharge
()
==
a2
.
GetFormalCharge
()
# Bond type not considered because all aromatic (so SINGLE matches DOUBLE)
def
ring_bond_equal
(
b1
,
b2
,
reverse
=
False
):
b1
=
(
b1
.
GetBeginAtom
(),
b1
.
GetEndAtom
())
if
reverse
:
b2
=
(
b2
.
GetEndAtom
(),
b2
.
GetBeginAtom
())
else
:
b2
=
(
b2
.
GetBeginAtom
(),
b2
.
GetEndAtom
())
return
atom_equal
(
b1
[
0
],
b2
[
0
])
and
atom_equal
(
b1
[
1
],
b2
[
1
])
def
attach_mols_nx
(
ctr_mol
,
neighbors
,
prev_nodes
,
nei_amap
):
prev_nids
=
[
node
[
'nid'
]
for
node
in
prev_nodes
]
for
nei_node
in
prev_nodes
+
neighbors
:
nei_id
,
nei_mol
=
nei_node
[
'nid'
],
nei_node
[
'mol'
]
amap
=
nei_amap
[
nei_id
]
for
atom
in
nei_mol
.
GetAtoms
():
if
atom
.
GetIdx
()
not
in
amap
:
new_atom
=
copy_atom
(
atom
)
amap
[
atom
.
GetIdx
()]
=
ctr_mol
.
AddAtom
(
new_atom
)
if
nei_mol
.
GetNumBonds
()
==
0
:
nei_atom
=
nei_mol
.
GetAtomWithIdx
(
0
)
ctr_atom
=
ctr_mol
.
GetAtomWithIdx
(
amap
[
0
])
ctr_atom
.
SetAtomMapNum
(
nei_atom
.
GetAtomMapNum
())
else
:
for
bond
in
nei_mol
.
GetBonds
():
a1
=
amap
[
bond
.
GetBeginAtom
().
GetIdx
()]
a2
=
amap
[
bond
.
GetEndAtom
().
GetIdx
()]
if
ctr_mol
.
GetBondBetweenAtoms
(
a1
,
a2
)
is
None
:
ctr_mol
.
AddBond
(
a1
,
a2
,
bond
.
GetBondType
())
elif
nei_id
in
prev_nids
:
# father node overrides
ctr_mol
.
RemoveBond
(
a1
,
a2
)
ctr_mol
.
AddBond
(
a1
,
a2
,
bond
.
GetBondType
())
return
ctr_mol
def
local_attach_nx
(
ctr_mol
,
neighbors
,
prev_nodes
,
amap_list
):
ctr_mol
=
copy_edit_mol
(
ctr_mol
)
nei_amap
=
{
nei
[
'nid'
]:
{}
for
nei
in
prev_nodes
+
neighbors
}
for
nei_id
,
ctr_atom
,
nei_atom
in
amap_list
:
nei_amap
[
nei_id
][
nei_atom
]
=
ctr_atom
ctr_mol
=
attach_mols_nx
(
ctr_mol
,
neighbors
,
prev_nodes
,
nei_amap
)
return
ctr_mol
.
GetMol
()
# This version records idx mapping between ctr_mol and nei_mol
def
enum_attach_nx
(
ctr_mol
,
nei_node
,
amap
,
singletons
):
nei_mol
,
nei_idx
=
nei_node
[
'mol'
],
nei_node
[
'nid'
]
att_confs
=
[]
black_list
=
[
atom_idx
for
nei_id
,
atom_idx
,
_
in
amap
if
nei_id
in
singletons
]
ctr_atoms
=
[
atom
for
atom
in
ctr_mol
.
GetAtoms
()
if
atom
.
GetIdx
()
not
in
black_list
]
ctr_bonds
=
[
bond
for
bond
in
ctr_mol
.
GetBonds
()]
if
nei_mol
.
GetNumBonds
()
==
0
:
# neighbor singleton
nei_atom
=
nei_mol
.
GetAtomWithIdx
(
0
)
used_list
=
[
atom_idx
for
_
,
atom_idx
,
_
in
amap
]
for
atom
in
ctr_atoms
:
if
atom_equal
(
atom
,
nei_atom
)
and
atom
.
GetIdx
()
not
in
used_list
:
new_amap
=
amap
+
[(
nei_idx
,
atom
.
GetIdx
(),
0
)]
att_confs
.
append
(
new_amap
)
elif
nei_mol
.
GetNumBonds
()
==
1
:
# neighbor is a bond
bond
=
nei_mol
.
GetBondWithIdx
(
0
)
bond_val
=
int
(
bond
.
GetBondTypeAsDouble
())
b1
,
b2
=
bond
.
GetBeginAtom
(),
bond
.
GetEndAtom
()
for
atom
in
ctr_atoms
:
# Optimize if atom is carbon (other atoms may change valence)
if
atom
.
GetAtomicNum
()
==
6
and
atom
.
GetTotalNumHs
()
<
bond_val
:
continue
if
atom_equal
(
atom
,
b1
):
new_amap
=
amap
+
[(
nei_idx
,
atom
.
GetIdx
(),
b1
.
GetIdx
())]
att_confs
.
append
(
new_amap
)
elif
atom_equal
(
atom
,
b2
):
new_amap
=
amap
+
[(
nei_idx
,
atom
.
GetIdx
(),
b2
.
GetIdx
())]
att_confs
.
append
(
new_amap
)
else
:
# intersection is an atom
for
a1
in
ctr_atoms
:
for
a2
in
nei_mol
.
GetAtoms
():
if
atom_equal
(
a1
,
a2
):
# Optimize if atom is carbon (other atoms may change
# valence)
if
a1
.
GetAtomicNum
()
==
6
and
a1
.
GetTotalNumHs
()
+
a2
.
GetTotalNumHs
()
<
4
:
continue
new_amap
=
amap
+
[(
nei_idx
,
a1
.
GetIdx
(),
a2
.
GetIdx
())]
att_confs
.
append
(
new_amap
)
# intersection is an bond
if
ctr_mol
.
GetNumBonds
()
>
1
:
for
b1
in
ctr_bonds
:
for
b2
in
nei_mol
.
GetBonds
():
if
ring_bond_equal
(
b1
,
b2
):
new_amap
=
amap
+
[(
nei_idx
,
b1
.
GetBeginAtom
().
GetIdx
(),
b2
.
GetBeginAtom
().
GetIdx
()),
(
nei_idx
,
b1
.
GetEndAtom
().
GetIdx
(),
b2
.
GetEndAtom
().
GetIdx
())]
att_confs
.
append
(
new_amap
)
if
ring_bond_equal
(
b1
,
b2
,
reverse
=
True
):
new_amap
=
amap
+
[(
nei_idx
,
b1
.
GetBeginAtom
().
GetIdx
(),
b2
.
GetEndAtom
().
GetIdx
()),
(
nei_idx
,
b1
.
GetEndAtom
().
GetIdx
(),
b2
.
GetBeginAtom
().
GetIdx
())]
att_confs
.
append
(
new_amap
)
return
att_confs
# Try rings first: Speed-Up
def
enum_assemble_nx
(
node
,
neighbors
,
prev_nodes
=
None
,
prev_amap
=
None
):
if
prev_nodes
is
None
:
prev_nodes
=
[]
if
prev_amap
is
None
:
prev_amap
=
[]
all_attach_confs
=
[]
singletons
=
[
nei_node
[
'nid'
]
for
nei_node
in
neighbors
+
prev_nodes
if
nei_node
[
'mol'
].
GetNumAtoms
()
==
1
]
def
search
(
cur_amap
,
depth
):
if
len
(
all_attach_confs
)
>
MAX_NCAND
:
return
None
if
depth
==
len
(
neighbors
):
all_attach_confs
.
append
(
cur_amap
)
return
None
nei_node
=
neighbors
[
depth
]
cand_amap
=
enum_attach_nx
(
node
[
'mol'
],
nei_node
,
cur_amap
,
singletons
)
cand_smiles
=
set
()
candidates
=
[]
for
amap
in
cand_amap
:
cand_mol
=
local_attach_nx
(
node
[
'mol'
],
neighbors
[:
depth
+
1
],
prev_nodes
,
amap
)
cand_mol
=
sanitize
(
cand_mol
)
if
cand_mol
is
None
:
continue
smiles
=
get_smiles
(
cand_mol
)
if
smiles
in
cand_smiles
:
continue
cand_smiles
.
add
(
smiles
)
candidates
.
append
(
amap
)
if
len
(
candidates
)
==
0
:
return
[]
for
new_amap
in
candidates
:
search
(
new_amap
,
depth
+
1
)
search
(
prev_amap
,
0
)
cand_smiles
=
set
()
candidates
=
[]
for
amap
in
all_attach_confs
:
cand_mol
=
local_attach_nx
(
node
[
'mol'
],
neighbors
,
prev_nodes
,
amap
)
cand_mol
=
Chem
.
MolFromSmiles
(
Chem
.
MolToSmiles
(
cand_mol
))
smiles
=
Chem
.
MolToSmiles
(
cand_mol
)
if
smiles
in
cand_smiles
:
continue
cand_smiles
.
add
(
smiles
)
Chem
.
Kekulize
(
cand_mol
)
candidates
.
append
((
smiles
,
cand_mol
,
amap
))
return
candidates
# Only used for debugging purpose
def
dfs_assemble_nx
(
graph
,
cur_mol
,
global_amap
,
fa_amap
,
cur_node_id
,
fa_node_id
):
cur_node
=
graph
.
nodes_dict
[
cur_node_id
]
fa_node
=
graph
.
nodes_dict
[
fa_node_id
]
if
fa_node_id
is
not
None
else
None
fa_nid
=
fa_node
[
'nid'
]
if
fa_node
is
not
None
else
-
1
prev_nodes
=
[
fa_node
]
if
fa_node
is
not
None
else
[]
children_id
=
[
nei
for
nei
in
graph
[
cur_node_id
]
if
graph
.
nodes_dict
[
nei
][
'nid'
]
!=
fa_nid
]
children
=
[
graph
.
nodes_dict
[
nei
]
for
nei
in
children_id
]
neighbors
=
[
nei
for
nei
in
children
if
nei
[
'mol'
].
GetNumAtoms
()
>
1
]
neighbors
=
sorted
(
neighbors
,
key
=
lambda
x
:
x
[
'mol'
].
GetNumAtoms
(),
reverse
=
True
)
singletons
=
[
nei
for
nei
in
children
if
nei
[
'mol'
].
GetNumAtoms
()
==
1
]
neighbors
=
singletons
+
neighbors
cur_amap
=
[(
fa_nid
,
a2
,
a1
)
for
nid
,
a1
,
a2
in
fa_amap
if
nid
==
cur_node
[
'nid'
]]
cands
=
enum_assemble_nx
(
graph
.
nodes_dict
[
cur_node_id
],
neighbors
,
prev_nodes
,
cur_amap
)
if
len
(
cands
)
==
0
:
return
cand_smiles
,
_
,
cand_amap
=
zip
(
*
cands
)
label_idx
=
cand_smiles
.
index
(
cur_node
[
'label'
])
label_amap
=
cand_amap
[
label_idx
]
for
nei_id
,
ctr_atom
,
nei_atom
in
label_amap
:
if
nei_id
==
fa_nid
:
continue
global_amap
[
nei_id
][
nei_atom
]
=
global_amap
[
cur_node
[
'nid'
]][
ctr_atom
]
# father is already attached
cur_mol
=
attach_mols_nx
(
cur_mol
,
children
,
[],
global_amap
)
for
nei_node_id
,
nei_node
in
zip
(
children_id
,
children
):
if
not
nei_node
[
'is_leaf'
]:
dfs_assemble_nx
(
graph
,
cur_mol
,
global_amap
,
label_amap
,
nei_node_id
,
cur_node_id
)
python/dgl/model_zoo/chem/jtnn/jtmpn.py
deleted
100644 → 0
View file @
94c67203
# pylint: disable=C0111, C0103, E1101, W0611, W0612, W1508
# pylint: disable=redefined-outer-name
import
os
import
rdkit.Chem
as
Chem
import
torch
import
torch.nn
as
nn
import
dgl.function
as
DGLF
from
dgl
import
DGLGraph
,
mean_nodes
from
.nnutils
import
cuda
ELEM_LIST
=
[
'C'
,
'N'
,
'O'
,
'S'
,
'F'
,
'Si'
,
'P'
,
'Cl'
,
'Br'
,
'Mg'
,
'Na'
,
'Ca'
,
'Fe'
,
'Al'
,
'I'
,
'B'
,
'K'
,
'Se'
,
'Zn'
,
'H'
,
'Cu'
,
'Mn'
,
'unknown'
]
ATOM_FDIM
=
len
(
ELEM_LIST
)
+
6
+
5
+
1
BOND_FDIM
=
5
MAX_NB
=
10
PAPER
=
os
.
getenv
(
'PAPER'
,
False
)
def
onek_encoding_unk
(
x
,
allowable_set
):
if
x
not
in
allowable_set
:
x
=
allowable_set
[
-
1
]
return
[
x
==
s
for
s
in
allowable_set
]
# Note that during graph decoding they don't predict stereochemistry-related
# characteristics (i.e. Chiral Atoms, E-Z, Cis-Trans). Instead, they decode
# the 2-D graph first, then enumerate all possible 3-D forms and find the
# one with highest score.
def
atom_features
(
atom
):
return
(
torch
.
Tensor
(
onek_encoding_unk
(
atom
.
GetSymbol
(),
ELEM_LIST
)
+
onek_encoding_unk
(
atom
.
GetDegree
(),
[
0
,
1
,
2
,
3
,
4
,
5
])
+
onek_encoding_unk
(
atom
.
GetFormalCharge
(),
[
-
1
,
-
2
,
1
,
2
,
0
])
+
[
atom
.
GetIsAromatic
()]))
def
bond_features
(
bond
):
bt
=
bond
.
GetBondType
()
return
torch
.
Tensor
([
bt
==
Chem
.
rdchem
.
BondType
.
SINGLE
,
bt
==
Chem
.
rdchem
.
BondType
.
DOUBLE
,
bt
==
Chem
.
rdchem
.
BondType
.
TRIPLE
,
bt
==
Chem
.
rdchem
.
BondType
.
AROMATIC
,
bond
.
IsInRing
()])
def
mol2dgl_single
(
cand_batch
):
cand_graphs
=
[]
tree_mess_source_edges
=
[]
# map these edges from trees to...
tree_mess_target_edges
=
[]
# these edges on candidate graphs
tree_mess_target_nodes
=
[]
n_nodes
=
0
atom_x
=
[]
bond_x
=
[]
for
mol
,
mol_tree
,
ctr_node_id
in
cand_batch
:
n_atoms
=
mol
.
GetNumAtoms
()
g
=
DGLGraph
()
for
i
,
atom
in
enumerate
(
mol
.
GetAtoms
()):
assert
i
==
atom
.
GetIdx
()
atom_x
.
append
(
atom_features
(
atom
))
g
.
add_nodes
(
n_atoms
)
bond_src
=
[]
bond_dst
=
[]
for
i
,
bond
in
enumerate
(
mol
.
GetBonds
()):
a1
=
bond
.
GetBeginAtom
()
a2
=
bond
.
GetEndAtom
()
begin_idx
=
a1
.
GetIdx
()
end_idx
=
a2
.
GetIdx
()
features
=
bond_features
(
bond
)
bond_src
.
append
(
begin_idx
)
bond_dst
.
append
(
end_idx
)
bond_x
.
append
(
features
)
bond_src
.
append
(
end_idx
)
bond_dst
.
append
(
begin_idx
)
bond_x
.
append
(
features
)
x_nid
,
y_nid
=
a1
.
GetAtomMapNum
(),
a2
.
GetAtomMapNum
()
# Tree node ID in the batch
x_bid
=
mol_tree
.
nodes_dict
[
x_nid
-
1
][
'idx'
]
if
x_nid
>
0
else
-
1
y_bid
=
mol_tree
.
nodes_dict
[
y_nid
-
1
][
'idx'
]
if
y_nid
>
0
else
-
1
if
x_bid
>=
0
and
y_bid
>=
0
and
x_bid
!=
y_bid
:
if
mol_tree
.
has_edge_between
(
x_bid
,
y_bid
):
tree_mess_target_edges
.
append
(
(
begin_idx
+
n_nodes
,
end_idx
+
n_nodes
))
tree_mess_source_edges
.
append
((
x_bid
,
y_bid
))
tree_mess_target_nodes
.
append
(
end_idx
+
n_nodes
)
if
mol_tree
.
has_edge_between
(
y_bid
,
x_bid
):
tree_mess_target_edges
.
append
(
(
end_idx
+
n_nodes
,
begin_idx
+
n_nodes
))
tree_mess_source_edges
.
append
((
y_bid
,
x_bid
))
tree_mess_target_nodes
.
append
(
begin_idx
+
n_nodes
)
n_nodes
+=
n_atoms
g
.
add_edges
(
bond_src
,
bond_dst
)
cand_graphs
.
append
(
g
)
return
cand_graphs
,
torch
.
stack
(
atom_x
),
\
torch
.
stack
(
bond_x
)
if
len
(
bond_x
)
>
0
else
torch
.
zeros
(
0
),
\
torch
.
LongTensor
(
tree_mess_source_edges
),
\
torch
.
LongTensor
(
tree_mess_target_edges
),
\
torch
.
LongTensor
(
tree_mess_target_nodes
)
mpn_loopy_bp_msg
=
DGLF
.
copy_src
(
src
=
'msg'
,
out
=
'msg'
)
mpn_loopy_bp_reduce
=
DGLF
.
sum
(
msg
=
'msg'
,
out
=
'accum_msg'
)
class
LoopyBPUpdate
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
):
super
(
LoopyBPUpdate
,
self
).
__init__
()
self
.
hidden_size
=
hidden_size
self
.
W_h
=
nn
.
Linear
(
hidden_size
,
hidden_size
,
bias
=
False
)
def
forward
(
self
,
node
):
msg_input
=
node
.
data
[
'msg_input'
]
msg_delta
=
self
.
W_h
(
node
.
data
[
'accum_msg'
]
+
node
.
data
[
'alpha'
])
msg
=
torch
.
relu
(
msg_input
+
msg_delta
)
return
{
'msg'
:
msg
}
if
PAPER
:
mpn_gather_msg
=
[
DGLF
.
copy_edge
(
edge
=
'msg'
,
out
=
'msg'
),
DGLF
.
copy_edge
(
edge
=
'alpha'
,
out
=
'alpha'
)
]
else
:
mpn_gather_msg
=
DGLF
.
copy_edge
(
edge
=
'msg'
,
out
=
'msg'
)
if
PAPER
:
mpn_gather_reduce
=
[
DGLF
.
sum
(
msg
=
'msg'
,
out
=
'm'
),
DGLF
.
sum
(
msg
=
'alpha'
,
out
=
'accum_alpha'
),
]
else
:
mpn_gather_reduce
=
DGLF
.
sum
(
msg
=
'msg'
,
out
=
'm'
)
class
GatherUpdate
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
):
super
(
GatherUpdate
,
self
).
__init__
()
self
.
hidden_size
=
hidden_size
self
.
W_o
=
nn
.
Linear
(
ATOM_FDIM
+
hidden_size
,
hidden_size
)
def
forward
(
self
,
node
):
if
PAPER
:
#m = node['m']
m
=
node
.
data
[
'm'
]
+
node
.
data
[
'accum_alpha'
]
else
:
m
=
node
.
data
[
'm'
]
+
node
.
data
[
'alpha'
]
return
{
'h'
:
torch
.
relu
(
self
.
W_o
(
torch
.
cat
([
node
.
data
[
'x'
],
m
],
1
))),
}
class
DGLJTMPN
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
depth
):
nn
.
Module
.
__init__
(
self
)
self
.
depth
=
depth
self
.
W_i
=
nn
.
Linear
(
ATOM_FDIM
+
BOND_FDIM
,
hidden_size
,
bias
=
False
)
self
.
loopy_bp_updater
=
LoopyBPUpdate
(
hidden_size
)
self
.
gather_updater
=
GatherUpdate
(
hidden_size
)
self
.
hidden_size
=
hidden_size
self
.
n_samples_total
=
0
self
.
n_nodes_total
=
0
self
.
n_edges_total
=
0
self
.
n_passes
=
0
def
forward
(
self
,
cand_batch
,
mol_tree_batch
):
cand_graphs
,
tree_mess_src_edges
,
tree_mess_tgt_edges
,
tree_mess_tgt_nodes
=
cand_batch
n_samples
=
len
(
cand_graphs
)
cand_line_graph
=
cand_graphs
.
line_graph
(
backtracking
=
False
,
shared
=
True
)
n_nodes
=
cand_graphs
.
number_of_nodes
()
n_edges
=
cand_graphs
.
number_of_edges
()
cand_graphs
=
self
.
run
(
cand_graphs
,
cand_line_graph
,
tree_mess_src_edges
,
tree_mess_tgt_edges
,
tree_mess_tgt_nodes
,
mol_tree_batch
)
g_repr
=
mean_nodes
(
cand_graphs
,
'h'
)
self
.
n_samples_total
+=
n_samples
self
.
n_nodes_total
+=
n_nodes
self
.
n_edges_total
+=
n_edges
self
.
n_passes
+=
1
return
g_repr
def
run
(
self
,
cand_graphs
,
cand_line_graph
,
tree_mess_src_edges
,
tree_mess_tgt_edges
,
tree_mess_tgt_nodes
,
mol_tree_batch
):
n_nodes
=
cand_graphs
.
number_of_nodes
()
cand_graphs
.
apply_edges
(
func
=
lambda
edges
:
{
'src_x'
:
edges
.
src
[
'x'
]},
)
bond_features
=
cand_line_graph
.
ndata
[
'x'
]
source_features
=
cand_line_graph
.
ndata
[
'src_x'
]
features
=
torch
.
cat
([
source_features
,
bond_features
],
1
)
msg_input
=
self
.
W_i
(
features
)
cand_line_graph
.
ndata
.
update
({
'msg_input'
:
msg_input
,
'msg'
:
torch
.
relu
(
msg_input
),
'accum_msg'
:
torch
.
zeros_like
(
msg_input
),
})
zero_node_state
=
bond_features
.
new
(
n_nodes
,
self
.
hidden_size
).
zero_
()
cand_graphs
.
ndata
.
update
({
'm'
:
zero_node_state
.
clone
(),
'h'
:
zero_node_state
.
clone
(),
})
cand_graphs
.
edata
[
'alpha'
]
=
\
cuda
(
torch
.
zeros
(
cand_graphs
.
number_of_edges
(),
self
.
hidden_size
))
cand_graphs
.
ndata
[
'alpha'
]
=
zero_node_state
if
tree_mess_src_edges
.
shape
[
0
]
>
0
:
if
PAPER
:
src_u
,
src_v
=
tree_mess_src_edges
.
unbind
(
1
)
tgt_u
,
tgt_v
=
tree_mess_tgt_edges
.
unbind
(
1
)
alpha
=
mol_tree_batch
.
edges
[
src_u
,
src_v
].
data
[
'm'
]
cand_graphs
.
edges
[
tgt_u
,
tgt_v
].
data
[
'alpha'
]
=
alpha
else
:
src_u
,
src_v
=
tree_mess_src_edges
.
unbind
(
1
)
alpha
=
mol_tree_batch
.
edges
[
src_u
,
src_v
].
data
[
'm'
]
node_idx
=
(
tree_mess_tgt_nodes
.
to
(
device
=
zero_node_state
.
device
)[:,
None
]
.
expand_as
(
alpha
))
node_alpha
=
zero_node_state
.
clone
().
scatter_add
(
0
,
node_idx
,
alpha
)
cand_graphs
.
ndata
[
'alpha'
]
=
node_alpha
cand_graphs
.
apply_edges
(
func
=
lambda
edges
:
{
'alpha'
:
edges
.
src
[
'alpha'
]},
)
for
i
in
range
(
self
.
depth
-
1
):
cand_line_graph
.
update_all
(
mpn_loopy_bp_msg
,
mpn_loopy_bp_reduce
,
self
.
loopy_bp_updater
,
)
cand_graphs
.
update_all
(
mpn_gather_msg
,
mpn_gather_reduce
,
self
.
gather_updater
,
)
return
cand_graphs
python/dgl/model_zoo/chem/jtnn/jtnn_dec.py
deleted
100644 → 0
View file @
94c67203
# pylint: disable=C0111, C0103, E1101, W0611, W0612
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
dgl.function
as
DGLF
from
dgl
import
batch
,
dfs_labeled_edges_generator
from
.chemutils
import
enum_assemble_nx
,
get_mol
from
.mol_tree_nx
import
DGLMolTree
from
.nnutils
import
GRUUpdate
,
cuda
MAX_NB
=
8
MAX_DECODE_LEN
=
100
def
dfs_order
(
forest
,
roots
):
edges
=
dfs_labeled_edges_generator
(
forest
,
roots
,
has_reverse_edge
=
True
)
for
e
,
l
in
zip
(
*
edges
):
# I exploited the fact that the reverse edge ID equal to 1 xor forward
# edge ID for molecule trees. Normally, I should locate reverse edges
# using find_edges().
yield
e
^
l
,
l
dec_tree_node_msg
=
DGLF
.
copy_edge
(
edge
=
'm'
,
out
=
'm'
)
dec_tree_node_reduce
=
DGLF
.
sum
(
msg
=
'm'
,
out
=
'h'
)
def
dec_tree_node_update
(
nodes
):
return
{
'new'
:
nodes
.
data
[
'new'
].
clone
().
zero_
()}
dec_tree_edge_msg
=
[
DGLF
.
copy_src
(
src
=
'm'
,
out
=
'm'
),
DGLF
.
copy_src
(
src
=
'rm'
,
out
=
'rm'
)]
dec_tree_edge_reduce
=
[
DGLF
.
sum
(
msg
=
'm'
,
out
=
's'
),
DGLF
.
sum
(
msg
=
'rm'
,
out
=
'accum_rm'
)]
def
have_slots
(
fa_slots
,
ch_slots
):
if
len
(
fa_slots
)
>
2
and
len
(
ch_slots
)
>
2
:
return
True
matches
=
[]
for
i
,
s1
in
enumerate
(
fa_slots
):
a1
,
c1
,
h1
=
s1
for
j
,
s2
in
enumerate
(
ch_slots
):
a2
,
c2
,
h2
=
s2
if
a1
==
a2
and
c1
==
c2
and
(
a1
!=
"C"
or
h1
+
h2
>=
4
):
matches
.
append
((
i
,
j
))
if
len
(
matches
)
==
0
:
return
False
fa_match
,
ch_match
=
list
(
zip
(
*
matches
))
if
len
(
set
(
fa_match
))
==
1
and
1
<
len
(
fa_slots
)
<=
2
:
# never remove atom from ring
fa_slots
.
pop
(
fa_match
[
0
])
if
len
(
set
(
ch_match
))
==
1
and
1
<
len
(
ch_slots
)
<=
2
:
# never remove atom from ring
ch_slots
.
pop
(
ch_match
[
0
])
return
True
def
can_assemble
(
mol_tree
,
u
,
v_node_dict
):
u_node_dict
=
mol_tree
.
nodes_dict
[
u
]
u_neighbors
=
mol_tree
.
successors
(
u
)
u_neighbors_node_dict
=
[
mol_tree
.
nodes_dict
[
_u
]
for
_u
in
u_neighbors
if
_u
in
mol_tree
.
nodes_dict
]
neis
=
u_neighbors_node_dict
+
[
v_node_dict
]
for
i
,
nei
in
enumerate
(
neis
):
nei
[
'nid'
]
=
i
neighbors
=
[
nei
for
nei
in
neis
if
nei
[
'mol'
].
GetNumAtoms
()
>
1
]
neighbors
=
sorted
(
neighbors
,
key
=
lambda
x
:
x
[
'mol'
].
GetNumAtoms
(),
reverse
=
True
)
singletons
=
[
nei
for
nei
in
neis
if
nei
[
'mol'
].
GetNumAtoms
()
==
1
]
neighbors
=
singletons
+
neighbors
cands
=
enum_assemble_nx
(
u_node_dict
,
neighbors
)
return
len
(
cands
)
>
0
def
create_node_dict
(
smiles
,
clique
=
None
):
if
clique
is
None
:
clique
=
[]
return
dict
(
smiles
=
smiles
,
mol
=
get_mol
(
smiles
),
clique
=
clique
,
)
class
DGLJTNNDecoder
(
nn
.
Module
):
def
__init__
(
self
,
vocab
,
hidden_size
,
latent_size
,
embedding
=
None
):
nn
.
Module
.
__init__
(
self
)
self
.
hidden_size
=
hidden_size
self
.
vocab_size
=
vocab
.
size
()
self
.
vocab
=
vocab
if
embedding
is
None
:
self
.
embedding
=
nn
.
Embedding
(
self
.
vocab_size
,
hidden_size
)
else
:
self
.
embedding
=
embedding
self
.
dec_tree_edge_update
=
GRUUpdate
(
hidden_size
)
self
.
W
=
nn
.
Linear
(
latent_size
+
hidden_size
,
hidden_size
)
self
.
U
=
nn
.
Linear
(
latent_size
+
2
*
hidden_size
,
hidden_size
)
self
.
W_o
=
nn
.
Linear
(
hidden_size
,
self
.
vocab_size
)
self
.
U_s
=
nn
.
Linear
(
hidden_size
,
1
)
def
forward
(
self
,
mol_trees
,
tree_vec
):
'''
The training procedure which computes the prediction loss given the
ground truth tree
'''
mol_tree_batch
=
batch
(
mol_trees
)
mol_tree_batch_lg
=
mol_tree_batch
.
line_graph
(
backtracking
=
False
,
shared
=
True
)
n_trees
=
len
(
mol_trees
)
return
self
.
run
(
mol_tree_batch
,
mol_tree_batch_lg
,
n_trees
,
tree_vec
)
def
run
(
self
,
mol_tree_batch
,
mol_tree_batch_lg
,
n_trees
,
tree_vec
):
node_offset
=
np
.
cumsum
([
0
]
+
mol_tree_batch
.
batch_num_nodes
)
root_ids
=
node_offset
[:
-
1
]
n_nodes
=
mol_tree_batch
.
number_of_nodes
()
n_edges
=
mol_tree_batch
.
number_of_edges
()
mol_tree_batch
.
ndata
.
update
({
'x'
:
self
.
embedding
(
mol_tree_batch
.
ndata
[
'wid'
]),
'h'
:
cuda
(
torch
.
zeros
(
n_nodes
,
self
.
hidden_size
)),
# whether it's newly generated node
'new'
:
cuda
(
torch
.
ones
(
n_nodes
).
byte
()),
})
mol_tree_batch
.
edata
.
update
({
's'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'm'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'r'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'z'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'src_x'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'dst_x'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'rm'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'accum_rm'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
})
mol_tree_batch
.
apply_edges
(
func
=
lambda
edges
:
{
'src_x'
:
edges
.
src
[
'x'
],
'dst_x'
:
edges
.
dst
[
'x'
]},
)
# input tensors for stop prediction (p) and label prediction (q)
p_inputs
=
[]
p_targets
=
[]
q_inputs
=
[]
q_targets
=
[]
# Predict root
mol_tree_batch
.
pull
(
root_ids
,
dec_tree_node_msg
,
dec_tree_node_reduce
,
dec_tree_node_update
,
)
# Extract hidden states and store them for stop/label prediction
h
=
mol_tree_batch
.
nodes
[
root_ids
].
data
[
'h'
]
x
=
mol_tree_batch
.
nodes
[
root_ids
].
data
[
'x'
]
p_inputs
.
append
(
torch
.
cat
([
x
,
h
,
tree_vec
],
1
))
# If the out degree is 0 we don't generate any edges at all
root_out_degrees
=
mol_tree_batch
.
out_degrees
(
root_ids
)
q_inputs
.
append
(
torch
.
cat
([
h
,
tree_vec
],
1
))
q_targets
.
append
(
mol_tree_batch
.
nodes
[
root_ids
].
data
[
'wid'
])
# Traverse the tree and predict on children
for
eid
,
p
in
dfs_order
(
mol_tree_batch
,
root_ids
):
u
,
v
=
mol_tree_batch
.
find_edges
(
eid
)
p_target_list
=
torch
.
zeros_like
(
root_out_degrees
)
p_target_list
[
root_out_degrees
>
0
]
=
1
-
p
p_target_list
=
p_target_list
[
root_out_degrees
>=
0
]
p_targets
.
append
(
torch
.
tensor
(
p_target_list
))
root_out_degrees
-=
(
root_out_degrees
==
0
).
long
()
root_out_degrees
-=
torch
.
tensor
(
np
.
isin
(
root_ids
,
v
).
astype
(
'int64'
))
mol_tree_batch_lg
.
pull
(
eid
,
dec_tree_edge_msg
,
dec_tree_edge_reduce
,
self
.
dec_tree_edge_update
,
)
is_new
=
mol_tree_batch
.
nodes
[
v
].
data
[
'new'
]
mol_tree_batch
.
pull
(
v
,
dec_tree_node_msg
,
dec_tree_node_reduce
,
dec_tree_node_update
,
)
# Extract
n_repr
=
mol_tree_batch
.
nodes
[
v
].
data
h
=
n_repr
[
'h'
]
x
=
n_repr
[
'x'
]
tree_vec_set
=
tree_vec
[
root_out_degrees
>=
0
]
wid
=
n_repr
[
'wid'
]
p_inputs
.
append
(
torch
.
cat
([
x
,
h
,
tree_vec_set
],
1
))
# Only newly generated nodes are needed for label prediction
# NOTE: The following works since the uncomputed messages are zeros.
q_input
=
torch
.
cat
([
h
,
tree_vec_set
],
1
)[
is_new
]
q_target
=
wid
[
is_new
]
if
q_input
.
shape
[
0
]
>
0
:
q_inputs
.
append
(
q_input
)
q_targets
.
append
(
q_target
)
p_targets
.
append
(
torch
.
zeros
((
root_out_degrees
==
0
).
sum
()).
long
())
# Batch compute the stop/label prediction losses
p_inputs
=
torch
.
cat
(
p_inputs
,
0
)
p_targets
=
cuda
(
torch
.
cat
(
p_targets
,
0
))
q_inputs
=
torch
.
cat
(
q_inputs
,
0
)
q_targets
=
torch
.
cat
(
q_targets
,
0
)
q
=
self
.
W_o
(
torch
.
relu
(
self
.
W
(
q_inputs
)))
p
=
self
.
U_s
(
torch
.
relu
(
self
.
U
(
p_inputs
)))[:,
0
]
p_loss
=
F
.
binary_cross_entropy_with_logits
(
p
,
p_targets
.
float
(),
size_average
=
False
)
/
n_trees
q_loss
=
F
.
cross_entropy
(
q
,
q_targets
,
size_average
=
False
)
/
n_trees
p_acc
=
((
p
>
0
).
long
()
==
p_targets
).
sum
().
float
()
/
\
p_targets
.
shape
[
0
]
q_acc
=
(
q
.
max
(
1
)[
1
]
==
q_targets
).
float
().
sum
()
/
q_targets
.
shape
[
0
]
self
.
q_inputs
=
q_inputs
self
.
q_targets
=
q_targets
self
.
q
=
q
self
.
p_inputs
=
p_inputs
self
.
p_targets
=
p_targets
self
.
p
=
p
return
q_loss
,
p_loss
,
q_acc
,
p_acc
def
decode
(
self
,
mol_vec
):
assert
mol_vec
.
shape
[
0
]
==
1
mol_tree
=
DGLMolTree
(
None
)
init_hidden
=
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
))
root_hidden
=
torch
.
cat
([
init_hidden
,
mol_vec
],
1
)
root_hidden
=
F
.
relu
(
self
.
W
(
root_hidden
))
root_score
=
self
.
W_o
(
root_hidden
)
_
,
root_wid
=
torch
.
max
(
root_score
,
1
)
root_wid
=
root_wid
.
view
(
1
)
mol_tree
.
add_nodes
(
1
)
# root
mol_tree
.
nodes
[
0
].
data
[
'wid'
]
=
root_wid
mol_tree
.
nodes
[
0
].
data
[
'x'
]
=
self
.
embedding
(
root_wid
)
mol_tree
.
nodes
[
0
].
data
[
'h'
]
=
init_hidden
mol_tree
.
nodes
[
0
].
data
[
'fail'
]
=
cuda
(
torch
.
tensor
([
0
]))
mol_tree
.
nodes_dict
[
0
]
=
root_node_dict
=
create_node_dict
(
self
.
vocab
.
get_smiles
(
root_wid
))
stack
,
trace
=
[],
[]
stack
.
append
((
0
,
self
.
vocab
.
get_slots
(
root_wid
)))
all_nodes
=
{
0
:
root_node_dict
}
first
=
True
new_node_id
=
0
new_edge_id
=
0
for
step
in
range
(
MAX_DECODE_LEN
):
u
,
u_slots
=
stack
[
-
1
]
udata
=
mol_tree
.
nodes
[
u
].
data
x
=
udata
[
'x'
]
h
=
udata
[
'h'
]
# Predict stop
p_input
=
torch
.
cat
([
x
,
h
,
mol_vec
],
1
)
p_score
=
torch
.
sigmoid
(
self
.
U_s
(
torch
.
relu
(
self
.
U
(
p_input
))))
backtrack
=
(
p_score
.
item
()
<
0.5
)
if
not
backtrack
:
# Predict next clique. Note that the prediction may fail due
# to lack of assemblable components
mol_tree
.
add_nodes
(
1
)
new_node_id
+=
1
v
=
new_node_id
mol_tree
.
add_edges
(
u
,
v
)
uv
=
new_edge_id
new_edge_id
+=
1
if
first
:
mol_tree
.
edata
.
update
({
's'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
'm'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
'r'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
'z'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
'src_x'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
'dst_x'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
'rm'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
'accum_rm'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
})
first
=
False
mol_tree
.
edges
[
uv
].
data
[
'src_x'
]
=
mol_tree
.
nodes
[
u
].
data
[
'x'
]
# keeping dst_x 0 is fine as h on new edge doesn't depend on that.
# DGL doesn't dynamically maintain a line graph.
mol_tree_lg
=
mol_tree
.
line_graph
(
backtracking
=
False
,
shared
=
True
)
mol_tree_lg
.
pull
(
uv
,
dec_tree_edge_msg
,
dec_tree_edge_reduce
,
self
.
dec_tree_edge_update
.
update_zm
,
)
mol_tree
.
pull
(
v
,
dec_tree_node_msg
,
dec_tree_node_reduce
,
)
vdata
=
mol_tree
.
nodes
[
v
].
data
h_v
=
vdata
[
'h'
]
q_input
=
torch
.
cat
([
h_v
,
mol_vec
],
1
)
q_score
=
torch
.
softmax
(
self
.
W_o
(
torch
.
relu
(
self
.
W
(
q_input
))),
-
1
)
_
,
sort_wid
=
torch
.
sort
(
q_score
,
1
,
descending
=
True
)
sort_wid
=
sort_wid
.
squeeze
()
next_wid
=
None
for
wid
in
sort_wid
.
tolist
()[:
5
]:
slots
=
self
.
vocab
.
get_slots
(
wid
)
cand_node_dict
=
create_node_dict
(
self
.
vocab
.
get_smiles
(
wid
))
if
(
have_slots
(
u_slots
,
slots
)
and
can_assemble
(
mol_tree
,
u
,
cand_node_dict
)):
next_wid
=
wid
next_slots
=
slots
next_node_dict
=
cand_node_dict
break
if
next_wid
is
None
:
# Failed adding an actual children; v is a spurious node
# and we mark it.
vdata
[
'fail'
]
=
cuda
(
torch
.
tensor
([
1
]))
backtrack
=
True
else
:
next_wid
=
cuda
(
torch
.
tensor
([
next_wid
]))
vdata
[
'wid'
]
=
next_wid
vdata
[
'x'
]
=
self
.
embedding
(
next_wid
)
mol_tree
.
nodes_dict
[
v
]
=
next_node_dict
all_nodes
[
v
]
=
next_node_dict
stack
.
append
((
v
,
next_slots
))
mol_tree
.
add_edge
(
v
,
u
)
vu
=
new_edge_id
new_edge_id
+=
1
mol_tree
.
edges
[
uv
].
data
[
'dst_x'
]
=
mol_tree
.
nodes
[
v
].
data
[
'x'
]
mol_tree
.
edges
[
vu
].
data
[
'src_x'
]
=
mol_tree
.
nodes
[
v
].
data
[
'x'
]
mol_tree
.
edges
[
vu
].
data
[
'dst_x'
]
=
mol_tree
.
nodes
[
u
].
data
[
'x'
]
# DGL doesn't dynamically maintain a line graph.
mol_tree_lg
=
mol_tree
.
line_graph
(
backtracking
=
False
,
shared
=
True
)
mol_tree_lg
.
apply_nodes
(
self
.
dec_tree_edge_update
.
update_r
,
uv
)
if
backtrack
:
if
len
(
stack
)
==
1
:
break
# At root, terminate
pu
,
_
=
stack
[
-
2
]
u_pu
=
mol_tree
.
edge_id
(
u
,
pu
)
mol_tree_lg
.
pull
(
u_pu
,
dec_tree_edge_msg
,
dec_tree_edge_reduce
,
self
.
dec_tree_edge_update
,
)
mol_tree
.
pull
(
pu
,
dec_tree_node_msg
,
dec_tree_node_reduce
,
)
stack
.
pop
()
effective_nodes
=
mol_tree
.
filter_nodes
(
lambda
nodes
:
nodes
.
data
[
'fail'
]
!=
1
)
effective_nodes
,
_
=
torch
.
sort
(
effective_nodes
)
return
mol_tree
,
all_nodes
,
effective_nodes
python/dgl/model_zoo/chem/jtnn/jtnn_enc.py
deleted
100644 → 0
View file @
94c67203
# pylint: disable=C0111, C0103, E1101, W0611, W0612
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
dgl.function
as
DGLF
from
dgl
import
batch
,
bfs_edges_generator
from
.nnutils
import
GRUUpdate
,
cuda
MAX_NB
=
8
def
level_order
(
forest
,
roots
):
edges
=
bfs_edges_generator
(
forest
,
roots
)
_
,
leaves
=
forest
.
find_edges
(
edges
[
-
1
])
edges_back
=
bfs_edges_generator
(
forest
,
roots
,
reverse
=
True
)
yield
from
reversed
(
edges_back
)
yield
from
edges
enc_tree_msg
=
[
DGLF
.
copy_src
(
src
=
'm'
,
out
=
'm'
),
DGLF
.
copy_src
(
src
=
'rm'
,
out
=
'rm'
)]
enc_tree_reduce
=
[
DGLF
.
sum
(
msg
=
'm'
,
out
=
's'
),
DGLF
.
sum
(
msg
=
'rm'
,
out
=
'accum_rm'
)]
enc_tree_gather_msg
=
DGLF
.
copy_edge
(
edge
=
'm'
,
out
=
'm'
)
enc_tree_gather_reduce
=
DGLF
.
sum
(
msg
=
'm'
,
out
=
'm'
)
class
EncoderGatherUpdate
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
):
nn
.
Module
.
__init__
(
self
)
self
.
hidden_size
=
hidden_size
self
.
W
=
nn
.
Linear
(
2
*
hidden_size
,
hidden_size
)
def
forward
(
self
,
nodes
):
x
=
nodes
.
data
[
'x'
]
m
=
nodes
.
data
[
'm'
]
return
{
'h'
:
torch
.
relu
(
self
.
W
(
torch
.
cat
([
x
,
m
],
1
))),
}
class
DGLJTNNEncoder
(
nn
.
Module
):
def
__init__
(
self
,
vocab
,
hidden_size
,
embedding
=
None
):
nn
.
Module
.
__init__
(
self
)
self
.
hidden_size
=
hidden_size
self
.
vocab_size
=
vocab
.
size
()
self
.
vocab
=
vocab
if
embedding
is
None
:
self
.
embedding
=
nn
.
Embedding
(
self
.
vocab_size
,
hidden_size
)
else
:
self
.
embedding
=
embedding
self
.
enc_tree_update
=
GRUUpdate
(
hidden_size
)
self
.
enc_tree_gather_update
=
EncoderGatherUpdate
(
hidden_size
)
def
forward
(
self
,
mol_trees
):
mol_tree_batch
=
batch
(
mol_trees
)
# Build line graph to prepare for belief propagation
mol_tree_batch_lg
=
mol_tree_batch
.
line_graph
(
backtracking
=
False
,
shared
=
True
)
return
self
.
run
(
mol_tree_batch
,
mol_tree_batch_lg
)
def
run
(
self
,
mol_tree_batch
,
mol_tree_batch_lg
):
# Since tree roots are designated to 0. In the batched graph we can
# simply find the corresponding node ID by looking at node_offset
node_offset
=
np
.
cumsum
([
0
]
+
mol_tree_batch
.
batch_num_nodes
)
root_ids
=
node_offset
[:
-
1
]
n_nodes
=
mol_tree_batch
.
number_of_nodes
()
n_edges
=
mol_tree_batch
.
number_of_edges
()
# Assign structure embeddings to tree nodes
mol_tree_batch
.
ndata
.
update
({
'x'
:
self
.
embedding
(
mol_tree_batch
.
ndata
[
'wid'
]),
'h'
:
cuda
(
torch
.
zeros
(
n_nodes
,
self
.
hidden_size
)),
})
# Initialize the intermediate variables according to Eq (4)-(8).
# Also initialize the src_x and dst_x fields.
# TODO: context?
mol_tree_batch
.
edata
.
update
({
's'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'm'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'r'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'z'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'src_x'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'dst_x'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'rm'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'accum_rm'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
})
# Send the source/destination node features to edges
mol_tree_batch
.
apply_edges
(
func
=
lambda
edges
:
{
'src_x'
:
edges
.
src
[
'x'
],
'dst_x'
:
edges
.
dst
[
'x'
]},
)
# Message passing
# I exploited the fact that the reduce function is a sum of incoming
# messages, and the uncomputed messages are zero vectors. Essentially,
# we can always compute s_ij as the sum of incoming m_ij, no matter
# if m_ij is actually computed or not.
for
eid
in
level_order
(
mol_tree_batch
,
root_ids
):
#eid = mol_tree_batch.edge_ids(u, v)
mol_tree_batch_lg
.
pull
(
eid
,
enc_tree_msg
,
enc_tree_reduce
,
self
.
enc_tree_update
,
)
# Readout
mol_tree_batch
.
update_all
(
enc_tree_gather_msg
,
enc_tree_gather_reduce
,
self
.
enc_tree_gather_update
,
)
root_vecs
=
mol_tree_batch
.
nodes
[
root_ids
].
data
[
'h'
]
return
mol_tree_batch
,
root_vecs
python/dgl/model_zoo/chem/jtnn/jtnn_vae.py
deleted
100644 → 0
View file @
94c67203
# pylint: disable=C0111, C0103, E1101, W0611, W0612, C0200
import
copy
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
rdkit.Chem
as
Chem
from
....graph
import
batch
,
unbatch
from
....contrib.deprecation
import
deprecated
from
....data.utils
import
get_download_dir
from
.chemutils
import
(
attach_mols_nx
,
copy_edit_mol
,
decode_stereo
,
enum_assemble_nx
,
set_atommap
)
from
.jtmpn
import
DGLJTMPN
from
.jtmpn
import
mol2dgl_single
as
mol2dgl_dec
from
.jtnn_dec
import
DGLJTNNDecoder
from
.jtnn_enc
import
DGLJTNNEncoder
from
.mol_tree
import
Vocab
from
.mpn
import
DGLMPN
from
.mpn
import
mol2dgl_single
as
mol2dgl_enc
from
.nnutils
import
cuda
,
move_dgl_to_cuda
class
DGLJTNNVAE
(
nn
.
Module
):
"""
`Junction Tree Variational Autoencoder for Molecular Graph Generation
<https://arxiv.org/abs/1802.04364>`__
"""
@
deprecated
(
'Import DGLJTNNVAE from dgllife.model instead.'
,
'class'
)
def
__init__
(
self
,
hidden_size
,
latent_size
,
depth
,
vocab
=
None
,
vocab_file
=
None
):
super
(
DGLJTNNVAE
,
self
).
__init__
()
if
vocab
is
None
:
if
vocab_file
is
None
:
vocab_file
=
'{}/jtnn/{}.txt'
.
format
(
get_download_dir
(),
'vocab'
)
self
.
vocab
=
Vocab
([
x
.
strip
(
"
\r\n
"
)
for
x
in
open
(
vocab_file
)])
else
:
self
.
vocab
=
vocab
self
.
hidden_size
=
hidden_size
self
.
latent_size
=
latent_size
self
.
depth
=
depth
self
.
embedding
=
nn
.
Embedding
(
self
.
vocab
.
size
(),
hidden_size
)
self
.
mpn
=
DGLMPN
(
hidden_size
,
depth
)
self
.
jtnn
=
DGLJTNNEncoder
(
self
.
vocab
,
hidden_size
,
self
.
embedding
)
self
.
decoder
=
DGLJTNNDecoder
(
self
.
vocab
,
hidden_size
,
latent_size
//
2
,
self
.
embedding
)
self
.
jtmpn
=
DGLJTMPN
(
hidden_size
,
depth
)
self
.
T_mean
=
nn
.
Linear
(
hidden_size
,
latent_size
//
2
)
self
.
T_var
=
nn
.
Linear
(
hidden_size
,
latent_size
//
2
)
self
.
G_mean
=
nn
.
Linear
(
hidden_size
,
latent_size
//
2
)
self
.
G_var
=
nn
.
Linear
(
hidden_size
,
latent_size
//
2
)
self
.
n_nodes_total
=
0
self
.
n_passes
=
0
self
.
n_edges_total
=
0
self
.
n_tree_nodes_total
=
0
@
staticmethod
def
move_to_cuda
(
mol_batch
):
for
t
in
mol_batch
[
'mol_trees'
]:
move_dgl_to_cuda
(
t
)
move_dgl_to_cuda
(
mol_batch
[
'mol_graph_batch'
])
if
'cand_graph_batch'
in
mol_batch
:
move_dgl_to_cuda
(
mol_batch
[
'cand_graph_batch'
])
if
mol_batch
.
get
(
'stereo_cand_graph_batch'
)
is
not
None
:
move_dgl_to_cuda
(
mol_batch
[
'stereo_cand_graph_batch'
])
def
encode
(
self
,
mol_batch
):
mol_graphs
=
mol_batch
[
'mol_graph_batch'
]
mol_vec
=
self
.
mpn
(
mol_graphs
)
mol_tree_batch
,
tree_vec
=
self
.
jtnn
(
mol_batch
[
'mol_trees'
])
self
.
n_nodes_total
+=
mol_graphs
.
number_of_nodes
()
self
.
n_edges_total
+=
mol_graphs
.
number_of_edges
()
self
.
n_tree_nodes_total
+=
sum
(
t
.
number_of_nodes
()
for
t
in
mol_batch
[
'mol_trees'
])
self
.
n_passes
+=
1
return
mol_tree_batch
,
tree_vec
,
mol_vec
def
sample
(
self
,
tree_vec
,
mol_vec
,
e1
=
None
,
e2
=
None
):
tree_mean
=
self
.
T_mean
(
tree_vec
)
tree_log_var
=
-
torch
.
abs
(
self
.
T_var
(
tree_vec
))
mol_mean
=
self
.
G_mean
(
mol_vec
)
mol_log_var
=
-
torch
.
abs
(
self
.
G_var
(
mol_vec
))
epsilon
=
cuda
(
torch
.
randn
(
*
tree_mean
.
shape
))
if
e1
is
None
else
e1
tree_vec
=
tree_mean
+
torch
.
exp
(
tree_log_var
/
2
)
*
epsilon
epsilon
=
cuda
(
torch
.
randn
(
*
mol_mean
.
shape
))
if
e2
is
None
else
e2
mol_vec
=
mol_mean
+
torch
.
exp
(
mol_log_var
/
2
)
*
epsilon
z_mean
=
torch
.
cat
([
tree_mean
,
mol_mean
],
1
)
z_log_var
=
torch
.
cat
([
tree_log_var
,
mol_log_var
],
1
)
return
tree_vec
,
mol_vec
,
z_mean
,
z_log_var
def
forward
(
self
,
mol_batch
,
beta
=
0
,
e1
=
None
,
e2
=
None
):
self
.
move_to_cuda
(
mol_batch
)
mol_trees
=
mol_batch
[
'mol_trees'
]
batch_size
=
len
(
mol_trees
)
mol_tree_batch
,
tree_vec
,
mol_vec
=
self
.
encode
(
mol_batch
)
tree_vec
,
mol_vec
,
z_mean
,
z_log_var
=
self
.
sample
(
tree_vec
,
mol_vec
,
e1
,
e2
)
kl_loss
=
-
0.5
*
torch
.
sum
(
1.0
+
z_log_var
-
z_mean
*
z_mean
-
torch
.
exp
(
z_log_var
))
/
batch_size
word_loss
,
topo_loss
,
word_acc
,
topo_acc
=
self
.
decoder
(
mol_trees
,
tree_vec
)
assm_loss
,
assm_acc
=
self
.
assm
(
mol_batch
,
mol_tree_batch
,
mol_vec
)
stereo_loss
,
stereo_acc
=
self
.
stereo
(
mol_batch
,
mol_vec
)
loss
=
word_loss
+
topo_loss
+
assm_loss
+
2
*
stereo_loss
+
beta
*
kl_loss
return
loss
,
kl_loss
,
word_acc
,
topo_acc
,
assm_acc
,
stereo_acc
def
assm
(
self
,
mol_batch
,
mol_tree_batch
,
mol_vec
):
cands
=
[
mol_batch
[
'cand_graph_batch'
],
mol_batch
[
'tree_mess_src_e'
],
mol_batch
[
'tree_mess_tgt_e'
],
mol_batch
[
'tree_mess_tgt_n'
]]
cand_vec
=
self
.
jtmpn
(
cands
,
mol_tree_batch
)
cand_vec
=
self
.
G_mean
(
cand_vec
)
batch_idx
=
cuda
(
torch
.
LongTensor
(
mol_batch
[
'cand_batch_idx'
]))
mol_vec
=
mol_vec
[
batch_idx
]
mol_vec
=
mol_vec
.
view
(
-
1
,
1
,
self
.
latent_size
//
2
)
cand_vec
=
cand_vec
.
view
(
-
1
,
self
.
latent_size
//
2
,
1
)
scores
=
(
mol_vec
@
cand_vec
)[:,
0
,
0
]
cnt
,
tot
,
acc
=
0
,
0
,
0
all_loss
=
[]
for
i
,
mol_tree
in
enumerate
(
mol_batch
[
'mol_trees'
]):
comp_nodes
=
[
node_id
for
node_id
,
node
in
mol_tree
.
nodes_dict
.
items
()
if
len
(
node
[
'cands'
])
>
1
and
not
node
[
'is_leaf'
]]
cnt
+=
len
(
comp_nodes
)
# segmented accuracy and cross entropy
for
node_id
in
comp_nodes
:
node
=
mol_tree
.
nodes_dict
[
node_id
]
label
=
node
[
'cands'
].
index
(
node
[
'label'
])
ncand
=
len
(
node
[
'cands'
])
cur_score
=
scores
[
tot
:
tot
+
ncand
]
tot
+=
ncand
if
cur_score
[
label
].
item
()
>=
cur_score
.
max
().
item
():
acc
+=
1
label
=
cuda
(
torch
.
LongTensor
([
label
]))
all_loss
.
append
(
F
.
cross_entropy
(
cur_score
.
view
(
1
,
-
1
),
label
,
size_average
=
False
))
all_loss
=
sum
(
all_loss
)
/
len
(
mol_batch
[
'mol_trees'
])
return
all_loss
,
acc
/
cnt
def
stereo
(
self
,
mol_batch
,
mol_vec
):
stereo_cands
=
mol_batch
[
'stereo_cand_graph_batch'
]
batch_idx
=
mol_batch
[
'stereo_cand_batch_idx'
]
labels
=
mol_batch
[
'stereo_cand_labels'
]
lengths
=
mol_batch
[
'stereo_cand_lengths'
]
if
len
(
labels
)
==
0
:
# Only one stereoisomer exists; do nothing
return
cuda
(
torch
.
tensor
(
0.
)),
1.
batch_idx
=
cuda
(
torch
.
LongTensor
(
batch_idx
))
stereo_cands
=
self
.
mpn
(
stereo_cands
)
stereo_cands
=
self
.
G_mean
(
stereo_cands
)
stereo_labels
=
mol_vec
[
batch_idx
]
scores
=
F
.
cosine_similarity
(
stereo_cands
,
stereo_labels
)
st
,
acc
=
0
,
0
all_loss
=
[]
for
label
,
le
in
zip
(
labels
,
lengths
):
cur_scores
=
scores
[
st
:
st
+
le
]
if
cur_scores
.
data
[
label
].
item
()
>=
cur_scores
.
max
().
item
():
acc
+=
1
label
=
cuda
(
torch
.
LongTensor
([
label
]))
all_loss
.
append
(
F
.
cross_entropy
(
cur_scores
.
view
(
1
,
-
1
),
label
,
size_average
=
False
))
st
+=
le
all_loss
=
sum
(
all_loss
)
/
len
(
labels
)
return
all_loss
,
acc
/
len
(
labels
)
def
decode
(
self
,
tree_vec
,
mol_vec
):
mol_tree
,
nodes_dict
,
effective_nodes
=
self
.
decoder
.
decode
(
tree_vec
)
effective_nodes_list
=
effective_nodes
.
tolist
()
nodes_dict
=
[
nodes_dict
[
v
]
for
v
in
effective_nodes_list
]
for
i
,
(
node_id
,
node
)
in
enumerate
(
zip
(
effective_nodes_list
,
nodes_dict
)):
node
[
'idx'
]
=
i
node
[
'nid'
]
=
i
+
1
node
[
'is_leaf'
]
=
True
if
mol_tree
.
in_degree
(
node_id
)
>
1
:
node
[
'is_leaf'
]
=
False
set_atommap
(
node
[
'mol'
],
node
[
'nid'
])
mol_tree_sg
=
mol_tree
.
subgraph
(
effective_nodes
)
mol_tree_sg
.
copy_from_parent
()
mol_tree_msg
,
_
=
self
.
jtnn
([
mol_tree_sg
])
mol_tree_msg
=
unbatch
(
mol_tree_msg
)[
0
]
mol_tree_msg
.
nodes_dict
=
nodes_dict
cur_mol
=
copy_edit_mol
(
nodes_dict
[
0
][
'mol'
])
global_amap
=
[{}]
+
[{}
for
node
in
nodes_dict
]
global_amap
[
1
]
=
{
atom
.
GetIdx
():
atom
.
GetIdx
()
for
atom
in
cur_mol
.
GetAtoms
()}
cur_mol
=
self
.
dfs_assemble
(
mol_tree_msg
,
mol_vec
,
cur_mol
,
global_amap
,
[],
0
,
None
)
if
cur_mol
is
None
:
return
None
cur_mol
=
cur_mol
.
GetMol
()
set_atommap
(
cur_mol
)
cur_mol
=
Chem
.
MolFromSmiles
(
Chem
.
MolToSmiles
(
cur_mol
))
if
cur_mol
is
None
:
return
None
smiles2D
=
Chem
.
MolToSmiles
(
cur_mol
)
stereo_cands
=
decode_stereo
(
smiles2D
)
if
len
(
stereo_cands
)
==
1
:
return
stereo_cands
[
0
]
stereo_graphs
=
[
mol2dgl_enc
(
c
)
for
c
in
stereo_cands
]
stereo_cand_graphs
,
atom_x
,
bond_x
=
\
zip
(
*
stereo_graphs
)
stereo_cand_graphs
=
batch
(
stereo_cand_graphs
)
atom_x
=
cuda
(
torch
.
cat
(
atom_x
))
bond_x
=
cuda
(
torch
.
cat
(
bond_x
))
stereo_cand_graphs
.
ndata
[
'x'
]
=
atom_x
stereo_cand_graphs
.
edata
[
'x'
]
=
bond_x
stereo_cand_graphs
.
edata
[
'src_x'
]
=
atom_x
.
new
(
bond_x
.
shape
[
0
],
atom_x
.
shape
[
1
]).
zero_
()
stereo_vecs
=
self
.
mpn
(
stereo_cand_graphs
)
stereo_vecs
=
self
.
G_mean
(
stereo_vecs
)
scores
=
F
.
cosine_similarity
(
stereo_vecs
,
mol_vec
)
_
,
max_id
=
scores
.
max
(
0
)
return
stereo_cands
[
max_id
.
item
()]
def
dfs_assemble
(
self
,
mol_tree_msg
,
mol_vec
,
cur_mol
,
global_amap
,
fa_amap
,
cur_node_id
,
fa_node_id
):
nodes_dict
=
mol_tree_msg
.
nodes_dict
fa_node
=
nodes_dict
[
fa_node_id
]
if
fa_node_id
is
not
None
else
None
cur_node
=
nodes_dict
[
cur_node_id
]
fa_nid
=
fa_node
[
'nid'
]
if
fa_node
is
not
None
else
-
1
prev_nodes
=
[
fa_node
]
if
fa_node
is
not
None
else
[]
children_node_id
=
[
v
for
v
in
mol_tree_msg
.
successors
(
cur_node_id
).
tolist
()
if
nodes_dict
[
v
][
'nid'
]
!=
fa_nid
]
children
=
[
nodes_dict
[
v
]
for
v
in
children_node_id
]
neighbors
=
[
nei
for
nei
in
children
if
nei
[
'mol'
].
GetNumAtoms
()
>
1
]
neighbors
=
sorted
(
neighbors
,
key
=
lambda
x
:
x
[
'mol'
].
GetNumAtoms
(),
reverse
=
True
)
singletons
=
[
nei
for
nei
in
children
if
nei
[
'mol'
].
GetNumAtoms
()
==
1
]
neighbors
=
singletons
+
neighbors
cur_amap
=
[(
fa_nid
,
a2
,
a1
)
for
nid
,
a1
,
a2
in
fa_amap
if
nid
==
cur_node
[
'nid'
]]
cands
=
enum_assemble_nx
(
cur_node
,
neighbors
,
prev_nodes
,
cur_amap
)
if
len
(
cands
)
==
0
:
return
None
cand_smiles
,
cand_mols
,
cand_amap
=
list
(
zip
(
*
cands
))
cands
=
[(
candmol
,
mol_tree_msg
,
cur_node_id
)
for
candmol
in
cand_mols
]
cand_graphs
,
atom_x
,
bond_x
,
tree_mess_src_edges
,
\
tree_mess_tgt_edges
,
tree_mess_tgt_nodes
=
mol2dgl_dec
(
cands
)
cand_graphs
=
batch
(
cand_graphs
)
atom_x
=
cuda
(
atom_x
)
bond_x
=
cuda
(
bond_x
)
cand_graphs
.
ndata
[
'x'
]
=
atom_x
cand_graphs
.
edata
[
'x'
]
=
bond_x
cand_graphs
.
edata
[
'src_x'
]
=
atom_x
.
new
(
bond_x
.
shape
[
0
],
atom_x
.
shape
[
1
]).
zero_
()
cand_vecs
=
self
.
jtmpn
(
(
cand_graphs
,
tree_mess_src_edges
,
tree_mess_tgt_edges
,
tree_mess_tgt_nodes
),
mol_tree_msg
,
)
cand_vecs
=
self
.
G_mean
(
cand_vecs
)
mol_vec
=
mol_vec
.
squeeze
()
scores
=
cand_vecs
@
mol_vec
_
,
cand_idx
=
torch
.
sort
(
scores
,
descending
=
True
)
backup_mol
=
Chem
.
RWMol
(
cur_mol
)
for
i
in
range
(
len
(
cand_idx
)):
cur_mol
=
Chem
.
RWMol
(
backup_mol
)
pred_amap
=
cand_amap
[
cand_idx
[
i
].
item
()]
new_global_amap
=
copy
.
deepcopy
(
global_amap
)
for
nei_id
,
ctr_atom
,
nei_atom
in
pred_amap
:
if
nei_id
==
fa_nid
:
continue
new_global_amap
[
nei_id
][
nei_atom
]
=
new_global_amap
[
cur_node
[
'nid'
]][
ctr_atom
]
cur_mol
=
attach_mols_nx
(
cur_mol
,
children
,
[],
new_global_amap
)
new_mol
=
cur_mol
.
GetMol
()
new_mol
=
Chem
.
MolFromSmiles
(
Chem
.
MolToSmiles
(
new_mol
))
if
new_mol
is
None
:
continue
result
=
True
for
nei_node_id
,
nei_node
in
zip
(
children_node_id
,
children
):
if
nei_node
[
'is_leaf'
]:
continue
cur_mol
=
self
.
dfs_assemble
(
mol_tree_msg
,
mol_vec
,
cur_mol
,
new_global_amap
,
pred_amap
,
nei_node_id
,
cur_node_id
)
if
cur_mol
is
None
:
result
=
False
break
if
result
:
return
cur_mol
return
None
python/dgl/model_zoo/chem/jtnn/mol_tree.py
deleted
100644 → 0
View file @
94c67203
# pylint: disable=C0111, C0103, E1101, W0611, W0612
import
copy
import
rdkit.Chem
as
Chem
def
get_slots
(
smiles
):
mol
=
Chem
.
MolFromSmiles
(
smiles
)
return
[(
atom
.
GetSymbol
(),
atom
.
GetFormalCharge
(),
atom
.
GetTotalNumHs
())
for
atom
in
mol
.
GetAtoms
()]
class
Vocab
(
object
):
def
__init__
(
self
,
smiles_list
):
self
.
vocab
=
smiles_list
self
.
vmap
=
{
x
:
i
for
i
,
x
in
enumerate
(
self
.
vocab
)}
self
.
slots
=
[
get_slots
(
smiles
)
for
smiles
in
self
.
vocab
]
def
get_index
(
self
,
smiles
):
return
self
.
vmap
[
smiles
]
def
get_smiles
(
self
,
idx
):
return
self
.
vocab
[
idx
]
def
get_slots
(
self
,
idx
):
return
copy
.
deepcopy
(
self
.
slots
[
idx
])
def
size
(
self
):
return
len
(
self
.
vocab
)
python/dgl/model_zoo/chem/jtnn/mol_tree_nx.py
deleted
100644 → 0
View file @
94c67203
# pylint: disable=C0111, C0103, E1101, W0611, W0612
import
numpy
as
np
import
rdkit.Chem
as
Chem
from
dgl
import
DGLGraph
from
.chemutils
import
(
decode_stereo
,
enum_assemble_nx
,
get_clique_mol
,
get_mol
,
get_smiles
,
set_atommap
,
tree_decomp
)
class
DGLMolTree
(
DGLGraph
):
def
__init__
(
self
,
smiles
):
DGLGraph
.
__init__
(
self
)
self
.
nodes_dict
=
{}
if
smiles
is
None
:
return
self
.
smiles
=
smiles
self
.
mol
=
get_mol
(
smiles
)
# Stereo Generation
mol
=
Chem
.
MolFromSmiles
(
smiles
)
self
.
smiles3D
=
Chem
.
MolToSmiles
(
mol
,
isomericSmiles
=
True
)
self
.
smiles2D
=
Chem
.
MolToSmiles
(
mol
)
self
.
stereo_cands
=
decode_stereo
(
self
.
smiles2D
)
# cliques: a list of list of atom indices
cliques
,
edges
=
tree_decomp
(
self
.
mol
)
root
=
0
for
i
,
c
in
enumerate
(
cliques
):
cmol
=
get_clique_mol
(
self
.
mol
,
c
)
csmiles
=
get_smiles
(
cmol
)
self
.
nodes_dict
[
i
]
=
dict
(
smiles
=
csmiles
,
mol
=
get_mol
(
csmiles
),
clique
=
c
,
)
if
min
(
c
)
==
0
:
root
=
i
self
.
add_nodes
(
len
(
cliques
))
# The clique with atom ID 0 becomes root
if
root
>
0
:
for
attr
in
self
.
nodes_dict
[
0
]:
self
.
nodes_dict
[
0
][
attr
],
self
.
nodes_dict
[
root
][
attr
]
=
\
self
.
nodes_dict
[
root
][
attr
],
self
.
nodes_dict
[
0
][
attr
]
src
=
np
.
zeros
((
len
(
edges
)
*
2
,),
dtype
=
'int'
)
dst
=
np
.
zeros
((
len
(
edges
)
*
2
,),
dtype
=
'int'
)
for
i
,
(
_x
,
_y
)
in
enumerate
(
edges
):
x
=
0
if
_x
==
root
else
root
if
_x
==
0
else
_x
y
=
0
if
_y
==
root
else
root
if
_y
==
0
else
_y
src
[
2
*
i
]
=
x
dst
[
2
*
i
]
=
y
src
[
2
*
i
+
1
]
=
y
dst
[
2
*
i
+
1
]
=
x
self
.
add_edges
(
src
,
dst
)
for
i
in
self
.
nodes_dict
:
self
.
nodes_dict
[
i
][
'nid'
]
=
i
+
1
if
self
.
out_degree
(
i
)
>
1
:
# Leaf node mol is not marked
set_atommap
(
self
.
nodes_dict
[
i
][
'mol'
],
self
.
nodes_dict
[
i
][
'nid'
])
self
.
nodes_dict
[
i
][
'is_leaf'
]
=
(
self
.
out_degree
(
i
)
==
1
)
def
treesize
(
self
):
return
self
.
number_of_nodes
()
def
_recover_node
(
self
,
i
,
original_mol
):
node
=
self
.
nodes_dict
[
i
]
clique
=
[]
clique
.
extend
(
node
[
'clique'
])
if
not
node
[
'is_leaf'
]:
for
cidx
in
node
[
'clique'
]:
original_mol
.
GetAtomWithIdx
(
cidx
).
SetAtomMapNum
(
node
[
'nid'
])
for
j
in
self
.
successors
(
i
).
numpy
():
nei_node
=
self
.
nodes_dict
[
j
]
clique
.
extend
(
nei_node
[
'clique'
])
if
nei_node
[
'is_leaf'
]:
# Leaf node, no need to mark
continue
for
cidx
in
nei_node
[
'clique'
]:
# allow singleton node override the atom mapping
if
cidx
not
in
node
[
'clique'
]
or
len
(
nei_node
[
'clique'
])
==
1
:
atom
=
original_mol
.
GetAtomWithIdx
(
cidx
)
atom
.
SetAtomMapNum
(
nei_node
[
'nid'
])
clique
=
list
(
set
(
clique
))
label_mol
=
get_clique_mol
(
original_mol
,
clique
)
node
[
'label'
]
=
Chem
.
MolToSmiles
(
Chem
.
MolFromSmiles
(
get_smiles
(
label_mol
)))
node
[
'label_mol'
]
=
get_mol
(
node
[
'label'
])
for
cidx
in
clique
:
original_mol
.
GetAtomWithIdx
(
cidx
).
SetAtomMapNum
(
0
)
return
node
[
'label'
]
def
_assemble_node
(
self
,
i
):
neighbors
=
[
self
.
nodes_dict
[
j
]
for
j
in
self
.
successors
(
i
).
numpy
()
if
self
.
nodes_dict
[
j
][
'mol'
].
GetNumAtoms
()
>
1
]
neighbors
=
sorted
(
neighbors
,
key
=
lambda
x
:
x
[
'mol'
].
GetNumAtoms
(),
reverse
=
True
)
singletons
=
[
self
.
nodes_dict
[
j
]
for
j
in
self
.
successors
(
i
).
numpy
()
if
self
.
nodes_dict
[
j
][
'mol'
].
GetNumAtoms
()
==
1
]
neighbors
=
singletons
+
neighbors
cands
=
enum_assemble_nx
(
self
.
nodes_dict
[
i
],
neighbors
)
if
len
(
cands
)
>
0
:
self
.
nodes_dict
[
i
][
'cands'
],
self
.
nodes_dict
[
i
][
'cand_mols'
],
_
=
list
(
zip
(
*
cands
))
self
.
nodes_dict
[
i
][
'cands'
]
=
list
(
self
.
nodes_dict
[
i
][
'cands'
])
self
.
nodes_dict
[
i
][
'cand_mols'
]
=
list
(
self
.
nodes_dict
[
i
][
'cand_mols'
])
else
:
self
.
nodes_dict
[
i
][
'cands'
]
=
[]
self
.
nodes_dict
[
i
][
'cand_mols'
]
=
[]
def
recover
(
self
):
for
i
in
self
.
nodes_dict
:
self
.
_recover_node
(
i
,
self
.
mol
)
def
assemble
(
self
):
for
i
in
self
.
nodes_dict
:
self
.
_assemble_node
(
i
)
python/dgl/model_zoo/chem/jtnn/mpn.py
deleted
100644 → 0
View file @
94c67203
# pylint: disable=C0111, C0103, E1101, W0611, W0612
# pylint: disable=redefined-outer-name
import
rdkit.Chem
as
Chem
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
dgl.function
as
DGLF
from
dgl
import
DGLGraph
,
mean_nodes
from
.chemutils
import
get_mol
ELEM_LIST
=
[
'C'
,
'N'
,
'O'
,
'S'
,
'F'
,
'Si'
,
'P'
,
'Cl'
,
'Br'
,
'Mg'
,
'Na'
,
'Ca'
,
'Fe'
,
'Al'
,
'I'
,
'B'
,
'K'
,
'Se'
,
'Zn'
,
'H'
,
'Cu'
,
'Mn'
,
'unknown'
]
ATOM_FDIM
=
len
(
ELEM_LIST
)
+
6
+
5
+
4
+
1
BOND_FDIM
=
5
+
6
MAX_NB
=
6
def
onek_encoding_unk
(
x
,
allowable_set
):
if
x
not
in
allowable_set
:
x
=
allowable_set
[
-
1
]
return
[
x
==
s
for
s
in
allowable_set
]
def
atom_features
(
atom
):
return
(
torch
.
Tensor
(
onek_encoding_unk
(
atom
.
GetSymbol
(),
ELEM_LIST
)
+
onek_encoding_unk
(
atom
.
GetDegree
(),
[
0
,
1
,
2
,
3
,
4
,
5
])
+
onek_encoding_unk
(
atom
.
GetFormalCharge
(),
[
-
1
,
-
2
,
1
,
2
,
0
])
+
onek_encoding_unk
(
int
(
atom
.
GetChiralTag
()),
[
0
,
1
,
2
,
3
])
+
[
atom
.
GetIsAromatic
()]))
def
bond_features
(
bond
):
bt
=
bond
.
GetBondType
()
stereo
=
int
(
bond
.
GetStereo
())
fbond
=
[
bt
==
Chem
.
rdchem
.
BondType
.
SINGLE
,
bt
==
Chem
.
rdchem
.
BondType
.
DOUBLE
,
bt
==
Chem
.
rdchem
.
BondType
.
TRIPLE
,
bt
==
Chem
.
rdchem
.
BondType
.
AROMATIC
,
bond
.
IsInRing
()]
fstereo
=
onek_encoding_unk
(
stereo
,
[
0
,
1
,
2
,
3
,
4
,
5
])
return
torch
.
Tensor
(
fbond
+
fstereo
)
def
mol2dgl_single
(
smiles
):
n_edges
=
0
atom_x
=
[]
bond_x
=
[]
mol
=
get_mol
(
smiles
)
n_atoms
=
mol
.
GetNumAtoms
()
n_bonds
=
mol
.
GetNumBonds
()
graph
=
DGLGraph
()
for
i
,
atom
in
enumerate
(
mol
.
GetAtoms
()):
assert
i
==
atom
.
GetIdx
()
atom_x
.
append
(
atom_features
(
atom
))
graph
.
add_nodes
(
n_atoms
)
bond_src
=
[]
bond_dst
=
[]
for
i
,
bond
in
enumerate
(
mol
.
GetBonds
()):
begin_idx
=
bond
.
GetBeginAtom
().
GetIdx
()
end_idx
=
bond
.
GetEndAtom
().
GetIdx
()
features
=
bond_features
(
bond
)
bond_src
.
append
(
begin_idx
)
bond_dst
.
append
(
end_idx
)
bond_x
.
append
(
features
)
# set up the reverse direction
bond_src
.
append
(
end_idx
)
bond_dst
.
append
(
begin_idx
)
bond_x
.
append
(
features
)
graph
.
add_edges
(
bond_src
,
bond_dst
)
n_edges
+=
n_bonds
return
graph
,
torch
.
stack
(
atom_x
),
\
torch
.
stack
(
bond_x
)
if
len
(
bond_x
)
>
0
else
torch
.
zeros
(
0
)
mpn_loopy_bp_msg
=
DGLF
.
copy_src
(
src
=
'msg'
,
out
=
'msg'
)
mpn_loopy_bp_reduce
=
DGLF
.
sum
(
msg
=
'msg'
,
out
=
'accum_msg'
)
class
LoopyBPUpdate
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
):
super
(
LoopyBPUpdate
,
self
).
__init__
()
self
.
hidden_size
=
hidden_size
self
.
W_h
=
nn
.
Linear
(
hidden_size
,
hidden_size
,
bias
=
False
)
def
forward
(
self
,
nodes
):
msg_input
=
nodes
.
data
[
'msg_input'
]
msg_delta
=
self
.
W_h
(
nodes
.
data
[
'accum_msg'
])
msg
=
F
.
relu
(
msg_input
+
msg_delta
)
return
{
'msg'
:
msg
}
mpn_gather_msg
=
DGLF
.
copy_edge
(
edge
=
'msg'
,
out
=
'msg'
)
mpn_gather_reduce
=
DGLF
.
sum
(
msg
=
'msg'
,
out
=
'm'
)
class
GatherUpdate
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
):
super
(
GatherUpdate
,
self
).
__init__
()
self
.
hidden_size
=
hidden_size
self
.
W_o
=
nn
.
Linear
(
ATOM_FDIM
+
hidden_size
,
hidden_size
)
def
forward
(
self
,
nodes
):
m
=
nodes
.
data
[
'm'
]
return
{
'h'
:
F
.
relu
(
self
.
W_o
(
torch
.
cat
([
nodes
.
data
[
'x'
],
m
],
1
))),
}
class
DGLMPN
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
depth
):
super
(
DGLMPN
,
self
).
__init__
()
self
.
depth
=
depth
self
.
W_i
=
nn
.
Linear
(
ATOM_FDIM
+
BOND_FDIM
,
hidden_size
,
bias
=
False
)
self
.
loopy_bp_updater
=
LoopyBPUpdate
(
hidden_size
)
self
.
gather_updater
=
GatherUpdate
(
hidden_size
)
self
.
hidden_size
=
hidden_size
self
.
n_samples_total
=
0
self
.
n_nodes_total
=
0
self
.
n_edges_total
=
0
self
.
n_passes
=
0
def
forward
(
self
,
mol_graph
):
n_samples
=
mol_graph
.
batch_size
mol_line_graph
=
mol_graph
.
line_graph
(
backtracking
=
False
,
shared
=
True
)
n_nodes
=
mol_graph
.
number_of_nodes
()
n_edges
=
mol_graph
.
number_of_edges
()
mol_graph
=
self
.
run
(
mol_graph
,
mol_line_graph
)
# TODO: replace with unbatch or readout
g_repr
=
mean_nodes
(
mol_graph
,
'h'
)
self
.
n_samples_total
+=
n_samples
self
.
n_nodes_total
+=
n_nodes
self
.
n_edges_total
+=
n_edges
self
.
n_passes
+=
1
return
g_repr
def
run
(
self
,
mol_graph
,
mol_line_graph
):
n_nodes
=
mol_graph
.
number_of_nodes
()
mol_graph
.
apply_edges
(
func
=
lambda
edges
:
{
'src_x'
:
edges
.
src
[
'x'
]},
)
e_repr
=
mol_line_graph
.
ndata
bond_features
=
e_repr
[
'x'
]
source_features
=
e_repr
[
'src_x'
]
features
=
torch
.
cat
([
source_features
,
bond_features
],
1
)
msg_input
=
self
.
W_i
(
features
)
mol_line_graph
.
ndata
.
update
({
'msg_input'
:
msg_input
,
'msg'
:
F
.
relu
(
msg_input
),
'accum_msg'
:
torch
.
zeros_like
(
msg_input
),
})
mol_graph
.
ndata
.
update
({
'm'
:
bond_features
.
new
(
n_nodes
,
self
.
hidden_size
).
zero_
(),
'h'
:
bond_features
.
new
(
n_nodes
,
self
.
hidden_size
).
zero_
(),
})
for
i
in
range
(
self
.
depth
-
1
):
mol_line_graph
.
update_all
(
mpn_loopy_bp_msg
,
mpn_loopy_bp_reduce
,
self
.
loopy_bp_updater
,
)
mol_graph
.
update_all
(
mpn_gather_msg
,
mpn_gather_reduce
,
self
.
gather_updater
,
)
return
mol_graph
python/dgl/model_zoo/chem/jtnn/nnutils.py
deleted
100644 → 0
View file @
94c67203
# pylint: disable=C0111, C0103, E1101, W0611, W0612
import
os
import
torch
import
torch.nn
as
nn
from
torch.autograd
import
Variable
def
create_var
(
tensor
,
requires_grad
=
None
):
if
requires_grad
is
None
:
return
Variable
(
tensor
)
else
:
return
Variable
(
tensor
,
requires_grad
=
requires_grad
)
def
cuda
(
tensor
):
if
torch
.
cuda
.
is_available
()
and
not
os
.
getenv
(
'NOCUDA'
,
None
):
return
tensor
.
cuda
()
else
:
return
tensor
class
GRUUpdate
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
):
nn
.
Module
.
__init__
(
self
)
self
.
hidden_size
=
hidden_size
self
.
W_z
=
nn
.
Linear
(
2
*
hidden_size
,
hidden_size
)
self
.
W_r
=
nn
.
Linear
(
hidden_size
,
hidden_size
,
bias
=
False
)
self
.
U_r
=
nn
.
Linear
(
hidden_size
,
hidden_size
)
self
.
W_h
=
nn
.
Linear
(
2
*
hidden_size
,
hidden_size
)
def
update_zm
(
self
,
node
):
src_x
=
node
.
data
[
'src_x'
]
s
=
node
.
data
[
's'
]
rm
=
node
.
data
[
'accum_rm'
]
z
=
torch
.
sigmoid
(
self
.
W_z
(
torch
.
cat
([
src_x
,
s
],
1
)))
m
=
torch
.
tanh
(
self
.
W_h
(
torch
.
cat
([
src_x
,
rm
],
1
)))
m
=
(
1
-
z
)
*
s
+
z
*
m
return
{
'm'
:
m
,
'z'
:
z
}
def
update_r
(
self
,
node
,
zm
=
None
):
dst_x
=
node
.
data
[
'dst_x'
]
m
=
node
.
data
[
'm'
]
if
zm
is
None
else
zm
[
'm'
]
r_1
=
self
.
W_r
(
dst_x
)
r_2
=
self
.
U_r
(
m
)
r
=
torch
.
sigmoid
(
r_1
+
r_2
)
return
{
'r'
:
r
,
'rm'
:
r
*
m
}
def
forward
(
self
,
node
):
dic
=
self
.
update_zm
(
node
)
dic
.
update
(
self
.
update_r
(
node
,
zm
=
dic
))
return
dic
def
move_dgl_to_cuda
(
g
):
g
.
ndata
.
update
({
k
:
cuda
(
g
.
ndata
[
k
])
for
k
in
g
.
ndata
})
g
.
edata
.
update
({
k
:
cuda
(
g
.
edata
[
k
])
for
k
in
g
.
edata
})
python/dgl/model_zoo/chem/layers.py
deleted
100644 → 0
View file @
94c67203
# -*- coding: utf-8 -*-
# pylint: disable=C0103, E1101, C0111
"""
The implementation of neural network layers used in SchNet and MGCN.
"""
import
torch
import
torch.nn
as
nn
from
torch.nn
import
Softplus
import
numpy
as
np
from
...
import
function
as
fn
class
AtomEmbedding
(
nn
.
Module
):
"""
Convert the atom(node) list to atom embeddings.
The atoms with the same element share the same initial embedding.
Parameters
----------
dim : int
Size of embeddings, default to be 128.
type_num : int
The largest atomic number of atoms in the dataset, default to be 100.
pre_train : None or pre-trained embeddings
Pre-trained embeddings, default to be None.
"""
def
__init__
(
self
,
dim
=
128
,
type_num
=
100
,
pre_train
=
None
):
super
(
AtomEmbedding
,
self
).
__init__
()
self
.
_dim
=
dim
self
.
_type_num
=
type_num
if
pre_train
is
not
None
:
self
.
embedding
=
nn
.
Embedding
.
from_pretrained
(
pre_train
,
padding_idx
=
0
)
else
:
self
.
embedding
=
nn
.
Embedding
(
type_num
,
dim
,
padding_idx
=
0
)
def
forward
(
self
,
atom_types
):
"""
Parameters
----------
atom_types : int64 tensor of shape (B1)
Types for atoms in the graph(s), B1 for the number of atoms.
Returns
-------
float32 tensor of shape (B1, self._dim)
Atom embeddings.
"""
return
self
.
embedding
(
atom_types
)
class
EdgeEmbedding
(
nn
.
Module
):
"""
Module for embedding edges. Edges linking same pairs of atoms share
the same initial embedding.
Parameters
----------
dim : int
Size of embeddings, default to be 128.
edge_num : int
Maximum number of edge types allowed, default to be 3000.
pre_train : Edge embeddings or None
Pre-trained edge embeddings, default to be None.
"""
def
__init__
(
self
,
dim
=
128
,
edge_num
=
3000
,
pre_train
=
None
):
super
(
EdgeEmbedding
,
self
).
__init__
()
self
.
_dim
=
dim
self
.
_edge_num
=
edge_num
if
pre_train
is
not
None
:
self
.
embedding
=
nn
.
Embedding
.
from_pretrained
(
pre_train
,
padding_idx
=
0
)
else
:
self
.
embedding
=
nn
.
Embedding
(
edge_num
,
dim
,
padding_idx
=
0
)
def
generate_edge_type
(
self
,
edges
):
"""Generate edge type.
The edge type is based on the type of the src & dst atom.
Note that directions are not distinguished, e.g. C-O and O-C are the same edge type.
To map a pair of nodes to one number, we use an unordered pairing function here
See more detail in this disscussion:
https://math.stackexchange.com/questions/23503/create-unique-number-from-2-numbers
Note that the edge_num should be larger than the square of maximum atomic number
in the dataset.
Parameters
----------
edges : EdgeBatch
Edges for deciding types
Returns
-------
dict
Stores the edge types in "type"
"""
atom_type_x
=
edges
.
src
[
'ntype'
]
atom_type_y
=
edges
.
dst
[
'ntype'
]
return
{
'etype'
:
atom_type_x
*
atom_type_y
+
\
(
torch
.
abs
(
atom_type_x
-
atom_type_y
)
-
1
)
**
2
/
4
}
def
forward
(
self
,
g
,
atom_types
):
"""Compute edge embeddings
Parameters
----------
g : DGLGraph
The graph to compute edge embeddings
atom_types : int64 tensor of shape (B1)
Types for atoms in the graph(s), B1 for the number of atoms.
Returns
-------
float32 tensor of shape (B2, self._dim)
Computed edge embeddings
"""
g
=
g
.
local_var
()
g
.
ndata
[
'ntype'
]
=
atom_types
g
.
apply_edges
(
self
.
generate_edge_type
)
return
self
.
embedding
(
g
.
edata
.
pop
(
'etype'
))
class
ShiftSoftplus
(
nn
.
Module
):
"""
ShiftSoftplus activation function:
1/beta * (log(1 + exp**(beta * x)) - log(shift))
Parameters
----------
beta : int
Default to be 1.
shift : int
Default to be 2.
threshold : int
Default to be 20.
"""
def
__init__
(
self
,
beta
=
1
,
shift
=
2
,
threshold
=
20
):
super
(
ShiftSoftplus
,
self
).
__init__
()
self
.
shift
=
shift
self
.
softplus
=
Softplus
(
beta
,
threshold
)
def
forward
(
self
,
x
):
"""Applies the activation function"""
return
self
.
softplus
(
x
)
-
np
.
log
(
float
(
self
.
shift
))
class
RBFLayer
(
nn
.
Module
):
"""
Radial basis functions Layer.
e(d) = exp(- gamma * ||d - mu_k||^2)
With the default parameters below, we are using a default settings:
* gamma = 10
* 0 <= mu_k <= 30 for k=1~300
Parameters
----------
low : int
Smallest value to take for mu_k, default to be 0.
high : int
Largest value to take for mu_k, default to be 30.
gap : float
Difference between two consecutive values for mu_k, default to be 0.1.
dim : int
Output size for each center, default to be 1.
"""
def
__init__
(
self
,
low
=
0
,
high
=
30
,
gap
=
0.1
,
dim
=
1
):
super
(
RBFLayer
,
self
).
__init__
()
self
.
_low
=
low
self
.
_high
=
high
self
.
_dim
=
dim
self
.
_n_centers
=
int
(
np
.
ceil
((
high
-
low
)
/
gap
))
centers
=
np
.
linspace
(
low
,
high
,
self
.
_n_centers
)
self
.
centers
=
torch
.
tensor
(
centers
,
dtype
=
torch
.
float
,
requires_grad
=
False
)
self
.
centers
=
nn
.
Parameter
(
self
.
centers
,
requires_grad
=
False
)
self
.
_fan_out
=
self
.
_dim
*
self
.
_n_centers
self
.
_gap
=
centers
[
1
]
-
centers
[
0
]
def
forward
(
self
,
edge_distances
):
"""
Parameters
----------
edge_distances : float32 tensor of shape (B, 1)
Edge distances, B for the number of edges.
Returns
-------
float32 tensor of shape (B, self._fan_out)
Computed RBF results
"""
radial
=
edge_distances
-
self
.
centers
coef
=
-
1
/
self
.
_gap
return
torch
.
exp
(
coef
*
(
radial
**
2
))
class
CFConv
(
nn
.
Module
):
"""
The continuous-filter convolution layer in SchNet.
Parameters
----------
rbf_dim : int
Dimension of the RBF layer output
dim : int
Dimension of output, default to be 64
act : activation function or None.
Activation function, default to be shifted softplus
"""
def
__init__
(
self
,
rbf_dim
,
dim
=
64
,
act
=
None
):
super
(
CFConv
,
self
).
__init__
()
self
.
_rbf_dim
=
rbf_dim
self
.
_dim
=
dim
if
act
is
None
:
activation
=
nn
.
Softplus
(
beta
=
0.5
,
threshold
=
14
)
else
:
activation
=
act
self
.
project
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
_rbf_dim
,
self
.
_dim
),
activation
,
nn
.
Linear
(
self
.
_dim
,
self
.
_dim
)
)
def
forward
(
self
,
g
,
node_weight
,
rbf_out
):
"""
Parameters
----------
g : DGLGraph
The graph for performing convolution
node_weight : float32 tensor of shape (B1, D1)
The weight of nodes in message passing, B1 for number of nodes and
D1 for node weight size.
rbf_out : float32 tensor of shape (B2, D2)
The output of RBFLayer, B2 for number of edges and D2 for rbf out size.
"""
g
=
g
.
local_var
()
e
=
self
.
project
(
rbf_out
)
g
.
ndata
[
'node_weight'
]
=
node_weight
g
.
edata
[
'e'
]
=
e
g
.
update_all
(
fn
.
u_mul_e
(
'node_weight'
,
'e'
,
'm'
),
fn
.
sum
(
'm'
,
'h'
))
return
g
.
ndata
.
pop
(
'h'
)
class
Interaction
(
nn
.
Module
):
"""
The interaction layer in the SchNet model.
Parameters
----------
rbf_dim : int
Dimension of the RBF layer output
dim : int
Dimension of intermediate representations
"""
def
__init__
(
self
,
rbf_dim
,
dim
):
super
(
Interaction
,
self
).
__init__
()
self
.
_dim
=
dim
self
.
node_layer1
=
nn
.
Linear
(
dim
,
dim
,
bias
=
False
)
self
.
cfconv
=
CFConv
(
rbf_dim
,
dim
,
Softplus
(
beta
=
0.5
,
threshold
=
14
))
self
.
node_layer2
=
nn
.
Sequential
(
nn
.
Linear
(
dim
,
dim
),
Softplus
(
beta
=
0.5
,
threshold
=
14
),
nn
.
Linear
(
dim
,
dim
)
)
def
forward
(
self
,
g
,
n_feat
,
rbf_out
):
"""
Parameters
----------
g : DGLGraph
The graph for performing convolution
n_feat : float32 tensor of shape (B1, D1)
Node features, B1 for number of nodes and D1 for feature size.
rbf_out : float32 tensor of shape (B2, D2)
The output of RBFLayer, B2 for number of edges and D2 for rbf out size.
Returns
-------
float32 tensor of shape (B1, D1)
Updated node representations
"""
n_weight
=
self
.
node_layer1
(
n_feat
)
new_n_feat
=
self
.
cfconv
(
g
,
n_weight
,
rbf_out
)
new_n_feat
=
self
.
node_layer2
(
new_n_feat
)
return
n_feat
+
new_n_feat
class
VEConv
(
nn
.
Module
):
"""
The Vertex-Edge convolution layer in MGCN which takes both edge & vertex features
in consideration.
Parameters
----------
rbf_dim : int
Size of the RBF layer output
dim : int
Size of intermediate representations, default to be 64.
update_edge : bool
Whether to apply a linear layer to update edge representations, default to be True.
"""
def
__init__
(
self
,
rbf_dim
,
dim
=
64
,
update_edge
=
True
):
super
(
VEConv
,
self
).
__init__
()
self
.
_rbf_dim
=
rbf_dim
self
.
_dim
=
dim
self
.
_update_edge
=
update_edge
self
.
update_rbf
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
_rbf_dim
,
self
.
_dim
),
nn
.
Softplus
(
beta
=
0.5
,
threshold
=
14
),
nn
.
Linear
(
self
.
_dim
,
self
.
_dim
)
)
self
.
update_efeat
=
nn
.
Linear
(
self
.
_dim
,
self
.
_dim
)
def
forward
(
self
,
g
,
n_feat
,
e_feat
,
rbf_out
):
"""
Parameters
----------
g : DGLGraph
The graph for performing convolution
n_feat : float32 tensor of shape (B1, D1)
Node features, B1 for number of nodes and D1 for feature size.
e_feat : float32 tensor of shape (B2, D2)
Edge features. B2 for number of edges and D2 for
the edge feature size.
rbf_out : float32 tensor of shape (B2, D3)
The output of RBFLayer, B2 for number of edges and D3 for rbf out size.
Returns
-------
n_feat : float32 tensor
Updated node features.
e_feat : float32 tensor
(Potentially updated) edge features
"""
rbf_out
=
self
.
update_rbf
(
rbf_out
)
if
self
.
_update_edge
:
e_feat
=
self
.
update_efeat
(
e_feat
)
g
=
g
.
local_var
()
g
.
ndata
.
update
({
'n_feat'
:
n_feat
})
g
.
edata
.
update
({
'rbf_out'
:
rbf_out
,
'e_feat'
:
e_feat
})
g
.
update_all
(
message_func
=
[
fn
.
u_mul_e
(
'n_feat'
,
'rbf_out'
,
'm_0'
),
fn
.
copy_e
(
'e_feat'
,
'm_1'
)],
reduce_func
=
[
fn
.
sum
(
'm_0'
,
'n_feat_0'
),
fn
.
sum
(
'm_1'
,
'n_feat_1'
)])
n_feat
=
g
.
ndata
.
pop
(
'n_feat_0'
)
+
g
.
ndata
.
pop
(
'n_feat_1'
)
return
n_feat
,
e_feat
class
MultiLevelInteraction
(
nn
.
Module
):
"""
The multilevel interaction in the MGCN model.
Parameters
----------
rbf_dim : int
Dimension of the RBF layer output
dim : int
Dimension of intermediate representations
"""
def
__init__
(
self
,
rbf_dim
,
dim
):
super
(
MultiLevelInteraction
,
self
).
__init__
()
self
.
_atom_dim
=
dim
self
.
node_layer1
=
nn
.
Linear
(
dim
,
dim
,
bias
=
True
)
self
.
conv_layer
=
VEConv
(
rbf_dim
,
dim
)
self
.
activation
=
nn
.
Softplus
(
beta
=
0.5
,
threshold
=
14
)
self
.
edge_layer1
=
nn
.
Linear
(
dim
,
dim
,
bias
=
True
)
self
.
node_out
=
nn
.
Sequential
(
nn
.
Linear
(
dim
,
dim
),
nn
.
Softplus
(
beta
=
0.5
,
threshold
=
14
),
nn
.
Linear
(
dim
,
dim
)
)
def
forward
(
self
,
g
,
n_feat
,
e_feat
,
rbf_out
):
"""
Parameters
----------
g : DGLGraph
The graph for performing convolution
n_feat : float32 tensor of shape (B1, D1)
Node features, B1 for number of nodes and D1 for feature size.
e_feat : float32 tensor of shape (B2, D2)
Edge features. B2 for number of edges and D2 for
the edge feature size.
rbf_out : float32 tensor of shape (B2, D3)
The output of RBFLayer, B2 for number of edges and D3 for rbf out size.
Returns
-------
n_feat : float32 tensor
Updated node representations
e_feat : float32 tensor
Updated edge representations
"""
new_n_feat
=
self
.
node_layer1
(
n_feat
)
new_n_feat
,
e_feat
=
self
.
conv_layer
(
g
,
new_n_feat
,
e_feat
,
rbf_out
)
new_n_feat
=
self
.
node_out
(
new_n_feat
)
n_feat
=
n_feat
+
new_n_feat
e_feat
=
self
.
activation
(
self
.
edge_layer1
(
e_feat
))
return
n_feat
,
e_feat
python/dgl/model_zoo/chem/mgcn.py
deleted
100644 → 0
View file @
94c67203
# -*- coding:utf-8 -*-
# pylint: disable=C0103, C0111, W0621
"""Implementation of MGCN model"""
import
torch
import
torch.nn
as
nn
from
.layers
import
AtomEmbedding
,
RBFLayer
,
EdgeEmbedding
,
\
MultiLevelInteraction
from
...nn.pytorch
import
SumPooling
from
...contrib.deprecation
import
deprecated
class
MGCNModel
(
nn
.
Module
):
"""
`Molecular Property Prediction: A Multilevel
Quantum Interactions Modeling Perspective <https://arxiv.org/abs/1906.11081>`__
Parameters
----------
dim : int
Size for embeddings, default to be 128.
width : int
Width in the RBF layer, default to be 1.
cutoff : float
The maximum distance between nodes, default to be 5.0.
edge_dim : int
Size for edge embedding, default to be 128.
out_put_dim: int
Number of target properties to predict, default to be 1.
n_conv : int
Number of convolutional layers, default to be 3.
norm : bool
Whether to perform normalization, default to be False.
atom_ref : Atom embeddings or None
If None, random representation initialization will be used. Otherwise,
they will be used to initialize atom representations. Default to be None.
pre_train : Atom embeddings or None
If None, random representation initialization will be used. Otherwise,
they will be used to initialize atom representations. Default to be None.
"""
@
deprecated
(
'Import MGCNPredictor from dgllife.model instead.'
,
'class'
)
def
__init__
(
self
,
dim
=
128
,
width
=
1
,
cutoff
=
5.0
,
edge_dim
=
128
,
output_dim
=
1
,
n_conv
=
3
,
norm
=
False
,
atom_ref
=
None
,
pre_train
=
None
):
super
(
MGCNModel
,
self
).
__init__
()
self
.
_dim
=
dim
self
.
output_dim
=
output_dim
self
.
edge_dim
=
edge_dim
self
.
cutoff
=
cutoff
self
.
width
=
width
self
.
n_conv
=
n_conv
self
.
atom_ref
=
atom_ref
self
.
norm
=
norm
if
pre_train
is
None
:
self
.
embedding_layer
=
AtomEmbedding
(
dim
)
else
:
self
.
embedding_layer
=
AtomEmbedding
(
pre_train
=
pre_train
)
self
.
rbf_layer
=
RBFLayer
(
0
,
cutoff
,
width
)
self
.
edge_embedding_layer
=
EdgeEmbedding
(
dim
=
edge_dim
)
if
atom_ref
is
not
None
:
self
.
e0
=
AtomEmbedding
(
1
,
pre_train
=
atom_ref
)
self
.
conv_layers
=
nn
.
ModuleList
([
MultiLevelInteraction
(
self
.
rbf_layer
.
_fan_out
,
dim
)
for
i
in
range
(
n_conv
)
])
self
.
out_project
=
nn
.
Sequential
(
nn
.
Linear
(
dim
*
(
self
.
n_conv
+
1
),
64
),
nn
.
Softplus
(
beta
=
1
,
threshold
=
20
),
nn
.
Linear
(
64
,
output_dim
)
)
self
.
readout
=
SumPooling
()
def
set_mean_std
(
self
,
mean
,
std
,
device
=
"cpu"
):
"""Set the mean and std of atom representations for normalization.
Parameters
----------
mean : list or numpy array
The mean of labels
std : list or numpy array
The std of labels
device : str or torch.device
Device for storing the mean and std
"""
self
.
mean_per_node
=
torch
.
tensor
(
mean
,
device
=
device
)
self
.
std_per_node
=
torch
.
tensor
(
std
,
device
=
device
)
def
forward
(
self
,
g
,
atom_types
,
edge_distances
):
"""Predict molecule labels
Parameters
----------
g : DGLGraph
Input DGLGraph for molecule(s)
atom_types : int64 tensor of shape (B1)
Types for atoms in the graph(s), B1 for the number of atoms.
edge_distances : float32 tensor of shape (B2, 1)
Edge distances, B2 for the number of edges.
Returns
-------
prediction : float32 tensor of shape (B, output_dim)
Model prediction for the batch of graphs, B for the number
of graphs, output_dim for the prediction size.
"""
h
=
self
.
embedding_layer
(
atom_types
)
e
=
self
.
edge_embedding_layer
(
g
,
atom_types
)
rbf_out
=
self
.
rbf_layer
(
edge_distances
)
all_layer_h
=
[
h
]
for
idx
in
range
(
self
.
n_conv
):
h
,
e
=
self
.
conv_layers
[
idx
](
g
,
h
,
e
,
rbf_out
)
all_layer_h
.
append
(
h
)
# concat multilevel representations
h
=
torch
.
cat
(
all_layer_h
,
dim
=
1
)
h
=
self
.
out_project
(
h
)
if
self
.
atom_ref
is
not
None
:
h_ref
=
self
.
e0
(
atom_types
)
h
=
h
+
h_ref
if
self
.
norm
:
h
=
h
*
self
.
std_per_node
+
self
.
mean_per_node
return
self
.
readout
(
g
,
h
)
Prev
1
…
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment