Unverified Commit 48ce118d authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[torch.compile] fix graph break problems partially (#5453)

* fix: controlnet graph?

* fix: sample

* fix:

* remove print

* styling

* fix-copies

* prevent more graph breaks?

* prevent more graph breaks?

* see?

* revert.

* compilation.

* rpopagate changes to controlnet sdxl pipeline too.

* add: clean version checking.
parent 1ade42f7
...@@ -817,7 +817,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): ...@@ -817,7 +817,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
# 6. scaling # 6. scaling
if guess_mode and not self.config.global_pool_conditions: if guess_mode and not self.config.global_pool_conditions:
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
scales = scales * conditioning_scale scales = scales * conditioning_scale
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
......
...@@ -874,9 +874,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -874,9 +874,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
forward_upsample_size = False forward_upsample_size = False
upsample_size = None upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): for dim in sample.shape[-2:]:
if dim % default_overall_up_factor != 0:
# Forward upsample size to force interpolation output size. # Forward upsample size to force interpolation output size.
forward_upsample_size = True forward_upsample_size = True
break
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
# expects mask of shape: # expects mask of shape:
......
...@@ -35,7 +35,7 @@ from ...utils import ( ...@@ -35,7 +35,7 @@ from ...utils import (
scale_lora_layers, scale_lora_layers,
unscale_lora_layers, unscale_lora_layers,
) )
from ...utils.torch_utils import is_compiled_module, randn_tensor from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
...@@ -976,8 +976,15 @@ class StableDiffusionControlNetPipeline( ...@@ -976,8 +976,15 @@ class StableDiffusionControlNetPipeline(
# 8. Denoising loop # 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
is_unet_compiled = is_compiled_module(self.unet)
is_controlnet_compiled = is_compiled_module(self.controlnet)
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
......
...@@ -36,7 +36,7 @@ from ...models.attention_processor import ( ...@@ -36,7 +36,7 @@ from ...models.attention_processor import (
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import is_compiled_module, randn_tensor from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
...@@ -1144,8 +1144,15 @@ class StableDiffusionXLControlNetPipeline( ...@@ -1144,8 +1144,15 @@ class StableDiffusionXLControlNetPipeline(
# 8. Denoising loop # 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
is_unet_compiled = is_compiled_module(self.unet)
is_controlnet_compiled = is_compiled_module(self.controlnet)
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
......
...@@ -814,7 +814,8 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline): ...@@ -814,7 +814,8 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
down_intrablock_additional_residuals=[state.clone() for state in adapter_state], down_intrablock_additional_residuals=[state.clone() for state in adapter_state],
).sample return_dict=False,
)[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
......
...@@ -1084,9 +1084,11 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -1084,9 +1084,11 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
forward_upsample_size = False forward_upsample_size = False
upsample_size = None upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): for dim in sample.shape[-2:]:
if dim % default_overall_up_factor != 0:
# Forward upsample size to force interpolation output size. # Forward upsample size to force interpolation output size.
forward_upsample_size = True forward_upsample_size = True
break
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
# expects mask of shape: # expects mask of shape:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment