Unverified Commit 17f9aed7 authored by clarencechen's avatar clarencechen Committed by GitHub
Browse files

[Scheduler] DPM-Solver (++) Inverse Scheduler (#3335)



* Add DPM-Solver Multistep Inverse Scheduler

* Add draft tests for DiffEdit

* Add inverse sde-dpmsolver steps to tune image diversity from inverted latents

* Fix tests

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 886575ee
......@@ -252,6 +252,8 @@
title: Euler scheduler
- local: api/schedulers/heun
title: Heun Scheduler
- local: api/schedulers/multistep_dpm_solver_inverse
title: Inverse Multistep DPM-Solver
- local: api/schedulers/ipndm
title: IPNDM
- local: api/schedulers/lms_discrete
......
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Inverse Multistep DPM-Solver (DPMSolverMultistepInverse)
## Overview
This scheduler is the inverted scheduler of [DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps](https://arxiv.org/abs/2206.00927) and [DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models
](https://arxiv.org/abs/2211.01095) by Cheng Lu, Yuhao Zhou, Fan Bao, Jianfei Chen, Chongxuan Li, and Jun Zhu.
The implementation is mostly based on the DDIM inversion definition of [Null-text Inversion for Editing Real Images using Guided Diffusion Models](https://arxiv.org/pdf/2211.09794.pdf) and the ad-hoc notebook implementation for DiffEdit latent inversion [here](https://github.com/Xiang-cd/DiffEdit-stable-diffusion/blob/main/diffedit.ipynb).
## DPMSolverMultistepInverseScheduler
[[autodoc]] DPMSolverMultistepInverseScheduler
......@@ -76,6 +76,7 @@ else:
DDIMScheduler,
DDPMScheduler,
DEISMultistepScheduler,
DPMSolverMultistepInverseScheduler,
DPMSolverMultistepScheduler,
DPMSolverSinglestepScheduler,
EulerAncestralDiscreteScheduler,
......
......@@ -33,6 +33,7 @@ else:
from .scheduling_ddpm import DDPMScheduler
from .scheduling_deis_multistep import DEISMultistepScheduler
from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
from .scheduling_dpmsolver_multistep_inverse import DPMSolverMultistepInverseScheduler
from .scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler
from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
from .scheduling_euler_discrete import EulerDiscreteScheduler
......
......@@ -450,6 +450,21 @@ class DEISMultistepScheduler(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class DPMSolverMultistepInverseScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class DPMSolverMultistepScheduler(metaclass=DummyObject):
_backends = ["torch"]
......
......@@ -27,6 +27,8 @@ from diffusers import (
AutoencoderKL,
DDIMInverseScheduler,
DDIMScheduler,
DPMSolverMultistepInverseScheduler,
DPMSolverMultistepScheduler,
StableDiffusionDiffEditPipeline,
UNet2DConditionModel,
)
......@@ -256,6 +258,30 @@ class StableDiffusionDiffEditPipelineFastTests(PipelineLatentTesterMixin, Pipeli
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=5e-3)
def test_inversion_dpm(self):
device = "cpu"
components = self.get_dummy_components()
scheduler_args = {"beta_start": 0.00085, "beta_end": 0.012, "beta_schedule": "scaled_linear"}
components["scheduler"] = DPMSolverMultistepScheduler(**scheduler_args)
components["inverse_scheduler"] = DPMSolverMultistepInverseScheduler(**scheduler_args)
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inversion_inputs(device)
image = pipe.invert(**inputs).images
image_slice = image[0, -1, -3:, -3:]
self.assertEqual(image.shape, (2, 32, 32, 3))
expected_slice = np.array(
[0.5150, 0.5134, 0.5043, 0.5376, 0.4694, 0.51050, 0.5015, 0.4407, 0.4799],
)
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
@require_torch_gpu
@slow
......@@ -320,3 +346,54 @@ class StableDiffusionDiffEditPipelineIntegrationTests(unittest.TestCase):
/ 255
)
assert np.abs((expected_image - image).max()) < 5e-1
def test_stable_diffusion_diffedit_dpm(self):
generator = torch.manual_seed(0)
pipe = StableDiffusionDiffEditPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1", safety_checker=None, torch_dtype=torch.float16
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.inverse_scheduler = DPMSolverMultistepInverseScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
source_prompt = "a bowl of fruit"
target_prompt = "a bowl of pears"
mask_image = pipe.generate_mask(
image=self.raw_image,
source_prompt=source_prompt,
target_prompt=target_prompt,
generator=generator,
)
inv_latents = pipe.invert(
prompt=source_prompt,
image=self.raw_image,
inpaint_strength=0.7,
generator=generator,
num_inference_steps=25,
).latents
image = pipe(
prompt=target_prompt,
mask_image=mask_image,
image_latents=inv_latents,
generator=generator,
negative_prompt=source_prompt,
inpaint_strength=0.7,
num_inference_steps=25,
output_type="numpy",
).images[0]
expected_image = (
np.array(
load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/diffedit/pears.png"
).resize((768, 768))
)
/ 255
)
assert np.abs((expected_image - image).max()) < 5e-1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment