Unverified Commit e6dcf8ab authored by cyyever's avatar cyyever Committed by GitHub
Browse files

Fix the deprecation warning of _torch_pytree._register_pytree_node (#27803)

parent f85a1e82
......@@ -306,7 +306,7 @@ class ModelOutput(OrderedDict):
`static_graph=True` with modules that output `ModelOutput` subclasses.
"""
if is_torch_available():
_torch_pytree._register_pytree_node(
torch_pytree_register_pytree_node(
cls,
_model_output_flatten,
_model_output_unflatten,
......@@ -438,7 +438,11 @@ if is_torch_available():
output_type, keys = context
return output_type(**dict(zip(keys, values)))
_torch_pytree._register_pytree_node(
if hasattr(_torch_pytree, "register_pytree_node"):
torch_pytree_register_pytree_node = _torch_pytree.register_pytree_node
else:
torch_pytree_register_pytree_node = _torch_pytree._register_pytree_node
torch_pytree_register_pytree_node(
ModelOutput,
_model_output_flatten,
_model_output_unflatten,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment