Unverified Commit 07349c25 authored by nbpppp's avatar nbpppp Committed by GitHub
Browse files

Fix deprecation warning for torch.utils._pytree._register_pytree_node in PyTorch 2.2 (#7008)



Fixed deprecation warning for torch.utils._pytree._register_pytree_node in PyTorch 2.2
Co-authored-by: default avatarYinghua <yzho0423@uni.sydney.edu.au>
parent 8974c50b
...@@ -21,7 +21,7 @@ from typing import Any, Tuple ...@@ -21,7 +21,7 @@ from typing import Any, Tuple
import numpy as np import numpy as np
from .import_utils import is_torch_available from .import_utils import is_torch_available, is_torch_version
def is_tensor(x) -> bool: def is_tensor(x) -> bool:
...@@ -60,11 +60,18 @@ class BaseOutput(OrderedDict): ...@@ -60,11 +60,18 @@ class BaseOutput(OrderedDict):
if is_torch_available(): if is_torch_available():
import torch.utils._pytree import torch.utils._pytree
torch.utils._pytree._register_pytree_node( if is_torch_version("<", "2.2"):
cls, torch.utils._pytree._register_pytree_node(
torch.utils._pytree._dict_flatten, cls,
lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)), torch.utils._pytree._dict_flatten,
) lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)),
)
else:
torch.utils._pytree.register_pytree_node(
cls,
torch.utils._pytree._dict_flatten,
lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)),
)
def __post_init__(self) -> None: def __post_init__(self) -> None:
class_fields = fields(self) class_fields = fields(self)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment