Unverified Commit 828a5e5b authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[DGL-LifeSci] Migration and Refactor (#1226)

* First commit

* Update

* Update splitters

* Update

* Update

* Update

* Update

* Update

* Update

* Migrate ACNN

* Fix

* Fix

* Update

* Update

* Update

* Update

* Update

* Update

* Finish classification

* Update

* Fix

* Update

* Update

* Update

* Fix

* Fix

* Fix

* Update

* Update

* Update

* trigger CI

* Fix CI

* Update

* Update

* Update

* Add default values

* Rename

* Update deprecation message
parent e4948c5c
import os
import shutil
import torch
from dgl.data.utils import download, _get_dgl_url, extract_archive
from dgllife.utils.complex_to_graph import *
from dgllife.utils.rdkit_utils import load_molecule
def remove_dir(dir):
if os.path.isdir(dir):
try:
shutil.rmtree(dir)
except OSError:
pass
def test_acnn_graph_construction_and_featurization():
remove_dir('tmp1')
remove_dir('tmp2')
url = _get_dgl_url('dgllife/example_mols.tar.gz')
local_path = 'tmp1/example_mols.tar.gz'
download(url, path=local_path)
extract_archive(local_path, 'tmp2')
pocket_mol, pocket_coords = load_molecule(
'tmp2/example_mols/example.pdb', remove_hs=True)
ligand_mol, ligand_coords = load_molecule(
'tmp2/example_mols/example.pdbqt', remove_hs=True)
pocket_mol_with_h, pocket_coords_with_h = load_molecule(
'tmp2/example_mols/example.pdb', remove_hs=False)
remove_dir('tmp1')
remove_dir('tmp2')
# Test default case
g = ACNN_graph_construction_and_featurization(ligand_mol,
pocket_mol,
ligand_coords,
pocket_coords)
assert g.ntypes == ['protein_atom', 'ligand_atom']
assert g.etypes == ['protein', 'ligand', 'complex', 'complex', 'complex', 'complex']
assert g.number_of_nodes('protein_atom') == 286
assert g.number_of_nodes('ligand_atom') == 21
assert g.number_of_edges('protein') == 3432
assert g.number_of_edges('ligand') == 252
assert g.number_of_edges(('protein_atom', 'complex', 'protein_atom')) == 3349
assert g.number_of_edges(('ligand_atom', 'complex', 'ligand_atom')) == 131
assert g.number_of_edges(('protein_atom', 'complex', 'ligand_atom')) == 121
assert g.number_of_edges(('ligand_atom', 'complex', 'protein_atom')) == 83
assert 'atomic_number' in g.nodes['protein_atom'].data
assert 'atomic_number' in g.nodes['ligand_atom'].data
assert torch.allclose(g.nodes['protein_atom'].data['mask'],
torch.ones(g.number_of_nodes('protein_atom'), 1))
assert torch.allclose(g.nodes['ligand_atom'].data['mask'],
torch.ones(g.number_of_nodes('ligand_atom'), 1))
assert 'distance' in g.edges['protein'].data
assert 'distance' in g.edges['ligand'].data
assert 'distance' in g.edges[('protein_atom', 'complex', 'protein_atom')].data
assert 'distance' in g.edges[('ligand_atom', 'complex', 'ligand_atom')].data
assert 'distance' in g.edges[('protein_atom', 'complex', 'ligand_atom')].data
assert 'distance' in g.edges[('ligand_atom', 'complex', 'protein_atom')].data
# Test max_num_ligand_atoms and max_num_protein_atoms
max_num_ligand_atoms = 30
max_num_protein_atoms = 300
g = ACNN_graph_construction_and_featurization(ligand_mol,
pocket_mol,
ligand_coords,
pocket_coords,
max_num_ligand_atoms=max_num_ligand_atoms,
max_num_protein_atoms=max_num_protein_atoms)
assert g.number_of_nodes('ligand_atom') == max_num_ligand_atoms
assert g.number_of_nodes('protein_atom') == max_num_protein_atoms
ligand_mask = torch.zeros(max_num_ligand_atoms, 1)
ligand_mask[:ligand_mol.GetNumAtoms(), :] = 1.
assert torch.allclose(ligand_mask, g.nodes['ligand_atom'].data['mask'])
protein_mask = torch.zeros(max_num_protein_atoms, 1)
protein_mask[:pocket_mol.GetNumAtoms(), :] = 1.
assert torch.allclose(protein_mask, g.nodes['protein_atom'].data['mask'])
# Test neighbor_cutoff
neighbor_cutoff = 6.
g = ACNN_graph_construction_and_featurization(ligand_mol,
pocket_mol,
ligand_coords,
pocket_coords,
neighbor_cutoff=neighbor_cutoff)
assert g.number_of_edges('protein') == 3405
assert g.number_of_edges('ligand') == 193
assert g.number_of_edges(('protein_atom', 'complex', 'protein_atom')) == 3331
assert g.number_of_edges(('ligand_atom', 'complex', 'ligand_atom')) == 123
assert g.number_of_edges(('protein_atom', 'complex', 'ligand_atom')) == 119
assert g.number_of_edges(('ligand_atom', 'complex', 'protein_atom')) == 82
# Test max_num_neighbors
g = ACNN_graph_construction_and_featurization(ligand_mol,
pocket_mol,
ligand_coords,
pocket_coords,
max_num_neighbors=6)
assert g.number_of_edges('protein') == 1716
assert g.number_of_edges('ligand') == 126
assert g.number_of_edges(('protein_atom', 'complex', 'protein_atom')) == 1691
assert g.number_of_edges(('ligand_atom', 'complex', 'ligand_atom')) == 86
assert g.number_of_edges(('protein_atom', 'complex', 'ligand_atom')) == 40
assert g.number_of_edges(('ligand_atom', 'complex', 'protein_atom')) == 25
# Test strip_hydrogens
g = ACNN_graph_construction_and_featurization(pocket_mol_with_h,
pocket_mol_with_h,
pocket_coords_with_h,
pocket_coords_with_h,
strip_hydrogens=True)
assert g.number_of_nodes('ligand_atom') != pocket_mol_with_h.GetNumAtoms()
assert g.number_of_nodes('protein_atom') != pocket_mol_with_h.GetNumAtoms()
non_h_atomic_numbers = []
for i in range(pocket_mol_with_h.GetNumAtoms()):
atom = pocket_mol_with_h.GetAtomWithIdx(i)
if atom.GetSymbol() != 'H':
non_h_atomic_numbers.append(atom.GetAtomicNum())
non_h_atomic_numbers = torch.tensor(non_h_atomic_numbers).float().reshape(-1, 1)
assert torch.allclose(non_h_atomic_numbers, g.nodes['ligand_atom'].data['atomic_number'])
assert torch.allclose(non_h_atomic_numbers, g.nodes['protein_atom'].data['atomic_number'])
if __name__ == '__main__':
test_acnn_graph_construction_and_featurization()
import os
import torch
import torch.nn as nn
from dgllife.utils import EarlyStopping
def remove_file(fname):
if os.path.isfile(fname):
try:
os.remove(fname)
except OSError:
pass
def test_early_stopping_high():
model1 = nn.Linear(2, 3)
stopper = EarlyStopping(mode='higher',
patience=1,
filename='test.pkl')
# Save model in the first step
stopper.step(1., model1)
model1.weight.data = model1.weight.data + 1
model2 = nn.Linear(2, 3)
stopper.load_checkpoint(model2)
assert not torch.allclose(model1.weight, model2.weight)
# Save model checkpoint with performance improvement
model1.weight.data = model1.weight.data + 1
stopper.step(2., model1)
stopper.load_checkpoint(model2)
assert torch.allclose(model1.weight, model2.weight)
# Stop when no improvement observed
model1.weight.data = model1.weight.data + 1
assert stopper.step(0.5, model1)
stopper.load_checkpoint(model2)
assert not torch.allclose(model1.weight, model2.weight)
remove_file('test.pkl')
def test_early_stopping_low():
model1 = nn.Linear(2, 3)
stopper = EarlyStopping(mode='lower',
patience=1,
filename='test.pkl')
# Save model in the first step
stopper.step(1., model1)
model1.weight.data = model1.weight.data + 1
model2 = nn.Linear(2, 3)
stopper.load_checkpoint(model2)
assert not torch.allclose(model1.weight, model2.weight)
# Save model checkpoint with performance improvement
model1.weight.data = model1.weight.data + 1
stopper.step(0.5, model1)
stopper.load_checkpoint(model2)
assert torch.allclose(model1.weight, model2.weight)
# Stop when no improvement observed
model1.weight.data = model1.weight.data + 1
assert stopper.step(2, model1)
stopper.load_checkpoint(model2)
assert not torch.allclose(model1.weight, model2.weight)
remove_file('test.pkl')
if __name__ == '__main__':
test_early_stopping_high()
test_early_stopping_low()
import numpy as np
import torch
from dgllife.utils.eval import *
def test_Meter():
label = torch.tensor([[0., 1.],
[0., 1.],
[1., 0.]])
pred = torch.tensor([[0.5, 0.5],
[0., 1.],
[1., 0.]])
mask = torch.tensor([[1., 0.],
[0., 1.],
[1., 1.]])
label_mean, label_std = label.mean(dim=0), label.std(dim=0)
# pearson r2
meter = Meter(label_mean, label_std)
meter.update(label, pred)
true_scores = [0.7499999999999999, 0.7499999999999999]
assert meter.pearson_r2() == true_scores
assert meter.pearson_r2('mean') == np.mean(true_scores)
assert meter.pearson_r2('sum') == np.sum(true_scores)
assert meter.compute_metric('r2') == true_scores
assert meter.compute_metric('r2', 'mean') == np.mean(true_scores)
assert meter.compute_metric('r2', 'sum') == np.sum(true_scores)
meter = Meter(label_mean, label_std)
meter.update(label, pred, mask)
true_scores = [1.0, 1.0]
assert meter.pearson_r2() == true_scores
assert meter.pearson_r2('mean') == np.mean(true_scores)
assert meter.pearson_r2('sum') == np.sum(true_scores)
assert meter.compute_metric('r2') == true_scores
assert meter.compute_metric('r2', 'mean') == np.mean(true_scores)
assert meter.compute_metric('r2', 'sum') == np.sum(true_scores)
# mae
meter = Meter()
meter.update(label, pred)
true_scores = [0.1666666716337204, 0.1666666716337204]
assert meter.mae() == true_scores
assert meter.mae('mean') == np.mean(true_scores)
assert meter.mae('sum') == np.sum(true_scores)
assert meter.compute_metric('mae') == true_scores
assert meter.compute_metric('mae', 'mean') == np.mean(true_scores)
assert meter.compute_metric('mae', 'sum') == np.sum(true_scores)
meter = Meter()
meter.update(label, pred, mask)
true_scores = [0.25, 0.0]
assert meter.mae() == true_scores
assert meter.mae('mean') == np.mean(true_scores)
assert meter.mae('sum') == np.sum(true_scores)
assert meter.compute_metric('mae') == true_scores
assert meter.compute_metric('mae', 'mean') == np.mean(true_scores)
assert meter.compute_metric('mae', 'sum') == np.sum(true_scores)
# rmse
meter = Meter(label_mean, label_std)
meter.update(label, pred)
true_scores = [0.22125875529784111, 0.5937311018897714]
assert meter.rmse() == true_scores
assert meter.rmse('mean') == np.mean(true_scores)
assert meter.rmse('sum') == np.sum(true_scores)
assert meter.compute_metric('rmse') == true_scores
assert meter.compute_metric('rmse', 'mean') == np.mean(true_scores)
assert meter.compute_metric('rmse', 'sum') == np.sum(true_scores)
meter = Meter(label_mean, label_std)
meter.update(label, pred, mask)
true_scores = [0.1337071188699867, 0.5019903799993205]
assert meter.rmse() == true_scores
assert meter.rmse('mean') == np.mean(true_scores)
assert meter.rmse('sum') == np.sum(true_scores)
assert meter.compute_metric('rmse') == true_scores
assert meter.compute_metric('rmse', 'mean') == np.mean(true_scores)
assert meter.compute_metric('rmse', 'sum') == np.sum(true_scores)
# roc auc score
meter = Meter()
meter.update(label, pred)
true_scores = [1.0, 0.75]
assert meter.roc_auc_score() == true_scores
assert meter.roc_auc_score('mean') == np.mean(true_scores)
assert meter.roc_auc_score('sum') == np.sum(true_scores)
assert meter.compute_metric('roc_auc_score') == true_scores
assert meter.compute_metric('roc_auc_score', 'mean') == np.mean(true_scores)
assert meter.compute_metric('roc_auc_score', 'sum') == np.sum(true_scores)
meter = Meter()
meter.update(label, pred, mask)
true_scores = [1.0, 1.0]
assert meter.roc_auc_score() == true_scores
assert meter.roc_auc_score('mean') == np.mean(true_scores)
assert meter.roc_auc_score('sum') == np.sum(true_scores)
assert meter.compute_metric('roc_auc_score') == true_scores
assert meter.compute_metric('roc_auc_score', 'mean') == np.mean(true_scores)
assert meter.compute_metric('roc_auc_score', 'sum') == np.sum(true_scores)
if __name__ == '__main__':
test_Meter()
import torch
from dgllife.utils.featurizers import *
from rdkit import Chem
def test_one_hot_encoding():
x = 1.
allowable_set = [0., 1., 2.]
assert one_hot_encoding(x, allowable_set) == [0, 1, 0]
assert one_hot_encoding(x, allowable_set, encode_unknown=True) == [0, 1, 0, 0]
assert one_hot_encoding(x, allowable_set) == [0, 1, 0, 0]
assert one_hot_encoding(x, allowable_set, encode_unknown=True) == [0, 1, 0, 0]
assert one_hot_encoding(-1, allowable_set, encode_unknown=True) == [0, 0, 0, 1]
def test_mol1():
return Chem.MolFromSmiles('CCO')
def test_mol2():
return Chem.MolFromSmiles('C1=CC2=CC=CC=CC2=C1')
def test_atom_type_one_hot():
mol = test_mol1()
assert atom_type_one_hot(mol.GetAtomWithIdx(0), ['C', 'O']) == [1, 0]
assert atom_type_one_hot(mol.GetAtomWithIdx(2), ['C', 'O']) == [0, 1]
def test_atomic_number_one_hot():
mol = test_mol1()
assert atomic_number_one_hot(mol.GetAtomWithIdx(0), [6, 8]) == [1, 0]
assert atomic_number_one_hot(mol.GetAtomWithIdx(2), [6, 8]) == [0, 1]
def test_atomic_number():
mol = test_mol1()
assert atomic_number(mol.GetAtomWithIdx(0)) == [6]
assert atomic_number(mol.GetAtomWithIdx(2)) == [8]
def test_atom_degree_one_hot():
mol = test_mol1()
assert atom_degree_one_hot(mol.GetAtomWithIdx(0), [0, 1, 2]) == [0, 1, 0]
assert atom_degree_one_hot(mol.GetAtomWithIdx(1), [0, 1, 2]) == [0, 0, 1]
def test_atom_degree():
mol = test_mol1()
assert atom_degree(mol.GetAtomWithIdx(0)) == [1]
assert atom_degree(mol.GetAtomWithIdx(1)) == [2]
def test_atom_total_degree_one_hot():
mol = test_mol1()
assert atom_total_degree_one_hot(mol.GetAtomWithIdx(0), [0, 2, 4]) == [0, 0, 1]
assert atom_total_degree_one_hot(mol.GetAtomWithIdx(2), [0, 2, 4]) == [0, 1, 0]
def test_atom_total_degree():
mol = test_mol1()
assert atom_total_degree(mol.GetAtomWithIdx(0)) == [4]
assert atom_total_degree(mol.GetAtomWithIdx(2)) == [2]
def test_atom_implicit_valence_one_hot():
mol = test_mol1()
assert atom_implicit_valence_one_hot(mol.GetAtomWithIdx(0), [1, 2, 3]) == [0, 0, 1]
assert atom_implicit_valence_one_hot(mol.GetAtomWithIdx(1), [1, 2, 3]) == [0, 1, 0]
def test_atom_implicit_valence():
mol = test_mol1()
assert atom_implicit_valence(mol.GetAtomWithIdx(0)) == [3]
assert atom_implicit_valence(mol.GetAtomWithIdx(1)) == [2]
def test_atom_hybridization_one_hot():
mol = test_mol1()
assert atom_hybridization_one_hot(mol.GetAtomWithIdx(0)) == [0, 0, 1, 0, 0]
def test_atom_total_num_H_one_hot():
mol = test_mol1()
assert atom_total_num_H_one_hot(mol.GetAtomWithIdx(0)) == [0, 0, 0, 1, 0]
assert atom_total_num_H_one_hot(mol.GetAtomWithIdx(1)) == [0, 0, 1, 0, 0]
def test_atom_total_num_H():
mol = test_mol1()
assert atom_total_num_H(mol.GetAtomWithIdx(0)) == [3]
assert atom_total_num_H(mol.GetAtomWithIdx(1)) == [2]
def test_atom_formal_charge_one_hot():
mol = test_mol1()
assert atom_formal_charge_one_hot(mol.GetAtomWithIdx(0)) == [0, 0, 1, 0, 0]
def test_atom_formal_charge():
mol = test_mol1()
assert atom_formal_charge(mol.GetAtomWithIdx(0)) == [0]
def test_atom_num_radical_electrons_one_hot():
mol = test_mol1()
assert atom_num_radical_electrons_one_hot(mol.GetAtomWithIdx(0)) == [1, 0, 0, 0, 0]
def test_atom_num_radical_electrons():
mol = test_mol1()
assert atom_num_radical_electrons(mol.GetAtomWithIdx(0)) == [0]
def test_atom_is_aromatic_one_hot():
mol = test_mol1()
assert atom_is_aromatic_one_hot(mol.GetAtomWithIdx(0)) == [1, 0]
mol = test_mol2()
assert atom_is_aromatic_one_hot(mol.GetAtomWithIdx(0)) == [0, 1]
def test_atom_is_aromatic():
mol = test_mol1()
assert atom_is_aromatic(mol.GetAtomWithIdx(0)) == [0]
mol = test_mol2()
assert atom_is_aromatic(mol.GetAtomWithIdx(0)) == [1]
def test_atom_chiral_tag_one_hot():
mol = test_mol1()
assert atom_chiral_tag_one_hot(mol.GetAtomWithIdx(0)) == [1, 0, 0, 0]
def test_atom_mass():
mol = test_mol1()
atom = mol.GetAtomWithIdx(0)
assert atom_mass(atom) == [atom.GetMass() * 0.01]
atom = mol.GetAtomWithIdx(1)
assert atom_mass(atom) == [atom.GetMass() * 0.01]
def test_concat_featurizer():
test_featurizer = ConcatFeaturizer(
[atom_is_aromatic_one_hot, atom_chiral_tag_one_hot]
)
mol = test_mol1()
assert test_featurizer(mol.GetAtomWithIdx(0)) == [1, 0, 1, 0, 0, 0]
mol = test_mol2()
assert test_featurizer(mol.GetAtomWithIdx(0)) == [0, 1, 1, 0, 0, 0]
class TestAtomFeaturizer(BaseAtomFeaturizer):
def __init__(self):
super(TestAtomFeaturizer, self).__init__(
featurizer_funcs={
'h1': ConcatFeaturizer([atom_total_degree_one_hot,
atom_formal_charge_one_hot]),
'h2': ConcatFeaturizer([atom_num_radical_electrons_one_hot])
}
)
def test_base_atom_featurizer():
test_featurizer = TestAtomFeaturizer()
mol = test_mol1()
feats = test_featurizer(mol)
torch.allclose(feats['h1'], torch.tensor([[0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.],
[0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0.]]))
torch.allclose(feats['h2'], torch.tensor([[1., 0., 0., 0., 0.],
[1., 0., 0., 0., 0.],
[1., 0., 0., 0., 0.]]))
def test_canonical_atom_featurizer():
test_featurizer = CanonicalAtomFeaturizer()
mol = test_mol1()
feats = test_featurizer(mol)
assert list(feats.keys()) == ['h']
torch.allclose(feats['h'], torch.tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
1., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1.,
0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.,
0., 0.]]))
def test_bond_type_one_hot():
mol = test_mol1()
assert bond_type_one_hot(mol.GetBondWithIdx(0)) == [1, 0, 0, 0]
mol = test_mol2()
assert bond_type_one_hot(mol.GetBondWithIdx(0)) == [0, 0, 0, 1]
def test_bond_is_conjugated_one_hot():
mol = test_mol1()
assert bond_is_conjugated_one_hot(mol.GetBondWithIdx(0)) == [1, 0]
mol = test_mol2()
assert bond_is_conjugated_one_hot(mol.GetBondWithIdx(0)) == [0, 1]
def test_bond_is_conjugated():
mol = test_mol1()
assert bond_is_conjugated(mol.GetBondWithIdx(0)) == [0]
mol = test_mol2()
assert bond_is_conjugated(mol.GetBondWithIdx(0)) == [1]
def test_bond_is_in_ring_one_hot():
mol = test_mol1()
assert bond_is_in_ring_one_hot(mol.GetBondWithIdx(0)) == [1, 0]
mol = test_mol2()
assert bond_is_in_ring_one_hot(mol.GetBondWithIdx(0)) == [0, 1]
def test_bond_is_in_ring():
mol = test_mol1()
assert bond_is_in_ring(mol.GetBondWithIdx(0)) == [0]
mol = test_mol2()
assert bond_is_in_ring(mol.GetBondWithIdx(0)) == [1]
def test_bond_stereo_one_hot():
mol = test_mol1()
assert bond_stereo_one_hot(mol.GetBondWithIdx(0)) == [1, 0, 0, 0, 0, 0]
class TestBondFeaturizer(BaseBondFeaturizer):
def __init__(self):
super(TestBondFeaturizer, self).__init__(
featurizer_funcs={
'h1': ConcatFeaturizer([bond_is_in_ring, bond_is_conjugated]),
'h2': ConcatFeaturizer([bond_stereo_one_hot])
}
)
def test_base_bond_featurizer():
test_featurizer = TestBondFeaturizer()
mol = test_mol1()
feats = test_featurizer(mol)
assert torch.allclose(feats['h1'], torch.tensor([[0., 0.], [0., 0.], [0., 0.], [0., 0.]]))
assert torch.allclose(feats['h2'], torch.tensor([[1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0.]]))
def test_canonical_bond_featurizer():
test_featurizer = CanonicalBondFeaturizer()
mol = test_mol1()
feats = test_featurizer(mol)
assert torch.allclose(feats['e'], torch.tensor(
[[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]]))
if __name__ == '__main__':
test_one_hot_encoding()
test_atom_type_one_hot()
test_atomic_number_one_hot()
test_atomic_number()
test_atom_degree_one_hot()
test_atom_degree()
test_atom_total_degree_one_hot()
test_atom_total_degree()
test_atom_implicit_valence_one_hot()
test_atom_implicit_valence()
test_atom_hybridization_one_hot()
test_atom_total_num_H_one_hot()
test_atom_total_num_H()
test_atom_formal_charge_one_hot()
test_atom_formal_charge()
test_atom_num_radical_electrons_one_hot()
test_atom_num_radical_electrons()
test_atom_is_aromatic_one_hot()
test_atom_is_aromatic()
test_atom_chiral_tag_one_hot()
test_atom_mass()
test_concat_featurizer()
test_base_atom_featurizer()
test_canonical_atom_featurizer()
test_bond_type_one_hot()
test_bond_is_conjugated_one_hot()
test_bond_is_conjugated()
test_bond_is_in_ring_one_hot()
test_bond_is_in_ring()
test_bond_stereo_one_hot()
test_base_bond_featurizer()
test_canonical_bond_featurizer()
import numpy as np
import torch
from dgllife.utils.featurizers import *
from dgllife.utils.mol_to_graph import *
from rdkit import Chem
test_smiles1 = 'CCO'
test_smiles2 = 'Fc1ccccc1'
class TestAtomFeaturizer(BaseAtomFeaturizer):
def __init__(self):
super(TestAtomFeaturizer, self).__init__(
featurizer_funcs={'hv': ConcatFeaturizer([atomic_number])})
class TestBondFeaturizer(BaseBondFeaturizer):
def __init__(self):
super(TestBondFeaturizer, self).__init__(
featurizer_funcs={'he': ConcatFeaturizer([bond_is_in_ring])})
def test_smiles_to_bigraph():
# Test the case with self loops added.
g1 = smiles_to_bigraph(test_smiles1, add_self_loop=True)
src, dst = g1.edges()
assert torch.allclose(src, torch.LongTensor([0, 2, 2, 1, 0, 1, 2]))
assert torch.allclose(dst, torch.LongTensor([2, 0, 1, 2, 0, 1, 2]))
# Test the case without self loops.
test_node_featurizer = TestAtomFeaturizer()
test_edge_featurizer = TestBondFeaturizer()
g2 = smiles_to_bigraph(test_smiles2, add_self_loop=False,
node_featurizer=test_node_featurizer,
edge_featurizer=test_edge_featurizer)
assert torch.allclose(g2.ndata['hv'], torch.tensor([[9.], [6.], [6.], [6.],
[6.], [6.], [6.]]))
assert torch.allclose(g2.edata['he'], torch.tensor([[0.], [0.], [1.], [1.], [1.],
[1.], [1.], [1.], [1.], [1.],
[1.], [1.], [1.], [1.]]))
def test_mol_to_bigraph():
mol1 = Chem.MolFromSmiles(test_smiles1)
g1 = mol_to_bigraph(mol1, add_self_loop=True)
src, dst = g1.edges()
assert torch.allclose(src, torch.LongTensor([0, 2, 2, 1, 0, 1, 2]))
assert torch.allclose(dst, torch.LongTensor([2, 0, 1, 2, 0, 1, 2]))
# Test the case without self loops.
mol2 = Chem.MolFromSmiles(test_smiles2)
test_node_featurizer = TestAtomFeaturizer()
test_edge_featurizer = TestBondFeaturizer()
g2 = mol_to_bigraph(mol2, add_self_loop=False,
node_featurizer=test_node_featurizer,
edge_featurizer=test_edge_featurizer)
assert torch.allclose(g2.ndata['hv'], torch.tensor([[9.], [6.], [6.], [6.],
[6.], [6.], [6.]]))
assert torch.allclose(g2.edata['he'], torch.tensor([[0.], [0.], [1.], [1.], [1.],
[1.], [1.], [1.], [1.], [1.],
[1.], [1.], [1.], [1.]]))
def test_smiles_to_complete_graph():
test_node_featurizer = TestAtomFeaturizer()
g = smiles_to_complete_graph(test_smiles1, add_self_loop=False,
node_featurizer=test_node_featurizer)
src, dst = g.edges()
assert torch.allclose(src, torch.LongTensor([0, 0, 1, 1, 2, 2]))
assert torch.allclose(dst, torch.LongTensor([1, 2, 0, 2, 0, 1]))
assert torch.allclose(g.ndata['hv'], torch.tensor([[6.], [8.], [6.]]))
def test_mol_to_complete_graph():
test_node_featurizer = TestAtomFeaturizer()
mol1 = Chem.MolFromSmiles(test_smiles1)
g = mol_to_complete_graph(mol1, add_self_loop=False,
node_featurizer=test_node_featurizer)
src, dst = g.edges()
assert torch.allclose(src, torch.LongTensor([0, 0, 1, 1, 2, 2]))
assert torch.allclose(dst, torch.LongTensor([1, 2, 0, 2, 0, 1]))
assert torch.allclose(g.ndata['hv'], torch.tensor([[6.], [8.], [6.]]))
def test_k_nearest_neighbors():
coordinates = np.array([[0.1, 0.1, 0.1],
[0.2, 0.1, 0.1],
[0.15, 0.15, 0.1],
[0.1, 0.15, 0.16],
[1.2, 0.1, 0.1],
[1.3, 0.2, 0.1]])
neighbor_cutoff = 1.
max_num_neighbors = 2
srcs, dsts, dists = k_nearest_neighbors(coordinates, neighbor_cutoff, max_num_neighbors)
assert srcs == [2, 3, 2, 0, 0, 1, 0, 2, 5, 4]
assert dsts == [0, 0, 1, 1, 2, 2, 3, 3, 4, 5]
assert dists == [0.07071067811865474,
0.07810249675906654,
0.07071067811865477,
0.1,
0.07071067811865474,
0.07071067811865477,
0.07810249675906654,
0.07810249675906654,
0.14142135623730956,
0.14142135623730956]
if __name__ == '__main__':
test_smiles_to_bigraph()
test_mol_to_bigraph()
test_smiles_to_complete_graph()
test_mol_to_complete_graph()
test_k_nearest_neighbors()
import numpy as np
import os
import shutil
from dgl.data.utils import download, _get_dgl_url, extract_archive
from dgllife.utils.rdkit_utils import get_mol_3D_coordinates, load_molecule
from rdkit import Chem
from rdkit.Chem import AllChem
def test_get_mol_3D_coordinates():
mol = Chem.MolFromSmiles('CCO')
# Test the case when conformation does not exist
assert get_mol_3D_coordinates(mol) is None
# Test the case when conformation exists
AllChem.EmbedMolecule(mol)
AllChem.MMFFOptimizeMolecule(mol)
coords = get_mol_3D_coordinates(mol)
assert isinstance(coords, np.ndarray)
assert coords.shape == (mol.GetNumAtoms(), 3)
def remove_dir(dir):
if os.path.isdir(dir):
try:
shutil.rmtree(dir)
except OSError:
pass
def test_load_molecule():
remove_dir('tmp1')
remove_dir('tmp2')
url = _get_dgl_url('dgllife/example_mols.tar.gz')
local_path = 'tmp1/example_mols.tar.gz'
download(url, path=local_path)
extract_archive(local_path, 'tmp2')
load_molecule('tmp2/example_mols/example.sdf')
load_molecule('tmp2/example_mols/example.mol2', use_conformation=False, sanitize=True)
load_molecule('tmp2/example_mols/example.pdbqt', calc_charges=True)
mol, _ = load_molecule('tmp2/example_mols/example.pdb', remove_hs=True)
assert mol.GetNumAtoms() == mol.GetNumHeavyAtoms()
remove_dir('tmp1')
remove_dir('tmp2')
if __name__ == '__main__':
test_get_mol_3D_coordinates()
test_load_molecule()
import torch
from dgllife.utils.splitters import *
from rdkit import Chem
class TestDataset(object):
def __init__(self):
self.smiles = [
'CCO',
'C1CCCCC1',
'O1CCOCC1',
'C1CCCC2C1CCCC2',
'N#N'
]
self.mols = [Chem.MolFromSmiles(s) for s in self.smiles]
self.labels = torch.arange(2 * len(self.smiles)).reshape(len(self.smiles), -1)
def __getitem__(self, item):
return self.smiles[item], self.mols[item]
def __len__(self):
return len(self.smiles)
def test_consecutive_splitter(dataset):
ConsecutiveSplitter.train_val_test_split(dataset)
ConsecutiveSplitter.k_fold_split(dataset)
def test_random_splitter(dataset):
RandomSplitter.train_val_test_split(dataset, random_state=0)
RandomSplitter.k_fold_split(dataset)
def test_molecular_weight_splitter(dataset):
MolecularWeightSplitter.train_val_test_split(dataset)
MolecularWeightSplitter.k_fold_split(dataset, mols=dataset.mols)
def test_scaffold_splitter(dataset):
ScaffoldSplitter.train_val_test_split(dataset, include_chirality=True)
ScaffoldSplitter.k_fold_split(dataset, mols=dataset.mols)
def test_single_task_stratified_splitter(dataset):
SingleTaskStratifiedSplitter.train_val_test_split(dataset, dataset.labels, 1)
SingleTaskStratifiedSplitter.k_fold_split(dataset, dataset.labels, 1)
if __name__ == '__main__':
dataset = TestDataset()
test_consecutive_splitter(dataset)
test_random_splitter(dataset)
test_molecular_weight_splitter(dataset)
test_scaffold_splitter(dataset)
test_single_task_stratified_splitter(dataset)
......@@ -2,7 +2,7 @@ import torch
from torch.utils.data import Dataset
import dgl
from dgl.data.utils import download, extract_archive, get_download_dir
from dgl.data.utils import download, extract_archive, get_download_dir, _get_dgl_url
from .mol_tree_nx import DGLMolTree
from .mol_tree import Vocab
......@@ -11,8 +11,6 @@ from .jtmpn import mol2dgl_single as mol2dgl_dec
from .jtmpn import ATOM_FDIM as ATOM_FDIM_DEC
from .jtmpn import BOND_FDIM as BOND_FDIM_DEC
_url = 'https://s3-ap-southeast-1.amazonaws.com/dgl-data-cn/dataset/jtnn.zip'
def _unpack_field(examples, field):
return [e[field] for e in examples]
......@@ -28,7 +26,8 @@ class JTNNDataset(Dataset):
def __init__(self, data, vocab, training=True):
self.dir = get_download_dir()
self.zip_file_path='{}/jtnn.zip'.format(self.dir)
download(_url, path=self.zip_file_path)
download(_get_dgl_url('dglls/jtnn.zip'), path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/jtnn'.format(self.dir))
print('Loading data...')
data_file = '{}/jtnn/{}.txt'.format(self.dir, data)
......
......@@ -130,7 +130,9 @@ and compare their property distributions against the training molecule property
![](https://s3.us-east-2.amazonaws.com/dgl.ai/model_zoo/drug_discovery/dgmg/DGMG_ZINC_canonical_dist.png)
Download it with `wget https://s3.us-east-2.amazonaws.com/dgl.ai/model_zoo/drug_discovery/dgmg/eval_jupyter.ipynb`.
Download it with `wget https://s3.us-west-2.amazonaws.com/dgl-data/dglls/dgmg/eval_jupyter.ipynb` from the s3
bucket in U.S. or `wget https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dglls/dgmg/eval_jupyter.ipynb` from the
s3 bucket in China.
### Pre-trained models
......
import torch
from torch.utils.data import Dataset
import dgl
from dgl.data.utils import download, extract_archive, get_download_dir
import os
import torch
from torch.utils.data import Dataset
from dgl.data.utils import download, extract_archive, get_download_dir, _get_dgl_url
from .mol_tree import Vocab, DGLMolTree
from .chemutils import mol2dgl_dec, mol2dgl_enc
......@@ -16,8 +17,6 @@ MAX_NB = 10
PAPER = os.getenv('PAPER', False)
_url = 'https://s3-ap-southeast-1.amazonaws.com/dgl-data-cn/dataset/jtnn.zip'
def _unpack_field(examples, field):
return [e[field] for e in examples]
......@@ -33,7 +32,7 @@ class JTNNDataset(Dataset):
def __init__(self, data, vocab, training=True):
self.dir = get_download_dir()
self.zip_file_path='{}/jtnn.zip'.format(self.dir)
download(_url, path=self.zip_file_path)
download(_get_dgl_url('dglls/jtnn.zip'), path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/jtnn'.format(self.dir))
print('Loading data...')
if data in ['train', 'test']:
......
......@@ -117,8 +117,40 @@ We provide a jupyter notebook for performing the visualization and you can downl
## Dataset Customization
To customize your own dataset, see the instructions
[here](https://github.com/dmlc/dgl/tree/master/python/dgl/data/chem).
Generally we follow the practice of PyTorch.
A Dataset class should implement `__getitem__(self, index)` and `__len__(self)` method
```python
class CustomDataset(object):
def __init__(self):
pass
def __getitem__(self, index):
"""
Parameters
----------
index : int
Index for the datapoint.
Returns
-------
str
SMILES for the molecule
DGLGraph
Constructed DGLGraph for the molecule
1D Tensor of dtype float32
Labels of the datapoint
"""
return self.smiles[index], self.graphs[index], self.labels[index]
def __len__(self):
return len(self.smiles)
```
We provide various methods for graph construction in `dgl.data.chem.utils.mol_to_graph`. If your dataset can
be converted to a pandas dataframe, e.g. a .csv file, you may use `MoleculeCSVDataset` in
`dgl.data.chem.datasets.csv_dataset`.
## References
[1] Wu et al. (2017) MoleculeNet: a benchmark for molecular machine learning. *Chemical Science* 9, 513-530.
......
"""Decorator for deprecation message.
This is used in migrating the chem related code to DGL-LifeSci.
Todo(Mufei): remove it in v0.5.
The code is adapted from
https://stackoverflow.com/questions/2536307/
decorators-in-the-python-standard-lib-deprecated-specifically/48632082#48632082.
"""
import warnings
def deprecated(message, mode='func'):
"""Print formatted deprecation message.
Parameters
----------
message : str
mode : str
'func' for function and 'class' for class.
Return
------
callable
"""
assert mode in ['func', 'class']
def deprecated_decorator(func):
def deprecated_func(*args, **kwargs):
if mode == 'func':
warnings.warn("{} is deprecated and will be removed from dgl in v0.5. {}".format(
func.__name__, message), category=DeprecationWarning, stacklevel=2)
else:
warnings.warn("The class is deprecated and "
"will be removed from dgl in v0.5. {}".format(message),
category=DeprecationWarning, stacklevel=2)
warnings.simplefilter('default', DeprecationWarning)
return func(*args, **kwargs)
return deprecated_func
return deprecated_decorator
......@@ -13,6 +13,7 @@ from ..utils import mol_to_complete_graph, atom_type_one_hot, \
atom_hybridization_one_hot, atom_is_aromatic
from ...utils import download, get_download_dir, _get_dgl_url, save_graphs, load_graphs
from .... import backend as F
from ....contrib.deprecation import deprecated
try:
import pandas as pd
......@@ -147,11 +148,6 @@ class TencentAlchemyDataset(object):
'dev', 'valid' or 'test', separately for training, validation and test.
Default to be 'dev'. Note that 'test' is not available as the Alchemy
contest is ongoing.
from_raw : bool
Whether to process the dataset from scratch or use a
processed one for faster speed. If you use different ways
to featurize atoms or bonds, you should set this to be True.
Default to be False.
mol_to_graph: callable, str -> DGLGraph
A function turning an RDKit molecule instance into a DGLGraph.
Default to :func:`dgl.data.chem.mol_to_complete_graph`.
......@@ -174,11 +170,17 @@ class TencentAlchemyDataset(object):
excluding the self loops. We store the distance between the end atoms under the name
``"distance"`` and store the edge features under the name ``"e_feat"``. The edge
features represent one hot encoding of edge types (bond types and non-bond edges).
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
"""
def __init__(self, mode='dev', from_raw=False,
@deprecated('Import TencentAlchemyDataset from dgllife.data.alchemy instead.', 'class')
def __init__(self, mode='dev',
mol_to_graph=mol_to_complete_graph,
node_featurizer=alchemy_nodes,
edge_featurizer=alchemy_edges):
edge_featurizer=alchemy_edges,
load=True):
if mode == 'test':
raise ValueError('The test mode is not supported before '
'the Alchemy contest finishes.')
......@@ -189,13 +191,14 @@ class TencentAlchemyDataset(object):
self.mode = mode
# Construct DGLGraphs from raw data or use the preprocessed data
self.from_raw = from_raw
self.load = load
file_dir = osp.join(get_download_dir(), 'Alchemy_data')
if not from_raw:
if load:
file_name = "%s_processed_dgl" % (mode)
else:
file_name = "%s_single_sdf" % (mode)
self.file_dir = pathlib.Path(file_dir, file_name)
self._url = 'dataset/alchemy/'
......@@ -209,7 +212,7 @@ class TencentAlchemyDataset(object):
self._load(mol_to_graph, node_featurizer, edge_featurizer)
def _load(self, mol_to_graph, node_featurizer, edge_featurizer):
if not self.from_raw:
if self.load:
self.graphs, label_dict = load_graphs(osp.join(self.file_dir, "%s_graphs.bin" % self.mode))
self.labels = label_dict['labels']
with open(osp.join(self.file_dir, "%s_smiles.txt" % self.mode), 'r') as f:
......
......@@ -6,17 +6,18 @@ import sys
from ...utils import save_graphs, load_graphs
from .... import backend as F
from ....contrib.deprecation import deprecated
class MoleculeCSVDataset(object):
"""MoleculeCSVDataset
This is a general class for loading molecular data from csv or pd.DataFrame.
This is a general class for loading molecular data from pandas.DataFrame.
In data pre-processing, we set non-existing labels to be 0,
and returning mask with 1 where label exists.
All molecules are converted into DGLGraphs. After the first-time construction, the
DGLGraphs will be saved for reloading so that we do not need to reconstruct them every time.
DGLGraphs can be saved for reloading so that we do not need to reconstruct them every time.
Parameters
----------
......@@ -35,22 +36,33 @@ class MoleculeCSVDataset(object):
smiles_column: str
Column name that including smiles.
cache_file_path: str
Path to store the preprocessed data.
Path to store the preprocessed DGLGraphs. For example, this can be ``'dglgraph.bin'``.
task_names : list of str or None
Columns in the data frame corresponding to real-valued labels. If None, we assume
all columns except the smiles_column are labels. Default to None.
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
"""
@deprecated('Import MoleculeCSVDataset from dgllife.data instead.', 'class')
def __init__(self, df, smiles_to_graph, node_featurizer, edge_featurizer,
smiles_column, cache_file_path):
smiles_column, cache_file_path, task_names=None, load=True):
if 'rdkit' not in sys.modules:
from ....base import dgl_warning
dgl_warning(
"Please install RDKit (Recommended Version is 2018.09.3)")
self.df = df
self.smiles = self.df[smiles_column].tolist()
self.task_names = self.df.columns.drop([smiles_column]).tolist()
if task_names is None:
self.task_names = self.df.columns.drop([smiles_column]).tolist()
else:
self.task_names = task_names
self.n_tasks = len(self.task_names)
self.cache_file_path = cache_file_path
self._pre_process(smiles_to_graph, node_featurizer, edge_featurizer)
self._pre_process(smiles_to_graph, node_featurizer, edge_featurizer, load)
def _pre_process(self, smiles_to_graph, node_featurizer, edge_featurizer):
def _pre_process(self, smiles_to_graph, node_featurizer, edge_featurizer, load):
"""Pre-process the dataset
* Convert molecules from smiles format into DGLGraphs
......@@ -68,8 +80,12 @@ class MoleculeCSVDataset(object):
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph.
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
"""
if os.path.exists(self.cache_file_path):
if os.path.exists(self.cache_file_path) and load:
# DGLGraphs have been constructed before, reload them
print('Loading previously saved dgl graphs...')
self.graphs, label_dict = load_graphs(self.cache_file_path)
......
......@@ -6,6 +6,7 @@ import pandas as pd
from ..utils import multiprocess_load_molecules, ACNN_graph_construction_and_featurization
from ...utils import get_download_dir, download, _get_dgl_url, extract_archive
from .... import backend as F
from ....contrib.deprecation import deprecated
class PDBBind(object):
"""PDBbind dataset processed by MoleculeNet.
......@@ -67,6 +68,7 @@ class PDBBind(object):
Number of worker processes to use. If None,
then we will use the number of CPUs in the system. Default to 64.
"""
@deprecated('Import PDBBind from dgllife.data instead.', 'class')
def __init__(self, subset, load_binding_pocket=True, add_hydrogens=False,
sanitize=False, calc_charges=False, remove_hs=False, use_conformation=True,
construct_graph_and_featurize=ACNN_graph_construction_and_featurization,
......
......@@ -4,6 +4,8 @@ import sys
from .csv_dataset import MoleculeCSVDataset
from ..utils import smiles_to_bigraph
from ...utils import get_download_dir, download, _get_dgl_url
from ....base import dgl_warning
from ....contrib.deprecation import deprecated
class PubChemBioAssayAromaticity(MoleculeCSVDataset):
"""Subset of PubChem BioAssay Dataset for aromaticity prediction.
......@@ -27,12 +29,15 @@ class PubChemBioAssayAromaticity(MoleculeCSVDataset):
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to pre-process from scratch. Default to True.
"""
@deprecated('Import PubChemBioAssayAromaticity from dgllife.data instead.', 'class')
def __init__(self, smiles_to_graph=smiles_to_bigraph,
node_featurizer=None,
edge_featurizer=None):
node_featurizer=None, edge_featurizer=None, load=True):
if 'pandas' not in sys.modules:
from ....base import dgl_warning
dgl_warning("Please install pandas")
self._url = 'dataset/pubchem_bioassay_aromaticity.csv'
......@@ -40,5 +45,6 @@ class PubChemBioAssayAromaticity(MoleculeCSVDataset):
download(_get_dgl_url(self._url), path=data_path)
df = pd.read_csv(data_path)
super(PubChemBioAssayAromaticity, self).__init__(df, smiles_to_graph, node_featurizer, edge_featurizer,
"cano_smiles", "pubchem_aromaticity_dglgraph.bin")
super(PubChemBioAssayAromaticity, self).__init__(
df, smiles_to_graph, node_featurizer, edge_featurizer, "cano_smiles",
"pubchem_aromaticity_dglgraph.bin", load=load)
......@@ -4,6 +4,8 @@ from .csv_dataset import MoleculeCSVDataset
from ..utils import smiles_to_bigraph
from ...utils import get_download_dir, download, _get_dgl_url
from .... import backend as F
from ....base import dgl_warning
from ....contrib.deprecation import deprecated
try:
import pandas as pd
......@@ -38,12 +40,17 @@ class Tox21(MoleculeCSVDataset):
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
"""
@deprecated('Import Tox21 from dgllife.data instead.', 'class')
def __init__(self, smiles_to_graph=smiles_to_bigraph,
node_featurizer=None,
edge_featurizer=None):
edge_featurizer=None,
load=True):
if 'pandas' not in sys.modules:
from ....base import dgl_warning
dgl_warning("Please install pandas")
self._url = 'dataset/tox21.csv.gz'
......@@ -55,7 +62,7 @@ class Tox21(MoleculeCSVDataset):
df = df.drop(columns=['mol_id'])
super(Tox21, self).__init__(df, smiles_to_graph, node_featurizer, edge_featurizer,
"smiles", "tox21_dglgraph.bin")
"smiles", "tox21_dglgraph.bin", load=load)
self._weight_balancing()
def _weight_balancing(self):
......
......@@ -4,6 +4,7 @@ import numpy as np
from ..utils import k_nearest_neighbors
from .... import graph, bipartite, hetero_from_relations
from .... import backend as F
from ....contrib.deprecation import deprecated
__all__ = ['ACNN_graph_construction_and_featurization']
......@@ -49,6 +50,7 @@ def get_atomic_numbers(mol, indices):
atomic_numbers.append(atom.GetAtomicNum())
return atomic_numbers
@deprecated('Import it from dgllife.utils instead.')
def ACNN_graph_construction_and_featurization(ligand_mol,
protein_mol,
ligand_coordinates,
......@@ -72,11 +74,13 @@ def ACNN_graph_construction_and_featurization(ligand_mol,
protein_coordinates : Float Tensor of shape (V2, 3)
Atom coordinates in a protein.
max_num_ligand_atoms : int or None
Maximum number of atoms in ligands for zero padding.
If None, no zero padding will be performed. Default to None.
Maximum number of atoms in ligands for zero padding, which should be no smaller than
ligand_mol.GetNumAtoms() if not None. If None, no zero padding will be performed.
Default to None.
max_num_protein_atoms : int or None
Maximum number of atoms in proteins for zero padding.
If None, no zero padding will be performed. Default to None.
Maximum number of atoms in proteins for zero padding, which should be no smaller than
protein_mol.GetNumAtoms() if not None. If None, no zero padding will be performed.
Default to None.
neighbor_cutoff : float
Distance cutoff to define 'neighboring'. Default to 12.
max_num_neighbors : int
......@@ -86,6 +90,12 @@ def ACNN_graph_construction_and_featurization(ligand_mol,
"""
assert ligand_coordinates is not None, 'Expect ligand_coordinates to be provided.'
assert protein_coordinates is not None, 'Expect protein_coordinates to be provided.'
if max_num_ligand_atoms is not None:
assert max_num_ligand_atoms >= ligand_mol.GetNumAtoms(), \
'Expect max_num_ligand_atoms to be no smaller than ligand_mol.GetNumAtoms()'
if max_num_protein_atoms is not None:
assert max_num_protein_atoms >= protein_mol.GetNumAtoms(), \
'Expect max_num_protein_atoms to be no smaller than protein_mol.GetNumAtoms()'
if strip_hydrogens:
# Remove hydrogen atoms and their corresponding coordinates
......
......@@ -4,6 +4,7 @@ import numpy as np
from collections import defaultdict
from .... import backend as F
from ....contrib.deprecation import deprecated
try:
from rdkit import Chem
......@@ -44,6 +45,7 @@ __all__ = ['one_hot_encoding',
'BaseBondFeaturizer',
'CanonicalBondFeaturizer']
@deprecated('Import it from dgllife.utils instead.')
def one_hot_encoding(x, allowable_set, encode_unknown=False):
"""One-hot encoding.
......@@ -77,6 +79,7 @@ def one_hot_encoding(x, allowable_set, encode_unknown=False):
# Atom featurization
#################################################################
@deprecated('Import it from dgllife.utils instead.')
def atom_type_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the type of an atom.
......@@ -106,6 +109,7 @@ def atom_type_one_hot(atom, allowable_set=None, encode_unknown=False):
'Ni', 'Cd', 'In', 'Mn', 'Zr', 'Cr', 'Pt', 'Hg', 'Pb']
return one_hot_encoding(atom.GetSymbol(), allowable_set, encode_unknown)
@deprecated('Import it from dgllife.utils instead.')
def atomic_number_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the atomic number of an atom.
......@@ -128,6 +132,7 @@ def atomic_number_one_hot(atom, allowable_set=None, encode_unknown=False):
allowable_set = list(range(1, 101))
return one_hot_encoding(atom.GetAtomicNum(), allowable_set, encode_unknown)
@deprecated('Import it from dgllife.utils instead.')
def atomic_number(atom):
"""Get the atomic number for an atom.
......@@ -143,6 +148,7 @@ def atomic_number(atom):
"""
return [atom.GetAtomicNum()]
@deprecated('Import it from dgllife.utils instead.')
def atom_degree_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the degree of an atom.
......@@ -172,6 +178,7 @@ def atom_degree_one_hot(atom, allowable_set=None, encode_unknown=False):
allowable_set = list(range(11))
return one_hot_encoding(atom.GetDegree(), allowable_set, encode_unknown)
@deprecated('Import it from dgllife.utils instead.')
def atom_degree(atom):
"""Get the degree of an atom.
......@@ -194,6 +201,7 @@ def atom_degree(atom):
"""
return [atom.GetDegree()]
@deprecated('Import it from dgllife.utils instead.')
def atom_total_degree_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the degree of an atom including Hs.
......@@ -215,6 +223,7 @@ def atom_total_degree_one_hot(atom, allowable_set=None, encode_unknown=False):
allowable_set = list(range(6))
return one_hot_encoding(atom.GetTotalDegree(), allowable_set, encode_unknown)
@deprecated('Import it from dgllife.utils instead.')
def atom_total_degree(atom):
"""The degree of an atom including Hs.
......@@ -229,6 +238,7 @@ def atom_total_degree(atom):
"""
return [atom.GetTotalDegree()]
@deprecated('Import it from dgllife.utils instead.')
def atom_implicit_valence_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the implicit valences of an atom.
......@@ -251,6 +261,7 @@ def atom_implicit_valence_one_hot(atom, allowable_set=None, encode_unknown=False
allowable_set = list(range(7))
return one_hot_encoding(atom.GetImplicitValence(), allowable_set, encode_unknown)
@deprecated('Import it from dgllife.utils instead.')
def atom_implicit_valence(atom):
"""Get the implicit valence of an atom.
......@@ -266,6 +277,7 @@ def atom_implicit_valence(atom):
"""
return [atom.GetImplicitValence()]
@deprecated('Import it from dgllife.utils instead.')
def atom_hybridization_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the hybridization of an atom.
......@@ -294,6 +306,7 @@ def atom_hybridization_one_hot(atom, allowable_set=None, encode_unknown=False):
Chem.rdchem.HybridizationType.SP3D2]
return one_hot_encoding(atom.GetHybridization(), allowable_set, encode_unknown)
@deprecated('Import it from dgllife.utils instead.')
def atom_total_num_H_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the total number of Hs of an atom.
......@@ -316,6 +329,7 @@ def atom_total_num_H_one_hot(atom, allowable_set=None, encode_unknown=False):
allowable_set = list(range(5))
return one_hot_encoding(atom.GetTotalNumHs(), allowable_set, encode_unknown)
@deprecated('Import it from dgllife.utils instead.')
def atom_total_num_H(atom):
"""Get the total number of Hs of an atom.
......@@ -331,6 +345,7 @@ def atom_total_num_H(atom):
"""
return [atom.GetTotalNumHs()]
@deprecated('Import it from dgllife.utils instead.')
def atom_formal_charge_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the formal charge of an atom.
......@@ -353,6 +368,7 @@ def atom_formal_charge_one_hot(atom, allowable_set=None, encode_unknown=False):
allowable_set = list(range(-2, 3))
return one_hot_encoding(atom.GetFormalCharge(), allowable_set, encode_unknown)
@deprecated('Import it from dgllife.utils instead.')
def atom_formal_charge(atom):
"""Get formal charge for an atom.
......@@ -368,6 +384,7 @@ def atom_formal_charge(atom):
"""
return [atom.GetFormalCharge()]
@deprecated('Import it from dgllife.utils instead.')
def atom_num_radical_electrons_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the number of radical electrons of an atom.
......@@ -390,6 +407,7 @@ def atom_num_radical_electrons_one_hot(atom, allowable_set=None, encode_unknown=
allowable_set = list(range(5))
return one_hot_encoding(atom.GetNumRadicalElectrons(), allowable_set, encode_unknown)
@deprecated('Import it from dgllife.utils instead.')
def atom_num_radical_electrons(atom):
"""Get the number of radical electrons for an atom.
......@@ -405,6 +423,7 @@ def atom_num_radical_electrons(atom):
"""
return [atom.GetNumRadicalElectrons()]
@deprecated('Import it from dgllife.utils instead.')
def atom_is_aromatic_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for whether the atom is aromatic.
......@@ -427,6 +446,7 @@ def atom_is_aromatic_one_hot(atom, allowable_set=None, encode_unknown=False):
allowable_set = [False, True]
return one_hot_encoding(atom.GetIsAromatic(), allowable_set, encode_unknown)
@deprecated('Import it from dgllife.utils instead.')
def atom_is_aromatic(atom):
"""Get whether the atom is aromatic.
......@@ -442,6 +462,7 @@ def atom_is_aromatic(atom):
"""
return [atom.GetIsAromatic()]
@deprecated('Import it from dgllife.utils instead.')
def atom_chiral_tag_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the chiral tag of an atom.
......@@ -462,6 +483,7 @@ def atom_chiral_tag_one_hot(atom, allowable_set=None, encode_unknown=False):
Chem.rdchem.ChiralType.CHI_OTHER]
return one_hot_encoding(atom.GetChiralTag(), allowable_set, encode_unknown)
@deprecated('Import it from dgllife.utils instead.')
def atom_mass(atom, coef=0.01):
"""Get the mass of an atom and scale it.
......@@ -490,6 +512,7 @@ class ConcatFeaturizer(object):
``func(data_type) -> list of float or bool or int``. The resulting order of
the features will follow that of the functions in the list.
"""
@deprecated('Import ConcatFeaturizer from dgllife.utils instead.', 'class')
def __init__(self, func_list):
self.func_list = func_list
......@@ -541,6 +564,7 @@ class BaseAtomFeaturizer(object):
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])}
"""
@deprecated('Import BaseAtomFeaturizer from dgllife.utils instead.', 'class')
def __init__(self, featurizer_funcs, feat_sizes=None):
self.featurizer_funcs = featurizer_funcs
if feat_sizes is None:
......@@ -627,6 +651,7 @@ class CanonicalAtomFeaturizer(BaseAtomFeaturizer):
atom_data_field : str
Name for storing atom features in DGLGraphs, default to be 'h'.
"""
@deprecated('Import CanonicalAtomFeaturizer from dgllife.utils instead.', 'class')
def __init__(self, atom_data_field='h'):
super(CanonicalAtomFeaturizer, self).__init__(
featurizer_funcs={atom_data_field: ConcatFeaturizer(
......@@ -640,6 +665,7 @@ class CanonicalAtomFeaturizer(BaseAtomFeaturizer):
atom_total_num_H_one_hot]
)})
@deprecated('Import it from dgllife.utils instead.')
def bond_type_one_hot(bond, allowable_set=None, encode_unknown=False):
"""One hot encoding for the type of a bond.
......@@ -667,6 +693,7 @@ def bond_type_one_hot(bond, allowable_set=None, encode_unknown=False):
Chem.rdchem.BondType.AROMATIC]
return one_hot_encoding(bond.GetBondType(), allowable_set, encode_unknown)
@deprecated('Import it from dgllife.utils instead.')
def bond_is_conjugated_one_hot(bond, allowable_set=None, encode_unknown=False):
"""One hot encoding for whether the bond is conjugated.
Parameters
......@@ -687,6 +714,7 @@ def bond_is_conjugated_one_hot(bond, allowable_set=None, encode_unknown=False):
allowable_set = [False, True]
return one_hot_encoding(bond.GetIsConjugated(), allowable_set, encode_unknown)
@deprecated('Import it from dgllife.utils instead.')
def bond_is_conjugated(bond):
"""Get whether the bond is conjugated.
Parameters
......@@ -700,6 +728,7 @@ def bond_is_conjugated(bond):
"""
return [bond.GetIsConjugated()]
@deprecated('Import it from dgllife.utils instead.')
def bond_is_in_ring_one_hot(bond, allowable_set=None, encode_unknown=False):
"""One hot encoding for whether the bond is in a ring of any size.
Parameters
......@@ -720,6 +749,7 @@ def bond_is_in_ring_one_hot(bond, allowable_set=None, encode_unknown=False):
allowable_set = [False, True]
return one_hot_encoding(bond.IsInRing(), allowable_set, encode_unknown)
@deprecated('Import it from dgllife.utils instead.')
def bond_is_in_ring(bond):
"""Get whether the bond is in a ring of any size.
Parameters
......@@ -733,6 +763,7 @@ def bond_is_in_ring(bond):
"""
return [bond.IsInRing()]
@deprecated('Import it from dgllife.utils instead.')
def bond_stereo_one_hot(bond, allowable_set=None, encode_unknown=False):
"""One hot encoding for the stereo configuration of a bond.
Parameters
......@@ -795,6 +826,7 @@ class BaseBondFeaturizer(object):
[1., 0., 0., 0.]]),
'in_ring': tensor([[0.], [0.], [0.], [0.]])}
"""
@deprecated('Import BaseBondFeaturizer from dgllife.utils instead.', 'class')
def __init__(self, featurizer_funcs, feat_sizes=None):
self.featurizer_funcs = featurizer_funcs
if feat_sizes is None:
......@@ -867,6 +899,7 @@ class CanonicalBondFeaturizer(BaseBondFeaturizer):
**We assume the resulting DGLGraph will be created with :func:`smiles_to_bigraph` without
self loops.**
"""
@deprecated('Import CanonicalBondFeaturizer from dgllife.utils instead.', 'class')
def __init__(self, bond_data_field='e'):
super(CanonicalBondFeaturizer, self).__init__(
featurizer_funcs={bond_data_field: ConcatFeaturizer(
......
......@@ -4,6 +4,7 @@ import numpy as np
from functools import partial
from .... import DGLGraph
from ....contrib.deprecation import deprecated
try:
import mdtraj
......@@ -19,6 +20,7 @@ __all__ = ['mol_to_graph',
'mol_to_complete_graph',
'k_nearest_neighbors']
@deprecated('Import it from dgllife.utils instead.')
def mol_to_graph(mol, graph_constructor, node_featurizer, edge_featurizer):
"""Convert an RDKit molecule object into a DGLGraph and featurize for it.
......@@ -102,6 +104,7 @@ def construct_bigraph_from_mol(mol, add_self_loop=False):
return g
@deprecated('Import it from dgllife.utils instead.')
def mol_to_bigraph(mol, add_self_loop=False,
node_featurizer=None,
edge_featurizer=None):
......@@ -128,6 +131,7 @@ def mol_to_bigraph(mol, add_self_loop=False,
return mol_to_graph(mol, partial(construct_bigraph_from_mol, add_self_loop=add_self_loop),
node_featurizer, edge_featurizer)
@deprecated('Import it from dgllife.utils instead.')
def smiles_to_bigraph(smiles, add_self_loop=False,
node_featurizer=None,
edge_featurizer=None):
......@@ -192,6 +196,7 @@ def construct_complete_graph_from_mol(mol, add_self_loop=False):
return g
@deprecated('Import it from dgllife.utils instead.')
def mol_to_complete_graph(mol, add_self_loop=False,
node_featurizer=None,
edge_featurizer=None):
......@@ -218,6 +223,7 @@ def mol_to_complete_graph(mol, add_self_loop=False,
return mol_to_graph(mol, partial(construct_complete_graph_from_mol, add_self_loop=add_self_loop),
node_featurizer, edge_featurizer)
@deprecated('Import it from dgllife.utils instead.')
def smiles_to_complete_graph(smiles, add_self_loop=False,
node_featurizer=None,
edge_featurizer=None):
......@@ -244,6 +250,7 @@ def smiles_to_complete_graph(smiles, add_self_loop=False,
mol = Chem.MolFromSmiles(smiles)
return mol_to_complete_graph(mol, add_self_loop, node_featurizer, edge_featurizer)
@deprecated('Import it from dgllife.utils instead.')
def k_nearest_neighbors(coordinates, neighbor_cutoff, max_num_neighbors):
"""Find k nearest neighbors for each atom based on the 3D coordinates.
......@@ -259,8 +266,14 @@ def k_nearest_neighbors(coordinates, neighbor_cutoff, max_num_neighbors):
Returns
-------
neighbor_list : dict(int -> list of ints)
Mapping atom indices to their k nearest neighbors.
Returns
-------
srcs : list of int
Source nodes.
dsts : list of int
Destination nodes.
distances : list of float
Distances between the end nodes.
"""
num_atoms = coordinates.shape[0]
traj = mdtraj.Trajectory(coordinates.reshape((1, num_atoms, 3)), None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment