Unverified Commit e6dcf8ab authored by cyyever's avatar cyyever Committed by GitHub
Browse files

Fix the deprecation warning of _torch_pytree._register_pytree_node (#27803)

parent f85a1e82
...@@ -306,7 +306,7 @@ class ModelOutput(OrderedDict): ...@@ -306,7 +306,7 @@ class ModelOutput(OrderedDict):
`static_graph=True` with modules that output `ModelOutput` subclasses. `static_graph=True` with modules that output `ModelOutput` subclasses.
""" """
if is_torch_available(): if is_torch_available():
_torch_pytree._register_pytree_node( torch_pytree_register_pytree_node(
cls, cls,
_model_output_flatten, _model_output_flatten,
_model_output_unflatten, _model_output_unflatten,
...@@ -438,7 +438,11 @@ if is_torch_available(): ...@@ -438,7 +438,11 @@ if is_torch_available():
output_type, keys = context output_type, keys = context
return output_type(**dict(zip(keys, values))) return output_type(**dict(zip(keys, values)))
_torch_pytree._register_pytree_node( if hasattr(_torch_pytree, "register_pytree_node"):
torch_pytree_register_pytree_node = _torch_pytree.register_pytree_node
else:
torch_pytree_register_pytree_node = _torch_pytree._register_pytree_node
torch_pytree_register_pytree_node(
ModelOutput, ModelOutput,
_model_output_flatten, _model_output_flatten,
_model_output_unflatten, _model_output_unflatten,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment