Unverified Commit 7ec78bb6 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Graphbolt] change dataset method to property. (#6023)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent 4135b1bd
......@@ -31,22 +31,27 @@ class Dataset:
generate a subgraph.
"""
@property
def train_sets(self) -> List[ItemSet] or List[ItemSetDict]:
"""Return the training sets."""
raise NotImplementedError
@property
def validation_sets(self) -> List[ItemSet] or List[ItemSetDict]:
"""Return the validation sets."""
raise NotImplementedError
@property
def test_sets(self) -> List[ItemSet] or List[ItemSetDict]:
"""Return the test sets."""
raise NotImplementedError
@property
def graph(self) -> object:
"""Return the graph."""
raise NotImplementedError
@property
def feature(self) -> Dict[object, FeatureStore]:
"""Return the feature."""
raise NotImplementedError
......
......@@ -281,22 +281,27 @@ class OnDiskDataset(Dataset):
self._validation_sets = self._init_tvt_sets(self._meta.validation_sets)
self._test_sets = self._init_tvt_sets(self._meta.test_sets)
@property
def train_sets(self) -> List[ItemSet] or List[ItemSetDict]:
"""Return the training set."""
return self._train_sets
@property
def validation_sets(self) -> List[ItemSet] or List[ItemSetDict]:
"""Return the validation set."""
return self._validation_sets
@property
def test_sets(self) -> List[ItemSet] or List[ItemSetDict]:
"""Return the test set."""
return self._test_sets
@property
def graph(self) -> object:
"""Return the graph."""
return self._graph
@property
def feature(self) -> Dict[Tuple, TorchBasedFeatureStore]:
"""Return the feature."""
return self._feature
......
......@@ -106,7 +106,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
dataset = gb.OnDiskDataset(yaml_file)
# Verify train set.
train_sets = dataset.train_sets()
train_sets = dataset.train_sets
assert len(train_sets) == 2
for train_set in train_sets:
assert len(train_set) == 1000
......@@ -117,7 +117,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
train_sets = None
# Verify validation set.
validation_sets = dataset.validation_sets()
validation_sets = dataset.validation_sets
assert len(validation_sets) == 2
for validation_set in validation_sets:
assert len(validation_set) == 1000
......@@ -128,7 +128,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
validation_sets = None
# Verify test set.
test_sets = dataset.test_sets()
test_sets = dataset.test_sets
assert len(test_sets) == 2
for test_set in test_sets:
assert len(test_set) == 1000
......@@ -151,9 +151,9 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
f.write(yaml_content)
dataset = gb.OnDiskDataset(yaml_file)
assert dataset.train_sets() is not None
assert dataset.validation_sets() is None
assert dataset.test_sets() is None
assert dataset.train_sets is not None
assert dataset.validation_sets is None
assert dataset.test_sets is None
dataset = None
......@@ -209,7 +209,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
dataset = gb.OnDiskDataset(yaml_file)
# Verify train set.
train_sets = dataset.train_sets()
train_sets = dataset.train_sets
assert len(train_sets) == 2
for train_set in train_sets:
assert len(train_set) == 1000
......@@ -221,7 +221,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
train_sets = None
# Verify validation set.
validation_sets = dataset.validation_sets()
validation_sets = dataset.validation_sets
assert len(validation_sets) == 2
for validation_set in validation_sets:
assert len(validation_set) == 1000
......@@ -233,7 +233,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
validation_sets = None
# Verify test set.
test_sets = dataset.test_sets()
test_sets = dataset.test_sets
assert len(test_sets) == 2
for test_set in test_sets:
assert len(test_set) == 1000
......@@ -299,7 +299,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
dataset = gb.OnDiskDataset(yaml_file)
# Verify train set.
train_sets = dataset.train_sets()
train_sets = dataset.train_sets
assert len(train_sets) == 2
for train_set in train_sets:
assert len(train_set) == 1000
......@@ -315,7 +315,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
train_sets = None
# Verify validation set.
validation_sets = dataset.validation_sets()
validation_sets = dataset.validation_sets
assert len(validation_sets) == 2
for validation_set in validation_sets:
assert len(validation_set) == 1000
......@@ -331,7 +331,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
validation_sets = None
# Verify test set.
test_sets = dataset.test_sets()
test_sets = dataset.test_sets
assert len(test_sets) == 2
for test_set in test_sets:
assert len(test_set) == 1000
......@@ -401,7 +401,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label():
dataset = gb.OnDiskDataset(yaml_file)
# Verify train set.
train_sets = dataset.train_sets()
train_sets = dataset.train_sets
assert len(train_sets) == 2
for train_set in train_sets:
assert len(train_set) == 1000
......@@ -418,7 +418,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label():
train_sets = None
# Verify validation set.
validation_sets = dataset.validation_sets()
validation_sets = dataset.validation_sets
assert len(validation_sets) == 2
for validation_set in validation_sets:
assert len(validation_set) == 1000
......@@ -435,7 +435,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label():
validation_sets = None
# Verify test set.
test_sets = dataset.test_sets()
test_sets = dataset.test_sets
assert len(test_sets) == 2
for test_set in test_sets:
assert len(test_set) == 1000
......@@ -507,7 +507,7 @@ def test_OnDiskDataset_Feature_heterograph():
dataset = gb.OnDiskDataset(yaml_file)
# Verify feature data storage.
feature_data = dataset.feature()
feature_data = dataset.feature
assert len(feature_data) == 4
# Verify node feature data.
......@@ -595,7 +595,7 @@ def test_OnDiskDataset_Feature_homograph():
dataset = gb.OnDiskDataset(yaml_file)
# Verify feature data storage.
feature_data = dataset.feature()
feature_data = dataset.feature
assert len(feature_data) == 4
# Verify node feature data.
......@@ -661,7 +661,7 @@ def test_OnDiskDataset_Graph_homogeneous():
f.write(yaml_content)
dataset = gb.OnDiskDataset(yaml_file)
graph2 = dataset.graph()
graph2 = dataset.graph
assert graph.num_nodes == graph2.num_nodes
assert graph.num_edges == graph2.num_edges
......@@ -703,7 +703,7 @@ def test_OnDiskDataset_Graph_heterogeneous():
f.write(yaml_content)
dataset = gb.OnDiskDataset(yaml_file)
graph2 = dataset.graph()
graph2 = dataset.graph
assert graph.num_nodes == graph2.num_nodes
assert graph.num_edges == graph2.num_edges
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment