"git@developer.sourcefind.cn:OpenDAS/deepspeed.git" did not exist on "be33bea4755e47a7dc460a860f86a02c501f83d7"
Unverified Commit f6345635 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] enable metadata for feature data in OnDiskDataset (#6595)

parent bfde1422
...@@ -745,8 +745,9 @@ The ``graph`` field is used to specify the graph structure. It has two fields: ...@@ -745,8 +745,9 @@ The ``graph`` field is used to specify the graph structure. It has two fields:
^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^
The ``feature_data`` field is used to specify the feature data. It is a list of The ``feature_data`` field is used to specify the feature data. It is a list of
``feature`` objects. Each ``feature`` object has five fields: ``domain``, ``type``, ``feature`` objects. Each ``feature`` object has five canonical fields: ``domain``,
``name``, ``format`` and ``path``. ``type``, ``name``, ``format`` and ``path``. Any other fields will be passed to
the ``Feature.metadata`` object.
- ``domain``: ``string`` - ``domain``: ``string``
......
...@@ -20,6 +20,39 @@ __all__ = [ ...@@ -20,6 +20,39 @@ __all__ = [
] ]
class ExtraMetaData(pydantic.BaseModel, extra="allow"):
"""Group extra fields into metadata. Internal use only."""
extra_fields: Optional[Dict[str, Any]] = {}
# As pydantic 2.0 has changed the API of validators, we need to use
# different validators for different versions to be compatible with
# previous versions.
if version.parse(pydantic.__version__) >= version.parse("2.0"):
@pydantic.model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra fields."""
for key in list(values.keys()):
if key not in cls.model_fields:
values["extra_fields"] = values.get("extra_fields", {})
values["extra_fields"][key] = values.pop(key)
return values
else:
@pydantic.root_validator(pre=True)
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra fields."""
for key in list(values.keys()):
if key not in ["train_set", "validation_set", "test_set"]:
values["extra_fields"] = values.get("extra_fields", {})
values["extra_fields"][key] = values.pop(key)
return values
class OnDiskFeatureDataFormat(str, Enum): class OnDiskFeatureDataFormat(str, Enum):
"""Enum of data format.""" """Enum of data format."""
...@@ -51,7 +84,7 @@ class OnDiskFeatureDataDomain(str, Enum): ...@@ -51,7 +84,7 @@ class OnDiskFeatureDataDomain(str, Enum):
GRAPH = "graph" GRAPH = "graph"
class OnDiskFeatureData(pydantic.BaseModel): class OnDiskFeatureData(ExtraMetaData):
r"""The description of an on-disk feature.""" r"""The description of an on-disk feature."""
domain: OnDiskFeatureDataDomain domain: OnDiskFeatureDataDomain
type: Optional[str] = None type: Optional[str] = None
...@@ -74,40 +107,12 @@ class OnDiskGraphTopology(pydantic.BaseModel): ...@@ -74,40 +107,12 @@ class OnDiskGraphTopology(pydantic.BaseModel):
path: str path: str
class OnDiskTaskData(pydantic.BaseModel, extra="allow"): class OnDiskTaskData(ExtraMetaData):
"""Task specification in YAML.""" """Task specification in YAML."""
train_set: Optional[List[OnDiskTVTSet]] = [] train_set: Optional[List[OnDiskTVTSet]] = []
validation_set: Optional[List[OnDiskTVTSet]] = [] validation_set: Optional[List[OnDiskTVTSet]] = []
test_set: Optional[List[OnDiskTVTSet]] = [] test_set: Optional[List[OnDiskTVTSet]] = []
extra_fields: Optional[Dict[str, Any]] = {}
# As pydantic 2.0 has changed the API of validators, we need to use
# different validators for different versions to be compatible with
# previous versions.
if version.parse(pydantic.__version__) >= version.parse("2.0"):
@pydantic.model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra fields."""
for key in list(values.keys()):
if key not in cls.model_fields:
values["extra_fields"] = values.get("extra_fields", {})
values["extra_fields"][key] = values.pop(key)
return values
else:
@pydantic.root_validator(pre=True)
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra fields."""
for key in list(values.keys()):
if key not in ["train_set", "validation_set", "test_set"]:
values["extra_fields"] = values.get("extra_fields", {})
values["extra_fields"][key] = values.pop(key)
return values
class OnDiskMetaData(pydantic.BaseModel): class OnDiskMetaData(pydantic.BaseModel):
......
...@@ -208,16 +208,20 @@ class TorchBasedFeatureStore(BasicFeatureStore): ...@@ -208,16 +208,20 @@ class TorchBasedFeatureStore(BasicFeatureStore):
features = {} features = {}
for spec in feat_data: for spec in feat_data:
key = (spec.domain, spec.type, spec.name) key = (spec.domain, spec.type, spec.name)
metadata = spec.extra_fields
if spec.format == "torch": if spec.format == "torch":
assert spec.in_memory, ( assert spec.in_memory, (
f"Pytorch tensor can only be loaded in memory, " f"Pytorch tensor can only be loaded in memory, "
f"but the feature {key} is loaded on disk." f"but the feature {key} is loaded on disk."
) )
features[key] = TorchBasedFeature(torch.load(spec.path)) features[key] = TorchBasedFeature(
torch.load(spec.path), metadata=metadata
)
elif spec.format == "numpy": elif spec.format == "numpy":
mmap_mode = "r+" if not spec.in_memory else None mmap_mode = "r+" if not spec.in_memory else None
features[key] = TorchBasedFeature( features[key] = TorchBasedFeature(
torch.as_tensor(np.load(spec.path, mmap_mode=mmap_mode)) torch.as_tensor(np.load(spec.path, mmap_mode=mmap_mode)),
metadata=metadata,
) )
else: else:
raise ValueError(f"Unknown feature format {spec.format}") raise ValueError(f"Unknown feature format {spec.format}")
......
...@@ -861,6 +861,7 @@ def test_OnDiskDataset_Feature_heterograph(): ...@@ -861,6 +861,7 @@ def test_OnDiskDataset_Feature_heterograph():
format: numpy format: numpy
in_memory: false in_memory: false
path: {node_data_paper_path} path: {node_data_paper_path}
num_categories: 10
- domain: node - domain: node
type: paper type: paper
name: labels name: labels
...@@ -873,6 +874,7 @@ def test_OnDiskDataset_Feature_heterograph(): ...@@ -873,6 +874,7 @@ def test_OnDiskDataset_Feature_heterograph():
format: numpy format: numpy
in_memory: false in_memory: false
path: {edge_data_writes_path} path: {edge_data_writes_path}
num_categories: 10
- domain: edge - domain: edge
type: "author:writes:paper" type: "author:writes:paper"
name: labels name: labels
...@@ -896,20 +898,35 @@ def test_OnDiskDataset_Feature_heterograph(): ...@@ -896,20 +898,35 @@ def test_OnDiskDataset_Feature_heterograph():
feature_data.read("node", "paper", "feat"), feature_data.read("node", "paper", "feat"),
torch.tensor(node_data_paper), torch.tensor(node_data_paper),
) )
assert (
feature_data.metadata("node", "paper", "feat")["num_categories"]
== 10
)
assert torch.equal( assert torch.equal(
feature_data.read("node", "paper", "labels"), feature_data.read("node", "paper", "labels"),
node_data_label.clone().detach(), node_data_label.clone().detach(),
) )
assert len(feature_data.metadata("node", "paper", "labels")) == 0
# Verify edge feature data. # Verify edge feature data.
assert torch.equal( assert torch.equal(
feature_data.read("edge", "author:writes:paper", "feat"), feature_data.read("edge", "author:writes:paper", "feat"),
torch.tensor(edge_data_writes), torch.tensor(edge_data_writes),
) )
assert (
feature_data.metadata("edge", "author:writes:paper", "feat")[
"num_categories"
]
== 10
)
assert torch.equal( assert torch.equal(
feature_data.read("edge", "author:writes:paper", "labels"), feature_data.read("edge", "author:writes:paper", "labels"),
edge_data_label.clone().detach(), edge_data_label.clone().detach(),
) )
assert (
len(feature_data.metadata("edge", "author:writes:paper", "labels"))
== 0
)
feature_data = None feature_data = None
dataset = None dataset = None
...@@ -947,6 +964,7 @@ def test_OnDiskDataset_Feature_homograph(): ...@@ -947,6 +964,7 @@ def test_OnDiskDataset_Feature_homograph():
format: numpy format: numpy
in_memory: false in_memory: false
path: {node_data_feat_path} path: {node_data_feat_path}
num_categories: 10
- domain: node - domain: node
name: labels name: labels
format: numpy format: numpy
...@@ -957,6 +975,7 @@ def test_OnDiskDataset_Feature_homograph(): ...@@ -957,6 +975,7 @@ def test_OnDiskDataset_Feature_homograph():
format: numpy format: numpy
in_memory: false in_memory: false
path: {edge_data_feat_path} path: {edge_data_feat_path}
num_categories: 10
- domain: edge - domain: edge
name: labels name: labels
format: numpy format: numpy
...@@ -979,20 +998,28 @@ def test_OnDiskDataset_Feature_homograph(): ...@@ -979,20 +998,28 @@ def test_OnDiskDataset_Feature_homograph():
feature_data.read("node", None, "feat"), feature_data.read("node", None, "feat"),
torch.tensor(node_data_feat), torch.tensor(node_data_feat),
) )
assert (
feature_data.metadata("node", None, "feat")["num_categories"] == 10
)
assert torch.equal( assert torch.equal(
feature_data.read("node", None, "labels"), feature_data.read("node", None, "labels"),
node_data_label.clone().detach(), node_data_label.clone().detach(),
) )
assert len(feature_data.metadata("node", None, "labels")) == 0
# Verify edge feature data. # Verify edge feature data.
assert torch.equal( assert torch.equal(
feature_data.read("edge", None, "feat"), feature_data.read("edge", None, "feat"),
torch.tensor(edge_data_feat), torch.tensor(edge_data_feat),
) )
assert (
feature_data.metadata("edge", None, "feat")["num_categories"] == 10
)
assert torch.equal( assert torch.equal(
feature_data.read("edge", None, "labels"), feature_data.read("edge", None, "labels"),
edge_data_label.clone().detach(), edge_data_label.clone().detach(),
) )
assert len(feature_data.metadata("edge", None, "labels")) == 0
feature_data = None feature_data = None
dataset = None dataset = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment