Unverified Commit b8e93575 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Graphbolt] Add BasicFeatureStore. (#6204)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent 60d4ea69
......@@ -2,7 +2,45 @@
import torch
__all__ = ["FeatureStore"]
__all__ = ["Feature", "FeatureStore"]
class Feature:
r"""Base class for feature."""
def __init__(self):
pass
def read(self, ids: torch.Tensor = None):
"""Read from the feature.
Parameters
----------
ids : torch.Tensor, optional
The index of the feature. If specified, only the specified indices
of the feature are read. If None, the entire feature is returned.
Returns
-------
torch.Tensor
The read feature.
"""
raise NotImplementedError
def update(self, value: torch.Tensor, ids: torch.Tensor = None):
"""Update the feature.
Parameters
----------
value : torch.Tensor
The updated value of the feature.
ids : torch.Tensor, optional
The indices of the feature to update. If specified, only the
specified indices of the feature will be updated. For the feature,
the `ids[i]` row is updated to `value[i]`. So the indices and value
must have the same length. If None, the entire feature will be
updated.
"""
raise NotImplementedError
class FeatureStore:
......
"""Implementation of GraphBolt."""
from .basic_feature_store import *
from .csc_sampling_graph import *
from .neighbor_sampler import *
from .ondisk_dataset import *
......
"""Basic feature store for GraphBolt."""
from typing import Dict, Tuple
import torch
from ..feature_store import Feature, FeatureStore
__all__ = ["BasicFeatureStore"]
class BasicFeatureStore(FeatureStore):
r"""Basic feature store."""
def __init__(self, features: Dict[Tuple[str, str, str], Feature]):
r"""Initiate a basic feature store.
Parameters
----------
features : Dict[Tuple[str, str, str], Feature]
The dict of features served by the feature store, in which the key
is tuple of (domain, type_name, feature_name).
Returns
-------
The feature stores.
"""
super().__init__()
self._features = features
def read(
self,
domain: str,
type_name: str,
feature_name: str,
ids: torch.Tensor = None,
):
"""Read from the feature store.
Parameters
----------
domain : str
The domain of the feature such as "node", "edge" or "graph".
type_name : str
The node or edge type name.
feature_name : str
The feature name.
ids : torch.Tensor, optional
The index of the feature. If specified, only the specified indices
of the feature are read. If None, the entire feature is returned.
Returns
-------
torch.Tensor
The read feature.
"""
return self._features[(domain, type_name, feature_name)].read(ids)
def update(
self,
domain: str,
type_name: str,
feature_name: str,
value: torch.Tensor,
ids: torch.Tensor = None,
):
"""Update the feature store.
Parameters
----------
domain : str
The domain of the feature such as "node", "edge" or "graph".
type_name : str
The node or edge type name.
feature_name : str
The feature name.
value : torch.Tensor
The updated value of the feature.
ids : torch.Tensor, optional
The indices of the feature to update. If specified, only the
specified indices of the feature will be updated. For the feature,
the `ids[i]` row is updated to `value[i]`. So the indices and value
must have the same length. If None, the entire feature will be
updated.
"""
self._features[(domain, type_name, feature_name)].update(value, ids)
def __len__(self):
"""Return the number of features."""
return len(self._features)
......@@ -4,13 +4,14 @@ from typing import List
import numpy as np
import torch
from ..feature_store import FeatureStore
from ..feature_store import Feature
from .basic_feature_store import BasicFeatureStore
from .ondisk_metadata import OnDiskFeatureData
__all__ = ["TorchBasedFeature", "TorchBasedFeatureStore"]
class TorchBasedFeature:
class TorchBasedFeature(Feature):
r"""Torch based feature."""
def __init__(self, torch_feature: torch.Tensor):
......@@ -48,6 +49,7 @@ class TorchBasedFeature:
>>> feature_store.read(torch.tensor([0, 1, 2]))
tensor([0, 1, 2])
"""
super().__init__()
assert isinstance(torch_feature, torch.Tensor), (
f"torch_feature in TorchBasedFeature must be torch.Tensor, "
f"but got {type(torch_feature)}."
......@@ -104,7 +106,7 @@ class TorchBasedFeature:
self._tensor[ids] = value
class TorchBasedFeatureStore(FeatureStore):
class TorchBasedFeatureStore(BasicFeatureStore):
r"""Torch based feature store."""
def __init__(self, feat_data: List[OnDiskFeatureData]):
......@@ -147,8 +149,7 @@ class TorchBasedFeatureStore(FeatureStore):
... ]
>>> feature_sotre = gb.TorchBasedFeatureStore(feat_data)
"""
super().__init__()
self._features = {}
features = {}
for spec in feat_data:
key = (spec.domain, spec.type, spec.name)
if spec.format == "torch":
......@@ -156,72 +157,12 @@ class TorchBasedFeatureStore(FeatureStore):
f"Pytorch tensor can only be loaded in memory, "
f"but the feature {key} is loaded on disk."
)
self._features[key] = TorchBasedFeature(torch.load(spec.path))
features[key] = TorchBasedFeature(torch.load(spec.path))
elif spec.format == "numpy":
mmap_mode = "r+" if not spec.in_memory else None
self._features[key] = TorchBasedFeature(
features[key] = TorchBasedFeature(
torch.as_tensor(np.load(spec.path, mmap_mode=mmap_mode))
)
else:
raise ValueError(f"Unknown feature format {spec.format}")
def read(
self,
domain: str,
type_name: str,
feature_name: str,
ids: torch.Tensor = None,
):
"""Read from the feature store.
Parameters
----------
domain : str
The domain of the feature such as "node", "edge" or "graph".
type_name : str
The node or edge type name.
feature_name : str
The feature name.
ids : torch.Tensor, optional
The index of the feature. If specified, only the specified indices
of the feature are read. If None, the entire feature is returned.
Returns
-------
torch.Tensor
The read feature.
"""
return self._features[(domain, type_name, feature_name)].read(ids)
def update(
self,
domain: str,
type_name: str,
feature_name: str,
value: torch.Tensor,
ids: torch.Tensor = None,
):
"""Update the feature store.
Parameters
----------
domain : str
The domain of the feature such as "node", "edge" or "graph".
type_name : str
The node or edge type name.
feature_name : str
The feature name.
value : torch.Tensor
The updated value of the feature.
ids : torch.Tensor, optional
The indices of the feature to update. If specified, only the
specified indices of the feature will be updated. For the feature,
the `ids[i]` row is updated to `value[i]`. So the indices and value
must have the same length. If None, the entire feature will be
updated.
"""
self._features[(domain, type_name, feature_name)].update(value, ids)
def __len__(self):
"""Return the number of features."""
return len(self._features)
super().__init__(features)
import pytest
import torch
from dgl import graphbolt as gb
def test_basic_feature_store():
a = torch.tensor([3, 2, 1])
b = torch.tensor([2, 5, 3])
features = {}
features[("node", "paper", "a")] = gb.TorchBasedFeature(a)
features[("edge", "paper-cites-paper", "b")] = gb.TorchBasedFeature(b)
feature_store = gb.BasicFeatureStore(features)
assert torch.equal(
feature_store.read("node", "paper", "a"), torch.tensor([3, 2, 1])
)
assert torch.equal(
feature_store.read("edge", "paper-cites-paper", "b"),
torch.tensor([2, 5, 3]),
)
assert torch.equal(
feature_store.read("node", "paper", "a", torch.tensor([0, 1])),
torch.tensor([3, 2]),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment