Unverified Commit 9cfd4ef0 authored by Charles Bensimon's avatar Charles Bensimon Committed by GitHub
Browse files

Make `BaseOutput` dataclasses picklable (#5234)

* Make BaseOutput dataclasses picklable

* make style

* Test

* Empty commit

* Simpler and safer
parent 78a78515
...@@ -16,7 +16,7 @@ Generic utilities ...@@ -16,7 +16,7 @@ Generic utilities
""" """
from collections import OrderedDict from collections import OrderedDict
from dataclasses import fields from dataclasses import fields, is_dataclass
from typing import Any, Tuple from typing import Any, Tuple
import numpy as np import numpy as np
...@@ -101,6 +101,13 @@ class BaseOutput(OrderedDict): ...@@ -101,6 +101,13 @@ class BaseOutput(OrderedDict):
# Don't call self.__setattr__ to avoid recursion errors # Don't call self.__setattr__ to avoid recursion errors
super().__setattr__(key, value) super().__setattr__(key, value)
def __reduce__(self):
if not is_dataclass(self):
return super().__reduce__()
callable, _args, *remaining = super().__reduce__()
args = tuple(getattr(self, field.name) for field in fields(self))
return callable, args, *remaining
def to_tuple(self) -> Tuple[Any]: def to_tuple(self) -> Tuple[Any]:
""" """
Convert self to a tuple containing all the attributes/keys that are not `None`. Convert self to a tuple containing all the attributes/keys that are not `None`.
......
import pickle as pkl
import unittest import unittest
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Union from typing import List, Union
...@@ -58,3 +59,13 @@ class ConfigTester(unittest.TestCase): ...@@ -58,3 +59,13 @@ class ConfigTester(unittest.TestCase):
assert isinstance(outputs["images"][0], PIL.Image.Image) assert isinstance(outputs["images"][0], PIL.Image.Image)
assert isinstance(outputs[0], list) assert isinstance(outputs[0], list)
assert isinstance(outputs[0][0], PIL.Image.Image) assert isinstance(outputs[0][0], PIL.Image.Image)
def test_outputs_serialization(self):
outputs_orig = CustomOutput(images=[PIL.Image.new("RGB", (4, 4))])
serialized = pkl.dumps(outputs_orig)
outputs_copy = pkl.loads(serialized)
# Check original and copy are equal
assert dir(outputs_orig) == dir(outputs_copy)
assert dict(outputs_orig) == dict(outputs_copy)
assert vars(outputs_orig) == vars(outputs_copy)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment