Unverified Commit 25217dc6 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] add misc metadata for Dataset (#5976)

parent 92a46d12
...@@ -50,3 +50,18 @@ class Dataset: ...@@ -50,3 +50,18 @@ class Dataset:
def feature(self) -> Dict[object, FeatureStore]: def feature(self) -> Dict[object, FeatureStore]:
"""Return the feature.""" """Return the feature."""
raise NotImplementedError raise NotImplementedError
@property
def dataset_name(self) -> str:
"""Return the dataset name."""
raise NotImplementedError
@property
def num_classes(self) -> int:
"""Return the number of classes."""
raise NotImplementedError
@property
def num_labels(self) -> int:
"""Return the number of labels."""
raise NotImplementedError
...@@ -29,6 +29,9 @@ class OnDiskDataset(Dataset): ...@@ -29,6 +29,9 @@ class OnDiskDataset(Dataset):
.. code-block:: yaml .. code-block:: yaml
dataset_name: graphbolt_test
num_classes: 10
num_labels: 10
graph_topology: graph_topology:
type: CSCSamplingGraph type: CSCSamplingGraph
path: graph_topology/csc_sampling_graph.tar path: graph_topology/csc_sampling_graph.tar
...@@ -70,6 +73,9 @@ class OnDiskDataset(Dataset): ...@@ -70,6 +73,9 @@ class OnDiskDataset(Dataset):
def __init__(self, path: str) -> None: def __init__(self, path: str) -> None:
with open(path, "r") as f: with open(path, "r") as f:
self._meta = OnDiskMetaData.parse_raw(f.read(), proto="yaml") self._meta = OnDiskMetaData.parse_raw(f.read(), proto="yaml")
self._dataset_name = self._meta.dataset_name
self._num_classes = self._meta.num_classes
self._num_labels = self._meta.num_labels
self._graph = self._load_graph(self._meta.graph_topology) self._graph = self._load_graph(self._meta.graph_topology)
self._feature = load_feature_stores(self._meta.feature_data) self._feature = load_feature_stores(self._meta.feature_data)
self._train_sets = self._init_tvt_sets(self._meta.train_sets) self._train_sets = self._init_tvt_sets(self._meta.train_sets)
...@@ -96,6 +102,21 @@ class OnDiskDataset(Dataset): ...@@ -96,6 +102,21 @@ class OnDiskDataset(Dataset):
"""Return the feature.""" """Return the feature."""
return self._feature return self._feature
@property
def dataset_name(self) -> str:
"""Return the dataset name."""
return self._dataset_name
@property
def num_classes(self) -> int:
"""Return the number of classes."""
return self._num_classes
@property
def num_labels(self) -> int:
"""Return the number of labels."""
return self._num_labels
def _load_graph( def _load_graph(
self, graph_topology: OnDiskGraphTopology self, graph_topology: OnDiskGraphTopology
) -> CSCSamplingGraph: ) -> CSCSamplingGraph:
......
...@@ -71,6 +71,9 @@ class OnDiskMetaData(pydantic_yaml.YamlModel): ...@@ -71,6 +71,9 @@ class OnDiskMetaData(pydantic_yaml.YamlModel):
is a list of list of ``OnDiskTVTSet``. is a list of list of ``OnDiskTVTSet``.
""" """
dataset_name: Optional[str] = None
num_classes: Optional[int] = None
num_labels: Optional[int] = None
graph_topology: Optional[OnDiskGraphTopology] = None graph_topology: Optional[OnDiskGraphTopology] = None
feature_data: Optional[List[OnDiskFeatureData]] = [] feature_data: Optional[List[OnDiskFeatureData]] = []
train_sets: Optional[List[List[OnDiskTVTSet]]] = [] train_sets: Optional[List[List[OnDiskTVTSet]]] = []
......
...@@ -20,3 +20,9 @@ def test_Dataset(): ...@@ -20,3 +20,9 @@ def test_Dataset():
_ = dataset.graph() _ = dataset.graph()
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
_ = dataset.feature() _ = dataset.feature()
with pytest.raises(NotImplementedError):
_ = dataset.dataset_name
with pytest.raises(NotImplementedError):
_ = dataset.num_classes
with pytest.raises(NotImplementedError):
_ = dataset.num_labels
...@@ -712,3 +712,38 @@ def test_OnDiskDataset_Graph_heterogeneous(): ...@@ -712,3 +712,38 @@ def test_OnDiskDataset_Graph_heterogeneous():
assert torch.equal(graph.type_per_edge, graph2.type_per_edge) assert torch.equal(graph.type_per_edge, graph2.type_per_edge)
assert graph.metadata.node_type_to_id == graph2.metadata.node_type_to_id assert graph.metadata.node_type_to_id == graph2.metadata.node_type_to_id
assert graph.metadata.edge_type_to_id == graph2.metadata.edge_type_to_id assert graph.metadata.edge_type_to_id == graph2.metadata.edge_type_to_id
def test_OnDiskDataset_Metadata():
"""Test metadata of OnDiskDataset."""
with tempfile.TemporaryDirectory() as test_dir:
# All metadata fields are specified.
dataset_name = "graphbolt_test"
num_classes = 10
num_labels = 9
yaml_content = f"""
dataset_name: {dataset_name}
num_classes: {num_classes}
num_labels: {num_labels}
"""
yaml_file = os.path.join(test_dir, "test.yaml")
with open(yaml_file, "w") as f:
f.write(yaml_content)
dataset = gb.OnDiskDataset(yaml_file)
assert dataset.dataset_name == dataset_name
assert dataset.num_classes == num_classes
assert dataset.num_labels == num_labels
# Only dataset_name is specified.
yaml_content = f"""
dataset_name: {dataset_name}
"""
yaml_file = os.path.join(test_dir, "test.yaml")
with open(yaml_file, "w") as f:
f.write(yaml_content)
dataset = gb.OnDiskDataset(yaml_file)
assert dataset.dataset_name == dataset_name
assert dataset.num_classes is None
assert dataset.num_labels is None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment