Unverified Commit 6133d98f authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

[IF| add set_begin_index for all IF pipelines (#7577)

add set_begin_index for all if pipelines
parent 1c60e094
......@@ -691,6 +691,9 @@ class IFPipeline(DiffusionPipeline, LoraLoaderMixin):
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(0)
# 5. Prepare intermediate images
intermediate_images = self.prepare_intermediate_images(
batch_size * num_images_per_prompt,
......
......@@ -633,12 +633,15 @@ class IFImg2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start:]
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
......
......@@ -714,13 +714,15 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
return image
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.get_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start:]
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
......
......@@ -723,13 +723,15 @@ class IFInpaintingPipeline(DiffusionPipeline, LoraLoaderMixin):
return mask_image
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.get_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start:]
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
......
......@@ -800,13 +800,15 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
return mask_image
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.get_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start:]
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
......
......@@ -775,6 +775,9 @@ class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(0)
# 5. Prepare intermediate images
num_channels = self.unet.config.in_channels // 2
intermediate_images = self.prepare_intermediate_images(
......
......@@ -998,6 +998,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
all_model_components = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)}
self._all_hooks = []
hook = None
for model_str in self.model_cpu_offload_seq.split("->"):
model = all_model_components.pop(model_str, None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment