Unverified Commit 96ddf410 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] refactor FeatureStore and related impl (#6103)

parent 8e86c89c
"""GraphBolt Dataset."""
from typing import Dict
from .feature_store import FeatureStore
from .itemset import ItemSet, ItemSetDict
......@@ -52,7 +50,7 @@ class Dataset:
raise NotImplementedError
@property
def feature(self) -> Dict[object, FeatureStore]:
def feature(self) -> FeatureStore:
"""Return the feature."""
raise NotImplementedError
......
......@@ -11,11 +11,23 @@ class FeatureStore:
def __init__(self):
pass
def read(self, ids: torch.Tensor = None):
def read(
self,
domain: str,
type_name: str,
feature_name: str,
ids: torch.Tensor = None,
):
"""Read from the feature store.
Parameters
----------
domain : str
The domain of the feature such as "node", "edge" or "graph".
type_name : str
The node or edge type name.
feature_name : str
The feature name.
ids : torch.Tensor, optional
The index of the feature. If specified, only the specified indices
of the feature are read. If None, the entire feature is returned.
......@@ -27,11 +39,24 @@ class FeatureStore:
"""
raise NotImplementedError
def update(self, value: torch.Tensor, ids: torch.Tensor = None):
def update(
self,
domain: str,
type_name: str,
feature_name: str,
value: torch.Tensor,
ids: torch.Tensor = None,
):
"""Update the feature store.
Parameters
----------
domain : str
The domain of the feature such as "node", "edge" or "graph".
type_name : str
The node or edge type name.
feature_name : str
The feature name.
value : torch.Tensor
The updated value of the feature.
ids : torch.Tensor, optional
......
......@@ -5,7 +5,7 @@ import shutil
from copy import deepcopy
from pathlib import Path
from typing import Dict, List, Tuple
from typing import List
import pandas as pd
import torch
......@@ -24,10 +24,7 @@ from .csc_sampling_graph import (
save_csc_sampling_graph,
)
from .ondisk_metadata import OnDiskGraphTopology, OnDiskMetaData, OnDiskTVTSet
from .torch_based_feature_store import (
load_feature_stores,
TorchBasedFeatureStore,
)
from .torch_based_feature_store import TorchBasedFeatureStore
__all__ = ["OnDiskDataset", "preprocess_ondisk_dataset"]
......@@ -281,7 +278,7 @@ class OnDiskDataset(Dataset):
self._num_classes = self._meta.num_classes
self._num_labels = self._meta.num_labels
self._graph = self._load_graph(self._meta.graph_topology)
self._feature = load_feature_stores(self._meta.feature_data)
self._feature = TorchBasedFeatureStore(self._meta.feature_data)
self._train_set = self._init_tvt_set(self._meta.train_set)
self._validation_set = self._init_tvt_set(self._meta.validation_set)
self._test_set = self._init_tvt_set(self._meta.test_set)
......@@ -307,7 +304,7 @@ class OnDiskDataset(Dataset):
return self._graph
@property
def feature(self) -> Dict[Tuple, TorchBasedFeatureStore]:
def feature(self) -> TorchBasedFeatureStore:
"""Return the feature."""
return self._feature
......
......@@ -8,11 +8,11 @@ import torch
from ..feature_store import FeatureStore
from .ondisk_metadata import OnDiskFeatureData
__all__ = ["TorchBasedFeatureStore", "load_feature_stores"]
__all__ = ["TorchBasedFeature", "TorchBasedFeatureStore"]
class TorchBasedFeatureStore(FeatureStore):
r"""Torch based feature store."""
class TorchBasedFeature:
r"""Torch based feature."""
def __init__(self, torch_feature: torch.Tensor):
"""Initialize a torch based feature store by a torch feature.
......@@ -28,7 +28,7 @@ class TorchBasedFeatureStore(FeatureStore):
--------
>>> import torch
>>> torch_feat = torch.arange(0, 5)
>>> feature_store = TorchBasedFeatureStore(torch_feat)
>>> feature_store = TorchBasedFeature(torch_feat)
>>> feature_store.read()
tensor([0, 1, 2, 3, 4])
>>> feature_store.read(torch.tensor([0, 1, 2]))
......@@ -43,15 +43,14 @@ class TorchBasedFeatureStore(FeatureStore):
>>> np.save("/tmp/arr.npy", arr)
>>> torch_feat = torch.as_tensor(np.load("/tmp/arr.npy",
... mmap_mode="r+"))
>>> feature_store = TorchBasedFeatureStore(torch_feat)
>>> feature_store = TorchBasedFeature(torch_feat)
>>> feature_store.read()
tensor([0, 1, 2, 3, 4])
>>> feature_store.read(torch.tensor([0, 1, 2]))
tensor([0, 1, 2])
"""
super(TorchBasedFeatureStore, self).__init__()
assert isinstance(torch_feature, torch.Tensor), (
f"torch_feature in TorchBasedFeatureStore must be torch.Tensor, "
f"torch_feature in TorchBasedFeature must be torch.Tensor, "
f"but got {type(torch_feature)}."
)
self._tensor = torch_feature
......@@ -106,7 +105,10 @@ class TorchBasedFeatureStore(FeatureStore):
self._tensor[ids] = value
def load_feature_stores(feat_data: List[OnDiskFeatureData]):
class TorchBasedFeatureStore(FeatureStore):
r"""Torch based feature store."""
def __init__(self, feat_data: List[OnDiskFeatureData]):
r"""Load feature stores from disk.
The feature stores are described by the `feat_data`. The `feat_data` is a
......@@ -144,14 +146,10 @@ def load_feature_stores(feat_data: List[OnDiskFeatureData]):
... gb.OnDiskFeatureData(domain="node", type="paper", name="feat",
... format="numpy", path="/tmp/node_feat.npy", in_memory=False),
... ]
>>> gb.load_feature_stores(feat_data)
... {("edge", "author:writes:paper", "label"):
... <dgl.graphbolt.feature_store.TorchBasedFeatureStore object at
... 0x7ff093cb4df0>, ("node", "paper", "feat"):
... <dgl.graphbolt.feature_store.TorchBasedFeatureStore object at
... 0x7ff093cb4dc0>}
>>> feature_sotre = gb.TorchBasedFeatureStore(feat_data)
"""
feat_stores = {}
super().__init__()
self._features = {}
for spec in feat_data:
key = (spec.domain, spec.type, spec.name)
if spec.format == "torch":
......@@ -159,12 +157,72 @@ def load_feature_stores(feat_data: List[OnDiskFeatureData]):
f"Pytorch tensor can only be loaded in memory, "
f"but the feature {key} is loaded on disk."
)
feat_stores[key] = TorchBasedFeatureStore(torch.load(spec.path))
self._features[key] = TorchBasedFeature(torch.load(spec.path))
elif spec.format == "numpy":
mmap_mode = "r+" if not spec.in_memory else None
feat_stores[key] = TorchBasedFeatureStore(
self._features[key] = TorchBasedFeature(
torch.as_tensor(np.load(spec.path, mmap_mode=mmap_mode))
)
else:
raise ValueError(f"Unknown feature format {spec.format}")
return feat_stores
def read(
self,
domain: str,
type_name: str,
feature_name: str,
ids: torch.Tensor = None,
):
"""Read from the feature store.
Parameters
----------
domain : str
The domain of the feature such as "node", "edge" or "graph".
type_name : str
The node or edge type name.
feature_name : str
The feature name.
ids : torch.Tensor, optional
The index of the feature. If specified, only the specified indices
of the feature are read. If None, the entire feature is returned.
Returns
-------
torch.Tensor
The read feature.
"""
return self._features[(domain, type_name, feature_name)].read(ids)
def update(
self,
domain: str,
type_name: str,
feature_name: str,
value: torch.Tensor,
ids: torch.Tensor = None,
):
"""Update the feature store.
Parameters
----------
domain : str
The domain of the feature such as "node", "edge" or "graph".
type_name : str
The node or edge type name.
feature_name : str
The feature name.
value : torch.Tensor
The updated value of the feature.
ids : torch.Tensor, optional
The indices of the feature to update. If specified, only the
specified indices of the feature will be updated. For the feature,
the `ids[i]` row is updated to `value[i]`. So the indices and value
must have the same length. If None, the entire feature will be
updated.
"""
self._features[(domain, type_name, feature_name)].update(value, ids)
def __len__(self):
"""Return the number of features."""
return len(self._features)
......@@ -6,10 +6,8 @@ import torch
def get_graphbolt_fetch_func():
feature_store = {
"feature": dgl.graphbolt.TorchBasedFeatureStore(torch.randn(200, 4)),
"label": dgl.graphbolt.TorchBasedFeatureStore(
torch.randint(0, 10, (200,))
),
"feature": dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4)),
"label": dgl.graphbolt.TorchBasedFeature(torch.randint(0, 10, (200,))),
}
def fetch_func(data):
......
......@@ -19,7 +19,7 @@ def to_on_disk_tensor(test_dir, name, t):
@pytest.mark.parametrize("in_memory", [True, False])
def test_torch_based_feature_store(in_memory):
def test_torch_based_feature(in_memory):
with tempfile.TemporaryDirectory() as test_dir:
a = torch.tensor([1, 2, 3])
b = torch.tensor([[1, 2, 3], [4, 5, 6]])
......@@ -27,8 +27,8 @@ def test_torch_based_feature_store(in_memory):
a = to_on_disk_tensor(test_dir, "a", a)
b = to_on_disk_tensor(test_dir, "b", b)
feat_store_a = gb.TorchBasedFeatureStore(a)
feat_store_b = gb.TorchBasedFeatureStore(b)
feat_store_a = gb.TorchBasedFeature(a)
feat_store_b = gb.TorchBasedFeature(b)
assert torch.equal(feat_store_a.read(), torch.tensor([1, 2, 3]))
assert torch.equal(
......@@ -71,7 +71,7 @@ def write_tensor_to_disk(dir, name, t, fmt="torch"):
@pytest.mark.parametrize("in_memory", [True, False])
def test_load_feature_stores(in_memory):
def test_torch_based_feature_store(in_memory):
with tempfile.TemporaryDirectory() as test_dir:
a = torch.tensor([1, 2, 3])
b = torch.tensor([2, 5, 3])
......@@ -95,12 +95,12 @@ def test_load_feature_stores(in_memory):
in_memory=in_memory,
),
]
feat_stores = gb.load_feature_stores(feat_data)
feat_store = gb.TorchBasedFeatureStore(feat_data)
assert torch.equal(
feat_stores[("node", "paper", "a")].read(), torch.tensor([1, 2, 3])
feat_store.read("node", "paper", "a"), torch.tensor([1, 2, 3])
)
assert torch.equal(
feat_stores[("edge", "paper-cites-paper", "b")].read(),
feat_store.read("edge", "paper-cites-paper", "b"),
torch.tensor([2, 5, 3]),
)
......@@ -130,6 +130,8 @@ def test_load_feature_stores(in_memory):
in_memory=True,
),
]
feat_stores = gb.load_feature_stores(feat_data)
assert ("node", None, "a") in feat_stores
feat_store = gb.TorchBasedFeatureStore(feat_data)
assert torch.equal(
feat_store.read("node", None, "a"), torch.tensor([1, 2, 3])
)
feat_stores = None
......@@ -27,8 +27,8 @@ def test_DataLoader():
# TODO(BarclayII): temporarily using DGLGraph. Should test using
# GraphBolt's storage as well once issue #5953 is resolved.
graph = dgl.add_reverse_edges(dgl.rand_graph(200, 6000))
features = dgl.graphbolt.TorchBasedFeatureStore(torch.randn(200, 4))
labels = dgl.graphbolt.TorchBasedFeatureStore(torch.randint(0, 10, (200,)))
features = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4))
labels = dgl.graphbolt.TorchBasedFeature(torch.randint(0, 10, (200,)))
minibatch_sampler = dgl.graphbolt.MinibatchSampler(itemset, batch_size=B)
subgraph_sampler = dgl.graphbolt.SubgraphSampler(
......
......@@ -647,35 +647,25 @@ def test_OnDiskDataset_Feature_heterograph():
assert len(feature_data) == 4
# Verify node feature data.
node_paper_feat = feature_data[("node", "paper", "feat")]
assert isinstance(node_paper_feat, gb.TorchBasedFeatureStore)
assert torch.equal(
node_paper_feat.read(), torch.tensor(node_data_paper)
feature_data.read("node", "paper", "feat"),
torch.tensor(node_data_paper),
)
node_paper_label = feature_data[("node", "paper", "label")]
assert isinstance(node_paper_label, gb.TorchBasedFeatureStore)
assert torch.equal(
node_paper_label.read(), torch.tensor(node_data_label)
feature_data.read("node", "paper", "label"),
torch.tensor(node_data_label),
)
# Verify edge feature data.
edge_writes_feat = feature_data[("edge", "author:writes:paper", "feat")]
assert isinstance(edge_writes_feat, gb.TorchBasedFeatureStore)
assert torch.equal(
edge_writes_feat.read(), torch.tensor(edge_data_writes)
feature_data.read("edge", "author:writes:paper", "feat"),
torch.tensor(edge_data_writes),
)
edge_writes_label = feature_data[
("edge", "author:writes:paper", "label")
]
assert isinstance(edge_writes_label, gb.TorchBasedFeatureStore)
assert torch.equal(
edge_writes_label.read(), torch.tensor(edge_data_label)
feature_data.read("edge", "author:writes:paper", "label"),
torch.tensor(edge_data_label),
)
node_paper_feat = None
node_paper_label = None
edge_writes_feat = None
edge_writes_label = None
feature_data = None
dataset = None
......@@ -735,25 +725,25 @@ def test_OnDiskDataset_Feature_homograph():
assert len(feature_data) == 4
# Verify node feature data.
node_feat = feature_data[("node", None, "feat")]
assert isinstance(node_feat, gb.TorchBasedFeatureStore)
assert torch.equal(node_feat.read(), torch.tensor(node_data_feat))
node_label = feature_data[("node", None, "label")]
assert isinstance(node_label, gb.TorchBasedFeatureStore)
assert torch.equal(node_label.read(), torch.tensor(node_data_label))
assert torch.equal(
feature_data.read("node", None, "feat"),
torch.tensor(node_data_feat),
)
assert torch.equal(
feature_data.read("node", None, "label"),
torch.tensor(node_data_label),
)
# Verify edge feature data.
edge_feat = feature_data[("edge", None, "feat")]
assert isinstance(edge_feat, gb.TorchBasedFeatureStore)
assert torch.equal(edge_feat.read(), torch.tensor(edge_data_feat))
edge_label = feature_data[("edge", None, "label")]
assert isinstance(edge_label, gb.TorchBasedFeatureStore)
assert torch.equal(edge_label.read(), torch.tensor(edge_data_label))
node_feat = None
node_label = None
edge_feat = None
edge_label = None
assert torch.equal(
feature_data.read("edge", None, "feat"),
torch.tensor(edge_data_feat),
)
assert torch.equal(
feature_data.read("edge", None, "label"),
torch.tensor(edge_data_label),
)
feature_data = None
dataset = None
......
......@@ -10,8 +10,8 @@ def test_DataLoader():
B = 4
itemset = dgl.graphbolt.ItemSet(torch.arange(N))
graph = gb_test_utils.rand_csc_graph(200, 0.15)
features = dgl.graphbolt.TorchBasedFeatureStore(torch.randn(200, 4))
labels = dgl.graphbolt.TorchBasedFeatureStore(torch.randint(0, 10, (200,)))
features = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4))
labels = dgl.graphbolt.TorchBasedFeature(torch.randint(0, 10, (200,)))
def sampler_func(data):
adjs = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment