Unverified Commit 4a861775 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] support both pydantic 1.x and 2.x (#6411)


Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent b79dae36
......@@ -28,14 +28,14 @@ from .utils import (
def load_graphbolt():
"""Load Graphbolt C++ library"""
version = torch.__version__.split("+", maxsplit=1)[0]
vers = torch.__version__.split("+", maxsplit=1)[0]
if sys.platform.startswith("linux"):
basename = f"libgraphbolt_pytorch_{version}.so"
basename = f"libgraphbolt_pytorch_{vers}.so"
elif sys.platform.startswith("darwin"):
basename = f"libgraphbolt_pytorch_{version}.dylib"
basename = f"libgraphbolt_pytorch_{vers}.dylib"
elif sys.platform.startswith("win"):
basename = f"graphbolt_pytorch_{version}.dll"
basename = f"graphbolt_pytorch_{vers}.dll"
else:
raise NotImplementedError("Unsupported system: %s" % sys.platform)
......
......@@ -5,6 +5,8 @@ from typing import Any, Dict, List, Optional
import pydantic
from ...utils import version
__all__ = [
"OnDiskFeatureDataFormat",
"OnDiskTVTSetData",
......@@ -80,15 +82,32 @@ class OnDiskTaskData(pydantic.BaseModel, extra="allow"):
test_set: Optional[List[OnDiskTVTSet]] = []
extra_fields: Optional[Dict[str, Any]] = {}
@pydantic.model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra fields."""
for key in list(values.keys()):
if key not in cls.model_fields:
values["extra_fields"] = values.get("extra_fields", {})
values["extra_fields"][key] = values.pop(key)
return values
# As pydantic 2.0 has changed the API of validators, we need to use
# different validators for different versions to be compatible with
# previous versions.
if version.parse(pydantic.__version__) >= version.parse("2.0"):
@pydantic.model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra fields."""
for key in list(values.keys()):
if key not in cls.model_fields:
values["extra_fields"] = values.get("extra_fields", {})
values["extra_fields"][key] = values.pop(key)
return values
else:
@pydantic.root_validator(pre=True)
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra fields."""
for key in list(values.keys()):
if key not in ["train_set", "validation_set", "test_set"]:
values["extra_fields"] = values.get("extra_fields", {})
values["extra_fields"][key] = values.pop(key)
return values
class OnDiskMetaData(pydantic.BaseModel):
......
......@@ -1582,7 +1582,9 @@ def test_OnDiskDataset_load_graph():
dataset.yaml_data["graph_topology"]["type"] = "fake_type"
with pytest.raises(
pydantic.ValidationError,
match="Input should be 'CSCSamplingGraph'",
# As error message diffs in pydantic 1.x and 2.x, we just match
# keyword only.
match="'CSCSamplingGraph'",
):
dataset.load()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment