Unverified Commit f3b866f4 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

Update (#1359)

parent 6c7c4039
......@@ -43,9 +43,11 @@ class MoleculeCSVDataset(object):
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
log_every : bool
Print a message every time ``log_every`` molecules are processed. Default to 1000.
"""
def __init__(self, df, smiles_to_graph, node_featurizer, edge_featurizer,
smiles_column, cache_file_path, task_names=None, load=True):
smiles_column, cache_file_path, task_names=None, load=True, log_every=1000):
self.df = df
self.smiles = self.df[smiles_column].tolist()
if task_names is None:
......@@ -54,9 +56,9 @@ class MoleculeCSVDataset(object):
self.task_names = task_names
self.n_tasks = len(self.task_names)
self.cache_file_path = cache_file_path
self._pre_process(smiles_to_graph, node_featurizer, edge_featurizer, load)
self._pre_process(smiles_to_graph, node_featurizer, edge_featurizer, load, log_every)
def _pre_process(self, smiles_to_graph, node_featurizer, edge_featurizer, load):
def _pre_process(self, smiles_to_graph, node_featurizer, edge_featurizer, load, log_every):
"""Pre-process the dataset
* Convert molecules from smiles format into DGLGraphs
......@@ -78,6 +80,8 @@ class MoleculeCSVDataset(object):
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
log_every : bool
Print a message every time ``log_every`` molecules are processed.
"""
if os.path.exists(self.cache_file_path) and load:
# DGLGraphs have been constructed before, reload them
......@@ -89,7 +93,8 @@ class MoleculeCSVDataset(object):
print('Processing dgl graphs from scratch...')
self.graphs = []
for i, s in enumerate(self.smiles):
print('Processing molecule {:d}/{:d}'.format(i+1, len(self)))
if (i + 1) % log_every == 0:
print('Processing molecule {:d}/{:d}'.format(i+1, len(self)))
self.graphs.append(smiles_to_graph(s, node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer))
_label_values = self.df[self.task_names].values
......
......@@ -34,9 +34,11 @@ class PubChemBioAssayAromaticity(MoleculeCSVDataset):
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to pre-process from scratch. Default to True.
log_every : bool
Print a message every time ``log_every`` molecules are processed. Default to 1000.
"""
def __init__(self, smiles_to_graph=smiles_to_bigraph,
node_featurizer=None, edge_featurizer=None, load=True):
node_featurizer=None, edge_featurizer=None, load=True, log_every=1000):
self._url = 'dataset/pubchem_bioassay_aromaticity.csv'
data_path = get_download_dir() + '/pubchem_bioassay_aromaticity.csv'
download(_get_dgl_url(self._url), path=data_path)
......@@ -44,4 +46,4 @@ class PubChemBioAssayAromaticity(MoleculeCSVDataset):
super(PubChemBioAssayAromaticity, self).__init__(
df, smiles_to_graph, node_featurizer, edge_featurizer, "cano_smiles",
"pubchem_aromaticity_dglgraph.bin", load=load)
"pubchem_aromaticity_dglgraph.bin", load=load, log_every=log_every)
......@@ -41,11 +41,14 @@ class Tox21(MoleculeCSVDataset):
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
log_every : bool
Print a message every time ``log_every`` molecules are processed. Default to 1000.
"""
def __init__(self, smiles_to_graph=smiles_to_bigraph,
node_featurizer=None,
edge_featurizer=None,
load=True):
load=True,
log_every=1000):
self._url = 'dataset/tox21.csv.gz'
data_path = get_download_dir() + '/tox21.csv.gz'
download(_get_dgl_url(self._url), path=data_path)
......@@ -55,7 +58,8 @@ class Tox21(MoleculeCSVDataset):
df = df.drop(columns=['mol_id'])
super(Tox21, self).__init__(df, smiles_to_graph, node_featurizer, edge_featurizer,
"smiles", "tox21_dglgraph.bin", load=load)
"smiles", "tox21_dglgraph.bin",
load=load, log_every=log_every)
self._weight_balancing()
def _weight_balancing(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment