Unverified Commit 828a5e5b authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[DGL-LifeSci] Migration and Refactor (#1226)

* First commit

* Update

* Update splitters

* Update

* Update

* Update

* Update

* Update

* Update

* Migrate ACNN

* Fix

* Fix

* Update

* Update

* Update

* Update

* Update

* Update

* Finish classification

* Update

* Fix

* Update

* Update

* Update

* Fix

* Fix

* Fix

* Update

* Update

* Update

* trigger CI

* Fix CI

* Update

* Update

* Update

* Add default values

* Rename

* Update deprecation message
parent e4948c5c
import rdkit.Chem as Chem
import torch
from collections import defaultdict
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree
from dgl import DGLGraph
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na',
'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
MST_MAX_WEIGHT = 100
MAX_NCAND = 2000
def onek_encoding_unk(x, allowable_set):
if x not in allowable_set:
x = allowable_set[-1]
return [x == s for s in allowable_set]
def set_atommap(mol, num=0):
for atom in mol.GetAtoms():
atom.SetAtomMapNum(num)
def get_mol(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
Chem.Kekulize(mol)
return mol
def get_smiles(mol):
return Chem.MolToSmiles(mol, kekuleSmiles=True)
def decode_stereo(smiles2D):
mol = Chem.MolFromSmiles(smiles2D)
dec_isomers = list(EnumerateStereoisomers(mol))
dec_isomers = [Chem.MolFromSmiles(Chem.MolToSmiles(
mol, isomericSmiles=True)) for mol in dec_isomers]
smiles3D = [Chem.MolToSmiles(mol, isomericSmiles=True)
for mol in dec_isomers]
chiralN = [atom.GetIdx() for atom in dec_isomers[0].GetAtoms() if int(
atom.GetChiralTag()) > 0 and atom.GetSymbol() == "N"]
if len(chiralN) > 0:
for mol in dec_isomers:
for idx in chiralN:
mol.GetAtomWithIdx(idx).SetChiralTag(
Chem.rdchem.ChiralType.CHI_UNSPECIFIED)
smiles3D.append(Chem.MolToSmiles(mol, isomericSmiles=True))
return smiles3D
def sanitize(mol):
try:
smiles = get_smiles(mol)
mol = get_mol(smiles)
except Exception as e:
return None
return mol
def copy_atom(atom):
new_atom = Chem.Atom(atom.GetSymbol())
new_atom.SetFormalCharge(atom.GetFormalCharge())
new_atom.SetAtomMapNum(atom.GetAtomMapNum())
return new_atom
def copy_edit_mol(mol):
new_mol = Chem.RWMol(Chem.MolFromSmiles(''))
for atom in mol.GetAtoms():
new_atom = copy_atom(atom)
new_mol.AddAtom(new_atom)
for bond in mol.GetBonds():
a1 = bond.GetBeginAtom().GetIdx()
a2 = bond.GetEndAtom().GetIdx()
bt = bond.GetBondType()
new_mol.AddBond(a1, a2, bt)
return new_mol
def get_clique_mol(mol, atoms):
smiles = Chem.MolFragmentToSmiles(mol, atoms, kekuleSmiles=True)
new_mol = Chem.MolFromSmiles(smiles, sanitize=False)
new_mol = copy_edit_mol(new_mol).GetMol()
new_mol = sanitize(new_mol) # We assume this is not None
return new_mol
def tree_decomp(mol):
n_atoms = mol.GetNumAtoms()
if n_atoms == 1:
return [[0]], []
cliques = []
for bond in mol.GetBonds():
a1 = bond.GetBeginAtom().GetIdx()
a2 = bond.GetEndAtom().GetIdx()
if not bond.IsInRing():
cliques.append([a1, a2])
ssr = [list(x) for x in Chem.GetSymmSSSR(mol)]
cliques.extend(ssr)
nei_list = [[] for i in range(n_atoms)]
for i in range(len(cliques)):
for atom in cliques[i]:
nei_list[atom].append(i)
# Merge Rings with intersection > 2 atoms
for i in range(len(cliques)):
if len(cliques[i]) <= 2:
continue
for atom in cliques[i]:
for j in nei_list[atom]:
if i >= j or len(cliques[j]) <= 2:
continue
inter = set(cliques[i]) & set(cliques[j])
if len(inter) > 2:
cliques[i].extend(cliques[j])
cliques[i] = list(set(cliques[i]))
cliques[j] = []
cliques = [c for c in cliques if len(c) > 0]
nei_list = [[] for i in range(n_atoms)]
for i in range(len(cliques)):
for atom in cliques[i]:
nei_list[atom].append(i)
# Build edges and add singleton cliques
edges = defaultdict(int)
for atom in range(n_atoms):
if len(nei_list[atom]) <= 1:
continue
cnei = nei_list[atom]
bonds = [c for c in cnei if len(cliques[c]) == 2]
rings = [c for c in cnei if len(cliques[c]) > 4]
# In general, if len(cnei) >= 3, a singleton should be added, but 1 bond + 2 ring is currently not dealt with.
if len(bonds) > 2 or (len(bonds) == 2 and len(cnei) > 2):
cliques.append([atom])
c2 = len(cliques) - 1
for c1 in cnei:
edges[(c1, c2)] = 1
elif len(rings) > 2: # Multiple (n>2) complex rings
cliques.append([atom])
c2 = len(cliques) - 1
for c1 in cnei:
edges[(c1, c2)] = MST_MAX_WEIGHT - 1
else:
for i in range(len(cnei)):
for j in range(i + 1, len(cnei)):
c1, c2 = cnei[i], cnei[j]
inter = set(cliques[c1]) & set(cliques[c2])
if edges[(c1, c2)] < len(inter):
# cnei[i] < cnei[j] by construction
edges[(c1, c2)] = len(inter)
edges = [u + (MST_MAX_WEIGHT-v,) for u, v in edges.items()]
if len(edges) == 0:
return cliques, edges
# Compute Maximum Spanning Tree
row, col, data = list(zip(*edges))
n_clique = len(cliques)
clique_graph = csr_matrix((data, (row, col)), shape=(n_clique, n_clique))
junc_tree = minimum_spanning_tree(clique_graph)
row, col = junc_tree.nonzero()
edges = [(row[i], col[i]) for i in range(len(row))]
return (cliques, edges)
def atom_equal(a1, a2):
return a1.GetSymbol() == a2.GetSymbol() and a1.GetFormalCharge() == a2.GetFormalCharge()
# Bond type not considered because all aromatic (so SINGLE matches DOUBLE)
def ring_bond_equal(b1, b2, reverse=False):
b1 = (b1.GetBeginAtom(), b1.GetEndAtom())
if reverse:
b2 = (b2.GetEndAtom(), b2.GetBeginAtom())
else:
b2 = (b2.GetBeginAtom(), b2.GetEndAtom())
return atom_equal(b1[0], b2[0]) and atom_equal(b1[1], b2[1])
def attach_mols_nx(ctr_mol, neighbors, prev_nodes, nei_amap):
prev_nids = [node['nid'] for node in prev_nodes]
for nei_node in prev_nodes + neighbors:
nei_id, nei_mol = nei_node['nid'], nei_node['mol']
amap = nei_amap[nei_id]
for atom in nei_mol.GetAtoms():
if atom.GetIdx() not in amap:
new_atom = copy_atom(atom)
amap[atom.GetIdx()] = ctr_mol.AddAtom(new_atom)
if nei_mol.GetNumBonds() == 0:
nei_atom = nei_mol.GetAtomWithIdx(0)
ctr_atom = ctr_mol.GetAtomWithIdx(amap[0])
ctr_atom.SetAtomMapNum(nei_atom.GetAtomMapNum())
else:
for bond in nei_mol.GetBonds():
a1 = amap[bond.GetBeginAtom().GetIdx()]
a2 = amap[bond.GetEndAtom().GetIdx()]
if ctr_mol.GetBondBetweenAtoms(a1, a2) is None:
ctr_mol.AddBond(a1, a2, bond.GetBondType())
elif nei_id in prev_nids: # father node overrides
ctr_mol.RemoveBond(a1, a2)
ctr_mol.AddBond(a1, a2, bond.GetBondType())
return ctr_mol
def local_attach_nx(ctr_mol, neighbors, prev_nodes, amap_list):
ctr_mol = copy_edit_mol(ctr_mol)
nei_amap = {nei['nid']: {} for nei in prev_nodes + neighbors}
for nei_id, ctr_atom, nei_atom in amap_list:
nei_amap[nei_id][nei_atom] = ctr_atom
ctr_mol = attach_mols_nx(ctr_mol, neighbors, prev_nodes, nei_amap)
return ctr_mol.GetMol()
# This version records idx mapping between ctr_mol and nei_mol
def enum_attach_nx(ctr_mol, nei_node, amap, singletons):
nei_mol, nei_idx = nei_node['mol'], nei_node['nid']
att_confs = []
black_list = [atom_idx for nei_id, atom_idx,
_ in amap if nei_id in singletons]
ctr_atoms = [atom for atom in ctr_mol.GetAtoms() if atom.GetIdx()
not in black_list]
ctr_bonds = [bond for bond in ctr_mol.GetBonds()]
if nei_mol.GetNumBonds() == 0: # neighbor singleton
nei_atom = nei_mol.GetAtomWithIdx(0)
used_list = [atom_idx for _, atom_idx, _ in amap]
for atom in ctr_atoms:
if atom_equal(atom, nei_atom) and atom.GetIdx() not in used_list:
new_amap = amap + [(nei_idx, atom.GetIdx(), 0)]
att_confs.append(new_amap)
elif nei_mol.GetNumBonds() == 1: # neighbor is a bond
bond = nei_mol.GetBondWithIdx(0)
bond_val = int(bond.GetBondTypeAsDouble())
b1, b2 = bond.GetBeginAtom(), bond.GetEndAtom()
for atom in ctr_atoms:
# Optimize if atom is carbon (other atoms may change valence)
if atom.GetAtomicNum() == 6 and atom.GetTotalNumHs() < bond_val:
continue
if atom_equal(atom, b1):
new_amap = amap + [(nei_idx, atom.GetIdx(), b1.GetIdx())]
att_confs.append(new_amap)
elif atom_equal(atom, b2):
new_amap = amap + [(nei_idx, atom.GetIdx(), b2.GetIdx())]
att_confs.append(new_amap)
else:
# intersection is an atom
for a1 in ctr_atoms:
for a2 in nei_mol.GetAtoms():
if atom_equal(a1, a2):
# Optimize if atom is carbon (other atoms may change valence)
if a1.GetAtomicNum() == 6 and a1.GetTotalNumHs() + a2.GetTotalNumHs() < 4:
continue
new_amap = amap + [(nei_idx, a1.GetIdx(), a2.GetIdx())]
att_confs.append(new_amap)
# intersection is an bond
if ctr_mol.GetNumBonds() > 1:
for b1 in ctr_bonds:
for b2 in nei_mol.GetBonds():
if ring_bond_equal(b1, b2):
new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetBeginAtom(
).GetIdx()), (nei_idx, b1.GetEndAtom().GetIdx(), b2.GetEndAtom().GetIdx())]
att_confs.append(new_amap)
if ring_bond_equal(b1, b2, reverse=True):
new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetEndAtom(
).GetIdx()), (nei_idx, b1.GetEndAtom().GetIdx(), b2.GetBeginAtom().GetIdx())]
att_confs.append(new_amap)
return att_confs
# Try rings first: Speed-Up
def enum_assemble_nx(node, neighbors, prev_nodes=[], prev_amap=[]):
all_attach_confs = []
singletons = [nei_node['nid'] for nei_node in neighbors +
prev_nodes if nei_node['mol'].GetNumAtoms() == 1]
def search(cur_amap, depth):
if len(all_attach_confs) > MAX_NCAND:
return
if depth == len(neighbors):
all_attach_confs.append(cur_amap)
return
nei_node = neighbors[depth]
cand_amap = enum_attach_nx(node['mol'], nei_node, cur_amap, singletons)
cand_smiles = set()
candidates = []
for amap in cand_amap:
cand_mol = local_attach_nx(
node['mol'], neighbors[:depth+1], prev_nodes, amap)
cand_mol = sanitize(cand_mol)
if cand_mol is None:
continue
smiles = get_smiles(cand_mol)
if smiles in cand_smiles:
continue
cand_smiles.add(smiles)
candidates.append(amap)
if len(candidates) == 0:
return []
for new_amap in candidates:
search(new_amap, depth + 1)
search(prev_amap, 0)
cand_smiles = set()
candidates = []
for amap in all_attach_confs:
cand_mol = local_attach_nx(node['mol'], neighbors, prev_nodes, amap)
cand_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cand_mol))
smiles = Chem.MolToSmiles(cand_mol)
if smiles in cand_smiles:
continue
cand_smiles.add(smiles)
Chem.Kekulize(cand_mol)
candidates.append((smiles, cand_mol, amap))
return candidates
# Only used for debugging purpose
def dfs_assemble_nx(graph, cur_mol, global_amap, fa_amap, cur_node_id, fa_node_id):
cur_node = graph.nodes_dict[cur_node_id]
fa_node = graph.nodes_dict[fa_node_id] if fa_node_id is not None else None
fa_nid = fa_node['nid'] if fa_node is not None else -1
prev_nodes = [fa_node] if fa_node is not None else []
children_id = [nei for nei in graph[cur_node_id]
if graph.nodes_dict[nei]['nid'] != fa_nid]
children = [graph.nodes_dict[nei] for nei in children_id]
neighbors = [nei for nei in children if nei['mol'].GetNumAtoms() > 1]
neighbors = sorted(
neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
singletons = [nei for nei in children if nei['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors
cur_amap = [(fa_nid, a2, a1)
for nid, a1, a2 in fa_amap if nid == cur_node['nid']]
cands = enum_assemble_nx(
graph.nodes_dict[cur_node_id], neighbors, prev_nodes, cur_amap)
if len(cands) == 0:
return
cand_smiles, _, cand_amap = zip(*cands)
label_idx = cand_smiles.index(cur_node['label'])
label_amap = cand_amap[label_idx]
for nei_id, ctr_atom, nei_atom in label_amap:
if nei_id == fa_nid:
continue
global_amap[nei_id][nei_atom] = global_amap[cur_node['nid']][ctr_atom]
# father is already attached
cur_mol = attach_mols_nx(cur_mol, children, [], global_amap)
for nei_node_id, nei_node in zip(children_id, children):
if not nei_node['is_leaf']:
dfs_assemble_nx(graph, cur_mol, global_amap,
label_amap, nei_node_id, cur_node_id)
def mol2dgl_dec(cand_batch):
# Note that during graph decoding they don't predict stereochemistry-related
# characteristics (i.e. Chiral Atoms, E-Z, Cis-Trans). Instead, they decode
# the 2-D graph first, then enumerate all possible 3-D forms and find the
# one with highest score.
def atom_features(atom):
return (torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
+ onek_encoding_unk(atom.GetDegree(),
[0, 1, 2, 3, 4, 5])
+ onek_encoding_unk(atom.GetFormalCharge(),
[-1, -2, 1, 2, 0])
+ [atom.GetIsAromatic()]))
def bond_features(bond):
bt = bond.GetBondType()
return (torch.Tensor([bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE,
bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC,
bond.IsInRing()]))
cand_graphs = []
tree_mess_source_edges = [] # map these edges from trees to...
tree_mess_target_edges = [] # these edges on candidate graphs
tree_mess_target_nodes = []
n_nodes = 0
atom_x = []
bond_x = []
for mol, mol_tree, ctr_node_id in cand_batch:
n_atoms = mol.GetNumAtoms()
g = DGLGraph()
for i, atom in enumerate(mol.GetAtoms()):
assert i == atom.GetIdx()
atom_x.append(atom_features(atom))
g.add_nodes(n_atoms)
bond_src = []
bond_dst = []
for i, bond in enumerate(mol.GetBonds()):
a1 = bond.GetBeginAtom()
a2 = bond.GetEndAtom()
begin_idx = a1.GetIdx()
end_idx = a2.GetIdx()
features = bond_features(bond)
bond_src.append(begin_idx)
bond_dst.append(end_idx)
bond_x.append(features)
bond_src.append(end_idx)
bond_dst.append(begin_idx)
bond_x.append(features)
x_nid, y_nid = a1.GetAtomMapNum(), a2.GetAtomMapNum()
# Tree node ID in the batch
x_bid = mol_tree.nodes_dict[x_nid - 1]['idx'] if x_nid > 0 else -1
y_bid = mol_tree.nodes_dict[y_nid - 1]['idx'] if y_nid > 0 else -1
if x_bid >= 0 and y_bid >= 0 and x_bid != y_bid:
if mol_tree.has_edge_between(x_bid, y_bid):
tree_mess_target_edges.append(
(begin_idx + n_nodes, end_idx + n_nodes))
tree_mess_source_edges.append((x_bid, y_bid))
tree_mess_target_nodes.append(end_idx + n_nodes)
if mol_tree.has_edge_between(y_bid, x_bid):
tree_mess_target_edges.append(
(end_idx + n_nodes, begin_idx + n_nodes))
tree_mess_source_edges.append((y_bid, x_bid))
tree_mess_target_nodes.append(begin_idx + n_nodes)
n_nodes += n_atoms
g.add_edges(bond_src, bond_dst)
cand_graphs.append(g)
return cand_graphs, torch.stack(atom_x), \
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0), \
torch.LongTensor(tree_mess_source_edges), \
torch.LongTensor(tree_mess_target_edges), \
torch.LongTensor(tree_mess_target_nodes)
def mol2dgl_enc(smiles):
def atom_features(atom):
return (torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
+ onek_encoding_unk(atom.GetDegree(),
[0, 1, 2, 3, 4, 5])
+ onek_encoding_unk(atom.GetFormalCharge(), [-1, -2, 1, 2, 0])
+ onek_encoding_unk(int(atom.GetChiralTag()), [0, 1, 2, 3])
+ [atom.GetIsAromatic()]))
def bond_features(bond):
bt = bond.GetBondType()
stereo = int(bond.GetStereo())
fbond = [bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt ==
Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()]
fstereo = onek_encoding_unk(stereo, [0, 1, 2, 3, 4, 5])
return (torch.Tensor(fbond + fstereo))
n_edges = 0
atom_x = []
bond_x = []
mol = get_mol(smiles)
n_atoms = mol.GetNumAtoms()
n_bonds = mol.GetNumBonds()
graph = DGLGraph()
for i, atom in enumerate(mol.GetAtoms()):
assert i == atom.GetIdx()
atom_x.append(atom_features(atom))
graph.add_nodes(n_atoms)
bond_src = []
bond_dst = []
for i, bond in enumerate(mol.GetBonds()):
begin_idx = bond.GetBeginAtom().GetIdx()
end_idx = bond.GetEndAtom().GetIdx()
features = bond_features(bond)
bond_src.append(begin_idx)
bond_dst.append(end_idx)
bond_x.append(features)
# set up the reverse direction
bond_src.append(end_idx)
bond_dst.append(begin_idx)
bond_x.append(features)
graph.add_edges(bond_src, bond_dst)
n_edges += n_bonds
return graph, torch.stack(atom_x), \
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0)
import dgl
import os
import torch
from dgl.data.utils import download, extract_archive, get_download_dir
from torch.utils.data import Dataset
from .mol_tree import Vocab, DGLMolTree
from .chemutils import mol2dgl_dec, mol2dgl_enc
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca',
'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
ATOM_FDIM_DEC = len(ELEM_LIST) + 6 + 5 + 1
BOND_FDIM_DEC = 5
MAX_NB = 10
PAPER = os.getenv('PAPER', False)
_url = 'https://s3-ap-southeast-1.amazonaws.com/dgl-data-cn/dataset/jtnn.zip'
def _unpack_field(examples, field):
return [e[field] for e in examples]
def _set_node_id(mol_tree, vocab):
wid = []
for i, node in enumerate(mol_tree.nodes_dict):
mol_tree.nodes_dict[node]['idx'] = i
wid.append(vocab.get_index(mol_tree.nodes_dict[node]['smiles']))
return wid
class JTNNDataset(Dataset):
def __init__(self, data, vocab, training=True):
self.dir = get_download_dir()
self.zip_file_path='{}/jtnn.zip'.format(self.dir)
download(_url, path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/jtnn'.format(self.dir))
print('Loading data...')
if data in ['train', 'test']:
data_file = '{}/jtnn/{}.txt'.format(self.dir, data)
else:
data_file = data
with open(data_file) as f:
self.data = [line.strip("\r\n ").split()[0] for line in f]
self.vocab_file = '{}/jtnn/{}.txt'.format(self.dir, vocab)
print('Loading finished.')
print('\tNum samples:', len(self.data))
print('\tVocab file:', self.vocab_file)
self.training = training
self.vocab = Vocab([x.strip("\r\n ") for x in open(self.vocab_file)])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
smiles = self.data[idx]
mol_tree = DGLMolTree(smiles)
mol_tree.recover()
mol_tree.assemble()
wid = _set_node_id(mol_tree, self.vocab)
# prebuild the molecule graph
mol_graph, atom_x_enc, bond_x_enc = mol2dgl_enc(mol_tree.smiles)
result = {
'mol_tree': mol_tree,
'mol_graph': mol_graph,
'atom_x_enc': atom_x_enc,
'bond_x_enc': bond_x_enc,
'wid': wid,
}
if not self.training:
return result
# prebuild the candidate graph list
cands = []
for node_id, node in mol_tree.nodes_dict.items():
# fill in ground truth
if node['label'] not in node['cands']:
node['cands'].append(node['label'])
node['cand_mols'].append(node['label_mol'])
if node['is_leaf'] or len(node['cands']) == 1:
continue
cands.extend([(cand, mol_tree, node_id)
for cand in node['cand_mols']])
if len(cands) > 0:
cand_graphs, atom_x_dec, bond_x_dec, tree_mess_src_e, \
tree_mess_tgt_e, tree_mess_tgt_n = mol2dgl_dec(cands)
else:
cand_graphs = []
atom_x_dec = torch.zeros(0, ATOM_FDIM_DEC)
bond_x_dec = torch.zeros(0, BOND_FDIM_DEC)
tree_mess_src_e = torch.zeros(0, 2).long()
tree_mess_tgt_e = torch.zeros(0, 2).long()
tree_mess_tgt_n = torch.zeros(0).long()
# prebuild the stereoisomers
cands = mol_tree.stereo_cands
if len(cands) > 1:
if mol_tree.smiles3D not in cands:
cands.append(mol_tree.smiles3D)
stereo_graphs = [mol2dgl_enc(c) for c in cands]
stereo_cand_graphs, stereo_atom_x_enc, stereo_bond_x_enc = \
zip(*stereo_graphs)
stereo_atom_x_enc = torch.cat(stereo_atom_x_enc)
stereo_bond_x_enc = torch.cat(stereo_bond_x_enc)
stereo_cand_label = [(cands.index(mol_tree.smiles3D), len(cands))]
else:
stereo_cand_graphs = []
stereo_atom_x_enc = torch.zeros(0, atom_x_enc.shape[1])
stereo_bond_x_enc = torch.zeros(0, bond_x_enc.shape[1])
stereo_cand_label = []
result.update({
'cand_graphs': cand_graphs,
'atom_x_dec': atom_x_dec,
'bond_x_dec': bond_x_dec,
'tree_mess_src_e': tree_mess_src_e,
'tree_mess_tgt_e': tree_mess_tgt_e,
'tree_mess_tgt_n': tree_mess_tgt_n,
'stereo_cand_graphs': stereo_cand_graphs,
'stereo_atom_x_enc': stereo_atom_x_enc,
'stereo_bond_x_enc': stereo_bond_x_enc,
'stereo_cand_label': stereo_cand_label,
})
return result
class JTNNCollator(object):
def __init__(self, vocab, training):
self.vocab = vocab
self.training = training
@staticmethod
def _batch_and_set(graphs, atom_x, bond_x, flatten):
if flatten:
graphs = [g for f in graphs for g in f]
graph_batch = dgl.batch(graphs)
graph_batch.ndata['x'] = atom_x
graph_batch.edata.update({
'x': bond_x,
'src_x': atom_x.new(bond_x.shape[0], atom_x.shape[1]).zero_(),
})
return graph_batch
def __call__(self, examples):
# get list of trees
mol_trees = _unpack_field(examples, 'mol_tree')
wid = _unpack_field(examples, 'wid')
for _wid, mol_tree in zip(wid, mol_trees):
mol_tree.ndata['wid'] = torch.LongTensor(_wid)
# TODO: either support pickling or get around ctypes pointers using scipy
# batch molecule graphs
mol_graphs = _unpack_field(examples, 'mol_graph')
atom_x = torch.cat(_unpack_field(examples, 'atom_x_enc'))
bond_x = torch.cat(_unpack_field(examples, 'bond_x_enc'))
mol_graph_batch = self._batch_and_set(mol_graphs, atom_x, bond_x, False)
result = {
'mol_trees': mol_trees,
'mol_graph_batch': mol_graph_batch,
}
if not self.training:
return result
# batch candidate graphs
cand_graphs = _unpack_field(examples, 'cand_graphs')
cand_batch_idx = []
atom_x = torch.cat(_unpack_field(examples, 'atom_x_dec'))
bond_x = torch.cat(_unpack_field(examples, 'bond_x_dec'))
tree_mess_src_e = _unpack_field(examples, 'tree_mess_src_e')
tree_mess_tgt_e = _unpack_field(examples, 'tree_mess_tgt_e')
tree_mess_tgt_n = _unpack_field(examples, 'tree_mess_tgt_n')
n_graph_nodes = 0
n_tree_nodes = 0
for i in range(len(cand_graphs)):
tree_mess_tgt_e[i] += n_graph_nodes
tree_mess_src_e[i] += n_tree_nodes
tree_mess_tgt_n[i] += n_graph_nodes
n_graph_nodes += sum(g.number_of_nodes() for g in cand_graphs[i])
n_tree_nodes += mol_trees[i].number_of_nodes()
cand_batch_idx.extend([i] * len(cand_graphs[i]))
tree_mess_tgt_e = torch.cat(tree_mess_tgt_e)
tree_mess_src_e = torch.cat(tree_mess_src_e)
tree_mess_tgt_n = torch.cat(tree_mess_tgt_n)
cand_graph_batch = self._batch_and_set(cand_graphs, atom_x, bond_x, True)
# batch stereoisomers
stereo_cand_graphs = _unpack_field(examples, 'stereo_cand_graphs')
atom_x = torch.cat(_unpack_field(examples, 'stereo_atom_x_enc'))
bond_x = torch.cat(_unpack_field(examples, 'stereo_bond_x_enc'))
stereo_cand_batch_idx = []
for i in range(len(stereo_cand_graphs)):
stereo_cand_batch_idx.extend([i] * len(stereo_cand_graphs[i]))
if len(stereo_cand_batch_idx) > 0:
stereo_cand_labels = [
(label, length)
for ex in _unpack_field(examples, 'stereo_cand_label')
for label, length in ex
]
stereo_cand_labels, stereo_cand_lengths = zip(*stereo_cand_labels)
stereo_cand_graph_batch = self._batch_and_set(
stereo_cand_graphs, atom_x, bond_x, True)
else:
stereo_cand_labels = []
stereo_cand_lengths = []
stereo_cand_graph_batch = None
stereo_cand_batch_idx = []
result.update({
'cand_graph_batch': cand_graph_batch,
'cand_batch_idx': cand_batch_idx,
'tree_mess_tgt_e': tree_mess_tgt_e,
'tree_mess_src_e': tree_mess_src_e,
'tree_mess_tgt_n': tree_mess_tgt_n,
'stereo_cand_graph_batch': stereo_cand_graph_batch,
'stereo_cand_batch_idx': stereo_cand_batch_idx,
'stereo_cand_labels': stereo_cand_labels,
'stereo_cand_lengths': stereo_cand_lengths,
})
return result
import copy
import numpy as np
import rdkit.Chem as Chem
from dgl import DGLGraph
from .chemutils import get_clique_mol, tree_decomp, get_mol, get_smiles, \
set_atommap, enum_assemble_nx, decode_stereo
def get_slots(smiles):
mol = Chem.MolFromSmiles(smiles)
return [(atom.GetSymbol(), atom.GetFormalCharge(), atom.GetTotalNumHs()) for atom in mol.GetAtoms()]
class Vocab(object):
def __init__(self, smiles_list):
self.vocab = smiles_list
self.vmap = {x: i for i, x in enumerate(self.vocab)}
self.slots = [get_slots(smiles) for smiles in self.vocab]
def get_index(self, smiles):
return self.vmap[smiles]
def get_smiles(self, idx):
return self.vocab[idx]
def get_slots(self, idx):
return copy.deepcopy(self.slots[idx])
def size(self):
return len(self.vocab)
class DGLMolTree(DGLGraph):
def __init__(self, smiles):
DGLGraph.__init__(self)
self.nodes_dict = {}
if smiles is None:
return
self.smiles = smiles
self.mol = get_mol(smiles)
# Stereo Generation
mol = Chem.MolFromSmiles(smiles)
self.smiles3D = Chem.MolToSmiles(mol, isomericSmiles=True)
self.smiles2D = Chem.MolToSmiles(mol)
self.stereo_cands = decode_stereo(self.smiles2D)
# cliques: a list of list of atom indices
cliques, edges = tree_decomp(self.mol)
root = 0
for i, c in enumerate(cliques):
cmol = get_clique_mol(self.mol, c)
csmiles = get_smiles(cmol)
self.nodes_dict[i] = dict(
smiles=csmiles,
mol=get_mol(csmiles),
clique=c,
)
if min(c) == 0:
root = i
self.add_nodes(len(cliques))
# The clique with atom ID 0 becomes root
if root > 0:
for attr in self.nodes_dict[0]:
self.nodes_dict[0][attr], self.nodes_dict[root][attr] = \
self.nodes_dict[root][attr], self.nodes_dict[0][attr]
src = np.zeros((len(edges) * 2,), dtype='int')
dst = np.zeros((len(edges) * 2,), dtype='int')
for i, (_x, _y) in enumerate(edges):
x = 0 if _x == root else root if _x == 0 else _x
y = 0 if _y == root else root if _y == 0 else _y
src[2 * i] = x
dst[2 * i] = y
src[2 * i + 1] = y
dst[2 * i + 1] = x
self.add_edges(src, dst)
for i in self.nodes_dict:
self.nodes_dict[i]['nid'] = i + 1
if self.out_degree(i) > 1: # Leaf node mol is not marked
set_atommap(self.nodes_dict[i]['mol'], self.nodes_dict[i]['nid'])
self.nodes_dict[i]['is_leaf'] = (self.out_degree(i) == 1)
def treesize(self):
return self.number_of_nodes()
def _recover_node(self, i, original_mol):
node = self.nodes_dict[i]
clique = []
clique.extend(node['clique'])
if not node['is_leaf']:
for cidx in node['clique']:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(node['nid'])
for j in self.successors(i).numpy():
nei_node = self.nodes_dict[j]
clique.extend(nei_node['clique'])
if nei_node['is_leaf']: # Leaf node, no need to mark
continue
for cidx in nei_node['clique']:
# allow singleton node override the atom mapping
if cidx not in node['clique'] or len(nei_node['clique']) == 1:
atom = original_mol.GetAtomWithIdx(cidx)
atom.SetAtomMapNum(nei_node['nid'])
clique = list(set(clique))
label_mol = get_clique_mol(original_mol, clique)
node['label'] = Chem.MolToSmiles(Chem.MolFromSmiles(get_smiles(label_mol)))
node['label_mol'] = get_mol(node['label'])
for cidx in clique:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0)
return node['label']
def _assemble_node(self, i):
neighbors = [self.nodes_dict[j] for j in self.successors(i).numpy()
if self.nodes_dict[j]['mol'].GetNumAtoms() > 1]
neighbors = sorted(neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
singletons = [self.nodes_dict[j] for j in self.successors(i).numpy()
if self.nodes_dict[j]['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors
cands = enum_assemble_nx(self.nodes_dict[i], neighbors)
if len(cands) > 0:
self.nodes_dict[i]['cands'], self.nodes_dict[i]['cand_mols'], _ = list(zip(*cands))
self.nodes_dict[i]['cands'] = list(self.nodes_dict[i]['cands'])
self.nodes_dict[i]['cand_mols'] = list(self.nodes_dict[i]['cand_mols'])
else:
self.nodes_dict[i]['cands'] = []
self.nodes_dict[i]['cand_mols'] = []
def recover(self):
for i in self.nodes_dict:
self._recover_node(i, self.mol)
def assemble(self):
for i in self.nodes_dict:
self._assemble_node(i)
import argparse
import rdkit
import torch
from dgllife.model import DGLJTNNVAE, load_pretrained
from dgllife.model.model_zoo.jtnn.nnutils import cuda
from torch.utils.data import DataLoader
from jtnn import *
def worker_init_fn(id_):
lg = rdkit.RDLogger.logger()
lg.setLevel(rdkit.RDLogger.CRITICAL)
worker_init_fn(None)
parser = argparse.ArgumentParser(description="Evaluation for JTNN",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-t", "--train", dest="train",
default='test', help='Training file name')
parser.add_argument("-v", "--vocab", dest="vocab",
default='vocab', help='Vocab file name')
parser.add_argument("-m", "--model", dest="model_path", default=None,
help="Pre-trained model to be loaded for evalutaion. If not specified,"
" would use pre-trained model from model zoo")
parser.add_argument("-w", "--hidden", dest="hidden_size", default=450,
help="Hidden size of representation vector, "
"should be consistent with pre-trained model")
parser.add_argument("-l", "--latent", dest="latent_size", default=56,
help="Latent Size of node(atom) features and edge(atom) features, "
"should be consistent with pre-trained model")
parser.add_argument("-d", "--depth", dest="depth", default=3,
help="Depth of message passing hops, "
"should be consistent with pre-trained model")
args = parser.parse_args()
dataset = JTNNDataset(data=args.train, vocab=args.vocab, training=False)
vocab_file = dataset.vocab_file
hidden_size = int(args.hidden_size)
latent_size = int(args.latent_size)
depth = int(args.depth)
model = DGLJTNNVAE(vocab_file=vocab_file,
hidden_size=hidden_size,
latent_size=latent_size,
depth=depth)
if args.model_path is not None:
model.load_state_dict(torch.load(args.model_path))
else:
model = load_pretrained("JTNN_ZINC")
model = cuda(model)
model.eval()
print("Model #Params: %dK" %
(sum([x.nelement() for x in model.parameters()]) / 1000,))
MAX_EPOCH = 100
PRINT_ITER = 20
def reconstruct():
dataset.training = False
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=0,
collate_fn=JTNNCollator(dataset.vocab, False),
drop_last=True,
worker_init_fn=worker_init_fn)
# Just an example of molecule decoding; in reality you may want to sample
# tree and molecule vectors.
acc = 0.0
tot = 0
with torch.no_grad():
for it, batch in enumerate(dataloader):
gt_smiles = batch['mol_trees'][0].smiles
# print(gt_smiles)
model.move_to_cuda(batch)
try:
_, tree_vec, mol_vec = model.encode(batch)
tree_mean = model.T_mean(tree_vec)
# Following Mueller et al.
tree_log_var = -torch.abs(model.T_var(tree_vec))
mol_mean = model.G_mean(mol_vec)
# Following Mueller et al.
mol_log_var = -torch.abs(model.G_var(mol_vec))
epsilon = torch.randn(1, model.latent_size // 2).cuda()
tree_vec = tree_mean + torch.exp(tree_log_var // 2) * epsilon
epsilon = torch.randn(1, model.latent_size // 2).cuda()
mol_vec = mol_mean + torch.exp(mol_log_var // 2) * epsilon
dec_smiles = model.decode(tree_vec, mol_vec)
if dec_smiles == gt_smiles:
acc += 1
tot += 1
except Exception as e:
print("Failed to encode: {}".format(gt_smiles))
print(e)
if it % 20 == 1:
print("Progress {}/{}; Current Reconstruction Accuracy: {:.4f}".format(it,
len(dataloader), acc / tot))
return acc / tot
if __name__ == '__main__':
reconstruct_acc = reconstruct()
print("Reconstruction Accuracy: {}".format(reconstruct_acc))
import argparse
import rdkit
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from dgllife.model import DGLJTNNVAE
from dgllife.model.model_zoo.jtnn.nnutils import cuda
from torch.utils.data import DataLoader
from jtnn import *
torch.multiprocessing.set_sharing_strategy('file_system')
def worker_init_fn(id_):
lg = rdkit.RDLogger.logger()
lg.setLevel(rdkit.RDLogger.CRITICAL)
worker_init_fn(None)
parser = argparse.ArgumentParser(description="Training for JTNN",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-t", "--train", dest="train", default='train', help='Training file name')
parser.add_argument("-v", "--vocab", dest="vocab", default='vocab', help='Vocab file name')
parser.add_argument("-s", "--save_dir", dest="save_path", default='./',
help="Path to save checkpoint models, default to be current working directory")
parser.add_argument("-m", "--model", dest="model_path", default=None,
help="Path to load pre-trained model")
parser.add_argument("-b", "--batch", dest="batch_size", default=40,
help="Batch size")
parser.add_argument("-w", "--hidden", dest="hidden_size", default=200,
help="Size of representation vectors")
parser.add_argument("-l", "--latent", dest="latent_size", default=56,
help="Latent Size of node(atom) features and edge(atom) features")
parser.add_argument("-d", "--depth", dest="depth", default=3,
help="Depth of message passing hops")
parser.add_argument("-z", "--beta", dest="beta", default=1.0,
help="Coefficient of KL Divergence term")
parser.add_argument("-q", "--lr", dest="lr", default=1e-3,
help="Learning Rate")
args = parser.parse_args()
dataset = JTNNDataset(data=args.train, vocab=args.vocab, training=True)
vocab_file = dataset.vocab_file
batch_size = int(args.batch_size)
hidden_size = int(args.hidden_size)
latent_size = int(args.latent_size)
depth = int(args.depth)
beta = float(args.beta)
lr = float(args.lr)
model = DGLJTNNVAE(vocab_file=vocab_file,
hidden_size=hidden_size,
latent_size=latent_size,
depth=depth)
if args.model_path is not None:
model.load_state_dict(torch.load(args.model_path))
else:
for param in model.parameters():
if param.dim() == 1:
nn.init.constant_(param, 0)
else:
nn.init.xavier_normal_(param)
model = cuda(model)
print("Model #Params: %dK" % (sum([x.nelement() for x in model.parameters()]) / 1000,))
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.ExponentialLR(optimizer, 0.9)
MAX_EPOCH = 100
PRINT_ITER = 20
def train():
dataset.training = True
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4,
collate_fn=JTNNCollator(dataset.vocab, True),
drop_last=True,
worker_init_fn=worker_init_fn)
for epoch in range(MAX_EPOCH):
word_acc, topo_acc, assm_acc, steo_acc = 0, 0, 0, 0
for it, batch in enumerate(dataloader):
model.zero_grad()
try:
loss, kl_div, wacc, tacc, sacc, dacc = model(batch, beta)
except:
print([t.smiles for t in batch['mol_trees']])
raise
loss.backward()
optimizer.step()
word_acc += wacc
topo_acc += tacc
assm_acc += sacc
steo_acc += dacc
if (it + 1) % PRINT_ITER == 0:
word_acc = word_acc / PRINT_ITER * 100
topo_acc = topo_acc / PRINT_ITER * 100
assm_acc = assm_acc / PRINT_ITER * 100
steo_acc = steo_acc / PRINT_ITER * 100
print("KL: %.1f, Word: %.2f, Topo: %.2f, Assm: %.2f, Steo: %.2f, Loss: %.6f" % (
kl_div, word_acc, topo_acc, assm_acc, steo_acc, loss.item()))
word_acc, topo_acc, assm_acc, steo_acc = 0, 0, 0, 0
sys.stdout.flush()
if (it + 1) % 1500 == 0: # Fast annealing
scheduler.step()
print("learning rate: %.6f" % scheduler.get_lr()[0])
torch.save(model.state_dict(),
args.save_path + "/model.iter-%d-%d" % (epoch, it + 1))
scheduler.step()
print("learning rate: %.6f" % scheduler.get_lr()[0])
torch.save(model.state_dict(), args.save_path + "/model.iter-" + str(epoch))
if __name__ == '__main__':
train()
print('# passes:', model.n_passes)
print('Total # nodes processed:', model.n_nodes_total)
print('Total # edges processed:', model.n_edges_total)
print('Total # tree nodes processed:', model.n_tree_nodes_total)
print('Graph decoder: # passes:', model.jtmpn.n_passes)
print('Graph decoder: Total # candidates processed:', model.jtmpn.n_samples_total)
print('Graph decoder: Total # nodes processed:', model.jtmpn.n_nodes_total)
print('Graph decoder: Total # edges processed:', model.jtmpn.n_edges_total)
print('Graph encoder: # passes:', model.mpn.n_passes)
print('Graph encoder: Total # candidates processed:', model.mpn.n_samples_total)
print('Graph encoder: Total # nodes processed:', model.mpn.n_nodes_total)
print('Graph encoder: Total # edges processed:', model.mpn.n_edges_total)
# Property Prediction
## Classification
Classification tasks require assigning discrete labels to a molecule, e.g. molecule toxicity.
### Datasets
- **Tox21**. The ["Toxicology in the 21st Century" (Tox21)](https://tripod.nih.gov/tox21/challenge/) initiative created
a public database measuring toxicity of compounds, which has been used in the 2014 Tox21 Data Challenge. The dataset
contains qualitative toxicity measurements for 8014 compounds on 12 different targets, including nuclear receptors and
stress response pathways. Each target yields a binary prediction problem. MoleculeNet [1] randomly splits the dataset
into training, validation and test set with a 80/10/10 ratio. By default we follow their split method.
### Models
- **Graph Convolutional Network** [2], [3]. Graph Convolutional Networks (GCN) have been one of the most popular graph neural
networks and they can be easily extended for graph level prediction. MoleculeNet [1] reports baseline results of graph
convolutions over multiple datasets.
- **Graph Attention Networks** [7]. Graph Attention Networks (GATs) incorporate multi-head attention into GCNs,
explicitly modeling the interactions between adjacent atoms.
### Usage
Use `classification.py` with arguments
```
-m {GCN, GAT}, MODEL, model to use
-d {Tox21}, DATASET, dataset to use
```
If you want to use the pre-trained model, simply add `-p`.
We use GPU whenever it is available.
### Performance
#### GCN on Tox21
| Source | Averaged Test ROC-AUC Score |
| ---------------- | --------------------------- |
| MoleculeNet [1] | 0.829 |
| [DeepChem example](https://github.com/deepchem/deepchem/blob/master/examples/tox21/tox21_tensorgraph_graph_conv.py) | 0.813 |
| Pretrained model | 0.833 |
Note that the dataset is randomly split so these numbers are only for reference and they do not necessarily suggest
a real difference.
#### GAT on Tox21
| Source | Averaged Test ROC-AUC Score |
| ---------------- | --------------------------- |
| Pretrained model | 0.853 |
## Regression
Regression tasks require assigning continuous labels to a molecule, e.g. molecular energy.
### Datasets
- **Alchemy**. The [Alchemy Dataset](https://alchemy.tencent.com/) is introduced by Tencent Quantum Lab to facilitate the development of new
machine learning models useful for chemistry and materials science. The dataset lists 12 quantum mechanical properties of 130,000+ organic
molecules comprising up to 12 heavy atoms (C, N, O, S, F and Cl), sampled from the [GDBMedChem](http://gdb.unibe.ch/downloads/) database.
These properties have been calculated using the open-source computational chemistry program Python-based Simulation of Chemistry Framework
([PySCF](https://github.com/pyscf/pyscf)). The Alchemy dataset expands on the volume and diversity of existing molecular datasets such as QM9.
- **PubChem BioAssay Aromaticity**. The dataset is introduced in
[Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism](https://www.ncbi.nlm.nih.gov/pubmed/31408336),
for the task of predicting the number of aromatic atoms in molecules. The dataset was constructed by sampling 3945 molecules with 0-40 aromatic atoms
from the PubChem BioAssay dataset.
### Models
- **Message Passing Neural Network** [6]. Message Passing Neural Networks (MPNNs) have reached the best performance on
the QM9 dataset for some time.
- **SchNet** [4]. SchNet employs continuous filter convolutional layers to model quantum interactions in molecules
without requiring them to lie on grids.
- **Multilevel Graph Convolutional Neural Network** [5]. Multilevel Graph Convolutional Neural Networks (MGCN) are
hierarchical graph neural networks that extract features from the conformation and spatial information followed by the
multilevel interactions.
- **AttentiveFP** [8]. AttentiveFP combines attention and GRU for better model capacity and shows competitive
performance across datasetts.
### Usage
Use `regression.py` with arguments
```
-m {MPNN, SchNet, MGCN, AttentiveFP}, Model to use
-d {Alchemy, Aromaticity}, Dataset to use
```
If you want to use the pre-trained model, simply add `-p`. Currently we only support pre-trained models of AttentiveFP
on PubChem BioAssay Aromaticity dataset.
### Performance
#### Alchemy
The Alchemy contest is still ongoing. Before the test set is fully released, we only include the performance numbers
on the training and validation set for reference.
| Model | Training MAE | Validation MAE |
| ---------- | ------------ | -------------- |
| SchNet [4] | 0.0651 | 0.0925 |
| MGCN [5] | 0.0582 | 0.0942 |
| MPNN [6] | 0.1004 | 0.1587 |
#### PubChem BioAssay Aromaticity
| Model | Test RMSE |
| --------------- | --------- |
| AttentiveFP [8] | 0.7508 |
Note that the dataset is randomly split so this number is only for reference.
## Interpretation
[8] visualizes the weights of atoms in readout for possible interpretations like the figure below.
We provide a jupyter notebook for performing the visualization and you can download it with
`wget https://s3.us-west-2.amazonaws.com/dgl-data/dgllife/attentive_fp/atom_weight_visualization.ipynb`
from the s3 bucket in U.S. or
`wget https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dgllife/attentive_fp/atom_weight_visualization.ipynb`
from the s3 bucket in China.
![](https://s3.us-west-2.amazonaws.com/dgl-data/dgllife/attentive_fp_vis_example.png)
## Dataset Customization
Generally we follow the practice of PyTorch.
A dataset class should implement `__getitem__(self, index)` and `__len__(self)` method
```python
class CustomDataset(object):
def __init__(self):
pass
def __getitem__(self, index):
"""
Parameters
----------
index : int
Index for the datapoint.
Returns
-------
str
SMILES for the molecule
DGLGraph
Constructed DGLGraph for the molecule
1D Tensor of dtype float32
Labels of the datapoint
"""
return self.smiles[index], self.graphs[index], self.labels[index]
def __len__(self):
return len(self.smiles)
```
We provide various methods for graph construction in `dgllife.utils.mol_to_graph`. If your dataset can
be converted to a pandas dataframe, e.g. a .csv file, you may use `MoleculeCSVDataset` in
`dgllife.data.csv_dataset`.
## References
[1] Wu et al. (2017) MoleculeNet: a benchmark for molecular machine learning. *Chemical Science* 9, 513-530.
[2] Duvenaud et al. (2015) Convolutional networks on graphs for learning molecular fingerprints. *Advances in neural
information processing systems (NeurIPS)*, 2224-2232.
[3] Kipf et al. (2017) Semi-Supervised Classification with Graph Convolutional Networks.
*The International Conference on Learning Representations (ICLR)*.
[4] Schütt et al. (2017) SchNet: A continuous-filter convolutional neural network for modeling quantum interactions.
*Advances in Neural Information Processing Systems (NeurIPS)*, 992-1002.
[5] Lu et al. (2019) Molecular Property Prediction: A Multilevel Quantum Interactions Modeling Perspective.
*The 33rd AAAI Conference on Artificial Intelligence*.
[6] Gilmer et al. (2017) Neural Message Passing for Quantum Chemistry. *Proceedings of the 34th International Conference on
Machine Learning*, JMLR. 1263-1272.
[7] Veličković et al. (2018) Graph Attention Networks.
*The International Conference on Learning Representations (ICLR)*.
[8] Xiong et al. (2019) Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph
Attention Mechanism. *Journal of Medicinal Chemistry*.
import numpy as np
import torch
from dgllife.model import load_pretrained
from dgllife.utils import EarlyStopping, Meter
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from utils import set_random_seed, load_dataset_for_classification, collate_molgraphs, load_model
def run_a_train_epoch(args, epoch, model, data_loader, loss_criterion, optimizer):
model.train()
train_meter = Meter()
for batch_id, batch_data in enumerate(data_loader):
smiles, bg, labels, masks = batch_data
atom_feats = bg.ndata.pop(args['atom_data_field'])
atom_feats, labels, masks = atom_feats.to(args['device']), \
labels.to(args['device']), \
masks.to(args['device'])
logits = model(bg, atom_feats)
# Mask non-existing labels
loss = (loss_criterion(logits, labels) * (masks != 0).float()).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('epoch {:d}/{:d}, batch {:d}/{:d}, loss {:.4f}'.format(
epoch + 1, args['num_epochs'], batch_id + 1, len(data_loader), loss.item()))
train_meter.update(logits, labels, masks)
train_score = np.mean(train_meter.compute_metric(args['metric_name']))
print('epoch {:d}/{:d}, training {} {:.4f}'.format(
epoch + 1, args['num_epochs'], args['metric_name'], train_score))
def run_an_eval_epoch(args, model, data_loader):
model.eval()
eval_meter = Meter()
with torch.no_grad():
for batch_id, batch_data in enumerate(data_loader):
smiles, bg, labels, masks = batch_data
atom_feats = bg.ndata.pop(args['atom_data_field'])
atom_feats, labels = atom_feats.to(args['device']), labels.to(args['device'])
logits = model(bg, atom_feats)
eval_meter.update(logits, labels, masks)
return np.mean(eval_meter.compute_metric(args['metric_name']))
def main(args):
args['device'] = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
set_random_seed(args['random_seed'])
# Interchangeable with other datasets
dataset, train_set, val_set, test_set = load_dataset_for_classification(args)
train_loader = DataLoader(train_set, batch_size=args['batch_size'],
collate_fn=collate_molgraphs)
val_loader = DataLoader(val_set, batch_size=args['batch_size'],
collate_fn=collate_molgraphs)
test_loader = DataLoader(test_set, batch_size=args['batch_size'],
collate_fn=collate_molgraphs)
if args['pre_trained']:
args['num_epochs'] = 0
model = load_pretrained(args['exp'])
else:
args['n_tasks'] = dataset.n_tasks
model = load_model(args)
loss_criterion = BCEWithLogitsLoss(pos_weight=dataset.task_pos_weights.to(args['device']),
reduction='none')
optimizer = Adam(model.parameters(), lr=args['lr'])
stopper = EarlyStopping(patience=args['patience'])
model.to(args['device'])
for epoch in range(args['num_epochs']):
# Train
run_a_train_epoch(args, epoch, model, train_loader, loss_criterion, optimizer)
# Validation and early stop
val_score = run_an_eval_epoch(args, model, val_loader)
early_stop = stopper.step(val_score, model)
print('epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'.format(
epoch + 1, args['num_epochs'], args['metric_name'],
val_score, args['metric_name'], stopper.best_score))
if early_stop:
break
if not args['pre_trained']:
stopper.load_checkpoint(model)
test_score = run_an_eval_epoch(args, model, test_loader)
print('test {} {:.4f}'.format(args['metric_name'], test_score))
if __name__ == '__main__':
import argparse
from configure import get_exp_configure
parser = argparse.ArgumentParser(description='Molecule Classification')
parser.add_argument('-m', '--model', type=str, choices=['GCN', 'GAT'],
help='Model to use')
parser.add_argument('-d', '--dataset', type=str, choices=['Tox21'],
help='Dataset to use')
parser.add_argument('-p', '--pre-trained', action='store_true',
help='Whether to skip training and use a pre-trained model')
args = parser.parse_args().__dict__
args['exp'] = '_'.join([args['model'], args['dataset']])
args.update(get_exp_configure(args['exp']))
main(args)
from functools import partial
from dgllife.utils.featurizers import CanonicalAtomFeaturizer, BaseAtomFeaturizer, \
BaseBondFeaturizer, ConcatFeaturizer, atom_type_one_hot, atom_degree_one_hot, \
atom_formal_charge, atom_num_radical_electrons, atom_hybridization_one_hot, \
atom_total_num_H_one_hot
from utils import chirality
GCN_Tox21 = {
'random_seed': 2,
'batch_size': 128,
'lr': 1e-3,
'num_epochs': 100,
'atom_data_field': 'h',
'frac_train': 0.8,
'frac_val': 0.1,
'frac_test': 0.1,
'in_feats': 74,
'gcn_hidden_feats': [64, 64],
'classifier_hidden_feats': 64,
'patience': 10,
'atom_featurizer': CanonicalAtomFeaturizer(),
'metric_name': 'roc_auc_score'
}
GAT_Tox21 = {
'random_seed': 2,
'batch_size': 128,
'lr': 1e-3,
'num_epochs': 100,
'atom_data_field': 'h',
'frac_train': 0.8,
'frac_val': 0.1,
'frac_test': 0.1,
'in_feats': 74,
'gat_hidden_feats': [32, 32],
'classifier_hidden_feats': 64,
'num_heads': [4, 4],
'patience': 10,
'atom_featurizer': CanonicalAtomFeaturizer(),
'metric_name': 'roc_auc_score'
}
MPNN_Alchemy = {
'random_seed': 0,
'batch_size': 16,
'num_epochs': 250,
'node_in_feats': 15,
'node_out_feats': 64,
'edge_in_feats': 5,
'edge_hidden_feats': 128,
'n_tasks': 12,
'lr': 0.0001,
'patience': 50,
'metric_name': 'mae',
'weight_decay': 0
}
SchNet_Alchemy = {
'random_seed': 0,
'batch_size': 16,
'num_epochs': 250,
'node_feats': 64,
'hidden_feats': [64, 64, 64],
'classifier_hidden_feats': 64,
'n_tasks': 12,
'lr': 0.0001,
'patience': 50,
'metric_name': 'mae',
'weight_decay': 0
}
MGCN_Alchemy = {
'random_seed': 0,
'batch_size': 16,
'num_epochs': 250,
'feats': 128,
'n_layers': 3,
'classifier_hidden_feats': 64,
'n_tasks': 12,
'lr': 0.0001,
'patience': 50,
'metric_name': 'mae',
'weight_decay': 0
}
AttentiveFP_Aromaticity = {
'random_seed': 8,
'graph_feat_size': 200,
'num_layers': 2,
'num_timesteps': 2,
'node_feat_size': 39,
'edge_feat_size': 10,
'n_tasks': 1,
'dropout': 0.2,
'weight_decay': 10 ** (-5.0),
'lr': 10 ** (-2.5),
'batch_size': 128,
'num_epochs': 800,
'frac_train': 0.8,
'frac_val': 0.1,
'frac_test': 0.1,
'patience': 80,
'metric_name': 'rmse',
# Follow the atom featurization in the original work
'atom_featurizer': BaseAtomFeaturizer(
featurizer_funcs={'hv': ConcatFeaturizer([
partial(atom_type_one_hot, allowable_set=[
'B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br', 'Te', 'I', 'At'],
encode_unknown=True),
partial(atom_degree_one_hot, allowable_set=list(range(6))),
atom_formal_charge, atom_num_radical_electrons,
partial(atom_hybridization_one_hot, encode_unknown=True),
lambda atom: [0], # A placeholder for aromatic information,
atom_total_num_H_one_hot, chirality
],
)}
),
'bond_featurizer': BaseBondFeaturizer({
'he': lambda bond: [0 for _ in range(10)]
})
}
experiment_configures = {
'GCN_Tox21': GCN_Tox21,
'GAT_Tox21': GAT_Tox21,
'MPNN_Alchemy': MPNN_Alchemy,
'SchNet_Alchemy': SchNet_Alchemy,
'MGCN_Alchemy': MGCN_Alchemy,
'AttentiveFP_Aromaticity': AttentiveFP_Aromaticity
}
def get_exp_configure(exp_name):
return experiment_configures[exp_name]
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from dgllife.model import load_pretrained
from dgllife.utils import EarlyStopping, Meter
from utils import set_random_seed, load_dataset_for_regression, collate_molgraphs, load_model
def regress(args, model, bg):
if args['model'] == 'MPNN':
h = bg.ndata.pop('n_feat')
e = bg.edata.pop('e_feat')
h, e = h.to(args['device']), e.to(args['device'])
return model(bg, h, e)
elif args['model'] in ['SchNet', 'MGCN']:
node_types = bg.ndata.pop('node_type')
edge_distances = bg.edata.pop('distance')
node_types, edge_distances = node_types.to(args['device']), \
edge_distances.to(args['device'])
return model(bg, node_types, edge_distances)
else:
atom_feats, bond_feats = bg.ndata.pop('hv'), bg.edata.pop('he')
atom_feats, bond_feats = atom_feats.to(args['device']), bond_feats.to(args['device'])
return model(bg, atom_feats, bond_feats)
def run_a_train_epoch(args, epoch, model, data_loader,
loss_criterion, optimizer):
model.train()
train_meter = Meter()
for batch_id, batch_data in enumerate(data_loader):
smiles, bg, labels, masks = batch_data
labels, masks = labels.to(args['device']), masks.to(args['device'])
prediction = regress(args, model, bg)
loss = (loss_criterion(prediction, labels) * (masks != 0).float()).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_meter.update(prediction, labels, masks)
total_score = np.mean(train_meter.compute_metric(args['metric_name']))
print('epoch {:d}/{:d}, training {} {:.4f}'.format(
epoch + 1, args['num_epochs'], args['metric_name'], total_score))
def run_an_eval_epoch(args, model, data_loader):
model.eval()
eval_meter = Meter()
with torch.no_grad():
for batch_id, batch_data in enumerate(data_loader):
smiles, bg, labels, masks = batch_data
labels = labels.to(args['device'])
prediction = regress(args, model, bg)
eval_meter.update(prediction, labels, masks)
total_score = np.mean(eval_meter.compute_metric(args['metric_name']))
return total_score
def main(args):
args['device'] = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
set_random_seed(args['random_seed'])
train_set, val_set, test_set = load_dataset_for_regression(args)
train_loader = DataLoader(dataset=train_set,
batch_size=args['batch_size'],
shuffle=True,
collate_fn=collate_molgraphs)
val_loader = DataLoader(dataset=val_set,
batch_size=args['batch_size'],
shuffle=True,
collate_fn=collate_molgraphs)
if test_set is not None:
test_loader = DataLoader(dataset=test_set,
batch_size=args['batch_size'],
collate_fn=collate_molgraphs)
if args['pre_trained']:
args['num_epochs'] = 0
model = load_pretrained(args['exp'])
else:
model = load_model(args)
loss_fn = nn.MSELoss(reduction='none')
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'],
weight_decay=args['weight_decay'])
stopper = EarlyStopping(mode='lower', patience=args['patience'])
model.to(args['device'])
for epoch in range(args['num_epochs']):
# Train
run_a_train_epoch(args, epoch, model, train_loader, loss_fn, optimizer)
# Validation and early stop
val_score = run_an_eval_epoch(args, model, val_loader)
early_stop = stopper.step(val_score, model)
print('epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'.format(
epoch + 1, args['num_epochs'], args['metric_name'], val_score,
args['metric_name'], stopper.best_score))
if early_stop:
break
if test_set is not None:
if not args['pre_trained']:
stopper.load_checkpoint(model)
test_score = run_an_eval_epoch(args, model, test_loader)
print('test {} {:.4f}'.format(args['metric_name'], test_score))
if __name__ == "__main__":
import argparse
from configure import get_exp_configure
parser = argparse.ArgumentParser(description='Molecule Regression')
parser.add_argument('-m', '--model', type=str,
choices=['MPNN', 'SchNet', 'MGCN', 'AttentiveFP'],
help='Model to use')
parser.add_argument('-d', '--dataset', type=str, choices=['Alchemy', 'Aromaticity'],
help='Dataset to use')
parser.add_argument('-p', '--pre-trained', action='store_true',
help='Whether to skip training and use a pre-trained model')
args = parser.parse_args().__dict__
args['exp'] = '_'.join([args['model'], args['dataset']])
args.update(get_exp_configure(args['exp']))
main(args)
import dgl
import numpy as np
import random
import torch
from dgllife.utils.featurizers import one_hot_encoding
from dgllife.utils.mol_to_graph import smiles_to_bigraph
from dgllife.utils.splitters import RandomSplitter
def set_random_seed(seed=0):
"""Set random seed.
Parameters
----------
seed : int
Random seed to use
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
def load_dataset_for_classification(args):
"""Load dataset for classification tasks.
Parameters
----------
args : dict
Configurations.
Returns
-------
dataset
The whole dataset.
train_set
Subset for training.
val_set
Subset for validation.
test_set
Subset for test.
"""
assert args['dataset'] in ['Tox21']
if args['dataset'] == 'Tox21':
from dgllife.data import Tox21
dataset = Tox21(smiles_to_bigraph, args['atom_featurizer'])
train_set, val_set, test_set = RandomSplitter.train_val_test_split(
dataset, frac_train=args['frac_train'], frac_val=args['frac_val'],
frac_test=args['frac_test'], random_state=args['random_seed'])
return dataset, train_set, val_set, test_set
def load_dataset_for_regression(args):
"""Load dataset for regression tasks.
Parameters
----------
args : dict
Configurations.
Returns
-------
train_set
Subset for training.
val_set
Subset for validation.
test_set
Subset for test.
"""
assert args['dataset'] in ['Alchemy', 'Aromaticity']
if args['dataset'] == 'Alchemy':
from dgllife.data import TencentAlchemyDataset
train_set = TencentAlchemyDataset(mode='dev')
val_set = TencentAlchemyDataset(mode='valid')
test_set = None
if args['dataset'] == 'Aromaticity':
from dgllife.data import PubChemBioAssayAromaticity
dataset = PubChemBioAssayAromaticity(smiles_to_bigraph,
args['atom_featurizer'],
args['bond_featurizer'])
train_set, val_set, test_set = RandomSplitter.train_val_test_split(
dataset, frac_train=args['frac_train'], frac_val=args['frac_val'],
frac_test=args['frac_test'], random_state=args['random_seed'])
return train_set, val_set, test_set
def collate_molgraphs(data):
"""Batching a list of datapoints for dataloader.
Parameters
----------
data : list of 3-tuples or 4-tuples.
Each tuple is for a single datapoint, consisting of
a SMILES, a DGLGraph, all-task labels and optionally
a binary mask indicating the existence of labels.
Returns
-------
smiles : list
List of smiles
bg : BatchedDGLGraph
Batched DGLGraphs
labels : Tensor of dtype float32 and shape (B, T)
Batched datapoint labels. B is len(data) and
T is the number of total tasks.
masks : Tensor of dtype float32 and shape (B, T)
Batched datapoint binary mask, indicating the
existence of labels. If binary masks are not
provided, return a tensor with ones.
"""
assert len(data[0]) in [3, 4], \
'Expect the tuple to be of length 3 or 4, got {:d}'.format(len(data[0]))
if len(data[0]) == 3:
smiles, graphs, labels = map(list, zip(*data))
masks = None
else:
smiles, graphs, labels, masks = map(list, zip(*data))
bg = dgl.batch(graphs)
bg.set_n_initializer(dgl.init.zero_initializer)
bg.set_e_initializer(dgl.init.zero_initializer)
labels = torch.stack(labels, dim=0)
if masks is None:
masks = torch.ones(labels.shape)
else:
masks = torch.stack(masks, dim=0)
return smiles, bg, labels, masks
def load_model(args):
if args['model'] == 'GCN':
from dgllife.model import GCNPredictor
model = GCNPredictor(in_feats=args['in_feats'],
hidden_feats=args['gcn_hidden_feats'],
classifier_hidden_feats=args['classifier_hidden_feats'],
n_tasks=args['n_tasks'])
if args['model'] == 'GAT':
from dgllife.model import GATPredictor
model = GATPredictor(in_feats=args['in_feats'],
hidden_feats=args['gat_hidden_feats'],
num_heads=args['num_heads'],
classifier_hidden_feats=args['classifier_hidden_feats'],
n_tasks=args['n_tasks'])
if args['model'] == 'AttentiveFP':
from dgllife.model import AttentiveFPPredictor
model = AttentiveFPPredictor(node_feat_size=args['node_feat_size'],
edge_feat_size=args['edge_feat_size'],
num_layers=args['num_layers'],
num_timesteps=args['num_timesteps'],
graph_feat_size=args['graph_feat_size'],
n_tasks=args['n_tasks'],
dropout=args['dropout'])
if args['model'] == 'SchNet':
from dgllife.model import SchNetPredictor
model = SchNetPredictor(node_feats=args['node_feats'],
hidden_feats=args['hidden_feats'],
classifier_hidden_feats=args['classifier_hidden_feats'],
n_tasks=args['n_tasks'])
if args['model'] == 'MGCN':
from dgllife.model import MGCNPredictor
model = MGCNPredictor(feats=args['feats'],
n_layers=args['n_layers'],
classifier_hidden_feats=args['classifier_hidden_feats'],
n_tasks=args['n_tasks'])
if args['model'] == 'MPNN':
from dgllife.model import MPNNPredictor
model = MPNNPredictor(node_in_feats=args['node_in_feats'],
edge_in_feats=args['edge_in_feats'],
node_out_feats=args['node_out_feats'],
edge_hidden_feats=args['edge_hidden_feats'],
n_tasks=args['n_tasks'])
return model
def chirality(atom):
try:
return one_hot_encoding(atom.GetProp('_CIPCode'), ['R', 'S']) + \
[atom.HasProp('_ChiralityPossible')]
except:
return [False, False] + [atom.HasProp('_ChiralityPossible')]
# Work Implemented in DGL-LifeSci
## Datasets/Benchmarks
- MoleculeNet: A Benchmark for Molecular Machine Learning [[paper]](https://arxiv.org/abs/1703.00564), [[website]](http://moleculenet.ai/)
- [Tox21 with DGL](https://github.com/dmlc/dgl/tree/master/apps/life_sci/dglls/data/tox21.py)
- [PDBBind with DGL](https://github.com/dmlc/dgl/tree/master/apps/life_sci/dglls/data/pdbbind.py)
- Alchemy: A Quantum Chemistry Dataset for Benchmarking AI Models [[paper]](https://arxiv.org/abs/1906.09427), [[github]](https://github.com/tencent-alchemy/Alchemy)
- [Alchemy with DGL](https://github.com/dmlc/dgl/tree/master/apps/life_sci/dglls/data/alchemy.py)
## Property Prediction
- Semi-Supervised Classification with Graph Convolutional Networks (GCN) [[paper]](https://arxiv.org/abs/1609.02907), [[github]](https://github.com/tkipf/gcn)
- [GCN-Based Predictor with DGL](https://github.com/dmlc/dgl/tree/master/apps/life_sci/dglls/model/model_zoo/gcn_predictor.py)
- [Example for Molecule Classification](https://github.com/dmlc/dgl/tree/master/apps/life_sci/examples/property_prediction/classification.py)
- Graph Attention Networks (GAT) [[paper]](https://arxiv.org/abs/1710.10903), [[github]](https://github.com/PetarV-/GAT)
- [GAT-Based Predictor with DGL](https://github.com/dmlc/dgl/tree/master/apps/life_sci/dglls/model/model_zoo/gat_predictor.py)
- [Example for Molecule Classification](https://github.com/dmlc/dgl/tree/master/apps/life_sci/examples/property_prediction/classification.py)
- SchNet: A continuous-filter convolutional neural network for modeling quantum interactions [[paper]](https://arxiv.org/abs/1706.08566), [[github]](https://github.com/atomistic-machine-learning/SchNet)
- [SchNet with DGL](https://github.com/dmlc/dgl/tree/master/apps/life_sci/dglls/model/model_zoo/schnet_predictor.py)
- [Example for Molecule Regression](https://github.com/dmlc/dgl/tree/master/apps/life_sci/examples/property_prediction/regression.py)
- Molecular Property Prediction: A Multilevel Quantum Interactions Modeling Perspective (MGCN) [[paper]](https://arxiv.org/abs/1906.11081)
- [MGCN with DGL](https://github.com/dmlc/dgl/tree/master/apps/life_sci/dglls/model/model_zoo/mgcn_predictor.py)
- [Example for Molecule Regression](https://github.com/dmlc/dgl/tree/master/apps/life_sci/examples/property_prediction/regression.py)
- Neural Message Passing for Quantum Chemistry (MPNN) [[paper]](https://arxiv.org/abs/1704.01212), [[github]](https://github.com/brain-research/mpnn)
- [MPNN with DGL](https://github.com/dmlc/dgl/tree/master/apps/life_sci/dglls/model/model_zoo/mpnn_predictor.py)
- [Example for Molecule Regression](https://github.com/dmlc/dgl/tree/master/apps/life_sci/examples/property_prediction/regression.py)
- Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism (AttentiveFP) [[paper]](https://pubs.acs.org/doi/abs/10.1021/acs.jmedchem.9b00959)
- [AttentiveFP with DGL](https://github.com/dmlc/dgl/tree/master/apps/life_sci/dglls/model/model_zoo/attentivefp_predictor.py)
- [Example for Molecule Regression](https://github.com/dmlc/dgl/tree/master/apps/life_sci/examples/property_prediction/regression.py)
## Generative Models
- Learning Deep Generative Models of Graphs (DGMG) [[paper]](https://arxiv.org/abs/1803.03324)
- [DGMG with DGL](https://github.com/dmlc/dgl/tree/master/apps/life_sci/dglls/model/model_zoo/dgmg.py)
- [Example Training Script](https://github.com/dmlc/dgl/tree/master/apps/life_sci/examples/generative_models/dgmg)
- Junction Tree Variational Autoencoder for Molecular Graph Generation (JTNN) [[paper]](https://arxiv.org/abs/1802.04364)
- [JTNN with DGL](https://github.com/dmlc/dgl/tree/master/apps/life_sci/dglls/model/model_zoo/jtnn)
- [Example Training Script](https://github.com/dmlc/dgl/tree/master/apps/life_sci/examples/generative_models/jtnn)
## Binding Affinity Prediction
- Atomic Convolutional Networks for Predicting Protein-Ligand Binding Affinity (ACNN) [[paper]](https://arxiv.org/abs/1703.10603), [[github]](https://github.com/deepchem/deepchem/tree/master/contrib/atomicconv)
- [ACNN with DGL](https://github.com/dmlc/dgl/tree/master/apps/life_sci/dglls/model/model_zoo/acnn.py)
- [Example Training Script](https://github.com/dmlc/dgl/tree/master/apps/life_sci/examples/binding_affinity_prediction)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import dgllife
import sys
from setuptools import find_packages
if '--inplace' in sys.argv:
from distutils.core import setup
else:
from setuptools import setup
setup(
name='dgllife',
version=dgllife.__version__,
description='DGL-based package for Chemistry',
keywords=[
'pytorch',
'dgl',
'graph-neural-networks',
'chemistry',
'drug-discovery'
],
zip_safe=False,
maintainer='DGL Team',
packages=[package for package in find_packages()
if package.startswith('dgllife')],
install_requires=[
'dgl>=0.4',
'torch>=1'
'scikit-learn>=0.21.2',
'pandas>=0.25.1',
'requests>=2.22.0'
],
url='https://github.com/dmlc/dgl/tree/master/apps/chem',
classifiers=[
'Programming Language :: Python :: 3',
]
)
import os
import pandas as pd
from dgllife.data.csv_dataset import *
from dgllife.utils.featurizers import *
from dgllife.utils.mol_to_graph import *
def test_data_frame():
data = [['CCO', 0, 1], ['CO', 2, 3]]
df = pd.DataFrame(data, columns = ['smiles', 'task1', 'task2'])
return df
def remove_file(fname):
if os.path.isfile(fname):
try:
os.remove(fname)
except OSError:
pass
def test_mol_csv():
df = test_data_frame()
fname = 'test.bin'
dataset = MoleculeCSVDataset(df=df, smiles_to_graph=smiles_to_bigraph,
node_featurizer=CanonicalAtomFeaturizer(),
edge_featurizer=CanonicalBondFeaturizer(),
smiles_column='smiles',
cache_file_path=fname)
assert dataset.task_names == ['task1', 'task2']
smiles, graph, label, mask = dataset[0]
assert label.shape[0] == 2
assert mask.shape[0] == 2
assert 'h' in graph.ndata
assert 'e' in graph.edata
# Test task_names
dataset = MoleculeCSVDataset(df=df, smiles_to_graph=smiles_to_bigraph,
node_featurizer=None,
edge_featurizer=None,
smiles_column='smiles',
cache_file_path=fname,
task_names=['task1'])
assert dataset.task_names == ['task1']
# Test load
dataset = MoleculeCSVDataset(df=df, smiles_to_graph=smiles_to_bigraph,
node_featurizer=CanonicalAtomFeaturizer(),
edge_featurizer=None,
smiles_column='smiles',
cache_file_path=fname,
load=True)
smiles, graph, label, mask = dataset[0]
assert 'h' in graph.ndata
assert 'e' in graph.edata
dataset = MoleculeCSVDataset(df=df, smiles_to_graph=smiles_to_bigraph,
node_featurizer=CanonicalAtomFeaturizer(),
edge_featurizer=None,
smiles_column='smiles',
cache_file_path=fname,
load=False)
smiles, graph, label, mask = dataset[0]
assert 'h' in graph.ndata
assert 'e' not in graph.edata
remove_file(fname)
if __name__ == '__main__':
test_mol_csv()
import os
from dgllife.data import *
def remove_file(fname):
if os.path.isfile(fname):
try:
os.remove(fname)
except OSError:
pass
def test_pubchem_aromaticity():
dataset = PubChemBioAssayAromaticity()
remove_file('pubchem_aromaticity_dglgraph.bin')
def test_tox21():
dataset = Tox21()
remove_file('tox21_dglgraph.bin')
def test_alchemy():
dataset = TencentAlchemyDataset(mode='valid',
node_featurizer=None,
edge_featurizer=None)
dataset = TencentAlchemyDataset(mode='valid',
node_featurizer=None,
edge_featurizer=None,
load=False)
def test_pdbbind():
dataset = PDBBind(subset='core', remove_hs=True)
if __name__ == '__main__':
test_pubchem_aromaticity()
test_tox21()
test_alchemy()
test_pdbbind()
import dgl
import os
import shutil
import torch
from dgl.data.utils import _get_dgl_url, download, extract_archive
from dgllife.model.model_zoo.acnn import ACNN
from dgllife.utils.complex_to_graph import ACNN_graph_construction_and_featurization
from dgllife.utils.rdkit_utils import load_molecule
def remove_dir(dir):
if os.path.isdir(dir):
try:
shutil.rmtree(dir)
except OSError:
pass
def test_acnn():
remove_dir('tmp1')
remove_dir('tmp2')
url = _get_dgl_url('dgllife/example_mols.tar.gz')
local_path = 'tmp1/example_mols.tar.gz'
download(url, path=local_path)
extract_archive(local_path, 'tmp2')
pocket_mol, pocket_coords = load_molecule(
'tmp2/example_mols/example.pdb', remove_hs=True)
ligand_mol, ligand_coords = load_molecule(
'tmp2/example_mols/example.pdbqt', remove_hs=True)
remove_dir('tmp1')
remove_dir('tmp2')
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g1 = ACNN_graph_construction_and_featurization(ligand_mol,
pocket_mol,
ligand_coords,
pocket_coords)
model = ACNN()
model.to(device)
g1.to(device)
assert model(g1).shape == torch.Size([1, 1])
bg = dgl.batch_hetero([g1, g1])
bg.to(device)
assert model(bg).shape == torch.Size([2, 1])
model = ACNN(hidden_sizes=[1, 2],
weight_init_stddevs=[1, 1],
dropouts=[0.1, 0.],
features_to_use=torch.tensor([6., 8.]),
radial=[[12.0], [0.0, 2.0], [4.0]])
model.to(device)
g1.to(device)
assert model(g1).shape == torch.Size([1, 1])
bg = dgl.batch_hetero([g1, g1])
bg.to(device)
assert model(bg).shape == torch.Size([2, 1])
if __name__ == '__main__':
test_acnn()
import torch
from rdkit import Chem
from dgllife.model import DGMG, DGLJTNNVAE
def test_dgmg():
model = DGMG(atom_types=['O', 'Cl', 'C', 'S', 'F', 'Br', 'N'],
bond_types=[Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE],
node_hidden_size=1,
num_prop_rounds=1,
dropout=0.2)
assert model(
actions=[(0, 2), (1, 3), (0, 0), (1, 0), (2, 0), (1, 3), (0, 7)], rdkit_mol=True) == 'CO'
assert model(rdkit_mol=False) is None
model.eval()
assert model(rdkit_mol=True) is not None
model = DGMG(atom_types=['O', 'Cl', 'C', 'S', 'F', 'Br', 'N'],
bond_types=[Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE])
assert model(
actions=[(0, 2), (1, 3), (0, 0), (1, 0), (2, 0), (1, 3), (0, 7)], rdkit_mol=True) == 'CO'
assert model(rdkit_mol=False) is None
model.eval()
assert model(rdkit_mol=True) is not None
def test_jtnn():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
model = DGLJTNNVAE(hidden_size=1,
latent_size=2,
depth=1).to(device)
if __name__ == '__main__':
test_dgmg()
test_jtnn()
import dgl
import torch
import torch.nn.functional as F
from dgl import DGLGraph
from dgllife.model.gnn import *
def test_graph1():
"""Graph with node features."""
g = DGLGraph([(0, 1), (0, 2), (1, 2)])
return g, torch.arange(g.number_of_nodes()).float().reshape(-1, 1)
def test_graph2():
"""Batched graph with node features."""
g1 = DGLGraph([(0, 1), (0, 2), (1, 2)])
g2 = DGLGraph([(0, 1), (1, 2), (1, 3), (1, 4)])
bg = dgl.batch([g1, g2])
return bg, torch.arange(bg.number_of_nodes()).float().reshape(-1, 1)
def test_graph3():
"""Graph with node and edge features."""
g = DGLGraph([(0, 1), (0, 2), (1, 2)])
return g, torch.arange(g.number_of_nodes()).float().reshape(-1, 1), \
torch.arange(2 * g.number_of_edges()).float().reshape(-1, 2)
def test_graph4():
"""Batched graph with node and edge features."""
g1 = DGLGraph([(0, 1), (0, 2), (1, 2)])
g2 = DGLGraph([(0, 1), (1, 2), (1, 3), (1, 4)])
bg = dgl.batch([g1, g2])
return bg, torch.arange(bg.number_of_nodes()).float().reshape(-1, 1), \
torch.arange(2 * bg.number_of_edges()).float().reshape(-1, 2)
def test_graph5():
"""Graph with node types and edge distances."""
g1 = DGLGraph([(0, 1), (0, 2), (1, 2)])
return g1, torch.LongTensor([0, 1, 0]), torch.randn(3, 1)
def test_graph6():
"""Batched graph with node types and edge distances."""
g1 = DGLGraph([(0, 1), (0, 2), (1, 2)])
g2 = DGLGraph([(0, 1), (1, 2), (1, 3), (1, 4)])
bg = dgl.batch([g1, g2])
return bg, torch.LongTensor([0, 1, 0, 2, 0, 3, 4, 4]), torch.randn(7, 1)
def test_gcn():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_feats = test_graph1()
g, node_feats = g.to(device), node_feats.to(device)
bg, batch_node_feats = test_graph2()
bg, batch_node_feats = bg.to(device), batch_node_feats.to(device)
# Test default setting
gnn = GCN(in_feats=1).to(device)
assert gnn(g, node_feats).shape == torch.Size([3, 64])
assert gnn(bg, batch_node_feats).shape == torch.Size([8, 64])
# Test configured setting
gnn = GCN(in_feats=1,
hidden_feats=[1, 1],
activation=[F.relu, F.relu],
residual=[True, True],
batchnorm=[True, True],
dropout=[0.2, 0.2]).to(device)
assert gnn(g, node_feats).shape == torch.Size([3, 1])
assert gnn(bg, batch_node_feats).shape == torch.Size([8, 1])
def test_gat():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_feats = test_graph1()
g, node_feats = g.to(device), node_feats.to(device)
bg, batch_node_feats = test_graph2()
bg, batch_node_feats = bg.to(device), batch_node_feats.to(device)
# Test default setting
gnn = GAT(in_feats=1)
assert gnn(g, node_feats).shape == torch.Size([3, 32])
assert gnn(bg, batch_node_feats).shape == torch.Size([8, 32])
# Test configured setting
gnn = GAT(in_feats=1,
hidden_feats=[1, 1],
num_heads=[2, 3],
feat_drops=[0.1, 0.1],
attn_drops=[0.1, 0.1],
alphas=[0.2, 0.2],
residuals=[True, True],
agg_modes=['flatten', 'mean'],
activations=[None, F.elu]).to(device)
assert gnn(g, node_feats).shape == torch.Size([3, 1])
assert gnn(bg, batch_node_feats).shape == torch.Size([8, 1])
gnn = GAT(in_feats=1,
hidden_feats=[1, 1],
num_heads=[2, 3],
feat_drops=[0.1, 0.1],
attn_drops=[0.1, 0.1],
alphas=[0.2, 0.2],
residuals=[True, True],
agg_modes=['mean', 'flatten'],
activations=[None, F.elu]).to(device)
assert gnn(g, node_feats).shape == torch.Size([3, 3])
assert gnn(bg, batch_node_feats).shape == torch.Size([8, 3])
def test_attentive_fp_gnn():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_feats, edge_feats = test_graph3()
g, node_feats, edge_feats = g.to(device), node_feats.to(device), edge_feats.to(device)
bg, batch_node_feats, batch_edge_feats = test_graph4()
bg, batch_node_feats, batch_edge_feats = bg.to(device), batch_node_feats.to(device), \
batch_edge_feats.to(device)
# Test AttentiveFPGNN
gnn = AttentiveFPGNN(node_feat_size=1,
edge_feat_size=2,
num_layers=1,
graph_feat_size=1,
dropout=0.).to(device)
assert gnn(g, node_feats, edge_feats).shape == torch.Size([3, 1])
assert gnn(bg, batch_node_feats, batch_edge_feats).shape == torch.Size([8, 1])
def test_schnet_gnn():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_types, edge_dists = test_graph5()
g, node_types, edge_dists = g.to(device), node_types.to(device), edge_dists.to(device)
bg, batch_node_types, batch_edge_dists = test_graph6()
bg, batch_node_types, batch_edge_dists = bg.to(device), batch_node_types.to(device), \
batch_edge_dists.to(device)
# Test default setting
gnn = SchNetGNN().to(device)
assert gnn(g, node_types, edge_dists).shape == torch.Size([3, 64])
assert gnn(bg, batch_node_types, batch_edge_dists).shape == torch.Size([8, 64])
# Test configured setting
gnn = SchNetGNN(num_node_types=5,
node_feats=2,
hidden_feats=[3],
cutoff=0.3).to(device)
assert gnn(g, node_types, edge_dists).shape == torch.Size([3, 2])
assert gnn(bg, batch_node_types, batch_edge_dists).shape == torch.Size([8, 2])
def test_mgcn_gnn():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_types, edge_dists = test_graph5()
g, node_types, edge_dists = g.to(device), node_types.to(device), edge_dists.to(device)
bg, batch_node_types, batch_edge_dists = test_graph6()
bg, batch_node_types, batch_edge_dists = bg.to(device), batch_node_types.to(device), \
batch_edge_dists.to(device)
# Test default setting
gnn = MGCNGNN().to(device)
assert gnn(g, node_types, edge_dists).shape == torch.Size([3, 512])
assert gnn(bg, batch_node_types, batch_edge_dists).shape == torch.Size([8, 512])
# Test configured setting
gnn = MGCNGNN(feats=2,
n_layers=2,
num_node_types=5,
num_edge_types=150,
cutoff=0.3).to(device)
assert gnn(g, node_types, edge_dists).shape == torch.Size([3, 6])
assert gnn(bg, batch_node_types, batch_edge_dists).shape == torch.Size([8, 6])
def test_mpnn_gnn():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_feats, edge_feats = test_graph3()
g, node_feats, edge_feats = g.to(device), node_feats.to(device), edge_feats.to(device)
bg, batch_node_feats, batch_edge_feats = test_graph4()
bg, batch_node_feats, batch_edge_feats = bg.to(device), batch_node_feats.to(device), \
batch_edge_feats.to(device)
# Test default setting
gnn = MPNNGNN(node_in_feats=1,
edge_in_feats=2)
assert gnn(g, node_feats, edge_feats).shape == torch.Size([3, 64])
assert gnn(bg, batch_node_feats, batch_edge_feats).shape == torch.Size([8, 64])
# Test configured setting
gnn = MPNNGNN(node_in_feats=1,
edge_in_feats=2,
node_out_feats=2,
edge_hidden_feats=2,
num_step_message_passing=2).to(device)
assert gnn(g, node_feats, edge_feats).shape == torch.Size([3, 2])
assert gnn(bg, batch_node_feats, batch_edge_feats).shape == torch.Size([8, 2])
if __name__ == '__main__':
test_gcn()
test_gat()
test_attentive_fp_gnn()
test_schnet_gnn()
test_mgcn_gnn()
test_mpnn_gnn()
import dgl
import os
import torch
from functools import partial
from dgllife.model import load_pretrained
from dgllife.utils import *
def remove_file(fname):
if os.path.isfile(fname):
try:
os.remove(fname)
except OSError:
pass
def run_dgmg_ChEMBL(model):
assert model(
actions=[(0, 2), (1, 3), (0, 0), (1, 0), (2, 0), (1, 3), (0, 7)],
rdkit_mol=True) == 'CO'
assert model(rdkit_mol=False) is None
model.eval()
assert model(rdkit_mol=True) is not None
def run_dgmg_ZINC(model):
assert model(
actions=[(0, 2), (1, 3), (0, 5), (1, 0), (2, 0), (1, 3), (0, 9)],
rdkit_mol=True) == 'CO'
assert model(rdkit_mol=False) is None
model.eval()
assert model(rdkit_mol=True) is not None
def test_dgmg():
model = load_pretrained('DGMG_ZINC_canonical')
run_dgmg_ZINC(model)
model = load_pretrained('DGMG_ZINC_random')
run_dgmg_ZINC(model)
model = load_pretrained('DGMG_ChEMBL_canonical')
run_dgmg_ChEMBL(model)
model = load_pretrained('DGMG_ChEMBL_random')
run_dgmg_ChEMBL(model)
remove_file('DGMG_ChEMBL_canonical_pre_trained.pth')
remove_file('DGMG_ChEMBL_random_pre_trained.pth')
remove_file('DGMG_ZINC_canonical_pre_trained.pth')
remove_file('DGMG_ZINC_random_pre_trained.pth')
def test_jtnn():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
model = load_pretrained('JTNN_ZINC').to(device)
remove_file('JTNN_ZINC_pre_trained.pth')
def test_gcn_tox21():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
node_featurizer = CanonicalAtomFeaturizer()
g1 = smiles_to_bigraph('CO', node_featurizer=node_featurizer)
g2 = smiles_to_bigraph('CCO', node_featurizer=node_featurizer)
bg = dgl.batch([g1, g2])
model = load_pretrained('GCN_Tox21').to(device)
model(bg.to(device), bg.ndata.pop('h').to(device))
model.eval()
model(g1.to(device), g1.ndata.pop('h').to(device))
remove_file('GCN_Tox21_pre_trained.pth')
def test_gat_tox21():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
node_featurizer = CanonicalAtomFeaturizer()
g1 = smiles_to_bigraph('CO', node_featurizer=node_featurizer)
g2 = smiles_to_bigraph('CCO', node_featurizer=node_featurizer)
bg = dgl.batch([g1, g2])
model = load_pretrained('GAT_Tox21').to(device)
model(bg.to(device), bg.ndata.pop('h').to(device))
model.eval()
model(g1.to(device), g1.ndata.pop('h').to(device))
remove_file('GAT_Tox21_pre_trained.pth')
def chirality(atom):
try:
return one_hot_encoding(atom.GetProp('_CIPCode'), ['R', 'S']) + \
[atom.HasProp('_ChiralityPossible')]
except:
return [False, False] + [atom.HasProp('_ChiralityPossible')]
def test_attentivefp_aromaticity():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
node_featurizer = BaseAtomFeaturizer(
featurizer_funcs={'hv': ConcatFeaturizer([
partial(atom_type_one_hot, allowable_set=[
'B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br', 'Te', 'I', 'At'],
encode_unknown=True),
partial(atom_degree_one_hot, allowable_set=list(range(6))),
atom_formal_charge, atom_num_radical_electrons,
partial(atom_hybridization_one_hot, encode_unknown=True),
lambda atom: [0], # A placeholder for aromatic information,
atom_total_num_H_one_hot, chirality
],
)}
)
edge_featurizer = BaseBondFeaturizer({
'he': lambda bond: [0 for _ in range(10)]
})
g1 = smiles_to_bigraph('CO', node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer)
g2 = smiles_to_bigraph('CCO', node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer)
bg = dgl.batch([g1, g2])
model = load_pretrained('AttentiveFP_Aromaticity').to(device)
model(bg.to(device), bg.ndata.pop('hv').to(device), bg.edata.pop('he').to(device))
model.eval()
model(g1.to(device), g1.ndata.pop('hv').to(device), g1.edata.pop('he').to(device))
remove_file('AttentiveFP_Aromaticity_pre_trained.pth')
if __name__ == '__main__':
test_dgmg()
test_jtnn()
test_gcn_tox21()
test_gat_tox21()
test_attentivefp_aromaticity()
import dgl
import torch
from dgl import DGLGraph
from dgllife.model.model_zoo import *
def test_graph1():
"""Graph with node features."""
g = DGLGraph([(0, 1), (0, 2), (1, 2)])
return g, torch.arange(g.number_of_nodes()).float().reshape(-1, 1)
def test_graph2():
"""Batched graph with node features."""
g1 = DGLGraph([(0, 1), (0, 2), (1, 2)])
g2 = DGLGraph([(0, 1), (1, 2), (1, 3), (1, 4)])
bg = dgl.batch([g1, g2])
return bg, torch.arange(bg.number_of_nodes()).float().reshape(-1, 1)
def test_graph3():
"""Graph with node features and edge features."""
g = DGLGraph([(0, 1), (0, 2), (1, 2)])
return g, torch.arange(g.number_of_nodes()).float().reshape(-1, 1), \
torch.arange(2 * g.number_of_edges()).float().reshape(-1, 2)
def test_graph4():
"""Batched graph with node features and edge features."""
g1 = DGLGraph([(0, 1), (0, 2), (1, 2)])
g2 = DGLGraph([(0, 1), (1, 2), (1, 3), (1, 4)])
bg = dgl.batch([g1, g2])
return bg, torch.arange(bg.number_of_nodes()).float().reshape(-1, 1), \
torch.arange(2 * bg.number_of_edges()).float().reshape(-1, 2)
def test_graph5():
"""Graph with node types and edge distances."""
g1 = DGLGraph([(0, 1), (0, 2), (1, 2)])
return g1, torch.LongTensor([0, 1, 0]), torch.randn(3, 1)
def test_graph6():
"""Batched graph with node types and edge distances."""
g1 = DGLGraph([(0, 1), (0, 2), (1, 2)])
g2 = DGLGraph([(0, 1), (1, 2), (1, 3), (1, 4)])
bg = dgl.batch([g1, g2])
return bg, torch.LongTensor([0, 1, 0, 2, 0, 3, 4, 4]), torch.randn(7, 1)
def test_mlp_predictor():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g_feats = torch.tensor([[1.], [2.]]).to(device)
mlp_predictor = MLPPredictor(in_feats=1, hidden_feats=1, n_tasks=2).to(device)
assert mlp_predictor(g_feats).shape == torch.Size([2, 2])
def test_gcn_predictor():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_feats = test_graph1()
g, node_feats = g.to(device), node_feats.to(device)
bg, batch_node_feats = test_graph2()
bg, batch_node_feats = bg.to(device), batch_node_feats.to(device)
# Test default setting
gcn_predictor = GCNPredictor(in_feats=1)
gcn_predictor.eval()
assert gcn_predictor(g, node_feats).shape == torch.Size([1, 1])
gcn_predictor.train()
assert gcn_predictor(bg, batch_node_feats).shape == torch.Size([2, 1])
# Test configured setting
gcn_predictor = GCNPredictor(in_feats=1,
hidden_feats=[1],
activation=[F.relu],
residual=[True],
batchnorm=[True],
dropout=[0.1],
classifier_hidden_feats=1,
classifier_dropout=0.1,
n_tasks=2).to(device)
gcn_predictor.eval()
assert gcn_predictor(g, node_feats).shape == torch.Size([1, 2])
gcn_predictor.train()
assert gcn_predictor(bg, batch_node_feats).shape == torch.Size([2, 2])
def test_gat_predictor():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_feats = test_graph1()
g, node_feats = g.to(device), node_feats.to(device)
bg, batch_node_feats = test_graph2()
bg, batch_node_feats = bg.to(device), batch_node_feats.to(device)
# Test default setting
gat_predictor = GATPredictor(in_feats=1).to(device)
gat_predictor.eval()
assert gat_predictor(g, node_feats).shape == torch.Size([1, 1])
gat_predictor.train()
assert gat_predictor(bg, batch_node_feats).shape == torch.Size([2, 1])
# Test configured setting
gat_predictor = GATPredictor(in_feats=1,
hidden_feats=[1, 2],
num_heads=[2, 3],
feat_drops=[0.1, 0.1],
attn_drops=[0.1, 0.1],
alphas=[0.1, 0.1],
residuals=[True, True],
agg_modes=['mean', 'flatten'],
activations=[None, F.elu]).to(device)
gat_predictor.eval()
assert gat_predictor(g, node_feats).shape == torch.Size([1, 1])
gat_predictor.train()
assert gat_predictor(bg, batch_node_feats).shape == torch.Size([2, 1])
def test_attentivefp_predictor():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_feats, edge_feats = test_graph3()
g, node_feats, edge_feats = g.to(device), node_feats.to(device), edge_feats.to(device)
bg, batch_node_feats, batch_edge_feats = test_graph4()
bg, batch_node_feats, batch_edge_feats = bg.to(device), batch_node_feats.to(device), \
batch_edge_feats.to(device)
attentivefp_predictor = AttentiveFPPredictor(node_feat_size=1,
edge_feat_size=2,
num_layers=2,
num_timesteps=1,
graph_feat_size=1,
n_tasks=2).to(device)
assert attentivefp_predictor(g, node_feats, edge_feats).shape == torch.Size([1, 2])
assert attentivefp_predictor(bg, batch_node_feats, batch_edge_feats).shape == \
torch.Size([2, 2])
def test_schnet_predictor():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_types, edge_dists = test_graph5()
g, node_types, edge_dists = g.to(device), node_types.to(device), edge_dists.to(device)
bg, batch_node_types, batch_edge_dists = test_graph6()
bg, batch_node_types, batch_edge_dists = bg.to(device), batch_node_types.to(device), \
batch_edge_dists.to(device)
# Test default setting
schnet_predictor = SchNetPredictor().to(device)
assert schnet_predictor(g, node_types, edge_dists).shape == torch.Size([1, 1])
assert schnet_predictor(bg, batch_node_types, batch_edge_dists).shape == \
torch.Size([2, 1])
# Test configured setting
schnet_predictor = SchNetPredictor(node_feats=2,
hidden_feats=[2, 2],
classifier_hidden_feats=3,
n_tasks=3,
num_node_types=5,
cutoff=0.3).to(device)
assert schnet_predictor(g, node_types, edge_dists).shape == torch.Size([1, 3])
assert schnet_predictor(bg, batch_node_types, batch_edge_dists).shape == \
torch.Size([2, 3])
def test_mgcn_predictor():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_types, edge_dists = test_graph5()
g, node_types, edge_dists = g.to(device), node_types.to(device), edge_dists.to(device)
bg, batch_node_types, batch_edge_dists = test_graph6()
bg, batch_node_types, batch_edge_dists = bg.to(device), batch_node_types.to(device), \
batch_edge_dists.to(device)
# Test default setting
mgcn_predictor = MGCNPredictor().to(device)
assert mgcn_predictor(g, node_types, edge_dists).shape == torch.Size([1, 1])
assert mgcn_predictor(bg, batch_node_types, batch_edge_dists).shape == \
torch.Size([2, 1])
# Test configured setting
mgcn_predictor = MGCNPredictor(feats=2,
n_layers=2,
classifier_hidden_feats=3,
n_tasks=3,
num_node_types=5,
num_edge_types=150,
cutoff=0.3).to(device)
assert mgcn_predictor(g, node_types, edge_dists).shape == torch.Size([1, 3])
assert mgcn_predictor(bg, batch_node_types, batch_edge_dists).shape == \
torch.Size([2, 3])
def test_mpnn_predictor():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_feats, edge_feats = test_graph3()
g, node_feats, edge_feats = g.to(device), node_feats.to(device), edge_feats.to(device)
bg, batch_node_feats, batch_edge_feats = test_graph4()
bg, batch_node_feats, batch_edge_feats = bg.to(device), batch_node_feats.to(device), \
batch_edge_feats.to(device)
# Test default setting
mpnn_predictor = MPNNPredictor(node_in_feats=1,
edge_in_feats=2)
assert mpnn_predictor(g, node_feats, edge_feats).shape == torch.Size([1, 1])
assert mpnn_predictor(bg, batch_node_feats, batch_edge_feats).shape == \
torch.Size([2, 1])
# Test configured setting
mpnn_predictor = MPNNPredictor(node_in_feats=1,
edge_in_feats=2,
node_out_feats=2,
edge_hidden_feats=2,
n_tasks=2,
num_step_message_passing=2,
num_step_set2set=2,
num_layer_set2set=2)
assert mpnn_predictor(g, node_feats, edge_feats).shape == torch.Size([1, 2])
assert mpnn_predictor(bg, batch_node_feats, batch_edge_feats).shape == \
torch.Size([2, 2])
if __name__ == '__main__':
test_mlp_predictor()
test_gcn_predictor()
test_gat_predictor()
test_attentivefp_predictor()
test_schnet_predictor()
test_mgcn_predictor()
test_mpnn_predictor()
import dgl
import torch
import torch.nn.functional as F
from dgl import DGLGraph
from dgllife.model.readout import *
def test_graph1():
"""Graph with node features"""
g = DGLGraph([(0, 1), (0, 2), (1, 2)])
return g, torch.arange(g.number_of_nodes()).float().reshape(-1, 1)
def test_graph2():
"Batched graph with node features"
g1 = DGLGraph([(0, 1), (0, 2), (1, 2)])
g2 = DGLGraph([(0, 1), (1, 2), (1, 3), (1, 4)])
bg = dgl.batch([g1, g2])
return bg, torch.arange(bg.number_of_nodes()).float().reshape(-1, 1)
def test_weighted_sum_and_max():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_feats = test_graph1()
g, node_feats = g.to(device), node_feats.to(device)
bg, batch_node_feats = test_graph2()
bg, batch_node_feats = bg.to(device), batch_node_feats.to(device)
model = WeightedSumAndMax(in_feats=1).to(device)
assert model(g, node_feats).shape == torch.Size([1, 2])
assert model(bg, batch_node_feats).shape == torch.Size([2, 2])
def test_attentive_fp_readout():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_feats = test_graph1()
g, node_feats = g.to(device), node_feats.to(device)
bg, batch_node_feats = test_graph2()
bg, batch_node_feats = bg.to(device), batch_node_feats.to(device)
model = AttentiveFPReadout(feat_size=1,
num_timesteps=1).to(device)
assert model(g, node_feats).shape == torch.Size([1, 1])
assert model(bg, batch_node_feats).shape == torch.Size([2, 1])
def test_mlp_readout():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_feats = test_graph1()
g, node_feats = g.to(device), node_feats.to(device)
bg, batch_node_feats = test_graph2()
bg, batch_node_feats = bg.to(device), batch_node_feats.to(device)
model = MLPNodeReadout(node_feats=1,
hidden_feats=2,
graph_feats=3,
activation=F.relu,
mode='sum').to(device)
assert model(g, node_feats).shape == torch.Size([1, 3])
assert model(bg, batch_node_feats).shape == torch.Size([2, 3])
model = MLPNodeReadout(node_feats=1,
hidden_feats=2,
graph_feats=3,
mode='max').to(device)
assert model(g, node_feats).shape == torch.Size([1, 3])
assert model(bg, batch_node_feats).shape == torch.Size([2, 3])
model = MLPNodeReadout(node_feats=1,
hidden_feats=2,
graph_feats=3,
mode='mean').to(device)
assert model(g, node_feats).shape == torch.Size([1, 3])
assert model(bg, batch_node_feats).shape == torch.Size([2, 3])
if __name__ == '__main__':
test_weighted_sum_and_max()
test_attentive_fp_readout()
test_mlp_readout()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment