Unverified Commit 07bd2fab authored by Pakkapon Phongthawee's avatar Pakkapon Phongthawee Committed by GitHub
Browse files

make controlnet support interrupt (#9620)

* make controlnet support interrupt

* remove white space in controlnet interrupt
parent af28ae2d
...@@ -893,6 +893,10 @@ class StableDiffusionControlNetPipeline( ...@@ -893,6 +893,10 @@ class StableDiffusionControlNetPipeline(
def num_timesteps(self): def num_timesteps(self):
return self._num_timesteps return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -1089,6 +1093,7 @@ class StableDiffusionControlNetPipeline( ...@@ -1089,6 +1093,7 @@ class StableDiffusionControlNetPipeline(
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
self._clip_skip = clip_skip self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False
# 2. Define call parameters # 2. Define call parameters
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
...@@ -1235,6 +1240,9 @@ class StableDiffusionControlNetPipeline( ...@@ -1235,6 +1240,9 @@ class StableDiffusionControlNetPipeline(
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
if self.interrupt:
continue
# Relevant thread: # Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
......
...@@ -891,6 +891,10 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -891,6 +891,10 @@ class StableDiffusionControlNetImg2ImgPipeline(
def num_timesteps(self): def num_timesteps(self):
return self._num_timesteps return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -1081,6 +1085,7 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -1081,6 +1085,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
self._clip_skip = clip_skip self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False
# 2. Define call parameters # 2. Define call parameters
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
...@@ -1211,6 +1216,9 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -1211,6 +1216,9 @@ class StableDiffusionControlNetImg2ImgPipeline(
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
......
...@@ -976,6 +976,10 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -976,6 +976,10 @@ class StableDiffusionControlNetInpaintPipeline(
def num_timesteps(self): def num_timesteps(self):
return self._num_timesteps return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -1191,6 +1195,7 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -1191,6 +1195,7 @@ class StableDiffusionControlNetInpaintPipeline(
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
self._clip_skip = clip_skip self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False
# 2. Define call parameters # 2. Define call parameters
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
...@@ -1375,6 +1380,9 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -1375,6 +1380,9 @@ class StableDiffusionControlNetInpaintPipeline(
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
......
...@@ -1145,6 +1145,10 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1145,6 +1145,10 @@ class StableDiffusionXLControlNetInpaintPipeline(
def num_timesteps(self): def num_timesteps(self):
return self._num_timesteps return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -1427,6 +1431,7 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1427,6 +1431,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
self._clip_skip = clip_skip self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False
# 2. Define call parameters # 2. Define call parameters
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
...@@ -1695,6 +1700,9 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1695,6 +1700,9 @@ class StableDiffusionXLControlNetInpaintPipeline(
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
......
...@@ -990,6 +990,10 @@ class StableDiffusionXLControlNetPipeline( ...@@ -990,6 +990,10 @@ class StableDiffusionXLControlNetPipeline(
def num_timesteps(self): def num_timesteps(self):
return self._num_timesteps return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -1245,6 +1249,7 @@ class StableDiffusionXLControlNetPipeline( ...@@ -1245,6 +1249,7 @@ class StableDiffusionXLControlNetPipeline(
self._clip_skip = clip_skip self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs self._cross_attention_kwargs = cross_attention_kwargs
self._denoising_end = denoising_end self._denoising_end = denoising_end
self._interrupt = False
# 2. Define call parameters # 2. Define call parameters
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
...@@ -1442,6 +1447,9 @@ class StableDiffusionXLControlNetPipeline( ...@@ -1442,6 +1447,9 @@ class StableDiffusionXLControlNetPipeline(
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
if self.interrupt:
continue
# Relevant thread: # Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
......
...@@ -1070,6 +1070,10 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -1070,6 +1070,10 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
def num_timesteps(self): def num_timesteps(self):
return self._num_timesteps return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -1338,6 +1342,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -1338,6 +1342,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
self._clip_skip = clip_skip self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False
# 2. Define call parameters # 2. Define call parameters
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
...@@ -1510,6 +1515,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -1510,6 +1515,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment