Unverified Commit 96ddf410 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] refactor FeatureStore and related impl (#6103)

parent 8e86c89c
"""GraphBolt Dataset."""
from typing import Dict
from .feature_store import FeatureStore
from .itemset import ItemSet, ItemSetDict
......@@ -52,7 +50,7 @@ class Dataset:
raise NotImplementedError
@property
def feature(self) -> Dict[object, FeatureStore]:
def feature(self) -> FeatureStore:
"""Return the feature."""
raise NotImplementedError
......
......@@ -11,11 +11,23 @@ class FeatureStore:
def __init__(self):
pass
def read(self, ids: torch.Tensor = None):
def read(
self,
domain: str,
type_name: str,
feature_name: str,
ids: torch.Tensor = None,
):
"""Read from the feature store.
Parameters
----------
domain : str
The domain of the feature such as "node", "edge" or "graph".
type_name : str
The node or edge type name.
feature_name : str
The feature name.
ids : torch.Tensor, optional
The index of the feature. If specified, only the specified indices
of the feature are read. If None, the entire feature is returned.
......@@ -27,11 +39,24 @@ class FeatureStore:
"""
raise NotImplementedError
def update(self, value: torch.Tensor, ids: torch.Tensor = None):
def update(
self,
domain: str,
type_name: str,
feature_name: str,
value: torch.Tensor,
ids: torch.Tensor = None,
):
"""Update the feature store.
Parameters
----------
domain : str
The domain of the feature such as "node", "edge" or "graph".
type_name : str
The node or edge type name.
feature_name : str
The feature name.
value : torch.Tensor
The updated value of the feature.
ids : torch.Tensor, optional
......
......@@ -5,7 +5,7 @@ import shutil
from copy import deepcopy
from pathlib import Path
from typing import Dict, List, Tuple
from typing import List
import pandas as pd
import torch
......@@ -24,10 +24,7 @@ from .csc_sampling_graph import (
save_csc_sampling_graph,
)
from .ondisk_metadata import OnDiskGraphTopology, OnDiskMetaData, OnDiskTVTSet
from .torch_based_feature_store import (
load_feature_stores,
TorchBasedFeatureStore,
)
from .torch_based_feature_store import TorchBasedFeatureStore
__all__ = ["OnDiskDataset", "preprocess_ondisk_dataset"]
......@@ -281,7 +278,7 @@ class OnDiskDataset(Dataset):
self._num_classes = self._meta.num_classes
self._num_labels = self._meta.num_labels
self._graph = self._load_graph(self._meta.graph_topology)
self._feature = load_feature_stores(self._meta.feature_data)
self._feature = TorchBasedFeatureStore(self._meta.feature_data)
self._train_set = self._init_tvt_set(self._meta.train_set)
self._validation_set = self._init_tvt_set(self._meta.validation_set)
self._test_set = self._init_tvt_set(self._meta.test_set)
......@@ -307,7 +304,7 @@ class OnDiskDataset(Dataset):
return self._graph
@property
def feature(self) -> Dict[Tuple, TorchBasedFeatureStore]:
def feature(self) -> TorchBasedFeatureStore:
"""Return the feature."""
return self._feature
......
......@@ -8,11 +8,11 @@ import torch
from ..feature_store import FeatureStore
from .ondisk_metadata import OnDiskFeatureData
__all__ = ["TorchBasedFeatureStore", "load_feature_stores"]
__all__ = ["TorchBasedFeature", "TorchBasedFeatureStore"]
class TorchBasedFeatureStore(FeatureStore):
r"""Torch based feature store."""
class TorchBasedFeature:
r"""Torch based feature."""
def __init__(self, torch_feature: torch.Tensor):
"""Initialize a torch based feature store by a torch feature.
......@@ -28,7 +28,7 @@ class TorchBasedFeatureStore(FeatureStore):
--------
>>> import torch
>>> torch_feat = torch.arange(0, 5)
>>> feature_store = TorchBasedFeatureStore(torch_feat)
>>> feature_store = TorchBasedFeature(torch_feat)
>>> feature_store.read()
tensor([0, 1, 2, 3, 4])
>>> feature_store.read(torch.tensor([0, 1, 2]))
......@@ -43,15 +43,14 @@ class TorchBasedFeatureStore(FeatureStore):
>>> np.save("/tmp/arr.npy", arr)
>>> torch_feat = torch.as_tensor(np.load("/tmp/arr.npy",
... mmap_mode="r+"))
>>> feature_store = TorchBasedFeatureStore(torch_feat)
>>> feature_store = TorchBasedFeature(torch_feat)
>>> feature_store.read()
tensor([0, 1, 2, 3, 4])
>>> feature_store.read(torch.tensor([0, 1, 2]))
tensor([0, 1, 2])
"""
super(TorchBasedFeatureStore, self).__init__()
assert isinstance(torch_feature, torch.Tensor), (
f"torch_feature in TorchBasedFeatureStore must be torch.Tensor, "
f"torch_feature in TorchBasedFeature must be torch.Tensor, "
f"but got {type(torch_feature)}."
)
self._tensor = torch_feature
......@@ -106,65 +105,124 @@ class TorchBasedFeatureStore(FeatureStore):
self._tensor[ids] = value
def load_feature_stores(feat_data: List[OnDiskFeatureData]):
r"""Load feature stores from disk.
The feature stores are described by the `feat_data`. The `feat_data` is a
list of `OnDiskFeatureData`.
For a feature store, its format must be either "pt" or "npy" for Pytorch or
Numpy formats. If the format is "pt", the feature store must be loaded in
memory. If the format is "npy", the feature store can be loaded in memory or
on disk.
Parameters
----------
feat_data : List[OnDiskFeatureData]
The description of the feature stores.
Returns
-------
dict
The loaded feature stores. The keys are the names of the feature stores,
and the values are the feature stores.
Examples
--------
>>> import torch
>>> import numpy as np
>>> from dgl import graphbolt as gb
>>> edge_label = torch.tensor([1, 2, 3])
>>> node_feat = torch.tensor([[1, 2, 3], [4, 5, 6]])
>>> torch.save(edge_label, "/tmp/edge_label.pt")
>>> np.save("/tmp/node_feat.npy", node_feat.numpy())
>>> feat_data = [
... gb.OnDiskFeatureData(domain="edge", type="author:writes:paper",
... name="label", format="torch", path="/tmp/edge_label.pt",
... in_memory=True),
... gb.OnDiskFeatureData(domain="node", type="paper", name="feat",
... format="numpy", path="/tmp/node_feat.npy", in_memory=False),
... ]
>>> gb.load_feature_stores(feat_data)
... {("edge", "author:writes:paper", "label"):
... <dgl.graphbolt.feature_store.TorchBasedFeatureStore object at
... 0x7ff093cb4df0>, ("node", "paper", "feat"):
... <dgl.graphbolt.feature_store.TorchBasedFeatureStore object at
... 0x7ff093cb4dc0>}
"""
feat_stores = {}
for spec in feat_data:
key = (spec.domain, spec.type, spec.name)
if spec.format == "torch":
assert spec.in_memory, (
f"Pytorch tensor can only be loaded in memory, "
f"but the feature {key} is loaded on disk."
)
feat_stores[key] = TorchBasedFeatureStore(torch.load(spec.path))
elif spec.format == "numpy":
mmap_mode = "r+" if not spec.in_memory else None
feat_stores[key] = TorchBasedFeatureStore(
torch.as_tensor(np.load(spec.path, mmap_mode=mmap_mode))
)
else:
raise ValueError(f"Unknown feature format {spec.format}")
return feat_stores
class TorchBasedFeatureStore(FeatureStore):
r"""Torch based feature store."""
def __init__(self, feat_data: List[OnDiskFeatureData]):
r"""Load feature stores from disk.
The feature stores are described by the `feat_data`. The `feat_data` is a
list of `OnDiskFeatureData`.
For a feature store, its format must be either "pt" or "npy" for Pytorch or
Numpy formats. If the format is "pt", the feature store must be loaded in
memory. If the format is "npy", the feature store can be loaded in memory or
on disk.
Parameters
----------
feat_data : List[OnDiskFeatureData]
The description of the feature stores.
Returns
-------
dict
The loaded feature stores. The keys are the names of the feature stores,
and the values are the feature stores.
Examples
--------
>>> import torch
>>> import numpy as np
>>> from dgl import graphbolt as gb
>>> edge_label = torch.tensor([1, 2, 3])
>>> node_feat = torch.tensor([[1, 2, 3], [4, 5, 6]])
>>> torch.save(edge_label, "/tmp/edge_label.pt")
>>> np.save("/tmp/node_feat.npy", node_feat.numpy())
>>> feat_data = [
... gb.OnDiskFeatureData(domain="edge", type="author:writes:paper",
... name="label", format="torch", path="/tmp/edge_label.pt",
... in_memory=True),
... gb.OnDiskFeatureData(domain="node", type="paper", name="feat",
... format="numpy", path="/tmp/node_feat.npy", in_memory=False),
... ]
>>> feature_sotre = gb.TorchBasedFeatureStore(feat_data)
"""
super().__init__()
self._features = {}
for spec in feat_data:
key = (spec.domain, spec.type, spec.name)
if spec.format == "torch":
assert spec.in_memory, (
f"Pytorch tensor can only be loaded in memory, "
f"but the feature {key} is loaded on disk."
)
self._features[key] = TorchBasedFeature(torch.load(spec.path))
elif spec.format == "numpy":
mmap_mode = "r+" if not spec.in_memory else None
self._features[key] = TorchBasedFeature(
torch.as_tensor(np.load(spec.path, mmap_mode=mmap_mode))
)
else:
raise ValueError(f"Unknown feature format {spec.format}")
def read(
self,
domain: str,
type_name: str,
feature_name: str,
ids: torch.Tensor = None,
):
"""Read from the feature store.
Parameters
----------
domain : str
The domain of the feature such as "node", "edge" or "graph".
type_name : str
The node or edge type name.
feature_name : str
The feature name.
ids : torch.Tensor, optional
The index of the feature. If specified, only the specified indices
of the feature are read. If None, the entire feature is returned.
Returns
-------
torch.Tensor
The read feature.
"""
return self._features[(domain, type_name, feature_name)].read(ids)
def update(
self,
domain: str,
type_name: str,
feature_name: str,
value: torch.Tensor,
ids: torch.Tensor = None,
):
"""Update the feature store.
Parameters
----------
domain : str
The domain of the feature such as "node", "edge" or "graph".
type_name : str
The node or edge type name.
feature_name : str
The feature name.
value : torch.Tensor
The updated value of the feature.
ids : torch.Tensor, optional
The indices of the feature to update. If specified, only the
specified indices of the feature will be updated. For the feature,
the `ids[i]` row is updated to `value[i]`. So the indices and value
must have the same length. If None, the entire feature will be
updated.
"""
self._features[(domain, type_name, feature_name)].update(value, ids)
def __len__(self):
"""Return the number of features."""
return len(self._features)
......@@ -6,10 +6,8 @@ import torch
def get_graphbolt_fetch_func():
feature_store = {
"feature": dgl.graphbolt.TorchBasedFeatureStore(torch.randn(200, 4)),
"label": dgl.graphbolt.TorchBasedFeatureStore(
torch.randint(0, 10, (200,))
),
"feature": dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4)),
"label": dgl.graphbolt.TorchBasedFeature(torch.randint(0, 10, (200,))),
}
def fetch_func(data):
......
......@@ -19,7 +19,7 @@ def to_on_disk_tensor(test_dir, name, t):
@pytest.mark.parametrize("in_memory", [True, False])
def test_torch_based_feature_store(in_memory):
def test_torch_based_feature(in_memory):
with tempfile.TemporaryDirectory() as test_dir:
a = torch.tensor([1, 2, 3])
b = torch.tensor([[1, 2, 3], [4, 5, 6]])
......@@ -27,8 +27,8 @@ def test_torch_based_feature_store(in_memory):
a = to_on_disk_tensor(test_dir, "a", a)
b = to_on_disk_tensor(test_dir, "b", b)
feat_store_a = gb.TorchBasedFeatureStore(a)
feat_store_b = gb.TorchBasedFeatureStore(b)
feat_store_a = gb.TorchBasedFeature(a)
feat_store_b = gb.TorchBasedFeature(b)
assert torch.equal(feat_store_a.read(), torch.tensor([1, 2, 3]))
assert torch.equal(
......@@ -71,7 +71,7 @@ def write_tensor_to_disk(dir, name, t, fmt="torch"):
@pytest.mark.parametrize("in_memory", [True, False])
def test_load_feature_stores(in_memory):
def test_torch_based_feature_store(in_memory):
with tempfile.TemporaryDirectory() as test_dir:
a = torch.tensor([1, 2, 3])
b = torch.tensor([2, 5, 3])
......@@ -95,12 +95,12 @@ def test_load_feature_stores(in_memory):
in_memory=in_memory,
),
]
feat_stores = gb.load_feature_stores(feat_data)
feat_store = gb.TorchBasedFeatureStore(feat_data)
assert torch.equal(
feat_stores[("node", "paper", "a")].read(), torch.tensor([1, 2, 3])
feat_store.read("node", "paper", "a"), torch.tensor([1, 2, 3])
)
assert torch.equal(
feat_stores[("edge", "paper-cites-paper", "b")].read(),
feat_store.read("edge", "paper-cites-paper", "b"),
torch.tensor([2, 5, 3]),
)
......@@ -130,6 +130,8 @@ def test_load_feature_stores(in_memory):
in_memory=True,
),
]
feat_stores = gb.load_feature_stores(feat_data)
assert ("node", None, "a") in feat_stores
feat_store = gb.TorchBasedFeatureStore(feat_data)
assert torch.equal(
feat_store.read("node", None, "a"), torch.tensor([1, 2, 3])
)
feat_stores = None
......@@ -27,8 +27,8 @@ def test_DataLoader():
# TODO(BarclayII): temporarily using DGLGraph. Should test using
# GraphBolt's storage as well once issue #5953 is resolved.
graph = dgl.add_reverse_edges(dgl.rand_graph(200, 6000))
features = dgl.graphbolt.TorchBasedFeatureStore(torch.randn(200, 4))
labels = dgl.graphbolt.TorchBasedFeatureStore(torch.randint(0, 10, (200,)))
features = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4))
labels = dgl.graphbolt.TorchBasedFeature(torch.randint(0, 10, (200,)))
minibatch_sampler = dgl.graphbolt.MinibatchSampler(itemset, batch_size=B)
subgraph_sampler = dgl.graphbolt.SubgraphSampler(
......
......@@ -647,35 +647,25 @@ def test_OnDiskDataset_Feature_heterograph():
assert len(feature_data) == 4
# Verify node feature data.
node_paper_feat = feature_data[("node", "paper", "feat")]
assert isinstance(node_paper_feat, gb.TorchBasedFeatureStore)
assert torch.equal(
node_paper_feat.read(), torch.tensor(node_data_paper)
feature_data.read("node", "paper", "feat"),
torch.tensor(node_data_paper),
)
node_paper_label = feature_data[("node", "paper", "label")]
assert isinstance(node_paper_label, gb.TorchBasedFeatureStore)
assert torch.equal(
node_paper_label.read(), torch.tensor(node_data_label)
feature_data.read("node", "paper", "label"),
torch.tensor(node_data_label),
)
# Verify edge feature data.
edge_writes_feat = feature_data[("edge", "author:writes:paper", "feat")]
assert isinstance(edge_writes_feat, gb.TorchBasedFeatureStore)
assert torch.equal(
edge_writes_feat.read(), torch.tensor(edge_data_writes)
feature_data.read("edge", "author:writes:paper", "feat"),
torch.tensor(edge_data_writes),
)
edge_writes_label = feature_data[
("edge", "author:writes:paper", "label")
]
assert isinstance(edge_writes_label, gb.TorchBasedFeatureStore)
assert torch.equal(
edge_writes_label.read(), torch.tensor(edge_data_label)
feature_data.read("edge", "author:writes:paper", "label"),
torch.tensor(edge_data_label),
)
node_paper_feat = None
node_paper_label = None
edge_writes_feat = None
edge_writes_label = None
feature_data = None
dataset = None
......@@ -735,25 +725,25 @@ def test_OnDiskDataset_Feature_homograph():
assert len(feature_data) == 4
# Verify node feature data.
node_feat = feature_data[("node", None, "feat")]
assert isinstance(node_feat, gb.TorchBasedFeatureStore)
assert torch.equal(node_feat.read(), torch.tensor(node_data_feat))
node_label = feature_data[("node", None, "label")]
assert isinstance(node_label, gb.TorchBasedFeatureStore)
assert torch.equal(node_label.read(), torch.tensor(node_data_label))
assert torch.equal(
feature_data.read("node", None, "feat"),
torch.tensor(node_data_feat),
)
assert torch.equal(
feature_data.read("node", None, "label"),
torch.tensor(node_data_label),
)
# Verify edge feature data.
edge_feat = feature_data[("edge", None, "feat")]
assert isinstance(edge_feat, gb.TorchBasedFeatureStore)
assert torch.equal(edge_feat.read(), torch.tensor(edge_data_feat))
edge_label = feature_data[("edge", None, "label")]
assert isinstance(edge_label, gb.TorchBasedFeatureStore)
assert torch.equal(edge_label.read(), torch.tensor(edge_data_label))
node_feat = None
node_label = None
edge_feat = None
edge_label = None
assert torch.equal(
feature_data.read("edge", None, "feat"),
torch.tensor(edge_data_feat),
)
assert torch.equal(
feature_data.read("edge", None, "label"),
torch.tensor(edge_data_label),
)
feature_data = None
dataset = None
......
......@@ -10,8 +10,8 @@ def test_DataLoader():
B = 4
itemset = dgl.graphbolt.ItemSet(torch.arange(N))
graph = gb_test_utils.rand_csc_graph(200, 0.15)
features = dgl.graphbolt.TorchBasedFeatureStore(torch.randn(200, 4))
labels = dgl.graphbolt.TorchBasedFeatureStore(torch.randint(0, 10, (200,)))
features = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4))
labels = dgl.graphbolt.TorchBasedFeature(torch.randint(0, 10, (200,)))
def sampler_func(data):
adjs = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment