"vscode:/vscode.git/clone" did not exist on "48441cc5b434ccbc1c088d7cd72e3b86e8afcfdd"
Commit 2b31740d authored by Patrick von Platen's avatar Patrick von Platen
Browse files
parents 63b34191 3bec90ff
...@@ -345,7 +345,8 @@ Textual Inversion is a technique for capturing novel concepts from a small numbe ...@@ -345,7 +345,8 @@ Textual Inversion is a technique for capturing novel concepts from a small numbe
## Stable Diffusion Community Pipelines ## Stable Diffusion Community Pipelines
The release of Stable Diffusion as an open source model has fostered a lot of interesting ideas and experimentation. Our [Community Examples folder](https://github.com/huggingface/diffusers/tree/main/examples/community) contains many ideas worth exploring, like interpolating to create animated videos, using CLIP Guidance for additional prompt fidelity, term weighting, and much more! [Take a look](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview) and [contribute your own](https://huggingface.co/docs/diffusers/using-diffusers/contribute_pipeline). The release of Stable Diffusion as an open source model has fostered a lot of interesting ideas and experimentation.
Our [Community Examples folder](https://github.com/huggingface/diffusers/tree/main/examples/community) contains many ideas worth exploring, like interpolating to create animated videos, using CLIP Guidance for additional prompt fidelity, term weighting, and much more! [Take a look](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview) and [contribute your own](https://huggingface.co/docs/diffusers/using-diffusers/contribute_pipeline).
## Other Examples ## Other Examples
...@@ -394,10 +395,14 @@ image.save("ddpm_generated_image.png") ...@@ -394,10 +395,14 @@ image.save("ddpm_generated_image.png")
- [Unconditional Latent Diffusion](https://huggingface.co/CompVis/ldm-celebahq-256) - [Unconditional Latent Diffusion](https://huggingface.co/CompVis/ldm-celebahq-256)
- [Unconditional Diffusion with continuous scheduler](https://huggingface.co/google/ncsnpp-ffhq-1024) - [Unconditional Diffusion with continuous scheduler](https://huggingface.co/google/ncsnpp-ffhq-1024)
**Other Notebooks**: **Other Image Notebooks**:
* [image-to-image generation with Stable Diffusion](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) ![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg), * [image-to-image generation with Stable Diffusion](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) ![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg),
* [tweak images via repeated Stable Diffusion seeds](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb) ![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg), * [tweak images via repeated Stable Diffusion seeds](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb) ![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg),
**Diffusers for Other Modalities**:
* [Molecule conformation generation](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/geodiff_molecule_conformation.ipynb) ![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg),
* [Model-based reinforcement learning](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/reinforcement_learning_with_diffusers.ipynb) ![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg),
### Web Demos ### Web Demos
If you just want to play around with some web demos, you can try out the following 🚀 Spaces: If you just want to play around with some web demos, you can try out the following 🚀 Spaces:
| Model | Hugging Face Spaces | | Model | Hugging Face Spaces |
......
...@@ -31,6 +31,14 @@ ...@@ -31,6 +31,14 @@
- local: using-diffusers/contribute_pipeline - local: using-diffusers/contribute_pipeline
title: "How to contribute a Pipeline" title: "How to contribute a Pipeline"
title: "Pipelines for Inference" title: "Pipelines for Inference"
- sections:
- local: using-diffusers/rl
title: "Reinforcement Learning"
- local: using-diffusers/audio
title: "Audio"
- local: using-diffusers/other-modalities
title: "Other Modalities"
title: "Taking Diffusers Beyond Images"
title: "Using Diffusers" title: "Using Diffusers"
- sections: - sections:
- local: optimization/fp16 - local: optimization/fp16
...@@ -107,4 +115,8 @@ ...@@ -107,4 +115,8 @@
- local: api/pipelines/repaint - local: api/pipelines/repaint
title: "RePaint" title: "RePaint"
title: "Pipelines" title: "Pipelines"
- sections:
- local: api/experimental/rl
title: "RL Planning"
title: "Experimental Features"
title: "API" title: "API"
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# TODO
Coming soon!
\ No newline at end of file
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Using Diffusers for audio
The [`DanceDiffusionPipeline`] can be used to generate audio rapidly!
More coming soon!
\ No newline at end of file
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Using Diffusers with other modalities
Diffusers is in the process of expanding to modalities other than images.
Currently, one example is for [molecule conformation](https://www.nature.com/subjects/molecular-conformation#:~:text=Definition,to%20changes%20in%20their%20environment.) generation.
* Generate conformations in Colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/geodiff_molecule_conformation.ipynb)
More coming soon!
\ No newline at end of file
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Using Diffusers for reinforcement learning
Support for one RL model and related pipelines is included in the `experimental` source of diffusers.
To try some of this in colab, please look at the following example:
* Model-based reinforcement learning on Colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/reinforcement_learning_with_diffusers.ipynb) ![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)
...@@ -370,7 +370,7 @@ def dummy(images, **kwargs): ...@@ -370,7 +370,7 @@ def dummy(images, **kwargs):
pipe.safety_checker = dummy pipe.safety_checker = dummy
images = [] images = []
generator = th.Generator("cuda").manual_seed(0) generator = torch.Generator("cuda").manual_seed(0)
seed = 0 seed = 0
prompt = "a forest | a camel" prompt = "a forest | a camel"
...@@ -399,6 +399,7 @@ import requests ...@@ -399,6 +399,7 @@ import requests
from PIL import Image from PIL import Image
from io import BytesIO from io import BytesIO
import torch import torch
import os
from diffusers import DiffusionPipeline, DDIMScheduler from diffusers import DiffusionPipeline, DDIMScheduler
has_cuda = torch.cuda.is_available() has_cuda = torch.cuda.is_available()
device = torch.device('cpu' if not has_cuda else 'cuda') device = torch.device('cpu' if not has_cuda else 'cuda')
...@@ -423,6 +424,7 @@ res = pipe.train( ...@@ -423,6 +424,7 @@ res = pipe.train(
num_inference_steps=50, num_inference_steps=50,
generator=generator) generator=generator)
res = pipe(alpha=1) res = pipe(alpha=1)
os.makedirs("imagic", exist_ok=True)
image = res.images[0] image = res.images[0]
image.save('./imagic/imagic_image_alpha_1.png') image.save('./imagic/imagic_image_alpha_1.png')
res = pipe(alpha=1.5) res = pipe(alpha=1.5)
...@@ -652,4 +654,4 @@ prompt = "a cup" # the masked out region will be replaced with this ...@@ -652,4 +654,4 @@ prompt = "a cup" # the masked out region will be replaced with this
with autocast("cuda"): with autocast("cuda"):
image = pipe(image=image, text=text, prompt=prompt).images[0] image = pipe(image=image, text=text, prompt=prompt).images[0]
``` ```
\ No newline at end of file
...@@ -472,7 +472,7 @@ def main(args): ...@@ -472,7 +472,7 @@ def main(args):
eps=args.adam_epsilon, eps=args.adam_epsilon,
) )
noise_scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler") noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")
train_dataset = DreamBoothDataset( train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir, instance_data_root=args.instance_data_dir,
......
...@@ -372,7 +372,7 @@ def main(): ...@@ -372,7 +372,7 @@ def main():
weight_decay=args.adam_weight_decay, weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon, eps=args.adam_epsilon,
) )
noise_scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler") noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")
# Get the datasets: you can either provide your own training and evaluation files (see below) # Get the datasets: you can either provide your own training and evaluation files (see below)
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
...@@ -605,7 +605,7 @@ def main(): ...@@ -605,7 +605,7 @@ def main():
vae=vae, vae=vae,
unet=unet, unet=unet,
tokenizer=tokenizer, tokenizer=tokenizer,
scheduler=PNDMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler"), scheduler=PNDMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler"),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
) )
......
...@@ -441,7 +441,7 @@ def main(): ...@@ -441,7 +441,7 @@ def main():
eps=args.adam_epsilon, eps=args.adam_epsilon,
) )
noise_scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler") noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")
train_dataset = TextualInversionDataset( train_dataset = TextualInversionDataset(
data_root=args.train_data_dir, data_root=args.train_data_dir,
...@@ -574,7 +574,7 @@ def main(): ...@@ -574,7 +574,7 @@ def main():
vae=vae, vae=vae,
unet=unet, unet=unet,
tokenizer=tokenizer, tokenizer=tokenizer,
scheduler=PNDMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler"), scheduler=PNDMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler"),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
) )
......
...@@ -82,6 +82,7 @@ if is_torch_available() and is_transformers_available() and is_onnx_available(): ...@@ -82,6 +82,7 @@ if is_torch_available() and is_transformers_available() and is_onnx_available():
from .pipelines import ( from .pipelines import (
OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionImg2ImgPipeline,
OnnxStableDiffusionInpaintPipeline, OnnxStableDiffusionInpaintPipeline,
OnnxStableDiffusionInpaintPipelineLegacy,
OnnxStableDiffusionPipeline, OnnxStableDiffusionPipeline,
StableDiffusionOnnxPipeline, StableDiffusionOnnxPipeline,
) )
......
...@@ -30,6 +30,7 @@ if is_transformers_available() and is_onnx_available(): ...@@ -30,6 +30,7 @@ if is_transformers_available() and is_onnx_available():
from .stable_diffusion import ( from .stable_diffusion import (
OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionImg2ImgPipeline,
OnnxStableDiffusionInpaintPipeline, OnnxStableDiffusionInpaintPipeline,
OnnxStableDiffusionInpaintPipelineLegacy,
OnnxStableDiffusionPipeline, OnnxStableDiffusionPipeline,
StableDiffusionOnnxPipeline, StableDiffusionOnnxPipeline,
) )
......
...@@ -81,7 +81,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -81,7 +81,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.AltDiffusionPipeline.__init__
def __init__( def __init__(
self, self,
vae: AutoencoderKL, vae: AutoencoderKL,
...@@ -148,7 +147,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -148,7 +147,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.AltDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r""" r"""
Enable sliced attention computation. Enable sliced attention computation.
...@@ -168,7 +166,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -168,7 +166,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
slice_size = self.unet.config.attention_head_dim // 2 slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size) self.unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.AltDiffusionPipeline.disable_attention_slicing
def disable_attention_slicing(self): def disable_attention_slicing(self):
r""" r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
...@@ -177,7 +174,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -177,7 +174,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
# set slice_size = `None` to disable `attention slicing` # set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None) self.enable_attention_slicing(None)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.AltDiffusionPipeline.enable_sequential_cpu_offload
def enable_sequential_cpu_offload(self, gpu_id=0): def enable_sequential_cpu_offload(self, gpu_id=0):
r""" r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
...@@ -196,7 +192,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -196,7 +192,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
cpu_offload(cpu_offloaded_model, device) cpu_offload(cpu_offloaded_model, device)
@property @property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.AltDiffusionPipeline._execution_device
def _execution_device(self): def _execution_device(self):
r""" r"""
Returns the device on which the pipeline's models will be executed. After calling Returns the device on which the pipeline's models will be executed. After calling
...@@ -214,7 +209,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -214,7 +209,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
return torch.device(module._hf_hook.execution_device) return torch.device(module._hf_hook.execution_device)
return self.device return self.device
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.AltDiffusionPipeline.enable_xformers_memory_efficient_attention
def enable_xformers_memory_efficient_attention(self): def enable_xformers_memory_efficient_attention(self):
r""" r"""
Enable memory efficient attention as implemented in xformers. Enable memory efficient attention as implemented in xformers.
...@@ -227,14 +221,12 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -227,14 +221,12 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
""" """
self.unet.set_use_memory_efficient_attention_xformers(True) self.unet.set_use_memory_efficient_attention_xformers(True)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.AltDiffusionPipeline.disable_xformers_memory_efficient_attention
def disable_xformers_memory_efficient_attention(self): def disable_xformers_memory_efficient_attention(self):
r""" r"""
Disable memory efficient attention as implemented in xformers. Disable memory efficient attention as implemented in xformers.
""" """
self.unet.set_use_memory_efficient_attention_xformers(False) self.unet.set_use_memory_efficient_attention_xformers(False)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.AltDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
...@@ -340,7 +332,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -340,7 +332,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
return text_embeddings return text_embeddings
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.AltDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None: if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
...@@ -351,7 +342,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -351,7 +342,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
has_nsfw_concept = None has_nsfw_concept = None
return image, has_nsfw_concept return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.AltDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
latents = 1 / 0.18215 * latents latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents).sample
...@@ -360,7 +350,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -360,7 +350,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
image = image.cpu().permute(0, 2, 3, 1).float().numpy() image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.AltDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta): def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
......
...@@ -39,6 +39,7 @@ if is_transformers_available() and is_onnx_available(): ...@@ -39,6 +39,7 @@ if is_transformers_available() and is_onnx_available():
from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline
from .pipeline_onnx_stable_diffusion_img2img import OnnxStableDiffusionImg2ImgPipeline from .pipeline_onnx_stable_diffusion_img2img import OnnxStableDiffusionImg2ImgPipeline
from .pipeline_onnx_stable_diffusion_inpaint import OnnxStableDiffusionInpaintPipeline from .pipeline_onnx_stable_diffusion_inpaint import OnnxStableDiffusionInpaintPipeline
from .pipeline_onnx_stable_diffusion_inpaint_legacy import OnnxStableDiffusionInpaintPipelineLegacy
if is_transformers_available() and is_flax_available(): if is_transformers_available() and is_flax_available():
import flax import flax
......
...@@ -165,6 +165,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -165,6 +165,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
guidance_scale: float = 7.5, guidance_scale: float = 7.5,
latents: Optional[jnp.array] = None, latents: Optional[jnp.array] = None,
debug: bool = False, debug: bool = False,
neg_prompt_ids: jnp.array = None,
): ):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
...@@ -177,10 +178,14 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -177,10 +178,14 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
batch_size = prompt_ids.shape[0] batch_size = prompt_ids.shape[0]
max_length = prompt_ids.shape[-1] max_length = prompt_ids.shape[-1]
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" if neg_prompt_ids is None:
) uncond_input = self.tokenizer(
uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0] [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
).input_ids
else:
uncond_input = neg_prompt_ids
uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
context = jnp.concatenate([uncond_embeddings, text_embeddings]) context = jnp.concatenate([uncond_embeddings, text_embeddings])
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
...@@ -251,6 +256,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -251,6 +256,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
return_dict: bool = True, return_dict: bool = True,
jit: bool = False, jit: bool = False,
debug: bool = False, debug: bool = False,
neg_prompt_ids: jnp.array = None,
**kwargs, **kwargs,
): ):
r""" r"""
...@@ -298,11 +304,30 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -298,11 +304,30 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
""" """
if jit: if jit:
images = _p_generate( images = _p_generate(
self, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug self,
prompt_ids,
params,
prng_seed,
num_inference_steps,
height,
width,
guidance_scale,
latents,
debug,
neg_prompt_ids,
) )
else: else:
images = self._generate( images = self._generate(
prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug prompt_ids,
params,
prng_seed,
num_inference_steps,
height,
width,
guidance_scale,
latents,
debug,
neg_prompt_ids,
) )
if self.safety_checker is not None: if self.safety_checker is not None:
...@@ -333,10 +358,29 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -333,10 +358,29 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
# TODO: maybe use a config dict instead of so many static argnums # TODO: maybe use a config dict instead of so many static argnums
@partial(jax.pmap, static_broadcasted_argnums=(0, 4, 5, 6, 7, 9)) @partial(jax.pmap, static_broadcasted_argnums=(0, 4, 5, 6, 7, 9))
def _p_generate( def _p_generate(
pipe, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug pipe,
prompt_ids,
params,
prng_seed,
num_inference_steps,
height,
width,
guidance_scale,
latents,
debug,
neg_prompt_ids,
): ):
return pipe._generate( return pipe._generate(
prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug prompt_ids,
params,
prng_seed,
num_inference_steps,
height,
width,
guidance_scale,
latents,
debug,
neg_prompt_ids,
) )
......
...@@ -408,9 +408,9 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -408,9 +408,9 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
# concat latents, mask, masked_image_latnets in the channel dimension # concat latents, mask, masked_image_latnets in the channel dimension
latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1)
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
latent_model_input = latent_model_input.cpu().numpy() latent_model_input = latent_model_input.cpu().numpy()
latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1)
# predict the noise residual # predict the noise residual
timestep = np.array([t], dtype=timestep_dtype) timestep = np.array([t], dtype=timestep_dtype)
......
...@@ -35,16 +35,93 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -35,16 +35,93 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def prepare_mask_and_masked_image(image, mask): def prepare_mask_and_masked_image(image, mask):
image = np.array(image.convert("RGB")) """
image = image[None].transpose(0, 3, 1, 2) Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline.
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 This means that those inputs will be converted to ``torch.Tensor`` with
shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for
the ``image`` and ``1`` for the ``mask``.
The ``image`` will be converted to ``torch.float32`` and normalized to be in
``[-1, 1]``. The ``mask`` will be binarized (``mask > 0.5``) and cast to
``torch.float32`` too.
mask = np.array(mask.convert("L")) Args:
mask = mask.astype(np.float32) / 255.0 image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
mask = mask[None, None] It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array``
mask[mask < 0.5] = 0 or a ``channels x height x width`` ``torch.Tensor`` or a
mask[mask >= 0.5] = 1 ``batch x channels x height x width`` ``torch.Tensor``.
mask = torch.from_numpy(mask) mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or
a ``1 x height x width`` ``torch.Tensor`` or a
``batch x 1 x height x width`` ``torch.Tensor``.
Raises:
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range.
ValueError: ``torch.Tensor`` mask should be in the ``[0, 1]`` range.
ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
(ot the other way around).
Returns:
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
dimensions: ``batch x channels x height x width``.
"""
if isinstance(image, torch.Tensor):
if not isinstance(mask, torch.Tensor):
raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
# Batch single image
if image.ndim == 3:
assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
image = image.unsqueeze(0)
# Batch and add channel dim for single mask
if mask.ndim == 2:
mask = mask.unsqueeze(0).unsqueeze(0)
# Batch single mask or add channel dim
if mask.ndim == 3:
# Single batched mask, no channel dim or single mask not batched but channel dim
if mask.shape[0] == 1:
mask = mask.unsqueeze(0)
# Batched masks no channel dim
else:
mask = mask.unsqueeze(1)
assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
# Check image is in [-1, 1]
if image.min() < -1 or image.max() > 1:
raise ValueError("Image should be in [-1, 1] range")
# Check mask is in [0, 1]
if mask.min() < 0 or mask.max() > 1:
raise ValueError("Mask should be in [0, 1] range")
# Binarize mask
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
# Image as float32
image = image.to(dtype=torch.float32)
elif isinstance(mask, torch.Tensor):
raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
else:
if isinstance(image, PIL.Image.Image):
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
if isinstance(mask, PIL.Image.Image):
mask = np.array(mask.convert("L"))
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = image * (mask < 0.5) masked_image = image * (mask < 0.5)
...@@ -586,9 +663,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -586,9 +663,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# concat latents, mask, masked_image_latents in the channel dimension # concat latents, mask, masked_image_latents in the channel dimension
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
# predict the noise residual # predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
......
...@@ -243,19 +243,18 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -243,19 +243,18 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
timesteps: torch.FloatTensor, timesteps: torch.FloatTensor,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples # Make sure sigmas and timesteps have the same device and dtype as original_samples
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps does not support float64 # mps does not support float64
self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else: else:
self.timesteps = self.timesteps.to(original_samples.device) schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device)
schedule_timesteps = self.timesteps
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = self.sigmas[step_indices].flatten() sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape): while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1) sigma = sigma.unsqueeze(-1)
......
...@@ -118,7 +118,10 @@ class FlaxSchedulerMixin: ...@@ -118,7 +118,10 @@ class FlaxSchedulerMixin:
""" """
config, kwargs = cls.load_config( config, kwargs = cls.load_config(
pretrained_model_name_or_path=pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs pretrained_model_name_or_path=pretrained_model_name_or_path,
subfolder=subfolder,
return_unused_kwargs=True,
**kwargs,
) )
scheduler, unused_kwargs = cls.from_config(config, return_unused_kwargs=True, **kwargs) scheduler, unused_kwargs = cls.from_config(config, return_unused_kwargs=True, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment