Unverified Commit 73b125df authored by Andy Shih's avatar Andy Shih Committed by GitHub
Browse files

[Pipeline] Add new pipeline for ParaDiGMS -- parallel sampling of diffusion models (#3716)



* add paradigms parallel sampling pipeline

* linting

* ran make fix-copies

* add paradigms parallel sampling pipeline

* linting

* ran make fix-copies

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* changes based on review

* add docs for paradigms

* update docs with paradigms abstract

* improve documentation, and add tests for ddim/ddpm batch_step_no_noise

* fix docs and run make fix-copies

* minor changes to docs.

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* move parallel scheduler to new classes for DDPMParallelScheduler and DDIMParallelScheduler

* remove changes for scheduling_ddim, adjust licenses, credits, and commented code

* fix tensor type that is breaking tests

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 88eb0448
...@@ -188,6 +188,8 @@ ...@@ -188,6 +188,8 @@
title: MultiDiffusion Panorama title: MultiDiffusion Panorama
- local: api/pipelines/paint_by_example - local: api/pipelines/paint_by_example
title: PaintByExample title: PaintByExample
- local: api/pipelines/paradigms
title: Parallel Sampling of Diffusion Models
- local: api/pipelines/pix2pix_zero - local: api/pipelines/pix2pix_zero
title: Pix2Pix Zero title: Pix2Pix Zero
- local: api/pipelines/pndm - local: api/pipelines/pndm
......
...@@ -61,6 +61,7 @@ available a colab notebook to directly try them out. ...@@ -61,6 +61,7 @@ available a colab notebook to directly try them out.
| [latent_diffusion](./latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image | | [latent_diffusion](./latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image |
| [latent_diffusion_uncond](./latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation | | [latent_diffusion_uncond](./latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
| [paint_by_example](./paint_by_example) | [**Paint by Example: Exemplar-based Image Editing with Diffusion Models**](https://arxiv.org/abs/2211.13227) | Image-Guided Image Inpainting | | [paint_by_example](./paint_by_example) | [**Paint by Example: Exemplar-based Image Editing with Diffusion Models**](https://arxiv.org/abs/2211.13227) | Image-Guided Image Inpainting |
| [paradigms](./paradigms) | [**Parallel Sampling of Diffusion Models**](https://arxiv.org/abs/2305.16317) | Text-to-Image Generation |
| [pndm](./pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation | | [pndm](./pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
| [score_sde_ve](./score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | | [score_sde_ve](./score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [score_sde_vp](./score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | | [score_sde_vp](./score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
......
<!--Copyright 2023 ParaDiGMS authors and The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Parallel Sampling of Diffusion Models (ParaDiGMS)
## Overview
[Parallel Sampling of Diffusion Models](https://arxiv.org/abs/2305.16317) by Andy Shih, Suneel Belkhale, Stefano Ermon, Dorsa Sadigh, Nima Anari.
The abstract of the paper is the following:
*Diffusion models are powerful generative models but suffer from slow sampling, often taking 1000 sequential denoising steps for one sample. As a result, considerable efforts have been directed toward reducing the number of denoising steps, but these methods hurt sample quality. Instead of reducing the number of denoising steps (trading quality for speed), in this paper we explore an orthogonal approach: can we run the denoising steps in parallel (trading compute for speed)? In spite of the sequential nature of the denoising steps, we show that surprisingly it is possible to parallelize sampling via Picard iterations, by guessing the solution of future denoising steps and iteratively refining until convergence. With this insight, we present ParaDiGMS, a novel method to accelerate the sampling of pretrained diffusion models by denoising multiple steps in parallel. ParaDiGMS is the first diffusion sampling method that enables trading compute for speed and is even compatible with existing fast sampling techniques such as DDIM and DPMSolver. Using ParaDiGMS, we improve sampling speed by 2-4x across a range of robotics and image generation models, giving state-of-the-art sampling speeds of 0.2s on 100-step DiffusionPolicy and 16s on 1000-step StableDiffusion-v2 with no measurable degradation of task reward, FID score, or CLIP score.*
Resources:
* [Paper](https://arxiv.org/abs/2305.16317).
* [Original Code](https://github.com/AndyShih12/paradigms).
## Available Pipelines:
| Pipeline | Tasks | Demo
|---|---|:---:|
| [StableDiffusionParadigmsPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py) | *Faster Text-to-Image Generation* | |
## Usage example
```python
import torch
from diffusers import DDPMParallelScheduler
from diffusers import StableDiffusionParadigmsPipeline
scheduler = DDPMParallelScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
pipe = StableDiffusionParadigmsPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", scheduler=scheduler, torch_dtype=torch.float16
)
pipe = pipe.to("cuda")
ngpu, batch_per_device = torch.cuda.device_count(), 5
pipe.wrapped_unet = torch.nn.DataParallel(pipe.unet, device_ids=[d for d in range(ngpu)])
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt, parallel=ngpu * batch_per_device, num_inference_steps=1000).images[0]
```
<Tip>
This pipeline improves sampling speed by running denoising steps in parallel, at the cost of increased total FLOPs.
Therefore, it is better to call this pipeline when running on multiple GPUs. Otherwise, without enough GPU bandwidth
sampling may be even slower than sequential sampling.
The two parameters to play with are `parallel` (batch size) and `tolerance`.
- If it fits in memory, for 1000-step DDPM you can aim for a batch size of around 100
(e.g. 8 GPUs and batch_per_device=12 to get parallel=96). Higher batch size
may not fit in memory, and lower batch size gives less parallelism.
- For tolerance, using a higher tolerance may get better speedups but can risk sample quality degradation.
If there is quality degradation with the default tolerance, then use a lower tolerance (e.g. 0.001).
For 1000-step DDPM on 8 A100 GPUs, you can expect around a 3x speedup by StableDiffusionParadigmsPipeline instead of StableDiffusionPipeline
by setting parallel=80 and tolerance=0.1.
</Tip>
<Tip>
Diffusers also offers distributed inference support for generating multiple prompts
in parallel on multiple GPUs. Check out the docs [here](https://huggingface.co/docs/diffusers/main/en/training/distributed_inference).
In contrast, this pipeline is designed for speeding up sampling of a single prompt (by using multiple GPUs).
</Tip>
## StableDiffusionParadigmsPipeline
[[autodoc]] StableDiffusionParadigmsPipeline
- __call__
- all
...@@ -73,7 +73,9 @@ else: ...@@ -73,7 +73,9 @@ else:
) )
from .schedulers import ( from .schedulers import (
DDIMInverseScheduler, DDIMInverseScheduler,
DDIMParallelScheduler,
DDIMScheduler, DDIMScheduler,
DDPMParallelScheduler,
DDPMScheduler, DDPMScheduler,
DEISMultistepScheduler, DEISMultistepScheduler,
DPMSolverMultistepInverseScheduler, DPMSolverMultistepInverseScheduler,
...@@ -152,6 +154,7 @@ else: ...@@ -152,6 +154,7 @@ else:
StableDiffusionLDM3DPipeline, StableDiffusionLDM3DPipeline,
StableDiffusionModelEditingPipeline, StableDiffusionModelEditingPipeline,
StableDiffusionPanoramaPipeline, StableDiffusionPanoramaPipeline,
StableDiffusionParadigmsPipeline,
StableDiffusionPipeline, StableDiffusionPipeline,
StableDiffusionPipelineSafe, StableDiffusionPipelineSafe,
StableDiffusionPix2PixZeroPipeline, StableDiffusionPix2PixZeroPipeline,
......
...@@ -80,6 +80,7 @@ else: ...@@ -80,6 +80,7 @@ else:
StableDiffusionLDM3DPipeline, StableDiffusionLDM3DPipeline,
StableDiffusionModelEditingPipeline, StableDiffusionModelEditingPipeline,
StableDiffusionPanoramaPipeline, StableDiffusionPanoramaPipeline,
StableDiffusionParadigmsPipeline,
StableDiffusionPipeline, StableDiffusionPipeline,
StableDiffusionPix2PixZeroPipeline, StableDiffusionPix2PixZeroPipeline,
StableDiffusionSAGPipeline, StableDiffusionSAGPipeline,
......
...@@ -53,6 +53,7 @@ else: ...@@ -53,6 +53,7 @@ else:
from .pipeline_stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline from .pipeline_stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline
from .pipeline_stable_diffusion_model_editing import StableDiffusionModelEditingPipeline from .pipeline_stable_diffusion_model_editing import StableDiffusionModelEditingPipeline
from .pipeline_stable_diffusion_panorama import StableDiffusionPanoramaPipeline from .pipeline_stable_diffusion_panorama import StableDiffusionPanoramaPipeline
from .pipeline_stable_diffusion_paradigms import StableDiffusionParadigmsPipeline
from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline
from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
from .pipeline_stable_unclip import StableUnCLIPPipeline from .pipeline_stable_unclip import StableUnCLIPPipeline
......
...@@ -30,7 +30,9 @@ except OptionalDependencyNotAvailable: ...@@ -30,7 +30,9 @@ except OptionalDependencyNotAvailable:
else: else:
from .scheduling_ddim import DDIMScheduler from .scheduling_ddim import DDIMScheduler
from .scheduling_ddim_inverse import DDIMInverseScheduler from .scheduling_ddim_inverse import DDIMInverseScheduler
from .scheduling_ddim_parallel import DDIMParallelScheduler
from .scheduling_ddpm import DDPMScheduler from .scheduling_ddpm import DDPMScheduler
from .scheduling_ddpm_parallel import DDPMParallelScheduler
from .scheduling_deis_multistep import DEISMultistepScheduler from .scheduling_deis_multistep import DEISMultistepScheduler
from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
from .scheduling_dpmsolver_multistep_inverse import DPMSolverMultistepInverseScheduler from .scheduling_dpmsolver_multistep_inverse import DPMSolverMultistepInverseScheduler
......
This diff is collapsed.
This diff is collapsed.
...@@ -405,6 +405,21 @@ class DDIMInverseScheduler(metaclass=DummyObject): ...@@ -405,6 +405,21 @@ class DDIMInverseScheduler(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class DDIMParallelScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class DDIMScheduler(metaclass=DummyObject): class DDIMScheduler(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -420,6 +435,21 @@ class DDIMScheduler(metaclass=DummyObject): ...@@ -420,6 +435,21 @@ class DDIMScheduler(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class DDPMParallelScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class DDPMScheduler(metaclass=DummyObject): class DDPMScheduler(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -497,6 +497,21 @@ class StableDiffusionPanoramaPipeline(metaclass=DummyObject): ...@@ -497,6 +497,21 @@ class StableDiffusionPanoramaPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class StableDiffusionParadigmsPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionPipeline(metaclass=DummyObject): class StableDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import numpy as np
import torch
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import (
AutoencoderKL,
DDIMParallelScheduler,
DDPMParallelScheduler,
StableDiffusionParadigmsPipeline,
UNet2DConditionModel,
)
from diffusers.utils import slow, torch_device
from diffusers.utils.testing_utils import (
enable_full_determinism,
require_torch_gpu,
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
enable_full_determinism()
class StableDiffusionParadigmsPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionParadigmsPipeline
params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
# SD2-specific config below
attention_head_dim=(2, 4),
use_linear_projection=True,
)
scheduler = DDIMParallelScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
sample_size=128,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
# SD2-specific config below
hidden_act="gelu",
projection_dim=512,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "a photograph of an astronaut riding a horse",
"generator": generator,
"num_inference_steps": 10,
"guidance_scale": 6.0,
"output_type": "numpy",
"parallel": 3,
"debug": True,
}
return inputs
def test_stable_diffusion_paradigms_default_case(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionParadigmsPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.4773, 0.5417, 0.4723, 0.4925, 0.5631, 0.4752, 0.5240, 0.4935, 0.5023])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_paradigms_default_case_ddpm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
torch.manual_seed(0)
components["scheduler"] = DDPMParallelScheduler()
torch.manual_seed(0)
sd_pipe = StableDiffusionParadigmsPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.3573, 0.4420, 0.4960, 0.4799, 0.3796, 0.3879, 0.4819, 0.4365, 0.4468])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
# override to speed the overall test timing up.
def test_inference_batch_consistent(self):
super().test_inference_batch_consistent(batch_sizes=[1, 2])
# override to speed the overall test timing up.
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(batch_size=2, expected_max_diff=3e-3)
def test_stable_diffusion_paradigms_negative_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionParadigmsPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
negative_prompt = "french fries"
output = sd_pipe(**inputs, negative_prompt=negative_prompt)
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.4771, 0.5420, 0.4683, 0.4918, 0.5636, 0.4725, 0.5230, 0.4923, 0.5015])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
@require_torch_gpu
class StableDiffusionParadigmsPipelineSlowTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, seed=0):
generator = torch.Generator(device=torch_device).manual_seed(seed)
inputs = {
"prompt": "a photograph of an astronaut riding a horse",
"generator": generator,
"num_inference_steps": 10,
"guidance_scale": 7.5,
"output_type": "numpy",
"parallel": 3,
"debug": True,
}
return inputs
def test_stable_diffusion_paradigms_default(self):
model_ckpt = "stabilityai/stable-diffusion-2-base"
scheduler = DDIMParallelScheduler.from_pretrained(model_ckpt, subfolder="scheduler")
pipe = StableDiffusionParadigmsPipeline.from_pretrained(model_ckpt, scheduler=scheduler, safety_checker=None)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
inputs = self.get_inputs()
image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.9622, 0.9602, 0.9748, 0.9591, 0.9630, 0.9691, 0.9661, 0.9631, 0.9741])
assert np.abs(expected_slice - image_slice).max() < 1e-2
# Copyright 2023 ParaDiGMS authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from diffusers import DDIMParallelScheduler
from .test_schedulers import SchedulerCommonTest
class DDIMParallelSchedulerTest(SchedulerCommonTest):
scheduler_classes = (DDIMParallelScheduler,)
forward_default_kwargs = (("eta", 0.0), ("num_inference_steps", 50))
def get_scheduler_config(self, **kwargs):
config = {
"num_train_timesteps": 1000,
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
"clip_sample": True,
}
config.update(**kwargs)
return config
def full_loop(self, **config):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
num_inference_steps, eta = 10, 0.0
model = self.dummy_model()
sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps)
for t in scheduler.timesteps:
residual = model(sample, t)
sample = scheduler.step(residual, t, sample, eta).prev_sample
return sample
def test_timesteps(self):
for timesteps in [100, 500, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
def test_steps_offset(self):
for steps_offset in [0, 1]:
self.check_over_configs(steps_offset=steps_offset)
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(steps_offset=1)
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(5)
assert torch.equal(scheduler.timesteps, torch.LongTensor([801, 601, 401, 201, 1]))
def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
def test_schedules(self):
for schedule in ["linear", "squaredcos_cap_v2"]:
self.check_over_configs(beta_schedule=schedule)
def test_prediction_type(self):
for prediction_type in ["epsilon", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type)
def test_clip_sample(self):
for clip_sample in [True, False]:
self.check_over_configs(clip_sample=clip_sample)
def test_timestep_spacing(self):
for timestep_spacing in ["trailing", "leading"]:
self.check_over_configs(timestep_spacing=timestep_spacing)
def test_rescale_betas_zero_snr(self):
for rescale_betas_zero_snr in [True, False]:
self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr)
def test_thresholding(self):
self.check_over_configs(thresholding=False)
for threshold in [0.5, 1.0, 2.0]:
for prediction_type in ["epsilon", "v_prediction"]:
self.check_over_configs(
thresholding=True,
prediction_type=prediction_type,
sample_max_value=threshold,
)
def test_time_indices(self):
for t in [1, 10, 49]:
self.check_over_forward(time_step=t)
def test_inference_steps(self):
for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]):
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
def test_eta(self):
for t, eta in zip([1, 10, 49], [0.0, 0.5, 1.0]):
self.check_over_forward(time_step=t, eta=eta)
def test_variance(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
assert torch.sum(torch.abs(scheduler._get_variance(0, 0) - 0.0)) < 1e-5
assert torch.sum(torch.abs(scheduler._get_variance(420, 400) - 0.14771)) < 1e-5
assert torch.sum(torch.abs(scheduler._get_variance(980, 960) - 0.32460)) < 1e-5
assert torch.sum(torch.abs(scheduler._get_variance(0, 0) - 0.0)) < 1e-5
assert torch.sum(torch.abs(scheduler._get_variance(487, 486) - 0.00979)) < 1e-5
assert torch.sum(torch.abs(scheduler._get_variance(999, 998) - 0.02)) < 1e-5
def test_batch_step_no_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
num_inference_steps, eta = 10, 0.0
scheduler.set_timesteps(num_inference_steps)
model = self.dummy_model()
sample1 = self.dummy_sample_deter
sample2 = self.dummy_sample_deter + 0.1
sample3 = self.dummy_sample_deter - 0.1
per_sample_batch = sample1.shape[0]
samples = torch.stack([sample1, sample2, sample3], dim=0)
timesteps = torch.arange(num_inference_steps)[0:3, None].repeat(1, per_sample_batch)
residual = model(samples.flatten(0, 1), timesteps.flatten(0, 1))
pred_prev_sample = scheduler.batch_step_no_noise(residual, timesteps.flatten(0, 1), samples.flatten(0, 1), eta)
result_sum = torch.sum(torch.abs(pred_prev_sample))
result_mean = torch.mean(torch.abs(pred_prev_sample))
assert abs(result_sum.item() - 1147.7904) < 1e-2
assert abs(result_mean.item() - 0.4982) < 1e-3
def test_full_loop_no_noise(self):
sample = self.full_loop()
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 172.0067) < 1e-2
assert abs(result_mean.item() - 0.223967) < 1e-3
def test_full_loop_with_v_prediction(self):
sample = self.full_loop(prediction_type="v_prediction")
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 52.5302) < 1e-2
assert abs(result_mean.item() - 0.0684) < 1e-3
def test_full_loop_with_set_alpha_to_one(self):
# We specify different beta, so that the first alpha is 0.99
sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 149.8295) < 1e-2
assert abs(result_mean.item() - 0.1951) < 1e-3
def test_full_loop_with_no_set_alpha_to_one(self):
# We specify different beta, so that the first alpha is 0.99
sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01)
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 149.0784) < 1e-2
assert abs(result_mean.item() - 0.1941) < 1e-3
# Copyright 2023 ParaDiGMS authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from diffusers import DDPMParallelScheduler
from .test_schedulers import SchedulerCommonTest
class DDPMParallelSchedulerTest(SchedulerCommonTest):
scheduler_classes = (DDPMParallelScheduler,)
def get_scheduler_config(self, **kwargs):
config = {
"num_train_timesteps": 1000,
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
"variance_type": "fixed_small",
"clip_sample": True,
}
config.update(**kwargs)
return config
def test_timesteps(self):
for timesteps in [1, 5, 100, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
def test_schedules(self):
for schedule in ["linear", "squaredcos_cap_v2"]:
self.check_over_configs(beta_schedule=schedule)
def test_variance_type(self):
for variance in ["fixed_small", "fixed_large", "other"]:
self.check_over_configs(variance_type=variance)
def test_clip_sample(self):
for clip_sample in [True, False]:
self.check_over_configs(clip_sample=clip_sample)
def test_thresholding(self):
self.check_over_configs(thresholding=False)
for threshold in [0.5, 1.0, 2.0]:
for prediction_type in ["epsilon", "sample", "v_prediction"]:
self.check_over_configs(
thresholding=True,
prediction_type=prediction_type,
sample_max_value=threshold,
)
def test_prediction_type(self):
for prediction_type in ["epsilon", "sample", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type)
def test_time_indices(self):
for t in [0, 500, 999]:
self.check_over_forward(time_step=t)
def test_variance(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
assert torch.sum(torch.abs(scheduler._get_variance(0) - 0.0)) < 1e-5
assert torch.sum(torch.abs(scheduler._get_variance(487) - 0.00979)) < 1e-5
assert torch.sum(torch.abs(scheduler._get_variance(999) - 0.02)) < 1e-5
def test_batch_step_no_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
num_trained_timesteps = len(scheduler)
model = self.dummy_model()
sample1 = self.dummy_sample_deter
sample2 = self.dummy_sample_deter + 0.1
sample3 = self.dummy_sample_deter - 0.1
per_sample_batch = sample1.shape[0]
samples = torch.stack([sample1, sample2, sample3], dim=0)
timesteps = torch.arange(num_trained_timesteps)[0:3, None].repeat(1, per_sample_batch)
residual = model(samples.flatten(0, 1), timesteps.flatten(0, 1))
pred_prev_sample = scheduler.batch_step_no_noise(residual, timesteps.flatten(0, 1), samples.flatten(0, 1))
result_sum = torch.sum(torch.abs(pred_prev_sample))
result_mean = torch.mean(torch.abs(pred_prev_sample))
assert abs(result_sum.item() - 1153.1833) < 1e-2
assert abs(result_mean.item() - 0.5005) < 1e-3
def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
num_trained_timesteps = len(scheduler)
model = self.dummy_model()
sample = self.dummy_sample_deter
generator = torch.manual_seed(0)
for t in reversed(range(num_trained_timesteps)):
# 1. predict noise residual
residual = model(sample, t)
# 2. predict previous mean of sample x_t-1
pred_prev_sample = scheduler.step(residual, t, sample, generator=generator).prev_sample
sample = pred_prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 258.9606) < 1e-2
assert abs(result_mean.item() - 0.3372) < 1e-3
def test_full_loop_with_v_prediction(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(prediction_type="v_prediction")
scheduler = scheduler_class(**scheduler_config)
num_trained_timesteps = len(scheduler)
model = self.dummy_model()
sample = self.dummy_sample_deter
generator = torch.manual_seed(0)
for t in reversed(range(num_trained_timesteps)):
# 1. predict noise residual
residual = model(sample, t)
# 2. predict previous mean of sample x_t-1
pred_prev_sample = scheduler.step(residual, t, sample, generator=generator).prev_sample
sample = pred_prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 202.0296) < 1e-2
assert abs(result_mean.item() - 0.2631) < 1e-3
def test_custom_timesteps(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
timesteps = [100, 87, 50, 1, 0]
scheduler.set_timesteps(timesteps=timesteps)
scheduler_timesteps = scheduler.timesteps
for i, timestep in enumerate(scheduler_timesteps):
if i == len(timesteps) - 1:
expected_prev_t = -1
else:
expected_prev_t = timesteps[i + 1]
prev_t = scheduler.previous_timestep(timestep)
prev_t = prev_t.item()
self.assertEqual(prev_t, expected_prev_t)
def test_custom_timesteps_increasing_order(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
timesteps = [100, 87, 50, 51, 0]
with self.assertRaises(ValueError, msg="`custom_timesteps` must be in descending order."):
scheduler.set_timesteps(timesteps=timesteps)
def test_custom_timesteps_passing_both_num_inference_steps_and_timesteps(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
timesteps = [100, 87, 50, 1, 0]
num_inference_steps = len(timesteps)
with self.assertRaises(ValueError, msg="Can only pass one of `num_inference_steps` or `custom_timesteps`."):
scheduler.set_timesteps(num_inference_steps=num_inference_steps, timesteps=timesteps)
def test_custom_timesteps_too_large(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
timesteps = [scheduler.config.num_train_timesteps]
with self.assertRaises(
ValueError,
msg="`timesteps` must start before `self.config.train_timesteps`: {scheduler.config.num_train_timesteps}}",
):
scheduler.set_timesteps(timesteps=timesteps)
...@@ -238,6 +238,12 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -238,6 +238,12 @@ class SchedulerCommonTest(unittest.TestCase):
def dummy_model(self): def dummy_model(self):
def model(sample, t, *args): def model(sample, t, *args):
# if t is a tensor, match the number of dimensions of sample
if isinstance(t, torch.Tensor):
num_dims = len(sample.shape)
# pad t with 1s to match num_dims
t = t.reshape(-1, *(1,) * (num_dims - 1)).to(sample.device).to(sample.dtype)
return sample * t / (t + 1) return sample * t / (t + 1)
return model return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment