Unverified Commit ff9f05c5 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Graphbolt] Remove keys in feature store. (#5938)

parent 28578137
......@@ -8,13 +8,11 @@ class FeatureStore:
def __init__(self):
pass
def read(self, key: str, ids: torch.Tensor = None):
"""Read a feature from the feature store.
def read(self, ids: torch.Tensor = None):
"""Read from the feature store.
Parameters
----------
key : str
The key that uniquely identifies the feature in the feature store.
ids : torch.Tensor, optional
The index of the feature. If specified, only the specified indices
of the feature are read. If None, the entire feature is returned.
......@@ -26,17 +24,11 @@ class FeatureStore:
"""
raise NotImplementedError
def update(self, key: str, value: torch.Tensor, ids: torch.Tensor = None):
"""Update a feature in the feature store.
This function is used to update a feature in the feature store. The
feature is identified by a unique key, and its value is specified using
a tensor.
def update(self, value: torch.Tensor, ids: torch.Tensor = None):
"""Update the feature store.
Parameters
----------
key : str
The key that uniquely identifies the feature in the feature store.
value : torch.Tensor
The updated value of the feature.
ids : torch.Tensor, optional
......@@ -50,87 +42,58 @@ class FeatureStore:
class TorchBasedFeatureStore(FeatureStore):
r"""Torch based key-value feature store, where the key are strings and
values are Pytorch tensors."""
def __init__(self, feature_dict: dict):
"""Initialize a torch based feature store.
r"""Torch based feature store."""
The feature store is initialized with a dictionary of tensors, where the
key is the name of a feature and the value is the tensor. The value can
be multi-dimensional, where the first dimension is the index of the
feature.
def __init__(self, torch_feature: torch.Tensor):
"""Initialize a torch based feature store by a torch feature.
Note that the values can be in memory or on disk.
Note that the feature can be either in memory or on disk.
Parameters
----------
feature_dict : dict, optional
A dictionary of tensors.
torch_feature : torch.Tensor
The torch feature.
Examples
--------
>>> import torch
>>> feature_dict = {
... "user": torch.arange(0, 5),
... "item": torch.arange(0, 6),
... "rel": torch.arange(0, 6).view(2, 3),
... }
>>> feature_store = TorchBasedFeatureStore(feature_dict)
>>> feature_store.read("user", torch.tensor([0, 1, 2]))
>>> torch_feat = torch.arange(0, 5)
>>> feature_store = TorchBasedFeatureStore(torch_feat)
>>> feature_store.read()
tensor([0, 1, 2, 3, 4])
>>> feature_store.read(torch.tensor([0, 1, 2]))
tensor([0, 1, 2])
>>> feature_store.read("item", torch.tensor([0, 1, 2]))
tensor([0, 1, 2])
>>> feature_store.read("rel", torch.tensor([0]))
tensor([[0, 1, 2]])
>>> feature_store.update("user",
... torch.ones(3, dtype=torch.long), torch.tensor([0, 1, 2]))
>>> feature_store.read("user", torch.tensor([0, 1, 2]))
tensor([1, 1, 1])
>>> feature_store.update(torch.ones(3, dtype=torch.long),
... torch.tensor([0, 1, 2]))
>>> feature_store.read(torch.tensor([0, 1, 2, 3]))
tensor([1, 1, 1, 3])
>>> import numpy as np
>>> user = np.arange(0, 5)
>>> item = np.arange(0, 6)
>>> np.save("/tmp/user.npy", user)
>>. np.save("/tmp/item.npy", item)
>>> feature_dict = {
... "user": torch.as_tensor(np.load("/tmp/user.npy",
... mmap_mode="r+")),
... "item": torch.as_tensor(np.load("/tmp/item.npy",
... mmap_mode="r+")),
... }
>>> feature_store = TorchBasedFeatureStore(feature_dict)
>>> feature_store.read("user", torch.tensor([0, 1, 2]))
>>> arr = np.arange(0, 5)
>>> np.save("/tmp/arr.npy", arr)
>>> torch_feat = torch.as_tensor(np.load("/tmp/arr.npy",
... mmap_mode="r+"))
>>> feature_store = TorchBasedFeatureStore(torch_feat)
>>> feature_store.read()
tensor([0, 1, 2, 3, 4])
>>> feature_store.read(torch.tensor([0, 1, 2]))
tensor([0, 1, 2])
>>> feature_store.read("item", torch.tensor([3, 4, 2]))
tensor([3, 4, 2])
"""
super(TorchBasedFeatureStore, self).__init__()
assert isinstance(feature_dict, dict), (
f"feature_dict in TorchBasedFeatureStore must be dict, "
f"but got {type(feature_dict)}."
)
for k, v in feature_dict.items():
assert isinstance(
k, str
), f"Key in TorchBasedFeatureStore must be str, but got {k}."
assert isinstance(v, torch.Tensor), (
f"Value in TorchBasedFeatureStore must be torch.Tensor,"
f"but got {v}."
assert isinstance(torch_feature, torch.Tensor), (
f"torch_feature in TorchBasedFeatureStore must be torch.Tensor, "
f"but got {type(torch_feature)}."
)
self._tensor = torch_feature
self._feature_dict = feature_dict
def read(self, ids: torch.Tensor = None):
"""Read the feature by index.
def read(self, key: str, ids: torch.Tensor = None):
"""Read a feature from the feature store by index.
The returned feature is always in memory, no matter whether the feature
to read is in memory or on disk.
The returned tensor is always in memory, no matter whether the feature
store is in memory or on disk.
Parameters
----------
key : str
The key of the feature.
ids : torch.Tensor, optional
The index of the feature. If specified, only the specified indices
of the feature are read. If None, the entire feature is returned.
......@@ -140,24 +103,15 @@ class TorchBasedFeatureStore(FeatureStore):
torch.Tensor
The read feature.
"""
assert (
key in self._feature_dict
), f"key {key} not in {self._feature_dict.keys()}"
if ids is None:
return self._feature_dict[key]
return self._feature_dict[key][ids]
def update(self, key: str, value: torch.Tensor, ids: torch.Tensor = None):
"""Update a feature in the feature store.
return self._tensor
return self._tensor[ids]
This function is used to update a feature in the feature store. The
feature is identified by a unique key, and its value is specified using
a tensor.
def update(self, value: torch.Tensor, ids: torch.Tensor = None):
"""Update the feature store.
Parameters
----------
key : str
The key that uniquely identifies the feature in the feature store.
value : torch.Tensor
The updated value of the feature.
ids : torch.Tensor, optional
......@@ -167,14 +121,16 @@ class TorchBasedFeatureStore(FeatureStore):
must have the same length. If None, the entire feature will be
updated.
"""
assert (
key in self._feature_dict
), f"key {key} not in {self._feature_dict.keys()}"
if ids is None:
self._feature_dict[key] = value
assert self._tensor.shape == value.shape, (
f"ids is None, so the entire feature will be updated. "
f"But the shape of the feature is {self._tensor.shape}, "
f"while the shape of the value is {value.shape}."
)
self._tensor[:] = value
else:
assert ids.shape[0] == value.shape[0], (
f"ids and value must have the same length, "
f"but got {ids.shape[0]} and {value.shape[0]}."
)
self._feature_dict[key][ids] = value
self._tensor[ids] = value
......@@ -5,16 +5,18 @@ import torch
def get_graphbolt_fetch_func():
feature_store = dgl.graphbolt.feature_store.TorchBasedFeatureStore(
{
"feature": torch.randn(200, 4),
"label": torch.randint(0, 10, (200,)),
feature_store = {
"feature": dgl.graphbolt.feature_store.TorchBasedFeatureStore(
torch.randn(200, 4)
),
"label": dgl.graphbolt.feature_store.TorchBasedFeatureStore(
torch.randint(0, 10, (200,))
),
}
)
def fetch_func(data):
return feature_store.read("feature", data), feature_store.read(
"label", data
return feature_store["feature"].read(data), feature_store["label"].read(
data
)
return fetch_func
......
......@@ -21,40 +21,39 @@ def to_on_disk_tensor(test_dir, name, t):
def test_torch_based_feature_store(in_memory):
with tempfile.TemporaryDirectory() as test_dir:
a = torch.tensor([1, 2, 3])
b = torch.tensor([3, 4, 5])
c = torch.tensor([[1, 2, 3], [4, 5, 6]])
b = torch.tensor([[1, 2, 3], [4, 5, 6]])
if not in_memory:
a = to_on_disk_tensor(test_dir, "a", a)
b = to_on_disk_tensor(test_dir, "b", b)
c = to_on_disk_tensor(test_dir, "c", c)
feature_store = gb.TorchBasedFeatureStore({"a": a, "b": b, "c": c})
assert torch.equal(feature_store.read("a"), torch.tensor([1, 2, 3]))
assert torch.equal(feature_store.read("b"), torch.tensor([3, 4, 5]))
feat_store_a = gb.TorchBasedFeatureStore(a)
feat_store_b = gb.TorchBasedFeatureStore(b)
assert torch.equal(feat_store_a.read(), torch.tensor([1, 2, 3]))
assert torch.equal(
feat_store_b.read(), torch.tensor([[1, 2, 3], [4, 5, 6]])
)
assert torch.equal(
feature_store.read("a", torch.tensor([0, 2])),
feat_store_a.read(torch.tensor([0, 2])),
torch.tensor([1, 3]),
)
assert torch.equal(
feature_store.read("a", torch.tensor([1, 1])),
feat_store_a.read(torch.tensor([1, 1])),
torch.tensor([2, 2]),
)
assert torch.equal(
feature_store.read("c", torch.tensor([1])),
feat_store_b.read(torch.tensor([1])),
torch.tensor([[4, 5, 6]]),
)
feature_store.update("a", torch.tensor([0, 1, 2]))
assert torch.equal(feature_store.read("a"), torch.tensor([0, 1, 2]))
assert torch.equal(
feature_store.read("a", torch.tensor([0, 2])),
torch.tensor([0, 2]),
)
with pytest.raises(AssertionError):
feature_store.read("d")
feat_store_a.update(torch.tensor([0, 1, 2]), torch.tensor([0, 1, 2]))
assert torch.equal(feat_store_a.read(), torch.tensor([0, 1, 2]))
feat_store_a.update(torch.tensor([2, 0]), torch.tensor([0, 2]))
assert torch.equal(feat_store_a.read(), torch.tensor([2, 1, 0]))
with pytest.raises(IndexError):
feature_store.read("a", torch.tensor([0, 1, 2, 3]))
feat_store_a.read(torch.tensor([0, 1, 2, 3]))
# For windows, the file is locked by the numpy.load. We need to delete
# it before closing the temporary directory.
a = b = c = feature_store = None
a = b = None
feat_store_a = feat_store_b = None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment