Unverified Commit 4ae13bd2 authored by Yuchen's avatar Yuchen Committed by GitHub
Browse files

[bugfix] Fix force_reload parameter of FraudDataset (#3210)

* enable force_reload of FraudDataset

* rewrite hash_key in FraudDataset
parent b1319200
...@@ -46,6 +46,10 @@ class FraudDataset(DGLBuiltinDataset): ...@@ -46,6 +46,10 @@ class FraudDataset(DGLBuiltinDataset):
validation set size of the dataset, and the validation set size of the dataset, and the
size of testing set is (1 - train_size - val_size) size of testing set is (1 - train_size - val_size)
Default: 0.1 Default: 0.1
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Attributes Attributes
---------- ----------
...@@ -84,8 +88,9 @@ class FraudDataset(DGLBuiltinDataset): ...@@ -84,8 +88,9 @@ class FraudDataset(DGLBuiltinDataset):
'yelp': 'review', 'yelp': 'review',
'amazon': 'user' 'amazon': 'user'
} }
def __init__(self, name, raw_dir=None, random_seed=717, train_size=0.7, val_size=0.1): def __init__(self, name, raw_dir=None, random_seed=717, train_size=0.7,
val_size=0.1, force_reload=False, verbose=True):
assert name in ['yelp', 'amazon'], "only supports 'yelp', or 'amazon'" assert name in ['yelp', 'amazon'], "only supports 'yelp', or 'amazon'"
url = _get_dgl_url(self.file_urls[name]) url = _get_dgl_url(self.file_urls[name])
self.seed = random_seed self.seed = random_seed
...@@ -93,29 +98,32 @@ class FraudDataset(DGLBuiltinDataset): ...@@ -93,29 +98,32 @@ class FraudDataset(DGLBuiltinDataset):
self.val_size = val_size self.val_size = val_size
super(FraudDataset, self).__init__(name=name, super(FraudDataset, self).__init__(name=name,
url=url, url=url,
raw_dir=raw_dir) raw_dir=raw_dir,
hash_key=(random_seed, train_size, val_size),
force_reload=force_reload,
verbose=verbose)
def process(self): def process(self):
"""process raw data to graph, labels, splitting masks""" """process raw data to graph, labels, splitting masks"""
file_path = os.path.join(self.raw_path, self.file_names[self.name]) file_path = os.path.join(self.raw_path, self.file_names[self.name])
data = io.loadmat(file_path) data = io.loadmat(file_path)
node_features = data['features'].todense() node_features = data['features'].todense()
node_labels = data['label'] node_labels = data['label']
graph_data = {} graph_data = {}
for relation in self.relations[self.name]: for relation in self.relations[self.name]:
adj = data[relation].tocoo() adj = data[relation].tocoo()
row, col = adj.row, adj.col row, col = adj.row, adj.col
graph_data[(self.node_name[self.name], relation, self.node_name[self.name])] = (row, col) graph_data[(self.node_name[self.name], relation, self.node_name[self.name])] = (row, col)
g = heterograph(graph_data) g = heterograph(graph_data)
g.ndata['feature'] = F.tensor(node_features) g.ndata['feature'] = F.tensor(node_features)
g.ndata['label'] = F.tensor(node_labels.T) g.ndata['label'] = F.tensor(node_labels.T)
self.graph = g self.graph = g
self._random_split(g.ndata['feature'], self.seed, self.train_size, self.val_size) self._random_split(g.ndata['feature'], self.seed, self.train_size, self.val_size)
def __getitem__(self, idx): def __getitem__(self, idx):
r""" Get graph object r""" Get graph object
...@@ -137,11 +145,11 @@ class FraudDataset(DGLBuiltinDataset): ...@@ -137,11 +145,11 @@ class FraudDataset(DGLBuiltinDataset):
""" """
assert idx == 0, "This dataset has only one graph" assert idx == 0, "This dataset has only one graph"
return self.graph return self.graph
def __len__(self): def __len__(self):
"""number of data examples""" """number of data examples"""
return len(self.graph) return len(self.graph)
@property @property
def num_classes(self): def num_classes(self):
"""Number of classes. """Number of classes.
...@@ -151,37 +159,37 @@ class FraudDataset(DGLBuiltinDataset): ...@@ -151,37 +159,37 @@ class FraudDataset(DGLBuiltinDataset):
int int
""" """
return 2 return 2
def save(self): def save(self):
"""save processed data to directory `self.save_path`""" """save processed data to directory `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin') graph_path = os.path.join(self.save_path, self.name + '_dgl_graph_{}.bin'.format(self.hash))
save_graphs(str(graph_path), self.graph) save_graphs(str(graph_path), self.graph)
def load(self): def load(self):
"""load processed data from directory `self.save_path`""" """load processed data from directory `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin') graph_path = os.path.join(self.save_path, self.name + '_dgl_graph_{}.bin'.format(self.hash))
graph_list, _ = load_graphs(str(graph_path)) graph_list, _ = load_graphs(str(graph_path))
g = graph_list[0] g = graph_list[0]
self.graph = g self.graph = g
def has_cache(self): def has_cache(self):
"""check whether there are processed data in `self.save_path`""" """check whether there are processed data in `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin') graph_path = os.path.join(self.save_path, self.name + '_dgl_graph_{}.bin'.format(self.hash))
return os.path.exists(graph_path) return os.path.exists(graph_path)
def _random_split(self, x, seed=717, train_size=0.7, val_size=0.1): def _random_split(self, x, seed=717, train_size=0.7, val_size=0.1):
"""split the dataset into training set, validation set and testing set""" """split the dataset into training set, validation set and testing set"""
assert 0 <= train_size + val_size <= 1, \ assert 0 <= train_size + val_size <= 1, \
"The sum of valid training set size and validation set size " \ "The sum of valid training set size and validation set size " \
"must between 0 and 1 (inclusive)." "must between 0 and 1 (inclusive)."
N = x.shape[0] N = x.shape[0]
index = np.arange(N) index = np.arange(N)
if self.name == 'amazon': if self.name == 'amazon':
# 0-3304 are unlabeled nodes # 0-3304 are unlabeled nodes
index = np.arange(3305, N) index = np.arange(3305, N)
index = np.random.RandomState(seed).permutation(index) index = np.random.RandomState(seed).permutation(index)
train_idx = index[:int(train_size * len(index))] train_idx = index[:int(train_size * len(index))]
val_idx = index[len(index) - int(val_size * len(index)):] val_idx = index[len(index) - int(val_size * len(index)):]
...@@ -243,6 +251,10 @@ class FraudYelpDataset(FraudDataset): ...@@ -243,6 +251,10 @@ class FraudYelpDataset(FraudDataset):
validation set size of the dataset, and the validation set size of the dataset, and the
size of testing set is (1 - train_size - val_size) size of testing set is (1 - train_size - val_size)
Default: 0.1 Default: 0.1
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Examples Examples
-------- --------
...@@ -252,13 +264,16 @@ class FraudYelpDataset(FraudDataset): ...@@ -252,13 +264,16 @@ class FraudYelpDataset(FraudDataset):
>>> feat = graph.ndata['feature'] >>> feat = graph.ndata['feature']
>>> label = graph.ndata['label'] >>> label = graph.ndata['label']
""" """
def __init__(self, raw_dir=None, random_seed=717, train_size=0.7, val_size=0.1): def __init__(self, raw_dir=None, random_seed=717, train_size=0.7,
val_size=0.1, force_reload=False, verbose=True):
super(FraudYelpDataset, self).__init__(name='yelp', super(FraudYelpDataset, self).__init__(name='yelp',
raw_dir=raw_dir, raw_dir=raw_dir,
random_seed=random_seed, random_seed=random_seed,
train_size=train_size, train_size=train_size,
val_size=val_size) val_size=val_size,
force_reload=force_reload,
verbose=verbose)
class FraudAmazonDataset(FraudDataset): class FraudAmazonDataset(FraudDataset):
...@@ -312,6 +327,10 @@ class FraudAmazonDataset(FraudDataset): ...@@ -312,6 +327,10 @@ class FraudAmazonDataset(FraudDataset):
validation set size of the dataset, and the validation set size of the dataset, and the
size of testing set is (1 - train_size - val_size) size of testing set is (1 - train_size - val_size)
Default: 0.1 Default: 0.1
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Examples Examples
-------- --------
...@@ -321,10 +340,13 @@ class FraudAmazonDataset(FraudDataset): ...@@ -321,10 +340,13 @@ class FraudAmazonDataset(FraudDataset):
>>> feat = graph.ndata['feature'] >>> feat = graph.ndata['feature']
>>> label = graph.ndata['label'] >>> label = graph.ndata['label']
""" """
def __init__(self, raw_dir=None, random_seed=717, train_size=0.7, val_size=0.1): def __init__(self, raw_dir=None, random_seed=717, train_size=0.7,
val_size=0.1, force_reload=False, verbose=True):
super(FraudAmazonDataset, self).__init__(name='amazon', super(FraudAmazonDataset, self).__init__(name='amazon',
raw_dir=raw_dir, raw_dir=raw_dir,
random_seed=random_seed, random_seed=random_seed,
train_size=train_size, train_size=train_size,
val_size=val_size) val_size=val_size,
force_reload=force_reload,
verbose=verbose)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment