Unverified Commit 828a5e5b authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[DGL-LifeSci] Migration and Refactor (#1226)

* First commit

* Update

* Update splitters

* Update

* Update

* Update

* Update

* Update

* Update

* Migrate ACNN

* Fix

* Fix

* Update

* Update

* Update

* Update

* Update

* Update

* Finish classification

* Update

* Fix

* Update

* Update

* Update

* Fix

* Fix

* Fix

* Update

* Update

* Update

* trigger CI

* Fix CI

* Update

* Update

* Update

* Add default values

* Rename

* Update deprecation message
parent e4948c5c
"""GAT-based model for regression and classification on graphs."""
import torch.nn as nn
import torch.nn.functional as F
from .mlp_predictor import MLPPredictor
from ..gnn.gat import GAT
from ..readout.weighted_sum_and_max import WeightedSumAndMax
class GATPredictor(nn.Module):
r"""GAT-based model for regression and classification on graphs.
GAT is introduced in `Graph Attention Networks <https://arxiv.org/abs/1710.10903>`__.
This model is based on GAT and can be used for regression and classification on graphs.
After updating node representations, we perform a weighted sum with learnable
weights and max pooling on them and concatenate the output of the two operations,
which is then fed into an MLP for final prediction.
For classification tasks, the output will be logits, i.e.
values before sigmoid or softmax.
Parameters
----------
in_feats : int
Number of input node features
hidden_feats : list of int
``hidden_feats[i]`` gives the output size of an attention head in the i-th GAT layer.
``len(hidden_feats)`` equals the number of GAT layers. By default, we use ``[32, 32]``.
num_heads : list of int
``num_heads[i]`` gives the number of attention heads in the i-th GAT layer.
``len(num_heads)`` equals the number of GAT layers. By default, we use 4 attention heads
for each GAT layer.
feat_drops : list of float
``feat_drops[i]`` gives the dropout applied to the input features in the i-th GAT layer.
``len(feat_drops)`` equals the number of GAT layers. By default, this will be zero for
all GAT layers.
attn_drops : list of float
``attn_drops[i]`` gives the dropout applied to attention values of edges in the i-th GAT
layer. ``len(attn_drops)`` equals the number of GAT layers. By default, this will be zero
for all GAT layers.
alphas : list of float
Hyperparameters in LeakyReLU, which are the slopes for negative values. ``alphas[i]``
gives the slope for negative value in the i-th GAT layer. ``len(alphas)`` equals the
number of GAT layers. By default, this will be 0.2 for all GAT layers.
residuals : list of bool
``residual[i]`` decides if residual connection is to be used for the i-th GAT layer.
``len(residual)`` equals the number of GAT layers. By default, residual connection
is performed for each GAT layer.
agg_modes : list of str
The way to aggregate multi-head attention results for each GAT layer, which can be either
'flatten' for concatenating all-head results or 'mean' for averaging all-head results.
``agg_modes[i]`` gives the way to aggregate multi-head attention results for the i-th
GAT layer. ``len(agg_modes)`` equals the number of GAT layers. By default, we flatten
multi-head results for intermediate GAT layers and compute mean of multi-head results
for the last GAT layer.
activations : list of activation function or None
``activations[i]`` gives the activation function applied to the aggregated multi-head
results for the i-th GAT layer. ``len(activations)`` equals the number of GAT layers.
By default, ELU is applied for intermediate GAT layers and no activation is applied
for the last GAT layer.
classifier_hidden_feats : int
Size of hidden graph representations in the classifier. Default to 128.
classifier_dropout : float
The probability for dropout in the classifier. Default to 0.
n_tasks : int
Number of tasks, which is also the output size. Default to 1.
"""
def __init__(self, in_feats, hidden_feats=None, num_heads=None, feat_drops=None, attn_drops=None,
alphas=None, residuals=None, agg_modes=None, activations=None,
classifier_hidden_feats=128, classifier_dropout=0., n_tasks=1):
super(GATPredictor, self).__init__()
self.gnn = GAT(in_feats=in_feats,
hidden_feats=hidden_feats,
num_heads=num_heads,
feat_drops=feat_drops,
attn_drops=attn_drops,
alphas=alphas,
residuals=residuals,
agg_modes=agg_modes,
activations=activations)
if self.gnn.agg_modes[-1] == 'flatten':
gnn_out_feats = self.gnn.hidden_feats[-1] * self.gnn.num_heads[-1]
else:
gnn_out_feats = self.gnn.hidden_feats[-1]
self.readout = WeightedSumAndMax(gnn_out_feats)
self.predict = MLPPredictor(2 * gnn_out_feats, classifier_hidden_feats,
n_tasks, classifier_dropout)
def forward(self, bg, feats):
"""Graph-level regression/soft classification.
Parameters
----------
bg : DGLGraph
DGLGraph for a batch of graphs.
feats : FloatTensor of shape (N, M1)
* N is the total number of nodes in the batch of graphs
* M1 is the input node feature size, which must match
in_feats in initialization
Returns
-------
FloatTensor of shape (B, n_tasks)
* Predictions on graphs
* B for the number of graphs in the batch
"""
node_feats = self.gnn(bg, feats)
graph_feats = self.readout(bg, node_feats)
return self.predict(graph_feats)
"""GCN-based model for regression and classification on graphs."""
import torch.nn as nn
from .mlp_predictor import MLPPredictor
from ..gnn.gcn import GCN
from ..readout.weighted_sum_and_max import WeightedSumAndMax
class GCNPredictor(nn.Module):
"""GCN-based model for regression and classification on graphs.
GCN is introduced in `Semi-Supervised Classification with Graph Convolutional Networks
<https://arxiv.org/abs/1609.02907>`__. This model is based on GCN and can be used
for regression and classification on graphs.
After updating node representations, we perform a weighted sum with learnable
weights and max pooling on them and concatenate the output of the two operations,
which is then fed into an MLP for final prediction.
For classification tasks, the output will be logits, i.e.
values before sigmoid or softmax.
Parameters
----------
in_feats : int
Number of input node features.
hidden_feats : list of int
``hidden_feats[i]`` gives the size of node representations after the i-th GCN layer.
``len(hidden_feats)`` equals the number of GCN layers. By default, we use
``[64, 64]``.
activation : list of activation functions or None
If None, no activation will be applied. If not None, ``activation[i]`` gives the
activation function to be used for the i-th GCN layer. ``len(activation)`` equals
the number of GCN layers. By default, ReLU is applied for all GCN layers.
residual : list of bool
``residual[i]`` decides if residual connection is to be used for the i-th GCN layer.
``len(residual)`` equals the number of GCN layers. By default, residual connection
is performed for each GCN layer.
batchnorm : list of bool
``batchnorm[i]`` decides if batch normalization is to be applied on the output of
the i-th GCN layer. ``len(batchnorm)`` equals the number of GCN layers. By default,
batch normalization is applied for all GCN layers.
dropout : list of float
``dropout[i]`` decides the dropout probability on the output of the i-th GCN layer.
``len(dropout)`` equals the number of GCN layers. By default, no dropout is
performed for all layers.
classifier_hidden_feats : int
Size of hidden graph representations in the classifier. Default to 128.
classifier_dropout : float
The probability for dropout in the classifier. Default to 0.
n_tasks : int
Number of tasks, which is also the output size. Default to 1.
"""
def __init__(self, in_feats, hidden_feats=None, activation=None, residual=None, batchnorm=None,
dropout=None, classifier_hidden_feats=128, classifier_dropout=0., n_tasks=1):
super(GCNPredictor, self).__init__()
self.gnn = GCN(in_feats=in_feats,
hidden_feats=hidden_feats,
activation=activation,
residual=residual,
batchnorm=batchnorm,
dropout=dropout)
gnn_out_feats = self.gnn.hidden_feats[-1]
self.readout = WeightedSumAndMax(gnn_out_feats)
self.predict = MLPPredictor(2 * gnn_out_feats, classifier_hidden_feats,
n_tasks, classifier_dropout)
def forward(self, bg, feats):
"""Graph-level regression/soft classification.
Parameters
----------
bg : DGLGraph
DGLGraph for a batch of graphs.
feats : FloatTensor of shape (N, M1)
* N is the total number of nodes in the batch of graphs
* M1 is the input node feature size, which must match
in_feats in initialization
Returns
-------
FloatTensor of shape (B, n_tasks)
* Predictions on graphs
* B for the number of graphs in the batch
"""
node_feats = self.gnn(bg, feats)
graph_feats = self.readout(bg, node_feats)
return self.predict(graph_feats)
"""JTNN Module"""
from .jtnn_vae import DGLJTNNVAE
# pylint: disable=C0111, C0103, E1101, W0611, W0612, W0703, C0200, R1710
import rdkit.Chem as Chem
from collections import defaultdict
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree
MST_MAX_WEIGHT = 100
MAX_NCAND = 2000
def set_atommap(mol, num=0):
for atom in mol.GetAtoms():
atom.SetAtomMapNum(num)
def get_mol(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
Chem.Kekulize(mol)
return mol
def get_smiles(mol):
return Chem.MolToSmiles(mol, kekuleSmiles=True)
def decode_stereo(smiles2D):
mol = Chem.MolFromSmiles(smiles2D)
dec_isomers = list(EnumerateStereoisomers(mol))
dec_isomers = [Chem.MolFromSmiles(Chem.MolToSmiles(
mol, isomericSmiles=True)) for mol in dec_isomers]
smiles3D = [Chem.MolToSmiles(mol, isomericSmiles=True)
for mol in dec_isomers]
chiralN = [atom.GetIdx() for atom in dec_isomers[0].GetAtoms() if int(
atom.GetChiralTag()) > 0 and atom.GetSymbol() == "N"]
if len(chiralN) > 0:
for mol in dec_isomers:
for idx in chiralN:
mol.GetAtomWithIdx(idx).SetChiralTag(
Chem.rdchem.ChiralType.CHI_UNSPECIFIED)
smiles3D.append(Chem.MolToSmiles(mol, isomericSmiles=True))
return smiles3D
def sanitize(mol):
try:
smiles = get_smiles(mol)
mol = get_mol(smiles)
except Exception:
return None
return mol
def copy_atom(atom):
new_atom = Chem.Atom(atom.GetSymbol())
new_atom.SetFormalCharge(atom.GetFormalCharge())
new_atom.SetAtomMapNum(atom.GetAtomMapNum())
return new_atom
def copy_edit_mol(mol):
new_mol = Chem.RWMol(Chem.MolFromSmiles(''))
for atom in mol.GetAtoms():
new_atom = copy_atom(atom)
new_mol.AddAtom(new_atom)
for bond in mol.GetBonds():
a1 = bond.GetBeginAtom().GetIdx()
a2 = bond.GetEndAtom().GetIdx()
bt = bond.GetBondType()
new_mol.AddBond(a1, a2, bt)
return new_mol
def get_clique_mol(mol, atoms):
smiles = Chem.MolFragmentToSmiles(mol, atoms, kekuleSmiles=True)
new_mol = Chem.MolFromSmiles(smiles, sanitize=False)
new_mol = copy_edit_mol(new_mol).GetMol()
new_mol = sanitize(new_mol) # We assume this is not None
return new_mol
def tree_decomp(mol):
n_atoms = mol.GetNumAtoms()
if n_atoms == 1:
return [[0]], []
cliques = []
for bond in mol.GetBonds():
a1 = bond.GetBeginAtom().GetIdx()
a2 = bond.GetEndAtom().GetIdx()
if not bond.IsInRing():
cliques.append([a1, a2])
ssr = [list(x) for x in Chem.GetSymmSSSR(mol)]
cliques.extend(ssr)
nei_list = [[] for i in range(n_atoms)]
for i in range(len(cliques)):
for atom in cliques[i]:
nei_list[atom].append(i)
# Merge Rings with intersection > 2 atoms
for i in range(len(cliques)):
if len(cliques[i]) <= 2:
continue
for atom in cliques[i]:
for j in nei_list[atom]:
if i >= j or len(cliques[j]) <= 2:
continue
inter = set(cliques[i]) & set(cliques[j])
if len(inter) > 2:
cliques[i].extend(cliques[j])
cliques[i] = list(set(cliques[i]))
cliques[j] = []
cliques = [c for c in cliques if len(c) > 0]
nei_list = [[] for i in range(n_atoms)]
for i in range(len(cliques)):
for atom in cliques[i]:
nei_list[atom].append(i)
# Build edges and add singleton cliques
edges = defaultdict(int)
for atom in range(n_atoms):
if len(nei_list[atom]) <= 1:
continue
cnei = nei_list[atom]
bonds = [c for c in cnei if len(cliques[c]) == 2]
rings = [c for c in cnei if len(cliques[c]) > 4]
# In general, if len(cnei) >= 3, a singleton should be added, but 1
# bond + 2 ring is currently not dealt with.
if len(bonds) > 2 or (len(bonds) == 2 and len(cnei) > 2):
cliques.append([atom])
c2 = len(cliques) - 1
for c1 in cnei:
edges[(c1, c2)] = 1
elif len(rings) > 2: # Multiple (n>2) complex rings
cliques.append([atom])
c2 = len(cliques) - 1
for c1 in cnei:
edges[(c1, c2)] = MST_MAX_WEIGHT - 1
else:
for i in range(len(cnei)):
for j in range(i + 1, len(cnei)):
c1, c2 = cnei[i], cnei[j]
inter = set(cliques[c1]) & set(cliques[c2])
if edges[(c1, c2)] < len(inter):
# cnei[i] < cnei[j] by construction
edges[(c1, c2)] = len(inter)
edges = [u + (MST_MAX_WEIGHT - v,) for u, v in edges.items()]
if len(edges) == 0:
return cliques, edges
# Compute Maximum Spanning Tree
row, col, data = list(zip(*edges))
n_clique = len(cliques)
clique_graph = csr_matrix((data, (row, col)), shape=(n_clique, n_clique))
junc_tree = minimum_spanning_tree(clique_graph)
row, col = junc_tree.nonzero()
edges = [(row[i], col[i]) for i in range(len(row))]
return (cliques, edges)
def atom_equal(a1, a2):
return a1.GetSymbol() == a2.GetSymbol() and a1.GetFormalCharge() == a2.GetFormalCharge()
# Bond type not considered because all aromatic (so SINGLE matches DOUBLE)
def ring_bond_equal(b1, b2, reverse=False):
b1 = (b1.GetBeginAtom(), b1.GetEndAtom())
if reverse:
b2 = (b2.GetEndAtom(), b2.GetBeginAtom())
else:
b2 = (b2.GetBeginAtom(), b2.GetEndAtom())
return atom_equal(b1[0], b2[0]) and atom_equal(b1[1], b2[1])
def attach_mols_nx(ctr_mol, neighbors, prev_nodes, nei_amap):
prev_nids = [node['nid'] for node in prev_nodes]
for nei_node in prev_nodes + neighbors:
nei_id, nei_mol = nei_node['nid'], nei_node['mol']
amap = nei_amap[nei_id]
for atom in nei_mol.GetAtoms():
if atom.GetIdx() not in amap:
new_atom = copy_atom(atom)
amap[atom.GetIdx()] = ctr_mol.AddAtom(new_atom)
if nei_mol.GetNumBonds() == 0:
nei_atom = nei_mol.GetAtomWithIdx(0)
ctr_atom = ctr_mol.GetAtomWithIdx(amap[0])
ctr_atom.SetAtomMapNum(nei_atom.GetAtomMapNum())
else:
for bond in nei_mol.GetBonds():
a1 = amap[bond.GetBeginAtom().GetIdx()]
a2 = amap[bond.GetEndAtom().GetIdx()]
if ctr_mol.GetBondBetweenAtoms(a1, a2) is None:
ctr_mol.AddBond(a1, a2, bond.GetBondType())
elif nei_id in prev_nids: # father node overrides
ctr_mol.RemoveBond(a1, a2)
ctr_mol.AddBond(a1, a2, bond.GetBondType())
return ctr_mol
def local_attach_nx(ctr_mol, neighbors, prev_nodes, amap_list):
ctr_mol = copy_edit_mol(ctr_mol)
nei_amap = {nei['nid']: {} for nei in prev_nodes + neighbors}
for nei_id, ctr_atom, nei_atom in amap_list:
nei_amap[nei_id][nei_atom] = ctr_atom
ctr_mol = attach_mols_nx(ctr_mol, neighbors, prev_nodes, nei_amap)
return ctr_mol.GetMol()
# This version records idx mapping between ctr_mol and nei_mol
def enum_attach_nx(ctr_mol, nei_node, amap, singletons):
nei_mol, nei_idx = nei_node['mol'], nei_node['nid']
att_confs = []
black_list = [atom_idx for nei_id, atom_idx,
_ in amap if nei_id in singletons]
ctr_atoms = [atom for atom in ctr_mol.GetAtoms() if atom.GetIdx()
not in black_list]
ctr_bonds = [bond for bond in ctr_mol.GetBonds()]
if nei_mol.GetNumBonds() == 0: # neighbor singleton
nei_atom = nei_mol.GetAtomWithIdx(0)
used_list = [atom_idx for _, atom_idx, _ in amap]
for atom in ctr_atoms:
if atom_equal(atom, nei_atom) and atom.GetIdx() not in used_list:
new_amap = amap + [(nei_idx, atom.GetIdx(), 0)]
att_confs.append(new_amap)
elif nei_mol.GetNumBonds() == 1: # neighbor is a bond
bond = nei_mol.GetBondWithIdx(0)
bond_val = int(bond.GetBondTypeAsDouble())
b1, b2 = bond.GetBeginAtom(), bond.GetEndAtom()
for atom in ctr_atoms:
# Optimize if atom is carbon (other atoms may change valence)
if atom.GetAtomicNum() == 6 and atom.GetTotalNumHs() < bond_val:
continue
if atom_equal(atom, b1):
new_amap = amap + [(nei_idx, atom.GetIdx(), b1.GetIdx())]
att_confs.append(new_amap)
elif atom_equal(atom, b2):
new_amap = amap + [(nei_idx, atom.GetIdx(), b2.GetIdx())]
att_confs.append(new_amap)
else:
# intersection is an atom
for a1 in ctr_atoms:
for a2 in nei_mol.GetAtoms():
if atom_equal(a1, a2):
# Optimize if atom is carbon (other atoms may change
# valence)
if a1.GetAtomicNum() == 6 and a1.GetTotalNumHs() + a2.GetTotalNumHs() < 4:
continue
new_amap = amap + [(nei_idx, a1.GetIdx(), a2.GetIdx())]
att_confs.append(new_amap)
# intersection is an bond
if ctr_mol.GetNumBonds() > 1:
for b1 in ctr_bonds:
for b2 in nei_mol.GetBonds():
if ring_bond_equal(b1, b2):
new_amap = amap + [(nei_idx,
b1.GetBeginAtom().GetIdx(),
b2.GetBeginAtom().GetIdx()),
(nei_idx,
b1.GetEndAtom().GetIdx(),
b2.GetEndAtom().GetIdx())]
att_confs.append(new_amap)
if ring_bond_equal(b1, b2, reverse=True):
new_amap = amap + [(nei_idx,
b1.GetBeginAtom().GetIdx(),
b2.GetEndAtom().GetIdx()),
(nei_idx,
b1.GetEndAtom().GetIdx(),
b2.GetBeginAtom().GetIdx())]
att_confs.append(new_amap)
return att_confs
# Try rings first: Speed-Up
def enum_assemble_nx(node, neighbors, prev_nodes=None, prev_amap=None):
if prev_nodes is None:
prev_nodes = []
if prev_amap is None:
prev_amap = []
all_attach_confs = []
singletons = [nei_node['nid'] for nei_node in neighbors +
prev_nodes if nei_node['mol'].GetNumAtoms() == 1]
def search(cur_amap, depth):
if len(all_attach_confs) > MAX_NCAND:
return None
if depth == len(neighbors):
all_attach_confs.append(cur_amap)
return None
nei_node = neighbors[depth]
cand_amap = enum_attach_nx(node['mol'], nei_node, cur_amap, singletons)
cand_smiles = set()
candidates = []
for amap in cand_amap:
cand_mol = local_attach_nx(
node['mol'], neighbors[:depth + 1], prev_nodes, amap)
cand_mol = sanitize(cand_mol)
if cand_mol is None:
continue
smiles = get_smiles(cand_mol)
if smiles in cand_smiles:
continue
cand_smiles.add(smiles)
candidates.append(amap)
if len(candidates) == 0:
return []
for new_amap in candidates:
search(new_amap, depth + 1)
search(prev_amap, 0)
cand_smiles = set()
candidates = []
for amap in all_attach_confs:
cand_mol = local_attach_nx(node['mol'], neighbors, prev_nodes, amap)
cand_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cand_mol))
smiles = Chem.MolToSmiles(cand_mol)
if smiles in cand_smiles:
continue
cand_smiles.add(smiles)
Chem.Kekulize(cand_mol)
candidates.append((smiles, cand_mol, amap))
return candidates
# Only used for debugging purpose
def dfs_assemble_nx(
graph,
cur_mol,
global_amap,
fa_amap,
cur_node_id,
fa_node_id):
cur_node = graph.nodes_dict[cur_node_id]
fa_node = graph.nodes_dict[fa_node_id] if fa_node_id is not None else None
fa_nid = fa_node['nid'] if fa_node is not None else -1
prev_nodes = [fa_node] if fa_node is not None else []
children_id = [nei for nei in graph[cur_node_id]
if graph.nodes_dict[nei]['nid'] != fa_nid]
children = [graph.nodes_dict[nei] for nei in children_id]
neighbors = [nei for nei in children if nei['mol'].GetNumAtoms() > 1]
neighbors = sorted(
neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
singletons = [nei for nei in children if nei['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors
cur_amap = [(fa_nid, a2, a1)
for nid, a1, a2 in fa_amap if nid == cur_node['nid']]
cands = enum_assemble_nx(
graph.nodes_dict[cur_node_id], neighbors, prev_nodes, cur_amap)
if len(cands) == 0:
return
cand_smiles, _, cand_amap = zip(*cands)
label_idx = cand_smiles.index(cur_node['label'])
label_amap = cand_amap[label_idx]
for nei_id, ctr_atom, nei_atom in label_amap:
if nei_id == fa_nid:
continue
global_amap[nei_id][nei_atom] = global_amap[cur_node['nid']][ctr_atom]
# father is already attached
cur_mol = attach_mols_nx(cur_mol, children, [], global_amap)
for nei_node_id, nei_node in zip(children_id, children):
if not nei_node['is_leaf']:
dfs_assemble_nx(graph, cur_mol, global_amap,
label_amap, nei_node_id, cur_node_id)
# pylint: disable=C0111, C0103, E1101, W0611, W0612, W1508
# pylint: disable=redefined-outer-name
import os
import rdkit.Chem as Chem
import torch
import torch.nn as nn
import dgl.function as DGLF
from dgl import DGLGraph, mean_nodes
from .nnutils import cuda
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na',
'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 1
BOND_FDIM = 5
MAX_NB = 10
PAPER = os.getenv('PAPER', False)
def onek_encoding_unk(x, allowable_set):
if x not in allowable_set:
x = allowable_set[-1]
return [x == s for s in allowable_set]
# Note that during graph decoding they don't predict stereochemistry-related
# characteristics (i.e. Chiral Atoms, E-Z, Cis-Trans). Instead, they decode
# the 2-D graph first, then enumerate all possible 3-D forms and find the
# one with highest score.
def atom_features(atom):
return (torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
+ onek_encoding_unk(atom.GetDegree(),
[0, 1, 2, 3, 4, 5])
+ onek_encoding_unk(atom.GetFormalCharge(),
[-1, -2, 1, 2, 0])
+ [atom.GetIsAromatic()]))
def bond_features(bond):
bt = bond.GetBondType()
return torch.Tensor([bt == Chem.rdchem.BondType.SINGLE,
bt == Chem.rdchem.BondType.DOUBLE,
bt == Chem.rdchem.BondType.TRIPLE,
bt == Chem.rdchem.BondType.AROMATIC,
bond.IsInRing()])
def mol2dgl_single(cand_batch):
cand_graphs = []
tree_mess_source_edges = [] # map these edges from trees to...
tree_mess_target_edges = [] # these edges on candidate graphs
tree_mess_target_nodes = []
n_nodes = 0
atom_x = []
bond_x = []
for mol, mol_tree, ctr_node_id in cand_batch:
n_atoms = mol.GetNumAtoms()
g = DGLGraph()
for i, atom in enumerate(mol.GetAtoms()):
assert i == atom.GetIdx()
atom_x.append(atom_features(atom))
g.add_nodes(n_atoms)
bond_src = []
bond_dst = []
for i, bond in enumerate(mol.GetBonds()):
a1 = bond.GetBeginAtom()
a2 = bond.GetEndAtom()
begin_idx = a1.GetIdx()
end_idx = a2.GetIdx()
features = bond_features(bond)
bond_src.append(begin_idx)
bond_dst.append(end_idx)
bond_x.append(features)
bond_src.append(end_idx)
bond_dst.append(begin_idx)
bond_x.append(features)
x_nid, y_nid = a1.GetAtomMapNum(), a2.GetAtomMapNum()
# Tree node ID in the batch
x_bid = mol_tree.nodes_dict[x_nid - 1]['idx'] if x_nid > 0 else -1
y_bid = mol_tree.nodes_dict[y_nid - 1]['idx'] if y_nid > 0 else -1
if x_bid >= 0 and y_bid >= 0 and x_bid != y_bid:
if mol_tree.has_edge_between(x_bid, y_bid):
tree_mess_target_edges.append(
(begin_idx + n_nodes, end_idx + n_nodes))
tree_mess_source_edges.append((x_bid, y_bid))
tree_mess_target_nodes.append(end_idx + n_nodes)
if mol_tree.has_edge_between(y_bid, x_bid):
tree_mess_target_edges.append(
(end_idx + n_nodes, begin_idx + n_nodes))
tree_mess_source_edges.append((y_bid, x_bid))
tree_mess_target_nodes.append(begin_idx + n_nodes)
n_nodes += n_atoms
g.add_edges(bond_src, bond_dst)
cand_graphs.append(g)
return cand_graphs, torch.stack(atom_x), \
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0), \
torch.LongTensor(tree_mess_source_edges), \
torch.LongTensor(tree_mess_target_edges), \
torch.LongTensor(tree_mess_target_nodes)
mpn_loopy_bp_msg = DGLF.copy_src(src='msg', out='msg')
mpn_loopy_bp_reduce = DGLF.sum(msg='msg', out='accum_msg')
class LoopyBPUpdate(nn.Module):
def __init__(self, hidden_size):
super(LoopyBPUpdate, self).__init__()
self.hidden_size = hidden_size
self.W_h = nn.Linear(hidden_size, hidden_size, bias=False)
def forward(self, node):
msg_input = node.data['msg_input']
msg_delta = self.W_h(node.data['accum_msg'] + node.data['alpha'])
msg = torch.relu(msg_input + msg_delta)
return {'msg': msg}
if PAPER:
mpn_gather_msg = [
DGLF.copy_edge(edge='msg', out='msg'),
DGLF.copy_edge(edge='alpha', out='alpha')
]
else:
mpn_gather_msg = DGLF.copy_edge(edge='msg', out='msg')
if PAPER:
mpn_gather_reduce = [
DGLF.sum(msg='msg', out='m'),
DGLF.sum(msg='alpha', out='accum_alpha'),
]
else:
mpn_gather_reduce = DGLF.sum(msg='msg', out='m')
class GatherUpdate(nn.Module):
def __init__(self, hidden_size):
super(GatherUpdate, self).__init__()
self.hidden_size = hidden_size
self.W_o = nn.Linear(ATOM_FDIM + hidden_size, hidden_size)
def forward(self, node):
if PAPER:
#m = node['m']
m = node.data['m'] + node.data['accum_alpha']
else:
m = node.data['m'] + node.data['alpha']
return {
'h': torch.relu(self.W_o(torch.cat([node.data['x'], m], 1))),
}
class DGLJTMPN(nn.Module):
def __init__(self, hidden_size, depth):
nn.Module.__init__(self)
self.depth = depth
self.W_i = nn.Linear(ATOM_FDIM + BOND_FDIM, hidden_size, bias=False)
self.loopy_bp_updater = LoopyBPUpdate(hidden_size)
self.gather_updater = GatherUpdate(hidden_size)
self.hidden_size = hidden_size
self.n_samples_total = 0
self.n_nodes_total = 0
self.n_edges_total = 0
self.n_passes = 0
def forward(self, cand_batch, mol_tree_batch):
cand_graphs, tree_mess_src_edges, tree_mess_tgt_edges, tree_mess_tgt_nodes = cand_batch
n_samples = len(cand_graphs)
cand_line_graph = cand_graphs.line_graph(
backtracking=False, shared=True)
n_nodes = cand_graphs.number_of_nodes()
n_edges = cand_graphs.number_of_edges()
cand_graphs = self.run(
cand_graphs, cand_line_graph, tree_mess_src_edges, tree_mess_tgt_edges,
tree_mess_tgt_nodes, mol_tree_batch)
g_repr = mean_nodes(cand_graphs, 'h')
self.n_samples_total += n_samples
self.n_nodes_total += n_nodes
self.n_edges_total += n_edges
self.n_passes += 1
return g_repr
def run(self, cand_graphs, cand_line_graph, tree_mess_src_edges, tree_mess_tgt_edges,
tree_mess_tgt_nodes, mol_tree_batch):
n_nodes = cand_graphs.number_of_nodes()
cand_graphs.apply_edges(
func=lambda edges: {'src_x': edges.src['x']},
)
bond_features = cand_line_graph.ndata['x']
source_features = cand_line_graph.ndata['src_x']
features = torch.cat([source_features, bond_features], 1)
msg_input = self.W_i(features)
cand_line_graph.ndata.update({
'msg_input': msg_input,
'msg': torch.relu(msg_input),
'accum_msg': torch.zeros_like(msg_input),
})
zero_node_state = bond_features.new(n_nodes, self.hidden_size).zero_()
cand_graphs.ndata.update({
'm': zero_node_state.clone(),
'h': zero_node_state.clone(),
})
cand_graphs.edata['alpha'] = \
cuda(torch.zeros(cand_graphs.number_of_edges(), self.hidden_size))
cand_graphs.ndata['alpha'] = zero_node_state
if tree_mess_src_edges.shape[0] > 0:
if PAPER:
src_u, src_v = tree_mess_src_edges.unbind(1)
tgt_u, tgt_v = tree_mess_tgt_edges.unbind(1)
alpha = mol_tree_batch.edges[src_u, src_v].data['m']
cand_graphs.edges[tgt_u, tgt_v].data['alpha'] = alpha
else:
src_u, src_v = tree_mess_src_edges.unbind(1)
alpha = mol_tree_batch.edges[src_u, src_v].data['m']
node_idx = (tree_mess_tgt_nodes
.to(device=zero_node_state.device)[:, None]
.expand_as(alpha))
node_alpha = zero_node_state.clone().scatter_add(0, node_idx, alpha)
cand_graphs.ndata['alpha'] = node_alpha
cand_graphs.apply_edges(
func=lambda edges: {'alpha': edges.src['alpha']},
)
for i in range(self.depth - 1):
cand_line_graph.update_all(
mpn_loopy_bp_msg,
mpn_loopy_bp_reduce,
self.loopy_bp_updater,
)
cand_graphs.update_all(
mpn_gather_msg,
mpn_gather_reduce,
self.gather_updater,
)
return cand_graphs
# pylint: disable=C0111, C0103, E1101, W0611, W0612
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as DGLF
from dgl import batch, dfs_labeled_edges_generator
from .chemutils import enum_assemble_nx, get_mol
from .mol_tree_nx import DGLMolTree
from .nnutils import GRUUpdate, cuda
MAX_NB = 8
MAX_DECODE_LEN = 100
def dfs_order(forest, roots):
edges = dfs_labeled_edges_generator(forest, roots, has_reverse_edge=True)
for e, l in zip(*edges):
# I exploited the fact that the reverse edge ID equal to 1 xor forward
# edge ID for molecule trees. Normally, I should locate reverse edges
# using find_edges().
yield e ^ l, l
dec_tree_node_msg = DGLF.copy_edge(edge='m', out='m')
dec_tree_node_reduce = DGLF.sum(msg='m', out='h')
def dec_tree_node_update(nodes):
return {'new': nodes.data['new'].clone().zero_()}
dec_tree_edge_msg = [DGLF.copy_src(
src='m', out='m'), DGLF.copy_src(src='rm', out='rm')]
dec_tree_edge_reduce = [
DGLF.sum(msg='m', out='s'), DGLF.sum(msg='rm', out='accum_rm')]
def have_slots(fa_slots, ch_slots):
if len(fa_slots) > 2 and len(ch_slots) > 2:
return True
matches = []
for i, s1 in enumerate(fa_slots):
a1, c1, h1 = s1
for j, s2 in enumerate(ch_slots):
a2, c2, h2 = s2
if a1 == a2 and c1 == c2 and (a1 != "C" or h1 + h2 >= 4):
matches.append((i, j))
if len(matches) == 0:
return False
fa_match, ch_match = list(zip(*matches))
if len(set(fa_match)) == 1 and 1 < len(fa_slots) <= 2: # never remove atom from ring
fa_slots.pop(fa_match[0])
if len(set(ch_match)) == 1 and 1 < len(ch_slots) <= 2: # never remove atom from ring
ch_slots.pop(ch_match[0])
return True
def can_assemble(mol_tree, u, v_node_dict):
u_node_dict = mol_tree.nodes_dict[u]
u_neighbors = mol_tree.successors(u)
u_neighbors_node_dict = [
mol_tree.nodes_dict[_u]
for _u in u_neighbors
if _u in mol_tree.nodes_dict
]
neis = u_neighbors_node_dict + [v_node_dict]
for i, nei in enumerate(neis):
nei['nid'] = i
neighbors = [nei for nei in neis if nei['mol'].GetNumAtoms() > 1]
neighbors = sorted(
neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
singletons = [nei for nei in neis if nei['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors
cands = enum_assemble_nx(u_node_dict, neighbors)
return len(cands) > 0
def create_node_dict(smiles, clique=None):
if clique is None:
clique = []
return dict(
smiles=smiles,
mol=get_mol(smiles),
clique=clique,
)
class DGLJTNNDecoder(nn.Module):
def __init__(self, vocab, hidden_size, latent_size, embedding=None):
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.vocab_size = vocab.size()
self.vocab = vocab
if embedding is None:
self.embedding = nn.Embedding(self.vocab_size, hidden_size)
else:
self.embedding = embedding
self.dec_tree_edge_update = GRUUpdate(hidden_size)
self.W = nn.Linear(latent_size + hidden_size, hidden_size)
self.U = nn.Linear(latent_size + 2 * hidden_size, hidden_size)
self.W_o = nn.Linear(hidden_size, self.vocab_size)
self.U_s = nn.Linear(hidden_size, 1)
def forward(self, mol_trees, tree_vec):
'''
The training procedure which computes the prediction loss given the
ground truth tree
'''
mol_tree_batch = batch(mol_trees)
mol_tree_batch_lg = mol_tree_batch.line_graph(
backtracking=False, shared=True)
n_trees = len(mol_trees)
return self.run(mol_tree_batch, mol_tree_batch_lg, n_trees, tree_vec)
def run(self, mol_tree_batch, mol_tree_batch_lg, n_trees, tree_vec):
node_offset = np.cumsum([0] + mol_tree_batch.batch_num_nodes)
root_ids = node_offset[:-1]
n_nodes = mol_tree_batch.number_of_nodes()
n_edges = mol_tree_batch.number_of_edges()
mol_tree_batch.ndata.update({
'x': self.embedding(mol_tree_batch.ndata['wid']),
'h': cuda(torch.zeros(n_nodes, self.hidden_size)),
# whether it's newly generated node
'new': cuda(torch.ones(n_nodes).bool()),
})
mol_tree_batch.edata.update({
's': cuda(torch.zeros(n_edges, self.hidden_size)),
'm': cuda(torch.zeros(n_edges, self.hidden_size)),
'r': cuda(torch.zeros(n_edges, self.hidden_size)),
'z': cuda(torch.zeros(n_edges, self.hidden_size)),
'src_x': cuda(torch.zeros(n_edges, self.hidden_size)),
'dst_x': cuda(torch.zeros(n_edges, self.hidden_size)),
'rm': cuda(torch.zeros(n_edges, self.hidden_size)),
'accum_rm': cuda(torch.zeros(n_edges, self.hidden_size)),
})
mol_tree_batch.apply_edges(
func=lambda edges: {
'src_x': edges.src['x'], 'dst_x': edges.dst['x']},
)
# input tensors for stop prediction (p) and label prediction (q)
p_inputs = []
p_targets = []
q_inputs = []
q_targets = []
# Predict root
mol_tree_batch.pull(
root_ids,
dec_tree_node_msg,
dec_tree_node_reduce,
dec_tree_node_update,
)
# Extract hidden states and store them for stop/label prediction
h = mol_tree_batch.nodes[root_ids].data['h']
x = mol_tree_batch.nodes[root_ids].data['x']
p_inputs.append(torch.cat([x, h, tree_vec], 1))
# If the out degree is 0 we don't generate any edges at all
root_out_degrees = mol_tree_batch.out_degrees(root_ids)
q_inputs.append(torch.cat([h, tree_vec], 1))
q_targets.append(mol_tree_batch.nodes[root_ids].data['wid'])
# Traverse the tree and predict on children
for eid, p in dfs_order(mol_tree_batch, root_ids):
u, v = mol_tree_batch.find_edges(eid)
p_target_list = torch.zeros_like(root_out_degrees)
p_target_list[root_out_degrees > 0] = 1 - p
p_target_list = p_target_list[root_out_degrees >= 0]
p_targets.append(p_target_list.clone().detach())
root_out_degrees -= (root_out_degrees == 0).long()
root_out_degrees -= torch.tensor(np.isin(root_ids,
v).astype('int64'))
mol_tree_batch_lg.pull(
eid,
dec_tree_edge_msg,
dec_tree_edge_reduce,
self.dec_tree_edge_update,
)
is_new = mol_tree_batch.nodes[v].data['new']
mol_tree_batch.pull(
v,
dec_tree_node_msg,
dec_tree_node_reduce,
dec_tree_node_update,
)
# Extract
n_repr = mol_tree_batch.nodes[v].data
h = n_repr['h']
x = n_repr['x']
tree_vec_set = tree_vec[root_out_degrees >= 0]
wid = n_repr['wid']
p_inputs.append(torch.cat([x, h, tree_vec_set], 1))
# Only newly generated nodes are needed for label prediction
# NOTE: The following works since the uncomputed messages are zeros.
q_input = torch.cat([h, tree_vec_set], 1)[is_new]
q_target = wid[is_new]
if q_input.shape[0] > 0:
q_inputs.append(q_input)
q_targets.append(q_target)
p_targets.append(torch.zeros((root_out_degrees == 0).sum()).long())
# Batch compute the stop/label prediction losses
p_inputs = torch.cat(p_inputs, 0)
p_targets = cuda(torch.cat(p_targets, 0))
q_inputs = torch.cat(q_inputs, 0)
q_targets = torch.cat(q_targets, 0)
q = self.W_o(torch.relu(self.W(q_inputs)))
p = self.U_s(torch.relu(self.U(p_inputs)))[:, 0]
p_loss = F.binary_cross_entropy_with_logits(
p, p_targets.float(), reduction='sum'
) / n_trees
q_loss = F.cross_entropy(q, q_targets, reduction='sum') / n_trees
p_acc = ((p > 0).long() == p_targets).sum().float() / \
p_targets.shape[0]
q_acc = (q.max(1)[1] == q_targets).float().sum() / q_targets.shape[0]
self.q_inputs = q_inputs
self.q_targets = q_targets
self.q = q
self.p_inputs = p_inputs
self.p_targets = p_targets
self.p = p
return q_loss, p_loss, q_acc, p_acc
def decode(self, mol_vec):
assert mol_vec.shape[0] == 1
mol_tree = DGLMolTree(None)
init_hidden = cuda(torch.zeros(1, self.hidden_size))
root_hidden = torch.cat([init_hidden, mol_vec], 1)
root_hidden = F.relu(self.W(root_hidden))
root_score = self.W_o(root_hidden)
_, root_wid = torch.max(root_score, 1)
root_wid = root_wid.view(1)
mol_tree.add_nodes(1) # root
mol_tree.nodes[0].data['wid'] = root_wid
mol_tree.nodes[0].data['x'] = self.embedding(root_wid)
mol_tree.nodes[0].data['h'] = init_hidden
mol_tree.nodes[0].data['fail'] = cuda(torch.tensor([0]))
mol_tree.nodes_dict[0] = root_node_dict = create_node_dict(
self.vocab.get_smiles(root_wid))
stack, trace = [], []
stack.append((0, self.vocab.get_slots(root_wid)))
all_nodes = {0: root_node_dict}
first = True
new_node_id = 0
new_edge_id = 0
for step in range(MAX_DECODE_LEN):
u, u_slots = stack[-1]
udata = mol_tree.nodes[u].data
x = udata['x']
h = udata['h']
# Predict stop
p_input = torch.cat([x, h, mol_vec], 1)
p_score = torch.sigmoid(self.U_s(torch.relu(self.U(p_input))))
backtrack = (p_score.item() < 0.5)
if not backtrack:
# Predict next clique. Note that the prediction may fail due
# to lack of assemblable components
mol_tree.add_nodes(1)
new_node_id += 1
v = new_node_id
mol_tree.add_edges(u, v)
uv = new_edge_id
new_edge_id += 1
if first:
mol_tree.edata.update({
's': cuda(torch.zeros(1, self.hidden_size)),
'm': cuda(torch.zeros(1, self.hidden_size)),
'r': cuda(torch.zeros(1, self.hidden_size)),
'z': cuda(torch.zeros(1, self.hidden_size)),
'src_x': cuda(torch.zeros(1, self.hidden_size)),
'dst_x': cuda(torch.zeros(1, self.hidden_size)),
'rm': cuda(torch.zeros(1, self.hidden_size)),
'accum_rm': cuda(torch.zeros(1, self.hidden_size)),
})
first = False
mol_tree.edges[uv].data['src_x'] = mol_tree.nodes[u].data['x']
# keeping dst_x 0 is fine as h on new edge doesn't depend on that.
# DGL doesn't dynamically maintain a line graph.
mol_tree_lg = mol_tree.line_graph(
backtracking=False, shared=True)
mol_tree_lg.pull(
uv,
dec_tree_edge_msg,
dec_tree_edge_reduce,
self.dec_tree_edge_update.update_zm,
)
mol_tree.pull(
v,
dec_tree_node_msg,
dec_tree_node_reduce,
)
vdata = mol_tree.nodes[v].data
h_v = vdata['h']
q_input = torch.cat([h_v, mol_vec], 1)
q_score = torch.softmax(
self.W_o(torch.relu(self.W(q_input))), -1)
_, sort_wid = torch.sort(q_score, 1, descending=True)
sort_wid = sort_wid.squeeze()
next_wid = None
for wid in sort_wid.tolist()[:5]:
slots = self.vocab.get_slots(wid)
cand_node_dict = create_node_dict(
self.vocab.get_smiles(wid))
if (have_slots(u_slots, slots) and can_assemble(mol_tree, u, cand_node_dict)):
next_wid = wid
next_slots = slots
next_node_dict = cand_node_dict
break
if next_wid is None:
# Failed adding an actual children; v is a spurious node
# and we mark it.
vdata['fail'] = cuda(torch.tensor([1]))
backtrack = True
else:
next_wid = cuda(torch.tensor([next_wid]))
vdata['wid'] = next_wid
vdata['x'] = self.embedding(next_wid)
mol_tree.nodes_dict[v] = next_node_dict
all_nodes[v] = next_node_dict
stack.append((v, next_slots))
mol_tree.add_edge(v, u)
vu = new_edge_id
new_edge_id += 1
mol_tree.edges[uv].data['dst_x'] = mol_tree.nodes[v].data['x']
mol_tree.edges[vu].data['src_x'] = mol_tree.nodes[v].data['x']
mol_tree.edges[vu].data['dst_x'] = mol_tree.nodes[u].data['x']
# DGL doesn't dynamically maintain a line graph.
mol_tree_lg = mol_tree.line_graph(
backtracking=False, shared=True)
mol_tree_lg.apply_nodes(
self.dec_tree_edge_update.update_r,
uv
)
if backtrack:
if len(stack) == 1:
break # At root, terminate
pu, _ = stack[-2]
u_pu = mol_tree.edge_id(u, pu)
mol_tree_lg.pull(
u_pu,
dec_tree_edge_msg,
dec_tree_edge_reduce,
self.dec_tree_edge_update,
)
mol_tree.pull(
pu,
dec_tree_node_msg,
dec_tree_node_reduce,
)
stack.pop()
effective_nodes = mol_tree.filter_nodes(
lambda nodes: nodes.data['fail'] != 1)
effective_nodes, _ = torch.sort(effective_nodes)
return mol_tree, all_nodes, effective_nodes
# pylint: disable=C0111, C0103, E1101, W0611, W0612
import numpy as np
import torch
import torch.nn as nn
import dgl.function as DGLF
from dgl import batch, bfs_edges_generator
from .nnutils import GRUUpdate, cuda
MAX_NB = 8
def level_order(forest, roots):
edges = bfs_edges_generator(forest, roots)
_, leaves = forest.find_edges(edges[-1])
edges_back = bfs_edges_generator(forest, roots, reverse=True)
yield from reversed(edges_back)
yield from edges
enc_tree_msg = [DGLF.copy_src(src='m', out='m'),
DGLF.copy_src(src='rm', out='rm')]
enc_tree_reduce = [DGLF.sum(msg='m', out='s'),
DGLF.sum(msg='rm', out='accum_rm')]
enc_tree_gather_msg = DGLF.copy_edge(edge='m', out='m')
enc_tree_gather_reduce = DGLF.sum(msg='m', out='m')
class EncoderGatherUpdate(nn.Module):
def __init__(self, hidden_size):
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.W = nn.Linear(2 * hidden_size, hidden_size)
def forward(self, nodes):
x = nodes.data['x']
m = nodes.data['m']
return {
'h': torch.relu(self.W(torch.cat([x, m], 1))),
}
class DGLJTNNEncoder(nn.Module):
def __init__(self, vocab, hidden_size, embedding=None):
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.vocab_size = vocab.size()
self.vocab = vocab
if embedding is None:
self.embedding = nn.Embedding(self.vocab_size, hidden_size)
else:
self.embedding = embedding
self.enc_tree_update = GRUUpdate(hidden_size)
self.enc_tree_gather_update = EncoderGatherUpdate(hidden_size)
def forward(self, mol_trees):
mol_tree_batch = batch(mol_trees)
# Build line graph to prepare for belief propagation
mol_tree_batch_lg = mol_tree_batch.line_graph(
backtracking=False, shared=True)
return self.run(mol_tree_batch, mol_tree_batch_lg)
def run(self, mol_tree_batch, mol_tree_batch_lg):
# Since tree roots are designated to 0. In the batched graph we can
# simply find the corresponding node ID by looking at node_offset
node_offset = np.cumsum([0] + mol_tree_batch.batch_num_nodes)
root_ids = node_offset[:-1]
n_nodes = mol_tree_batch.number_of_nodes()
n_edges = mol_tree_batch.number_of_edges()
# Assign structure embeddings to tree nodes
mol_tree_batch.ndata.update({
'x': self.embedding(mol_tree_batch.ndata['wid']),
'h': cuda(torch.zeros(n_nodes, self.hidden_size)),
})
# Initialize the intermediate variables according to Eq (4)-(8).
# Also initialize the src_x and dst_x fields.
# TODO: context?
mol_tree_batch.edata.update({
's': cuda(torch.zeros(n_edges, self.hidden_size)),
'm': cuda(torch.zeros(n_edges, self.hidden_size)),
'r': cuda(torch.zeros(n_edges, self.hidden_size)),
'z': cuda(torch.zeros(n_edges, self.hidden_size)),
'src_x': cuda(torch.zeros(n_edges, self.hidden_size)),
'dst_x': cuda(torch.zeros(n_edges, self.hidden_size)),
'rm': cuda(torch.zeros(n_edges, self.hidden_size)),
'accum_rm': cuda(torch.zeros(n_edges, self.hidden_size)),
})
# Send the source/destination node features to edges
mol_tree_batch.apply_edges(
func=lambda edges: {
'src_x': edges.src['x'], 'dst_x': edges.dst['x']},
)
# Message passing
# I exploited the fact that the reduce function is a sum of incoming
# messages, and the uncomputed messages are zero vectors. Essentially,
# we can always compute s_ij as the sum of incoming m_ij, no matter
# if m_ij is actually computed or not.
for eid in level_order(mol_tree_batch, root_ids):
#eid = mol_tree_batch.edge_ids(u, v)
mol_tree_batch_lg.pull(
eid,
enc_tree_msg,
enc_tree_reduce,
self.enc_tree_update,
)
# Readout
mol_tree_batch.update_all(
enc_tree_gather_msg,
enc_tree_gather_reduce,
self.enc_tree_gather_update,
)
root_vecs = mol_tree_batch.nodes[root_ids].data['h']
return mol_tree_batch, root_vecs
This diff is collapsed.
# pylint: disable=C0111, C0103, E1101, W0611, W0612
import copy
import rdkit.Chem as Chem
def get_slots(smiles):
mol = Chem.MolFromSmiles(smiles)
return [(atom.GetSymbol(), atom.GetFormalCharge(), atom.GetTotalNumHs())
for atom in mol.GetAtoms()]
class Vocab(object):
def __init__(self, smiles_list):
self.vocab = smiles_list
self.vmap = {x: i for i, x in enumerate(self.vocab)}
self.slots = [get_slots(smiles) for smiles in self.vocab]
def get_index(self, smiles):
return self.vmap[smiles]
def get_smiles(self, idx):
return self.vocab[idx]
def get_slots(self, idx):
return copy.deepcopy(self.slots[idx])
def size(self):
return len(self.vocab)
# pylint: disable=C0111, C0103, E1101, W0611, W0612
import numpy as np
import rdkit.Chem as Chem
from dgl import DGLGraph
from .chemutils import (decode_stereo, enum_assemble_nx, get_clique_mol,
get_mol, get_smiles, set_atommap, tree_decomp)
class DGLMolTree(DGLGraph):
def __init__(self, smiles):
DGLGraph.__init__(self)
self.nodes_dict = {}
if smiles is None:
return
self.smiles = smiles
self.mol = get_mol(smiles)
# Stereo Generation
mol = Chem.MolFromSmiles(smiles)
self.smiles3D = Chem.MolToSmiles(mol, isomericSmiles=True)
self.smiles2D = Chem.MolToSmiles(mol)
self.stereo_cands = decode_stereo(self.smiles2D)
# cliques: a list of list of atom indices
cliques, edges = tree_decomp(self.mol)
root = 0
for i, c in enumerate(cliques):
cmol = get_clique_mol(self.mol, c)
csmiles = get_smiles(cmol)
self.nodes_dict[i] = dict(
smiles=csmiles,
mol=get_mol(csmiles),
clique=c,
)
if min(c) == 0:
root = i
self.add_nodes(len(cliques))
# The clique with atom ID 0 becomes root
if root > 0:
for attr in self.nodes_dict[0]:
self.nodes_dict[0][attr], self.nodes_dict[root][attr] = \
self.nodes_dict[root][attr], self.nodes_dict[0][attr]
src = np.zeros((len(edges) * 2,), dtype='int')
dst = np.zeros((len(edges) * 2,), dtype='int')
for i, (_x, _y) in enumerate(edges):
x = 0 if _x == root else root if _x == 0 else _x
y = 0 if _y == root else root if _y == 0 else _y
src[2 * i] = x
dst[2 * i] = y
src[2 * i + 1] = y
dst[2 * i + 1] = x
self.add_edges(src, dst)
for i in self.nodes_dict:
self.nodes_dict[i]['nid'] = i + 1
if self.out_degree(i) > 1: # Leaf node mol is not marked
set_atommap(self.nodes_dict[i]['mol'],
self.nodes_dict[i]['nid'])
self.nodes_dict[i]['is_leaf'] = (self.out_degree(i) == 1)
def treesize(self):
return self.number_of_nodes()
def _recover_node(self, i, original_mol):
node = self.nodes_dict[i]
clique = []
clique.extend(node['clique'])
if not node['is_leaf']:
for cidx in node['clique']:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(node['nid'])
for j in self.successors(i).numpy():
nei_node = self.nodes_dict[j]
clique.extend(nei_node['clique'])
if nei_node['is_leaf']: # Leaf node, no need to mark
continue
for cidx in nei_node['clique']:
# allow singleton node override the atom mapping
if cidx not in node['clique'] or len(nei_node['clique']) == 1:
atom = original_mol.GetAtomWithIdx(cidx)
atom.SetAtomMapNum(nei_node['nid'])
clique = list(set(clique))
label_mol = get_clique_mol(original_mol, clique)
node['label'] = Chem.MolToSmiles(
Chem.MolFromSmiles(get_smiles(label_mol)))
node['label_mol'] = get_mol(node['label'])
for cidx in clique:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0)
return node['label']
def _assemble_node(self, i):
neighbors = [self.nodes_dict[j] for j in self.successors(i).numpy()
if self.nodes_dict[j]['mol'].GetNumAtoms() > 1]
neighbors = sorted(
neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
singletons = [self.nodes_dict[j] for j in self.successors(i).numpy()
if self.nodes_dict[j]['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors
cands = enum_assemble_nx(self.nodes_dict[i], neighbors)
if len(cands) > 0:
self.nodes_dict[i]['cands'], self.nodes_dict[i]['cand_mols'], _ = list(
zip(*cands))
self.nodes_dict[i]['cands'] = list(self.nodes_dict[i]['cands'])
self.nodes_dict[i]['cand_mols'] = list(
self.nodes_dict[i]['cand_mols'])
else:
self.nodes_dict[i]['cands'] = []
self.nodes_dict[i]['cand_mols'] = []
def recover(self):
for i in self.nodes_dict:
self._recover_node(i, self.mol)
def assemble(self):
for i in self.nodes_dict:
self._assemble_node(i)
This diff is collapsed.
# pylint: disable=C0111, C0103, E1101, W0611, W0612
import os
import torch
import torch.nn as nn
from torch.autograd import Variable
def create_var(tensor, requires_grad=None):
if requires_grad is None:
return Variable(tensor)
else:
return Variable(tensor, requires_grad=requires_grad)
def cuda(tensor):
if torch.cuda.is_available() and not os.getenv('NOCUDA', None):
return tensor.cuda()
else:
return tensor
class GRUUpdate(nn.Module):
def __init__(self, hidden_size):
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.W_z = nn.Linear(2 * hidden_size, hidden_size)
self.W_r = nn.Linear(hidden_size, hidden_size, bias=False)
self.U_r = nn.Linear(hidden_size, hidden_size)
self.W_h = nn.Linear(2 * hidden_size, hidden_size)
def update_zm(self, node):
src_x = node.data['src_x']
s = node.data['s']
rm = node.data['accum_rm']
z = torch.sigmoid(self.W_z(torch.cat([src_x, s], 1)))
m = torch.tanh(self.W_h(torch.cat([src_x, rm], 1)))
m = (1 - z) * s + z * m
return {'m': m, 'z': z}
def update_r(self, node, zm=None):
dst_x = node.data['dst_x']
m = node.data['m'] if zm is None else zm['m']
r_1 = self.W_r(dst_x)
r_2 = self.U_r(m)
r = torch.sigmoid(r_1 + r_2)
return {'r': r, 'rm': r * m}
def forward(self, node):
dic = self.update_zm(node)
dic.update(self.update_r(node, zm=dic))
return dic
def move_dgl_to_cuda(g):
g.ndata.update({k: cuda(g.ndata[k]) for k in g.ndata})
g.edata.update({k: cuda(g.edata[k]) for k in g.edata})
"""MGCN"""
import torch.nn as nn
from ..gnn import MGCNGNN
from ..readout import MLPNodeReadout
__all__ = ['MGCNPredictor']
class MGCNPredictor(nn.Module):
"""MGCN for for regression and classification on graphs.
MGCN is introduced in `Molecular Property Prediction: A Multilevel Quantum Interactions
Modeling Perspective <https://arxiv.org/abs/1906.11081>`__.
Parameters
----------
feats : int
Size for the node and edge embeddings to learn. Default to 128.
n_layers : int
Number of gnn layers to use. Default to 3.
classifier_hidden_feats : int
Size for hidden representations in the classifier. Default to 64.
n_tasks : int
Number of tasks, which is also the output size. Default to 1.
num_node_types : int
Number of node types to embed. Default to 100.
num_edge_types : int
Number of edge types to embed. Default to 3000.
cutoff : float
Largest center in RBF expansion. Default to 5.0
gap : float
Difference between two adjacent centers in RBF expansion. Default to 1.0
"""
def __init__(self, feats=128, n_layers=3, classifier_hidden_feats=64, n_tasks=1,
num_node_types=100, num_edge_types=3000, cutoff=5.0, gap=1.0):
super(MGCNPredictor, self).__init__()
self.gnn = MGCNGNN(feats=feats,
n_layers=n_layers,
num_node_types=num_node_types,
num_edge_types=num_edge_types,
cutoff=cutoff,
gap=gap)
self.readout = MLPNodeReadout(node_feats=(n_layers + 1) * feats,
hidden_feats=classifier_hidden_feats,
graph_feats=n_tasks,
activation=nn.Softplus(beta=1, threshold=20))
def forward(self, g, node_types, edge_dists):
"""Graph-level regression/soft classification.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_types : int64 tensor of shape (V)
Node types to embed, V for the number of nodes.
edge_dists : float32 tensor of shape (E, 1)
Distances between end nodes of edges, E for the number of edges.
Returns
-------
float32 tensor of shape (G, n_tasks)
Prediction for the graphs in the batch. G for the number of graphs.
"""
node_feats = self.gnn(g, node_types, edge_dists)
return self.readout(g, node_feats)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
"""
Readout functions for computing molecular representations
out of node and edge representations.
"""
from .attentivefp_readout import *
from .weighted_sum_and_max import *
from .mlp_readout import *
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment