Unverified Commit 25f11424 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Ensure Flax pipeline always returns numpy array (#1435)

* Ensure Flax pipeline always returns numpy array.

* Clarify documentation.
parent 89300131
......@@ -63,15 +63,14 @@ if is_transformers_available() and is_flax_available():
Output class for Stable Diffusion pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
images (`np.ndarray`)
Array of shape `(batch_size, height, width, num_channels)` with images from the diffusion pipeline.
nsfw_content_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
images: np.ndarray
nsfw_content_detected: List[bool]
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
......
......@@ -316,9 +316,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
jit (`bool`, defaults to `False`):
Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument
exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release.
......@@ -382,6 +379,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
images = images.reshape(num_devices, batch_size, height, width, 3)
else:
images = np.asarray(images)
has_nsfw_concept = False
if not return_dict:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment