Unverified Commit 4ae13bd2 authored by Yuchen's avatar Yuchen Committed by GitHub
Browse files

[bugfix] Fix force_reload parameter of FraudDataset (#3210)

* enable force_reload of FraudDataset

* rewrite hash_key in FraudDataset
parent b1319200
...@@ -46,6 +46,10 @@ class FraudDataset(DGLBuiltinDataset): ...@@ -46,6 +46,10 @@ class FraudDataset(DGLBuiltinDataset):
validation set size of the dataset, and the validation set size of the dataset, and the
size of testing set is (1 - train_size - val_size) size of testing set is (1 - train_size - val_size)
Default: 0.1 Default: 0.1
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Attributes Attributes
---------- ----------
...@@ -85,7 +89,8 @@ class FraudDataset(DGLBuiltinDataset): ...@@ -85,7 +89,8 @@ class FraudDataset(DGLBuiltinDataset):
'amazon': 'user' 'amazon': 'user'
} }
def __init__(self, name, raw_dir=None, random_seed=717, train_size=0.7, val_size=0.1): def __init__(self, name, raw_dir=None, random_seed=717, train_size=0.7,
val_size=0.1, force_reload=False, verbose=True):
assert name in ['yelp', 'amazon'], "only supports 'yelp', or 'amazon'" assert name in ['yelp', 'amazon'], "only supports 'yelp', or 'amazon'"
url = _get_dgl_url(self.file_urls[name]) url = _get_dgl_url(self.file_urls[name])
self.seed = random_seed self.seed = random_seed
...@@ -93,7 +98,10 @@ class FraudDataset(DGLBuiltinDataset): ...@@ -93,7 +98,10 @@ class FraudDataset(DGLBuiltinDataset):
self.val_size = val_size self.val_size = val_size
super(FraudDataset, self).__init__(name=name, super(FraudDataset, self).__init__(name=name,
url=url, url=url,
raw_dir=raw_dir) raw_dir=raw_dir,
hash_key=(random_seed, train_size, val_size),
force_reload=force_reload,
verbose=verbose)
def process(self): def process(self):
"""process raw data to graph, labels, splitting masks""" """process raw data to graph, labels, splitting masks"""
...@@ -154,19 +162,19 @@ class FraudDataset(DGLBuiltinDataset): ...@@ -154,19 +162,19 @@ class FraudDataset(DGLBuiltinDataset):
def save(self): def save(self):
"""save processed data to directory `self.save_path`""" """save processed data to directory `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin') graph_path = os.path.join(self.save_path, self.name + '_dgl_graph_{}.bin'.format(self.hash))
save_graphs(str(graph_path), self.graph) save_graphs(str(graph_path), self.graph)
def load(self): def load(self):
"""load processed data from directory `self.save_path`""" """load processed data from directory `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin') graph_path = os.path.join(self.save_path, self.name + '_dgl_graph_{}.bin'.format(self.hash))
graph_list, _ = load_graphs(str(graph_path)) graph_list, _ = load_graphs(str(graph_path))
g = graph_list[0] g = graph_list[0]
self.graph = g self.graph = g
def has_cache(self): def has_cache(self):
"""check whether there are processed data in `self.save_path`""" """check whether there are processed data in `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin') graph_path = os.path.join(self.save_path, self.name + '_dgl_graph_{}.bin'.format(self.hash))
return os.path.exists(graph_path) return os.path.exists(graph_path)
def _random_split(self, x, seed=717, train_size=0.7, val_size=0.1): def _random_split(self, x, seed=717, train_size=0.7, val_size=0.1):
...@@ -243,6 +251,10 @@ class FraudYelpDataset(FraudDataset): ...@@ -243,6 +251,10 @@ class FraudYelpDataset(FraudDataset):
validation set size of the dataset, and the validation set size of the dataset, and the
size of testing set is (1 - train_size - val_size) size of testing set is (1 - train_size - val_size)
Default: 0.1 Default: 0.1
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Examples Examples
-------- --------
...@@ -253,12 +265,15 @@ class FraudYelpDataset(FraudDataset): ...@@ -253,12 +265,15 @@ class FraudYelpDataset(FraudDataset):
>>> label = graph.ndata['label'] >>> label = graph.ndata['label']
""" """
def __init__(self, raw_dir=None, random_seed=717, train_size=0.7, val_size=0.1): def __init__(self, raw_dir=None, random_seed=717, train_size=0.7,
val_size=0.1, force_reload=False, verbose=True):
super(FraudYelpDataset, self).__init__(name='yelp', super(FraudYelpDataset, self).__init__(name='yelp',
raw_dir=raw_dir, raw_dir=raw_dir,
random_seed=random_seed, random_seed=random_seed,
train_size=train_size, train_size=train_size,
val_size=val_size) val_size=val_size,
force_reload=force_reload,
verbose=verbose)
class FraudAmazonDataset(FraudDataset): class FraudAmazonDataset(FraudDataset):
...@@ -312,6 +327,10 @@ class FraudAmazonDataset(FraudDataset): ...@@ -312,6 +327,10 @@ class FraudAmazonDataset(FraudDataset):
validation set size of the dataset, and the validation set size of the dataset, and the
size of testing set is (1 - train_size - val_size) size of testing set is (1 - train_size - val_size)
Default: 0.1 Default: 0.1
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Examples Examples
-------- --------
...@@ -322,9 +341,12 @@ class FraudAmazonDataset(FraudDataset): ...@@ -322,9 +341,12 @@ class FraudAmazonDataset(FraudDataset):
>>> label = graph.ndata['label'] >>> label = graph.ndata['label']
""" """
def __init__(self, raw_dir=None, random_seed=717, train_size=0.7, val_size=0.1): def __init__(self, raw_dir=None, random_seed=717, train_size=0.7,
val_size=0.1, force_reload=False, verbose=True):
super(FraudAmazonDataset, self).__init__(name='amazon', super(FraudAmazonDataset, self).__init__(name='amazon',
raw_dir=raw_dir, raw_dir=raw_dir,
random_seed=random_seed, random_seed=random_seed,
train_size=train_size, train_size=train_size,
val_size=val_size) val_size=val_size,
force_reload=force_reload,
verbose=verbose)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment