Unverified Commit a94977b8 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

Fix panorama to support all schedulers (#3546)

* refactor blocks init

* refactor blocks loop

* remove unused function and warnings

* fix scheduler update location

* reformat code

* reformat code again

* fix PNDM test case

* reformat pndm test case
parent 8e69708b
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import inspect import inspect
import warnings import warnings
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
...@@ -21,7 +22,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer ...@@ -21,7 +22,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, PNDMScheduler from ...schedulers import DDIMScheduler
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
...@@ -96,9 +97,6 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -96,9 +97,6 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
): ):
super().__init__() super().__init__()
if isinstance(scheduler, PNDMScheduler):
logger.error("PNDMScheduler for this pipeline is currently not supported.")
if safety_checker is None and requires_safety_checker: if safety_checker is None and requires_safety_checker:
logger.warning( logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
...@@ -612,7 +610,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -612,7 +610,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
# 6. Define panorama grid and initialize views for synthesis. # 6. Define panorama grid and initialize views for synthesis.
views = self.get_views(height, width) views = self.get_views(height, width)
blocks_model_outputs = [None] * len(views) views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(views)
count = torch.zeros_like(latents) count = torch.zeros_like(latents)
value = torch.zeros_like(latents) value = torch.zeros_like(latents)
...@@ -637,6 +635,9 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -637,6 +635,9 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
# get the latents corresponding to the current view coordinates # get the latents corresponding to the current view coordinates
latents_for_view = latents[:, :, h_start:h_end, w_start:w_end] latents_for_view = latents[:, :, h_start:h_end, w_start:w_end]
# rematch block's scheduler status
self.scheduler.__dict__.update(views_scheduler_status[j])
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = ( latent_model_input = (
torch.cat([latents_for_view] * 2) if do_classifier_free_guidance else latents_for_view torch.cat([latents_for_view] * 2) if do_classifier_free_guidance else latents_for_view
...@@ -657,21 +658,13 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -657,21 +658,13 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
if hasattr(self.scheduler, "model_outputs"): latents_view_denoised = self.scheduler.step(
# rematch model_outputs in each block noise_pred, t, latents_for_view, **extra_step_kwargs
if i >= 1: ).prev_sample
self.scheduler.model_outputs = blocks_model_outputs[j]
latents_view_denoised = self.scheduler.step( # save views scheduler status after sample
noise_pred, t, latents_for_view, **extra_step_kwargs views_scheduler_status[j] = copy.deepcopy(self.scheduler.__dict__)
).prev_sample
# collect model_outputs
blocks_model_outputs[j] = [
output if output is not None else None for output in self.scheduler.model_outputs
]
else:
latents_view_denoised = self.scheduler.step(
noise_pred, t, latents_for_view, **extra_step_kwargs
).prev_sample
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1 count[:, :, h_start:h_end, w_start:w_end] += 1
......
...@@ -174,15 +174,22 @@ class StableDiffusionPanoramaPipelineFastTests(PipelineLatentTesterMixin, Pipeli ...@@ -174,15 +174,22 @@ class StableDiffusionPanoramaPipelineFastTests(PipelineLatentTesterMixin, Pipeli
def test_stable_diffusion_panorama_pndm(self): def test_stable_diffusion_panorama_pndm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components() components = self.get_dummy_components()
components["scheduler"] = PNDMScheduler() components["scheduler"] = PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
)
sd_pipe = StableDiffusionPanoramaPipeline(**components) sd_pipe = StableDiffusionPanoramaPipeline(**components)
sd_pipe = sd_pipe.to(device) sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None) sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device) inputs = self.get_dummy_inputs(device)
# the pipeline does not expect pndm so test if it raises error. image = sd_pipe(**inputs).images
with self.assertRaises(ValueError): image_slice = image[0, -3:, -3:, -1]
_ = sd_pipe(**inputs).images
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.6391, 0.6291, 0.4861, 0.5134, 0.5552, 0.4578, 0.5032, 0.5023, 0.4539])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow @slow
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment