Commit d420d713 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

make style

parent a1fad828
...@@ -20,22 +20,19 @@ from dataclasses import dataclass ...@@ -20,22 +20,19 @@ from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import PIL.Image
import torch import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, ConfigMixin, DiffusionPipeline, SchedulerMixin, UNet2DConditionModel, logging from diffusers import AutoencoderKL, ConfigMixin, DiffusionPipeline, SchedulerMixin, UNet2DConditionModel, logging
from diffusers.configuration_utils import register_to_config from diffusers.configuration_utils import register_to_config
from diffusers.image_processor import VaeImageProcessor, PipelineImageInput from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.utils import BaseOutput from diffusers.utils import BaseOutput
from diffusers.utils.torch_utils import randn_tensor from diffusers.utils.torch_utils import randn_tensor
import PIL.Image
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -167,7 +164,19 @@ class LatentConsistencyModelImg2ImgPipeline(DiffusionPipeline): ...@@ -167,7 +164,19 @@ class LatentConsistencyModelImg2ImgPipeline(DiffusionPipeline):
) )
return image, has_nsfw_concept return image, has_nsfw_concept
def prepare_latents(self, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, latents=None, generator=None): def prepare_latents(
self,
image,
timestep,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
latents=None,
generator=None,
):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
...@@ -201,7 +210,7 @@ class LatentConsistencyModelImg2ImgPipeline(DiffusionPipeline): ...@@ -201,7 +210,7 @@ class LatentConsistencyModelImg2ImgPipeline(DiffusionPipeline):
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size # expand init_latents for batch_size
deprecation_message = ( (
f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
...@@ -643,7 +652,9 @@ class LCMSchedulerWithTimestamp(SchedulerMixin, ConfigMixin): ...@@ -643,7 +652,9 @@ class LCMSchedulerWithTimestamp(SchedulerMixin, ConfigMixin):
return sample return sample
def set_timesteps(self, stength, num_inference_steps: int, lcm_origin_steps: int, device: Union[str, torch.device] = None): def set_timesteps(
self, stength, num_inference_steps: int, lcm_origin_steps: int, device: Union[str, torch.device] = None
):
""" """
Sets the discrete timesteps used for the diffusion chain (to be run before inference). Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args: Args:
...@@ -662,7 +673,9 @@ class LCMSchedulerWithTimestamp(SchedulerMixin, ConfigMixin): ...@@ -662,7 +673,9 @@ class LCMSchedulerWithTimestamp(SchedulerMixin, ConfigMixin):
# LCM Timesteps Setting: # Linear Spacing # LCM Timesteps Setting: # Linear Spacing
c = self.config.num_train_timesteps // lcm_origin_steps c = self.config.num_train_timesteps // lcm_origin_steps
lcm_origin_timesteps = np.asarray(list(range(1, int(lcm_origin_steps * stength) + 1))) * c - 1 # LCM Training Steps Schedule lcm_origin_timesteps = (
np.asarray(list(range(1, int(lcm_origin_steps * stength) + 1))) * c - 1
) # LCM Training Steps Schedule
skipping_step = len(lcm_origin_timesteps) // num_inference_steps skipping_step = len(lcm_origin_timesteps) // num_inference_steps
timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps] # LCM Inference Steps Schedule timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps] # LCM Inference Steps Schedule
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment