Unverified Commit 4ae13bd2 authored by Yuchen's avatar Yuchen Committed by GitHub
Browse files

[bugfix] Fix force_reload parameter of FraudDataset (#3210)

* enable force_reload of FraudDataset

* rewrite hash_key in FraudDataset
parent b1319200
......@@ -46,6 +46,10 @@ class FraudDataset(DGLBuiltinDataset):
validation set size of the dataset, and the
size of testing set is (1 - train_size - val_size)
Default: 0.1
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Attributes
----------
......@@ -85,7 +89,8 @@ class FraudDataset(DGLBuiltinDataset):
'amazon': 'user'
}
def __init__(self, name, raw_dir=None, random_seed=717, train_size=0.7, val_size=0.1):
def __init__(self, name, raw_dir=None, random_seed=717, train_size=0.7,
val_size=0.1, force_reload=False, verbose=True):
assert name in ['yelp', 'amazon'], "only supports 'yelp', or 'amazon'"
url = _get_dgl_url(self.file_urls[name])
self.seed = random_seed
......@@ -93,7 +98,10 @@ class FraudDataset(DGLBuiltinDataset):
self.val_size = val_size
super(FraudDataset, self).__init__(name=name,
url=url,
raw_dir=raw_dir)
raw_dir=raw_dir,
hash_key=(random_seed, train_size, val_size),
force_reload=force_reload,
verbose=verbose)
def process(self):
"""process raw data to graph, labels, splitting masks"""
......@@ -154,19 +162,19 @@ class FraudDataset(DGLBuiltinDataset):
def save(self):
"""save processed data to directory `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin')
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph_{}.bin'.format(self.hash))
save_graphs(str(graph_path), self.graph)
def load(self):
"""load processed data from directory `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin')
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph_{}.bin'.format(self.hash))
graph_list, _ = load_graphs(str(graph_path))
g = graph_list[0]
self.graph = g
def has_cache(self):
"""check whether there are processed data in `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin')
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph_{}.bin'.format(self.hash))
return os.path.exists(graph_path)
def _random_split(self, x, seed=717, train_size=0.7, val_size=0.1):
......@@ -243,6 +251,10 @@ class FraudYelpDataset(FraudDataset):
validation set size of the dataset, and the
size of testing set is (1 - train_size - val_size)
Default: 0.1
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Examples
--------
......@@ -253,12 +265,15 @@ class FraudYelpDataset(FraudDataset):
>>> label = graph.ndata['label']
"""
def __init__(self, raw_dir=None, random_seed=717, train_size=0.7, val_size=0.1):
def __init__(self, raw_dir=None, random_seed=717, train_size=0.7,
val_size=0.1, force_reload=False, verbose=True):
super(FraudYelpDataset, self).__init__(name='yelp',
raw_dir=raw_dir,
random_seed=random_seed,
train_size=train_size,
val_size=val_size)
val_size=val_size,
force_reload=force_reload,
verbose=verbose)
class FraudAmazonDataset(FraudDataset):
......@@ -312,6 +327,10 @@ class FraudAmazonDataset(FraudDataset):
validation set size of the dataset, and the
size of testing set is (1 - train_size - val_size)
Default: 0.1
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Examples
--------
......@@ -322,9 +341,12 @@ class FraudAmazonDataset(FraudDataset):
>>> label = graph.ndata['label']
"""
def __init__(self, raw_dir=None, random_seed=717, train_size=0.7, val_size=0.1):
def __init__(self, raw_dir=None, random_seed=717, train_size=0.7,
val_size=0.1, force_reload=False, verbose=True):
super(FraudAmazonDataset, self).__init__(name='amazon',
raw_dir=raw_dir,
random_seed=random_seed,
train_size=train_size,
val_size=val_size)
val_size=val_size,
force_reload=force_reload,
verbose=verbose)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment