Unverified Commit 6d4958c2 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Graphbolt] Support feature metadata. (#6576)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent 8c3fc0d3
......@@ -52,6 +52,16 @@ class Feature:
"""
raise NotImplementedError
def metadata(self):
"""Get the metadata of the feature.
Returns
-------
Dict
The metadata of the feature.
"""
return {}
class FeatureStore:
r"""A store to manage multiple features for access."""
......@@ -110,6 +120,29 @@ class FeatureStore:
"""
raise NotImplementedError
def metadata(
self,
domain: str,
type_name: str,
feature_name: str,
):
"""Get the metadata of the specified feature in the feature store.
Parameters
----------
domain : str
The domain of the feature such as "node", "edge" or "graph".
type_name : str
The node or edge type name.
feature_name : str
The feature name.
Returns
-------
Dict
The metadata of the feature.
"""
raise NotImplementedError
def update(
self,
domain: str,
......
......@@ -81,6 +81,29 @@ class BasicFeatureStore(FeatureStore):
"""
return self._features[(domain, type_name, feature_name)].size()
def metadata(
self,
domain: str,
type_name: str,
feature_name: str,
):
"""Get the metadata of the specified feature in the feature store.
Parameters
----------
domain : str
The domain of the feature such as "node", "edge" or "graph".
type_name : str
The node or edge type name.
feature_name : str
The feature name.
Returns
-------
Dict
The metadata of the feature.
"""
return self._features[(domain, type_name, feature_name)].metadata()
def update(
self,
domain: str,
......
"""Torch-based feature store for GraphBolt."""
from typing import List
from typing import Dict, List
import numpy as np
import torch
......@@ -67,7 +67,7 @@ class TorchBasedFeature(Feature):
device(type='cuda', index=0)
"""
def __init__(self, torch_feature: torch.Tensor):
def __init__(self, torch_feature: torch.Tensor, metadata: Dict = None):
super().__init__()
assert isinstance(torch_feature, torch.Tensor), (
f"torch_feature in TorchBasedFeature must be torch.Tensor, "
......@@ -79,6 +79,7 @@ class TorchBasedFeature(Feature):
)
# Make sure the tensor is contiguous.
self._tensor = torch_feature.contiguous()
self._metadata = metadata
def read(self, ids: torch.Tensor = None):
"""Read the feature by index.
......@@ -151,6 +152,18 @@ class TorchBasedFeature(Feature):
)
self._tensor[ids] = value
def metadata(self):
"""Get the metadata of the feature.
Returns
-------
Dict
The metadata of the feature.
"""
return (
self._metadata if self._metadata is not None else super().metadata()
)
def pin_memory_(self):
"""In-place operation to copy the feature to pinned memory."""
self._tensor = self._tensor.pin_memory()
......
......@@ -7,9 +7,10 @@ from dgl import graphbolt as gb
def test_basic_feature_store_homo():
a = torch.tensor([[1, 2, 4], [2, 5, 3]])
b = torch.tensor([[[1, 2], [3, 4]], [[2, 5], [4, 3]]])
metadata = {"max_value": 3}
features = {}
features[("node", None, "a")] = gb.TorchBasedFeature(a)
features[("node", None, "a")] = gb.TorchBasedFeature(a, metadata=metadata)
features[("node", None, "b")] = gb.TorchBasedFeature(b)
feature_store = gb.BasicFeatureStore(features)
......@@ -38,13 +39,20 @@ def test_basic_feature_store_homo():
assert feature_store.size("node", None, "a") == torch.Size([3])
assert feature_store.size("node", None, "b") == torch.Size([2, 2])
# Test get metadata of the feature.
assert feature_store.metadata("node", None, "a") == metadata
assert feature_store.metadata("node", None, "b") == {}
def test_basic_feature_store_hetero():
a = torch.tensor([[1, 2, 4], [2, 5, 3]])
b = torch.tensor([[[6], [8]], [[8], [9]]])
metadata = {"max_value": 3}
features = {}
features[("node", "author", "a")] = gb.TorchBasedFeature(a)
features[("node", "author", "a")] = gb.TorchBasedFeature(
a, metadata=metadata
)
features[("edge", "paper:cites", "b")] = gb.TorchBasedFeature(b)
feature_store = gb.BasicFeatureStore(features)
......@@ -69,6 +77,10 @@ def test_basic_feature_store_hetero():
assert feature_store.size("node", "author", "a") == torch.Size([3])
assert feature_store.size("edge", "paper:cites", "b") == torch.Size([2, 1])
# Test get metadata of the feature.
assert feature_store.metadata("node", "author", "a") == metadata
assert feature_store.metadata("edge", "paper:cites", "b") == {}
def test_basic_feature_store_errors():
a = torch.tensor([3, 2, 1])
......
......@@ -29,11 +29,12 @@ def test_torch_based_feature(in_memory):
with tempfile.TemporaryDirectory() as test_dir:
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
b = torch.tensor([[[1, 2], [3, 4]], [[4, 5], [6, 7]]])
metadata = {"max_value": 3}
if not in_memory:
a = to_on_disk_tensor(test_dir, "a", a)
b = to_on_disk_tensor(test_dir, "b", b)
feature_a = gb.TorchBasedFeature(a)
feature_a = gb.TorchBasedFeature(a, metadata=metadata)
feature_b = gb.TorchBasedFeature(b)
# Read the entire feature.
......@@ -83,6 +84,11 @@ def test_torch_based_feature(in_memory):
# Test get the size of the entire feature.
assert feature_a.size() == torch.Size([3])
assert feature_b.size() == torch.Size([2, 2])
# Test get metadata of the feature.
assert feature_a.metadata() == metadata
assert feature_b.metadata() == {}
with pytest.raises(IndexError):
feature_a.read(torch.tensor([0, 1, 2, 3]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment