Unverified Commit a45dca07 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Fix BaseOutput initialization from dict (#570)

* Fix BaseOutput initialization from dict

* style

* Simplify post-init, add tests

* remove debug
parent c01ec2d1
......@@ -59,10 +59,17 @@ class BaseOutput(OrderedDict):
if not len(class_fields):
raise ValueError(f"{self.__class__.__name__} has no fields.")
for field in class_fields:
v = getattr(self, field.name)
if v is not None:
self[field.name] = v
first_field = getattr(self, class_fields[0].name)
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
if other_fields_are_none and isinstance(first_field, dict):
for key, value in first_field.items():
self[key] = value
else:
for field in class_fields:
v = getattr(self, field.name)
if v is not None:
self[field.name] = v
def __delitem__(self, *args, **kwargs):
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
......
import unittest
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
from diffusers.utils.outputs import BaseOutput
@dataclass
class CustomOutput(BaseOutput):
images: Union[List[PIL.Image.Image], np.ndarray]
class ConfigTester(unittest.TestCase):
def test_outputs_single_attribute(self):
outputs = CustomOutput(images=np.random.rand(1, 3, 4, 4))
# check every way of getting the attribute
assert isinstance(outputs.images, np.ndarray)
assert outputs.images.shape == (1, 3, 4, 4)
assert isinstance(outputs["images"], np.ndarray)
assert outputs["images"].shape == (1, 3, 4, 4)
assert isinstance(outputs[0], np.ndarray)
assert outputs[0].shape == (1, 3, 4, 4)
# test with a non-tensor attribute
outputs = CustomOutput(images=[PIL.Image.new("RGB", (4, 4))])
# check every way of getting the attribute
assert isinstance(outputs.images, list)
assert isinstance(outputs.images[0], PIL.Image.Image)
assert isinstance(outputs["images"], list)
assert isinstance(outputs["images"][0], PIL.Image.Image)
assert isinstance(outputs[0], list)
assert isinstance(outputs[0][0], PIL.Image.Image)
def test_outputs_dict_init(self):
# test output reinitialization with a `dict` for compatibility with `accelerate`
outputs = CustomOutput({"images": np.random.rand(1, 3, 4, 4)})
# check every way of getting the attribute
assert isinstance(outputs.images, np.ndarray)
assert outputs.images.shape == (1, 3, 4, 4)
assert isinstance(outputs["images"], np.ndarray)
assert outputs["images"].shape == (1, 3, 4, 4)
assert isinstance(outputs[0], np.ndarray)
assert outputs[0].shape == (1, 3, 4, 4)
# test with a non-tensor attribute
outputs = CustomOutput({"images": [PIL.Image.new("RGB", (4, 4))]})
# check every way of getting the attribute
assert isinstance(outputs.images, list)
assert isinstance(outputs.images[0], PIL.Image.Image)
assert isinstance(outputs["images"], list)
assert isinstance(outputs["images"][0], PIL.Image.Image)
assert isinstance(outputs[0], list)
assert isinstance(outputs[0][0], PIL.Image.Image)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment