"git@developer.sourcefind.cn:OpenDAS/bitsandbytes.git" did not exist on "32fa459ed7c812c79e847145004061f21b7ac0d9"
Unverified Commit 7ec78bb6 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Graphbolt] change dataset method to property. (#6023)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent 4135b1bd
...@@ -31,22 +31,27 @@ class Dataset: ...@@ -31,22 +31,27 @@ class Dataset:
generate a subgraph. generate a subgraph.
""" """
@property
def train_sets(self) -> List[ItemSet] or List[ItemSetDict]: def train_sets(self) -> List[ItemSet] or List[ItemSetDict]:
"""Return the training sets.""" """Return the training sets."""
raise NotImplementedError raise NotImplementedError
@property
def validation_sets(self) -> List[ItemSet] or List[ItemSetDict]: def validation_sets(self) -> List[ItemSet] or List[ItemSetDict]:
"""Return the validation sets.""" """Return the validation sets."""
raise NotImplementedError raise NotImplementedError
@property
def test_sets(self) -> List[ItemSet] or List[ItemSetDict]: def test_sets(self) -> List[ItemSet] or List[ItemSetDict]:
"""Return the test sets.""" """Return the test sets."""
raise NotImplementedError raise NotImplementedError
@property
def graph(self) -> object: def graph(self) -> object:
"""Return the graph.""" """Return the graph."""
raise NotImplementedError raise NotImplementedError
@property
def feature(self) -> Dict[object, FeatureStore]: def feature(self) -> Dict[object, FeatureStore]:
"""Return the feature.""" """Return the feature."""
raise NotImplementedError raise NotImplementedError
......
...@@ -281,22 +281,27 @@ class OnDiskDataset(Dataset): ...@@ -281,22 +281,27 @@ class OnDiskDataset(Dataset):
self._validation_sets = self._init_tvt_sets(self._meta.validation_sets) self._validation_sets = self._init_tvt_sets(self._meta.validation_sets)
self._test_sets = self._init_tvt_sets(self._meta.test_sets) self._test_sets = self._init_tvt_sets(self._meta.test_sets)
@property
def train_sets(self) -> List[ItemSet] or List[ItemSetDict]: def train_sets(self) -> List[ItemSet] or List[ItemSetDict]:
"""Return the training set.""" """Return the training set."""
return self._train_sets return self._train_sets
@property
def validation_sets(self) -> List[ItemSet] or List[ItemSetDict]: def validation_sets(self) -> List[ItemSet] or List[ItemSetDict]:
"""Return the validation set.""" """Return the validation set."""
return self._validation_sets return self._validation_sets
@property
def test_sets(self) -> List[ItemSet] or List[ItemSetDict]: def test_sets(self) -> List[ItemSet] or List[ItemSetDict]:
"""Return the test set.""" """Return the test set."""
return self._test_sets return self._test_sets
@property
def graph(self) -> object: def graph(self) -> object:
"""Return the graph.""" """Return the graph."""
return self._graph return self._graph
@property
def feature(self) -> Dict[Tuple, TorchBasedFeatureStore]: def feature(self) -> Dict[Tuple, TorchBasedFeatureStore]:
"""Return the feature.""" """Return the feature."""
return self._feature return self._feature
......
...@@ -106,7 +106,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): ...@@ -106,7 +106,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
dataset = gb.OnDiskDataset(yaml_file) dataset = gb.OnDiskDataset(yaml_file)
# Verify train set. # Verify train set.
train_sets = dataset.train_sets() train_sets = dataset.train_sets
assert len(train_sets) == 2 assert len(train_sets) == 2
for train_set in train_sets: for train_set in train_sets:
assert len(train_set) == 1000 assert len(train_set) == 1000
...@@ -117,7 +117,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): ...@@ -117,7 +117,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
train_sets = None train_sets = None
# Verify validation set. # Verify validation set.
validation_sets = dataset.validation_sets() validation_sets = dataset.validation_sets
assert len(validation_sets) == 2 assert len(validation_sets) == 2
for validation_set in validation_sets: for validation_set in validation_sets:
assert len(validation_set) == 1000 assert len(validation_set) == 1000
...@@ -128,7 +128,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): ...@@ -128,7 +128,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
validation_sets = None validation_sets = None
# Verify test set. # Verify test set.
test_sets = dataset.test_sets() test_sets = dataset.test_sets
assert len(test_sets) == 2 assert len(test_sets) == 2
for test_set in test_sets: for test_set in test_sets:
assert len(test_set) == 1000 assert len(test_set) == 1000
...@@ -151,9 +151,9 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): ...@@ -151,9 +151,9 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
f.write(yaml_content) f.write(yaml_content)
dataset = gb.OnDiskDataset(yaml_file) dataset = gb.OnDiskDataset(yaml_file)
assert dataset.train_sets() is not None assert dataset.train_sets is not None
assert dataset.validation_sets() is None assert dataset.validation_sets is None
assert dataset.test_sets() is None assert dataset.test_sets is None
dataset = None dataset = None
...@@ -209,7 +209,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label(): ...@@ -209,7 +209,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
dataset = gb.OnDiskDataset(yaml_file) dataset = gb.OnDiskDataset(yaml_file)
# Verify train set. # Verify train set.
train_sets = dataset.train_sets() train_sets = dataset.train_sets
assert len(train_sets) == 2 assert len(train_sets) == 2
for train_set in train_sets: for train_set in train_sets:
assert len(train_set) == 1000 assert len(train_set) == 1000
...@@ -221,7 +221,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label(): ...@@ -221,7 +221,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
train_sets = None train_sets = None
# Verify validation set. # Verify validation set.
validation_sets = dataset.validation_sets() validation_sets = dataset.validation_sets
assert len(validation_sets) == 2 assert len(validation_sets) == 2
for validation_set in validation_sets: for validation_set in validation_sets:
assert len(validation_set) == 1000 assert len(validation_set) == 1000
...@@ -233,7 +233,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label(): ...@@ -233,7 +233,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
validation_sets = None validation_sets = None
# Verify test set. # Verify test set.
test_sets = dataset.test_sets() test_sets = dataset.test_sets
assert len(test_sets) == 2 assert len(test_sets) == 2
for test_set in test_sets: for test_set in test_sets:
assert len(test_set) == 1000 assert len(test_set) == 1000
...@@ -299,7 +299,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label(): ...@@ -299,7 +299,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
dataset = gb.OnDiskDataset(yaml_file) dataset = gb.OnDiskDataset(yaml_file)
# Verify train set. # Verify train set.
train_sets = dataset.train_sets() train_sets = dataset.train_sets
assert len(train_sets) == 2 assert len(train_sets) == 2
for train_set in train_sets: for train_set in train_sets:
assert len(train_set) == 1000 assert len(train_set) == 1000
...@@ -315,7 +315,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label(): ...@@ -315,7 +315,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
train_sets = None train_sets = None
# Verify validation set. # Verify validation set.
validation_sets = dataset.validation_sets() validation_sets = dataset.validation_sets
assert len(validation_sets) == 2 assert len(validation_sets) == 2
for validation_set in validation_sets: for validation_set in validation_sets:
assert len(validation_set) == 1000 assert len(validation_set) == 1000
...@@ -331,7 +331,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label(): ...@@ -331,7 +331,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
validation_sets = None validation_sets = None
# Verify test set. # Verify test set.
test_sets = dataset.test_sets() test_sets = dataset.test_sets
assert len(test_sets) == 2 assert len(test_sets) == 2
for test_set in test_sets: for test_set in test_sets:
assert len(test_set) == 1000 assert len(test_set) == 1000
...@@ -401,7 +401,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label(): ...@@ -401,7 +401,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label():
dataset = gb.OnDiskDataset(yaml_file) dataset = gb.OnDiskDataset(yaml_file)
# Verify train set. # Verify train set.
train_sets = dataset.train_sets() train_sets = dataset.train_sets
assert len(train_sets) == 2 assert len(train_sets) == 2
for train_set in train_sets: for train_set in train_sets:
assert len(train_set) == 1000 assert len(train_set) == 1000
...@@ -418,7 +418,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label(): ...@@ -418,7 +418,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label():
train_sets = None train_sets = None
# Verify validation set. # Verify validation set.
validation_sets = dataset.validation_sets() validation_sets = dataset.validation_sets
assert len(validation_sets) == 2 assert len(validation_sets) == 2
for validation_set in validation_sets: for validation_set in validation_sets:
assert len(validation_set) == 1000 assert len(validation_set) == 1000
...@@ -435,7 +435,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label(): ...@@ -435,7 +435,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label():
validation_sets = None validation_sets = None
# Verify test set. # Verify test set.
test_sets = dataset.test_sets() test_sets = dataset.test_sets
assert len(test_sets) == 2 assert len(test_sets) == 2
for test_set in test_sets: for test_set in test_sets:
assert len(test_set) == 1000 assert len(test_set) == 1000
...@@ -507,7 +507,7 @@ def test_OnDiskDataset_Feature_heterograph(): ...@@ -507,7 +507,7 @@ def test_OnDiskDataset_Feature_heterograph():
dataset = gb.OnDiskDataset(yaml_file) dataset = gb.OnDiskDataset(yaml_file)
# Verify feature data storage. # Verify feature data storage.
feature_data = dataset.feature() feature_data = dataset.feature
assert len(feature_data) == 4 assert len(feature_data) == 4
# Verify node feature data. # Verify node feature data.
...@@ -595,7 +595,7 @@ def test_OnDiskDataset_Feature_homograph(): ...@@ -595,7 +595,7 @@ def test_OnDiskDataset_Feature_homograph():
dataset = gb.OnDiskDataset(yaml_file) dataset = gb.OnDiskDataset(yaml_file)
# Verify feature data storage. # Verify feature data storage.
feature_data = dataset.feature() feature_data = dataset.feature
assert len(feature_data) == 4 assert len(feature_data) == 4
# Verify node feature data. # Verify node feature data.
...@@ -661,7 +661,7 @@ def test_OnDiskDataset_Graph_homogeneous(): ...@@ -661,7 +661,7 @@ def test_OnDiskDataset_Graph_homogeneous():
f.write(yaml_content) f.write(yaml_content)
dataset = gb.OnDiskDataset(yaml_file) dataset = gb.OnDiskDataset(yaml_file)
graph2 = dataset.graph() graph2 = dataset.graph
assert graph.num_nodes == graph2.num_nodes assert graph.num_nodes == graph2.num_nodes
assert graph.num_edges == graph2.num_edges assert graph.num_edges == graph2.num_edges
...@@ -703,7 +703,7 @@ def test_OnDiskDataset_Graph_heterogeneous(): ...@@ -703,7 +703,7 @@ def test_OnDiskDataset_Graph_heterogeneous():
f.write(yaml_content) f.write(yaml_content)
dataset = gb.OnDiskDataset(yaml_file) dataset = gb.OnDiskDataset(yaml_file)
graph2 = dataset.graph() graph2 = dataset.graph
assert graph.num_nodes == graph2.num_nodes assert graph.num_nodes == graph2.num_nodes
assert graph.num_edges == graph2.num_edges assert graph.num_edges == graph2.num_edges
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment