Unverified Commit dc90ea16 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Graphbolt] Add load_feature_stores (#5949)

parent e5ddc62b
"""Feature store for GraphBolt.""" """Feature store for GraphBolt."""
from typing import List
import numpy as np
import pydantic
import pydantic_yaml
import torch import torch
...@@ -134,3 +140,82 @@ class TorchBasedFeatureStore(FeatureStore): ...@@ -134,3 +140,82 @@ class TorchBasedFeatureStore(FeatureStore):
f"but got {ids.shape[0]} and {value.shape[0]}." f"but got {ids.shape[0]} and {value.shape[0]}."
) )
self._tensor[ids] = value self._tensor[ids] = value
# FIXME(Rui): To avoid circular import, we make a copy of `OnDiskDataFormatEnum`
# from dataset.py. We need to merge the two definitions later.
class OnDiskDataFormatEnum(pydantic_yaml.YamlStrEnum):
"""Enum of data format."""
TORCH = "torch"
NUMPY = "numpy"
class OnDiskFeatureData(pydantic.BaseModel):
r"""The description of an on-disk feature."""
name: str
format: OnDiskDataFormatEnum
path: str
in_memory: bool = True
def load_feature_stores(feat_data: List[OnDiskFeatureData]):
r"""Load feature stores from disk.
The feature stores are described by the `feat_data`. The `feat_data` is a
list of `OnDiskFeatureData`.
For a feature store, its format must be either "pt" or "npy" for Pytorch or
Numpy formats. If the format is "pt", the feature store must be loaded in
memory. If the format is "npy", the feature store can be loaded in memory or
on disk.
Parameters
----------
feat_data : List[OnDiskFeatureData]
The description of the feature stores.
Returns
-------
dict
The loaded feature stores. The keys are the names of the feature stores,
and the values are the feature stores.
Examples
--------
>>> import torch
>>> import numpy as np
>>> from dgl import graphbolt as gb
>>> a = torch.tensor([1, 2, 3])
>>> b = torch.tensor([[1, 2, 3], [4, 5, 6]])
>>> torch.save(a, "/tmp/a.pt")
>>> np.save("/tmp/b.npy", b.numpy())
>>> feat_data = [
... gb.OnDiskFeatureData(name="a", format="torch", path="/tmp/a.pt",
... in_memory=True),
... gb.OnDiskFeatureData(name="b", format="numpy", path="/tmp/b.npy",
... in_memory=False),
... ]
>>> gb.load_feature_stores(feat_data)
... {'a': <dgl.graphbolt.feature_store.TorchBasedFeatureStore object at
... 0x7ff093cb4df0>, 'b':
... <dgl.graphbolt.feature_store.TorchBasedFeatureStore object at
... 0x7ff093cb4dc0>}
"""
feat_stores = {}
for spec in feat_data:
key = spec.name
if spec.format == "torch":
assert spec.in_memory, (
f"Pytorch tensor can only be loaded in memory, "
f"but the feature {key} is loaded on disk."
)
feat_stores[key] = TorchBasedFeatureStore(torch.load(spec.path))
elif spec.format == "numpy":
mmap_mode = "r+" if not spec.in_memory else None
feat_stores[key] = TorchBasedFeatureStore(
torch.as_tensor(np.load(spec.path, mmap_mode=mmap_mode))
)
else:
raise ValueError(f"Unknown feature format {spec.format}")
return feat_stores
...@@ -57,3 +57,42 @@ def test_torch_based_feature_store(in_memory): ...@@ -57,3 +57,42 @@ def test_torch_based_feature_store(in_memory):
# it before closing the temporary directory. # it before closing the temporary directory.
a = b = None a = b = None
feat_store_a = feat_store_b = None feat_store_a = feat_store_b = None
def write_tensor_to_disk(dir, name, t, fmt="pt"):
if fmt == "pt":
torch.save(t, os.path.join(dir, name + ".pt"))
else:
t = t.numpy()
np.save(os.path.join(dir, name + ".npy"), t)
@pytest.mark.parametrize("in_memory", [True, False])
def test_load_feature_stores(in_memory):
with tempfile.TemporaryDirectory() as test_dir:
a = torch.tensor([1, 2, 3])
b = torch.tensor([2, 5, 3])
write_tensor_to_disk(test_dir, "a", a, fmt="pt")
write_tensor_to_disk(test_dir, "b", b, fmt="npy")
feat_data = [
gb.OnDiskFeatureData(
name="a",
format="torch",
path=os.path.join(test_dir, "a.pt"),
in_memory=True,
),
gb.OnDiskFeatureData(
name="b",
format="numpy",
path=os.path.join(test_dir, "b.npy"),
in_memory=in_memory,
),
]
feat_stores = gb.load_feature_stores(feat_data)
assert torch.equal(feat_stores["a"].read(), torch.tensor([1, 2, 3]))
assert torch.equal(feat_stores["b"].read(), torch.tensor([2, 5, 3]))
# For windows, the file is locked by the numpy.load. We need to delete
# it before closing the temporary directory.
a = b = None
feat_stores = None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment