Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
36c7b771
Unverified
Commit
36c7b771
authored
Jun 05, 2020
by
Mufei Li
Committed by
GitHub
Jun 05, 2020
Browse files
[LifeSci] Move to Independent Repo (#1592)
* Move LifeSci * Remove doc
parent
94c67203
Changes
203
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
3463 deletions
+0
-3463
apps/life_sci/python/dgllife/model/model_zoo/attentivefp_predictor.py
...i/python/dgllife/model/model_zoo/attentivefp_predictor.py
+0
-87
apps/life_sci/python/dgllife/model/model_zoo/dgmg.py
apps/life_sci/python/dgllife/model/model_zoo/dgmg.py
+0
-830
apps/life_sci/python/dgllife/model/model_zoo/gat_predictor.py
.../life_sci/python/dgllife/model/model_zoo/gat_predictor.py
+0
-112
apps/life_sci/python/dgllife/model/model_zoo/gcn_predictor.py
.../life_sci/python/dgllife/model/model_zoo/gcn_predictor.py
+0
-90
apps/life_sci/python/dgllife/model/model_zoo/gin_predictor.py
.../life_sci/python/dgllife/model/model_zoo/gin_predictor.py
+0
-122
apps/life_sci/python/dgllife/model/model_zoo/jtnn/__init__.py
.../life_sci/python/dgllife/model/model_zoo/jtnn/__init__.py
+0
-2
apps/life_sci/python/dgllife/model/model_zoo/jtnn/chemutils.py
...life_sci/python/dgllife/model/model_zoo/jtnn/chemutils.py
+0
-377
apps/life_sci/python/dgllife/model/model_zoo/jtnn/jtmpn.py
apps/life_sci/python/dgllife/model/model_zoo/jtnn/jtmpn.py
+0
-256
apps/life_sci/python/dgllife/model/model_zoo/jtnn/jtnn_dec.py
.../life_sci/python/dgllife/model/model_zoo/jtnn/jtnn_dec.py
+0
-390
apps/life_sci/python/dgllife/model/model_zoo/jtnn/jtnn_enc.py
.../life_sci/python/dgllife/model/model_zoo/jtnn/jtnn_enc.py
+0
-122
apps/life_sci/python/dgllife/model/model_zoo/jtnn/jtnn_vae.py
.../life_sci/python/dgllife/model/model_zoo/jtnn/jtnn_vae.py
+0
-327
apps/life_sci/python/dgllife/model/model_zoo/jtnn/mol_tree.py
.../life_sci/python/dgllife/model/model_zoo/jtnn/mol_tree.py
+0
-26
apps/life_sci/python/dgllife/model/model_zoo/jtnn/mol_tree_nx.py
...fe_sci/python/dgllife/model/model_zoo/jtnn/mol_tree_nx.py
+0
-126
apps/life_sci/python/dgllife/model/model_zoo/jtnn/mpn.py
apps/life_sci/python/dgllife/model/model_zoo/jtnn/mpn.py
+0
-181
apps/life_sci/python/dgllife/model/model_zoo/jtnn/nnutils.py
apps/life_sci/python/dgllife/model/model_zoo/jtnn/nnutils.py
+0
-54
apps/life_sci/python/dgllife/model/model_zoo/mgcn_predictor.py
...life_sci/python/dgllife/model/model_zoo/mgcn_predictor.py
+0
-69
apps/life_sci/python/dgllife/model/model_zoo/mlp_predictor.py
.../life_sci/python/dgllife/model/model_zoo/mlp_predictor.py
+0
-49
apps/life_sci/python/dgllife/model/model_zoo/mpnn_predictor.py
...life_sci/python/dgllife/model/model_zoo/mpnn_predictor.py
+0
-81
apps/life_sci/python/dgllife/model/model_zoo/schnet_predictor.py
...fe_sci/python/dgllife/model/model_zoo/schnet_predictor.py
+0
-64
apps/life_sci/python/dgllife/model/model_zoo/weave_predictor.py
...ife_sci/python/dgllife/model/model_zoo/weave_predictor.py
+0
-98
No files found.
apps/life_sci/python/dgllife/model/model_zoo/attentivefp_predictor.py
deleted
100644 → 0
View file @
94c67203
"""AttentiveFP"""
# pylint: disable= no-member, arguments-differ, invalid-name
import
torch.nn
as
nn
from
..gnn
import
AttentiveFPGNN
from
..readout
import
AttentiveFPReadout
__all__
=
[
'AttentiveFPPredictor'
]
# pylint: disable=W0221
class
AttentiveFPPredictor
(
nn
.
Module
):
"""AttentiveFP for regression and classification on graphs.
AttentiveFP is introduced in `Pushing the Boundaries of Molecular Representation for Drug
Discovery with the Graph Attention Mechanism.
<https://www.ncbi.nlm.nih.gov/pubmed/31408336>`__
Parameters
----------
node_feat_size : int
Size for the input node features.
edge_feat_size : int
Size for the input edge features.
num_layers : int
Number of GNN layers. Default to 2.
num_timesteps : int
Times of updating the graph representations with GRU. Default to 2.
graph_feat_size : int
Size for the learned graph representations. Default to 200.
n_tasks : int
Number of tasks, which is also the output size. Default to 1.
dropout : float
Probability for performing the dropout. Default to 0.
"""
def
__init__
(
self
,
node_feat_size
,
edge_feat_size
,
num_layers
=
2
,
num_timesteps
=
2
,
graph_feat_size
=
200
,
n_tasks
=
1
,
dropout
=
0.
):
super
(
AttentiveFPPredictor
,
self
).
__init__
()
self
.
gnn
=
AttentiveFPGNN
(
node_feat_size
=
node_feat_size
,
edge_feat_size
=
edge_feat_size
,
num_layers
=
num_layers
,
graph_feat_size
=
graph_feat_size
,
dropout
=
dropout
)
self
.
readout
=
AttentiveFPReadout
(
feat_size
=
graph_feat_size
,
num_timesteps
=
num_timesteps
,
dropout
=
dropout
)
self
.
predict
=
nn
.
Sequential
(
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
graph_feat_size
,
n_tasks
)
)
def
forward
(
self
,
g
,
node_feats
,
edge_feats
,
get_node_weight
=
False
):
"""Graph-level regression/soft classification.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_feat_size)
Input node features. V for the number of nodes.
edge_feats : float32 tensor of shape (E, edge_feat_size)
Input edge features. E for the number of edges.
get_node_weight : bool
Whether to get the weights of atoms during readout. Default to False.
Returns
-------
float32 tensor of shape (G, n_tasks)
Prediction for the graphs in the batch. G for the number of graphs.
node_weights : list of float32 tensor of shape (V, 1), optional
This is returned when ``get_node_weight`` is ``True``.
The list has a length ``num_timesteps`` and ``node_weights[i]``
gives the node weights in the i-th update.
"""
node_feats
=
self
.
gnn
(
g
,
node_feats
,
edge_feats
)
if
get_node_weight
:
g_feats
,
node_weights
=
self
.
readout
(
g
,
node_feats
,
get_node_weight
)
return
self
.
predict
(
g_feats
),
node_weights
else
:
g_feats
=
self
.
readout
(
g
,
node_feats
,
get_node_weight
)
return
self
.
predict
(
g_feats
)
apps/life_sci/python/dgllife/model/model_zoo/dgmg.py
deleted
100644 → 0
View file @
94c67203
# pylint: disable=C0103, W0622, R1710, W0104, E1101, W0221, C0411
"""
Learning Deep Generative Models of Graphs
https://arxiv.org/pdf/1803.03324.pdf
"""
import
dgl
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.init
as
init
from
dgl
import
DGLGraph
from
functools
import
partial
from
rdkit
import
Chem
from
torch.distributions
import
Categorical
__all__
=
[
'DGMG'
]
class
MoleculeEnv
(
object
):
"""MDP environment for generating molecules.
Parameters
----------
atom_types : list
E.g. ['C', 'N']
bond_types : list
E.g. [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
"""
def
__init__
(
self
,
atom_types
,
bond_types
):
super
(
MoleculeEnv
,
self
).
__init__
()
self
.
atom_types
=
atom_types
self
.
bond_types
=
bond_types
self
.
atom_type_to_id
=
dict
()
self
.
bond_type_to_id
=
dict
()
for
id
,
a_type
in
enumerate
(
atom_types
):
self
.
atom_type_to_id
[
a_type
]
=
id
for
id
,
b_type
in
enumerate
(
bond_types
):
self
.
bond_type_to_id
[
b_type
]
=
id
def
get_decision_sequence
(
self
,
mol
,
atom_order
):
"""Extract a decision sequence with which DGMG can generate the
molecule with a specified atom order.
Parameters
----------
mol : Chem.rdchem.Mol
atom_order : list
Specifies a mapping between the original atom
indices and the new atom indices. In particular,
atom_order[i] is re-labeled as i.
Returns
-------
decisions : list
decisions[i] is a 2-tuple (i, j)
- If i = 0, j specifies either the type of the atom to add
self.atom_types[j] or termination with j = len(self.atom_types)
- If i = 1, j specifies either the type of the bond to add
self.bond_types[j] or termination with j = len(self.bond_types)
- If i = 2, j specifies the destination atom id for the bond to add.
With the formulation of DGMG, j must be created before the decision.
"""
decisions
=
[]
old2new
=
dict
()
for
new_id
,
old_id
in
enumerate
(
atom_order
):
atom
=
mol
.
GetAtomWithIdx
(
old_id
)
a_type
=
atom
.
GetSymbol
()
decisions
.
append
((
0
,
self
.
atom_type_to_id
[
a_type
]))
for
bond
in
atom
.
GetBonds
():
u
=
bond
.
GetBeginAtomIdx
()
v
=
bond
.
GetEndAtomIdx
()
if
v
==
old_id
:
u
,
v
=
v
,
u
if
v
in
old2new
:
decisions
.
append
((
1
,
self
.
bond_type_to_id
[
bond
.
GetBondType
()]))
decisions
.
append
((
2
,
old2new
[
v
]))
decisions
.
append
((
1
,
len
(
self
.
bond_types
)))
old2new
[
old_id
]
=
new_id
decisions
.
append
((
0
,
len
(
self
.
atom_types
)))
return
decisions
def
reset
(
self
,
rdkit_mol
=
False
):
"""Setup for generating a new molecule
Parameters
----------
rdkit_mol : bool
Whether to keep a Chem.rdchem.Mol object so
that we know what molecule is being generated
"""
self
.
dgl_graph
=
DGLGraph
()
# If there are some features for nodes and edges,
# zero tensors will be set for those of new nodes and edges.
self
.
dgl_graph
.
set_n_initializer
(
dgl
.
frame
.
zero_initializer
)
self
.
dgl_graph
.
set_e_initializer
(
dgl
.
frame
.
zero_initializer
)
self
.
mol
=
None
if
rdkit_mol
:
# RWMol is a molecule class that is intended to be edited.
self
.
mol
=
Chem
.
RWMol
(
Chem
.
MolFromSmiles
(
''
))
def
num_atoms
(
self
):
"""Get the number of atoms for the current molecule.
Returns
-------
int
"""
return
self
.
dgl_graph
.
number_of_nodes
()
def
add_atom
(
self
,
type
):
"""Add an atom of the specified type.
Parameters
----------
type : int
Should be in the range of [0, len(self.atom_types) - 1]
"""
self
.
dgl_graph
.
add_nodes
(
1
)
if
self
.
mol
is
not
None
:
self
.
mol
.
AddAtom
(
Chem
.
Atom
(
self
.
atom_types
[
type
]))
def
add_bond
(
self
,
u
,
v
,
type
,
bi_direction
=
True
):
"""Add a bond of the specified type between atom u and v.
Parameters
----------
u : int
Index for the first atom
v : int
Index for the second atom
type : int
Index for the bond type
bi_direction : bool
Whether to add edges for both directions in the DGLGraph.
If not, we will only add the edge (u, v).
"""
if
bi_direction
:
self
.
dgl_graph
.
add_edges
([
u
,
v
],
[
v
,
u
])
else
:
self
.
dgl_graph
.
add_edge
(
u
,
v
)
if
self
.
mol
is
not
None
:
self
.
mol
.
AddBond
(
u
,
v
,
self
.
bond_types
[
type
])
def
get_current_smiles
(
self
):
"""Get the generated molecule in SMILES
Returns
-------
s : str
SMILES
"""
assert
self
.
mol
is
not
None
,
'Expect a Chem.rdchem.Mol object initialized.'
s
=
Chem
.
MolToSmiles
(
self
.
mol
)
return
s
class
GraphEmbed
(
nn
.
Module
):
"""Compute a molecule representations out of atom representations.
Parameters
----------
node_hidden_size : int
Size of atom representation
"""
def
__init__
(
self
,
node_hidden_size
):
super
(
GraphEmbed
,
self
).
__init__
()
# Setting from the paper
self
.
graph_hidden_size
=
2
*
node_hidden_size
# Embed graphs
self
.
node_gating
=
nn
.
Sequential
(
nn
.
Linear
(
node_hidden_size
,
1
),
nn
.
Sigmoid
()
)
self
.
node_to_graph
=
nn
.
Linear
(
node_hidden_size
,
self
.
graph_hidden_size
)
def
forward
(
self
,
g
):
"""
Parameters
----------
g : DGLGraph
Current molecule graph
Returns
-------
tensor of dtype float32 and shape (1, self.graph_hidden_size)
Computed representation for the current molecule graph
"""
if
g
.
number_of_nodes
()
==
0
:
# Use a zero tensor for an empty molecule.
return
torch
.
zeros
(
1
,
self
.
graph_hidden_size
)
else
:
# Node features are stored as hv in ndata.
hvs
=
g
.
ndata
[
'hv'
]
return
(
self
.
node_gating
(
hvs
)
*
self
.
node_to_graph
(
hvs
)).
sum
(
0
,
keepdim
=
True
)
class
GraphProp
(
nn
.
Module
):
"""Perform message passing over a molecule graph and update its atom representations.
Parameters
----------
num_prop_rounds : int
Number of message passing rounds for each time
node_hidden_size : int
Size of atom representation
edge_hidden_size : int
Size of bond representation
"""
def
__init__
(
self
,
num_prop_rounds
,
node_hidden_size
,
edge_hidden_size
):
super
(
GraphProp
,
self
).
__init__
()
self
.
num_prop_rounds
=
num_prop_rounds
# Setting from the paper
self
.
node_activation_hidden_size
=
2
*
node_hidden_size
message_funcs
=
[]
self
.
reduce_funcs
=
[]
node_update_funcs
=
[]
for
t
in
range
(
num_prop_rounds
):
# input being [hv, hu, xuv]
message_funcs
.
append
(
nn
.
Linear
(
2
*
node_hidden_size
+
edge_hidden_size
,
self
.
node_activation_hidden_size
))
self
.
reduce_funcs
.
append
(
partial
(
self
.
dgmg_reduce
,
round
=
t
))
node_update_funcs
.
append
(
nn
.
GRUCell
(
self
.
node_activation_hidden_size
,
node_hidden_size
))
self
.
message_funcs
=
nn
.
ModuleList
(
message_funcs
)
self
.
node_update_funcs
=
nn
.
ModuleList
(
node_update_funcs
)
def
dgmg_msg
(
self
,
edges
):
"""For an edge u->v, send a message concat([h_u, x_uv])
Parameters
----------
edges : batch of edges
Returns
-------
dict
Dictionary containing messages for the edge batch,
with the messages being tensors of shape (B, F1),
B for the number of edges and F1 for the message size.
"""
return
{
'm'
:
torch
.
cat
([
edges
.
src
[
'hv'
],
edges
.
data
[
'he'
]],
dim
=
1
)}
def
dgmg_reduce
(
self
,
nodes
,
round
):
"""Aggregate messages.
Parameters
----------
nodes : batch of nodes
round : int
Update round
Returns
-------
dict
Dictionary containing aggregated messages for each node
in the batch, with the messages being tensors of shape
(B, F2), B for the number of nodes and F2 for the aggregated
message size
"""
hv_old
=
nodes
.
data
[
'hv'
]
m
=
nodes
.
mailbox
[
'm'
]
# Make copies of original atom representations to match the
# number of messages.
message
=
torch
.
cat
([
hv_old
.
unsqueeze
(
1
).
expand
(
-
1
,
m
.
size
(
1
),
-
1
),
m
],
dim
=
2
)
node_activation
=
(
self
.
message_funcs
[
round
](
message
)).
sum
(
1
)
return
{
'a'
:
node_activation
}
def
forward
(
self
,
g
):
"""
Parameters
----------
g : DGLGraph
"""
if
g
.
number_of_edges
()
==
0
:
return
else
:
for
t
in
range
(
self
.
num_prop_rounds
):
g
.
update_all
(
message_func
=
self
.
dgmg_msg
,
reduce_func
=
self
.
reduce_funcs
[
t
])
g
.
ndata
[
'hv'
]
=
self
.
node_update_funcs
[
t
](
g
.
ndata
[
'a'
],
g
.
ndata
[
'hv'
])
class
AddNode
(
nn
.
Module
):
"""Stop or add an atom of a particular type.
Parameters
----------
env : MoleculeEnv
Environment for generating molecules
graph_embed_func : callable taking g as input
Function for computing molecule representation
node_hidden_size : int
Size of atom representation
dropout : float
Probability for dropout
"""
def
__init__
(
self
,
env
,
graph_embed_func
,
node_hidden_size
,
dropout
):
super
(
AddNode
,
self
).
__init__
()
self
.
env
=
env
n_node_types
=
len
(
env
.
atom_types
)
self
.
graph_op
=
{
'embed'
:
graph_embed_func
}
self
.
stop
=
n_node_types
self
.
add_node
=
nn
.
Sequential
(
nn
.
Linear
(
graph_embed_func
.
graph_hidden_size
,
graph_embed_func
.
graph_hidden_size
),
nn
.
Dropout
(
p
=
dropout
),
nn
.
Linear
(
graph_embed_func
.
graph_hidden_size
,
n_node_types
+
1
)
)
# If to add a node, initialize its hv
self
.
node_type_embed
=
nn
.
Embedding
(
n_node_types
,
node_hidden_size
)
self
.
initialize_hv
=
nn
.
Linear
(
node_hidden_size
+
\
graph_embed_func
.
graph_hidden_size
,
node_hidden_size
)
self
.
init_node_activation
=
torch
.
zeros
(
1
,
2
*
node_hidden_size
)
self
.
dropout
=
nn
.
Dropout
(
p
=
dropout
)
def
_initialize_node_repr
(
self
,
g
,
node_type
,
graph_embed
):
"""Initialize atom representation
Parameters
----------
g : DGLGraph
node_type : int
Index for the type of the new atom
graph_embed : tensor of dtype float32
Molecule representation
"""
num_nodes
=
g
.
number_of_nodes
()
hv_init
=
torch
.
cat
([
self
.
node_type_embed
(
torch
.
LongTensor
([
node_type
])),
graph_embed
],
dim
=
1
)
hv_init
=
self
.
dropout
(
hv_init
)
hv_init
=
self
.
initialize_hv
(
hv_init
)
g
.
nodes
[
num_nodes
-
1
].
data
[
'hv'
]
=
hv_init
g
.
nodes
[
num_nodes
-
1
].
data
[
'a'
]
=
self
.
init_node_activation
def
prepare_log_prob
(
self
,
compute_log_prob
):
"""Setup for returning log likelihood
Parameters
----------
compute_log_prob : bool
Whether to compute log likelihood
"""
if
compute_log_prob
:
self
.
log_prob
=
[]
self
.
compute_log_prob
=
compute_log_prob
def
forward
(
self
,
action
=
None
):
"""
Parameters
----------
action : None or int
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
Returns
-------
stop : bool
Whether we stop adding new atoms
"""
g
=
self
.
env
.
dgl_graph
graph_embed
=
self
.
graph_op
[
'embed'
](
g
)
logits
=
self
.
add_node
(
graph_embed
).
view
(
1
,
-
1
)
probs
=
F
.
softmax
(
logits
,
dim
=
1
)
if
action
is
None
:
action
=
Categorical
(
probs
).
sample
().
item
()
stop
=
bool
(
action
==
self
.
stop
)
if
not
stop
:
self
.
env
.
add_atom
(
action
)
self
.
_initialize_node_repr
(
g
,
action
,
graph_embed
)
if
self
.
compute_log_prob
:
sample_log_prob
=
F
.
log_softmax
(
logits
,
dim
=
1
)[:,
action
:
action
+
1
]
self
.
log_prob
.
append
(
sample_log_prob
)
return
stop
class
AddEdge
(
nn
.
Module
):
"""Stop or add a bond of a particular type.
Parameters
----------
env : MoleculeEnv
Environment for generating molecules
graph_embed_func : callable taking g as input
Function for computing molecule representation
node_hidden_size : int
Size of atom representation
dropout : float
Probability for dropout
"""
def
__init__
(
self
,
env
,
graph_embed_func
,
node_hidden_size
,
dropout
):
super
(
AddEdge
,
self
).
__init__
()
self
.
env
=
env
n_bond_types
=
len
(
env
.
bond_types
)
self
.
stop
=
n_bond_types
self
.
graph_op
=
{
'embed'
:
graph_embed_func
}
self
.
add_edge
=
nn
.
Sequential
(
nn
.
Linear
(
graph_embed_func
.
graph_hidden_size
+
node_hidden_size
,
graph_embed_func
.
graph_hidden_size
+
node_hidden_size
),
nn
.
Dropout
(
p
=
dropout
),
nn
.
Linear
(
graph_embed_func
.
graph_hidden_size
+
node_hidden_size
,
n_bond_types
+
1
)
)
def
prepare_log_prob
(
self
,
compute_log_prob
):
"""Setup for returning log likelihood
Parameters
----------
compute_log_prob : bool
Whether to compute log likelihood
"""
if
compute_log_prob
:
self
.
log_prob
=
[]
self
.
compute_log_prob
=
compute_log_prob
def
forward
(
self
,
action
=
None
):
"""
Parameters
----------
action : None or int
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
Returns
-------
stop : bool
Whether we stop adding new bonds
action : int
The type for the new bond
"""
g
=
self
.
env
.
dgl_graph
graph_embed
=
self
.
graph_op
[
'embed'
](
g
)
src_embed
=
g
.
nodes
[
g
.
number_of_nodes
()
-
1
].
data
[
'hv'
]
logits
=
self
.
add_edge
(
torch
.
cat
([
graph_embed
,
src_embed
],
dim
=
1
))
probs
=
F
.
softmax
(
logits
,
dim
=
1
)
if
action
is
None
:
action
=
Categorical
(
probs
).
sample
().
item
()
stop
=
bool
(
action
==
self
.
stop
)
if
self
.
compute_log_prob
:
sample_log_prob
=
F
.
log_softmax
(
logits
,
dim
=
1
)[:,
action
:
action
+
1
]
self
.
log_prob
.
append
(
sample_log_prob
)
return
stop
,
action
class
ChooseDestAndUpdate
(
nn
.
Module
):
"""Choose the atom to connect for the new bond.
Parameters
----------
env : MoleculeEnv
Environment for generating molecules
graph_prop_func : callable taking g as input
Function for performing message passing
and updating atom representations
node_hidden_size : int
Size of atom representation
dropout : float
Probability for dropout
"""
def
__init__
(
self
,
env
,
graph_prop_func
,
node_hidden_size
,
dropout
):
super
(
ChooseDestAndUpdate
,
self
).
__init__
()
self
.
env
=
env
n_bond_types
=
len
(
self
.
env
.
bond_types
)
# To be used for one-hot encoding of bond type
self
.
bond_embedding
=
torch
.
eye
(
n_bond_types
)
self
.
graph_op
=
{
'prop'
:
graph_prop_func
}
self
.
choose_dest
=
nn
.
Sequential
(
nn
.
Linear
(
2
*
node_hidden_size
+
n_bond_types
,
2
*
node_hidden_size
+
n_bond_types
),
nn
.
Dropout
(
p
=
dropout
),
nn
.
Linear
(
2
*
node_hidden_size
+
n_bond_types
,
1
)
)
def
_initialize_edge_repr
(
self
,
g
,
src_list
,
dest_list
,
edge_embed
):
"""Initialize bond representation
Parameters
----------
g : DGLGraph
src_list : list of int
source atoms for new bonds
dest_list : list of int
destination atoms for new bonds
edge_embed : 2D tensor of dtype float32
Embeddings for the new bonds
"""
g
.
edges
[
src_list
,
dest_list
].
data
[
'he'
]
=
edge_embed
.
expand
(
len
(
src_list
),
-
1
)
def
prepare_log_prob
(
self
,
compute_log_prob
):
"""Setup for returning log likelihood
Parameters
----------
compute_log_prob : bool
Whether to compute log likelihood
"""
if
compute_log_prob
:
self
.
log_prob
=
[]
self
.
compute_log_prob
=
compute_log_prob
def
forward
(
self
,
bond_type
,
dest
):
"""
Parameters
----------
bond_type : int
The type for the new bond
dest : int or None
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
"""
g
=
self
.
env
.
dgl_graph
src
=
g
.
number_of_nodes
()
-
1
possible_dests
=
range
(
src
)
src_embed_expand
=
g
.
nodes
[
src
].
data
[
'hv'
].
expand
(
src
,
-
1
)
possible_dests_embed
=
g
.
nodes
[
possible_dests
].
data
[
'hv'
]
edge_embed
=
self
.
bond_embedding
[
bond_type
:
bond_type
+
1
]
dests_scores
=
self
.
choose_dest
(
torch
.
cat
([
possible_dests_embed
,
src_embed_expand
,
edge_embed
.
expand
(
src
,
-
1
)],
dim
=
1
)).
view
(
1
,
-
1
)
dests_probs
=
F
.
softmax
(
dests_scores
,
dim
=
1
)
if
dest
is
None
:
dest
=
Categorical
(
dests_probs
).
sample
().
item
()
if
not
g
.
has_edge_between
(
src
,
dest
):
# For undirected graphs, we add edges for both directions
# so that we can perform graph propagation.
src_list
=
[
src
,
dest
]
dest_list
=
[
dest
,
src
]
self
.
env
.
add_bond
(
src
,
dest
,
bond_type
)
self
.
_initialize_edge_repr
(
g
,
src_list
,
dest_list
,
edge_embed
)
# Perform message passing when new bonds are added.
self
.
graph_op
[
'prop'
](
g
)
if
self
.
compute_log_prob
:
if
dests_probs
.
nelement
()
>
1
:
self
.
log_prob
.
append
(
F
.
log_softmax
(
dests_scores
,
dim
=
1
)[:,
dest
:
dest
+
1
])
def
weights_init
(
m
):
'''Function to initialize weights for models
Code from https://gist.github.com/jeasinema/ed9236ce743c8efaf30fa2ff732749f5
Usage:
model = Model()
model.apply(weight_init)
'''
if
isinstance
(
m
,
nn
.
Linear
):
init
.
xavier_normal_
(
m
.
weight
.
data
)
init
.
normal_
(
m
.
bias
.
data
)
elif
isinstance
(
m
,
nn
.
GRUCell
):
for
param
in
m
.
parameters
():
if
len
(
param
.
shape
)
>=
2
:
init
.
orthogonal_
(
param
.
data
)
else
:
init
.
normal_
(
param
.
data
)
def
dgmg_message_weight_init
(
m
):
"""Weight initialization for graph propagation module
These are suggested by the author. This should only be used for
the message passing functions, i.e. fe's in the paper.
"""
def
_weight_init
(
m
):
if
isinstance
(
m
,
nn
.
Linear
):
init
.
normal_
(
m
.
weight
.
data
,
std
=
1.
/
10
)
init
.
normal_
(
m
.
bias
.
data
,
std
=
1.
/
10
)
else
:
raise
ValueError
(
'Expected the input to be of type nn.Linear!'
)
if
isinstance
(
m
,
nn
.
ModuleList
):
for
layer
in
m
:
layer
.
apply
(
_weight_init
)
else
:
m
.
apply
(
_weight_init
)
class
DGMG
(
nn
.
Module
):
"""DGMG model
`Learning Deep Generative Models of Graphs <https://arxiv.org/abs/1803.03324>`__
Users only need to initialize an instance of this class.
Parameters
----------
atom_types : list
E.g. ['C', 'N'].
bond_types : list
E.g. [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC].
node_hidden_size : int
Size of atom representation. Default to 128.
num_prop_rounds : int
Number of message passing rounds for each time. Default to 2.
dropout : float
Probability for dropout. Default to 0.2.
"""
def
__init__
(
self
,
atom_types
,
bond_types
,
node_hidden_size
=
128
,
num_prop_rounds
=
2
,
dropout
=
0.2
):
super
(
DGMG
,
self
).
__init__
()
self
.
env
=
MoleculeEnv
(
atom_types
,
bond_types
)
# Graph embedding module
self
.
graph_embed
=
GraphEmbed
(
node_hidden_size
)
# Graph propagation module
# For one-hot encoding, edge_hidden_size is just the number of bond types
self
.
graph_prop
=
GraphProp
(
num_prop_rounds
,
node_hidden_size
,
len
(
self
.
env
.
bond_types
))
# Actions
self
.
add_node_agent
=
AddNode
(
self
.
env
,
self
.
graph_embed
,
node_hidden_size
,
dropout
)
self
.
add_edge_agent
=
AddEdge
(
self
.
env
,
self
.
graph_embed
,
node_hidden_size
,
dropout
)
self
.
choose_dest_agent
=
ChooseDestAndUpdate
(
self
.
env
,
self
.
graph_prop
,
node_hidden_size
,
dropout
)
# Weight initialization
self
.
init_weights
()
def
init_weights
(
self
):
"""Initialize model weights"""
self
.
graph_embed
.
apply
(
weights_init
)
self
.
graph_prop
.
apply
(
weights_init
)
self
.
add_node_agent
.
apply
(
weights_init
)
self
.
add_edge_agent
.
apply
(
weights_init
)
self
.
choose_dest_agent
.
apply
(
weights_init
)
self
.
graph_prop
.
message_funcs
.
apply
(
dgmg_message_weight_init
)
def
count_step
(
self
):
"""Increment the step by 1."""
self
.
step_count
+=
1
def
prepare_log_prob
(
self
,
compute_log_prob
):
"""Setup for returning log likelihood
Parameters
----------
compute_log_prob : bool
Whether to compute log likelihood
"""
self
.
compute_log_prob
=
compute_log_prob
self
.
add_node_agent
.
prepare_log_prob
(
compute_log_prob
)
self
.
add_edge_agent
.
prepare_log_prob
(
compute_log_prob
)
self
.
choose_dest_agent
.
prepare_log_prob
(
compute_log_prob
)
def
add_node_and_update
(
self
,
a
=
None
):
"""Decide if to add a new atom.
If a new atom should be added, update the graph.
Parameters
----------
a : None or int
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
"""
self
.
count_step
()
return
self
.
add_node_agent
(
a
)
def
add_edge_or_not
(
self
,
a
=
None
):
"""Decide if to add a new bond.
Parameters
----------
a : None or int
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
"""
self
.
count_step
()
return
self
.
add_edge_agent
(
a
)
def
choose_dest_and_update
(
self
,
bond_type
,
a
=
None
):
"""Choose destination and connect it to the latest atom.
Add edges for both directions and update the graph.
Parameters
----------
bond_type : int
The type of the new bond to add
a : None or int
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
"""
self
.
count_step
()
self
.
choose_dest_agent
(
bond_type
,
a
)
def
get_log_prob
(
self
):
"""Compute the log likelihood for the decision sequence,
typically corresponding to the generation of a molecule.
Returns
-------
torch.tensor consisting of a float only
"""
return
torch
.
cat
(
self
.
add_node_agent
.
log_prob
).
sum
()
\
+
torch
.
cat
(
self
.
add_edge_agent
.
log_prob
).
sum
()
\
+
torch
.
cat
(
self
.
choose_dest_agent
.
log_prob
).
sum
()
def
teacher_forcing
(
self
,
actions
):
"""Generate a molecule according to a sequence of actions.
Parameters
----------
actions : list of 2-tuples of int
actions[t] gives (i, j), the action to execute by DGMG at timestep t.
- If i = 0, j specifies either the type of the atom to add or termination
- If i = 1, j specifies either the type of the bond to add or termination
- If i = 2, j specifies the destination atom id for the bond to add.
With the formulation of DGMG, j must be created before the decision.
"""
stop_node
=
self
.
add_node_and_update
(
a
=
actions
[
self
.
step_count
][
1
])
while
not
stop_node
:
# A new atom was just added.
stop_edge
,
bond_type
=
self
.
add_edge_or_not
(
a
=
actions
[
self
.
step_count
][
1
])
while
not
stop_edge
:
# A new bond is to be added.
self
.
choose_dest_and_update
(
bond_type
,
a
=
actions
[
self
.
step_count
][
1
])
stop_edge
,
bond_type
=
self
.
add_edge_or_not
(
a
=
actions
[
self
.
step_count
][
1
])
stop_node
=
self
.
add_node_and_update
(
a
=
actions
[
self
.
step_count
][
1
])
def
rollout
(
self
,
max_num_steps
):
"""Sample a molecule from the distribution learned by DGMG."""
stop_node
=
self
.
add_node_and_update
()
while
(
not
stop_node
)
and
(
self
.
step_count
<=
max_num_steps
):
stop_edge
,
bond_type
=
self
.
add_edge_or_not
()
if
self
.
env
.
num_atoms
()
==
1
:
stop_edge
=
True
while
(
not
stop_edge
)
and
(
self
.
step_count
<=
max_num_steps
):
self
.
choose_dest_and_update
(
bond_type
)
stop_edge
,
bond_type
=
self
.
add_edge_or_not
()
stop_node
=
self
.
add_node_and_update
()
def
forward
(
self
,
actions
=
None
,
rdkit_mol
=
False
,
compute_log_prob
=
False
,
max_num_steps
=
400
):
"""
Parameters
----------
actions : list of 2-tuples or None.
If actions are not None, generate a molecule according to actions.
Otherwise, a molecule will be generated based on sampled actions.
rdkit_mol : bool
Whether to maintain a Chem.rdchem.Mol object. This brings extra
computational cost, but is necessary if we are interested in
learning the generated molecule.
compute_log_prob : bool
Whether to compute log likelihood
max_num_steps : int
Maximum number of steps allowed. This only comes into effect
during inference and prevents the model from not stopping.
Returns
-------
torch.tensor consisting of a float only, optional
The log likelihood for the actions taken
str, optional
The generated molecule in the form of SMILES
"""
# Initialize an empty molecule
self
.
step_count
=
0
self
.
env
.
reset
(
rdkit_mol
=
rdkit_mol
)
self
.
prepare_log_prob
(
compute_log_prob
)
if
actions
is
not
None
:
# A sequence of decisions is given, use teacher forcing
self
.
teacher_forcing
(
actions
)
else
:
# Sample a molecule from the distribution learned by DGMG
self
.
rollout
(
max_num_steps
)
if
compute_log_prob
and
rdkit_mol
:
return
self
.
get_log_prob
(),
self
.
env
.
get_current_smiles
()
if
compute_log_prob
:
return
self
.
get_log_prob
()
if
rdkit_mol
:
return
self
.
env
.
get_current_smiles
()
apps/life_sci/python/dgllife/model/model_zoo/gat_predictor.py
deleted
100644 → 0
View file @
94c67203
"""GAT-based model for regression and classification on graphs."""
# pylint: disable= no-member, arguments-differ, invalid-name
import
torch.nn
as
nn
from
.mlp_predictor
import
MLPPredictor
from
..gnn.gat
import
GAT
from
..readout.weighted_sum_and_max
import
WeightedSumAndMax
# pylint: disable=W0221
class
GATPredictor
(
nn
.
Module
):
r
"""GAT-based model for regression and classification on graphs.
GAT is introduced in `Graph Attention Networks <https://arxiv.org/abs/1710.10903>`__.
This model is based on GAT and can be used for regression and classification on graphs.
After updating node representations, we perform a weighted sum with learnable
weights and max pooling on them and concatenate the output of the two operations,
which is then fed into an MLP for final prediction.
For classification tasks, the output will be logits, i.e.
values before sigmoid or softmax.
Parameters
----------
in_feats : int
Number of input node features
hidden_feats : list of int
``hidden_feats[i]`` gives the output size of an attention head in the i-th GAT layer.
``len(hidden_feats)`` equals the number of GAT layers. By default, we use ``[32, 32]``.
num_heads : list of int
``num_heads[i]`` gives the number of attention heads in the i-th GAT layer.
``len(num_heads)`` equals the number of GAT layers. By default, we use 4 attention heads
for each GAT layer.
feat_drops : list of float
``feat_drops[i]`` gives the dropout applied to the input features in the i-th GAT layer.
``len(feat_drops)`` equals the number of GAT layers. By default, this will be zero for
all GAT layers.
attn_drops : list of float
``attn_drops[i]`` gives the dropout applied to attention values of edges in the i-th GAT
layer. ``len(attn_drops)`` equals the number of GAT layers. By default, this will be zero
for all GAT layers.
alphas : list of float
Hyperparameters in LeakyReLU, which are the slopes for negative values. ``alphas[i]``
gives the slope for negative value in the i-th GAT layer. ``len(alphas)`` equals the
number of GAT layers. By default, this will be 0.2 for all GAT layers.
residuals : list of bool
``residual[i]`` decides if residual connection is to be used for the i-th GAT layer.
``len(residual)`` equals the number of GAT layers. By default, residual connection
is performed for each GAT layer.
agg_modes : list of str
The way to aggregate multi-head attention results for each GAT layer, which can be either
'flatten' for concatenating all-head results or 'mean' for averaging all-head results.
``agg_modes[i]`` gives the way to aggregate multi-head attention results for the i-th
GAT layer. ``len(agg_modes)`` equals the number of GAT layers. By default, we flatten
multi-head results for intermediate GAT layers and compute mean of multi-head results
for the last GAT layer.
activations : list of activation function or None
``activations[i]`` gives the activation function applied to the aggregated multi-head
results for the i-th GAT layer. ``len(activations)`` equals the number of GAT layers.
By default, ELU is applied for intermediate GAT layers and no activation is applied
for the last GAT layer.
classifier_hidden_feats : int
Size of hidden graph representations in the classifier. Default to 128.
classifier_dropout : float
The probability for dropout in the classifier. Default to 0.
n_tasks : int
Number of tasks, which is also the output size. Default to 1.
"""
def
__init__
(
self
,
in_feats
,
hidden_feats
=
None
,
num_heads
=
None
,
feat_drops
=
None
,
attn_drops
=
None
,
alphas
=
None
,
residuals
=
None
,
agg_modes
=
None
,
activations
=
None
,
classifier_hidden_feats
=
128
,
classifier_dropout
=
0.
,
n_tasks
=
1
):
super
(
GATPredictor
,
self
).
__init__
()
self
.
gnn
=
GAT
(
in_feats
=
in_feats
,
hidden_feats
=
hidden_feats
,
num_heads
=
num_heads
,
feat_drops
=
feat_drops
,
attn_drops
=
attn_drops
,
alphas
=
alphas
,
residuals
=
residuals
,
agg_modes
=
agg_modes
,
activations
=
activations
)
if
self
.
gnn
.
agg_modes
[
-
1
]
==
'flatten'
:
gnn_out_feats
=
self
.
gnn
.
hidden_feats
[
-
1
]
*
self
.
gnn
.
num_heads
[
-
1
]
else
:
gnn_out_feats
=
self
.
gnn
.
hidden_feats
[
-
1
]
self
.
readout
=
WeightedSumAndMax
(
gnn_out_feats
)
self
.
predict
=
MLPPredictor
(
2
*
gnn_out_feats
,
classifier_hidden_feats
,
n_tasks
,
classifier_dropout
)
def
forward
(
self
,
bg
,
feats
):
"""Graph-level regression/soft classification.
Parameters
----------
bg : DGLGraph
DGLGraph for a batch of graphs.
feats : FloatTensor of shape (N, M1)
* N is the total number of nodes in the batch of graphs
* M1 is the input node feature size, which must match
in_feats in initialization
Returns
-------
FloatTensor of shape (B, n_tasks)
* Predictions on graphs
* B for the number of graphs in the batch
"""
node_feats
=
self
.
gnn
(
bg
,
feats
)
graph_feats
=
self
.
readout
(
bg
,
node_feats
)
return
self
.
predict
(
graph_feats
)
apps/life_sci/python/dgllife/model/model_zoo/gcn_predictor.py
deleted
100644 → 0
View file @
94c67203
"""GCN-based model for regression and classification on graphs."""
# pylint: disable= no-member, arguments-differ, invalid-name
import
torch.nn
as
nn
from
.mlp_predictor
import
MLPPredictor
from
..gnn.gcn
import
GCN
from
..readout.weighted_sum_and_max
import
WeightedSumAndMax
# pylint: disable=W0221
class
GCNPredictor
(
nn
.
Module
):
"""GCN-based model for regression and classification on graphs.
GCN is introduced in `Semi-Supervised Classification with Graph Convolutional Networks
<https://arxiv.org/abs/1609.02907>`__. This model is based on GCN and can be used
for regression and classification on graphs.
After updating node representations, we perform a weighted sum with learnable
weights and max pooling on them and concatenate the output of the two operations,
which is then fed into an MLP for final prediction.
For classification tasks, the output will be logits, i.e.
values before sigmoid or softmax.
Parameters
----------
in_feats : int
Number of input node features.
hidden_feats : list of int
``hidden_feats[i]`` gives the size of node representations after the i-th GCN layer.
``len(hidden_feats)`` equals the number of GCN layers. By default, we use
``[64, 64]``.
activation : list of activation functions or None
If None, no activation will be applied. If not None, ``activation[i]`` gives the
activation function to be used for the i-th GCN layer. ``len(activation)`` equals
the number of GCN layers. By default, ReLU is applied for all GCN layers.
residual : list of bool
``residual[i]`` decides if residual connection is to be used for the i-th GCN layer.
``len(residual)`` equals the number of GCN layers. By default, residual connection
is performed for each GCN layer.
batchnorm : list of bool
``batchnorm[i]`` decides if batch normalization is to be applied on the output of
the i-th GCN layer. ``len(batchnorm)`` equals the number of GCN layers. By default,
batch normalization is applied for all GCN layers.
dropout : list of float
``dropout[i]`` decides the dropout probability on the output of the i-th GCN layer.
``len(dropout)`` equals the number of GCN layers. By default, no dropout is
performed for all layers.
classifier_hidden_feats : int
Size of hidden graph representations in the classifier. Default to 128.
classifier_dropout : float
The probability for dropout in the classifier. Default to 0.
n_tasks : int
Number of tasks, which is also the output size. Default to 1.
"""
def
__init__
(
self
,
in_feats
,
hidden_feats
=
None
,
activation
=
None
,
residual
=
None
,
batchnorm
=
None
,
dropout
=
None
,
classifier_hidden_feats
=
128
,
classifier_dropout
=
0.
,
n_tasks
=
1
):
super
(
GCNPredictor
,
self
).
__init__
()
self
.
gnn
=
GCN
(
in_feats
=
in_feats
,
hidden_feats
=
hidden_feats
,
activation
=
activation
,
residual
=
residual
,
batchnorm
=
batchnorm
,
dropout
=
dropout
)
gnn_out_feats
=
self
.
gnn
.
hidden_feats
[
-
1
]
self
.
readout
=
WeightedSumAndMax
(
gnn_out_feats
)
self
.
predict
=
MLPPredictor
(
2
*
gnn_out_feats
,
classifier_hidden_feats
,
n_tasks
,
classifier_dropout
)
def
forward
(
self
,
bg
,
feats
):
"""Graph-level regression/soft classification.
Parameters
----------
bg : DGLGraph
DGLGraph for a batch of graphs.
feats : FloatTensor of shape (N, M1)
* N is the total number of nodes in the batch of graphs
* M1 is the input node feature size, which must match
in_feats in initialization
Returns
-------
FloatTensor of shape (B, n_tasks)
* Predictions on graphs
* B for the number of graphs in the batch
"""
node_feats
=
self
.
gnn
(
bg
,
feats
)
graph_feats
=
self
.
readout
(
bg
,
node_feats
)
return
self
.
predict
(
graph_feats
)
apps/life_sci/python/dgllife/model/model_zoo/gin_predictor.py
deleted
100644 → 0
View file @
94c67203
"""GIN-based model for regression and classification on graphs."""
# pylint: disable= no-member, arguments-differ, invalid-name
import
dgl
import
torch.nn
as
nn
from
dgl.nn.pytorch.glob
import
GlobalAttentionPooling
,
SumPooling
,
AvgPooling
,
MaxPooling
from
..gnn.gin
import
GIN
__all__
=
[
'GINPredictor'
]
# pylint: disable=W0221
class
GINPredictor
(
nn
.
Module
):
"""GIN-based model for regression and classification on graphs.
GIN was first introduced in `How Powerful Are Graph Neural Networks
<https://arxiv.org/abs/1810.00826>`__ for general graph property
prediction problems. It was further extended in `Strategies for
Pre-training Graph Neural Networks <https://arxiv.org/abs/1905.12265>`__
for pre-training and semi-supervised learning on large-scale datasets.
For classification tasks, the output will be logits, i.e. values before
sigmoid or softmax.
Parameters
----------
num_node_emb_list : list of int
num_node_emb_list[i] gives the number of items to embed for the
i-th categorical node feature variables. E.g. num_node_emb_list[0] can be
the number of atom types and num_node_emb_list[1] can be the number of
atom chirality types.
num_edge_emb_list : list of int
num_edge_emb_list[i] gives the number of items to embed for the
i-th categorical edge feature variables. E.g. num_edge_emb_list[0] can be
the number of bond types and num_edge_emb_list[1] can be the number of
bond direction types.
num_layers : int
Number of GIN layers to use. Default to 5.
emb_dim : int
The size of each embedding vector. Default to 300.
JK : str
JK for jumping knowledge as in `Representation Learning on Graphs with
Jumping Knowledge Networks <https://arxiv.org/abs/1806.03536>`__. It decides
how we are going to combine the all-layer node representations for the final output.
There can be four options for this argument, ``'concat'``, ``'last'``, ``'max'`` and
``'sum'``. Default to 'last'.
* ``'concat'``: concatenate the output node representations from all GIN layers
* ``'last'``: use the node representations from the last GIN layer
* ``'max'``: apply max pooling to the node representations across all GIN layers
* ``'sum'``: sum the output node representations from all GIN layers
dropout : float
Dropout to apply to the output of each GIN layer. Default to 0.5.
readout : str
Readout for computing graph representations out of node representations, which
can be ``'sum'``, ``'mean'``, ``'max'``, or ``'attention'``. Default to 'mean'.
n_tasks : int
Number of tasks, which is also the output size. Default to 1.
"""
def
__init__
(
self
,
num_node_emb_list
,
num_edge_emb_list
,
num_layers
=
5
,
emb_dim
=
300
,
JK
=
'last'
,
dropout
=
0.5
,
readout
=
'mean'
,
n_tasks
=
1
):
super
(
GINPredictor
,
self
).
__init__
()
if
num_layers
<
2
:
raise
ValueError
(
'Number of GNN layers must be greater '
'than 1, got {:d}'
.
format
(
num_layers
))
self
.
gnn
=
GIN
(
num_node_emb_list
=
num_node_emb_list
,
num_edge_emb_list
=
num_edge_emb_list
,
num_layers
=
num_layers
,
emb_dim
=
emb_dim
,
JK
=
JK
,
dropout
=
dropout
)
if
readout
==
'sum'
:
self
.
readout
=
SumPooling
()
elif
readout
==
'mean'
:
self
.
readout
=
AvgPooling
()
elif
readout
==
'max'
:
self
.
readout
=
MaxPooling
()
elif
readout
==
'attention'
:
if
JK
==
'concat'
:
self
.
readout
=
GlobalAttentionPooling
(
gate_nn
=
nn
.
Linear
((
num_layers
+
1
)
*
emb_dim
,
1
))
else
:
self
.
readout
=
GlobalAttentionPooling
(
gate_nn
=
nn
.
Linear
(
emb_dim
,
1
))
else
:
raise
ValueError
(
"Expect readout to be 'sum', 'mean', "
"'max' or 'attention', got {}"
.
format
(
readout
))
if
JK
==
'concat'
:
self
.
predict
=
nn
.
Linear
((
num_layers
+
1
)
*
emb_dim
,
n_tasks
)
else
:
self
.
predict
=
nn
.
Linear
(
emb_dim
,
n_tasks
)
def
forward
(
self
,
g
,
categorical_node_feats
,
categorical_edge_feats
):
"""Graph-level regression/soft classification.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
categorical_node_feats : list of LongTensor of shape (N)
* Input categorical node features
* len(categorical_node_feats) should be the same as len(num_node_emb_list)
* N is the total number of nodes in the batch of graphs
categorical_edge_feats : list of LongTensor of shape (E)
* Input categorical edge features
* len(categorical_edge_feats) should be the same as
len(num_edge_emb_list) in the arguments
* E is the total number of edges in the batch of graphs
Returns
-------
FloatTensor of shape (B, n_tasks)
* Predictions on graphs
* B for the number of graphs in the batch
"""
node_feats
=
self
.
gnn
(
g
,
categorical_node_feats
,
categorical_edge_feats
)
graph_feats
=
self
.
readout
(
g
,
node_feats
)
return
self
.
predict
(
graph_feats
)
apps/life_sci/python/dgllife/model/model_zoo/jtnn/__init__.py
deleted
100644 → 0
View file @
94c67203
"""JTNN Module"""
from
.jtnn_vae
import
DGLJTNNVAE
apps/life_sci/python/dgllife/model/model_zoo/jtnn/chemutils.py
deleted
100644 → 0
View file @
94c67203
# pylint: disable=C0111, C0103, E1101, W0611, W0612, W0703, C0200, R1710, I1101, R1721
from
collections
import
defaultdict
import
rdkit.Chem
as
Chem
from
rdkit.Chem.EnumerateStereoisomers
import
EnumerateStereoisomers
from
scipy.sparse
import
csr_matrix
from
scipy.sparse.csgraph
import
minimum_spanning_tree
MST_MAX_WEIGHT
=
100
MAX_NCAND
=
2000
def
set_atommap
(
mol
,
num
=
0
):
for
atom
in
mol
.
GetAtoms
():
atom
.
SetAtomMapNum
(
num
)
def
get_mol
(
smiles
):
mol
=
Chem
.
MolFromSmiles
(
smiles
)
if
mol
is
None
:
return
None
Chem
.
Kekulize
(
mol
)
return
mol
def
get_smiles
(
mol
):
return
Chem
.
MolToSmiles
(
mol
,
kekuleSmiles
=
True
)
def
decode_stereo
(
smiles2D
):
mol
=
Chem
.
MolFromSmiles
(
smiles2D
)
dec_isomers
=
list
(
EnumerateStereoisomers
(
mol
))
dec_isomers
=
[
Chem
.
MolFromSmiles
(
Chem
.
MolToSmiles
(
mol
,
isomericSmiles
=
True
))
for
mol
in
dec_isomers
]
smiles3D
=
[
Chem
.
MolToSmiles
(
mol
,
isomericSmiles
=
True
)
for
mol
in
dec_isomers
]
chiralN
=
[
atom
.
GetIdx
()
for
atom
in
dec_isomers
[
0
].
GetAtoms
()
if
int
(
atom
.
GetChiralTag
())
>
0
and
atom
.
GetSymbol
()
==
"N"
]
if
len
(
chiralN
)
>
0
:
for
mol
in
dec_isomers
:
for
idx
in
chiralN
:
mol
.
GetAtomWithIdx
(
idx
).
SetChiralTag
(
Chem
.
rdchem
.
ChiralType
.
CHI_UNSPECIFIED
)
smiles3D
.
append
(
Chem
.
MolToSmiles
(
mol
,
isomericSmiles
=
True
))
return
smiles3D
def
sanitize
(
mol
):
try
:
smiles
=
get_smiles
(
mol
)
mol
=
get_mol
(
smiles
)
except
Exception
:
return
None
return
mol
def
copy_atom
(
atom
):
new_atom
=
Chem
.
Atom
(
atom
.
GetSymbol
())
new_atom
.
SetFormalCharge
(
atom
.
GetFormalCharge
())
new_atom
.
SetAtomMapNum
(
atom
.
GetAtomMapNum
())
return
new_atom
def
copy_edit_mol
(
mol
):
new_mol
=
Chem
.
RWMol
(
Chem
.
MolFromSmiles
(
''
))
for
atom
in
mol
.
GetAtoms
():
new_atom
=
copy_atom
(
atom
)
new_mol
.
AddAtom
(
new_atom
)
for
bond
in
mol
.
GetBonds
():
a1
=
bond
.
GetBeginAtom
().
GetIdx
()
a2
=
bond
.
GetEndAtom
().
GetIdx
()
bt
=
bond
.
GetBondType
()
new_mol
.
AddBond
(
a1
,
a2
,
bt
)
return
new_mol
def
get_clique_mol
(
mol
,
atoms
):
smiles
=
Chem
.
MolFragmentToSmiles
(
mol
,
atoms
,
kekuleSmiles
=
True
)
new_mol
=
Chem
.
MolFromSmiles
(
smiles
,
sanitize
=
False
)
new_mol
=
copy_edit_mol
(
new_mol
).
GetMol
()
new_mol
=
sanitize
(
new_mol
)
# We assume this is not None
return
new_mol
def
tree_decomp
(
mol
):
n_atoms
=
mol
.
GetNumAtoms
()
if
n_atoms
==
1
:
return
[[
0
]],
[]
cliques
=
[]
for
bond
in
mol
.
GetBonds
():
a1
=
bond
.
GetBeginAtom
().
GetIdx
()
a2
=
bond
.
GetEndAtom
().
GetIdx
()
if
not
bond
.
IsInRing
():
cliques
.
append
([
a1
,
a2
])
ssr
=
[
list
(
x
)
for
x
in
Chem
.
GetSymmSSSR
(
mol
)]
cliques
.
extend
(
ssr
)
nei_list
=
[[]
for
i
in
range
(
n_atoms
)]
for
i
in
range
(
len
(
cliques
)):
for
atom
in
cliques
[
i
]:
nei_list
[
atom
].
append
(
i
)
# Merge Rings with intersection > 2 atoms
for
i
in
range
(
len
(
cliques
)):
if
len
(
cliques
[
i
])
<=
2
:
continue
for
atom
in
cliques
[
i
]:
for
j
in
nei_list
[
atom
]:
if
i
>=
j
or
len
(
cliques
[
j
])
<=
2
:
continue
inter
=
set
(
cliques
[
i
])
&
set
(
cliques
[
j
])
if
len
(
inter
)
>
2
:
cliques
[
i
].
extend
(
cliques
[
j
])
cliques
[
i
]
=
list
(
set
(
cliques
[
i
]))
cliques
[
j
]
=
[]
cliques
=
[
c
for
c
in
cliques
if
len
(
c
)
>
0
]
nei_list
=
[[]
for
i
in
range
(
n_atoms
)]
for
i
in
range
(
len
(
cliques
)):
for
atom
in
cliques
[
i
]:
nei_list
[
atom
].
append
(
i
)
# Build edges and add singleton cliques
edges
=
defaultdict
(
int
)
for
atom
in
range
(
n_atoms
):
if
len
(
nei_list
[
atom
])
<=
1
:
continue
cnei
=
nei_list
[
atom
]
bonds
=
[
c
for
c
in
cnei
if
len
(
cliques
[
c
])
==
2
]
rings
=
[
c
for
c
in
cnei
if
len
(
cliques
[
c
])
>
4
]
# In general, if len(cnei) >= 3, a singleton should be added, but 1
# bond + 2 ring is currently not dealt with.
if
len
(
bonds
)
>
2
or
(
len
(
bonds
)
==
2
and
len
(
cnei
)
>
2
):
cliques
.
append
([
atom
])
c2
=
len
(
cliques
)
-
1
for
c1
in
cnei
:
edges
[(
c1
,
c2
)]
=
1
elif
len
(
rings
)
>
2
:
# Multiple (n>2) complex rings
cliques
.
append
([
atom
])
c2
=
len
(
cliques
)
-
1
for
c1
in
cnei
:
edges
[(
c1
,
c2
)]
=
MST_MAX_WEIGHT
-
1
else
:
for
i
in
range
(
len
(
cnei
)):
for
j
in
range
(
i
+
1
,
len
(
cnei
)):
c1
,
c2
=
cnei
[
i
],
cnei
[
j
]
inter
=
set
(
cliques
[
c1
])
&
set
(
cliques
[
c2
])
if
edges
[(
c1
,
c2
)]
<
len
(
inter
):
# cnei[i] < cnei[j] by construction
edges
[(
c1
,
c2
)]
=
len
(
inter
)
edges
=
[
u
+
(
MST_MAX_WEIGHT
-
v
,)
for
u
,
v
in
edges
.
items
()]
if
len
(
edges
)
==
0
:
return
cliques
,
edges
# Compute Maximum Spanning Tree
row
,
col
,
data
=
list
(
zip
(
*
edges
))
n_clique
=
len
(
cliques
)
clique_graph
=
csr_matrix
((
data
,
(
row
,
col
)),
shape
=
(
n_clique
,
n_clique
))
junc_tree
=
minimum_spanning_tree
(
clique_graph
)
row
,
col
=
junc_tree
.
nonzero
()
edges
=
[(
row
[
i
],
col
[
i
])
for
i
in
range
(
len
(
row
))]
return
(
cliques
,
edges
)
def
atom_equal
(
a1
,
a2
):
return
a1
.
GetSymbol
()
==
a2
.
GetSymbol
()
and
a1
.
GetFormalCharge
()
==
a2
.
GetFormalCharge
()
# Bond type not considered because all aromatic (so SINGLE matches DOUBLE)
def
ring_bond_equal
(
b1
,
b2
,
reverse
=
False
):
b1
=
(
b1
.
GetBeginAtom
(),
b1
.
GetEndAtom
())
if
reverse
:
b2
=
(
b2
.
GetEndAtom
(),
b2
.
GetBeginAtom
())
else
:
b2
=
(
b2
.
GetBeginAtom
(),
b2
.
GetEndAtom
())
return
atom_equal
(
b1
[
0
],
b2
[
0
])
and
atom_equal
(
b1
[
1
],
b2
[
1
])
def
attach_mols_nx
(
ctr_mol
,
neighbors
,
prev_nodes
,
nei_amap
):
prev_nids
=
[
node
[
'nid'
]
for
node
in
prev_nodes
]
for
nei_node
in
prev_nodes
+
neighbors
:
nei_id
,
nei_mol
=
nei_node
[
'nid'
],
nei_node
[
'mol'
]
amap
=
nei_amap
[
nei_id
]
for
atom
in
nei_mol
.
GetAtoms
():
if
atom
.
GetIdx
()
not
in
amap
:
new_atom
=
copy_atom
(
atom
)
amap
[
atom
.
GetIdx
()]
=
ctr_mol
.
AddAtom
(
new_atom
)
if
nei_mol
.
GetNumBonds
()
==
0
:
nei_atom
=
nei_mol
.
GetAtomWithIdx
(
0
)
ctr_atom
=
ctr_mol
.
GetAtomWithIdx
(
amap
[
0
])
ctr_atom
.
SetAtomMapNum
(
nei_atom
.
GetAtomMapNum
())
else
:
for
bond
in
nei_mol
.
GetBonds
():
a1
=
amap
[
bond
.
GetBeginAtom
().
GetIdx
()]
a2
=
amap
[
bond
.
GetEndAtom
().
GetIdx
()]
if
ctr_mol
.
GetBondBetweenAtoms
(
a1
,
a2
)
is
None
:
ctr_mol
.
AddBond
(
a1
,
a2
,
bond
.
GetBondType
())
elif
nei_id
in
prev_nids
:
# father node overrides
ctr_mol
.
RemoveBond
(
a1
,
a2
)
ctr_mol
.
AddBond
(
a1
,
a2
,
bond
.
GetBondType
())
return
ctr_mol
def
local_attach_nx
(
ctr_mol
,
neighbors
,
prev_nodes
,
amap_list
):
ctr_mol
=
copy_edit_mol
(
ctr_mol
)
nei_amap
=
{
nei
[
'nid'
]:
{}
for
nei
in
prev_nodes
+
neighbors
}
for
nei_id
,
ctr_atom
,
nei_atom
in
amap_list
:
nei_amap
[
nei_id
][
nei_atom
]
=
ctr_atom
ctr_mol
=
attach_mols_nx
(
ctr_mol
,
neighbors
,
prev_nodes
,
nei_amap
)
return
ctr_mol
.
GetMol
()
# This version records idx mapping between ctr_mol and nei_mol
def
enum_attach_nx
(
ctr_mol
,
nei_node
,
amap
,
singletons
):
nei_mol
,
nei_idx
=
nei_node
[
'mol'
],
nei_node
[
'nid'
]
att_confs
=
[]
black_list
=
[
atom_idx
for
nei_id
,
atom_idx
,
_
in
amap
if
nei_id
in
singletons
]
ctr_atoms
=
[
atom
for
atom
in
ctr_mol
.
GetAtoms
()
if
atom
.
GetIdx
()
not
in
black_list
]
ctr_bonds
=
[
bond
for
bond
in
ctr_mol
.
GetBonds
()]
if
nei_mol
.
GetNumBonds
()
==
0
:
# neighbor singleton
nei_atom
=
nei_mol
.
GetAtomWithIdx
(
0
)
used_list
=
[
atom_idx
for
_
,
atom_idx
,
_
in
amap
]
for
atom
in
ctr_atoms
:
if
atom_equal
(
atom
,
nei_atom
)
and
atom
.
GetIdx
()
not
in
used_list
:
new_amap
=
amap
+
[(
nei_idx
,
atom
.
GetIdx
(),
0
)]
att_confs
.
append
(
new_amap
)
elif
nei_mol
.
GetNumBonds
()
==
1
:
# neighbor is a bond
bond
=
nei_mol
.
GetBondWithIdx
(
0
)
bond_val
=
int
(
bond
.
GetBondTypeAsDouble
())
b1
,
b2
=
bond
.
GetBeginAtom
(),
bond
.
GetEndAtom
()
for
atom
in
ctr_atoms
:
# Optimize if atom is carbon (other atoms may change valence)
if
atom
.
GetAtomicNum
()
==
6
and
atom
.
GetTotalNumHs
()
<
bond_val
:
continue
if
atom_equal
(
atom
,
b1
):
new_amap
=
amap
+
[(
nei_idx
,
atom
.
GetIdx
(),
b1
.
GetIdx
())]
att_confs
.
append
(
new_amap
)
elif
atom_equal
(
atom
,
b2
):
new_amap
=
amap
+
[(
nei_idx
,
atom
.
GetIdx
(),
b2
.
GetIdx
())]
att_confs
.
append
(
new_amap
)
else
:
# intersection is an atom
for
a1
in
ctr_atoms
:
for
a2
in
nei_mol
.
GetAtoms
():
if
atom_equal
(
a1
,
a2
):
# Optimize if atom is carbon (other atoms may change
# valence)
if
a1
.
GetAtomicNum
()
==
6
and
a1
.
GetTotalNumHs
()
+
a2
.
GetTotalNumHs
()
<
4
:
continue
new_amap
=
amap
+
[(
nei_idx
,
a1
.
GetIdx
(),
a2
.
GetIdx
())]
att_confs
.
append
(
new_amap
)
# intersection is an bond
if
ctr_mol
.
GetNumBonds
()
>
1
:
for
b1
in
ctr_bonds
:
for
b2
in
nei_mol
.
GetBonds
():
if
ring_bond_equal
(
b1
,
b2
):
new_amap
=
amap
+
[(
nei_idx
,
b1
.
GetBeginAtom
().
GetIdx
(),
b2
.
GetBeginAtom
().
GetIdx
()),
(
nei_idx
,
b1
.
GetEndAtom
().
GetIdx
(),
b2
.
GetEndAtom
().
GetIdx
())]
att_confs
.
append
(
new_amap
)
if
ring_bond_equal
(
b1
,
b2
,
reverse
=
True
):
new_amap
=
amap
+
[(
nei_idx
,
b1
.
GetBeginAtom
().
GetIdx
(),
b2
.
GetEndAtom
().
GetIdx
()),
(
nei_idx
,
b1
.
GetEndAtom
().
GetIdx
(),
b2
.
GetBeginAtom
().
GetIdx
())]
att_confs
.
append
(
new_amap
)
return
att_confs
# Try rings first: Speed-Up
def
enum_assemble_nx
(
node
,
neighbors
,
prev_nodes
=
None
,
prev_amap
=
None
):
if
prev_nodes
is
None
:
prev_nodes
=
[]
if
prev_amap
is
None
:
prev_amap
=
[]
all_attach_confs
=
[]
singletons
=
[
nei_node
[
'nid'
]
for
nei_node
in
neighbors
+
prev_nodes
if
nei_node
[
'mol'
].
GetNumAtoms
()
==
1
]
def
search
(
cur_amap
,
depth
):
if
len
(
all_attach_confs
)
>
MAX_NCAND
:
return
None
if
depth
==
len
(
neighbors
):
all_attach_confs
.
append
(
cur_amap
)
return
None
nei_node
=
neighbors
[
depth
]
cand_amap
=
enum_attach_nx
(
node
[
'mol'
],
nei_node
,
cur_amap
,
singletons
)
cand_smiles
=
set
()
candidates
=
[]
for
amap
in
cand_amap
:
cand_mol
=
local_attach_nx
(
node
[
'mol'
],
neighbors
[:
depth
+
1
],
prev_nodes
,
amap
)
cand_mol
=
sanitize
(
cand_mol
)
if
cand_mol
is
None
:
continue
smiles
=
get_smiles
(
cand_mol
)
if
smiles
in
cand_smiles
:
continue
cand_smiles
.
add
(
smiles
)
candidates
.
append
(
amap
)
if
len
(
candidates
)
==
0
:
return
[]
for
new_amap
in
candidates
:
search
(
new_amap
,
depth
+
1
)
search
(
prev_amap
,
0
)
cand_smiles
=
set
()
candidates
=
[]
for
amap
in
all_attach_confs
:
cand_mol
=
local_attach_nx
(
node
[
'mol'
],
neighbors
,
prev_nodes
,
amap
)
cand_mol
=
Chem
.
MolFromSmiles
(
Chem
.
MolToSmiles
(
cand_mol
))
smiles
=
Chem
.
MolToSmiles
(
cand_mol
)
if
smiles
in
cand_smiles
:
continue
cand_smiles
.
add
(
smiles
)
Chem
.
Kekulize
(
cand_mol
)
candidates
.
append
((
smiles
,
cand_mol
,
amap
))
return
candidates
# Only used for debugging purpose
def
dfs_assemble_nx
(
graph
,
cur_mol
,
global_amap
,
fa_amap
,
cur_node_id
,
fa_node_id
):
cur_node
=
graph
.
nodes_dict
[
cur_node_id
]
fa_node
=
graph
.
nodes_dict
[
fa_node_id
]
if
fa_node_id
is
not
None
else
None
fa_nid
=
fa_node
[
'nid'
]
if
fa_node
is
not
None
else
-
1
prev_nodes
=
[
fa_node
]
if
fa_node
is
not
None
else
[]
children_id
=
[
nei
for
nei
in
graph
[
cur_node_id
]
if
graph
.
nodes_dict
[
nei
][
'nid'
]
!=
fa_nid
]
children
=
[
graph
.
nodes_dict
[
nei
]
for
nei
in
children_id
]
neighbors
=
[
nei
for
nei
in
children
if
nei
[
'mol'
].
GetNumAtoms
()
>
1
]
neighbors
=
sorted
(
neighbors
,
key
=
lambda
x
:
x
[
'mol'
].
GetNumAtoms
(),
reverse
=
True
)
singletons
=
[
nei
for
nei
in
children
if
nei
[
'mol'
].
GetNumAtoms
()
==
1
]
neighbors
=
singletons
+
neighbors
cur_amap
=
[(
fa_nid
,
a2
,
a1
)
for
nid
,
a1
,
a2
in
fa_amap
if
nid
==
cur_node
[
'nid'
]]
cands
=
enum_assemble_nx
(
graph
.
nodes_dict
[
cur_node_id
],
neighbors
,
prev_nodes
,
cur_amap
)
if
len
(
cands
)
==
0
:
return
cand_smiles
,
_
,
cand_amap
=
zip
(
*
cands
)
label_idx
=
cand_smiles
.
index
(
cur_node
[
'label'
])
label_amap
=
cand_amap
[
label_idx
]
for
nei_id
,
ctr_atom
,
nei_atom
in
label_amap
:
if
nei_id
==
fa_nid
:
continue
global_amap
[
nei_id
][
nei_atom
]
=
global_amap
[
cur_node
[
'nid'
]][
ctr_atom
]
# father is already attached
cur_mol
=
attach_mols_nx
(
cur_mol
,
children
,
[],
global_amap
)
for
nei_node_id
,
nei_node
in
zip
(
children_id
,
children
):
if
not
nei_node
[
'is_leaf'
]:
dfs_assemble_nx
(
graph
,
cur_mol
,
global_amap
,
label_amap
,
nei_node_id
,
cur_node_id
)
apps/life_sci/python/dgllife/model/model_zoo/jtnn/jtmpn.py
deleted
100644 → 0
View file @
94c67203
# pylint: disable=C0111, C0103, E1101, W0611, W0612, W1508, I1101, W0221
# pylint: disable=redefined-outer-name
import
os
import
rdkit.Chem
as
Chem
import
torch
import
torch.nn
as
nn
import
dgl.function
as
DGLF
from
dgl
import
DGLGraph
,
mean_nodes
from
.nnutils
import
cuda
ELEM_LIST
=
[
'C'
,
'N'
,
'O'
,
'S'
,
'F'
,
'Si'
,
'P'
,
'Cl'
,
'Br'
,
'Mg'
,
'Na'
,
'Ca'
,
'Fe'
,
'Al'
,
'I'
,
'B'
,
'K'
,
'Se'
,
'Zn'
,
'H'
,
'Cu'
,
'Mn'
,
'unknown'
]
ATOM_FDIM
=
len
(
ELEM_LIST
)
+
6
+
5
+
1
BOND_FDIM
=
5
MAX_NB
=
10
PAPER
=
os
.
getenv
(
'PAPER'
,
False
)
def
onek_encoding_unk
(
x
,
allowable_set
):
if
x
not
in
allowable_set
:
x
=
allowable_set
[
-
1
]
return
[
x
==
s
for
s
in
allowable_set
]
# Note that during graph decoding they don't predict stereochemistry-related
# characteristics (i.e. Chiral Atoms, E-Z, Cis-Trans). Instead, they decode
# the 2-D graph first, then enumerate all possible 3-D forms and find the
# one with highest score.
def
atom_features
(
atom
):
return
(
torch
.
Tensor
(
onek_encoding_unk
(
atom
.
GetSymbol
(),
ELEM_LIST
)
+
onek_encoding_unk
(
atom
.
GetDegree
(),
[
0
,
1
,
2
,
3
,
4
,
5
])
+
onek_encoding_unk
(
atom
.
GetFormalCharge
(),
[
-
1
,
-
2
,
1
,
2
,
0
])
+
[
atom
.
GetIsAromatic
()]))
def
bond_features
(
bond
):
bt
=
bond
.
GetBondType
()
return
torch
.
Tensor
([
bt
==
Chem
.
rdchem
.
BondType
.
SINGLE
,
bt
==
Chem
.
rdchem
.
BondType
.
DOUBLE
,
bt
==
Chem
.
rdchem
.
BondType
.
TRIPLE
,
bt
==
Chem
.
rdchem
.
BondType
.
AROMATIC
,
bond
.
IsInRing
()])
def
mol2dgl_single
(
cand_batch
):
cand_graphs
=
[]
tree_mess_source_edges
=
[]
# map these edges from trees to...
tree_mess_target_edges
=
[]
# these edges on candidate graphs
tree_mess_target_nodes
=
[]
n_nodes
=
0
atom_x
=
[]
bond_x
=
[]
for
mol
,
mol_tree
,
ctr_node_id
in
cand_batch
:
n_atoms
=
mol
.
GetNumAtoms
()
g
=
DGLGraph
()
for
i
,
atom
in
enumerate
(
mol
.
GetAtoms
()):
assert
i
==
atom
.
GetIdx
()
atom_x
.
append
(
atom_features
(
atom
))
g
.
add_nodes
(
n_atoms
)
bond_src
=
[]
bond_dst
=
[]
for
i
,
bond
in
enumerate
(
mol
.
GetBonds
()):
a1
=
bond
.
GetBeginAtom
()
a2
=
bond
.
GetEndAtom
()
begin_idx
=
a1
.
GetIdx
()
end_idx
=
a2
.
GetIdx
()
features
=
bond_features
(
bond
)
bond_src
.
append
(
begin_idx
)
bond_dst
.
append
(
end_idx
)
bond_x
.
append
(
features
)
bond_src
.
append
(
end_idx
)
bond_dst
.
append
(
begin_idx
)
bond_x
.
append
(
features
)
x_nid
,
y_nid
=
a1
.
GetAtomMapNum
(),
a2
.
GetAtomMapNum
()
# Tree node ID in the batch
x_bid
=
mol_tree
.
nodes_dict
[
x_nid
-
1
][
'idx'
]
if
x_nid
>
0
else
-
1
y_bid
=
mol_tree
.
nodes_dict
[
y_nid
-
1
][
'idx'
]
if
y_nid
>
0
else
-
1
if
x_bid
>=
0
and
y_bid
>=
0
and
x_bid
!=
y_bid
:
if
mol_tree
.
has_edge_between
(
x_bid
,
y_bid
):
tree_mess_target_edges
.
append
(
(
begin_idx
+
n_nodes
,
end_idx
+
n_nodes
))
tree_mess_source_edges
.
append
((
x_bid
,
y_bid
))
tree_mess_target_nodes
.
append
(
end_idx
+
n_nodes
)
if
mol_tree
.
has_edge_between
(
y_bid
,
x_bid
):
tree_mess_target_edges
.
append
(
(
end_idx
+
n_nodes
,
begin_idx
+
n_nodes
))
tree_mess_source_edges
.
append
((
y_bid
,
x_bid
))
tree_mess_target_nodes
.
append
(
begin_idx
+
n_nodes
)
n_nodes
+=
n_atoms
g
.
add_edges
(
bond_src
,
bond_dst
)
cand_graphs
.
append
(
g
)
return
cand_graphs
,
torch
.
stack
(
atom_x
),
\
torch
.
stack
(
bond_x
)
if
len
(
bond_x
)
>
0
else
torch
.
zeros
(
0
),
\
torch
.
LongTensor
(
tree_mess_source_edges
),
\
torch
.
LongTensor
(
tree_mess_target_edges
),
\
torch
.
LongTensor
(
tree_mess_target_nodes
)
mpn_loopy_bp_msg
=
DGLF
.
copy_src
(
src
=
'msg'
,
out
=
'msg'
)
mpn_loopy_bp_reduce
=
DGLF
.
sum
(
msg
=
'msg'
,
out
=
'accum_msg'
)
class
LoopyBPUpdate
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
):
super
(
LoopyBPUpdate
,
self
).
__init__
()
self
.
hidden_size
=
hidden_size
self
.
W_h
=
nn
.
Linear
(
hidden_size
,
hidden_size
,
bias
=
False
)
def
forward
(
self
,
node
):
msg_input
=
node
.
data
[
'msg_input'
]
msg_delta
=
self
.
W_h
(
node
.
data
[
'accum_msg'
]
+
node
.
data
[
'alpha'
])
msg
=
torch
.
relu
(
msg_input
+
msg_delta
)
return
{
'msg'
:
msg
}
if
PAPER
:
mpn_gather_msg
=
[
DGLF
.
copy_edge
(
edge
=
'msg'
,
out
=
'msg'
),
DGLF
.
copy_edge
(
edge
=
'alpha'
,
out
=
'alpha'
)
]
else
:
mpn_gather_msg
=
DGLF
.
copy_edge
(
edge
=
'msg'
,
out
=
'msg'
)
if
PAPER
:
mpn_gather_reduce
=
[
DGLF
.
sum
(
msg
=
'msg'
,
out
=
'm'
),
DGLF
.
sum
(
msg
=
'alpha'
,
out
=
'accum_alpha'
),
]
else
:
mpn_gather_reduce
=
DGLF
.
sum
(
msg
=
'msg'
,
out
=
'm'
)
class
GatherUpdate
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
):
super
(
GatherUpdate
,
self
).
__init__
()
self
.
hidden_size
=
hidden_size
self
.
W_o
=
nn
.
Linear
(
ATOM_FDIM
+
hidden_size
,
hidden_size
)
def
forward
(
self
,
node
):
if
PAPER
:
#m = node['m']
m
=
node
.
data
[
'm'
]
+
node
.
data
[
'accum_alpha'
]
else
:
m
=
node
.
data
[
'm'
]
+
node
.
data
[
'alpha'
]
return
{
'h'
:
torch
.
relu
(
self
.
W_o
(
torch
.
cat
([
node
.
data
[
'x'
],
m
],
1
))),
}
class
DGLJTMPN
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
depth
):
nn
.
Module
.
__init__
(
self
)
self
.
depth
=
depth
self
.
W_i
=
nn
.
Linear
(
ATOM_FDIM
+
BOND_FDIM
,
hidden_size
,
bias
=
False
)
self
.
loopy_bp_updater
=
LoopyBPUpdate
(
hidden_size
)
self
.
gather_updater
=
GatherUpdate
(
hidden_size
)
self
.
hidden_size
=
hidden_size
self
.
n_samples_total
=
0
self
.
n_nodes_total
=
0
self
.
n_edges_total
=
0
self
.
n_passes
=
0
def
forward
(
self
,
cand_batch
,
mol_tree_batch
):
cand_graphs
,
tree_mess_src_edges
,
tree_mess_tgt_edges
,
tree_mess_tgt_nodes
=
cand_batch
n_samples
=
len
(
cand_graphs
)
cand_line_graph
=
cand_graphs
.
line_graph
(
backtracking
=
False
,
shared
=
True
)
n_nodes
=
cand_graphs
.
number_of_nodes
()
n_edges
=
cand_graphs
.
number_of_edges
()
cand_graphs
=
self
.
run
(
cand_graphs
,
cand_line_graph
,
tree_mess_src_edges
,
tree_mess_tgt_edges
,
tree_mess_tgt_nodes
,
mol_tree_batch
)
g_repr
=
mean_nodes
(
cand_graphs
,
'h'
)
self
.
n_samples_total
+=
n_samples
self
.
n_nodes_total
+=
n_nodes
self
.
n_edges_total
+=
n_edges
self
.
n_passes
+=
1
return
g_repr
def
run
(
self
,
cand_graphs
,
cand_line_graph
,
tree_mess_src_edges
,
tree_mess_tgt_edges
,
tree_mess_tgt_nodes
,
mol_tree_batch
):
n_nodes
=
cand_graphs
.
number_of_nodes
()
cand_graphs
.
apply_edges
(
func
=
lambda
edges
:
{
'src_x'
:
edges
.
src
[
'x'
]},
)
bond_features
=
cand_line_graph
.
ndata
[
'x'
]
source_features
=
cand_line_graph
.
ndata
[
'src_x'
]
features
=
torch
.
cat
([
source_features
,
bond_features
],
1
)
msg_input
=
self
.
W_i
(
features
)
cand_line_graph
.
ndata
.
update
({
'msg_input'
:
msg_input
,
'msg'
:
torch
.
relu
(
msg_input
),
'accum_msg'
:
torch
.
zeros_like
(
msg_input
),
})
zero_node_state
=
bond_features
.
new
(
n_nodes
,
self
.
hidden_size
).
zero_
()
cand_graphs
.
ndata
.
update
({
'm'
:
zero_node_state
.
clone
(),
'h'
:
zero_node_state
.
clone
(),
})
cand_graphs
.
edata
[
'alpha'
]
=
\
cuda
(
torch
.
zeros
(
cand_graphs
.
number_of_edges
(),
self
.
hidden_size
))
cand_graphs
.
ndata
[
'alpha'
]
=
zero_node_state
if
tree_mess_src_edges
.
shape
[
0
]
>
0
:
if
PAPER
:
src_u
,
src_v
=
tree_mess_src_edges
.
unbind
(
1
)
tgt_u
,
tgt_v
=
tree_mess_tgt_edges
.
unbind
(
1
)
alpha
=
mol_tree_batch
.
edges
[
src_u
,
src_v
].
data
[
'm'
]
cand_graphs
.
edges
[
tgt_u
,
tgt_v
].
data
[
'alpha'
]
=
alpha
else
:
src_u
,
src_v
=
tree_mess_src_edges
.
unbind
(
1
)
alpha
=
mol_tree_batch
.
edges
[
src_u
,
src_v
].
data
[
'm'
]
node_idx
=
(
tree_mess_tgt_nodes
.
to
(
device
=
zero_node_state
.
device
)[:,
None
]
.
expand_as
(
alpha
))
node_alpha
=
zero_node_state
.
clone
().
scatter_add
(
0
,
node_idx
,
alpha
)
cand_graphs
.
ndata
[
'alpha'
]
=
node_alpha
cand_graphs
.
apply_edges
(
func
=
lambda
edges
:
{
'alpha'
:
edges
.
src
[
'alpha'
]},
)
for
i
in
range
(
self
.
depth
-
1
):
cand_line_graph
.
update_all
(
mpn_loopy_bp_msg
,
mpn_loopy_bp_reduce
,
self
.
loopy_bp_updater
,
)
cand_graphs
.
update_all
(
mpn_gather_msg
,
mpn_gather_reduce
,
self
.
gather_updater
,
)
return
cand_graphs
apps/life_sci/python/dgllife/model/model_zoo/jtnn/jtnn_dec.py
deleted
100644 → 0
View file @
94c67203
# pylint: disable=C0111, C0103, E1101, W0611, W0612, W0221, E1102
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
dgl.function
as
DGLF
from
dgl
import
batch
,
dfs_labeled_edges_generator
from
.chemutils
import
enum_assemble_nx
,
get_mol
from
.mol_tree_nx
import
DGLMolTree
from
.nnutils
import
GRUUpdate
,
cuda
MAX_NB
=
8
MAX_DECODE_LEN
=
100
def
dfs_order
(
forest
,
roots
):
edges
=
dfs_labeled_edges_generator
(
forest
,
roots
,
has_reverse_edge
=
True
)
for
e
,
l
in
zip
(
*
edges
):
# I exploited the fact that the reverse edge ID equal to 1 xor forward
# edge ID for molecule trees. Normally, I should locate reverse edges
# using find_edges().
yield
e
^
l
,
l
dec_tree_node_msg
=
DGLF
.
copy_edge
(
edge
=
'm'
,
out
=
'm'
)
dec_tree_node_reduce
=
DGLF
.
sum
(
msg
=
'm'
,
out
=
'h'
)
def
dec_tree_node_update
(
nodes
):
return
{
'new'
:
nodes
.
data
[
'new'
].
clone
().
zero_
()}
dec_tree_edge_msg
=
[
DGLF
.
copy_src
(
src
=
'm'
,
out
=
'm'
),
DGLF
.
copy_src
(
src
=
'rm'
,
out
=
'rm'
)]
dec_tree_edge_reduce
=
[
DGLF
.
sum
(
msg
=
'm'
,
out
=
's'
),
DGLF
.
sum
(
msg
=
'rm'
,
out
=
'accum_rm'
)]
def
have_slots
(
fa_slots
,
ch_slots
):
if
len
(
fa_slots
)
>
2
and
len
(
ch_slots
)
>
2
:
return
True
matches
=
[]
for
i
,
s1
in
enumerate
(
fa_slots
):
a1
,
c1
,
h1
=
s1
for
j
,
s2
in
enumerate
(
ch_slots
):
a2
,
c2
,
h2
=
s2
if
a1
==
a2
and
c1
==
c2
and
(
a1
!=
"C"
or
h1
+
h2
>=
4
):
matches
.
append
((
i
,
j
))
if
len
(
matches
)
==
0
:
return
False
fa_match
,
ch_match
=
list
(
zip
(
*
matches
))
if
len
(
set
(
fa_match
))
==
1
and
1
<
len
(
fa_slots
)
<=
2
:
# never remove atom from ring
fa_slots
.
pop
(
fa_match
[
0
])
if
len
(
set
(
ch_match
))
==
1
and
1
<
len
(
ch_slots
)
<=
2
:
# never remove atom from ring
ch_slots
.
pop
(
ch_match
[
0
])
return
True
def
can_assemble
(
mol_tree
,
u
,
v_node_dict
):
u_node_dict
=
mol_tree
.
nodes_dict
[
u
]
u_neighbors
=
mol_tree
.
successors
(
u
)
u_neighbors_node_dict
=
[
mol_tree
.
nodes_dict
[
_u
]
for
_u
in
u_neighbors
if
_u
in
mol_tree
.
nodes_dict
]
neis
=
u_neighbors_node_dict
+
[
v_node_dict
]
for
i
,
nei
in
enumerate
(
neis
):
nei
[
'nid'
]
=
i
neighbors
=
[
nei
for
nei
in
neis
if
nei
[
'mol'
].
GetNumAtoms
()
>
1
]
neighbors
=
sorted
(
neighbors
,
key
=
lambda
x
:
x
[
'mol'
].
GetNumAtoms
(),
reverse
=
True
)
singletons
=
[
nei
for
nei
in
neis
if
nei
[
'mol'
].
GetNumAtoms
()
==
1
]
neighbors
=
singletons
+
neighbors
cands
=
enum_assemble_nx
(
u_node_dict
,
neighbors
)
return
len
(
cands
)
>
0
def
create_node_dict
(
smiles
,
clique
=
None
):
if
clique
is
None
:
clique
=
[]
return
dict
(
smiles
=
smiles
,
mol
=
get_mol
(
smiles
),
clique
=
clique
,
)
class
DGLJTNNDecoder
(
nn
.
Module
):
def
__init__
(
self
,
vocab
,
hidden_size
,
latent_size
,
embedding
=
None
):
nn
.
Module
.
__init__
(
self
)
self
.
hidden_size
=
hidden_size
self
.
vocab_size
=
vocab
.
size
()
self
.
vocab
=
vocab
if
embedding
is
None
:
self
.
embedding
=
nn
.
Embedding
(
self
.
vocab_size
,
hidden_size
)
else
:
self
.
embedding
=
embedding
self
.
dec_tree_edge_update
=
GRUUpdate
(
hidden_size
)
self
.
W
=
nn
.
Linear
(
latent_size
+
hidden_size
,
hidden_size
)
self
.
U
=
nn
.
Linear
(
latent_size
+
2
*
hidden_size
,
hidden_size
)
self
.
W_o
=
nn
.
Linear
(
hidden_size
,
self
.
vocab_size
)
self
.
U_s
=
nn
.
Linear
(
hidden_size
,
1
)
def
forward
(
self
,
mol_trees
,
tree_vec
):
'''
The training procedure which computes the prediction loss given the
ground truth tree
'''
mol_tree_batch
=
batch
(
mol_trees
)
mol_tree_batch_lg
=
mol_tree_batch
.
line_graph
(
backtracking
=
False
,
shared
=
True
)
n_trees
=
len
(
mol_trees
)
return
self
.
run
(
mol_tree_batch
,
mol_tree_batch_lg
,
n_trees
,
tree_vec
)
def
run
(
self
,
mol_tree_batch
,
mol_tree_batch_lg
,
n_trees
,
tree_vec
):
node_offset
=
np
.
cumsum
([
0
]
+
mol_tree_batch
.
batch_num_nodes
)
root_ids
=
node_offset
[:
-
1
]
n_nodes
=
mol_tree_batch
.
number_of_nodes
()
n_edges
=
mol_tree_batch
.
number_of_edges
()
mol_tree_batch
.
ndata
.
update
({
'x'
:
self
.
embedding
(
mol_tree_batch
.
ndata
[
'wid'
]),
'h'
:
cuda
(
torch
.
zeros
(
n_nodes
,
self
.
hidden_size
)),
# whether it's newly generated node
'new'
:
cuda
(
torch
.
ones
(
n_nodes
).
bool
()),
})
mol_tree_batch
.
edata
.
update
({
's'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'm'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'r'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'z'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'src_x'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'dst_x'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'rm'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'accum_rm'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
})
mol_tree_batch
.
apply_edges
(
func
=
lambda
edges
:
{
'src_x'
:
edges
.
src
[
'x'
],
'dst_x'
:
edges
.
dst
[
'x'
]},
)
# input tensors for stop prediction (p) and label prediction (q)
p_inputs
=
[]
p_targets
=
[]
q_inputs
=
[]
q_targets
=
[]
# Predict root
mol_tree_batch
.
pull
(
root_ids
,
dec_tree_node_msg
,
dec_tree_node_reduce
,
dec_tree_node_update
,
)
# Extract hidden states and store them for stop/label prediction
h
=
mol_tree_batch
.
nodes
[
root_ids
].
data
[
'h'
]
x
=
mol_tree_batch
.
nodes
[
root_ids
].
data
[
'x'
]
p_inputs
.
append
(
torch
.
cat
([
x
,
h
,
tree_vec
],
1
))
# If the out degree is 0 we don't generate any edges at all
root_out_degrees
=
mol_tree_batch
.
out_degrees
(
root_ids
)
q_inputs
.
append
(
torch
.
cat
([
h
,
tree_vec
],
1
))
q_targets
.
append
(
mol_tree_batch
.
nodes
[
root_ids
].
data
[
'wid'
])
# Traverse the tree and predict on children
for
eid
,
p
in
dfs_order
(
mol_tree_batch
,
root_ids
):
u
,
v
=
mol_tree_batch
.
find_edges
(
eid
)
p_target_list
=
torch
.
zeros_like
(
root_out_degrees
)
p_target_list
[
root_out_degrees
>
0
]
=
1
-
p
p_target_list
=
p_target_list
[
root_out_degrees
>=
0
]
p_targets
.
append
(
p_target_list
.
clone
().
detach
())
root_out_degrees
-=
(
root_out_degrees
==
0
).
long
()
root_out_degrees
-=
torch
.
tensor
(
np
.
isin
(
root_ids
,
v
).
astype
(
'int64'
))
mol_tree_batch_lg
.
pull
(
eid
,
dec_tree_edge_msg
,
dec_tree_edge_reduce
,
self
.
dec_tree_edge_update
,
)
is_new
=
mol_tree_batch
.
nodes
[
v
].
data
[
'new'
]
mol_tree_batch
.
pull
(
v
,
dec_tree_node_msg
,
dec_tree_node_reduce
,
dec_tree_node_update
,
)
# Extract
n_repr
=
mol_tree_batch
.
nodes
[
v
].
data
h
=
n_repr
[
'h'
]
x
=
n_repr
[
'x'
]
tree_vec_set
=
tree_vec
[
root_out_degrees
>=
0
]
wid
=
n_repr
[
'wid'
]
p_inputs
.
append
(
torch
.
cat
([
x
,
h
,
tree_vec_set
],
1
))
# Only newly generated nodes are needed for label prediction
# NOTE: The following works since the uncomputed messages are zeros.
q_input
=
torch
.
cat
([
h
,
tree_vec_set
],
1
)[
is_new
]
q_target
=
wid
[
is_new
]
if
q_input
.
shape
[
0
]
>
0
:
q_inputs
.
append
(
q_input
)
q_targets
.
append
(
q_target
)
p_targets
.
append
(
torch
.
zeros
((
root_out_degrees
==
0
).
sum
()).
long
())
# Batch compute the stop/label prediction losses
p_inputs
=
torch
.
cat
(
p_inputs
,
0
)
p_targets
=
cuda
(
torch
.
cat
(
p_targets
,
0
))
q_inputs
=
torch
.
cat
(
q_inputs
,
0
)
q_targets
=
torch
.
cat
(
q_targets
,
0
)
q
=
self
.
W_o
(
torch
.
relu
(
self
.
W
(
q_inputs
)))
p
=
self
.
U_s
(
torch
.
relu
(
self
.
U
(
p_inputs
)))[:,
0
]
p_loss
=
F
.
binary_cross_entropy_with_logits
(
p
,
p_targets
.
float
(),
reduction
=
'sum'
)
/
n_trees
q_loss
=
F
.
cross_entropy
(
q
,
q_targets
,
reduction
=
'sum'
)
/
n_trees
p_acc
=
((
p
>
0
).
long
()
==
p_targets
).
sum
().
float
()
/
\
p_targets
.
shape
[
0
]
q_acc
=
(
q
.
max
(
1
)[
1
]
==
q_targets
).
float
().
sum
()
/
q_targets
.
shape
[
0
]
self
.
q_inputs
=
q_inputs
self
.
q_targets
=
q_targets
self
.
q
=
q
self
.
p_inputs
=
p_inputs
self
.
p_targets
=
p_targets
self
.
p
=
p
return
q_loss
,
p_loss
,
q_acc
,
p_acc
def
decode
(
self
,
mol_vec
):
assert
mol_vec
.
shape
[
0
]
==
1
mol_tree
=
DGLMolTree
(
None
)
init_hidden
=
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
))
root_hidden
=
torch
.
cat
([
init_hidden
,
mol_vec
],
1
)
root_hidden
=
F
.
relu
(
self
.
W
(
root_hidden
))
root_score
=
self
.
W_o
(
root_hidden
)
_
,
root_wid
=
torch
.
max
(
root_score
,
1
)
root_wid
=
root_wid
.
view
(
1
)
mol_tree
.
add_nodes
(
1
)
# root
mol_tree
.
nodes
[
0
].
data
[
'wid'
]
=
root_wid
mol_tree
.
nodes
[
0
].
data
[
'x'
]
=
self
.
embedding
(
root_wid
)
mol_tree
.
nodes
[
0
].
data
[
'h'
]
=
init_hidden
mol_tree
.
nodes
[
0
].
data
[
'fail'
]
=
cuda
(
torch
.
tensor
([
0
]))
mol_tree
.
nodes_dict
[
0
]
=
root_node_dict
=
create_node_dict
(
self
.
vocab
.
get_smiles
(
root_wid
))
stack
,
trace
=
[],
[]
stack
.
append
((
0
,
self
.
vocab
.
get_slots
(
root_wid
)))
all_nodes
=
{
0
:
root_node_dict
}
first
=
True
new_node_id
=
0
new_edge_id
=
0
for
step
in
range
(
MAX_DECODE_LEN
):
u
,
u_slots
=
stack
[
-
1
]
udata
=
mol_tree
.
nodes
[
u
].
data
x
=
udata
[
'x'
]
h
=
udata
[
'h'
]
# Predict stop
p_input
=
torch
.
cat
([
x
,
h
,
mol_vec
],
1
)
p_score
=
torch
.
sigmoid
(
self
.
U_s
(
torch
.
relu
(
self
.
U
(
p_input
))))
backtrack
=
(
p_score
.
item
()
<
0.5
)
if
not
backtrack
:
# Predict next clique. Note that the prediction may fail due
# to lack of assemblable components
mol_tree
.
add_nodes
(
1
)
new_node_id
+=
1
v
=
new_node_id
mol_tree
.
add_edges
(
u
,
v
)
uv
=
new_edge_id
new_edge_id
+=
1
if
first
:
mol_tree
.
edata
.
update
({
's'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
'm'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
'r'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
'z'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
'src_x'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
'dst_x'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
'rm'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
'accum_rm'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
})
first
=
False
mol_tree
.
edges
[
uv
].
data
[
'src_x'
]
=
mol_tree
.
nodes
[
u
].
data
[
'x'
]
# keeping dst_x 0 is fine as h on new edge doesn't depend on that.
# DGL doesn't dynamically maintain a line graph.
mol_tree_lg
=
mol_tree
.
line_graph
(
backtracking
=
False
,
shared
=
True
)
mol_tree_lg
.
pull
(
uv
,
dec_tree_edge_msg
,
dec_tree_edge_reduce
,
self
.
dec_tree_edge_update
.
update_zm
,
)
mol_tree
.
pull
(
v
,
dec_tree_node_msg
,
dec_tree_node_reduce
,
)
vdata
=
mol_tree
.
nodes
[
v
].
data
h_v
=
vdata
[
'h'
]
q_input
=
torch
.
cat
([
h_v
,
mol_vec
],
1
)
q_score
=
torch
.
softmax
(
self
.
W_o
(
torch
.
relu
(
self
.
W
(
q_input
))),
-
1
)
_
,
sort_wid
=
torch
.
sort
(
q_score
,
1
,
descending
=
True
)
sort_wid
=
sort_wid
.
squeeze
()
next_wid
=
None
for
wid
in
sort_wid
.
tolist
()[:
5
]:
slots
=
self
.
vocab
.
get_slots
(
wid
)
cand_node_dict
=
create_node_dict
(
self
.
vocab
.
get_smiles
(
wid
))
if
(
have_slots
(
u_slots
,
slots
)
and
can_assemble
(
mol_tree
,
u
,
cand_node_dict
)):
next_wid
=
wid
next_slots
=
slots
next_node_dict
=
cand_node_dict
break
if
next_wid
is
None
:
# Failed adding an actual children; v is a spurious node
# and we mark it.
vdata
[
'fail'
]
=
cuda
(
torch
.
tensor
([
1
]))
backtrack
=
True
else
:
next_wid
=
cuda
(
torch
.
tensor
([
next_wid
]))
vdata
[
'wid'
]
=
next_wid
vdata
[
'x'
]
=
self
.
embedding
(
next_wid
)
mol_tree
.
nodes_dict
[
v
]
=
next_node_dict
all_nodes
[
v
]
=
next_node_dict
stack
.
append
((
v
,
next_slots
))
mol_tree
.
add_edge
(
v
,
u
)
vu
=
new_edge_id
new_edge_id
+=
1
mol_tree
.
edges
[
uv
].
data
[
'dst_x'
]
=
mol_tree
.
nodes
[
v
].
data
[
'x'
]
mol_tree
.
edges
[
vu
].
data
[
'src_x'
]
=
mol_tree
.
nodes
[
v
].
data
[
'x'
]
mol_tree
.
edges
[
vu
].
data
[
'dst_x'
]
=
mol_tree
.
nodes
[
u
].
data
[
'x'
]
# DGL doesn't dynamically maintain a line graph.
mol_tree_lg
=
mol_tree
.
line_graph
(
backtracking
=
False
,
shared
=
True
)
mol_tree_lg
.
apply_nodes
(
self
.
dec_tree_edge_update
.
update_r
,
uv
)
if
backtrack
:
if
len
(
stack
)
==
1
:
break
# At root, terminate
pu
,
_
=
stack
[
-
2
]
u_pu
=
mol_tree
.
edge_id
(
u
,
pu
)
mol_tree_lg
.
pull
(
u_pu
,
dec_tree_edge_msg
,
dec_tree_edge_reduce
,
self
.
dec_tree_edge_update
,
)
mol_tree
.
pull
(
pu
,
dec_tree_node_msg
,
dec_tree_node_reduce
,
)
stack
.
pop
()
effective_nodes
=
mol_tree
.
filter_nodes
(
lambda
nodes
:
nodes
.
data
[
'fail'
]
!=
1
)
effective_nodes
,
_
=
torch
.
sort
(
effective_nodes
)
return
mol_tree
,
all_nodes
,
effective_nodes
apps/life_sci/python/dgllife/model/model_zoo/jtnn/jtnn_enc.py
deleted
100644 → 0
View file @
94c67203
# pylint: disable=C0111, C0103, E1101, W0611, W0612, W0221
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
dgl.function
as
DGLF
from
dgl
import
batch
,
bfs_edges_generator
from
.nnutils
import
GRUUpdate
,
cuda
MAX_NB
=
8
def
level_order
(
forest
,
roots
):
edges
=
bfs_edges_generator
(
forest
,
roots
)
_
,
leaves
=
forest
.
find_edges
(
edges
[
-
1
])
edges_back
=
bfs_edges_generator
(
forest
,
roots
,
reverse
=
True
)
yield
from
reversed
(
edges_back
)
yield
from
edges
enc_tree_msg
=
[
DGLF
.
copy_src
(
src
=
'm'
,
out
=
'm'
),
DGLF
.
copy_src
(
src
=
'rm'
,
out
=
'rm'
)]
enc_tree_reduce
=
[
DGLF
.
sum
(
msg
=
'm'
,
out
=
's'
),
DGLF
.
sum
(
msg
=
'rm'
,
out
=
'accum_rm'
)]
enc_tree_gather_msg
=
DGLF
.
copy_edge
(
edge
=
'm'
,
out
=
'm'
)
enc_tree_gather_reduce
=
DGLF
.
sum
(
msg
=
'm'
,
out
=
'm'
)
class
EncoderGatherUpdate
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
):
nn
.
Module
.
__init__
(
self
)
self
.
hidden_size
=
hidden_size
self
.
W
=
nn
.
Linear
(
2
*
hidden_size
,
hidden_size
)
def
forward
(
self
,
nodes
):
x
=
nodes
.
data
[
'x'
]
m
=
nodes
.
data
[
'm'
]
return
{
'h'
:
torch
.
relu
(
self
.
W
(
torch
.
cat
([
x
,
m
],
1
))),
}
class
DGLJTNNEncoder
(
nn
.
Module
):
def
__init__
(
self
,
vocab
,
hidden_size
,
embedding
=
None
):
nn
.
Module
.
__init__
(
self
)
self
.
hidden_size
=
hidden_size
self
.
vocab_size
=
vocab
.
size
()
self
.
vocab
=
vocab
if
embedding
is
None
:
self
.
embedding
=
nn
.
Embedding
(
self
.
vocab_size
,
hidden_size
)
else
:
self
.
embedding
=
embedding
self
.
enc_tree_update
=
GRUUpdate
(
hidden_size
)
self
.
enc_tree_gather_update
=
EncoderGatherUpdate
(
hidden_size
)
def
forward
(
self
,
mol_trees
):
mol_tree_batch
=
batch
(
mol_trees
)
# Build line graph to prepare for belief propagation
mol_tree_batch_lg
=
mol_tree_batch
.
line_graph
(
backtracking
=
False
,
shared
=
True
)
return
self
.
run
(
mol_tree_batch
,
mol_tree_batch_lg
)
def
run
(
self
,
mol_tree_batch
,
mol_tree_batch_lg
):
# Since tree roots are designated to 0. In the batched graph we can
# simply find the corresponding node ID by looking at node_offset
node_offset
=
np
.
cumsum
([
0
]
+
mol_tree_batch
.
batch_num_nodes
)
root_ids
=
node_offset
[:
-
1
]
n_nodes
=
mol_tree_batch
.
number_of_nodes
()
n_edges
=
mol_tree_batch
.
number_of_edges
()
# Assign structure embeddings to tree nodes
mol_tree_batch
.
ndata
.
update
({
'x'
:
self
.
embedding
(
mol_tree_batch
.
ndata
[
'wid'
]),
'h'
:
cuda
(
torch
.
zeros
(
n_nodes
,
self
.
hidden_size
)),
})
# Initialize the intermediate variables according to Eq (4)-(8).
# Also initialize the src_x and dst_x fields.
# TODO: context?
mol_tree_batch
.
edata
.
update
({
's'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'm'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'r'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'z'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'src_x'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'dst_x'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'rm'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'accum_rm'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
})
# Send the source/destination node features to edges
mol_tree_batch
.
apply_edges
(
func
=
lambda
edges
:
{
'src_x'
:
edges
.
src
[
'x'
],
'dst_x'
:
edges
.
dst
[
'x'
]},
)
# Message passing
# I exploited the fact that the reduce function is a sum of incoming
# messages, and the uncomputed messages are zero vectors. Essentially,
# we can always compute s_ij as the sum of incoming m_ij, no matter
# if m_ij is actually computed or not.
for
eid
in
level_order
(
mol_tree_batch
,
root_ids
):
#eid = mol_tree_batch.edge_ids(u, v)
mol_tree_batch_lg
.
pull
(
eid
,
enc_tree_msg
,
enc_tree_reduce
,
self
.
enc_tree_update
,
)
# Readout
mol_tree_batch
.
update_all
(
enc_tree_gather_msg
,
enc_tree_gather_reduce
,
self
.
enc_tree_gather_update
,
)
root_vecs
=
mol_tree_batch
.
nodes
[
root_ids
].
data
[
'h'
]
return
mol_tree_batch
,
root_vecs
apps/life_sci/python/dgllife/model/model_zoo/jtnn/jtnn_vae.py
deleted
100644 → 0
View file @
94c67203
# pylint: disable=C0111, C0103, E1101, W0611, W0612, C0200, W0221, E1102
import
copy
import
rdkit.Chem
as
Chem
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
dgl
import
batch
,
unbatch
from
dgl.data.utils
import
get_download_dir
from
.chemutils
import
(
attach_mols_nx
,
copy_edit_mol
,
decode_stereo
,
enum_assemble_nx
,
set_atommap
)
from
.jtmpn
import
DGLJTMPN
from
.jtmpn
import
mol2dgl_single
as
mol2dgl_dec
from
.jtnn_dec
import
DGLJTNNDecoder
from
.jtnn_enc
import
DGLJTNNEncoder
from
.mol_tree
import
Vocab
from
.mpn
import
DGLMPN
from
.mpn
import
mol2dgl_single
as
mol2dgl_enc
from
.nnutils
import
cuda
,
move_dgl_to_cuda
class
DGLJTNNVAE
(
nn
.
Module
):
"""
`Junction Tree Variational Autoencoder for Molecular Graph Generation
<https://arxiv.org/abs/1802.04364>`__
"""
def
__init__
(
self
,
hidden_size
,
latent_size
,
depth
,
vocab
=
None
,
vocab_file
=
None
):
super
(
DGLJTNNVAE
,
self
).
__init__
()
if
vocab
is
None
:
if
vocab_file
is
None
:
vocab_file
=
'{}/jtnn/{}.txt'
.
format
(
get_download_dir
(),
'vocab'
)
self
.
vocab
=
Vocab
([
x
.
strip
(
"
\r\n
"
)
for
x
in
open
(
vocab_file
)])
else
:
self
.
vocab
=
vocab
self
.
hidden_size
=
hidden_size
self
.
latent_size
=
latent_size
self
.
depth
=
depth
self
.
embedding
=
nn
.
Embedding
(
self
.
vocab
.
size
(),
hidden_size
)
self
.
mpn
=
DGLMPN
(
hidden_size
,
depth
)
self
.
jtnn
=
DGLJTNNEncoder
(
self
.
vocab
,
hidden_size
,
self
.
embedding
)
self
.
decoder
=
DGLJTNNDecoder
(
self
.
vocab
,
hidden_size
,
latent_size
//
2
,
self
.
embedding
)
self
.
jtmpn
=
DGLJTMPN
(
hidden_size
,
depth
)
self
.
T_mean
=
nn
.
Linear
(
hidden_size
,
latent_size
//
2
)
self
.
T_var
=
nn
.
Linear
(
hidden_size
,
latent_size
//
2
)
self
.
G_mean
=
nn
.
Linear
(
hidden_size
,
latent_size
//
2
)
self
.
G_var
=
nn
.
Linear
(
hidden_size
,
latent_size
//
2
)
self
.
n_nodes_total
=
0
self
.
n_passes
=
0
self
.
n_edges_total
=
0
self
.
n_tree_nodes_total
=
0
@
staticmethod
def
move_to_cuda
(
mol_batch
):
for
t
in
mol_batch
[
'mol_trees'
]:
move_dgl_to_cuda
(
t
)
move_dgl_to_cuda
(
mol_batch
[
'mol_graph_batch'
])
if
'cand_graph_batch'
in
mol_batch
:
move_dgl_to_cuda
(
mol_batch
[
'cand_graph_batch'
])
if
mol_batch
.
get
(
'stereo_cand_graph_batch'
)
is
not
None
:
move_dgl_to_cuda
(
mol_batch
[
'stereo_cand_graph_batch'
])
def
encode
(
self
,
mol_batch
):
mol_graphs
=
mol_batch
[
'mol_graph_batch'
]
mol_vec
=
self
.
mpn
(
mol_graphs
)
mol_tree_batch
,
tree_vec
=
self
.
jtnn
(
mol_batch
[
'mol_trees'
])
self
.
n_nodes_total
+=
mol_graphs
.
number_of_nodes
()
self
.
n_edges_total
+=
mol_graphs
.
number_of_edges
()
self
.
n_tree_nodes_total
+=
sum
(
t
.
number_of_nodes
()
for
t
in
mol_batch
[
'mol_trees'
])
self
.
n_passes
+=
1
return
mol_tree_batch
,
tree_vec
,
mol_vec
def
sample
(
self
,
tree_vec
,
mol_vec
,
e1
=
None
,
e2
=
None
):
tree_mean
=
self
.
T_mean
(
tree_vec
)
tree_log_var
=
-
torch
.
abs
(
self
.
T_var
(
tree_vec
))
mol_mean
=
self
.
G_mean
(
mol_vec
)
mol_log_var
=
-
torch
.
abs
(
self
.
G_var
(
mol_vec
))
epsilon
=
cuda
(
torch
.
randn
(
*
tree_mean
.
shape
))
if
e1
is
None
else
e1
tree_vec
=
tree_mean
+
torch
.
exp
(
tree_log_var
/
2
)
*
epsilon
epsilon
=
cuda
(
torch
.
randn
(
*
mol_mean
.
shape
))
if
e2
is
None
else
e2
mol_vec
=
mol_mean
+
torch
.
exp
(
mol_log_var
/
2
)
*
epsilon
z_mean
=
torch
.
cat
([
tree_mean
,
mol_mean
],
1
)
z_log_var
=
torch
.
cat
([
tree_log_var
,
mol_log_var
],
1
)
return
tree_vec
,
mol_vec
,
z_mean
,
z_log_var
def
forward
(
self
,
mol_batch
,
beta
=
0
,
e1
=
None
,
e2
=
None
):
self
.
move_to_cuda
(
mol_batch
)
mol_trees
=
mol_batch
[
'mol_trees'
]
batch_size
=
len
(
mol_trees
)
mol_tree_batch
,
tree_vec
,
mol_vec
=
self
.
encode
(
mol_batch
)
tree_vec
,
mol_vec
,
z_mean
,
z_log_var
=
self
.
sample
(
tree_vec
,
mol_vec
,
e1
,
e2
)
kl_loss
=
-
0.5
*
torch
.
sum
(
1.0
+
z_log_var
-
z_mean
*
z_mean
-
torch
.
exp
(
z_log_var
))
/
batch_size
word_loss
,
topo_loss
,
word_acc
,
topo_acc
=
self
.
decoder
(
mol_trees
,
tree_vec
)
assm_loss
,
assm_acc
=
self
.
assm
(
mol_batch
,
mol_tree_batch
,
mol_vec
)
stereo_loss
,
stereo_acc
=
self
.
stereo
(
mol_batch
,
mol_vec
)
loss
=
word_loss
+
topo_loss
+
assm_loss
+
2
*
stereo_loss
+
beta
*
kl_loss
return
loss
,
kl_loss
,
word_acc
,
topo_acc
,
assm_acc
,
stereo_acc
def
assm
(
self
,
mol_batch
,
mol_tree_batch
,
mol_vec
):
cands
=
[
mol_batch
[
'cand_graph_batch'
],
mol_batch
[
'tree_mess_src_e'
],
mol_batch
[
'tree_mess_tgt_e'
],
mol_batch
[
'tree_mess_tgt_n'
]]
cand_vec
=
self
.
jtmpn
(
cands
,
mol_tree_batch
)
cand_vec
=
self
.
G_mean
(
cand_vec
)
batch_idx
=
cuda
(
torch
.
LongTensor
(
mol_batch
[
'cand_batch_idx'
]))
mol_vec
=
mol_vec
[
batch_idx
]
mol_vec
=
mol_vec
.
view
(
-
1
,
1
,
self
.
latent_size
//
2
)
cand_vec
=
cand_vec
.
view
(
-
1
,
self
.
latent_size
//
2
,
1
)
scores
=
(
mol_vec
@
cand_vec
)[:,
0
,
0
]
cnt
,
tot
,
acc
=
0
,
0
,
0
all_loss
=
[]
for
i
,
mol_tree
in
enumerate
(
mol_batch
[
'mol_trees'
]):
comp_nodes
=
[
node_id
for
node_id
,
node
in
mol_tree
.
nodes_dict
.
items
()
if
len
(
node
[
'cands'
])
>
1
and
not
node
[
'is_leaf'
]]
cnt
+=
len
(
comp_nodes
)
# segmented accuracy and cross entropy
for
node_id
in
comp_nodes
:
node
=
mol_tree
.
nodes_dict
[
node_id
]
label
=
node
[
'cands'
].
index
(
node
[
'label'
])
ncand
=
len
(
node
[
'cands'
])
cur_score
=
scores
[
tot
:
tot
+
ncand
]
tot
+=
ncand
if
cur_score
[
label
].
item
()
>=
cur_score
.
max
().
item
():
acc
+=
1
label
=
cuda
(
torch
.
LongTensor
([
label
]))
all_loss
.
append
(
F
.
cross_entropy
(
cur_score
.
view
(
1
,
-
1
),
label
,
reduction
=
'sum'
))
all_loss
=
sum
(
all_loss
)
/
len
(
mol_batch
[
'mol_trees'
])
return
all_loss
,
acc
/
cnt
def
stereo
(
self
,
mol_batch
,
mol_vec
):
stereo_cands
=
mol_batch
[
'stereo_cand_graph_batch'
]
batch_idx
=
mol_batch
[
'stereo_cand_batch_idx'
]
labels
=
mol_batch
[
'stereo_cand_labels'
]
lengths
=
mol_batch
[
'stereo_cand_lengths'
]
if
len
(
labels
)
==
0
:
# Only one stereoisomer exists; do nothing
return
cuda
(
torch
.
tensor
(
0.
)),
1.
batch_idx
=
cuda
(
torch
.
LongTensor
(
batch_idx
))
stereo_cands
=
self
.
mpn
(
stereo_cands
)
stereo_cands
=
self
.
G_mean
(
stereo_cands
)
stereo_labels
=
mol_vec
[
batch_idx
]
scores
=
F
.
cosine_similarity
(
stereo_cands
,
stereo_labels
)
st
,
acc
=
0
,
0
all_loss
=
[]
for
label
,
le
in
zip
(
labels
,
lengths
):
cur_scores
=
scores
[
st
:
st
+
le
]
if
cur_scores
.
data
[
label
].
item
()
>=
cur_scores
.
max
().
item
():
acc
+=
1
label
=
cuda
(
torch
.
LongTensor
([
label
]))
all_loss
.
append
(
F
.
cross_entropy
(
cur_scores
.
view
(
1
,
-
1
),
label
,
reduction
=
'sum'
))
st
+=
le
all_loss
=
sum
(
all_loss
)
/
len
(
labels
)
return
all_loss
,
acc
/
len
(
labels
)
def
decode
(
self
,
tree_vec
,
mol_vec
):
mol_tree
,
nodes_dict
,
effective_nodes
=
self
.
decoder
.
decode
(
tree_vec
)
effective_nodes_list
=
effective_nodes
.
tolist
()
nodes_dict
=
[
nodes_dict
[
v
]
for
v
in
effective_nodes_list
]
for
i
,
(
node_id
,
node
)
in
enumerate
(
zip
(
effective_nodes_list
,
nodes_dict
)):
node
[
'idx'
]
=
i
node
[
'nid'
]
=
i
+
1
node
[
'is_leaf'
]
=
True
if
mol_tree
.
in_degree
(
node_id
)
>
1
:
node
[
'is_leaf'
]
=
False
set_atommap
(
node
[
'mol'
],
node
[
'nid'
])
mol_tree_sg
=
mol_tree
.
subgraph
(
effective_nodes
)
mol_tree_sg
.
copy_from_parent
()
mol_tree_msg
,
_
=
self
.
jtnn
([
mol_tree_sg
])
mol_tree_msg
=
unbatch
(
mol_tree_msg
)[
0
]
mol_tree_msg
.
nodes_dict
=
nodes_dict
cur_mol
=
copy_edit_mol
(
nodes_dict
[
0
][
'mol'
])
global_amap
=
[{}]
+
[{}
for
node
in
nodes_dict
]
global_amap
[
1
]
=
{
atom
.
GetIdx
():
atom
.
GetIdx
()
for
atom
in
cur_mol
.
GetAtoms
()}
cur_mol
=
self
.
dfs_assemble
(
mol_tree_msg
,
mol_vec
,
cur_mol
,
global_amap
,
[],
0
,
None
)
if
cur_mol
is
None
:
return
None
cur_mol
=
cur_mol
.
GetMol
()
set_atommap
(
cur_mol
)
cur_mol
=
Chem
.
MolFromSmiles
(
Chem
.
MolToSmiles
(
cur_mol
))
if
cur_mol
is
None
:
return
None
smiles2D
=
Chem
.
MolToSmiles
(
cur_mol
)
stereo_cands
=
decode_stereo
(
smiles2D
)
if
len
(
stereo_cands
)
==
1
:
return
stereo_cands
[
0
]
stereo_graphs
=
[
mol2dgl_enc
(
c
)
for
c
in
stereo_cands
]
stereo_cand_graphs
,
atom_x
,
bond_x
=
\
zip
(
*
stereo_graphs
)
stereo_cand_graphs
=
batch
(
stereo_cand_graphs
)
atom_x
=
cuda
(
torch
.
cat
(
atom_x
))
bond_x
=
cuda
(
torch
.
cat
(
bond_x
))
stereo_cand_graphs
.
ndata
[
'x'
]
=
atom_x
stereo_cand_graphs
.
edata
[
'x'
]
=
bond_x
stereo_cand_graphs
.
edata
[
'src_x'
]
=
atom_x
.
new
(
bond_x
.
shape
[
0
],
atom_x
.
shape
[
1
]).
zero_
()
stereo_vecs
=
self
.
mpn
(
stereo_cand_graphs
)
stereo_vecs
=
self
.
G_mean
(
stereo_vecs
)
scores
=
F
.
cosine_similarity
(
stereo_vecs
,
mol_vec
)
_
,
max_id
=
scores
.
max
(
0
)
return
stereo_cands
[
max_id
.
item
()]
def
dfs_assemble
(
self
,
mol_tree_msg
,
mol_vec
,
cur_mol
,
global_amap
,
fa_amap
,
cur_node_id
,
fa_node_id
):
nodes_dict
=
mol_tree_msg
.
nodes_dict
fa_node
=
nodes_dict
[
fa_node_id
]
if
fa_node_id
is
not
None
else
None
cur_node
=
nodes_dict
[
cur_node_id
]
fa_nid
=
fa_node
[
'nid'
]
if
fa_node
is
not
None
else
-
1
prev_nodes
=
[
fa_node
]
if
fa_node
is
not
None
else
[]
children_node_id
=
[
v
for
v
in
mol_tree_msg
.
successors
(
cur_node_id
).
tolist
()
if
nodes_dict
[
v
][
'nid'
]
!=
fa_nid
]
children
=
[
nodes_dict
[
v
]
for
v
in
children_node_id
]
neighbors
=
[
nei
for
nei
in
children
if
nei
[
'mol'
].
GetNumAtoms
()
>
1
]
neighbors
=
sorted
(
neighbors
,
key
=
lambda
x
:
x
[
'mol'
].
GetNumAtoms
(),
reverse
=
True
)
singletons
=
[
nei
for
nei
in
children
if
nei
[
'mol'
].
GetNumAtoms
()
==
1
]
neighbors
=
singletons
+
neighbors
cur_amap
=
[(
fa_nid
,
a2
,
a1
)
for
nid
,
a1
,
a2
in
fa_amap
if
nid
==
cur_node
[
'nid'
]]
cands
=
enum_assemble_nx
(
cur_node
,
neighbors
,
prev_nodes
,
cur_amap
)
if
len
(
cands
)
==
0
:
return
None
cand_smiles
,
cand_mols
,
cand_amap
=
list
(
zip
(
*
cands
))
cands
=
[(
candmol
,
mol_tree_msg
,
cur_node_id
)
for
candmol
in
cand_mols
]
cand_graphs
,
atom_x
,
bond_x
,
tree_mess_src_edges
,
\
tree_mess_tgt_edges
,
tree_mess_tgt_nodes
=
mol2dgl_dec
(
cands
)
cand_graphs
=
batch
(
cand_graphs
)
atom_x
=
cuda
(
atom_x
)
bond_x
=
cuda
(
bond_x
)
cand_graphs
.
ndata
[
'x'
]
=
atom_x
cand_graphs
.
edata
[
'x'
]
=
bond_x
cand_graphs
.
edata
[
'src_x'
]
=
atom_x
.
new
(
bond_x
.
shape
[
0
],
atom_x
.
shape
[
1
]).
zero_
()
cand_vecs
=
self
.
jtmpn
(
(
cand_graphs
,
tree_mess_src_edges
,
tree_mess_tgt_edges
,
tree_mess_tgt_nodes
),
mol_tree_msg
,
)
cand_vecs
=
self
.
G_mean
(
cand_vecs
)
mol_vec
=
mol_vec
.
squeeze
()
scores
=
cand_vecs
@
mol_vec
_
,
cand_idx
=
torch
.
sort
(
scores
,
descending
=
True
)
backup_mol
=
Chem
.
RWMol
(
cur_mol
)
for
i
in
range
(
len
(
cand_idx
)):
cur_mol
=
Chem
.
RWMol
(
backup_mol
)
pred_amap
=
cand_amap
[
cand_idx
[
i
].
item
()]
new_global_amap
=
copy
.
deepcopy
(
global_amap
)
for
nei_id
,
ctr_atom
,
nei_atom
in
pred_amap
:
if
nei_id
==
fa_nid
:
continue
new_global_amap
[
nei_id
][
nei_atom
]
=
new_global_amap
[
cur_node
[
'nid'
]][
ctr_atom
]
cur_mol
=
attach_mols_nx
(
cur_mol
,
children
,
[],
new_global_amap
)
new_mol
=
cur_mol
.
GetMol
()
new_mol
=
Chem
.
MolFromSmiles
(
Chem
.
MolToSmiles
(
new_mol
))
if
new_mol
is
None
:
continue
result
=
True
for
nei_node_id
,
nei_node
in
zip
(
children_node_id
,
children
):
if
nei_node
[
'is_leaf'
]:
continue
cur_mol
=
self
.
dfs_assemble
(
mol_tree_msg
,
mol_vec
,
cur_mol
,
new_global_amap
,
pred_amap
,
nei_node_id
,
cur_node_id
)
if
cur_mol
is
None
:
result
=
False
break
if
result
:
return
cur_mol
return
None
apps/life_sci/python/dgllife/model/model_zoo/jtnn/mol_tree.py
deleted
100644 → 0
View file @
94c67203
# pylint: disable=C0111, C0103, E1101, W0611, W0612
import
copy
import
rdkit.Chem
as
Chem
def
get_slots
(
smiles
):
mol
=
Chem
.
MolFromSmiles
(
smiles
)
return
[(
atom
.
GetSymbol
(),
atom
.
GetFormalCharge
(),
atom
.
GetTotalNumHs
())
for
atom
in
mol
.
GetAtoms
()]
class
Vocab
(
object
):
def
__init__
(
self
,
smiles_list
):
self
.
vocab
=
smiles_list
self
.
vmap
=
{
x
:
i
for
i
,
x
in
enumerate
(
self
.
vocab
)}
self
.
slots
=
[
get_slots
(
smiles
)
for
smiles
in
self
.
vocab
]
def
get_index
(
self
,
smiles
):
return
self
.
vmap
[
smiles
]
def
get_smiles
(
self
,
idx
):
return
self
.
vocab
[
idx
]
def
get_slots
(
self
,
idx
):
return
copy
.
deepcopy
(
self
.
slots
[
idx
])
def
size
(
self
):
return
len
(
self
.
vocab
)
apps/life_sci/python/dgllife/model/model_zoo/jtnn/mol_tree_nx.py
deleted
100644 → 0
View file @
94c67203
# pylint: disable=C0111, C0103, E1101, W0611, W0612
import
numpy
as
np
import
rdkit.Chem
as
Chem
from
dgl
import
DGLGraph
from
.chemutils
import
(
decode_stereo
,
enum_assemble_nx
,
get_clique_mol
,
get_mol
,
get_smiles
,
set_atommap
,
tree_decomp
)
class
DGLMolTree
(
DGLGraph
):
def
__init__
(
self
,
smiles
):
DGLGraph
.
__init__
(
self
)
self
.
nodes_dict
=
{}
if
smiles
is
None
:
return
self
.
smiles
=
smiles
self
.
mol
=
get_mol
(
smiles
)
# Stereo Generation
mol
=
Chem
.
MolFromSmiles
(
smiles
)
self
.
smiles3D
=
Chem
.
MolToSmiles
(
mol
,
isomericSmiles
=
True
)
self
.
smiles2D
=
Chem
.
MolToSmiles
(
mol
)
self
.
stereo_cands
=
decode_stereo
(
self
.
smiles2D
)
# cliques: a list of list of atom indices
cliques
,
edges
=
tree_decomp
(
self
.
mol
)
root
=
0
for
i
,
c
in
enumerate
(
cliques
):
cmol
=
get_clique_mol
(
self
.
mol
,
c
)
csmiles
=
get_smiles
(
cmol
)
self
.
nodes_dict
[
i
]
=
dict
(
smiles
=
csmiles
,
mol
=
get_mol
(
csmiles
),
clique
=
c
,
)
if
min
(
c
)
==
0
:
root
=
i
self
.
add_nodes
(
len
(
cliques
))
# The clique with atom ID 0 becomes root
if
root
>
0
:
for
attr
in
self
.
nodes_dict
[
0
]:
self
.
nodes_dict
[
0
][
attr
],
self
.
nodes_dict
[
root
][
attr
]
=
\
self
.
nodes_dict
[
root
][
attr
],
self
.
nodes_dict
[
0
][
attr
]
src
=
np
.
zeros
((
len
(
edges
)
*
2
,),
dtype
=
'int'
)
dst
=
np
.
zeros
((
len
(
edges
)
*
2
,),
dtype
=
'int'
)
for
i
,
(
_x
,
_y
)
in
enumerate
(
edges
):
x
=
0
if
_x
==
root
else
root
if
_x
==
0
else
_x
y
=
0
if
_y
==
root
else
root
if
_y
==
0
else
_y
src
[
2
*
i
]
=
x
dst
[
2
*
i
]
=
y
src
[
2
*
i
+
1
]
=
y
dst
[
2
*
i
+
1
]
=
x
self
.
add_edges
(
src
,
dst
)
for
i
in
self
.
nodes_dict
:
self
.
nodes_dict
[
i
][
'nid'
]
=
i
+
1
if
self
.
out_degree
(
i
)
>
1
:
# Leaf node mol is not marked
set_atommap
(
self
.
nodes_dict
[
i
][
'mol'
],
self
.
nodes_dict
[
i
][
'nid'
])
self
.
nodes_dict
[
i
][
'is_leaf'
]
=
(
self
.
out_degree
(
i
)
==
1
)
def
treesize
(
self
):
return
self
.
number_of_nodes
()
def
_recover_node
(
self
,
i
,
original_mol
):
node
=
self
.
nodes_dict
[
i
]
clique
=
[]
clique
.
extend
(
node
[
'clique'
])
if
not
node
[
'is_leaf'
]:
for
cidx
in
node
[
'clique'
]:
original_mol
.
GetAtomWithIdx
(
cidx
).
SetAtomMapNum
(
node
[
'nid'
])
for
j
in
self
.
successors
(
i
).
numpy
():
nei_node
=
self
.
nodes_dict
[
j
]
clique
.
extend
(
nei_node
[
'clique'
])
if
nei_node
[
'is_leaf'
]:
# Leaf node, no need to mark
continue
for
cidx
in
nei_node
[
'clique'
]:
# allow singleton node override the atom mapping
if
cidx
not
in
node
[
'clique'
]
or
len
(
nei_node
[
'clique'
])
==
1
:
atom
=
original_mol
.
GetAtomWithIdx
(
cidx
)
atom
.
SetAtomMapNum
(
nei_node
[
'nid'
])
clique
=
list
(
set
(
clique
))
label_mol
=
get_clique_mol
(
original_mol
,
clique
)
node
[
'label'
]
=
Chem
.
MolToSmiles
(
Chem
.
MolFromSmiles
(
get_smiles
(
label_mol
)))
node
[
'label_mol'
]
=
get_mol
(
node
[
'label'
])
for
cidx
in
clique
:
original_mol
.
GetAtomWithIdx
(
cidx
).
SetAtomMapNum
(
0
)
return
node
[
'label'
]
def
_assemble_node
(
self
,
i
):
neighbors
=
[
self
.
nodes_dict
[
j
]
for
j
in
self
.
successors
(
i
).
numpy
()
if
self
.
nodes_dict
[
j
][
'mol'
].
GetNumAtoms
()
>
1
]
neighbors
=
sorted
(
neighbors
,
key
=
lambda
x
:
x
[
'mol'
].
GetNumAtoms
(),
reverse
=
True
)
singletons
=
[
self
.
nodes_dict
[
j
]
for
j
in
self
.
successors
(
i
).
numpy
()
if
self
.
nodes_dict
[
j
][
'mol'
].
GetNumAtoms
()
==
1
]
neighbors
=
singletons
+
neighbors
cands
=
enum_assemble_nx
(
self
.
nodes_dict
[
i
],
neighbors
)
if
len
(
cands
)
>
0
:
self
.
nodes_dict
[
i
][
'cands'
],
self
.
nodes_dict
[
i
][
'cand_mols'
],
_
=
list
(
zip
(
*
cands
))
self
.
nodes_dict
[
i
][
'cands'
]
=
list
(
self
.
nodes_dict
[
i
][
'cands'
])
self
.
nodes_dict
[
i
][
'cand_mols'
]
=
list
(
self
.
nodes_dict
[
i
][
'cand_mols'
])
else
:
self
.
nodes_dict
[
i
][
'cands'
]
=
[]
self
.
nodes_dict
[
i
][
'cand_mols'
]
=
[]
def
recover
(
self
):
for
i
in
self
.
nodes_dict
:
self
.
_recover_node
(
i
,
self
.
mol
)
def
assemble
(
self
):
for
i
in
self
.
nodes_dict
:
self
.
_assemble_node
(
i
)
apps/life_sci/python/dgllife/model/model_zoo/jtnn/mpn.py
deleted
100644 → 0
View file @
94c67203
# pylint: disable=C0111, C0103, E1101, W0611, W0612, I1101, W0221
# pylint: disable=redefined-outer-name
import
rdkit.Chem
as
Chem
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
dgl.function
as
DGLF
from
dgl
import
DGLGraph
,
mean_nodes
from
.chemutils
import
get_mol
ELEM_LIST
=
[
'C'
,
'N'
,
'O'
,
'S'
,
'F'
,
'Si'
,
'P'
,
'Cl'
,
'Br'
,
'Mg'
,
'Na'
,
'Ca'
,
'Fe'
,
'Al'
,
'I'
,
'B'
,
'K'
,
'Se'
,
'Zn'
,
'H'
,
'Cu'
,
'Mn'
,
'unknown'
]
ATOM_FDIM
=
len
(
ELEM_LIST
)
+
6
+
5
+
4
+
1
BOND_FDIM
=
5
+
6
MAX_NB
=
6
def
onek_encoding_unk
(
x
,
allowable_set
):
if
x
not
in
allowable_set
:
x
=
allowable_set
[
-
1
]
return
[
x
==
s
for
s
in
allowable_set
]
def
atom_features
(
atom
):
return
(
torch
.
Tensor
(
onek_encoding_unk
(
atom
.
GetSymbol
(),
ELEM_LIST
)
+
onek_encoding_unk
(
atom
.
GetDegree
(),
[
0
,
1
,
2
,
3
,
4
,
5
])
+
onek_encoding_unk
(
atom
.
GetFormalCharge
(),
[
-
1
,
-
2
,
1
,
2
,
0
])
+
onek_encoding_unk
(
int
(
atom
.
GetChiralTag
()),
[
0
,
1
,
2
,
3
])
+
[
atom
.
GetIsAromatic
()]))
def
bond_features
(
bond
):
bt
=
bond
.
GetBondType
()
stereo
=
int
(
bond
.
GetStereo
())
fbond
=
[
bt
==
Chem
.
rdchem
.
BondType
.
SINGLE
,
bt
==
Chem
.
rdchem
.
BondType
.
DOUBLE
,
bt
==
Chem
.
rdchem
.
BondType
.
TRIPLE
,
bt
==
Chem
.
rdchem
.
BondType
.
AROMATIC
,
bond
.
IsInRing
()]
fstereo
=
onek_encoding_unk
(
stereo
,
[
0
,
1
,
2
,
3
,
4
,
5
])
return
torch
.
Tensor
(
fbond
+
fstereo
)
def
mol2dgl_single
(
smiles
):
n_edges
=
0
atom_x
=
[]
bond_x
=
[]
mol
=
get_mol
(
smiles
)
n_atoms
=
mol
.
GetNumAtoms
()
n_bonds
=
mol
.
GetNumBonds
()
graph
=
DGLGraph
()
for
i
,
atom
in
enumerate
(
mol
.
GetAtoms
()):
assert
i
==
atom
.
GetIdx
()
atom_x
.
append
(
atom_features
(
atom
))
graph
.
add_nodes
(
n_atoms
)
bond_src
=
[]
bond_dst
=
[]
for
i
,
bond
in
enumerate
(
mol
.
GetBonds
()):
begin_idx
=
bond
.
GetBeginAtom
().
GetIdx
()
end_idx
=
bond
.
GetEndAtom
().
GetIdx
()
features
=
bond_features
(
bond
)
bond_src
.
append
(
begin_idx
)
bond_dst
.
append
(
end_idx
)
bond_x
.
append
(
features
)
# set up the reverse direction
bond_src
.
append
(
end_idx
)
bond_dst
.
append
(
begin_idx
)
bond_x
.
append
(
features
)
graph
.
add_edges
(
bond_src
,
bond_dst
)
n_edges
+=
n_bonds
return
graph
,
torch
.
stack
(
atom_x
),
\
torch
.
stack
(
bond_x
)
if
len
(
bond_x
)
>
0
else
torch
.
zeros
(
0
)
mpn_loopy_bp_msg
=
DGLF
.
copy_src
(
src
=
'msg'
,
out
=
'msg'
)
mpn_loopy_bp_reduce
=
DGLF
.
sum
(
msg
=
'msg'
,
out
=
'accum_msg'
)
class
LoopyBPUpdate
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
):
super
(
LoopyBPUpdate
,
self
).
__init__
()
self
.
hidden_size
=
hidden_size
self
.
W_h
=
nn
.
Linear
(
hidden_size
,
hidden_size
,
bias
=
False
)
def
forward
(
self
,
nodes
):
msg_input
=
nodes
.
data
[
'msg_input'
]
msg_delta
=
self
.
W_h
(
nodes
.
data
[
'accum_msg'
])
msg
=
F
.
relu
(
msg_input
+
msg_delta
)
return
{
'msg'
:
msg
}
mpn_gather_msg
=
DGLF
.
copy_edge
(
edge
=
'msg'
,
out
=
'msg'
)
mpn_gather_reduce
=
DGLF
.
sum
(
msg
=
'msg'
,
out
=
'm'
)
class
GatherUpdate
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
):
super
(
GatherUpdate
,
self
).
__init__
()
self
.
hidden_size
=
hidden_size
self
.
W_o
=
nn
.
Linear
(
ATOM_FDIM
+
hidden_size
,
hidden_size
)
def
forward
(
self
,
nodes
):
m
=
nodes
.
data
[
'm'
]
return
{
'h'
:
F
.
relu
(
self
.
W_o
(
torch
.
cat
([
nodes
.
data
[
'x'
],
m
],
1
))),
}
class
DGLMPN
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
depth
):
super
(
DGLMPN
,
self
).
__init__
()
self
.
depth
=
depth
self
.
W_i
=
nn
.
Linear
(
ATOM_FDIM
+
BOND_FDIM
,
hidden_size
,
bias
=
False
)
self
.
loopy_bp_updater
=
LoopyBPUpdate
(
hidden_size
)
self
.
gather_updater
=
GatherUpdate
(
hidden_size
)
self
.
hidden_size
=
hidden_size
self
.
n_samples_total
=
0
self
.
n_nodes_total
=
0
self
.
n_edges_total
=
0
self
.
n_passes
=
0
def
forward
(
self
,
mol_graph
):
n_samples
=
mol_graph
.
batch_size
mol_line_graph
=
mol_graph
.
line_graph
(
backtracking
=
False
,
shared
=
True
)
n_nodes
=
mol_graph
.
number_of_nodes
()
n_edges
=
mol_graph
.
number_of_edges
()
mol_graph
=
self
.
run
(
mol_graph
,
mol_line_graph
)
# TODO: replace with unbatch or readout
g_repr
=
mean_nodes
(
mol_graph
,
'h'
)
self
.
n_samples_total
+=
n_samples
self
.
n_nodes_total
+=
n_nodes
self
.
n_edges_total
+=
n_edges
self
.
n_passes
+=
1
return
g_repr
def
run
(
self
,
mol_graph
,
mol_line_graph
):
n_nodes
=
mol_graph
.
number_of_nodes
()
mol_graph
.
apply_edges
(
func
=
lambda
edges
:
{
'src_x'
:
edges
.
src
[
'x'
]},
)
e_repr
=
mol_line_graph
.
ndata
bond_features
=
e_repr
[
'x'
]
source_features
=
e_repr
[
'src_x'
]
features
=
torch
.
cat
([
source_features
,
bond_features
],
1
)
msg_input
=
self
.
W_i
(
features
)
mol_line_graph
.
ndata
.
update
({
'msg_input'
:
msg_input
,
'msg'
:
F
.
relu
(
msg_input
),
'accum_msg'
:
torch
.
zeros_like
(
msg_input
),
})
mol_graph
.
ndata
.
update
({
'm'
:
bond_features
.
new
(
n_nodes
,
self
.
hidden_size
).
zero_
(),
'h'
:
bond_features
.
new
(
n_nodes
,
self
.
hidden_size
).
zero_
(),
})
for
i
in
range
(
self
.
depth
-
1
):
mol_line_graph
.
update_all
(
mpn_loopy_bp_msg
,
mpn_loopy_bp_reduce
,
self
.
loopy_bp_updater
,
)
mol_graph
.
update_all
(
mpn_gather_msg
,
mpn_gather_reduce
,
self
.
gather_updater
,
)
return
mol_graph
apps/life_sci/python/dgllife/model/model_zoo/jtnn/nnutils.py
deleted
100644 → 0
View file @
94c67203
# pylint: disable=C0111, C0103, E1101, W0611, W0612, W0221
import
os
import
torch
import
torch.nn
as
nn
from
torch.autograd
import
Variable
def
create_var
(
tensor
,
requires_grad
=
None
):
if
requires_grad
is
None
:
return
Variable
(
tensor
)
else
:
return
Variable
(
tensor
,
requires_grad
=
requires_grad
)
def
cuda
(
tensor
):
if
torch
.
cuda
.
is_available
()
and
not
os
.
getenv
(
'NOCUDA'
,
None
):
return
tensor
.
cuda
()
else
:
return
tensor
class
GRUUpdate
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
):
nn
.
Module
.
__init__
(
self
)
self
.
hidden_size
=
hidden_size
self
.
W_z
=
nn
.
Linear
(
2
*
hidden_size
,
hidden_size
)
self
.
W_r
=
nn
.
Linear
(
hidden_size
,
hidden_size
,
bias
=
False
)
self
.
U_r
=
nn
.
Linear
(
hidden_size
,
hidden_size
)
self
.
W_h
=
nn
.
Linear
(
2
*
hidden_size
,
hidden_size
)
def
update_zm
(
self
,
node
):
src_x
=
node
.
data
[
'src_x'
]
s
=
node
.
data
[
's'
]
rm
=
node
.
data
[
'accum_rm'
]
z
=
torch
.
sigmoid
(
self
.
W_z
(
torch
.
cat
([
src_x
,
s
],
1
)))
m
=
torch
.
tanh
(
self
.
W_h
(
torch
.
cat
([
src_x
,
rm
],
1
)))
m
=
(
1
-
z
)
*
s
+
z
*
m
return
{
'm'
:
m
,
'z'
:
z
}
def
update_r
(
self
,
node
,
zm
=
None
):
dst_x
=
node
.
data
[
'dst_x'
]
m
=
node
.
data
[
'm'
]
if
zm
is
None
else
zm
[
'm'
]
r_1
=
self
.
W_r
(
dst_x
)
r_2
=
self
.
U_r
(
m
)
r
=
torch
.
sigmoid
(
r_1
+
r_2
)
return
{
'r'
:
r
,
'rm'
:
r
*
m
}
def
forward
(
self
,
node
):
dic
=
self
.
update_zm
(
node
)
dic
.
update
(
self
.
update_r
(
node
,
zm
=
dic
))
return
dic
def
move_dgl_to_cuda
(
g
):
g
.
ndata
.
update
({
k
:
cuda
(
g
.
ndata
[
k
])
for
k
in
g
.
ndata
})
g
.
edata
.
update
({
k
:
cuda
(
g
.
edata
[
k
])
for
k
in
g
.
edata
})
apps/life_sci/python/dgllife/model/model_zoo/mgcn_predictor.py
deleted
100644 → 0
View file @
94c67203
"""MGCN"""
# pylint: disable= no-member, arguments-differ, invalid-name
import
torch.nn
as
nn
from
..gnn
import
MGCNGNN
from
..readout
import
MLPNodeReadout
__all__
=
[
'MGCNPredictor'
]
# pylint: disable=W0221
class
MGCNPredictor
(
nn
.
Module
):
"""MGCN for for regression and classification on graphs.
MGCN is introduced in `Molecular Property Prediction: A Multilevel Quantum Interactions
Modeling Perspective <https://arxiv.org/abs/1906.11081>`__.
Parameters
----------
feats : int
Size for the node and edge embeddings to learn. Default to 128.
n_layers : int
Number of gnn layers to use. Default to 3.
classifier_hidden_feats : int
Size for hidden representations in the classifier. Default to 64.
n_tasks : int
Number of tasks, which is also the output size. Default to 1.
num_node_types : int
Number of node types to embed. Default to 100.
num_edge_types : int
Number of edge types to embed. Default to 3000.
cutoff : float
Largest center in RBF expansion. Default to 5.0
gap : float
Difference between two adjacent centers in RBF expansion. Default to 1.0
"""
def
__init__
(
self
,
feats
=
128
,
n_layers
=
3
,
classifier_hidden_feats
=
64
,
n_tasks
=
1
,
num_node_types
=
100
,
num_edge_types
=
3000
,
cutoff
=
5.0
,
gap
=
1.0
):
super
(
MGCNPredictor
,
self
).
__init__
()
self
.
gnn
=
MGCNGNN
(
feats
=
feats
,
n_layers
=
n_layers
,
num_node_types
=
num_node_types
,
num_edge_types
=
num_edge_types
,
cutoff
=
cutoff
,
gap
=
gap
)
self
.
readout
=
MLPNodeReadout
(
node_feats
=
(
n_layers
+
1
)
*
feats
,
hidden_feats
=
classifier_hidden_feats
,
graph_feats
=
n_tasks
,
activation
=
nn
.
Softplus
(
beta
=
1
,
threshold
=
20
))
def
forward
(
self
,
g
,
node_types
,
edge_dists
):
"""Graph-level regression/soft classification.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_types : int64 tensor of shape (V)
Node types to embed, V for the number of nodes.
edge_dists : float32 tensor of shape (E, 1)
Distances between end nodes of edges, E for the number of edges.
Returns
-------
float32 tensor of shape (G, n_tasks)
Prediction for the graphs in the batch. G for the number of graphs.
"""
node_feats
=
self
.
gnn
(
g
,
node_types
,
edge_dists
)
return
self
.
readout
(
g
,
node_feats
)
apps/life_sci/python/dgllife/model/model_zoo/mlp_predictor.py
deleted
100644 → 0
View file @
94c67203
"""MLP for prediction on the output of readout."""
# pylint: disable= no-member, arguments-differ, invalid-name
import
torch.nn
as
nn
# pylint: disable=W0221
class
MLPPredictor
(
nn
.
Module
):
"""Two-layer MLP for regression or soft classification
over multiple tasks from graph representations.
For classification tasks, the output will be logits, i.e.
values before sigmoid or softmax.
Parameters
----------
in_feats : int
Number of input graph features
hidden_feats : int
Number of graph features in hidden layers
n_tasks : int
Number of tasks, which is also the output size.
dropout : float
The probability for dropout. Default to be 0., i.e. no
dropout is performed.
"""
def
__init__
(
self
,
in_feats
,
hidden_feats
,
n_tasks
,
dropout
=
0.
):
super
(
MLPPredictor
,
self
).
__init__
()
self
.
predict
=
nn
.
Sequential
(
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
in_feats
,
hidden_feats
),
nn
.
ReLU
(),
nn
.
BatchNorm1d
(
hidden_feats
),
nn
.
Linear
(
hidden_feats
,
n_tasks
)
)
def
forward
(
self
,
feats
):
"""Make prediction.
Parameters
----------
feats : FloatTensor of shape (B, M3)
* B is the number of graphs in a batch
* M3 is the input graph feature size, must match in_feats in initialization
Returns
-------
FloatTensor of shape (B, n_tasks)
"""
return
self
.
predict
(
feats
)
apps/life_sci/python/dgllife/model/model_zoo/mpnn_predictor.py
deleted
100644 → 0
View file @
94c67203
"""MPNN"""
# pylint: disable= no-member, arguments-differ, invalid-name
import
torch.nn
as
nn
from
dgl.nn.pytorch
import
Set2Set
from
..gnn
import
MPNNGNN
__all__
=
[
'MPNNPredictor'
]
# pylint: disable=W0221
class
MPNNPredictor
(
nn
.
Module
):
"""MPNN for regression and classification on graphs.
MPNN is introduced in `Neural Message Passing for Quantum Chemistry
<https://arxiv.org/abs/1704.01212>`__.
Parameters
----------
node_in_feats : int
Size for the input node features.
edge_in_feats : int
Size for the input edge features.
node_out_feats : int
Size for the output node representations. Default to 64.
edge_hidden_feats : int
Size for the hidden edge representations. Default to 128.
n_tasks : int
Number of tasks, which is also the output size. Default to 1.
num_step_message_passing : int
Number of message passing steps. Default to 6.
num_step_set2set : int
Number of set2set steps. Default to 6.
num_layer_set2set : int
Number of set2set layers. Default to 3.
"""
def
__init__
(
self
,
node_in_feats
,
edge_in_feats
,
node_out_feats
=
64
,
edge_hidden_feats
=
128
,
n_tasks
=
1
,
num_step_message_passing
=
6
,
num_step_set2set
=
6
,
num_layer_set2set
=
3
):
super
(
MPNNPredictor
,
self
).
__init__
()
self
.
gnn
=
MPNNGNN
(
node_in_feats
=
node_in_feats
,
node_out_feats
=
node_out_feats
,
edge_in_feats
=
edge_in_feats
,
edge_hidden_feats
=
edge_hidden_feats
,
num_step_message_passing
=
num_step_message_passing
)
self
.
readout
=
Set2Set
(
input_dim
=
node_out_feats
,
n_iters
=
num_step_set2set
,
n_layers
=
num_layer_set2set
)
self
.
predict
=
nn
.
Sequential
(
nn
.
Linear
(
2
*
node_out_feats
,
node_out_feats
),
nn
.
ReLU
(),
nn
.
Linear
(
node_out_feats
,
n_tasks
)
)
def
forward
(
self
,
g
,
node_feats
,
edge_feats
):
"""Graph-level regression/soft classification.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features.
edge_feats : float32 tensor of shape (E, edge_in_feats)
Input edge features.
Returns
-------
float32 tensor of shape (G, n_tasks)
Prediction for the graphs in the batch. G for the number of graphs.
"""
node_feats
=
self
.
gnn
(
g
,
node_feats
,
edge_feats
)
graph_feats
=
self
.
readout
(
g
,
node_feats
)
return
self
.
predict
(
graph_feats
)
apps/life_sci/python/dgllife/model/model_zoo/schnet_predictor.py
deleted
100644 → 0
View file @
94c67203
"""SchNet"""
# pylint: disable= no-member, arguments-differ, invalid-name
import
torch.nn
as
nn
from
dgl.nn.pytorch.conv.cfconv
import
ShiftedSoftplus
from
..gnn
import
SchNetGNN
from
..readout
import
MLPNodeReadout
__all__
=
[
'SchNetPredictor'
]
# pylint: disable=W0221
class
SchNetPredictor
(
nn
.
Module
):
"""SchNet for regression and classification on graphs.
SchNet is introduced in `SchNet: A continuous-filter convolutional neural network for
modeling quantum interactions <https://arxiv.org/abs/1706.08566>`__.
Parameters
----------
node_feats : int
Size for node representations to learn. Default to 64.
hidden_feats : list of int
``hidden_feats[i]`` gives the size of hidden representations for the i-th interaction
(gnn) layer. ``len(hidden_feats)`` equals the number of interaction (gnn) layers.
Default to ``[64, 64, 64]``.
classifier_hidden_feats : int
Size for hidden representations in the classifier. Default to 64.
n_tasks : int
Number of tasks, which is also the output size. Default to 1.
num_node_types : int
Number of node types to embed. Default to 100.
cutoff : float
Largest center in RBF expansion. Default to 30.
gap : float
Difference between two adjacent centers in RBF expansion. Default to 0.1.
"""
def
__init__
(
self
,
node_feats
=
64
,
hidden_feats
=
None
,
classifier_hidden_feats
=
64
,
n_tasks
=
1
,
num_node_types
=
100
,
cutoff
=
30.
,
gap
=
0.1
):
super
(
SchNetPredictor
,
self
).
__init__
()
self
.
gnn
=
SchNetGNN
(
node_feats
,
hidden_feats
,
num_node_types
,
cutoff
,
gap
)
self
.
readout
=
MLPNodeReadout
(
node_feats
,
classifier_hidden_feats
,
n_tasks
,
activation
=
ShiftedSoftplus
())
def
forward
(
self
,
g
,
node_types
,
edge_dists
):
"""Graph-level regression/soft classification.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_types : int64 tensor of shape (V)
Node types to embed, V for the number of nodes.
edge_dists : float32 tensor of shape (E, 1)
Distances between end nodes of edges, E for the number of edges.
Returns
-------
float32 tensor of shape (G, n_tasks)
Prediction for the graphs in the batch. G for the number of graphs.
"""
node_feats
=
self
.
gnn
(
g
,
node_types
,
edge_dists
)
return
self
.
readout
(
g
,
node_feats
)
apps/life_sci/python/dgllife/model/model_zoo/weave_predictor.py
deleted
100644 → 0
View file @
94c67203
"""Weave"""
# pylint: disable= no-member, arguments-differ, invalid-name
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
..gnn
import
WeaveGNN
from
..readout
import
WeaveGather
__all__
=
[
'WeavePredictor'
]
# pylint: disable=W0221
class
WeavePredictor
(
nn
.
Module
):
r
"""Weave for regression and classification on graphs.
Weave is introduced in `Molecular Graph Convolutions: Moving Beyond Fingerprints
<https://arxiv.org/abs/1603.00856>`__
Parameters
----------
node_in_feats : int
Size for the input node features.
edge_in_feats : int
Size for the input edge features.
num_gnn_layers : int
Number of GNN (Weave) layers to use. Default to 2.
gnn_hidden_feats : int
Size for the hidden node and edge representations. Default to 50.
gnn_activation : callable
Activation function to be used in GNN (Weave) layers. Default to ReLU.
graph_feats : int
Size for the hidden graph representations. Default to 50.
gaussian_expand : bool
Whether to expand each dimension of node features by gaussian histogram in
computing graph representations. Default to True.
gaussian_memberships : list of 2-tuples
For each tuple, the first and second element separately specifies the mean
and std for constructing a normal distribution. This argument comes into
effect only when ``gaussian_expand==True``. By default, we set this to be
``[(-1.645, 0.283), (-1.080, 0.170), (-0.739, 0.134), (-0.468, 0.118),
(-0.228, 0.114), (0., 0.114), (0.228, 0.114), (0.468, 0.118),
(0.739, 0.134), (1.080, 0.170), (1.645, 0.283)]``.
readout_activation : callable
Activation function to be used in computing graph representations out of
node representations. Default to Tanh.
n_tasks : int
Number of tasks, which is also the output size. Default to 1.
"""
def
__init__
(
self
,
node_in_feats
,
edge_in_feats
,
num_gnn_layers
=
2
,
gnn_hidden_feats
=
50
,
gnn_activation
=
F
.
relu
,
graph_feats
=
128
,
gaussian_expand
=
True
,
gaussian_memberships
=
None
,
readout_activation
=
nn
.
Tanh
(),
n_tasks
=
1
):
super
(
WeavePredictor
,
self
).
__init__
()
self
.
gnn
=
WeaveGNN
(
node_in_feats
=
node_in_feats
,
edge_in_feats
=
edge_in_feats
,
num_layers
=
num_gnn_layers
,
hidden_feats
=
gnn_hidden_feats
,
activation
=
gnn_activation
)
self
.
node_to_graph
=
nn
.
Sequential
(
nn
.
Linear
(
gnn_hidden_feats
,
graph_feats
),
readout_activation
,
nn
.
BatchNorm1d
(
graph_feats
)
)
self
.
readout
=
WeaveGather
(
node_in_feats
=
graph_feats
,
gaussian_expand
=
gaussian_expand
,
gaussian_memberships
=
gaussian_memberships
,
activation
=
readout_activation
)
self
.
predict
=
nn
.
Linear
(
graph_feats
,
n_tasks
)
def
forward
(
self
,
g
,
node_feats
,
edge_feats
):
"""Graph-level regression/soft classification.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features. V for the number of nodes.
edge_feats : float32 tensor of shape (E, edge_in_feats)
Input edge features. E for the number of edges.
Returns
-------
float32 tensor of shape (G, n_tasks)
Prediction for the graphs in the batch. G for the number of graphs.
"""
node_feats
=
self
.
gnn
(
g
,
node_feats
,
edge_feats
,
node_only
=
True
)
node_feats
=
self
.
node_to_graph
(
node_feats
)
g_feats
=
self
.
readout
(
g
,
node_feats
)
return
self
.
predict
(
g_feats
)
Prev
1
2
3
4
5
6
7
8
9
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment