Commit 2ee3c78c authored by VoVAllen's avatar VoVAllen Committed by Mufei Li
Browse files

[Dataset] Tox21 (#760)

* tox21

* fix ci

* fix ci

* fix urls to url

* add doc

* remove binary
parent 17b60e1a
...@@ -11,6 +11,7 @@ from .reddit import RedditDataset ...@@ -11,6 +11,7 @@ from .reddit import RedditDataset
from .ppi import PPIDataset from .ppi import PPIDataset
from .tu import TUDataset from .tu import TUDataset
from .gindt import GINDataset from .gindt import GINDataset
from .chem import Tox21
def register_data_args(parser): def register_data_args(parser):
......
from .tox21 import Tox21
\ No newline at end of file
from __future__ import absolute_import
import dgl.backend as F
import numpy as np
import os
import pickle
import sys
from dgl import DGLGraph
from .utils import smile2graph
from ..utils import download, get_download_dir, _get_dgl_url, Subset
class CSVDataset(object):
"""CSVDataset
This is a general class for loading data from csv or pd.DataFrame.
In data pre-processing, we set non-existing labels to be 0, and returning mask with 1 where label exists.
All molecules are converted into DGLGraphs. After the first-time construction, the
DGLGraphs will be saved for reloading so that we do not need to reconstruct them every time.
Parameters
----------
df: pandas.DataFrame
Dataframe including smiles and labels. Can be loaded by pandas.read_csv(file_path).
One column includes smiles and other columns for labels.
Column names other than smiles column would be considered as task names.
smile2graph: callable, str -> DGLGraph
A function turns smiles into a DGLGraph. Default one can be found
at python/dgl/data/chem/utils.py named with smile2graph.
smile_column: str
Column name that including smiles
cache_file_path: str
Path to store the preprocessed data
"""
def __init__(self, df, smile2graph=smile2graph, smile_column='smiles', cache_file_path="csvdata_dglgraph.pkl"):
if 'rdkit' not in sys.modules:
from ...base import dgl_warning
dgl_warning("Please install RDKit (Recommended Version is 2018.09.3)")
self.df = df
self.smiles = self.df[smile_column].tolist()
self.task_names = self.df.columns.drop([smile_column]).tolist()
self.cache_file_path = cache_file_path
self._pre_process(smile2graph)
def _pre_process(self, smile2graph):
"""Pre-process the dataset
* Convert molecules from smiles format into DGLGraphs
and featurize their atoms
* Set missing labels to be 0 and use a binary masking
matrix to mask them
"""
if os.path.exists(self.cache_file_path):
# DGLGraphs have been constructed before, reload them
print('Loading previously saved dgl graphs...')
with open(self.cache_file_path, 'rb') as f:
self.graphs = pickle.load(f)
else:
self.graphs = []
for id, s in enumerate(self.smiles):
self.graphs.append(smile2graph(s))
with open(self.cache_file_path, 'wb') as f:
pickle.dump(self.graphs, f)
_label_values = self.df[self.task_names].values
# np.nan_to_num will also turn inf into a very large number
self.labels = F.zerocopy_from_numpy(np.nan_to_num(_label_values))
self.mask = F.zerocopy_from_numpy(~np.isnan(_label_values).astype(np.float32))
def __getitem__(self, item):
"""Get the ith datapoint
Returns
-------
str
SMILES for the ith datapoint
DGLGraph
DGLGraph for the ith datapoint
Tensor of dtype float32
Labels of the datapoint for all tasks
Tensor of dtype float32
Weights of the datapoint for all tasks
"""
return self.smiles[item], self.graphs[item], self.labels[item], self.mask[item]
def __len__(self):
"""Length of Dataset
Return
------
int
Length of Dataset
"""
return len(self.smiles)
import numpy as np
import sys
from .csv_dataset import CSVDataset
from .utils import smile2graph
from ..utils import get_download_dir, download, _get_dgl_url, Subset
try:
import pandas as pd
except ImportError:
pass
class Tox21(CSVDataset):
_url = 'dataset/tox21.csv.gz'
"""Tox21 dataset.
The Toxicology in the 21st Century (https://tripod.nih.gov/tox21/challenge/)
initiative created a public database measuring toxicity of compounds, which
has been used in the 2014 Tox21 Data Challenge. The dataset contains qualitative
toxicity measurements for 8014 compounds on 12 different targets, including nuclear
receptors and stress response pathways. Each target results in a binary label.
A common issue for multi-task prediction is that some datapoints are not labeled for
all tasks. This is also the case for Tox21. In data pre-processing, we set non-existing
labels to be 0 so that they can be placed in tensors and used for masking in loss computation.
See examples below for more details.
All molecules are converted into DGLGraphs. After the first-time construction,
the DGLGraphs will be saved for reloading so that we do not need to reconstruct them everytime.
Parameters
----------
smile2graph: callable, str -> DGLGraph
A function turns smiles into a DGLGraph. Default one can be found
at python/dgl/data/chem/utils.py named with smile2graph.
"""
def __init__(self, smile2graph=smile2graph):
if 'pandas' not in sys.modules:
from ...base import dgl_warning
dgl_warning("Please install pandas")
data_path = get_download_dir() + '/tox21.csv.gz'
download(_get_dgl_url(self._url), path=data_path)
df = pd.read_csv(data_path)
self.id = df['mol_id']
df = df.drop(columns=['mol_id'])
super().__init__(df, smile2graph, cache_file_path="tox21_dglgraph.pkl")
self._weight_balancing()
def _weight_balancing(self):
"""Perform re-balancing for each task.
It's quite common that the number of positive samples and the
number of negative samples are significantly different. To compensate
for the class imbalance issue, we can weight each datapoint in
loss computation.
In particular, for each task we will set the weight of negative samples
to be 1 and the weight of positive samples to be the number of negative
samples divided by the number of positive samples.
If weight balancing is performed, one attribute will be affected:
* self._task_pos_weights is set, which is a list of positive sample weights
for each task.
"""
num_pos = np.sum(self.labels, axis=0)
num_indices = np.sum(self.mask, axis=0)
self._task_pos_weights = (num_indices - num_pos) / num_pos
@property
def task_pos_weights(self):
"""Get weights for positive samples on each task
Returns
-------
list
numpy array gives the weight of positive samples on all tasks
"""
return self._task_pos_weights
import dgl.backend as F
import numpy as np
import os
import pickle
from dgl import DGLGraph
try:
from rdkit import Chem
from rdkit.Chem import rdmolfiles, rdmolops
except ImportError:
pass
def one_hot_encoding(x, allowable_set):
"""One-hot encoding.
Parameters
----------
x : str, int or Chem.rdchem.HybridizationType
allowable_set : list
The elements of the allowable_set should be of the
same type as x.
Returns
-------
list
List of boolean values where at most one value is True.
If the i-th value is True, then we must have
x == allowable_set[i].
"""
return list(map(lambda s: x == s, allowable_set))
class BaseAtomFeaturizer(object):
"""An abstract class for atom featurizers
All atom featurizers that map a molecule to atom features should subclass it.
All subclasses should overwrite ``_featurize_atom``, which featurizes a single
atom and ``__call__``, which featurizes all atoms in a molecule.
"""
def _featurize_atom(self, atom):
return NotImplementedError
def __call__(self, mol):
return NotImplementedError
class DefaultAtomFeaturizer(BaseAtomFeaturizer):
"""A default featurizer for atoms.
The atom features include:
* **One hot encoding of the atom type**. The supported atom types include
``C``, ``N``, ``O``, ``S``, ``F``, ``Si``, ``P``, ``Cl``, ``Br``, ``Mg``,
``Na``, ``Ca``, ``Fe``, ``As``, ``Al``, ``I``, ``B``, ``V``, ``K``, ``Tl``,
``Yb``, ``Sb``, ``Sn``, ``Ag``, ``Pd``, ``Co``, ``Se``, ``Ti``, ``Zn``,
``H``, ``Li``, ``Ge``, ``Cu``, ``Au``, ``Ni``, ``Cd``, ``In``, ``Mn``, ``Zr``,
``Cr``, ``Pt``, ``Hg``, ``Pb``.
* **One hot encoding of the atom degree**. The supported possibilities
include ``0 - 10``.
* **One hot encoding of the number of implicit Hs on the atom**. The supported
possibilities include ``0 - 6``.
* **Formal charge of the atom**.
* **Number of radical electrons of the atom**.
* **One hot encoding of the atom hybridization**. The supported possibilities include
``SP``, ``SP2``, ``SP3``, ``SP3D``, ``SP3D2``.
* **Whether the atom is aromatic**.
* **One hot encoding of the number of total Hs on the atom**. The supported possibilities
include ``0 - 4``.
Parameters
----------
atom_data_field : str
Name for storing atom features in DGLGraphs, default to be 'h'.
"""
def __init__(self, atom_data_field='h'):
super(DefaultAtomFeaturizer, self).__init__()
self.atom_data_field = atom_data_field
@property
def feat_size(self):
"""Returns feature size"""
return 74
def _featurize_atom(self, atom):
"""Featurize an atom
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
Returns
-------
results : list
List of feature values, including boolean values and numbers
"""
atom_types = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br',
'Mg', 'Na', 'Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V',
'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se',
'Ti', 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd',
'In', 'Mn', 'Zr', 'Cr', 'Pt', 'Hg', 'Pb']
results = one_hot_encoding(atom.GetSymbol(), atom_types) + \
one_hot_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + \
one_hot_encoding(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6]) + \
[atom.GetFormalCharge(), atom.GetNumRadicalElectrons()] + \
one_hot_encoding(atom.GetHybridization(),
[Chem.rdchem.HybridizationType.SP,
Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3,
Chem.rdchem.HybridizationType.SP3D,
Chem.rdchem.HybridizationType.SP3D2]) + \
[atom.GetIsAromatic()] + \
one_hot_encoding(atom.GetTotalNumHs(), [0, 1, 2, 3, 4])
return results
def __call__(self, mol):
"""Featurize a molecule
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
dict
Atom features of shape (N, 74),
where N is the number of atoms in the molecule
"""
num_atoms = mol.GetNumAtoms()
atom_features = []
for i in range(num_atoms):
atom = mol.GetAtomWithIdx(i)
atom_features.append(self._featurize_atom(atom))
atom_features = np.stack(atom_features)
atom_features = F.zerocopy_from_numpy(atom_features.astype(np.float32))
return {self.atom_data_field: atom_features}
def smile2graph(smile, add_self_loop=False, atom_featurizer=None, bond_featurizer=None):
"""Convert SMILES into a DGLGraph.
The **i** th atom in the molecule, i.e. ``mol.GetAtomWithIdx(i)``, corresponds to the
**i** th node in the returned DGLGraph.
The **i** th bond in the molecule, i.e. ``mol.GetBondWithIdx(i)``, corresponds to the
**(2i)**-th and **(2i+1)**-th edges in the returned DGLGraph. The **(2i)**-th and
**(2i+1)**-th edges will be separately from **u** to **v** and **v** to **u**, where
**u** is ``bond.GetBeginAtomIdx()`` and **v** is ``bond.GetEndAtomIdx()``.
If self loops are added, the last **n** edges will separately be self loops for
atoms ``0, 1, ..., n-1``.
Parameters
----------
smiles : str
String of SMILES
add_self_loop : bool
Whether to add self loops in DGLGraphs.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph.
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph.
"""
mol = Chem.MolFromSmiles(smile)
new_order = rdmolfiles.CanonicalRankAtoms(mol)
mol = rdmolops.RenumberAtoms(mol, new_order)
g = DGLGraph()
num_atoms = mol.GetNumAtoms()
g.add_nodes(num_atoms)
src_list = []
dst_list = []
num_bonds = mol.GetNumBonds()
for i in range(num_bonds):
bond = mol.GetBondWithIdx(i)
u = bond.GetBeginAtomIdx()
v = bond.GetEndAtomIdx()
src_list.extend([u, v])
dst_list.extend([v, u])
g.add_edges(src_list, dst_list)
if add_self_loop:
nodes = g.nodes()
g.add_edges(nodes, nodes)
# Featurization
if atom_featurizer is not None:
g.ndata.update(atom_featurizer(mol))
if bond_featurizer is not None:
g.edata.update(bond_featurizer(mol))
return g
"""Dataset utilities.""" """Dataset utilities."""
from __future__ import absolute_import from __future__ import absolute_import
import os, sys import os
import sys
import hashlib import hashlib
import warnings import warnings
import zipfile import zipfile
import tarfile import tarfile
import numpy as np
try: try:
import requests import requests
except ImportError: except ImportError:
...@@ -13,7 +15,9 @@ except ImportError: ...@@ -13,7 +15,9 @@ except ImportError:
pass pass
requests = requests_failed_to_import requests = requests_failed_to_import
__all__ = ['download', 'check_sha1', 'extract_archive', 'get_download_dir'] __all__ = ['download', 'check_sha1', 'extract_archive',
'get_download_dir', 'Subset', 'split_dataset']
def _get_dgl_url(file_url): def _get_dgl_url(file_url):
"""Get DGL online url for download.""" """Get DGL online url for download."""
...@@ -24,6 +28,25 @@ def _get_dgl_url(file_url): ...@@ -24,6 +28,25 @@ def _get_dgl_url(file_url):
return repo_url + file_url return repo_url + file_url
def split_dataset(dataset, frac_list=None, shuffle=False, random_state=None):
from itertools import accumulate
if frac_list is None:
frac_list = [0.8, 0.1, 0.1]
frac_list = np.array(frac_list)
assert np.allclose(np.sum(frac_list), 1.), \
'Expect frac_list sum to 1, got {:.4f}'.format(
np.sum(frac_list))
num_data = len(dataset)
lengths = (num_data * frac_list).astype(int)
lengths[-1] = num_data - np.sum(lengths[:-1])
if shuffle:
indices = np.random.RandomState(
seed=random_state).permutation(num_data)
else:
indices = np.arange(num_data)
return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(accumulate(lengths), lengths)]
def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True): def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True):
"""Download a given URL. """Download a given URL.
...@@ -77,18 +100,18 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ ...@@ -77,18 +100,18 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_
# Disable pyling too broad Exception # Disable pyling too broad Exception
# pylint: disable=W0703 # pylint: disable=W0703
try: try:
print('Downloading %s from %s...'%(fname, url)) print('Downloading %s from %s...' % (fname, url))
r = requests.get(url, stream=True, verify=verify_ssl) r = requests.get(url, stream=True, verify=verify_ssl)
if r.status_code != 200: if r.status_code != 200:
raise RuntimeError("Failed downloading url %s"%url) raise RuntimeError("Failed downloading url %s" % url)
with open(fname, 'wb') as f: with open(fname, 'wb') as f:
for chunk in r.iter_content(chunk_size=1024): for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks if chunk: # filter out keep-alive new chunks
f.write(chunk) f.write(chunk)
if sha1_hash and not check_sha1(fname, sha1_hash): if sha1_hash and not check_sha1(fname, sha1_hash):
raise UserWarning('File {} is downloaded but the content hash does not match.'\ raise UserWarning('File {} is downloaded but the content hash does not match.'
' The repo may be outdated or download may be incomplete. '\ ' The repo may be outdated or download may be incomplete. '
'If the "repo_url" is overridden, consider switching to '\ 'If the "repo_url" is overridden, consider switching to '
'the default repo.'.format(fname)) 'the default repo.'.format(fname))
break break
except Exception as e: except Exception as e:
...@@ -101,6 +124,7 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ ...@@ -101,6 +124,7 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_
return fname return fname
def check_sha1(filename, sha1_hash): def check_sha1(filename, sha1_hash):
"""Check whether the sha1 hash of the file content matches the expected hash. """Check whether the sha1 hash of the file content matches the expected hash.
...@@ -128,6 +152,7 @@ def check_sha1(filename, sha1_hash): ...@@ -128,6 +152,7 @@ def check_sha1(filename, sha1_hash):
return sha1.hexdigest() == sha1_hash return sha1.hexdigest() == sha1_hash
def extract_archive(file, target_dir): def extract_archive(file, target_dir):
"""Extract archive file. """Extract archive file.
...@@ -150,6 +175,7 @@ def extract_archive(file, target_dir): ...@@ -150,6 +175,7 @@ def extract_archive(file, target_dir):
archive.extractall(path=target_dir) archive.extractall(path=target_dir)
archive.close() archive.close()
def get_download_dir(): def get_download_dir():
"""Get the absolute path to the download directory. """Get the absolute path to the download directory.
...@@ -163,3 +189,41 @@ def get_download_dir(): ...@@ -163,3 +189,41 @@ def get_download_dir():
if not os.path.exists(dirname): if not os.path.exists(dirname):
os.makedirs(dirname) os.makedirs(dirname)
return dirname return dirname
class Subset(object):
"""Subset of a dataset at specified indices
Code adapted from PyTorch.
Parameters
----------
dataset
dataset[i] should return the ith datapoint
indices : list
List of datapoint indices to construct the subset
"""
def __init__(self, dataset, indices):
self.dataset = dataset
self.indices = indices
def __getitem__(self, item):
"""Get the datapoint indexed by item
Returns
-------
tuple
datapoint
"""
return self.dataset[self.indices[item]]
def __len__(self):
"""Get subset size
Returns
-------
int
Number of datapoints in the subset
"""
return len(self.indices)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment