Unverified Commit 243e186e authored by Angela Yi's avatar Angela Yi Committed by GitHub
Browse files

Add serialization logic to pytree types (#27871)

* Add serialized type name to pytrees

* Modify context

* add serde test
parent f1cc6157
...@@ -28,6 +28,7 @@ logger = logging.get_logger(__name__) ...@@ -28,6 +28,7 @@ logger = logging.get_logger(__name__)
parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version) parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)
is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2")
is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1") is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1")
is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0") is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0")
is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse("1.13") is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse("1.13")
......
...@@ -22,11 +22,13 @@ from collections.abc import MutableMapping ...@@ -22,11 +22,13 @@ from collections.abc import MutableMapping
from contextlib import ExitStack, contextmanager from contextlib import ExitStack, contextmanager
from dataclasses import fields, is_dataclass from dataclasses import fields, is_dataclass
from enum import Enum from enum import Enum
from functools import partial
from typing import Any, ContextManager, Iterable, List, Tuple from typing import Any, ContextManager, Iterable, List, Tuple
import numpy as np import numpy as np
from packaging import version
from .import_utils import is_flax_available, is_tf_available, is_torch_available, is_torch_fx_proxy from .import_utils import get_torch_version, is_flax_available, is_tf_available, is_torch_available, is_torch_fx_proxy
if is_flax_available(): if is_flax_available():
...@@ -306,10 +308,18 @@ class ModelOutput(OrderedDict): ...@@ -306,10 +308,18 @@ class ModelOutput(OrderedDict):
`static_graph=True` with modules that output `ModelOutput` subclasses. `static_graph=True` with modules that output `ModelOutput` subclasses.
""" """
if is_torch_available(): if is_torch_available():
torch_pytree_register_pytree_node( if version.parse(get_torch_version()) >= version.parse("2.2"):
_torch_pytree.register_pytree_node(
cls, cls,
_model_output_flatten, _model_output_flatten,
_model_output_unflatten, partial(_model_output_unflatten, output_type=cls),
serialized_type_name=f"{cls.__module__}.{cls.__name__}",
)
else:
_torch_pytree._register_pytree_node(
cls,
_model_output_flatten,
partial(_model_output_unflatten, output_type=cls),
) )
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -432,20 +442,27 @@ if is_torch_available(): ...@@ -432,20 +442,27 @@ if is_torch_available():
import torch.utils._pytree as _torch_pytree import torch.utils._pytree as _torch_pytree
def _model_output_flatten(output: ModelOutput) -> Tuple[List[Any], "_torch_pytree.Context"]: def _model_output_flatten(output: ModelOutput) -> Tuple[List[Any], "_torch_pytree.Context"]:
return list(output.values()), (type(output), list(output.keys())) return list(output.values()), list(output.keys())
def _model_output_unflatten(values: Iterable[Any], context: "_torch_pytree.Context") -> ModelOutput: def _model_output_unflatten(
output_type, keys = context values: Iterable[Any],
return output_type(**dict(zip(keys, values))) context: "_torch_pytree.Context",
output_type=None,
) -> ModelOutput:
return output_type(**dict(zip(context, values)))
if hasattr(_torch_pytree, "register_pytree_node"): if version.parse(get_torch_version()) >= version.parse("2.2"):
torch_pytree_register_pytree_node = _torch_pytree.register_pytree_node _torch_pytree.register_pytree_node(
ModelOutput,
_model_output_flatten,
partial(_model_output_unflatten, output_type=ModelOutput),
serialized_type_name=f"{ModelOutput.__module__}.{ModelOutput.__name__}",
)
else: else:
torch_pytree_register_pytree_node = _torch_pytree._register_pytree_node _torch_pytree._register_pytree_node(
torch_pytree_register_pytree_node(
ModelOutput, ModelOutput,
_model_output_flatten, _model_output_flatten,
_model_output_unflatten, partial(_model_output_unflatten, output_type=ModelOutput),
) )
......
...@@ -13,12 +13,20 @@ ...@@ -13,12 +13,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import io
import unittest import unittest
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
from transformers import AlbertForMaskedLM
from transformers.testing_utils import require_torch from transformers.testing_utils import require_torch
from transformers.utils import ModelOutput from transformers.utils import ModelOutput, is_torch_available
if is_torch_available():
import torch
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_2
@dataclass @dataclass
...@@ -135,9 +143,7 @@ class ModelOutputTester(unittest.TestCase): ...@@ -135,9 +143,7 @@ class ModelOutputTester(unittest.TestCase):
self.assertFalse(pytree._is_leaf(x)) self.assertFalse(pytree._is_leaf(x))
expected_flat_outs = [1.0, 2.0] expected_flat_outs = [1.0, 2.0]
expected_tree_spec = pytree.TreeSpec( expected_tree_spec = pytree.TreeSpec(ModelOutputTest, ["a", "c"], [pytree.LeafSpec(), pytree.LeafSpec()])
ModelOutputTest, (ModelOutputTest, ["a", "c"]), [pytree.LeafSpec(), pytree.LeafSpec()]
)
actual_flat_outs, actual_tree_spec = pytree.tree_flatten(x) actual_flat_outs, actual_tree_spec = pytree.tree_flatten(x)
self.assertEqual(expected_flat_outs, actual_flat_outs) self.assertEqual(expected_flat_outs, actual_flat_outs)
...@@ -146,6 +152,33 @@ class ModelOutputTester(unittest.TestCase): ...@@ -146,6 +152,33 @@ class ModelOutputTester(unittest.TestCase):
unflattened_x = pytree.tree_unflatten(actual_flat_outs, actual_tree_spec) unflattened_x = pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
self.assertEqual(x, unflattened_x) self.assertEqual(x, unflattened_x)
if is_torch_greater_or_equal_than_2_2:
self.assertEqual(
pytree.treespec_dumps(actual_tree_spec),
'[1, {"type": "tests.utils.test_model_output.ModelOutputTest", "context": ["a", "c"], "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}]',
)
@require_torch
def test_export_serialization(self):
if not is_torch_greater_or_equal_than_2_2:
return
model_cls = AlbertForMaskedLM
model_config = model_cls.config_class()
model = model_cls(model_config)
input_dict = {"input_ids": torch.randint(0, 30000, (1, 512), dtype=torch.int64, requires_grad=False)}
ep = torch.export.export(model, (), input_dict)
buffer = io.BytesIO()
torch.export.save(ep, buffer)
buffer.seek(0)
loaded_ep = torch.export.load(buffer)
input_dict = {"input_ids": torch.randint(0, 30000, (1, 512), dtype=torch.int64, requires_grad=False)}
assert torch.allclose(model(**input_dict).logits, loaded_ep(**input_dict).logits)
class ModelOutputTestNoDataclass(ModelOutput): class ModelOutputTestNoDataclass(ModelOutput):
"""Invalid test subclass of ModelOutput where @dataclass decorator is not used""" """Invalid test subclass of ModelOutput where @dataclass decorator is not used"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment