Unverified Commit 4a861775 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] support both pydantic 1.x and 2.x (#6411)


Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent b79dae36
...@@ -28,14 +28,14 @@ from .utils import ( ...@@ -28,14 +28,14 @@ from .utils import (
def load_graphbolt(): def load_graphbolt():
"""Load Graphbolt C++ library""" """Load Graphbolt C++ library"""
version = torch.__version__.split("+", maxsplit=1)[0] vers = torch.__version__.split("+", maxsplit=1)[0]
if sys.platform.startswith("linux"): if sys.platform.startswith("linux"):
basename = f"libgraphbolt_pytorch_{version}.so" basename = f"libgraphbolt_pytorch_{vers}.so"
elif sys.platform.startswith("darwin"): elif sys.platform.startswith("darwin"):
basename = f"libgraphbolt_pytorch_{version}.dylib" basename = f"libgraphbolt_pytorch_{vers}.dylib"
elif sys.platform.startswith("win"): elif sys.platform.startswith("win"):
basename = f"graphbolt_pytorch_{version}.dll" basename = f"graphbolt_pytorch_{vers}.dll"
else: else:
raise NotImplementedError("Unsupported system: %s" % sys.platform) raise NotImplementedError("Unsupported system: %s" % sys.platform)
......
...@@ -5,6 +5,8 @@ from typing import Any, Dict, List, Optional ...@@ -5,6 +5,8 @@ from typing import Any, Dict, List, Optional
import pydantic import pydantic
from ...utils import version
__all__ = [ __all__ = [
"OnDiskFeatureDataFormat", "OnDiskFeatureDataFormat",
"OnDiskTVTSetData", "OnDiskTVTSetData",
...@@ -80,6 +82,11 @@ class OnDiskTaskData(pydantic.BaseModel, extra="allow"): ...@@ -80,6 +82,11 @@ class OnDiskTaskData(pydantic.BaseModel, extra="allow"):
test_set: Optional[List[OnDiskTVTSet]] = [] test_set: Optional[List[OnDiskTVTSet]] = []
extra_fields: Optional[Dict[str, Any]] = {} extra_fields: Optional[Dict[str, Any]] = {}
# As pydantic 2.0 has changed the API of validators, we need to use
# different validators for different versions to be compatible with
# previous versions.
if version.parse(pydantic.__version__) >= version.parse("2.0"):
@pydantic.model_validator(mode="before") @pydantic.model_validator(mode="before")
@classmethod @classmethod
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
...@@ -90,6 +97,18 @@ class OnDiskTaskData(pydantic.BaseModel, extra="allow"): ...@@ -90,6 +97,18 @@ class OnDiskTaskData(pydantic.BaseModel, extra="allow"):
values["extra_fields"][key] = values.pop(key) values["extra_fields"][key] = values.pop(key)
return values return values
else:
@pydantic.root_validator(pre=True)
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra fields."""
for key in list(values.keys()):
if key not in ["train_set", "validation_set", "test_set"]:
values["extra_fields"] = values.get("extra_fields", {})
values["extra_fields"][key] = values.pop(key)
return values
class OnDiskMetaData(pydantic.BaseModel): class OnDiskMetaData(pydantic.BaseModel):
"""Metadata specification in YAML. """Metadata specification in YAML.
......
...@@ -1582,7 +1582,9 @@ def test_OnDiskDataset_load_graph(): ...@@ -1582,7 +1582,9 @@ def test_OnDiskDataset_load_graph():
dataset.yaml_data["graph_topology"]["type"] = "fake_type" dataset.yaml_data["graph_topology"]["type"] = "fake_type"
with pytest.raises( with pytest.raises(
pydantic.ValidationError, pydantic.ValidationError,
match="Input should be 'CSCSamplingGraph'", # As error message diffs in pydantic 1.x and 2.x, we just match
# keyword only.
match="'CSCSamplingGraph'",
): ):
dataset.load() dataset.load()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment