"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "e8cf9613244ddac05c86f50f963de71717790931"
Unverified Commit f242eba4 authored by SkyTNT's avatar SkyTNT Committed by GitHub
Browse files

Fix lpw stable diffusion pipeline compatibility (#1622)

parent 3faf204c
...@@ -5,14 +5,37 @@ from typing import Callable, List, Optional, Union ...@@ -5,14 +5,37 @@ from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
import diffusers
import PIL import PIL
from diffusers import SchedulerMixin, StableDiffusionPipeline from diffusers import SchedulerMixin, StableDiffusionPipeline
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging from diffusers.utils import deprecate, logging
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
try:
from diffusers.utils import PIL_INTERPOLATION
except ImportError:
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
PIL_INTERPOLATION = {
"linear": PIL.Image.Resampling.BILINEAR,
"bilinear": PIL.Image.Resampling.BILINEAR,
"bicubic": PIL.Image.Resampling.BICUBIC,
"lanczos": PIL.Image.Resampling.LANCZOS,
"nearest": PIL.Image.Resampling.NEAREST,
}
else:
PIL_INTERPOLATION = {
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
"nearest": PIL.Image.NEAREST,
}
# ------------------------------------------------------------------------------
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
re_attention = re.compile( re_attention = re.compile(
...@@ -404,27 +427,75 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): ...@@ -404,27 +427,75 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
def __init__( if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
self,
vae: AutoencoderKL, def __init__(
text_encoder: CLIPTextModel, self,
tokenizer: CLIPTokenizer, vae: AutoencoderKL,
unet: UNet2DConditionModel, text_encoder: CLIPTextModel,
scheduler: SchedulerMixin, tokenizer: CLIPTokenizer,
safety_checker: StableDiffusionSafetyChecker, unet: UNet2DConditionModel,
feature_extractor: CLIPFeatureExtractor, scheduler: SchedulerMixin,
requires_safety_checker: bool = True, safety_checker: StableDiffusionSafetyChecker,
): feature_extractor: CLIPFeatureExtractor,
super().__init__( requires_safety_checker: bool = True,
vae=vae, ):
text_encoder=text_encoder, super().__init__(
tokenizer=tokenizer, vae=vae,
unet=unet, text_encoder=text_encoder,
scheduler=scheduler, tokenizer=tokenizer,
safety_checker=safety_checker, unet=unet,
feature_extractor=feature_extractor, scheduler=scheduler,
requires_safety_checker=requires_safety_checker, safety_checker=safety_checker,
) feature_extractor=feature_extractor,
requires_safety_checker=requires_safety_checker,
)
self.__init__additional__()
else:
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: SchedulerMixin,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
):
super().__init__(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.__init__additional__()
def __init__additional__(self):
if not hasattr(self, "vae_scale_factor"):
setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
@property
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
return self.device
for module in self.unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
def _encode_prompt( def _encode_prompt(
self, self,
...@@ -752,37 +823,33 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): ...@@ -752,37 +823,33 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 8. Denoising loop # 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order for i, t in enumerate(self.progress_bar(timesteps)):
with self.progress_bar(total=num_inference_steps) as progress_bar: # expand the latents if we are doing classifier free guidance
for i, t in enumerate(timesteps): latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# expand the latents if we are doing classifier free guidance latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample # perform guidance
if do_classifier_free_guidance:
# perform guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
if do_classifier_free_guidance: noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample if mask is not None:
# masking
if mask is not None: init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
# masking latents = (init_latents_proper * mask) + (latents * (1 - mask))
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
latents = (init_latents_proper * mask) + (latents * (1 - mask)) # call the callback, if provided
if i % callback_steps == 0:
# call the callback, if provided if callback is not None:
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): callback(i, t, latents)
progress_bar.update() if is_cancelled_callback is not None and is_cancelled_callback():
if i % callback_steps == 0: return None
if callback is not None:
callback(i, t, latents)
if is_cancelled_callback is not None and is_cancelled_callback():
return None
# 9. Post-processing # 9. Post-processing
image = self.decode_latents(latents) image = self.decode_latents(latents)
......
...@@ -5,14 +5,55 @@ from typing import Callable, List, Optional, Union ...@@ -5,14 +5,55 @@ from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
import diffusers
import PIL import PIL
from diffusers import OnnxStableDiffusionPipeline, SchedulerMixin from diffusers import OnnxStableDiffusionPipeline, SchedulerMixin
from diffusers.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel from diffusers.onnx_utils import OnnxRuntimeModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging from diffusers.utils import deprecate, logging
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTokenizer
try:
from diffusers.onnx_utils import ORT_TO_NP_TYPE
except ImportError:
ORT_TO_NP_TYPE = {
"tensor(bool)": np.bool_,
"tensor(int8)": np.int8,
"tensor(uint8)": np.uint8,
"tensor(int16)": np.int16,
"tensor(uint16)": np.uint16,
"tensor(int32)": np.int32,
"tensor(uint32)": np.uint32,
"tensor(int64)": np.int64,
"tensor(uint64)": np.uint64,
"tensor(float16)": np.float16,
"tensor(float)": np.float32,
"tensor(double)": np.float64,
}
try:
from diffusers.utils import PIL_INTERPOLATION
except ImportError:
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
PIL_INTERPOLATION = {
"linear": PIL.Image.Resampling.BILINEAR,
"bilinear": PIL.Image.Resampling.BILINEAR,
"bicubic": PIL.Image.Resampling.BICUBIC,
"lanczos": PIL.Image.Resampling.LANCZOS,
"nearest": PIL.Image.Resampling.NEAREST,
}
else:
PIL_INTERPOLATION = {
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
"nearest": PIL.Image.NEAREST,
}
# ------------------------------------------------------------------------------
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
re_attention = re.compile( re_attention = re.compile(
...@@ -390,30 +431,59 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline ...@@ -390,30 +431,59 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
""" """
if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
def __init__(
self,
vae_encoder: OnnxRuntimeModel,
vae_decoder: OnnxRuntimeModel,
text_encoder: OnnxRuntimeModel,
tokenizer: CLIPTokenizer,
unet: OnnxRuntimeModel,
scheduler: SchedulerMixin,
safety_checker: OnnxRuntimeModel,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__(
vae_encoder=vae_encoder,
vae_decoder=vae_decoder,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
requires_safety_checker=requires_safety_checker,
)
self.__init__additional__()
def __init__( else:
self,
vae_encoder: OnnxRuntimeModel, def __init__(
vae_decoder: OnnxRuntimeModel, self,
text_encoder: OnnxRuntimeModel, vae_encoder: OnnxRuntimeModel,
tokenizer: CLIPTokenizer, vae_decoder: OnnxRuntimeModel,
unet: OnnxRuntimeModel, text_encoder: OnnxRuntimeModel,
scheduler: SchedulerMixin, tokenizer: CLIPTokenizer,
safety_checker: OnnxRuntimeModel, unet: OnnxRuntimeModel,
feature_extractor: CLIPFeatureExtractor, scheduler: SchedulerMixin,
requires_safety_checker: bool = True, safety_checker: OnnxRuntimeModel,
): feature_extractor: CLIPFeatureExtractor,
super().__init__( ):
vae_encoder=vae_encoder, super().__init__(
vae_decoder=vae_decoder, vae_encoder=vae_encoder,
text_encoder=text_encoder, vae_decoder=vae_decoder,
tokenizer=tokenizer, text_encoder=text_encoder,
unet=unet, tokenizer=tokenizer,
scheduler=scheduler, unet=unet,
safety_checker=safety_checker, scheduler=scheduler,
feature_extractor=feature_extractor, safety_checker=safety_checker,
requires_safety_checker=requires_safety_checker, feature_extractor=feature_extractor,
) )
self.__init__additional__()
def __init__additional__(self):
self.unet_in_channels = 4 self.unet_in_channels = 4
self.vae_scale_factor = 8 self.vae_scale_factor = 8
...@@ -741,49 +811,47 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline ...@@ -741,49 +811,47 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 8. Denoising loop # 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order for i, t in enumerate(self.progress_bar(timesteps)):
with self.progress_bar(total=num_inference_steps) as progress_bar: # expand the latents if we are doing classifier free guidance
for i, t in enumerate(timesteps): latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
# expand the latents if we are doing classifier free guidance latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = latent_model_input.numpy()
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
latent_model_input = latent_model_input.numpy() # predict the noise residual
noise_pred = self.unet(
# predict the noise residual sample=latent_model_input,
noise_pred = self.unet( timestep=np.array([t], dtype=timestep_dtype),
sample=latent_model_input, encoder_hidden_states=text_embeddings,
timestep=np.array([t], dtype=timestep_dtype), )
encoder_hidden_states=text_embeddings, noise_pred = noise_pred[0]
)
noise_pred = noise_pred[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
scheduler_output = self.scheduler.step(
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
)
latents = scheduler_output.prev_sample.numpy()
if mask is not None:
# masking
init_latents_proper = self.scheduler.add_noise(
torch.from_numpy(init_latents_orig),
torch.from_numpy(noise),
t,
).numpy()
latents = (init_latents_proper * mask) + (latents * (1 - mask))
# call the callback, if provided
if i % callback_steps == 0:
if callback is not None:
callback(i, t, latents)
if is_cancelled_callback is not None and is_cancelled_callback():
return None
# compute the previous noisy sample x_t -> x_t-1
scheduler_output = self.scheduler.step(
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
)
latents = scheduler_output.prev_sample.numpy()
if mask is not None:
# masking
init_latents_proper = self.scheduler.add_noise(
torch.from_numpy(init_latents_orig),
torch.from_numpy(noise),
t,
).numpy()
latents = (init_latents_proper * mask) + (latents * (1 - mask))
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if i % callback_steps == 0:
if callback is not None:
callback(i, t, latents)
if is_cancelled_callback is not None and is_cancelled_callback():
return None
# 9. Post-processing # 9. Post-processing
image = self.decode_latents(latents) image = self.decode_latents(latents)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment