Unverified Commit fb38bb16 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Support grayscale images in `numpy_to_pil` (#1025)

parent de00c632
...@@ -444,7 +444,11 @@ class FlaxDiffusionPipeline(ConfigMixin): ...@@ -444,7 +444,11 @@ class FlaxDiffusionPipeline(ConfigMixin):
if images.ndim == 3: if images.ndim == 3:
images = images[None, ...] images = images[None, ...]
images = (images * 255).round().astype("uint8") images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images] if images.shape[-1] == 1:
# special case for grayscale (single channel) images
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
else:
pil_images = [Image.fromarray(image) for image in images]
return pil_images return pil_images
......
...@@ -625,7 +625,11 @@ class DiffusionPipeline(ConfigMixin): ...@@ -625,7 +625,11 @@ class DiffusionPipeline(ConfigMixin):
if images.ndim == 3: if images.ndim == 3:
images = images[None, ...] images = images[None, ...]
images = (images * 255).round().astype("uint8") images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images] if images.shape[-1] == 1:
# special case for grayscale (single channel) images
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
else:
pil_images = [Image.fromarray(image) for image in images]
return pil_images return pil_images
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment