"git@developer.sourcefind.cn:change/sglang.git" did not exist on "b503881bd214eed7ff6d46e965004d721a7a11fe"
Unverified Commit f3b866f4 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

Update (#1359)

parent 6c7c4039
...@@ -43,9 +43,11 @@ class MoleculeCSVDataset(object): ...@@ -43,9 +43,11 @@ class MoleculeCSVDataset(object):
Whether to load the previously pre-processed dataset or pre-process from scratch. Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and ``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True. featurization methods and need to preprocess from scratch. Default to True.
log_every : bool
Print a message every time ``log_every`` molecules are processed. Default to 1000.
""" """
def __init__(self, df, smiles_to_graph, node_featurizer, edge_featurizer, def __init__(self, df, smiles_to_graph, node_featurizer, edge_featurizer,
smiles_column, cache_file_path, task_names=None, load=True): smiles_column, cache_file_path, task_names=None, load=True, log_every=1000):
self.df = df self.df = df
self.smiles = self.df[smiles_column].tolist() self.smiles = self.df[smiles_column].tolist()
if task_names is None: if task_names is None:
...@@ -54,9 +56,9 @@ class MoleculeCSVDataset(object): ...@@ -54,9 +56,9 @@ class MoleculeCSVDataset(object):
self.task_names = task_names self.task_names = task_names
self.n_tasks = len(self.task_names) self.n_tasks = len(self.task_names)
self.cache_file_path = cache_file_path self.cache_file_path = cache_file_path
self._pre_process(smiles_to_graph, node_featurizer, edge_featurizer, load) self._pre_process(smiles_to_graph, node_featurizer, edge_featurizer, load, log_every)
def _pre_process(self, smiles_to_graph, node_featurizer, edge_featurizer, load): def _pre_process(self, smiles_to_graph, node_featurizer, edge_featurizer, load, log_every):
"""Pre-process the dataset """Pre-process the dataset
* Convert molecules from smiles format into DGLGraphs * Convert molecules from smiles format into DGLGraphs
...@@ -78,6 +80,8 @@ class MoleculeCSVDataset(object): ...@@ -78,6 +80,8 @@ class MoleculeCSVDataset(object):
Whether to load the previously pre-processed dataset or pre-process from scratch. Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and ``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True. featurization methods and need to preprocess from scratch. Default to True.
log_every : bool
Print a message every time ``log_every`` molecules are processed.
""" """
if os.path.exists(self.cache_file_path) and load: if os.path.exists(self.cache_file_path) and load:
# DGLGraphs have been constructed before, reload them # DGLGraphs have been constructed before, reload them
...@@ -89,6 +93,7 @@ class MoleculeCSVDataset(object): ...@@ -89,6 +93,7 @@ class MoleculeCSVDataset(object):
print('Processing dgl graphs from scratch...') print('Processing dgl graphs from scratch...')
self.graphs = [] self.graphs = []
for i, s in enumerate(self.smiles): for i, s in enumerate(self.smiles):
if (i + 1) % log_every == 0:
print('Processing molecule {:d}/{:d}'.format(i+1, len(self))) print('Processing molecule {:d}/{:d}'.format(i+1, len(self)))
self.graphs.append(smiles_to_graph(s, node_featurizer=node_featurizer, self.graphs.append(smiles_to_graph(s, node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer)) edge_featurizer=edge_featurizer))
......
...@@ -34,9 +34,11 @@ class PubChemBioAssayAromaticity(MoleculeCSVDataset): ...@@ -34,9 +34,11 @@ class PubChemBioAssayAromaticity(MoleculeCSVDataset):
Whether to load the previously pre-processed dataset or pre-process from scratch. Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and ``load`` should be False when we want to try different graph construction and
featurization methods and need to pre-process from scratch. Default to True. featurization methods and need to pre-process from scratch. Default to True.
log_every : bool
Print a message every time ``log_every`` molecules are processed. Default to 1000.
""" """
def __init__(self, smiles_to_graph=smiles_to_bigraph, def __init__(self, smiles_to_graph=smiles_to_bigraph,
node_featurizer=None, edge_featurizer=None, load=True): node_featurizer=None, edge_featurizer=None, load=True, log_every=1000):
self._url = 'dataset/pubchem_bioassay_aromaticity.csv' self._url = 'dataset/pubchem_bioassay_aromaticity.csv'
data_path = get_download_dir() + '/pubchem_bioassay_aromaticity.csv' data_path = get_download_dir() + '/pubchem_bioassay_aromaticity.csv'
download(_get_dgl_url(self._url), path=data_path) download(_get_dgl_url(self._url), path=data_path)
...@@ -44,4 +46,4 @@ class PubChemBioAssayAromaticity(MoleculeCSVDataset): ...@@ -44,4 +46,4 @@ class PubChemBioAssayAromaticity(MoleculeCSVDataset):
super(PubChemBioAssayAromaticity, self).__init__( super(PubChemBioAssayAromaticity, self).__init__(
df, smiles_to_graph, node_featurizer, edge_featurizer, "cano_smiles", df, smiles_to_graph, node_featurizer, edge_featurizer, "cano_smiles",
"pubchem_aromaticity_dglgraph.bin", load=load) "pubchem_aromaticity_dglgraph.bin", load=load, log_every=log_every)
...@@ -41,11 +41,14 @@ class Tox21(MoleculeCSVDataset): ...@@ -41,11 +41,14 @@ class Tox21(MoleculeCSVDataset):
Whether to load the previously pre-processed dataset or pre-process from scratch. Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and ``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True. featurization methods and need to preprocess from scratch. Default to True.
log_every : bool
Print a message every time ``log_every`` molecules are processed. Default to 1000.
""" """
def __init__(self, smiles_to_graph=smiles_to_bigraph, def __init__(self, smiles_to_graph=smiles_to_bigraph,
node_featurizer=None, node_featurizer=None,
edge_featurizer=None, edge_featurizer=None,
load=True): load=True,
log_every=1000):
self._url = 'dataset/tox21.csv.gz' self._url = 'dataset/tox21.csv.gz'
data_path = get_download_dir() + '/tox21.csv.gz' data_path = get_download_dir() + '/tox21.csv.gz'
download(_get_dgl_url(self._url), path=data_path) download(_get_dgl_url(self._url), path=data_path)
...@@ -55,7 +58,8 @@ class Tox21(MoleculeCSVDataset): ...@@ -55,7 +58,8 @@ class Tox21(MoleculeCSVDataset):
df = df.drop(columns=['mol_id']) df = df.drop(columns=['mol_id'])
super(Tox21, self).__init__(df, smiles_to_graph, node_featurizer, edge_featurizer, super(Tox21, self).__init__(df, smiles_to_graph, node_featurizer, edge_featurizer,
"smiles", "tox21_dglgraph.bin", load=load) "smiles", "tox21_dglgraph.bin",
load=load, log_every=log_every)
self._weight_balancing() self._weight_balancing()
def _weight_balancing(self): def _weight_balancing(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment