Unverified Commit ff9f05c5 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Graphbolt] Remove keys in feature store. (#5938)

parent 28578137
...@@ -8,13 +8,11 @@ class FeatureStore: ...@@ -8,13 +8,11 @@ class FeatureStore:
def __init__(self): def __init__(self):
pass pass
def read(self, key: str, ids: torch.Tensor = None): def read(self, ids: torch.Tensor = None):
"""Read a feature from the feature store. """Read from the feature store.
Parameters Parameters
---------- ----------
key : str
The key that uniquely identifies the feature in the feature store.
ids : torch.Tensor, optional ids : torch.Tensor, optional
The index of the feature. If specified, only the specified indices The index of the feature. If specified, only the specified indices
of the feature are read. If None, the entire feature is returned. of the feature are read. If None, the entire feature is returned.
...@@ -26,17 +24,11 @@ class FeatureStore: ...@@ -26,17 +24,11 @@ class FeatureStore:
""" """
raise NotImplementedError raise NotImplementedError
def update(self, key: str, value: torch.Tensor, ids: torch.Tensor = None): def update(self, value: torch.Tensor, ids: torch.Tensor = None):
"""Update a feature in the feature store. """Update the feature store.
This function is used to update a feature in the feature store. The
feature is identified by a unique key, and its value is specified using
a tensor.
Parameters Parameters
---------- ----------
key : str
The key that uniquely identifies the feature in the feature store.
value : torch.Tensor value : torch.Tensor
The updated value of the feature. The updated value of the feature.
ids : torch.Tensor, optional ids : torch.Tensor, optional
...@@ -50,87 +42,58 @@ class FeatureStore: ...@@ -50,87 +42,58 @@ class FeatureStore:
class TorchBasedFeatureStore(FeatureStore): class TorchBasedFeatureStore(FeatureStore):
r"""Torch based key-value feature store, where the key are strings and r"""Torch based feature store."""
values are Pytorch tensors."""
def __init__(self, feature_dict: dict):
"""Initialize a torch based feature store.
The feature store is initialized with a dictionary of tensors, where the def __init__(self, torch_feature: torch.Tensor):
key is the name of a feature and the value is the tensor. The value can """Initialize a torch based feature store by a torch feature.
be multi-dimensional, where the first dimension is the index of the
feature.
Note that the values can be in memory or on disk. Note that the feature can be either in memory or on disk.
Parameters Parameters
---------- ----------
feature_dict : dict, optional torch_feature : torch.Tensor
A dictionary of tensors. The torch feature.
Examples Examples
-------- --------
>>> import torch >>> import torch
>>> feature_dict = { >>> torch_feat = torch.arange(0, 5)
... "user": torch.arange(0, 5), >>> feature_store = TorchBasedFeatureStore(torch_feat)
... "item": torch.arange(0, 6), >>> feature_store.read()
... "rel": torch.arange(0, 6).view(2, 3), tensor([0, 1, 2, 3, 4])
... } >>> feature_store.read(torch.tensor([0, 1, 2]))
>>> feature_store = TorchBasedFeatureStore(feature_dict)
>>> feature_store.read("user", torch.tensor([0, 1, 2]))
tensor([0, 1, 2]) tensor([0, 1, 2])
>>> feature_store.read("item", torch.tensor([0, 1, 2])) >>> feature_store.update(torch.ones(3, dtype=torch.long),
tensor([0, 1, 2]) ... torch.tensor([0, 1, 2]))
>>> feature_store.read("rel", torch.tensor([0])) >>> feature_store.read(torch.tensor([0, 1, 2, 3]))
tensor([[0, 1, 2]]) tensor([1, 1, 1, 3])
>>> feature_store.update("user",
... torch.ones(3, dtype=torch.long), torch.tensor([0, 1, 2]))
>>> feature_store.read("user", torch.tensor([0, 1, 2]))
tensor([1, 1, 1])
>>> import numpy as np >>> import numpy as np
>>> user = np.arange(0, 5) >>> arr = np.arange(0, 5)
>>> item = np.arange(0, 6) >>> np.save("/tmp/arr.npy", arr)
>>> np.save("/tmp/user.npy", user) >>> torch_feat = torch.as_tensor(np.load("/tmp/arr.npy",
>>. np.save("/tmp/item.npy", item) ... mmap_mode="r+"))
>>> feature_dict = { >>> feature_store = TorchBasedFeatureStore(torch_feat)
... "user": torch.as_tensor(np.load("/tmp/user.npy", >>> feature_store.read()
... mmap_mode="r+")), tensor([0, 1, 2, 3, 4])
... "item": torch.as_tensor(np.load("/tmp/item.npy", >>> feature_store.read(torch.tensor([0, 1, 2]))
... mmap_mode="r+")),
... }
>>> feature_store = TorchBasedFeatureStore(feature_dict)
>>> feature_store.read("user", torch.tensor([0, 1, 2]))
tensor([0, 1, 2]) tensor([0, 1, 2])
>>> feature_store.read("item", torch.tensor([3, 4, 2]))
tensor([3, 4, 2])
""" """
super(TorchBasedFeatureStore, self).__init__() super(TorchBasedFeatureStore, self).__init__()
assert isinstance(feature_dict, dict), ( assert isinstance(torch_feature, torch.Tensor), (
f"feature_dict in TorchBasedFeatureStore must be dict, " f"torch_feature in TorchBasedFeatureStore must be torch.Tensor, "
f"but got {type(feature_dict)}." f"but got {type(torch_feature)}."
) )
for k, v in feature_dict.items(): self._tensor = torch_feature
assert isinstance(
k, str
), f"Key in TorchBasedFeatureStore must be str, but got {k}."
assert isinstance(v, torch.Tensor), (
f"Value in TorchBasedFeatureStore must be torch.Tensor,"
f"but got {v}."
)
self._feature_dict = feature_dict def read(self, ids: torch.Tensor = None):
"""Read the feature by index.
def read(self, key: str, ids: torch.Tensor = None): The returned tensor is always in memory, no matter whether the feature
"""Read a feature from the feature store by index. store is in memory or on disk.
The returned feature is always in memory, no matter whether the feature
to read is in memory or on disk.
Parameters Parameters
---------- ----------
key : str
The key of the feature.
ids : torch.Tensor, optional ids : torch.Tensor, optional
The index of the feature. If specified, only the specified indices The index of the feature. If specified, only the specified indices
of the feature are read. If None, the entire feature is returned. of the feature are read. If None, the entire feature is returned.
...@@ -140,24 +103,15 @@ class TorchBasedFeatureStore(FeatureStore): ...@@ -140,24 +103,15 @@ class TorchBasedFeatureStore(FeatureStore):
torch.Tensor torch.Tensor
The read feature. The read feature.
""" """
assert (
key in self._feature_dict
), f"key {key} not in {self._feature_dict.keys()}"
if ids is None: if ids is None:
return self._feature_dict[key] return self._tensor
return self._feature_dict[key][ids] return self._tensor[ids]
def update(self, key: str, value: torch.Tensor, ids: torch.Tensor = None):
"""Update a feature in the feature store.
This function is used to update a feature in the feature store. The def update(self, value: torch.Tensor, ids: torch.Tensor = None):
feature is identified by a unique key, and its value is specified using """Update the feature store.
a tensor.
Parameters Parameters
---------- ----------
key : str
The key that uniquely identifies the feature in the feature store.
value : torch.Tensor value : torch.Tensor
The updated value of the feature. The updated value of the feature.
ids : torch.Tensor, optional ids : torch.Tensor, optional
...@@ -167,14 +121,16 @@ class TorchBasedFeatureStore(FeatureStore): ...@@ -167,14 +121,16 @@ class TorchBasedFeatureStore(FeatureStore):
must have the same length. If None, the entire feature will be must have the same length. If None, the entire feature will be
updated. updated.
""" """
assert (
key in self._feature_dict
), f"key {key} not in {self._feature_dict.keys()}"
if ids is None: if ids is None:
self._feature_dict[key] = value assert self._tensor.shape == value.shape, (
f"ids is None, so the entire feature will be updated. "
f"But the shape of the feature is {self._tensor.shape}, "
f"while the shape of the value is {value.shape}."
)
self._tensor[:] = value
else: else:
assert ids.shape[0] == value.shape[0], ( assert ids.shape[0] == value.shape[0], (
f"ids and value must have the same length, " f"ids and value must have the same length, "
f"but got {ids.shape[0]} and {value.shape[0]}." f"but got {ids.shape[0]} and {value.shape[0]}."
) )
self._feature_dict[key][ids] = value self._tensor[ids] = value
...@@ -5,16 +5,18 @@ import torch ...@@ -5,16 +5,18 @@ import torch
def get_graphbolt_fetch_func(): def get_graphbolt_fetch_func():
feature_store = dgl.graphbolt.feature_store.TorchBasedFeatureStore( feature_store = {
{ "feature": dgl.graphbolt.feature_store.TorchBasedFeatureStore(
"feature": torch.randn(200, 4), torch.randn(200, 4)
"label": torch.randint(0, 10, (200,)), ),
} "label": dgl.graphbolt.feature_store.TorchBasedFeatureStore(
) torch.randint(0, 10, (200,))
),
}
def fetch_func(data): def fetch_func(data):
return feature_store.read("feature", data), feature_store.read( return feature_store["feature"].read(data), feature_store["label"].read(
"label", data data
) )
return fetch_func return fetch_func
......
...@@ -21,40 +21,39 @@ def to_on_disk_tensor(test_dir, name, t): ...@@ -21,40 +21,39 @@ def to_on_disk_tensor(test_dir, name, t):
def test_torch_based_feature_store(in_memory): def test_torch_based_feature_store(in_memory):
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
a = torch.tensor([1, 2, 3]) a = torch.tensor([1, 2, 3])
b = torch.tensor([3, 4, 5]) b = torch.tensor([[1, 2, 3], [4, 5, 6]])
c = torch.tensor([[1, 2, 3], [4, 5, 6]])
if not in_memory: if not in_memory:
a = to_on_disk_tensor(test_dir, "a", a) a = to_on_disk_tensor(test_dir, "a", a)
b = to_on_disk_tensor(test_dir, "b", b) b = to_on_disk_tensor(test_dir, "b", b)
c = to_on_disk_tensor(test_dir, "c", c)
feature_store = gb.TorchBasedFeatureStore({"a": a, "b": b, "c": c}) feat_store_a = gb.TorchBasedFeatureStore(a)
assert torch.equal(feature_store.read("a"), torch.tensor([1, 2, 3])) feat_store_b = gb.TorchBasedFeatureStore(b)
assert torch.equal(feature_store.read("b"), torch.tensor([3, 4, 5]))
assert torch.equal(feat_store_a.read(), torch.tensor([1, 2, 3]))
assert torch.equal(
feat_store_b.read(), torch.tensor([[1, 2, 3], [4, 5, 6]])
)
assert torch.equal( assert torch.equal(
feature_store.read("a", torch.tensor([0, 2])), feat_store_a.read(torch.tensor([0, 2])),
torch.tensor([1, 3]), torch.tensor([1, 3]),
) )
assert torch.equal( assert torch.equal(
feature_store.read("a", torch.tensor([1, 1])), feat_store_a.read(torch.tensor([1, 1])),
torch.tensor([2, 2]), torch.tensor([2, 2]),
) )
assert torch.equal( assert torch.equal(
feature_store.read("c", torch.tensor([1])), feat_store_b.read(torch.tensor([1])),
torch.tensor([[4, 5, 6]]), torch.tensor([[4, 5, 6]]),
) )
feature_store.update("a", torch.tensor([0, 1, 2])) feat_store_a.update(torch.tensor([0, 1, 2]), torch.tensor([0, 1, 2]))
assert torch.equal(feature_store.read("a"), torch.tensor([0, 1, 2])) assert torch.equal(feat_store_a.read(), torch.tensor([0, 1, 2]))
assert torch.equal( feat_store_a.update(torch.tensor([2, 0]), torch.tensor([0, 2]))
feature_store.read("a", torch.tensor([0, 2])), assert torch.equal(feat_store_a.read(), torch.tensor([2, 1, 0]))
torch.tensor([0, 2]),
)
with pytest.raises(AssertionError):
feature_store.read("d")
with pytest.raises(IndexError): with pytest.raises(IndexError):
feature_store.read("a", torch.tensor([0, 1, 2, 3])) feat_store_a.read(torch.tensor([0, 1, 2, 3]))
# For windows, the file is locked by the numpy.load. We need to delete # For windows, the file is locked by the numpy.load. We need to delete
# it before closing the temporary directory. # it before closing the temporary directory.
a = b = c = feature_store = None a = b = None
feat_store_a = feat_store_b = None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment