Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
2ee3c78c
Commit
2ee3c78c
authored
Aug 14, 2019
by
VoVAllen
Committed by
Mufei Li
Aug 14, 2019
Browse files
[Dataset] Tox21 (#760)
* tox21 * fix ci * fix ci * fix urls to url * add doc * remove binary
parent
17b60e1a
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
462 additions
and
8 deletions
+462
-8
python/dgl/data/__init__.py
python/dgl/data/__init__.py
+1
-0
python/dgl/data/chem/__init__.py
python/dgl/data/chem/__init__.py
+1
-0
python/dgl/data/chem/csv_dataset.py
python/dgl/data/chem/csv_dataset.py
+101
-0
python/dgl/data/chem/tox21.py
python/dgl/data/chem/tox21.py
+87
-0
python/dgl/data/chem/utils.py
python/dgl/data/chem/utils.py
+200
-0
python/dgl/data/utils.py
python/dgl/data/utils.py
+72
-8
No files found.
python/dgl/data/__init__.py
View file @
2ee3c78c
...
...
@@ -11,6 +11,7 @@ from .reddit import RedditDataset
from
.ppi
import
PPIDataset
from
.tu
import
TUDataset
from
.gindt
import
GINDataset
from
.chem
import
Tox21
def
register_data_args
(
parser
):
...
...
python/dgl/data/chem/__init__.py
0 → 100644
View file @
2ee3c78c
from
.tox21
import
Tox21
\ No newline at end of file
python/dgl/data/chem/csv_dataset.py
0 → 100644
View file @
2ee3c78c
from
__future__
import
absolute_import
import
dgl.backend
as
F
import
numpy
as
np
import
os
import
pickle
import
sys
from
dgl
import
DGLGraph
from
.utils
import
smile2graph
from
..utils
import
download
,
get_download_dir
,
_get_dgl_url
,
Subset
class
CSVDataset
(
object
):
"""CSVDataset
This is a general class for loading data from csv or pd.DataFrame.
In data pre-processing, we set non-existing labels to be 0, and returning mask with 1 where label exists.
All molecules are converted into DGLGraphs. After the first-time construction, the
DGLGraphs will be saved for reloading so that we do not need to reconstruct them every time.
Parameters
----------
df: pandas.DataFrame
Dataframe including smiles and labels. Can be loaded by pandas.read_csv(file_path).
One column includes smiles and other columns for labels.
Column names other than smiles column would be considered as task names.
smile2graph: callable, str -> DGLGraph
A function turns smiles into a DGLGraph. Default one can be found
at python/dgl/data/chem/utils.py named with smile2graph.
smile_column: str
Column name that including smiles
cache_file_path: str
Path to store the preprocessed data
"""
def
__init__
(
self
,
df
,
smile2graph
=
smile2graph
,
smile_column
=
'smiles'
,
cache_file_path
=
"csvdata_dglgraph.pkl"
):
if
'rdkit'
not
in
sys
.
modules
:
from
...base
import
dgl_warning
dgl_warning
(
"Please install RDKit (Recommended Version is 2018.09.3)"
)
self
.
df
=
df
self
.
smiles
=
self
.
df
[
smile_column
].
tolist
()
self
.
task_names
=
self
.
df
.
columns
.
drop
([
smile_column
]).
tolist
()
self
.
cache_file_path
=
cache_file_path
self
.
_pre_process
(
smile2graph
)
def
_pre_process
(
self
,
smile2graph
):
"""Pre-process the dataset
* Convert molecules from smiles format into DGLGraphs
and featurize their atoms
* Set missing labels to be 0 and use a binary masking
matrix to mask them
"""
if
os
.
path
.
exists
(
self
.
cache_file_path
):
# DGLGraphs have been constructed before, reload them
print
(
'Loading previously saved dgl graphs...'
)
with
open
(
self
.
cache_file_path
,
'rb'
)
as
f
:
self
.
graphs
=
pickle
.
load
(
f
)
else
:
self
.
graphs
=
[]
for
id
,
s
in
enumerate
(
self
.
smiles
):
self
.
graphs
.
append
(
smile2graph
(
s
))
with
open
(
self
.
cache_file_path
,
'wb'
)
as
f
:
pickle
.
dump
(
self
.
graphs
,
f
)
_label_values
=
self
.
df
[
self
.
task_names
].
values
# np.nan_to_num will also turn inf into a very large number
self
.
labels
=
F
.
zerocopy_from_numpy
(
np
.
nan_to_num
(
_label_values
))
self
.
mask
=
F
.
zerocopy_from_numpy
(
~
np
.
isnan
(
_label_values
).
astype
(
np
.
float32
))
def
__getitem__
(
self
,
item
):
"""Get the ith datapoint
Returns
-------
str
SMILES for the ith datapoint
DGLGraph
DGLGraph for the ith datapoint
Tensor of dtype float32
Labels of the datapoint for all tasks
Tensor of dtype float32
Weights of the datapoint for all tasks
"""
return
self
.
smiles
[
item
],
self
.
graphs
[
item
],
self
.
labels
[
item
],
self
.
mask
[
item
]
def
__len__
(
self
):
"""Length of Dataset
Return
------
int
Length of Dataset
"""
return
len
(
self
.
smiles
)
python/dgl/data/chem/tox21.py
0 → 100644
View file @
2ee3c78c
import
numpy
as
np
import
sys
from
.csv_dataset
import
CSVDataset
from
.utils
import
smile2graph
from
..utils
import
get_download_dir
,
download
,
_get_dgl_url
,
Subset
try
:
import
pandas
as
pd
except
ImportError
:
pass
class
Tox21
(
CSVDataset
):
_url
=
'dataset/tox21.csv.gz'
"""Tox21 dataset.
The Toxicology in the 21st Century (https://tripod.nih.gov/tox21/challenge/)
initiative created a public database measuring toxicity of compounds, which
has been used in the 2014 Tox21 Data Challenge. The dataset contains qualitative
toxicity measurements for 8014 compounds on 12 different targets, including nuclear
receptors and stress response pathways. Each target results in a binary label.
A common issue for multi-task prediction is that some datapoints are not labeled for
all tasks. This is also the case for Tox21. In data pre-processing, we set non-existing
labels to be 0 so that they can be placed in tensors and used for masking in loss computation.
See examples below for more details.
All molecules are converted into DGLGraphs. After the first-time construction,
the DGLGraphs will be saved for reloading so that we do not need to reconstruct them everytime.
Parameters
----------
smile2graph: callable, str -> DGLGraph
A function turns smiles into a DGLGraph. Default one can be found
at python/dgl/data/chem/utils.py named with smile2graph.
"""
def
__init__
(
self
,
smile2graph
=
smile2graph
):
if
'pandas'
not
in
sys
.
modules
:
from
...base
import
dgl_warning
dgl_warning
(
"Please install pandas"
)
data_path
=
get_download_dir
()
+
'/tox21.csv.gz'
download
(
_get_dgl_url
(
self
.
_url
),
path
=
data_path
)
df
=
pd
.
read_csv
(
data_path
)
self
.
id
=
df
[
'mol_id'
]
df
=
df
.
drop
(
columns
=
[
'mol_id'
])
super
().
__init__
(
df
,
smile2graph
,
cache_file_path
=
"tox21_dglgraph.pkl"
)
self
.
_weight_balancing
()
def
_weight_balancing
(
self
):
"""Perform re-balancing for each task.
It's quite common that the number of positive samples and the
number of negative samples are significantly different. To compensate
for the class imbalance issue, we can weight each datapoint in
loss computation.
In particular, for each task we will set the weight of negative samples
to be 1 and the weight of positive samples to be the number of negative
samples divided by the number of positive samples.
If weight balancing is performed, one attribute will be affected:
* self._task_pos_weights is set, which is a list of positive sample weights
for each task.
"""
num_pos
=
np
.
sum
(
self
.
labels
,
axis
=
0
)
num_indices
=
np
.
sum
(
self
.
mask
,
axis
=
0
)
self
.
_task_pos_weights
=
(
num_indices
-
num_pos
)
/
num_pos
@
property
def
task_pos_weights
(
self
):
"""Get weights for positive samples on each task
Returns
-------
list
numpy array gives the weight of positive samples on all tasks
"""
return
self
.
_task_pos_weights
python/dgl/data/chem/utils.py
0 → 100644
View file @
2ee3c78c
import
dgl.backend
as
F
import
numpy
as
np
import
os
import
pickle
from
dgl
import
DGLGraph
try
:
from
rdkit
import
Chem
from
rdkit.Chem
import
rdmolfiles
,
rdmolops
except
ImportError
:
pass
def
one_hot_encoding
(
x
,
allowable_set
):
"""One-hot encoding.
Parameters
----------
x : str, int or Chem.rdchem.HybridizationType
allowable_set : list
The elements of the allowable_set should be of the
same type as x.
Returns
-------
list
List of boolean values where at most one value is True.
If the i-th value is True, then we must have
x == allowable_set[i].
"""
return
list
(
map
(
lambda
s
:
x
==
s
,
allowable_set
))
class
BaseAtomFeaturizer
(
object
):
"""An abstract class for atom featurizers
All atom featurizers that map a molecule to atom features should subclass it.
All subclasses should overwrite ``_featurize_atom``, which featurizes a single
atom and ``__call__``, which featurizes all atoms in a molecule.
"""
def
_featurize_atom
(
self
,
atom
):
return
NotImplementedError
def
__call__
(
self
,
mol
):
return
NotImplementedError
class
DefaultAtomFeaturizer
(
BaseAtomFeaturizer
):
"""A default featurizer for atoms.
The atom features include:
* **One hot encoding of the atom type**. The supported atom types include
``C``, ``N``, ``O``, ``S``, ``F``, ``Si``, ``P``, ``Cl``, ``Br``, ``Mg``,
``Na``, ``Ca``, ``Fe``, ``As``, ``Al``, ``I``, ``B``, ``V``, ``K``, ``Tl``,
``Yb``, ``Sb``, ``Sn``, ``Ag``, ``Pd``, ``Co``, ``Se``, ``Ti``, ``Zn``,
``H``, ``Li``, ``Ge``, ``Cu``, ``Au``, ``Ni``, ``Cd``, ``In``, ``Mn``, ``Zr``,
``Cr``, ``Pt``, ``Hg``, ``Pb``.
* **One hot encoding of the atom degree**. The supported possibilities
include ``0 - 10``.
* **One hot encoding of the number of implicit Hs on the atom**. The supported
possibilities include ``0 - 6``.
* **Formal charge of the atom**.
* **Number of radical electrons of the atom**.
* **One hot encoding of the atom hybridization**. The supported possibilities include
``SP``, ``SP2``, ``SP3``, ``SP3D``, ``SP3D2``.
* **Whether the atom is aromatic**.
* **One hot encoding of the number of total Hs on the atom**. The supported possibilities
include ``0 - 4``.
Parameters
----------
atom_data_field : str
Name for storing atom features in DGLGraphs, default to be 'h'.
"""
def
__init__
(
self
,
atom_data_field
=
'h'
):
super
(
DefaultAtomFeaturizer
,
self
).
__init__
()
self
.
atom_data_field
=
atom_data_field
@
property
def
feat_size
(
self
):
"""Returns feature size"""
return
74
def
_featurize_atom
(
self
,
atom
):
"""Featurize an atom
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
Returns
-------
results : list
List of feature values, including boolean values and numbers
"""
atom_types
=
[
'C'
,
'N'
,
'O'
,
'S'
,
'F'
,
'Si'
,
'P'
,
'Cl'
,
'Br'
,
'Mg'
,
'Na'
,
'Ca'
,
'Fe'
,
'As'
,
'Al'
,
'I'
,
'B'
,
'V'
,
'K'
,
'Tl'
,
'Yb'
,
'Sb'
,
'Sn'
,
'Ag'
,
'Pd'
,
'Co'
,
'Se'
,
'Ti'
,
'Zn'
,
'H'
,
'Li'
,
'Ge'
,
'Cu'
,
'Au'
,
'Ni'
,
'Cd'
,
'In'
,
'Mn'
,
'Zr'
,
'Cr'
,
'Pt'
,
'Hg'
,
'Pb'
]
results
=
one_hot_encoding
(
atom
.
GetSymbol
(),
atom_types
)
+
\
one_hot_encoding
(
atom
.
GetDegree
(),
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
+
\
one_hot_encoding
(
atom
.
GetImplicitValence
(),
[
0
,
1
,
2
,
3
,
4
,
5
,
6
])
+
\
[
atom
.
GetFormalCharge
(),
atom
.
GetNumRadicalElectrons
()]
+
\
one_hot_encoding
(
atom
.
GetHybridization
(),
[
Chem
.
rdchem
.
HybridizationType
.
SP
,
Chem
.
rdchem
.
HybridizationType
.
SP2
,
Chem
.
rdchem
.
HybridizationType
.
SP3
,
Chem
.
rdchem
.
HybridizationType
.
SP3D
,
Chem
.
rdchem
.
HybridizationType
.
SP3D2
])
+
\
[
atom
.
GetIsAromatic
()]
+
\
one_hot_encoding
(
atom
.
GetTotalNumHs
(),
[
0
,
1
,
2
,
3
,
4
])
return
results
def
__call__
(
self
,
mol
):
"""Featurize a molecule
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
dict
Atom features of shape (N, 74),
where N is the number of atoms in the molecule
"""
num_atoms
=
mol
.
GetNumAtoms
()
atom_features
=
[]
for
i
in
range
(
num_atoms
):
atom
=
mol
.
GetAtomWithIdx
(
i
)
atom_features
.
append
(
self
.
_featurize_atom
(
atom
))
atom_features
=
np
.
stack
(
atom_features
)
atom_features
=
F
.
zerocopy_from_numpy
(
atom_features
.
astype
(
np
.
float32
))
return
{
self
.
atom_data_field
:
atom_features
}
def
smile2graph
(
smile
,
add_self_loop
=
False
,
atom_featurizer
=
None
,
bond_featurizer
=
None
):
"""Convert SMILES into a DGLGraph.
The **i** th atom in the molecule, i.e. ``mol.GetAtomWithIdx(i)``, corresponds to the
**i** th node in the returned DGLGraph.
The **i** th bond in the molecule, i.e. ``mol.GetBondWithIdx(i)``, corresponds to the
**(2i)**-th and **(2i+1)**-th edges in the returned DGLGraph. The **(2i)**-th and
**(2i+1)**-th edges will be separately from **u** to **v** and **v** to **u**, where
**u** is ``bond.GetBeginAtomIdx()`` and **v** is ``bond.GetEndAtomIdx()``.
If self loops are added, the last **n** edges will separately be self loops for
atoms ``0, 1, ..., n-1``.
Parameters
----------
smiles : str
String of SMILES
add_self_loop : bool
Whether to add self loops in DGLGraphs.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph.
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph.
"""
mol
=
Chem
.
MolFromSmiles
(
smile
)
new_order
=
rdmolfiles
.
CanonicalRankAtoms
(
mol
)
mol
=
rdmolops
.
RenumberAtoms
(
mol
,
new_order
)
g
=
DGLGraph
()
num_atoms
=
mol
.
GetNumAtoms
()
g
.
add_nodes
(
num_atoms
)
src_list
=
[]
dst_list
=
[]
num_bonds
=
mol
.
GetNumBonds
()
for
i
in
range
(
num_bonds
):
bond
=
mol
.
GetBondWithIdx
(
i
)
u
=
bond
.
GetBeginAtomIdx
()
v
=
bond
.
GetEndAtomIdx
()
src_list
.
extend
([
u
,
v
])
dst_list
.
extend
([
v
,
u
])
g
.
add_edges
(
src_list
,
dst_list
)
if
add_self_loop
:
nodes
=
g
.
nodes
()
g
.
add_edges
(
nodes
,
nodes
)
# Featurization
if
atom_featurizer
is
not
None
:
g
.
ndata
.
update
(
atom_featurizer
(
mol
))
if
bond_featurizer
is
not
None
:
g
.
edata
.
update
(
bond_featurizer
(
mol
))
return
g
python/dgl/data/utils.py
View file @
2ee3c78c
"""Dataset utilities."""
from
__future__
import
absolute_import
import
os
,
sys
import
os
import
sys
import
hashlib
import
warnings
import
zipfile
import
tarfile
import
numpy
as
np
try
:
import
requests
except
ImportError
:
...
...
@@ -13,7 +15,9 @@ except ImportError:
pass
requests
=
requests_failed_to_import
__all__
=
[
'download'
,
'check_sha1'
,
'extract_archive'
,
'get_download_dir'
]
__all__
=
[
'download'
,
'check_sha1'
,
'extract_archive'
,
'get_download_dir'
,
'Subset'
,
'split_dataset'
]
def
_get_dgl_url
(
file_url
):
"""Get DGL online url for download."""
...
...
@@ -24,6 +28,25 @@ def _get_dgl_url(file_url):
return
repo_url
+
file_url
def
split_dataset
(
dataset
,
frac_list
=
None
,
shuffle
=
False
,
random_state
=
None
):
from
itertools
import
accumulate
if
frac_list
is
None
:
frac_list
=
[
0.8
,
0.1
,
0.1
]
frac_list
=
np
.
array
(
frac_list
)
assert
np
.
allclose
(
np
.
sum
(
frac_list
),
1.
),
\
'Expect frac_list sum to 1, got {:.4f}'
.
format
(
np
.
sum
(
frac_list
))
num_data
=
len
(
dataset
)
lengths
=
(
num_data
*
frac_list
).
astype
(
int
)
lengths
[
-
1
]
=
num_data
-
np
.
sum
(
lengths
[:
-
1
])
if
shuffle
:
indices
=
np
.
random
.
RandomState
(
seed
=
random_state
).
permutation
(
num_data
)
else
:
indices
=
np
.
arange
(
num_data
)
return
[
Subset
(
dataset
,
indices
[
offset
-
length
:
offset
])
for
offset
,
length
in
zip
(
accumulate
(
lengths
),
lengths
)]
def
download
(
url
,
path
=
None
,
overwrite
=
False
,
sha1_hash
=
None
,
retries
=
5
,
verify_ssl
=
True
):
"""Download a given URL.
...
...
@@ -77,18 +100,18 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_
# Disable pyling too broad Exception
# pylint: disable=W0703
try
:
print
(
'Downloading %s from %s...'
%
(
fname
,
url
))
print
(
'Downloading %s from %s...'
%
(
fname
,
url
))
r
=
requests
.
get
(
url
,
stream
=
True
,
verify
=
verify_ssl
)
if
r
.
status_code
!=
200
:
raise
RuntimeError
(
"Failed downloading url %s"
%
url
)
raise
RuntimeError
(
"Failed downloading url %s"
%
url
)
with
open
(
fname
,
'wb'
)
as
f
:
for
chunk
in
r
.
iter_content
(
chunk_size
=
1024
):
if
chunk
:
# filter out keep-alive new chunks
if
chunk
:
# filter out keep-alive new chunks
f
.
write
(
chunk
)
if
sha1_hash
and
not
check_sha1
(
fname
,
sha1_hash
):
raise
UserWarning
(
'File {} is downloaded but the content hash does not match.'
\
' The repo may be outdated or download may be incomplete. '
\
'If the "repo_url" is overridden, consider switching to '
\
raise
UserWarning
(
'File {} is downloaded but the content hash does not match.'
' The repo may be outdated or download may be incomplete. '
'If the "repo_url" is overridden, consider switching to '
'the default repo.'
.
format
(
fname
))
break
except
Exception
as
e
:
...
...
@@ -101,6 +124,7 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_
return
fname
def
check_sha1
(
filename
,
sha1_hash
):
"""Check whether the sha1 hash of the file content matches the expected hash.
...
...
@@ -128,6 +152,7 @@ def check_sha1(filename, sha1_hash):
return
sha1
.
hexdigest
()
==
sha1_hash
def
extract_archive
(
file
,
target_dir
):
"""Extract archive file.
...
...
@@ -150,6 +175,7 @@ def extract_archive(file, target_dir):
archive
.
extractall
(
path
=
target_dir
)
archive
.
close
()
def
get_download_dir
():
"""Get the absolute path to the download directory.
...
...
@@ -163,3 +189,41 @@ def get_download_dir():
if
not
os
.
path
.
exists
(
dirname
):
os
.
makedirs
(
dirname
)
return
dirname
class
Subset
(
object
):
"""Subset of a dataset at specified indices
Code adapted from PyTorch.
Parameters
----------
dataset
dataset[i] should return the ith datapoint
indices : list
List of datapoint indices to construct the subset
"""
def
__init__
(
self
,
dataset
,
indices
):
self
.
dataset
=
dataset
self
.
indices
=
indices
def
__getitem__
(
self
,
item
):
"""Get the datapoint indexed by item
Returns
-------
tuple
datapoint
"""
return
self
.
dataset
[
self
.
indices
[
item
]]
def
__len__
(
self
):
"""Get subset size
Returns
-------
int
Number of datapoints in the subset
"""
return
len
(
self
.
indices
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment