"vscode:/vscode.git/clone" did not exist on "cd502b25cf0debac6f98d27a6638ef95208d1ea2"
Unverified Commit 96ddf410 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] refactor FeatureStore and related impl (#6103)

parent 8e86c89c
"""GraphBolt Dataset.""" """GraphBolt Dataset."""
from typing import Dict
from .feature_store import FeatureStore from .feature_store import FeatureStore
from .itemset import ItemSet, ItemSetDict from .itemset import ItemSet, ItemSetDict
...@@ -52,7 +50,7 @@ class Dataset: ...@@ -52,7 +50,7 @@ class Dataset:
raise NotImplementedError raise NotImplementedError
@property @property
def feature(self) -> Dict[object, FeatureStore]: def feature(self) -> FeatureStore:
"""Return the feature.""" """Return the feature."""
raise NotImplementedError raise NotImplementedError
......
...@@ -11,11 +11,23 @@ class FeatureStore: ...@@ -11,11 +11,23 @@ class FeatureStore:
def __init__(self): def __init__(self):
pass pass
def read(self, ids: torch.Tensor = None): def read(
self,
domain: str,
type_name: str,
feature_name: str,
ids: torch.Tensor = None,
):
"""Read from the feature store. """Read from the feature store.
Parameters Parameters
---------- ----------
domain : str
The domain of the feature such as "node", "edge" or "graph".
type_name : str
The node or edge type name.
feature_name : str
The feature name.
ids : torch.Tensor, optional ids : torch.Tensor, optional
The index of the feature. If specified, only the specified indices The index of the feature. If specified, only the specified indices
of the feature are read. If None, the entire feature is returned. of the feature are read. If None, the entire feature is returned.
...@@ -27,11 +39,24 @@ class FeatureStore: ...@@ -27,11 +39,24 @@ class FeatureStore:
""" """
raise NotImplementedError raise NotImplementedError
def update(self, value: torch.Tensor, ids: torch.Tensor = None): def update(
self,
domain: str,
type_name: str,
feature_name: str,
value: torch.Tensor,
ids: torch.Tensor = None,
):
"""Update the feature store. """Update the feature store.
Parameters Parameters
---------- ----------
domain : str
The domain of the feature such as "node", "edge" or "graph".
type_name : str
The node or edge type name.
feature_name : str
The feature name.
value : torch.Tensor value : torch.Tensor
The updated value of the feature. The updated value of the feature.
ids : torch.Tensor, optional ids : torch.Tensor, optional
......
...@@ -5,7 +5,7 @@ import shutil ...@@ -5,7 +5,7 @@ import shutil
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import List
import pandas as pd import pandas as pd
import torch import torch
...@@ -24,10 +24,7 @@ from .csc_sampling_graph import ( ...@@ -24,10 +24,7 @@ from .csc_sampling_graph import (
save_csc_sampling_graph, save_csc_sampling_graph,
) )
from .ondisk_metadata import OnDiskGraphTopology, OnDiskMetaData, OnDiskTVTSet from .ondisk_metadata import OnDiskGraphTopology, OnDiskMetaData, OnDiskTVTSet
from .torch_based_feature_store import ( from .torch_based_feature_store import TorchBasedFeatureStore
load_feature_stores,
TorchBasedFeatureStore,
)
__all__ = ["OnDiskDataset", "preprocess_ondisk_dataset"] __all__ = ["OnDiskDataset", "preprocess_ondisk_dataset"]
...@@ -281,7 +278,7 @@ class OnDiskDataset(Dataset): ...@@ -281,7 +278,7 @@ class OnDiskDataset(Dataset):
self._num_classes = self._meta.num_classes self._num_classes = self._meta.num_classes
self._num_labels = self._meta.num_labels self._num_labels = self._meta.num_labels
self._graph = self._load_graph(self._meta.graph_topology) self._graph = self._load_graph(self._meta.graph_topology)
self._feature = load_feature_stores(self._meta.feature_data) self._feature = TorchBasedFeatureStore(self._meta.feature_data)
self._train_set = self._init_tvt_set(self._meta.train_set) self._train_set = self._init_tvt_set(self._meta.train_set)
self._validation_set = self._init_tvt_set(self._meta.validation_set) self._validation_set = self._init_tvt_set(self._meta.validation_set)
self._test_set = self._init_tvt_set(self._meta.test_set) self._test_set = self._init_tvt_set(self._meta.test_set)
...@@ -307,7 +304,7 @@ class OnDiskDataset(Dataset): ...@@ -307,7 +304,7 @@ class OnDiskDataset(Dataset):
return self._graph return self._graph
@property @property
def feature(self) -> Dict[Tuple, TorchBasedFeatureStore]: def feature(self) -> TorchBasedFeatureStore:
"""Return the feature.""" """Return the feature."""
return self._feature return self._feature
......
...@@ -8,11 +8,11 @@ import torch ...@@ -8,11 +8,11 @@ import torch
from ..feature_store import FeatureStore from ..feature_store import FeatureStore
from .ondisk_metadata import OnDiskFeatureData from .ondisk_metadata import OnDiskFeatureData
__all__ = ["TorchBasedFeatureStore", "load_feature_stores"] __all__ = ["TorchBasedFeature", "TorchBasedFeatureStore"]
class TorchBasedFeatureStore(FeatureStore): class TorchBasedFeature:
r"""Torch based feature store.""" r"""Torch based feature."""
def __init__(self, torch_feature: torch.Tensor): def __init__(self, torch_feature: torch.Tensor):
"""Initialize a torch based feature store by a torch feature. """Initialize a torch based feature store by a torch feature.
...@@ -28,7 +28,7 @@ class TorchBasedFeatureStore(FeatureStore): ...@@ -28,7 +28,7 @@ class TorchBasedFeatureStore(FeatureStore):
-------- --------
>>> import torch >>> import torch
>>> torch_feat = torch.arange(0, 5) >>> torch_feat = torch.arange(0, 5)
>>> feature_store = TorchBasedFeatureStore(torch_feat) >>> feature_store = TorchBasedFeature(torch_feat)
>>> feature_store.read() >>> feature_store.read()
tensor([0, 1, 2, 3, 4]) tensor([0, 1, 2, 3, 4])
>>> feature_store.read(torch.tensor([0, 1, 2])) >>> feature_store.read(torch.tensor([0, 1, 2]))
...@@ -43,15 +43,14 @@ class TorchBasedFeatureStore(FeatureStore): ...@@ -43,15 +43,14 @@ class TorchBasedFeatureStore(FeatureStore):
>>> np.save("/tmp/arr.npy", arr) >>> np.save("/tmp/arr.npy", arr)
>>> torch_feat = torch.as_tensor(np.load("/tmp/arr.npy", >>> torch_feat = torch.as_tensor(np.load("/tmp/arr.npy",
... mmap_mode="r+")) ... mmap_mode="r+"))
>>> feature_store = TorchBasedFeatureStore(torch_feat) >>> feature_store = TorchBasedFeature(torch_feat)
>>> feature_store.read() >>> feature_store.read()
tensor([0, 1, 2, 3, 4]) tensor([0, 1, 2, 3, 4])
>>> feature_store.read(torch.tensor([0, 1, 2])) >>> feature_store.read(torch.tensor([0, 1, 2]))
tensor([0, 1, 2]) tensor([0, 1, 2])
""" """
super(TorchBasedFeatureStore, self).__init__()
assert isinstance(torch_feature, torch.Tensor), ( assert isinstance(torch_feature, torch.Tensor), (
f"torch_feature in TorchBasedFeatureStore must be torch.Tensor, " f"torch_feature in TorchBasedFeature must be torch.Tensor, "
f"but got {type(torch_feature)}." f"but got {type(torch_feature)}."
) )
self._tensor = torch_feature self._tensor = torch_feature
...@@ -106,7 +105,10 @@ class TorchBasedFeatureStore(FeatureStore): ...@@ -106,7 +105,10 @@ class TorchBasedFeatureStore(FeatureStore):
self._tensor[ids] = value self._tensor[ids] = value
def load_feature_stores(feat_data: List[OnDiskFeatureData]): class TorchBasedFeatureStore(FeatureStore):
r"""Torch based feature store."""
def __init__(self, feat_data: List[OnDiskFeatureData]):
r"""Load feature stores from disk. r"""Load feature stores from disk.
The feature stores are described by the `feat_data`. The `feat_data` is a The feature stores are described by the `feat_data`. The `feat_data` is a
...@@ -144,14 +146,10 @@ def load_feature_stores(feat_data: List[OnDiskFeatureData]): ...@@ -144,14 +146,10 @@ def load_feature_stores(feat_data: List[OnDiskFeatureData]):
... gb.OnDiskFeatureData(domain="node", type="paper", name="feat", ... gb.OnDiskFeatureData(domain="node", type="paper", name="feat",
... format="numpy", path="/tmp/node_feat.npy", in_memory=False), ... format="numpy", path="/tmp/node_feat.npy", in_memory=False),
... ] ... ]
>>> gb.load_feature_stores(feat_data) >>> feature_sotre = gb.TorchBasedFeatureStore(feat_data)
... {("edge", "author:writes:paper", "label"):
... <dgl.graphbolt.feature_store.TorchBasedFeatureStore object at
... 0x7ff093cb4df0>, ("node", "paper", "feat"):
... <dgl.graphbolt.feature_store.TorchBasedFeatureStore object at
... 0x7ff093cb4dc0>}
""" """
feat_stores = {} super().__init__()
self._features = {}
for spec in feat_data: for spec in feat_data:
key = (spec.domain, spec.type, spec.name) key = (spec.domain, spec.type, spec.name)
if spec.format == "torch": if spec.format == "torch":
...@@ -159,12 +157,72 @@ def load_feature_stores(feat_data: List[OnDiskFeatureData]): ...@@ -159,12 +157,72 @@ def load_feature_stores(feat_data: List[OnDiskFeatureData]):
f"Pytorch tensor can only be loaded in memory, " f"Pytorch tensor can only be loaded in memory, "
f"but the feature {key} is loaded on disk." f"but the feature {key} is loaded on disk."
) )
feat_stores[key] = TorchBasedFeatureStore(torch.load(spec.path)) self._features[key] = TorchBasedFeature(torch.load(spec.path))
elif spec.format == "numpy": elif spec.format == "numpy":
mmap_mode = "r+" if not spec.in_memory else None mmap_mode = "r+" if not spec.in_memory else None
feat_stores[key] = TorchBasedFeatureStore( self._features[key] = TorchBasedFeature(
torch.as_tensor(np.load(spec.path, mmap_mode=mmap_mode)) torch.as_tensor(np.load(spec.path, mmap_mode=mmap_mode))
) )
else: else:
raise ValueError(f"Unknown feature format {spec.format}") raise ValueError(f"Unknown feature format {spec.format}")
return feat_stores
def read(
self,
domain: str,
type_name: str,
feature_name: str,
ids: torch.Tensor = None,
):
"""Read from the feature store.
Parameters
----------
domain : str
The domain of the feature such as "node", "edge" or "graph".
type_name : str
The node or edge type name.
feature_name : str
The feature name.
ids : torch.Tensor, optional
The index of the feature. If specified, only the specified indices
of the feature are read. If None, the entire feature is returned.
Returns
-------
torch.Tensor
The read feature.
"""
return self._features[(domain, type_name, feature_name)].read(ids)
def update(
self,
domain: str,
type_name: str,
feature_name: str,
value: torch.Tensor,
ids: torch.Tensor = None,
):
"""Update the feature store.
Parameters
----------
domain : str
The domain of the feature such as "node", "edge" or "graph".
type_name : str
The node or edge type name.
feature_name : str
The feature name.
value : torch.Tensor
The updated value of the feature.
ids : torch.Tensor, optional
The indices of the feature to update. If specified, only the
specified indices of the feature will be updated. For the feature,
the `ids[i]` row is updated to `value[i]`. So the indices and value
must have the same length. If None, the entire feature will be
updated.
"""
self._features[(domain, type_name, feature_name)].update(value, ids)
def __len__(self):
"""Return the number of features."""
return len(self._features)
...@@ -6,10 +6,8 @@ import torch ...@@ -6,10 +6,8 @@ import torch
def get_graphbolt_fetch_func(): def get_graphbolt_fetch_func():
feature_store = { feature_store = {
"feature": dgl.graphbolt.TorchBasedFeatureStore(torch.randn(200, 4)), "feature": dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4)),
"label": dgl.graphbolt.TorchBasedFeatureStore( "label": dgl.graphbolt.TorchBasedFeature(torch.randint(0, 10, (200,))),
torch.randint(0, 10, (200,))
),
} }
def fetch_func(data): def fetch_func(data):
......
...@@ -19,7 +19,7 @@ def to_on_disk_tensor(test_dir, name, t): ...@@ -19,7 +19,7 @@ def to_on_disk_tensor(test_dir, name, t):
@pytest.mark.parametrize("in_memory", [True, False]) @pytest.mark.parametrize("in_memory", [True, False])
def test_torch_based_feature_store(in_memory): def test_torch_based_feature(in_memory):
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
a = torch.tensor([1, 2, 3]) a = torch.tensor([1, 2, 3])
b = torch.tensor([[1, 2, 3], [4, 5, 6]]) b = torch.tensor([[1, 2, 3], [4, 5, 6]])
...@@ -27,8 +27,8 @@ def test_torch_based_feature_store(in_memory): ...@@ -27,8 +27,8 @@ def test_torch_based_feature_store(in_memory):
a = to_on_disk_tensor(test_dir, "a", a) a = to_on_disk_tensor(test_dir, "a", a)
b = to_on_disk_tensor(test_dir, "b", b) b = to_on_disk_tensor(test_dir, "b", b)
feat_store_a = gb.TorchBasedFeatureStore(a) feat_store_a = gb.TorchBasedFeature(a)
feat_store_b = gb.TorchBasedFeatureStore(b) feat_store_b = gb.TorchBasedFeature(b)
assert torch.equal(feat_store_a.read(), torch.tensor([1, 2, 3])) assert torch.equal(feat_store_a.read(), torch.tensor([1, 2, 3]))
assert torch.equal( assert torch.equal(
...@@ -71,7 +71,7 @@ def write_tensor_to_disk(dir, name, t, fmt="torch"): ...@@ -71,7 +71,7 @@ def write_tensor_to_disk(dir, name, t, fmt="torch"):
@pytest.mark.parametrize("in_memory", [True, False]) @pytest.mark.parametrize("in_memory", [True, False])
def test_load_feature_stores(in_memory): def test_torch_based_feature_store(in_memory):
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
a = torch.tensor([1, 2, 3]) a = torch.tensor([1, 2, 3])
b = torch.tensor([2, 5, 3]) b = torch.tensor([2, 5, 3])
...@@ -95,12 +95,12 @@ def test_load_feature_stores(in_memory): ...@@ -95,12 +95,12 @@ def test_load_feature_stores(in_memory):
in_memory=in_memory, in_memory=in_memory,
), ),
] ]
feat_stores = gb.load_feature_stores(feat_data) feat_store = gb.TorchBasedFeatureStore(feat_data)
assert torch.equal( assert torch.equal(
feat_stores[("node", "paper", "a")].read(), torch.tensor([1, 2, 3]) feat_store.read("node", "paper", "a"), torch.tensor([1, 2, 3])
) )
assert torch.equal( assert torch.equal(
feat_stores[("edge", "paper-cites-paper", "b")].read(), feat_store.read("edge", "paper-cites-paper", "b"),
torch.tensor([2, 5, 3]), torch.tensor([2, 5, 3]),
) )
...@@ -130,6 +130,8 @@ def test_load_feature_stores(in_memory): ...@@ -130,6 +130,8 @@ def test_load_feature_stores(in_memory):
in_memory=True, in_memory=True,
), ),
] ]
feat_stores = gb.load_feature_stores(feat_data) feat_store = gb.TorchBasedFeatureStore(feat_data)
assert ("node", None, "a") in feat_stores assert torch.equal(
feat_store.read("node", None, "a"), torch.tensor([1, 2, 3])
)
feat_stores = None feat_stores = None
...@@ -27,8 +27,8 @@ def test_DataLoader(): ...@@ -27,8 +27,8 @@ def test_DataLoader():
# TODO(BarclayII): temporarily using DGLGraph. Should test using # TODO(BarclayII): temporarily using DGLGraph. Should test using
# GraphBolt's storage as well once issue #5953 is resolved. # GraphBolt's storage as well once issue #5953 is resolved.
graph = dgl.add_reverse_edges(dgl.rand_graph(200, 6000)) graph = dgl.add_reverse_edges(dgl.rand_graph(200, 6000))
features = dgl.graphbolt.TorchBasedFeatureStore(torch.randn(200, 4)) features = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4))
labels = dgl.graphbolt.TorchBasedFeatureStore(torch.randint(0, 10, (200,))) labels = dgl.graphbolt.TorchBasedFeature(torch.randint(0, 10, (200,)))
minibatch_sampler = dgl.graphbolt.MinibatchSampler(itemset, batch_size=B) minibatch_sampler = dgl.graphbolt.MinibatchSampler(itemset, batch_size=B)
subgraph_sampler = dgl.graphbolt.SubgraphSampler( subgraph_sampler = dgl.graphbolt.SubgraphSampler(
......
...@@ -647,35 +647,25 @@ def test_OnDiskDataset_Feature_heterograph(): ...@@ -647,35 +647,25 @@ def test_OnDiskDataset_Feature_heterograph():
assert len(feature_data) == 4 assert len(feature_data) == 4
# Verify node feature data. # Verify node feature data.
node_paper_feat = feature_data[("node", "paper", "feat")]
assert isinstance(node_paper_feat, gb.TorchBasedFeatureStore)
assert torch.equal( assert torch.equal(
node_paper_feat.read(), torch.tensor(node_data_paper) feature_data.read("node", "paper", "feat"),
torch.tensor(node_data_paper),
) )
node_paper_label = feature_data[("node", "paper", "label")]
assert isinstance(node_paper_label, gb.TorchBasedFeatureStore)
assert torch.equal( assert torch.equal(
node_paper_label.read(), torch.tensor(node_data_label) feature_data.read("node", "paper", "label"),
torch.tensor(node_data_label),
) )
# Verify edge feature data. # Verify edge feature data.
edge_writes_feat = feature_data[("edge", "author:writes:paper", "feat")]
assert isinstance(edge_writes_feat, gb.TorchBasedFeatureStore)
assert torch.equal( assert torch.equal(
edge_writes_feat.read(), torch.tensor(edge_data_writes) feature_data.read("edge", "author:writes:paper", "feat"),
torch.tensor(edge_data_writes),
) )
edge_writes_label = feature_data[
("edge", "author:writes:paper", "label")
]
assert isinstance(edge_writes_label, gb.TorchBasedFeatureStore)
assert torch.equal( assert torch.equal(
edge_writes_label.read(), torch.tensor(edge_data_label) feature_data.read("edge", "author:writes:paper", "label"),
torch.tensor(edge_data_label),
) )
node_paper_feat = None
node_paper_label = None
edge_writes_feat = None
edge_writes_label = None
feature_data = None feature_data = None
dataset = None dataset = None
...@@ -735,25 +725,25 @@ def test_OnDiskDataset_Feature_homograph(): ...@@ -735,25 +725,25 @@ def test_OnDiskDataset_Feature_homograph():
assert len(feature_data) == 4 assert len(feature_data) == 4
# Verify node feature data. # Verify node feature data.
node_feat = feature_data[("node", None, "feat")] assert torch.equal(
assert isinstance(node_feat, gb.TorchBasedFeatureStore) feature_data.read("node", None, "feat"),
assert torch.equal(node_feat.read(), torch.tensor(node_data_feat)) torch.tensor(node_data_feat),
node_label = feature_data[("node", None, "label")] )
assert isinstance(node_label, gb.TorchBasedFeatureStore) assert torch.equal(
assert torch.equal(node_label.read(), torch.tensor(node_data_label)) feature_data.read("node", None, "label"),
torch.tensor(node_data_label),
)
# Verify edge feature data. # Verify edge feature data.
edge_feat = feature_data[("edge", None, "feat")] assert torch.equal(
assert isinstance(edge_feat, gb.TorchBasedFeatureStore) feature_data.read("edge", None, "feat"),
assert torch.equal(edge_feat.read(), torch.tensor(edge_data_feat)) torch.tensor(edge_data_feat),
edge_label = feature_data[("edge", None, "label")] )
assert isinstance(edge_label, gb.TorchBasedFeatureStore) assert torch.equal(
assert torch.equal(edge_label.read(), torch.tensor(edge_data_label)) feature_data.read("edge", None, "label"),
torch.tensor(edge_data_label),
node_feat = None )
node_label = None
edge_feat = None
edge_label = None
feature_data = None feature_data = None
dataset = None dataset = None
......
...@@ -10,8 +10,8 @@ def test_DataLoader(): ...@@ -10,8 +10,8 @@ def test_DataLoader():
B = 4 B = 4
itemset = dgl.graphbolt.ItemSet(torch.arange(N)) itemset = dgl.graphbolt.ItemSet(torch.arange(N))
graph = gb_test_utils.rand_csc_graph(200, 0.15) graph = gb_test_utils.rand_csc_graph(200, 0.15)
features = dgl.graphbolt.TorchBasedFeatureStore(torch.randn(200, 4)) features = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4))
labels = dgl.graphbolt.TorchBasedFeatureStore(torch.randint(0, 10, (200,))) labels = dgl.graphbolt.TorchBasedFeature(torch.randint(0, 10, (200,)))
def sampler_func(data): def sampler_func(data):
adjs = [] adjs = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment