Unverified Commit 2f997f30 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

Fix bug in panorama pipeline when using dpmsolver scheduler (#3499)

fix panorama pipeline with dpmsolver scheduler
parent 67cd4601
......@@ -612,6 +612,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
# 6. Define panorama grid and initialize views for synthesis.
views = self.get_views(height, width)
blocks_model_outputs = [None] * len(views)
count = torch.zeros_like(latents)
value = torch.zeros_like(latents)
......@@ -632,7 +633,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
# denoised (latent) crops are then averaged to produce the final latent
# for the current timestep via MultiDiffusion. Please see Sec. 4.1 in the
# MultiDiffusion paper for more details: https://arxiv.org/abs/2302.08113
for h_start, h_end, w_start, w_end in views:
for j, (h_start, h_end, w_start, w_end) in enumerate(views):
# get the latents corresponding to the current view coordinates
latents_for_view = latents[:, :, h_start:h_end, w_start:w_end]
......@@ -656,9 +657,21 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents_view_denoised = self.scheduler.step(
noise_pred, t, latents_for_view, **extra_step_kwargs
).prev_sample
if hasattr(self.scheduler, "model_outputs"):
# rematch model_outputs in each block
if i >= 1:
self.scheduler.model_outputs = blocks_model_outputs[j]
latents_view_denoised = self.scheduler.step(
noise_pred, t, latents_for_view, **extra_step_kwargs
).prev_sample
# collect model_outputs
blocks_model_outputs[j] = [
output if output is not None else None for output in self.scheduler.model_outputs
]
else:
latents_view_denoised = self.scheduler.step(
noise_pred, t, latents_for_view, **extra_step_kwargs
).prev_sample
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment