Unverified Commit c5594795 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

Postprocessing refactor all others (#3337)



* add text2img

* fix-copies

* add

* add all other pipelines

* add

* add

* add

* add

* add

* make style

* style + fix copies

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
parent a757b2db
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import warnings
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
...@@ -22,6 +23,7 @@ from transformers import CLIPImageProcessor, XLMRobertaTokenizer ...@@ -22,6 +23,7 @@ from transformers import CLIPImageProcessor, XLMRobertaTokenizer
from diffusers.utils import is_accelerate_available, is_accelerate_version from diffusers.utils import is_accelerate_available, is_accelerate_version
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
...@@ -174,6 +176,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -174,6 +176,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
def enable_vae_slicing(self): def enable_vae_slicing(self):
...@@ -426,16 +429,27 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -426,16 +429,27 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
return prompt_embeds return prompt_embeds
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None: if self.safety_checker is None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker( image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype) images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
) )
else:
has_nsfw_concept = None
return image, has_nsfw_concept return image, has_nsfw_concept
def decode_latents(self, latents): def decode_latents(self, latents):
warnings.warn(
(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead"
),
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0] image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
...@@ -700,24 +714,19 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -700,24 +714,19 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
if output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents image = latents
has_nsfw_concept = None has_nsfw_concept = None
elif output_type == "pil":
# 8. Post-processing
image = self.decode_latents(latents)
# 9. Run safety checker if has_nsfw_concept is None:
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) do_denormalize = [True] * image.shape[0]
# 10. Convert to PIL
image = self.numpy_to_pil(image)
else: else:
# 8. Post-processing do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.decode_latents(latents)
# 9. Run safety checker image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
# Offload last model to CPU # Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import warnings
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
...@@ -22,6 +23,7 @@ from transformers import CLIPImageProcessor ...@@ -22,6 +23,7 @@ from transformers import CLIPImageProcessor
from diffusers.utils import is_accelerate_available from diffusers.utils import is_accelerate_available
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import logging, randn_tensor from ...utils import logging, randn_tensor
...@@ -184,6 +186,7 @@ class PaintByExamplePipeline(DiffusionPipeline): ...@@ -184,6 +186,7 @@ class PaintByExamplePipeline(DiffusionPipeline):
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
def enable_sequential_cpu_offload(self, gpu_id=0): def enable_sequential_cpu_offload(self, gpu_id=0):
...@@ -226,13 +229,17 @@ class PaintByExamplePipeline(DiffusionPipeline): ...@@ -226,13 +229,17 @@ class PaintByExamplePipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None: if self.safety_checker is None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker( image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype) images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
) )
else:
has_nsfw_concept = None
return image, has_nsfw_concept return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
...@@ -255,6 +262,11 @@ class PaintByExamplePipeline(DiffusionPipeline): ...@@ -255,6 +262,11 @@ class PaintByExamplePipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0] image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
...@@ -560,15 +572,19 @@ class PaintByExamplePipeline(DiffusionPipeline): ...@@ -560,15 +572,19 @@ class PaintByExamplePipeline(DiffusionPipeline):
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
# 11. Post-processing if not output_type == "latent":
image = self.decode_latents(latents) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
else:
image = latents
has_nsfw_concept = None
# 12. Run safety checker if has_nsfw_concept is None:
image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype) do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
# 13. Convert to PIL image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict: if not return_dict:
return (image, has_nsfw_concept) return (image, has_nsfw_concept)
......
import inspect import inspect
import warnings
from itertools import repeat from itertools import repeat
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import torch import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
...@@ -129,10 +131,31 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline): ...@@ -129,10 +131,31 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline):
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0] image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
...@@ -681,20 +704,19 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline): ...@@ -681,20 +704,19 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline):
callback(i, t, latents) callback(i, t, latents)
# 8. Post-processing # 8. Post-processing
image = self.decode_latents(latents) if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
if self.safety_checker is not None: image, has_nsfw_concept = self.run_safety_checker(image, self.device, text_embeddings.dtype)
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
self.device
)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
)
else: else:
image = latents
has_nsfw_concept = None has_nsfw_concept = None
if output_type == "pil": if has_nsfw_concept is None:
image = self.numpy_to_pil(image) do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
if not return_dict: if not return_dict:
return (image, has_nsfw_concept) return (image, has_nsfw_concept)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import warnings
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
...@@ -24,6 +25,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer ...@@ -24,6 +25,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from diffusers.utils import is_accelerate_available, is_accelerate_version from diffusers.utils import is_accelerate_available, is_accelerate_version
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler from ...schedulers import DDIMScheduler
...@@ -220,6 +222,8 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -220,6 +222,8 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
...@@ -504,17 +508,26 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -504,17 +508,26 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None: if self.safety_checker is None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker( image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype) images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
) )
else:
has_nsfw_concept = None
return image, has_nsfw_concept return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0] image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
...@@ -770,14 +783,19 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -770,14 +783,19 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
callback(i, t, latents) callback(i, t, latents)
# 9. Post-processing # 9. Post-processing
image = self.decode_latents(latents) if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None
# 10. Run safety checker if has_nsfw_concept is None:
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
# 11. Convert to PIL image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict: if not return_dict:
return (image, has_nsfw_concept) return (image, has_nsfw_concept)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import warnings
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
...@@ -20,6 +21,7 @@ from packaging import version ...@@ -20,6 +21,7 @@ from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
...@@ -177,6 +179,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -177,6 +179,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
def enable_vae_slicing(self): def enable_vae_slicing(self):
...@@ -429,16 +432,25 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -429,16 +432,25 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
return prompt_embeds return prompt_embeds
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None: if self.safety_checker is None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker( image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype) images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
) )
else:
has_nsfw_concept = None
return image, has_nsfw_concept return image, has_nsfw_concept
def decode_latents(self, latents): def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0] image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
...@@ -703,24 +715,19 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -703,24 +715,19 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
if output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents image = latents
has_nsfw_concept = None has_nsfw_concept = None
elif output_type == "pil":
# 8. Post-processing
image = self.decode_latents(latents)
# 9. Run safety checker if has_nsfw_concept is None:
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) do_denormalize = [True] * image.shape[0]
# 10. Convert to PIL
image = self.numpy_to_pil(image)
else: else:
# 8. Post-processing do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.decode_latents(latents)
# 9. Run safety checker image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
# Offload last model to CPU # Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import inspect import inspect
import math import math
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -21,6 +22,7 @@ import torch ...@@ -21,6 +22,7 @@ import torch
from torch.nn import functional as F from torch.nn import functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import Attention from ...models.attention_processor import Attention
...@@ -228,6 +230,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion ...@@ -228,6 +230,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
...@@ -442,17 +445,26 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion ...@@ -442,17 +445,26 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None: if self.safety_checker is None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker( image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype) images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
) )
else:
has_nsfw_concept = None
return image, has_nsfw_concept return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0] image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
...@@ -972,14 +984,19 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion ...@@ -972,14 +984,19 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
callback(i, t, latents) callback(i, t, latents)
# 8. Post-processing # 8. Post-processing
image = self.decode_latents(latents) if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None
# 9. Run safety checker if has_nsfw_concept is None:
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
# 10. Convert to PIL image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict: if not return_dict:
return (image, has_nsfw_concept) return (image, has_nsfw_concept)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import inspect import inspect
import os import os
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -24,6 +25,7 @@ import torch.nn.functional as F ...@@ -24,6 +25,7 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.controlnet import ControlNetOutput from ...models.controlnet import ControlNetOutput
...@@ -230,6 +232,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -230,6 +232,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
...@@ -485,17 +488,26 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -485,17 +488,26 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None: if self.safety_checker is None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker( image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype) images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
) )
else:
has_nsfw_concept = None
return image, has_nsfw_concept return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0] image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
...@@ -1061,24 +1073,19 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -1061,24 +1073,19 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
self.controlnet.to("cpu") self.controlnet.to("cpu")
torch.cuda.empty_cache() torch.cuda.empty_cache()
if output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents image = latents
has_nsfw_concept = None has_nsfw_concept = None
elif output_type == "pil":
# 8. Post-processing
image = self.decode_latents(latents)
# 9. Run safety checker if has_nsfw_concept is None:
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) do_denormalize = [True] * image.shape[0]
# 10. Convert to PIL
image = self.numpy_to_pil(image)
else: else:
# 8. Post-processing do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.decode_latents(latents)
# 9. Run safety checker image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
# Offload last model to CPU # Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import contextlib import contextlib
import inspect import inspect
import warnings
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
...@@ -23,6 +24,7 @@ from packaging import version ...@@ -23,6 +24,7 @@ from packaging import version
from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
...@@ -128,6 +130,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -128,6 +130,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def enable_sequential_cpu_offload(self, gpu_id=0): def enable_sequential_cpu_offload(self, gpu_id=0):
r""" r"""
...@@ -314,17 +317,26 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -314,17 +317,26 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None: if self.safety_checker is None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker( image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype) images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
) )
else:
has_nsfw_concept = None
return image, has_nsfw_concept return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0] image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
...@@ -695,12 +707,12 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -695,12 +707,12 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
# 10. Post-processing if not output_type == "latent":
image = self.decode_latents(latents) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
else:
image = latents
# 11. Convert to PIL image = self.image_processor.postprocess(image, output_type=output_type)
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict: if not return_dict:
return (image,) return (image,)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
...@@ -23,6 +24,7 @@ from packaging import version ...@@ -23,6 +24,7 @@ from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMInverseScheduler, KarrasDiffusionSchedulers from ...schedulers import DDIMInverseScheduler, KarrasDiffusionSchedulers
...@@ -357,6 +359,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -357,6 +359,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
inverse_scheduler=inverse_scheduler, inverse_scheduler=inverse_scheduler,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
...@@ -618,13 +621,17 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -618,13 +621,17 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None: if self.safety_checker is None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker( image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype) images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
) )
else:
has_nsfw_concept = None
return image, has_nsfw_concept return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
...@@ -647,6 +654,11 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -647,6 +654,11 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0] image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
...@@ -1052,7 +1064,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -1052,7 +1064,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
# 9. Convert to Numpy array or PIL. # 9. Convert to Numpy array or PIL.
if output_type == "pil": if output_type == "pil":
mask_image = self.numpy_to_pil(mask_image) mask_image = self.image_processor.numpy_to_pil(mask_image)
# Offload last model to CPU # Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
...@@ -1287,7 +1299,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -1287,7 +1299,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
# 9. Convert to PIL. # 9. Convert to PIL.
if decode_latents and output_type == "pil": if decode_latents and output_type == "pil":
image = self.numpy_to_pil(image) image = self.image_processor.numpy_to_pil(image)
# Offload last model to CPU # Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
...@@ -1510,15 +1522,19 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -1510,15 +1522,19 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
# 9. Post-processing if not output_type == "latent":
image = self.decode_latents(latents) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None
# 10. Run safety checker if has_nsfw_concept is None:
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
# 11. Convert to PIL image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
if output_type == "pil":
image = self.numpy_to_pil(image)
# Offload last model to CPU # Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import warnings
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import PIL import PIL
...@@ -21,6 +22,7 @@ from packaging import version ...@@ -21,6 +22,7 @@ from packaging import version
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, is_accelerate_available, logging, randn_tensor from ...utils import deprecate, is_accelerate_available, logging, randn_tensor
...@@ -118,6 +120,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -118,6 +120,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
def enable_sequential_cpu_offload(self, gpu_id=0): def enable_sequential_cpu_offload(self, gpu_id=0):
...@@ -183,17 +186,26 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -183,17 +186,26 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None: if self.safety_checker is None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker( image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype) images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
) )
else:
has_nsfw_concept = None
return image, has_nsfw_concept return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0] image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
...@@ -398,15 +410,19 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -398,15 +410,19 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
# 8. Post-processing if not output_type == "latent":
image = self.decode_latents(latents) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
else:
image = latents
has_nsfw_concept = None
# 9. Run safety checker if has_nsfw_concept is None:
image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype) do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
# 10. Convert to PIL image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict: if not return_dict:
return (image, has_nsfw_concept) return (image, has_nsfw_concept)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import warnings
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
...@@ -22,6 +23,7 @@ from packaging import version ...@@ -22,6 +23,7 @@ from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
...@@ -270,6 +272,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -270,6 +272,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
...@@ -495,13 +498,17 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -495,13 +498,17 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None: if self.safety_checker is None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker( image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype) images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
) )
else:
has_nsfw_concept = None
return image, has_nsfw_concept return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
...@@ -524,6 +531,11 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -524,6 +531,11 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0] image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
...@@ -896,15 +908,19 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -896,15 +908,19 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
# 11. Post-processing if not output_type == "latent":
image = self.decode_latents(latents) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None
# 12. Run safety checker if has_nsfw_concept is None:
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
# 13. Convert to PIL image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
if output_type == "pil":
image = self.numpy_to_pil(image)
# Offload last model to CPU # Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import warnings
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
...@@ -22,6 +23,7 @@ from packaging import version ...@@ -22,6 +23,7 @@ from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
...@@ -209,6 +211,7 @@ class StableDiffusionInpaintPipelineLegacy( ...@@ -209,6 +211,7 @@ class StableDiffusionInpaintPipelineLegacy(
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
...@@ -434,17 +437,26 @@ class StableDiffusionInpaintPipelineLegacy( ...@@ -434,17 +437,26 @@ class StableDiffusionInpaintPipelineLegacy(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None: if self.safety_checker is None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker( image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype) images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
) )
else:
has_nsfw_concept = None
return image, has_nsfw_concept return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0] image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
...@@ -720,15 +732,19 @@ class StableDiffusionInpaintPipelineLegacy( ...@@ -720,15 +732,19 @@ class StableDiffusionInpaintPipelineLegacy(
# use original latents corresponding to unmasked portions of the image # use original latents corresponding to unmasked portions of the image
latents = (init_latents_orig * mask) + (latents * (1 - mask)) latents = (init_latents_orig * mask) + (latents * (1 - mask))
# 10. Post-processing if not output_type == "latent":
image = self.decode_latents(latents) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None
# 11. Run safety checker if has_nsfw_concept is None:
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
# 12. Convert to PIL image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
if output_type == "pil":
image = self.numpy_to_pil(image)
# Offload last model to CPU # Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import warnings
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
...@@ -20,6 +21,7 @@ import PIL ...@@ -20,6 +21,7 @@ import PIL
import torch import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
...@@ -136,6 +138,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -136,6 +138,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
@torch.no_grad() @torch.no_grad()
...@@ -386,15 +389,19 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -386,15 +389,19 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
# 10. Post-processing if not output_type == "latent":
image = self.decode_latents(latents) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None
# 11. Run safety checker if has_nsfw_concept is None:
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
# 12. Convert to PIL image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
if output_type == "pil":
image = self.numpy_to_pil(image)
# Offload last model to CPU # Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
...@@ -628,13 +635,17 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -628,13 +635,17 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None: if self.safety_checker is None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker( image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype) images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
) )
else:
has_nsfw_concept = None
return image, has_nsfw_concept return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
...@@ -657,6 +668,11 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -657,6 +668,11 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0] image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
......
...@@ -13,12 +13,14 @@ ...@@ -13,12 +13,14 @@
# limitations under the License. # limitations under the License.
import importlib import importlib
import warnings
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import torch import torch
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
from k_diffusion.sampling import get_sigmas_karras from k_diffusion.sampling import get_sigmas_karras
from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin from ...loaders import TextualInversionLoaderMixin
from ...pipelines import DiffusionPipeline from ...pipelines import DiffusionPipeline
from ...schedulers import LMSDiscreteScheduler from ...schedulers import LMSDiscreteScheduler
...@@ -111,6 +113,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -111,6 +113,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
) )
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
model = ModelWrapper(unet, scheduler.alphas_cumprod) model = ModelWrapper(unet, scheduler.alphas_cumprod)
if scheduler.config.prediction_type == "v_prediction": if scheduler.config.prediction_type == "v_prediction":
...@@ -346,17 +349,26 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -346,17 +349,26 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None: if self.safety_checker is None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker( image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype) images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
) )
else:
has_nsfw_concept = None
return image, has_nsfw_concept return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0] image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
...@@ -590,15 +602,19 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -590,15 +602,19 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
# 8. Run k-diffusion solver # 8. Run k-diffusion solver
latents = self.sampler(model_fn, latents, sigmas) latents = self.sampler(model_fn, latents, sigmas)
# 9. Post-processing if not output_type == "latent":
image = self.decode_latents(latents) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None
# 10. Run safety checker if has_nsfw_concept is None:
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
# 11. Convert to PIL image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
if output_type == "pil":
image = self.numpy_to_pil(image)
# Offload last model to CPU # Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import warnings
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
...@@ -20,6 +21,7 @@ import torch ...@@ -20,6 +21,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import EulerDiscreteScheduler from ...schedulers import EulerDiscreteScheduler
from ...utils import is_accelerate_available, logging, randn_tensor from ...utils import is_accelerate_available, logging, randn_tensor
...@@ -91,6 +93,8 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline): ...@@ -91,6 +93,8 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline):
unet=unet, unet=unet,
scheduler=scheduler, scheduler=scheduler,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def enable_sequential_cpu_offload(self, gpu_id=0): def enable_sequential_cpu_offload(self, gpu_id=0):
r""" r"""
...@@ -220,6 +224,11 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline): ...@@ -220,6 +224,11 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0] image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
...@@ -505,12 +514,12 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline): ...@@ -505,12 +514,12 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline):
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
# 10. Post-processing if not output_type == "latent":
image = self.decode_latents(latents) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
else:
image = latents
# 11. Convert to PIL image = self.image_processor.postprocess(image, output_type=output_type)
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict: if not return_dict:
return (image,) return (image,)
......
...@@ -13,11 +13,13 @@ ...@@ -13,11 +13,13 @@
import copy import copy
import inspect import inspect
import warnings
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import PNDMScheduler from ...schedulers import PNDMScheduler
...@@ -129,6 +131,7 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa ...@@ -129,6 +131,7 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
self.with_to_k = with_to_k self.with_to_k = with_to_k
...@@ -373,17 +376,26 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa ...@@ -373,17 +376,26 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None: if self.safety_checker is None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker( image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype) images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
) )
else:
has_nsfw_concept = None
return image, has_nsfw_concept return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0] image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
...@@ -767,24 +779,19 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa ...@@ -767,24 +779,19 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
if output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents image = latents
has_nsfw_concept = None has_nsfw_concept = None
elif output_type == "pil":
# 8. Post-processing
image = self.decode_latents(latents)
# 9. Run safety checker if has_nsfw_concept is None:
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) do_denormalize = [True] * image.shape[0]
# 10. Convert to PIL
image = self.numpy_to_pil(image)
else: else:
# 8. Post-processing do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.decode_latents(latents)
# 9. Run safety checker image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
# Offload last model to CPU # Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
......
...@@ -12,11 +12,13 @@ ...@@ -12,11 +12,13 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import warnings
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, PNDMScheduler
...@@ -123,6 +125,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -123,6 +125,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
...@@ -337,17 +340,26 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -337,17 +340,26 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None: if self.safety_checker is None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker( image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype) images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
) )
else:
has_nsfw_concept = None
return image, has_nsfw_concept return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0] image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
...@@ -659,15 +671,19 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -659,15 +671,19 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
# 8. Post-processing if not output_type == "latent":
image = self.decode_latents(latents) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None
# 9. Run safety checker if has_nsfw_concept is None:
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
# 10. Convert to PIL image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict: if not return_dict:
return (image, has_nsfw_concept) return (image, has_nsfw_concept)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
...@@ -28,6 +29,7 @@ from transformers import ( ...@@ -28,6 +29,7 @@ from transformers import (
CLIPTokenizer, CLIPTokenizer,
) )
from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import Attention from ...models.attention_processor import Attention
...@@ -358,6 +360,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): ...@@ -358,6 +360,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
inverse_scheduler=inverse_scheduler, inverse_scheduler=inverse_scheduler,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
...@@ -578,17 +581,26 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): ...@@ -578,17 +581,26 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None: if self.safety_checker is None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker( image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype) images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
) )
else:
has_nsfw_concept = None
return image, has_nsfw_concept return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0] image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
...@@ -1045,24 +1057,28 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): ...@@ -1045,24 +1057,28 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
# 11. Post-process the latents. if not output_type == "latent":
edited_image = self.decode_latents(latents) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None
# 12. Run the safety checker. if has_nsfw_concept is None:
edited_image, has_nsfw_concept = self.run_safety_checker(edited_image, device, prompt_embeds.dtype) do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
# 13. Convert to PIL. image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
if output_type == "pil":
edited_image = self.numpy_to_pil(edited_image)
# Offload last model to CPU # Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload() self.final_offload_hook.offload()
if not return_dict: if not return_dict:
return (edited_image, has_nsfw_concept) return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=edited_image, nsfw_content_detected=has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_INVERT_DOC_STRING) @replace_example_docstring(EXAMPLE_INVERT_DOC_STRING)
...@@ -1259,7 +1275,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): ...@@ -1259,7 +1275,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
# 9. Convert to PIL. # 9. Convert to PIL.
if output_type == "pil": if output_type == "pil":
image = self.numpy_to_pil(image) image = self.image_processor.numpy_to_pil(image)
if not return_dict: if not return_dict:
return (inverted_latents, image) return (inverted_latents, image)
......
...@@ -13,12 +13,14 @@ ...@@ -13,12 +13,14 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import warnings
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
...@@ -140,6 +142,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin) ...@@ -140,6 +142,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
...@@ -354,17 +357,26 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin) ...@@ -354,17 +357,26 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None: if self.safety_checker is None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker( image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype) images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
) )
else:
has_nsfw_concept = None
return image, has_nsfw_concept return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0] image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
...@@ -682,15 +694,19 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin) ...@@ -682,15 +694,19 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
# 8. Post-processing if not output_type == "latent":
image = self.decode_latents(latents) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None
# 9. Run safety checker if has_nsfw_concept is None:
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
# 10. Convert to PIL image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict: if not return_dict:
return (image, has_nsfw_concept) return (image, has_nsfw_concept)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import warnings
from typing import Any, Callable, List, Optional, Union from typing import Any, Callable, List, Optional, Union
import numpy as np import numpy as np
...@@ -372,6 +373,11 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -372,6 +373,11 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0] image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment