Unverified Commit a9190bad authored by Dhruv Naik's avatar Dhruv Naik Committed by GitHub
Browse files

Add Flax stable diffusion img2img pipeline (#1355)



* add flax img2img pipeline

* update pipeline

* black format file

* remove argg from get_timesteps

* update get_timesteps

* fix bug: make use of timesteps for for_loop

* black file

* black, isort, flake8

* update docstring

* update readme

* update flax img2img readme

* update sd pipeline init

* Update src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Update src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* update inits

* revert change

* update var name to image, typo

* update readme

* return new t_start instead of modified timestep

* black format

* isort files

* update docs

* fix-copies

* update prng_seed typing
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent d07f7300
......@@ -235,6 +235,55 @@ images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
```
Diffusers also has a Image-to-Image generation pipeline with Flax/Jax
```python
import jax
import numpy as np
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard
import requests
from io import BytesIO
from PIL import Image
from diffusers import FlaxStableDiffusionImg2ImgPipeline
def create_key(seed=0):
return jax.random.PRNGKey(seed)
rng = create_key(0)
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
response = requests.get(url)
init_img = Image.open(BytesIO(response.content)).convert("RGB")
init_img = init_img.resize((768, 512))
prompts = "A fantasy landscape, trending on artstation"
pipeline, params = FlaxStableDiffusionImg2ImgPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="flax",
dtype=jnp.bfloat16,
)
num_samples = jax.device_count()
rng = jax.random.split(rng, jax.device_count())
prompt_ids, processed_image = pipeline.prepare_inputs(prompt=[prompts]*num_samples, image = [init_img]*num_samples)
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
processed_image = shard(processed_image)
output = pipeline(
prompt_ids=prompt_ids,
image=processed_image,
params=p_params,
prng_seed=rng,
strength=0.75,
num_inference_steps=50,
jit=True,
height=512,
width=768).images
output_images = pipeline.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
```
### Image-to-Image text-guided generation with Stable Diffusion
The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images.
......
......@@ -164,10 +164,11 @@ else:
FlaxScoreSdeVeScheduler,
)
try:
if not (is_flax_available() and is_transformers_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_flax_and_transformers_objects import * # noqa F403
else:
from .pipelines import FlaxStableDiffusionPipeline
from .pipelines import FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionPipeline
......@@ -91,4 +91,4 @@ try:
except OptionalDependencyNotAvailable:
from ..utils.dummy_flax_and_transformers_objects import * # noqa F403
else:
from .stable_diffusion import FlaxStableDiffusionPipeline
from .stable_diffusion import FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionPipeline
......@@ -98,4 +98,5 @@ if is_transformers_available() and is_flax_available():
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline
from .pipeline_flax_stable_diffusion_img2img import FlaxStableDiffusionImg2ImgPipeline
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from functools import partial
from typing import Dict, List, Union
import numpy as np
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from flax.jax_utils import unreplicate
from flax.training.common_utils import shard
from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
from ...pipeline_flax_utils import FlaxDiffusionPipeline
from ...schedulers import (
FlaxDDIMScheduler,
FlaxDPMSolverMultistepScheduler,
FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler,
)
from ...utils import PIL_INTERPOLATION, logging
from . import FlaxStableDiffusionPipelineOutput
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
r"""
Pipeline for image-to-image generation using Stable Diffusion.
This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`FlaxAutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`FlaxCLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel),
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or
[`FlaxDPMSolverMultistepScheduler`].
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
def __init__(
self,
vae: FlaxAutoencoderKL,
text_encoder: FlaxCLIPTextModel,
tokenizer: CLIPTokenizer,
unet: FlaxUNet2DConditionModel,
scheduler: Union[
FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
],
safety_checker: FlaxStableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
dtype: jnp.dtype = jnp.float32,
):
super().__init__()
self.dtype = dtype
if safety_checker is None:
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
def prepare_inputs(self, prompt: Union[str, List[str]], image: Union[Image.Image, List[Image.Image]]):
if not isinstance(prompt, (str, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if not isinstance(image, (Image.Image, list)):
raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}")
if isinstance(image, Image.Image):
image = [image]
processed_image = []
for img in image:
processed_image.append(preprocess(img, self.dtype))
processed_image = jnp.array(processed_image).squeeze()
text_input = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="np",
)
return text_input.input_ids, processed_image
def _get_has_nsfw_concepts(self, features, params):
has_nsfw_concepts = self.safety_checker(features, params)
return has_nsfw_concepts
def _run_safety_checker(self, images, safety_model_params, jit=False):
# safety_model_params should already be replicated when jit is True
pil_images = [Image.fromarray(image) for image in images]
features = self.feature_extractor(pil_images, return_tensors="np").pixel_values
if jit:
features = shard(features)
has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params)
has_nsfw_concepts = unshard(has_nsfw_concepts)
safety_model_params = unreplicate(safety_model_params)
else:
has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params)
images_was_copied = False
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
if has_nsfw_concept:
if not images_was_copied:
images_was_copied = True
images = images.copy()
images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image
if any(has_nsfw_concepts):
warnings.warn(
"Potential NSFW content was detected in one or more images. A black image will be returned"
" instead. Try again with a different prompt and/or seed."
)
return images, has_nsfw_concepts
def get_timestep_start(self, num_inference_steps, strength, scheduler_state):
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
t_start = max(num_inference_steps - init_timestep + offset, 0)
return t_start
def _generate(
self,
prompt_ids: jnp.array,
image: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.PRNGKey,
strength: float = 0.8,
num_inference_steps: int = 50,
height: int = 512,
width: int = 512,
guidance_scale: float = 7.5,
debug: bool = False,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
# get prompt text embeddings
text_embeddings = self.text_encoder(prompt_ids, params=params["text_encoder"])[0]
# TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0`
# implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0`
batch_size = prompt_ids.shape[0]
max_length = prompt_ids.shape[-1]
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0]
context = jnp.concatenate([uncond_embeddings, text_embeddings])
# Create init_latents
init_latent_dist = self.vae.apply({"params": params["vae"]}, image, method=self.vae.encode).latent_dist
init_latents = init_latent_dist.sample(key=prng_seed).transpose((0, 3, 1, 2))
init_latents = 0.18215 * init_latents
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
noise = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype)
def loop_body(step, args):
latents, scheduler_state = args
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
latents_input = jnp.concatenate([latents] * 2)
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
timestep = jnp.broadcast_to(t, latents_input.shape[0])
latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t)
# predict the noise residual
noise_pred = self.unet.apply(
{"params": params["unet"]},
jnp.array(latents_input),
jnp.array(timestep, dtype=jnp.int32),
encoder_hidden_states=context,
).sample
# perform guidance
noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
return latents, scheduler_state
scheduler_state = self.scheduler.set_timesteps(
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape
)
t_start = self.get_timestep_start(num_inference_steps, strength, scheduler_state)
latent_timestep = scheduler_state.timesteps[t_start : t_start + 1].repeat(batch_size)
init_latents = self.scheduler.add_noise(init_latents, noise, latent_timestep)
latents = init_latents
if debug:
# run with python for loop
for i in range(t_start, len(scheduler_state.timesteps)):
latents, scheduler_state = loop_body(i, (latents, scheduler_state))
else:
latents, _ = jax.lax.fori_loop(
t_start, len(scheduler_state.timesteps), loop_body, (latents, scheduler_state)
)
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
return image
def __call__(
self,
prompt_ids: jnp.array,
image: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray,
num_inference_steps: int = 50,
height: int = 512,
width: int = 512,
guidance_scale: float = 7.5,
strength: float = 0.8,
return_dict: bool = True,
jit: bool = False,
debug: bool = False,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt_ids (`jnp.array`):
The prompt or prompts to guide the image generation.
image (`jnp.array`):
Array representing an image batch, that will be used as the starting point for the process.
params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights
prng_seed (`jax.random.KeyArray` or `jax.Array`): Array containing random number generator key
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
be maximum and the denoising process will run for the full number of iterations specified in
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
a plain tuple.
jit (`bool`, defaults to `False`):
Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument
exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release.
debug (`bool`, *optional*, defaults to `False`): Whether to make use of python forloop or lax.fori_loop
Returns:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
`tuple. When returning a tuple, the first element is a list with the generated images, and the second
element is a list of `bool`s denoting whether the corresponding generated image likely represents
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
"""
if jit:
image = _p_generate(
self,
prompt_ids,
image,
params,
prng_seed,
strength,
num_inference_steps,
height,
width,
guidance_scale,
debug,
)
else:
image = self._generate(
prompt_ids,
image,
params,
prng_seed,
strength,
num_inference_steps,
height,
width,
guidance_scale,
debug,
)
if self.safety_checker is not None:
safety_params = params["safety_checker"]
image_uint8_casted = (image * 255).round().astype("uint8")
num_devices, batch_size = image.shape[:2]
image_uint8_casted = np.asarray(image_uint8_casted).reshape(num_devices * batch_size, height, width, 3)
image_uint8_casted, has_nsfw_concept = self._run_safety_checker(image_uint8_casted, safety_params, jit)
image = np.asarray(image)
# block images
if any(has_nsfw_concept):
for i, is_nsfw in enumerate(has_nsfw_concept):
if is_nsfw:
image[i] = np.asarray(image_uint8_casted[i])
image = image.reshape(num_devices, batch_size, height, width, 3)
else:
has_nsfw_concept = False
if not return_dict:
return (image, has_nsfw_concept)
return FlaxStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
# TODO: maybe use a config dict instead of so many static argnums
@partial(jax.pmap, static_broadcasted_argnums=(0, 5, 6, 7, 8, 9, 10))
def _p_generate(
pipe,
prompt_ids,
image,
params,
prng_seed,
strength,
num_inference_steps,
height,
width,
guidance_scale,
debug,
):
return pipe._generate(
prompt_ids, image, params, prng_seed, strength, num_inference_steps, height, width, guidance_scale, debug
)
@partial(jax.pmap, static_broadcasted_argnums=(0,))
def _p_get_has_nsfw_concepts(pipe, features, params):
return pipe._get_has_nsfw_concepts(features, params)
def unshard(x: jnp.ndarray):
# einops.rearrange(x, 'd b ... -> (d b) ...')
num_devices, batch_size = x.shape[:2]
rest = x.shape[2:]
return x.reshape(num_devices * batch_size, *rest)
def preprocess(image, dtype):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = jnp.array(image).astype(dtype) / 255.0
image = image[None].transpose(0, 3, 1, 2)
return 2.0 * image - 1.0
......@@ -4,6 +4,21 @@
from ..utils import DummyObject, requires_backends
class FlaxStableDiffusionImg2ImgPipeline(metaclass=DummyObject):
_backends = ["flax", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["flax", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax", "transformers"])
class FlaxStableDiffusionPipeline(metaclass=DummyObject):
_backends = ["flax", "transformers"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment