"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "bdd3d0c76ddabd48057c97f6560675e546b58f85"
Unverified Commit 243e186e authored by Angela Yi's avatar Angela Yi Committed by GitHub
Browse files

Add serialization logic to pytree types (#27871)

* Add serialized type name to pytrees

* Modify context

* add serde test
parent f1cc6157
......@@ -28,6 +28,7 @@ logger = logging.get_logger(__name__)
parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)
is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2")
is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1")
is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0")
is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse("1.13")
......
......@@ -22,11 +22,13 @@ from collections.abc import MutableMapping
from contextlib import ExitStack, contextmanager
from dataclasses import fields, is_dataclass
from enum import Enum
from functools import partial
from typing import Any, ContextManager, Iterable, List, Tuple
import numpy as np
from packaging import version
from .import_utils import is_flax_available, is_tf_available, is_torch_available, is_torch_fx_proxy
from .import_utils import get_torch_version, is_flax_available, is_tf_available, is_torch_available, is_torch_fx_proxy
if is_flax_available():
......@@ -306,11 +308,19 @@ class ModelOutput(OrderedDict):
`static_graph=True` with modules that output `ModelOutput` subclasses.
"""
if is_torch_available():
torch_pytree_register_pytree_node(
cls,
_model_output_flatten,
_model_output_unflatten,
)
if version.parse(get_torch_version()) >= version.parse("2.2"):
_torch_pytree.register_pytree_node(
cls,
_model_output_flatten,
partial(_model_output_unflatten, output_type=cls),
serialized_type_name=f"{cls.__module__}.{cls.__name__}",
)
else:
_torch_pytree._register_pytree_node(
cls,
_model_output_flatten,
partial(_model_output_unflatten, output_type=cls),
)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
......@@ -432,21 +442,28 @@ if is_torch_available():
import torch.utils._pytree as _torch_pytree
def _model_output_flatten(output: ModelOutput) -> Tuple[List[Any], "_torch_pytree.Context"]:
return list(output.values()), (type(output), list(output.keys()))
def _model_output_unflatten(values: Iterable[Any], context: "_torch_pytree.Context") -> ModelOutput:
output_type, keys = context
return output_type(**dict(zip(keys, values)))
if hasattr(_torch_pytree, "register_pytree_node"):
torch_pytree_register_pytree_node = _torch_pytree.register_pytree_node
return list(output.values()), list(output.keys())
def _model_output_unflatten(
values: Iterable[Any],
context: "_torch_pytree.Context",
output_type=None,
) -> ModelOutput:
return output_type(**dict(zip(context, values)))
if version.parse(get_torch_version()) >= version.parse("2.2"):
_torch_pytree.register_pytree_node(
ModelOutput,
_model_output_flatten,
partial(_model_output_unflatten, output_type=ModelOutput),
serialized_type_name=f"{ModelOutput.__module__}.{ModelOutput.__name__}",
)
else:
torch_pytree_register_pytree_node = _torch_pytree._register_pytree_node
torch_pytree_register_pytree_node(
ModelOutput,
_model_output_flatten,
_model_output_unflatten,
)
_torch_pytree._register_pytree_node(
ModelOutput,
_model_output_flatten,
partial(_model_output_unflatten, output_type=ModelOutput),
)
class ExplicitEnum(str, Enum):
......
......@@ -13,12 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import unittest
from dataclasses import dataclass
from typing import Optional
from transformers import AlbertForMaskedLM
from transformers.testing_utils import require_torch
from transformers.utils import ModelOutput
from transformers.utils import ModelOutput, is_torch_available
if is_torch_available():
import torch
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_2
@dataclass
......@@ -135,9 +143,7 @@ class ModelOutputTester(unittest.TestCase):
self.assertFalse(pytree._is_leaf(x))
expected_flat_outs = [1.0, 2.0]
expected_tree_spec = pytree.TreeSpec(
ModelOutputTest, (ModelOutputTest, ["a", "c"]), [pytree.LeafSpec(), pytree.LeafSpec()]
)
expected_tree_spec = pytree.TreeSpec(ModelOutputTest, ["a", "c"], [pytree.LeafSpec(), pytree.LeafSpec()])
actual_flat_outs, actual_tree_spec = pytree.tree_flatten(x)
self.assertEqual(expected_flat_outs, actual_flat_outs)
......@@ -146,6 +152,33 @@ class ModelOutputTester(unittest.TestCase):
unflattened_x = pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
self.assertEqual(x, unflattened_x)
if is_torch_greater_or_equal_than_2_2:
self.assertEqual(
pytree.treespec_dumps(actual_tree_spec),
'[1, {"type": "tests.utils.test_model_output.ModelOutputTest", "context": ["a", "c"], "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}]',
)
@require_torch
def test_export_serialization(self):
if not is_torch_greater_or_equal_than_2_2:
return
model_cls = AlbertForMaskedLM
model_config = model_cls.config_class()
model = model_cls(model_config)
input_dict = {"input_ids": torch.randint(0, 30000, (1, 512), dtype=torch.int64, requires_grad=False)}
ep = torch.export.export(model, (), input_dict)
buffer = io.BytesIO()
torch.export.save(ep, buffer)
buffer.seek(0)
loaded_ep = torch.export.load(buffer)
input_dict = {"input_ids": torch.randint(0, 30000, (1, 512), dtype=torch.int64, requires_grad=False)}
assert torch.allclose(model(**input_dict).logits, loaded_ep(**input_dict).logits)
class ModelOutputTestNoDataclass(ModelOutput):
"""Invalid test subclass of ModelOutput where @dataclass decorator is not used"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment