Unverified Commit 6d4958c2 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Graphbolt] Support feature metadata. (#6576)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent 8c3fc0d3
...@@ -52,6 +52,16 @@ class Feature: ...@@ -52,6 +52,16 @@ class Feature:
""" """
raise NotImplementedError raise NotImplementedError
def metadata(self):
"""Get the metadata of the feature.
Returns
-------
Dict
The metadata of the feature.
"""
return {}
class FeatureStore: class FeatureStore:
r"""A store to manage multiple features for access.""" r"""A store to manage multiple features for access."""
...@@ -110,6 +120,29 @@ class FeatureStore: ...@@ -110,6 +120,29 @@ class FeatureStore:
""" """
raise NotImplementedError raise NotImplementedError
def metadata(
self,
domain: str,
type_name: str,
feature_name: str,
):
"""Get the metadata of the specified feature in the feature store.
Parameters
----------
domain : str
The domain of the feature such as "node", "edge" or "graph".
type_name : str
The node or edge type name.
feature_name : str
The feature name.
Returns
-------
Dict
The metadata of the feature.
"""
raise NotImplementedError
def update( def update(
self, self,
domain: str, domain: str,
......
...@@ -81,6 +81,29 @@ class BasicFeatureStore(FeatureStore): ...@@ -81,6 +81,29 @@ class BasicFeatureStore(FeatureStore):
""" """
return self._features[(domain, type_name, feature_name)].size() return self._features[(domain, type_name, feature_name)].size()
def metadata(
self,
domain: str,
type_name: str,
feature_name: str,
):
"""Get the metadata of the specified feature in the feature store.
Parameters
----------
domain : str
The domain of the feature such as "node", "edge" or "graph".
type_name : str
The node or edge type name.
feature_name : str
The feature name.
Returns
-------
Dict
The metadata of the feature.
"""
return self._features[(domain, type_name, feature_name)].metadata()
def update( def update(
self, self,
domain: str, domain: str,
......
"""Torch-based feature store for GraphBolt.""" """Torch-based feature store for GraphBolt."""
from typing import List from typing import Dict, List
import numpy as np import numpy as np
import torch import torch
...@@ -67,7 +67,7 @@ class TorchBasedFeature(Feature): ...@@ -67,7 +67,7 @@ class TorchBasedFeature(Feature):
device(type='cuda', index=0) device(type='cuda', index=0)
""" """
def __init__(self, torch_feature: torch.Tensor): def __init__(self, torch_feature: torch.Tensor, metadata: Dict = None):
super().__init__() super().__init__()
assert isinstance(torch_feature, torch.Tensor), ( assert isinstance(torch_feature, torch.Tensor), (
f"torch_feature in TorchBasedFeature must be torch.Tensor, " f"torch_feature in TorchBasedFeature must be torch.Tensor, "
...@@ -79,6 +79,7 @@ class TorchBasedFeature(Feature): ...@@ -79,6 +79,7 @@ class TorchBasedFeature(Feature):
) )
# Make sure the tensor is contiguous. # Make sure the tensor is contiguous.
self._tensor = torch_feature.contiguous() self._tensor = torch_feature.contiguous()
self._metadata = metadata
def read(self, ids: torch.Tensor = None): def read(self, ids: torch.Tensor = None):
"""Read the feature by index. """Read the feature by index.
...@@ -151,6 +152,18 @@ class TorchBasedFeature(Feature): ...@@ -151,6 +152,18 @@ class TorchBasedFeature(Feature):
) )
self._tensor[ids] = value self._tensor[ids] = value
def metadata(self):
"""Get the metadata of the feature.
Returns
-------
Dict
The metadata of the feature.
"""
return (
self._metadata if self._metadata is not None else super().metadata()
)
def pin_memory_(self): def pin_memory_(self):
"""In-place operation to copy the feature to pinned memory.""" """In-place operation to copy the feature to pinned memory."""
self._tensor = self._tensor.pin_memory() self._tensor = self._tensor.pin_memory()
......
...@@ -7,9 +7,10 @@ from dgl import graphbolt as gb ...@@ -7,9 +7,10 @@ from dgl import graphbolt as gb
def test_basic_feature_store_homo(): def test_basic_feature_store_homo():
a = torch.tensor([[1, 2, 4], [2, 5, 3]]) a = torch.tensor([[1, 2, 4], [2, 5, 3]])
b = torch.tensor([[[1, 2], [3, 4]], [[2, 5], [4, 3]]]) b = torch.tensor([[[1, 2], [3, 4]], [[2, 5], [4, 3]]])
metadata = {"max_value": 3}
features = {} features = {}
features[("node", None, "a")] = gb.TorchBasedFeature(a) features[("node", None, "a")] = gb.TorchBasedFeature(a, metadata=metadata)
features[("node", None, "b")] = gb.TorchBasedFeature(b) features[("node", None, "b")] = gb.TorchBasedFeature(b)
feature_store = gb.BasicFeatureStore(features) feature_store = gb.BasicFeatureStore(features)
...@@ -38,13 +39,20 @@ def test_basic_feature_store_homo(): ...@@ -38,13 +39,20 @@ def test_basic_feature_store_homo():
assert feature_store.size("node", None, "a") == torch.Size([3]) assert feature_store.size("node", None, "a") == torch.Size([3])
assert feature_store.size("node", None, "b") == torch.Size([2, 2]) assert feature_store.size("node", None, "b") == torch.Size([2, 2])
# Test get metadata of the feature.
assert feature_store.metadata("node", None, "a") == metadata
assert feature_store.metadata("node", None, "b") == {}
def test_basic_feature_store_hetero(): def test_basic_feature_store_hetero():
a = torch.tensor([[1, 2, 4], [2, 5, 3]]) a = torch.tensor([[1, 2, 4], [2, 5, 3]])
b = torch.tensor([[[6], [8]], [[8], [9]]]) b = torch.tensor([[[6], [8]], [[8], [9]]])
metadata = {"max_value": 3}
features = {} features = {}
features[("node", "author", "a")] = gb.TorchBasedFeature(a) features[("node", "author", "a")] = gb.TorchBasedFeature(
a, metadata=metadata
)
features[("edge", "paper:cites", "b")] = gb.TorchBasedFeature(b) features[("edge", "paper:cites", "b")] = gb.TorchBasedFeature(b)
feature_store = gb.BasicFeatureStore(features) feature_store = gb.BasicFeatureStore(features)
...@@ -69,6 +77,10 @@ def test_basic_feature_store_hetero(): ...@@ -69,6 +77,10 @@ def test_basic_feature_store_hetero():
assert feature_store.size("node", "author", "a") == torch.Size([3]) assert feature_store.size("node", "author", "a") == torch.Size([3])
assert feature_store.size("edge", "paper:cites", "b") == torch.Size([2, 1]) assert feature_store.size("edge", "paper:cites", "b") == torch.Size([2, 1])
# Test get metadata of the feature.
assert feature_store.metadata("node", "author", "a") == metadata
assert feature_store.metadata("edge", "paper:cites", "b") == {}
def test_basic_feature_store_errors(): def test_basic_feature_store_errors():
a = torch.tensor([3, 2, 1]) a = torch.tensor([3, 2, 1])
......
...@@ -29,11 +29,12 @@ def test_torch_based_feature(in_memory): ...@@ -29,11 +29,12 @@ def test_torch_based_feature(in_memory):
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
a = torch.tensor([[1, 2, 3], [4, 5, 6]]) a = torch.tensor([[1, 2, 3], [4, 5, 6]])
b = torch.tensor([[[1, 2], [3, 4]], [[4, 5], [6, 7]]]) b = torch.tensor([[[1, 2], [3, 4]], [[4, 5], [6, 7]]])
metadata = {"max_value": 3}
if not in_memory: if not in_memory:
a = to_on_disk_tensor(test_dir, "a", a) a = to_on_disk_tensor(test_dir, "a", a)
b = to_on_disk_tensor(test_dir, "b", b) b = to_on_disk_tensor(test_dir, "b", b)
feature_a = gb.TorchBasedFeature(a) feature_a = gb.TorchBasedFeature(a, metadata=metadata)
feature_b = gb.TorchBasedFeature(b) feature_b = gb.TorchBasedFeature(b)
# Read the entire feature. # Read the entire feature.
...@@ -83,6 +84,11 @@ def test_torch_based_feature(in_memory): ...@@ -83,6 +84,11 @@ def test_torch_based_feature(in_memory):
# Test get the size of the entire feature. # Test get the size of the entire feature.
assert feature_a.size() == torch.Size([3]) assert feature_a.size() == torch.Size([3])
assert feature_b.size() == torch.Size([2, 2]) assert feature_b.size() == torch.Size([2, 2])
# Test get metadata of the feature.
assert feature_a.metadata() == metadata
assert feature_b.metadata() == {}
with pytest.raises(IndexError): with pytest.raises(IndexError):
feature_a.read(torch.tensor([0, 1, 2, 3])) feature_a.read(torch.tensor([0, 1, 2, 3]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment