Unverified Commit 2f997f30 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

Fix bug in panorama pipeline when using dpmsolver scheduler (#3499)

fix panorama pipeline with dpmsolver scheduler
parent 67cd4601
...@@ -612,6 +612,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -612,6 +612,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
# 6. Define panorama grid and initialize views for synthesis. # 6. Define panorama grid and initialize views for synthesis.
views = self.get_views(height, width) views = self.get_views(height, width)
blocks_model_outputs = [None] * len(views)
count = torch.zeros_like(latents) count = torch.zeros_like(latents)
value = torch.zeros_like(latents) value = torch.zeros_like(latents)
...@@ -632,7 +633,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -632,7 +633,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
# denoised (latent) crops are then averaged to produce the final latent # denoised (latent) crops are then averaged to produce the final latent
# for the current timestep via MultiDiffusion. Please see Sec. 4.1 in the # for the current timestep via MultiDiffusion. Please see Sec. 4.1 in the
# MultiDiffusion paper for more details: https://arxiv.org/abs/2302.08113 # MultiDiffusion paper for more details: https://arxiv.org/abs/2302.08113
for h_start, h_end, w_start, w_end in views: for j, (h_start, h_end, w_start, w_end) in enumerate(views):
# get the latents corresponding to the current view coordinates # get the latents corresponding to the current view coordinates
latents_for_view = latents[:, :, h_start:h_end, w_start:w_end] latents_for_view = latents[:, :, h_start:h_end, w_start:w_end]
...@@ -656,9 +657,21 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -656,9 +657,21 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents_view_denoised = self.scheduler.step( if hasattr(self.scheduler, "model_outputs"):
noise_pred, t, latents_for_view, **extra_step_kwargs # rematch model_outputs in each block
).prev_sample if i >= 1:
self.scheduler.model_outputs = blocks_model_outputs[j]
latents_view_denoised = self.scheduler.step(
noise_pred, t, latents_for_view, **extra_step_kwargs
).prev_sample
# collect model_outputs
blocks_model_outputs[j] = [
output if output is not None else None for output in self.scheduler.model_outputs
]
else:
latents_view_denoised = self.scheduler.step(
noise_pred, t, latents_for_view, **extra_step_kwargs
).prev_sample
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1 count[:, :, h_start:h_end, w_start:w_end] += 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment