"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "1fa5639438820b9288bd063a07ebf9e29a015b70"
Unverified Commit f6345635 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] enable metadata for feature data in OnDiskDataset (#6595)

parent bfde1422
......@@ -745,8 +745,9 @@ The ``graph`` field is used to specify the graph structure. It has two fields:
^^^^^^^^^^^^^^^
The ``feature_data`` field is used to specify the feature data. It is a list of
``feature`` objects. Each ``feature`` object has five fields: ``domain``, ``type``,
``name``, ``format`` and ``path``.
``feature`` objects. Each ``feature`` object has five canonical fields: ``domain``,
``type``, ``name``, ``format`` and ``path``. Any other fields will be passed to
the ``Feature.metadata`` object.
- ``domain``: ``string``
......
......@@ -20,6 +20,39 @@ __all__ = [
]
class ExtraMetaData(pydantic.BaseModel, extra="allow"):
"""Group extra fields into metadata. Internal use only."""
extra_fields: Optional[Dict[str, Any]] = {}
# As pydantic 2.0 has changed the API of validators, we need to use
# different validators for different versions to be compatible with
# previous versions.
if version.parse(pydantic.__version__) >= version.parse("2.0"):
@pydantic.model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra fields."""
for key in list(values.keys()):
if key not in cls.model_fields:
values["extra_fields"] = values.get("extra_fields", {})
values["extra_fields"][key] = values.pop(key)
return values
else:
@pydantic.root_validator(pre=True)
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra fields."""
for key in list(values.keys()):
if key not in ["train_set", "validation_set", "test_set"]:
values["extra_fields"] = values.get("extra_fields", {})
values["extra_fields"][key] = values.pop(key)
return values
class OnDiskFeatureDataFormat(str, Enum):
"""Enum of data format."""
......@@ -51,7 +84,7 @@ class OnDiskFeatureDataDomain(str, Enum):
GRAPH = "graph"
class OnDiskFeatureData(pydantic.BaseModel):
class OnDiskFeatureData(ExtraMetaData):
r"""The description of an on-disk feature."""
domain: OnDiskFeatureDataDomain
type: Optional[str] = None
......@@ -74,40 +107,12 @@ class OnDiskGraphTopology(pydantic.BaseModel):
path: str
class OnDiskTaskData(pydantic.BaseModel, extra="allow"):
class OnDiskTaskData(ExtraMetaData):
"""Task specification in YAML."""
train_set: Optional[List[OnDiskTVTSet]] = []
validation_set: Optional[List[OnDiskTVTSet]] = []
test_set: Optional[List[OnDiskTVTSet]] = []
extra_fields: Optional[Dict[str, Any]] = {}
# As pydantic 2.0 has changed the API of validators, we need to use
# different validators for different versions to be compatible with
# previous versions.
if version.parse(pydantic.__version__) >= version.parse("2.0"):
@pydantic.model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra fields."""
for key in list(values.keys()):
if key not in cls.model_fields:
values["extra_fields"] = values.get("extra_fields", {})
values["extra_fields"][key] = values.pop(key)
return values
else:
@pydantic.root_validator(pre=True)
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra fields."""
for key in list(values.keys()):
if key not in ["train_set", "validation_set", "test_set"]:
values["extra_fields"] = values.get("extra_fields", {})
values["extra_fields"][key] = values.pop(key)
return values
class OnDiskMetaData(pydantic.BaseModel):
......
......@@ -208,16 +208,20 @@ class TorchBasedFeatureStore(BasicFeatureStore):
features = {}
for spec in feat_data:
key = (spec.domain, spec.type, spec.name)
metadata = spec.extra_fields
if spec.format == "torch":
assert spec.in_memory, (
f"Pytorch tensor can only be loaded in memory, "
f"but the feature {key} is loaded on disk."
)
features[key] = TorchBasedFeature(torch.load(spec.path))
features[key] = TorchBasedFeature(
torch.load(spec.path), metadata=metadata
)
elif spec.format == "numpy":
mmap_mode = "r+" if not spec.in_memory else None
features[key] = TorchBasedFeature(
torch.as_tensor(np.load(spec.path, mmap_mode=mmap_mode))
torch.as_tensor(np.load(spec.path, mmap_mode=mmap_mode)),
metadata=metadata,
)
else:
raise ValueError(f"Unknown feature format {spec.format}")
......
......@@ -861,6 +861,7 @@ def test_OnDiskDataset_Feature_heterograph():
format: numpy
in_memory: false
path: {node_data_paper_path}
num_categories: 10
- domain: node
type: paper
name: labels
......@@ -873,6 +874,7 @@ def test_OnDiskDataset_Feature_heterograph():
format: numpy
in_memory: false
path: {edge_data_writes_path}
num_categories: 10
- domain: edge
type: "author:writes:paper"
name: labels
......@@ -896,20 +898,35 @@ def test_OnDiskDataset_Feature_heterograph():
feature_data.read("node", "paper", "feat"),
torch.tensor(node_data_paper),
)
assert (
feature_data.metadata("node", "paper", "feat")["num_categories"]
== 10
)
assert torch.equal(
feature_data.read("node", "paper", "labels"),
node_data_label.clone().detach(),
)
assert len(feature_data.metadata("node", "paper", "labels")) == 0
# Verify edge feature data.
assert torch.equal(
feature_data.read("edge", "author:writes:paper", "feat"),
torch.tensor(edge_data_writes),
)
assert (
feature_data.metadata("edge", "author:writes:paper", "feat")[
"num_categories"
]
== 10
)
assert torch.equal(
feature_data.read("edge", "author:writes:paper", "labels"),
edge_data_label.clone().detach(),
)
assert (
len(feature_data.metadata("edge", "author:writes:paper", "labels"))
== 0
)
feature_data = None
dataset = None
......@@ -947,6 +964,7 @@ def test_OnDiskDataset_Feature_homograph():
format: numpy
in_memory: false
path: {node_data_feat_path}
num_categories: 10
- domain: node
name: labels
format: numpy
......@@ -957,6 +975,7 @@ def test_OnDiskDataset_Feature_homograph():
format: numpy
in_memory: false
path: {edge_data_feat_path}
num_categories: 10
- domain: edge
name: labels
format: numpy
......@@ -979,20 +998,28 @@ def test_OnDiskDataset_Feature_homograph():
feature_data.read("node", None, "feat"),
torch.tensor(node_data_feat),
)
assert (
feature_data.metadata("node", None, "feat")["num_categories"] == 10
)
assert torch.equal(
feature_data.read("node", None, "labels"),
node_data_label.clone().detach(),
)
assert len(feature_data.metadata("node", None, "labels")) == 0
# Verify edge feature data.
assert torch.equal(
feature_data.read("edge", None, "feat"),
torch.tensor(edge_data_feat),
)
assert (
feature_data.metadata("edge", None, "feat")["num_categories"] == 10
)
assert torch.equal(
feature_data.read("edge", None, "labels"),
edge_data_label.clone().detach(),
)
assert len(feature_data.metadata("edge", None, "labels")) == 0
feature_data = None
dataset = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment