Unverified Commit ae3102d3 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Dataset] Migration for Chemistry Datasets (#926)

* Update

* Update
parent c37076df
...@@ -76,8 +76,8 @@ def main(args): ...@@ -76,8 +76,8 @@ def main(args):
classifier_hidden_feats=args['classifier_hidden_feats'], classifier_hidden_feats=args['classifier_hidden_feats'],
n_tasks=dataset.n_tasks) n_tasks=dataset.n_tasks)
loss_criterion = BCEWithLogitsLoss(pos_weight=torch.tensor( loss_criterion = BCEWithLogitsLoss(pos_weight=dataset.task_pos_weights.to(args['device']),
dataset.task_pos_weights).to(args['device']), reduction='none') reduction='none')
optimizer = Adam(model.parameters(), lr=args['lr']) optimizer = Adam(model.parameters(), lr=args['lr'])
stopper = EarlyStopping(patience=args['patience']) stopper = EarlyStopping(patience=args['patience'])
model.to(args['device']) model.to(args['device'])
......
from .utils import * from .utils import *
from .csv_dataset import CSVDataset from .csv_dataset import MoleculeCSVDataset
from .tox21 import Tox21 from .tox21 import Tox21
from .alchemy import TencentAlchemyDataset from .alchemy import TencentAlchemyDataset
...@@ -11,7 +11,7 @@ import zipfile ...@@ -11,7 +11,7 @@ import zipfile
from collections import defaultdict from collections import defaultdict
from .utils import mol_to_complete_graph from .utils import mol_to_complete_graph
from ..utils import download, get_download_dir, _get_dgl_url from ..utils import download, get_download_dir, _get_dgl_url, save_graphs, load_graphs
from ... import backend as F from ... import backend as F
try: try:
...@@ -172,7 +172,7 @@ class TencentAlchemyDataset(object): ...@@ -172,7 +172,7 @@ class TencentAlchemyDataset(object):
file_dir = osp.join(get_download_dir(), 'Alchemy_data') file_dir = osp.join(get_download_dir(), 'Alchemy_data')
if not from_raw: if not from_raw:
file_name = "%s_processed" % (mode) file_name = "%s_processed_dgl" % (mode)
else: else:
file_name = "%s_single_sdf" % (mode) file_name = "%s_single_sdf" % (mode)
self.file_dir = pathlib.Path(file_dir, file_name) self.file_dir = pathlib.Path(file_dir, file_name)
...@@ -189,10 +189,11 @@ class TencentAlchemyDataset(object): ...@@ -189,10 +189,11 @@ class TencentAlchemyDataset(object):
def _load(self): def _load(self):
if not self.from_raw: if not self.from_raw:
with open(osp.join(self.file_dir, "%s_graphs.pkl" % self.mode), "rb") as f: self.graphs, label_dict = load_graphs(osp.join(self.file_dir, "%s_graphs.bin" % self.mode))
self.graphs = pickle.load(f) self.labels = label_dict['labels']
with open(osp.join(self.file_dir, "%s_labels.pkl" % self.mode), "rb") as f: with open(osp.join(self.file_dir, "%s_smiles.txt" % self.mode), 'r') as f:
self.labels = pickle.load(f) smiles_ = f.readlines()
self.smiles = [s.strip() for s in smiles_]
else: else:
print('Start preprocessing dataset...') print('Start preprocessing dataset...')
target_file = pathlib.Path(self.file_dir, "%s_target.csv" % self.mode) target_file = pathlib.Path(self.file_dir, "%s_target.csv" % self.mode)
...@@ -201,7 +202,7 @@ class TencentAlchemyDataset(object): ...@@ -201,7 +202,7 @@ class TencentAlchemyDataset(object):
index_col=0, index_col=0,
usecols=['gdb_idx',] + ['property_%d' % x for x in range(12)]) usecols=['gdb_idx',] + ['property_%d' % x for x in range(12)])
self.target = self.target[['property_%d' % x for x in range(12)]] self.target = self.target[['property_%d' % x for x in range(12)]]
self.graphs, self.labels = [], [] self.graphs, self.labels, self.smiles = [], [], []
supp = Chem.SDMolSupplier(osp.join(self.file_dir, self.mode + ".sdf")) supp = Chem.SDMolSupplier(osp.join(self.file_dir, self.mode + ".sdf"))
cnt = 0 cnt = 0
...@@ -211,16 +212,17 @@ class TencentAlchemyDataset(object): ...@@ -211,16 +212,17 @@ class TencentAlchemyDataset(object):
print('Processing molecule {:d}/{:d}'.format(cnt, dataset_size)) print('Processing molecule {:d}/{:d}'.format(cnt, dataset_size))
graph = mol_to_complete_graph(mol, atom_featurizer=alchemy_nodes, graph = mol_to_complete_graph(mol, atom_featurizer=alchemy_nodes,
bond_featurizer=alchemy_edges) bond_featurizer=alchemy_edges)
smile = Chem.MolToSmiles(mol) smiles = Chem.MolToSmiles(mol)
graph.smile = smile self.smiles.append(smiles)
self.graphs.append(graph) self.graphs.append(graph)
label = F.tensor(np.array(label[1].tolist()).astype(np.float32)) label = F.tensor(np.array(label[1].tolist()).astype(np.float32))
self.labels.append(label) self.labels.append(label)
with open(osp.join(self.file_dir, "%s_graphs.pkl" % self.mode), "wb") as f: save_graphs(osp.join(self.file_dir, "%s_graphs.bin" % self.mode), self.graphs,
pickle.dump(self.graphs, f) labels={'labels': F.stack(self.labels, dim=0)})
with open(osp.join(self.file_dir, "%s_labels.pkl" % self.mode), "wb") as f: with open(osp.join(self.file_dir, "%s_smiles.txt" % self.mode), 'w') as f:
pickle.dump(self.labels, f) for s in self.smiles:
f.write(s + '\n')
self.set_mean_and_std() self.set_mean_and_std()
print(len(self.graphs), "loaded!") print(len(self.graphs), "loaded!")
...@@ -242,8 +244,7 @@ class TencentAlchemyDataset(object): ...@@ -242,8 +244,7 @@ class TencentAlchemyDataset(object):
Tensor of dtype float32 Tensor of dtype float32
Labels of the datapoint for all tasks Labels of the datapoint for all tasks
""" """
g, l = self.graphs[item], self.labels[item] return self.smiles[item], self.graphs[item], self.labels[item]
return g.smile, g, l
def __len__(self): def __len__(self):
"""Length of the dataset """Length of the dataset
......
...@@ -3,17 +3,17 @@ from __future__ import absolute_import ...@@ -3,17 +3,17 @@ from __future__ import absolute_import
import dgl.backend as F import dgl.backend as F
import numpy as np import numpy as np
import os import os
import pickle
import sys import sys
from dgl import DGLGraph
from .utils import smile_to_bigraph from .utils import smile_to_bigraph
from ..utils import save_graphs, load_graphs
from ... import backend as F
from ...graph import DGLGraph
class MoleculeCSVDataset(object):
"""MoleculeCSVDataset
class CSVDataset(object): This is a general class for loading molecular data from csv or pd.DataFrame.
"""CSVDataset
This is a general class for loading data from csv or pd.DataFrame.
In data pre-processing, we set non-existing labels to be 0, In data pre-processing, we set non-existing labels to be 0,
and returning mask with 1 where label exists. and returning mask with 1 where label exists.
...@@ -36,7 +36,7 @@ class CSVDataset(object): ...@@ -36,7 +36,7 @@ class CSVDataset(object):
Path to store the preprocessed data Path to store the preprocessed data
""" """
def __init__(self, df, smile_to_graph=smile_to_bigraph, smile_column='smiles', def __init__(self, df, smile_to_graph=smile_to_bigraph, smile_column='smiles',
cache_file_path="csvdata_dglgraph.pkl"): cache_file_path="csvdata_dglgraph.bin"):
if 'rdkit' not in sys.modules: if 'rdkit' not in sys.modules:
from ...base import dgl_warning from ...base import dgl_warning
dgl_warning( dgl_warning(
...@@ -64,17 +64,21 @@ class CSVDataset(object): ...@@ -64,17 +64,21 @@ class CSVDataset(object):
if os.path.exists(self.cache_file_path): if os.path.exists(self.cache_file_path):
# DGLGraphs have been constructed before, reload them # DGLGraphs have been constructed before, reload them
print('Loading previously saved dgl graphs...') print('Loading previously saved dgl graphs...')
with open(self.cache_file_path, 'rb') as f: self.graphs, label_dict = load_graphs(self.cache_file_path)
self.graphs = pickle.load(f) self.labels = label_dict['labels']
self.mask = label_dict['mask']
else: else:
self.graphs = [smile_to_graph(s) for s in self.smiles] print('Processing dgl graphs from scratch...')
with open(self.cache_file_path, 'wb') as f: self.graphs = []
pickle.dump(self.graphs, f) for i, s in enumerate(self.smiles):
print('Processing molecule {:d}/{:d}'.format(i+1, len(self)))
self.graphs.append(smile_to_graph(s))
_label_values = self.df[self.task_names].values _label_values = self.df[self.task_names].values
# np.nan_to_num will also turn inf into a very large number # np.nan_to_num will also turn inf into a very large number
self.labels = np.nan_to_num(_label_values).astype(np.float32) self.labels = F.zerocopy_from_numpy(np.nan_to_num(_label_values).astype(np.float32))
self.mask = (~np.isnan(_label_values)).astype(np.float32) self.mask = F.zerocopy_from_numpy((~np.isnan(_label_values)).astype(np.float32))
save_graphs(self.cache_file_path, self.graphs,
labels={'labels': self.labels, 'mask': self.mask})
def __getitem__(self, item): def __getitem__(self, item):
"""Get datapoint with index """Get datapoint with index
...@@ -95,9 +99,7 @@ class CSVDataset(object): ...@@ -95,9 +99,7 @@ class CSVDataset(object):
Tensor of dtype float32 Tensor of dtype float32
Binary masks indicating the existence of labels for all tasks Binary masks indicating the existence of labels for all tasks
""" """
return self.smiles[item], self.graphs[item], \ return self.smiles[item], self.graphs[item], self.labels[item], self.mask[item]
F.zerocopy_from_numpy(self.labels[item]), \
F.zerocopy_from_numpy(self.mask[item])
def __len__(self): def __len__(self):
"""Length of the dataset """Length of the dataset
......
import numpy as np import numpy as np
import sys import sys
from .csv_dataset import CSVDataset from .csv_dataset import MoleculeCSVDataset
from .utils import smile_to_bigraph from .utils import smile_to_bigraph
from ..utils import get_download_dir, download, _get_dgl_url from ..utils import get_download_dir, download, _get_dgl_url
from ... import backend as F
try: try:
import pandas as pd import pandas as pd
except ImportError: except ImportError:
pass pass
class Tox21(CSVDataset): class Tox21(MoleculeCSVDataset):
"""Tox21 dataset. """Tox21 dataset.
The Toxicology in the 21st Century (https://tripod.nih.gov/tox21/challenge/) The Toxicology in the 21st Century (https://tripod.nih.gov/tox21/challenge/)
...@@ -46,7 +47,7 @@ class Tox21(CSVDataset): ...@@ -46,7 +47,7 @@ class Tox21(CSVDataset):
df = df.drop(columns=['mol_id']) df = df.drop(columns=['mol_id'])
super().__init__(df, smile_to_graph, cache_file_path="tox21_dglgraph.pkl") super().__init__(df, smile_to_graph, cache_file_path="tox21_dglgraph.bin")
self._weight_balancing() self._weight_balancing()
...@@ -67,8 +68,8 @@ class Tox21(CSVDataset): ...@@ -67,8 +68,8 @@ class Tox21(CSVDataset):
* self._task_pos_weights is set, which is a list of positive sample weights * self._task_pos_weights is set, which is a list of positive sample weights
for each task. for each task.
""" """
num_pos = np.sum(self.labels, axis=0) num_pos = F.sum(self.labels, dim=0)
num_indices = np.sum(self.mask, axis=0) num_indices = F.sum(self.mask, dim=0)
self._task_pos_weights = (num_indices - num_pos) / num_pos self._task_pos_weights = (num_indices - num_pos) / num_pos
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment