"docs/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "d0c54e5563c3245b57d2b374e8e334da77305c05"
Unverified Commit 4ae13bd2 authored by Yuchen's avatar Yuchen Committed by GitHub
Browse files

[bugfix] Fix force_reload parameter of FraudDataset (#3210)

* enable force_reload of FraudDataset

* rewrite hash_key in FraudDataset
parent b1319200
......@@ -46,6 +46,10 @@ class FraudDataset(DGLBuiltinDataset):
validation set size of the dataset, and the
size of testing set is (1 - train_size - val_size)
Default: 0.1
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Attributes
----------
......@@ -84,8 +88,9 @@ class FraudDataset(DGLBuiltinDataset):
'yelp': 'review',
'amazon': 'user'
}
def __init__(self, name, raw_dir=None, random_seed=717, train_size=0.7, val_size=0.1):
def __init__(self, name, raw_dir=None, random_seed=717, train_size=0.7,
val_size=0.1, force_reload=False, verbose=True):
assert name in ['yelp', 'amazon'], "only supports 'yelp', or 'amazon'"
url = _get_dgl_url(self.file_urls[name])
self.seed = random_seed
......@@ -93,29 +98,32 @@ class FraudDataset(DGLBuiltinDataset):
self.val_size = val_size
super(FraudDataset, self).__init__(name=name,
url=url,
raw_dir=raw_dir)
raw_dir=raw_dir,
hash_key=(random_seed, train_size, val_size),
force_reload=force_reload,
verbose=verbose)
def process(self):
"""process raw data to graph, labels, splitting masks"""
file_path = os.path.join(self.raw_path, self.file_names[self.name])
data = io.loadmat(file_path)
node_features = data['features'].todense()
node_labels = data['label']
graph_data = {}
for relation in self.relations[self.name]:
adj = data[relation].tocoo()
row, col = adj.row, adj.col
graph_data[(self.node_name[self.name], relation, self.node_name[self.name])] = (row, col)
g = heterograph(graph_data)
g.ndata['feature'] = F.tensor(node_features)
g.ndata['label'] = F.tensor(node_labels.T)
self.graph = g
self._random_split(g.ndata['feature'], self.seed, self.train_size, self.val_size)
def __getitem__(self, idx):
r""" Get graph object
......@@ -137,11 +145,11 @@ class FraudDataset(DGLBuiltinDataset):
"""
assert idx == 0, "This dataset has only one graph"
return self.graph
def __len__(self):
"""number of data examples"""
return len(self.graph)
@property
def num_classes(self):
"""Number of classes.
......@@ -151,37 +159,37 @@ class FraudDataset(DGLBuiltinDataset):
int
"""
return 2
def save(self):
"""save processed data to directory `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin')
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph_{}.bin'.format(self.hash))
save_graphs(str(graph_path), self.graph)
def load(self):
"""load processed data from directory `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin')
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph_{}.bin'.format(self.hash))
graph_list, _ = load_graphs(str(graph_path))
g = graph_list[0]
self.graph = g
def has_cache(self):
"""check whether there are processed data in `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin')
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph_{}.bin'.format(self.hash))
return os.path.exists(graph_path)
def _random_split(self, x, seed=717, train_size=0.7, val_size=0.1):
"""split the dataset into training set, validation set and testing set"""
assert 0 <= train_size + val_size <= 1, \
"The sum of valid training set size and validation set size " \
"must between 0 and 1 (inclusive)."
N = x.shape[0]
index = np.arange(N)
if self.name == 'amazon':
# 0-3304 are unlabeled nodes
index = np.arange(3305, N)
index = np.random.RandomState(seed).permutation(index)
train_idx = index[:int(train_size * len(index))]
val_idx = index[len(index) - int(val_size * len(index)):]
......@@ -243,6 +251,10 @@ class FraudYelpDataset(FraudDataset):
validation set size of the dataset, and the
size of testing set is (1 - train_size - val_size)
Default: 0.1
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Examples
--------
......@@ -252,13 +264,16 @@ class FraudYelpDataset(FraudDataset):
>>> feat = graph.ndata['feature']
>>> label = graph.ndata['label']
"""
def __init__(self, raw_dir=None, random_seed=717, train_size=0.7, val_size=0.1):
def __init__(self, raw_dir=None, random_seed=717, train_size=0.7,
val_size=0.1, force_reload=False, verbose=True):
super(FraudYelpDataset, self).__init__(name='yelp',
raw_dir=raw_dir,
random_seed=random_seed,
train_size=train_size,
val_size=val_size)
val_size=val_size,
force_reload=force_reload,
verbose=verbose)
class FraudAmazonDataset(FraudDataset):
......@@ -312,6 +327,10 @@ class FraudAmazonDataset(FraudDataset):
validation set size of the dataset, and the
size of testing set is (1 - train_size - val_size)
Default: 0.1
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Examples
--------
......@@ -321,10 +340,13 @@ class FraudAmazonDataset(FraudDataset):
>>> feat = graph.ndata['feature']
>>> label = graph.ndata['label']
"""
def __init__(self, raw_dir=None, random_seed=717, train_size=0.7, val_size=0.1):
def __init__(self, raw_dir=None, random_seed=717, train_size=0.7,
val_size=0.1, force_reload=False, verbose=True):
super(FraudAmazonDataset, self).__init__(name='amazon',
raw_dir=raw_dir,
random_seed=random_seed,
train_size=train_size,
val_size=val_size)
val_size=val_size,
force_reload=force_reload,
verbose=verbose)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment