Unverified Commit a94977b8 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

Fix panorama to support all schedulers (#3546)

* refactor blocks init

* refactor blocks loop

* remove unused function and warnings

* fix scheduler update location

* reformat code

* reformat code again

* fix PNDM test case

* reformat pndm test case
parent 8e69708b
......@@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import inspect
import warnings
from typing import Any, Callable, Dict, List, Optional, Union
......@@ -21,7 +22,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, PNDMScheduler
from ...schedulers import DDIMScheduler
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput
......@@ -96,9 +97,6 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
):
super().__init__()
if isinstance(scheduler, PNDMScheduler):
logger.error("PNDMScheduler for this pipeline is currently not supported.")
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
......@@ -612,7 +610,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
# 6. Define panorama grid and initialize views for synthesis.
views = self.get_views(height, width)
blocks_model_outputs = [None] * len(views)
views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(views)
count = torch.zeros_like(latents)
value = torch.zeros_like(latents)
......@@ -637,6 +635,9 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
# get the latents corresponding to the current view coordinates
latents_for_view = latents[:, :, h_start:h_end, w_start:w_end]
# rematch block's scheduler status
self.scheduler.__dict__.update(views_scheduler_status[j])
# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents_for_view] * 2) if do_classifier_free_guidance else latents_for_view
......@@ -657,21 +658,13 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
if hasattr(self.scheduler, "model_outputs"):
# rematch model_outputs in each block
if i >= 1:
self.scheduler.model_outputs = blocks_model_outputs[j]
latents_view_denoised = self.scheduler.step(
noise_pred, t, latents_for_view, **extra_step_kwargs
).prev_sample
# collect model_outputs
blocks_model_outputs[j] = [
output if output is not None else None for output in self.scheduler.model_outputs
]
else:
latents_view_denoised = self.scheduler.step(
noise_pred, t, latents_for_view, **extra_step_kwargs
).prev_sample
# save views scheduler status after sample
views_scheduler_status[j] = copy.deepcopy(self.scheduler.__dict__)
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1
......
......@@ -174,15 +174,22 @@ class StableDiffusionPanoramaPipelineFastTests(PipelineLatentTesterMixin, Pipeli
def test_stable_diffusion_panorama_pndm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
components["scheduler"] = PNDMScheduler()
components["scheduler"] = PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
)
sd_pipe = StableDiffusionPanoramaPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
# the pipeline does not expect pndm so test if it raises error.
with self.assertRaises(ValueError):
_ = sd_pipe(**inputs).images
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.6391, 0.6291, 0.4861, 0.5134, 0.5552, 0.4578, 0.5032, 0.5023, 0.4539])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment