Unverified Commit 958e17da authored by dg845's avatar dg845 Committed by GitHub
Browse files

Add Latent Consistency Models Pipeline (#5448)



* initial commit for LatentConsistencyModelPipeline and LCMScheduler based on the community pipeline

* Add callback and freeu support.

* apply suggestions from review

* Clean up LCMScheduler

* Remove timeindex argument to LCMScheduler.step.

* Add support for clipping or thresholding the predicted original sample.

* Remove unused methods and arguments in LCMScheduler.

* Improve comment about (lack of) negative prompt support.

* Change input guidance_scale to match the StableDiffusionPipeline (Imagen) CFG formulation.

* Move lcm_origin_steps from pipeline __call__ to LCMScheduler.__init__/config (as origin_steps).

* Fix typo when clipping/thresholding in LCMScheduler.

* Add some initial LCMScheduler tests.

* add type annotations from review

* Fix type annotation bug.

* Override test_add_noise_device in LCMSchedulerTest since hardcoded timesteps doesn't work under default settings.

* Add generator argument pipeline prepare_latents call.

* Cast LCMScheduler.timesteps to long in set_timesteps.

* Add onestep and multistep full loop scheduler tests.

* Set default height/width to None and don't hardcode guidance scale embedding dim.

* Add initial LatentConsistencyPipeline fast and slow tests.

* Add initial documentation for LatentConsistencyModelPipeline and LCMScheduler.

* Make remaining failing fast tests pass.

* make style

* Make original_inference_steps configurable from pipeline __call__ again.

* make style

* Remove guidance_rescale arg from pipeline __call__ since LCM currently doesn't support CFG.

* Make LCMScheduler defaults match config of LCM_Dreamshaper_v7 checkpoint.

* Fix LatentConsistencyPipeline slow tests and add dummy expected slices.

* Add checks for original_steps in LCMScheduler.set_timesteps.

* make fix-copies

* Improve LatentConsistencyModelPipeline docs.

* Apply suggestions from code review
Co-authored-by: default avatarAryan V S <avs050602@gmail.com>

* Apply suggestions from code review
Co-authored-by: default avatarAryan V S <avs050602@gmail.com>

* Apply suggestions from code review
Co-authored-by: default avatarAryan V S <avs050602@gmail.com>

* Update src/diffusers/schedulers/scheduling_lcm.py

* Apply suggestions from code review
Co-authored-by: default avatarAryan V S <avs050602@gmail.com>

* finish

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarAryan V S <avs050602@gmail.com>
parent 7c3a75a1
...@@ -252,6 +252,8 @@ ...@@ -252,6 +252,8 @@
title: Kandinsky title: Kandinsky
- local: api/pipelines/kandinsky_v22 - local: api/pipelines/kandinsky_v22
title: Kandinsky 2.2 title: Kandinsky 2.2
- local: api/pipelines/latent_consistency_models
title: Latent Consistency Models
- local: api/pipelines/latent_diffusion - local: api/pipelines/latent_diffusion
title: Latent Diffusion title: Latent Diffusion
- local: api/pipelines/panorama - local: api/pipelines/panorama
...@@ -368,6 +370,8 @@ ...@@ -368,6 +370,8 @@
title: KDPM2AncestralDiscreteScheduler title: KDPM2AncestralDiscreteScheduler
- local: api/schedulers/dpm_discrete - local: api/schedulers/dpm_discrete
title: KDPM2DiscreteScheduler title: KDPM2DiscreteScheduler
- local: api/schedulers/lcm
title: LCMScheduler
- local: api/schedulers/lms_discrete - local: api/schedulers/lms_discrete
title: LMSDiscreteScheduler title: LMSDiscreteScheduler
- local: api/schedulers/pndm - local: api/schedulers/pndm
......
# Latent Consistency Models
Latent Consistency Models (LCMs) were proposed in [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378) by Simian Luo, Yiqin Tan, Longbo Huang, Jian Li, and Hang Zhao.
The abstract of the [paper](https://arxiv.org/pdf/2310.04378.pdf) is as follows:
*Latent Diffusion models (LDMs) have achieved remarkable results in synthesizing high-resolution images. However, the iterative sampling process is computationally intensive and leads to slow generation. Inspired by Consistency Models (song et al.), we propose Latent Consistency Models (LCMs), enabling swift inference with minimal steps on any pre-trained LDMs, including Stable Diffusion (rombach et al). Viewing the guided reverse diffusion process as solving an augmented probability flow ODE (PF-ODE), LCMs are designed to directly predict the solution of such ODE in latent space, mitigating the need for numerous iterations and allowing rapid, high-fidelity sampling. Efficiently distilled from pre-trained classifier-free guided diffusion models, a high-quality 768 x 768 2~4-step LCM takes only 32 A100 GPU hours for training. Furthermore, we introduce Latent Consistency Fine-tuning (LCF), a novel method that is tailored for fine-tuning LCMs on customized image datasets. Evaluation on the LAION-5B-Aesthetics dataset demonstrates that LCMs achieve state-of-the-art text-to-image generation performance with few-step inference.*
A demo for the [SimianLuo/LCM_Dreamshaper_v7](https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7) checkpoint can be found [here](https://huggingface.co/spaces/SimianLuo/Latent_Consistency_Model).
This pipeline was contributed by [luosiallen](https://luosiallen.github.io/) and [dg845](https://github.com/dg845).
```python
import torch
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained("SimianLuo/LCM_Dreamshaper_v7", torch_dtype=torch.float32)
# To save GPU memory, torch.float16 can be used, but it may compromise image quality.
pipe.to(torch_device="cuda", torch_dtype=torch.float32)
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
# Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps.
num_inference_steps = 4
images = pipe(prompt=prompt, num_inference_steps=num_inference_steps, guidance_scale=8.0).images
```
## LatentConsistencyModelPipeline
[[autodoc]] LatentConsistencyModelPipeline
- all
- __call__
- enable_freeu
- disable_freeu
- enable_vae_slicing
- disable_vae_slicing
- enable_vae_tiling
- disable_vae_tiling
## StableDiffusionPipelineOutput
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
# Latent Consistency Model Multistep Scheduler
## Overview
Multistep and onestep scheduler (Algorithm 3) introduced alongside latent consistency models in the paper [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378) by Simian Luo, Yiqin Tan, Longbo Huang, Jian Li, and Hang Zhao.
This scheduler should be able to generate good samples from [`LatentConsistencyModelPipeline`] in 1-8 steps.
## LCMScheduler
[[autodoc]] LCMScheduler
...@@ -142,6 +142,7 @@ else: ...@@ -142,6 +142,7 @@ else:
"KarrasVeScheduler", "KarrasVeScheduler",
"KDPM2AncestralDiscreteScheduler", "KDPM2AncestralDiscreteScheduler",
"KDPM2DiscreteScheduler", "KDPM2DiscreteScheduler",
"LCMScheduler",
"PNDMScheduler", "PNDMScheduler",
"RePaintScheduler", "RePaintScheduler",
"SchedulerMixin", "SchedulerMixin",
...@@ -226,6 +227,7 @@ else: ...@@ -226,6 +227,7 @@ else:
"KandinskyV22Pipeline", "KandinskyV22Pipeline",
"KandinskyV22PriorEmb2EmbPipeline", "KandinskyV22PriorEmb2EmbPipeline",
"KandinskyV22PriorPipeline", "KandinskyV22PriorPipeline",
"LatentConsistencyModelPipeline",
"LDMTextToImagePipeline", "LDMTextToImagePipeline",
"MusicLDMPipeline", "MusicLDMPipeline",
"PaintByExamplePipeline", "PaintByExamplePipeline",
...@@ -499,6 +501,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -499,6 +501,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
KarrasVeScheduler, KarrasVeScheduler,
KDPM2AncestralDiscreteScheduler, KDPM2AncestralDiscreteScheduler,
KDPM2DiscreteScheduler, KDPM2DiscreteScheduler,
LCMScheduler,
PNDMScheduler, PNDMScheduler,
RePaintScheduler, RePaintScheduler,
SchedulerMixin, SchedulerMixin,
...@@ -564,6 +567,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -564,6 +567,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
KandinskyV22Pipeline, KandinskyV22Pipeline,
KandinskyV22PriorEmb2EmbPipeline, KandinskyV22PriorEmb2EmbPipeline,
KandinskyV22PriorPipeline, KandinskyV22PriorPipeline,
LatentConsistencyModelPipeline,
LDMTextToImagePipeline, LDMTextToImagePipeline,
MusicLDMPipeline, MusicLDMPipeline,
PaintByExamplePipeline, PaintByExamplePipeline,
......
...@@ -109,6 +109,7 @@ else: ...@@ -109,6 +109,7 @@ else:
"KandinskyV22PriorEmb2EmbPipeline", "KandinskyV22PriorEmb2EmbPipeline",
"KandinskyV22PriorPipeline", "KandinskyV22PriorPipeline",
] ]
_import_structure["latent_consistency_models"] = ["LatentConsistencyModelPipeline"]
_import_structure["latent_diffusion"].extend(["LDMTextToImagePipeline"]) _import_structure["latent_diffusion"].extend(["LDMTextToImagePipeline"])
_import_structure["musicldm"] = ["MusicLDMPipeline"] _import_structure["musicldm"] = ["MusicLDMPipeline"]
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"] _import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
...@@ -331,6 +332,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -331,6 +332,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
KandinskyV22PriorEmb2EmbPipeline, KandinskyV22PriorEmb2EmbPipeline,
KandinskyV22PriorPipeline, KandinskyV22PriorPipeline,
) )
from .latent_consistency_models import LatentConsistencyModelPipeline
from .latent_diffusion import LDMTextToImagePipeline from .latent_diffusion import LDMTextToImagePipeline
from .musicldm import MusicLDMPipeline from .musicldm import MusicLDMPipeline
from .paint_by_example import PaintByExamplePipeline from .paint_by_example import PaintByExamplePipeline
......
from typing import TYPE_CHECKING
from ...utils import (
_LazyModule,
)
_import_structure = {"pipeline_latent_consistency_models": ["LatentConsistencyModelPipeline"]}
if TYPE_CHECKING:
from .pipeline_latent_consistency_models import LatentConsistencyModelPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
...@@ -56,6 +56,7 @@ else: ...@@ -56,6 +56,7 @@ else:
_import_structure["scheduling_k_dpm_2_ancestral_discrete"] = ["KDPM2AncestralDiscreteScheduler"] _import_structure["scheduling_k_dpm_2_ancestral_discrete"] = ["KDPM2AncestralDiscreteScheduler"]
_import_structure["scheduling_k_dpm_2_discrete"] = ["KDPM2DiscreteScheduler"] _import_structure["scheduling_k_dpm_2_discrete"] = ["KDPM2DiscreteScheduler"]
_import_structure["scheduling_karras_ve"] = ["KarrasVeScheduler"] _import_structure["scheduling_karras_ve"] = ["KarrasVeScheduler"]
_import_structure["scheduling_lcm"] = ["LCMScheduler"]
_import_structure["scheduling_pndm"] = ["PNDMScheduler"] _import_structure["scheduling_pndm"] = ["PNDMScheduler"]
_import_structure["scheduling_repaint"] = ["RePaintScheduler"] _import_structure["scheduling_repaint"] = ["RePaintScheduler"]
_import_structure["scheduling_sde_ve"] = ["ScoreSdeVeScheduler"] _import_structure["scheduling_sde_ve"] = ["ScoreSdeVeScheduler"]
...@@ -145,6 +146,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -145,6 +146,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler
from .scheduling_k_dpm_2_discrete import KDPM2DiscreteScheduler from .scheduling_k_dpm_2_discrete import KDPM2DiscreteScheduler
from .scheduling_karras_ve import KarrasVeScheduler from .scheduling_karras_ve import KarrasVeScheduler
from .scheduling_lcm import LCMScheduler
from .scheduling_pndm import PNDMScheduler from .scheduling_pndm import PNDMScheduler
from .scheduling_repaint import RePaintScheduler from .scheduling_repaint import RePaintScheduler
from .scheduling_sde_ve import ScoreSdeVeScheduler from .scheduling_sde_ve import ScoreSdeVeScheduler
......
This diff is collapsed.
...@@ -825,6 +825,21 @@ class KDPM2DiscreteScheduler(metaclass=DummyObject): ...@@ -825,6 +825,21 @@ class KDPM2DiscreteScheduler(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class LCMScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class PNDMScheduler(metaclass=DummyObject): class PNDMScheduler(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -482,6 +482,21 @@ class KandinskyV22PriorPipeline(metaclass=DummyObject): ...@@ -482,6 +482,21 @@ class KandinskyV22PriorPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class LatentConsistencyModelPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class LDMTextToImagePipeline(metaclass=DummyObject): class LDMTextToImagePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
import gc
import unittest
import numpy as np
import torch
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import (
AutoencoderKL,
LatentConsistencyModelPipeline,
LCMScheduler,
UNet2DConditionModel,
)
from diffusers.utils.testing_utils import (
enable_full_determinism,
require_torch_gpu,
slow,
torch_device,
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
enable_full_determinism()
class LatentConsistencyModelPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_class = LatentConsistencyModelPipeline
params = TEXT_TO_IMAGE_PARAMS - {"negative_prompt", "negative_prompt_embeds"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {"negative_prompt"}
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(4, 8),
layers_per_block=1,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
norm_num_groups=2,
time_cond_proj_dim=32,
)
scheduler = LCMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[4, 8],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
norm_num_groups=2,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=64,
layer_norm_eps=1e-05,
num_attention_heads=8,
num_hidden_layers=3,
pad_token_id=1,
vocab_size=1000,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
"requires_safety_checker": False,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"output_type": "np",
}
return inputs
def test_lcm_onestep(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
pipe = LatentConsistencyModelPipeline(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = 1
output = pipe(**inputs)
image = output.images
assert image.shape == (1, 64, 64, 3)
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.1441, 0.5304, 0.5452, 0.1361, 0.4011, 0.4370, 0.5326, 0.3492, 0.3637])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_lcm_multistep(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
pipe = LatentConsistencyModelPipeline(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
output = pipe(**inputs)
image = output.images
assert image.shape == (1, 64, 64, 3)
image_slice = image[0, -3:, -3:, -1]
# TODO: get expected slice
expected_slice = np.array([0.1540, 0.5205, 0.5458, 0.1200, 0.3983, 0.4350, 0.5386, 0.3522, 0.3614])
assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=5e-4)
@slow
@require_torch_gpu
class LatentConsistencyModelPipelineSlowTests(unittest.TestCase):
def setUp(self):
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
latents = np.random.RandomState(seed).standard_normal((1, 4, 64, 64))
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
inputs = {
"prompt": "a photograph of an astronaut riding a horse",
"latents": latents,
"generator": generator,
"num_inference_steps": 3,
"guidance_scale": 7.5,
"output_type": "np",
}
return inputs
def test_lcm_onestep(self):
pipe = LatentConsistencyModelPipeline.from_pretrained("SimianLuo/LCM_Dreamshaper_v7", safety_checker=None)
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
inputs["num_inference_steps"] = 1
image = pipe(**inputs).images
assert image.shape == (1, 512, 512, 3)
image_slice = image[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.1025, 0.0911, 0.0984, 0.0981, 0.0901, 0.0918, 0.1055, 0.0940, 0.0730])
assert np.abs(image_slice - expected_slice).max() < 1e-3
def test_lcm_multistep(self):
pipe = LatentConsistencyModelPipeline.from_pretrained("SimianLuo/LCM_Dreamshaper_v7", safety_checker=None)
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images
assert image.shape == (1, 512, 512, 3)
image_slice = image[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.01855, 0.01855, 0.01489, 0.01392, 0.01782, 0.01465, 0.01831, 0.02539, 0.0])
assert np.abs(image_slice - expected_slice).max() < 1e-3
import tempfile
from typing import Dict, List, Tuple
import torch
from diffusers import LCMScheduler
from diffusers.utils.testing_utils import torch_device
from .test_schedulers import SchedulerCommonTest
class LCMSchedulerTest(SchedulerCommonTest):
scheduler_classes = (LCMScheduler,)
forward_default_kwargs = (("num_inference_steps", 10),)
def get_scheduler_config(self, **kwargs):
config = {
"num_train_timesteps": 1000,
"beta_start": 0.00085,
"beta_end": 0.0120,
"beta_schedule": "scaled_linear",
"prediction_type": "epsilon",
}
config.update(**kwargs)
return config
@property
def default_valid_timestep(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
scheduler_config = self.get_scheduler_config()
scheduler = self.scheduler_classes[0](**scheduler_config)
scheduler.set_timesteps(num_inference_steps)
timestep = scheduler.timesteps[-1]
return timestep
def test_timesteps(self):
for timesteps in [100, 500, 1000]:
# 0 is not guaranteed to be in the timestep schedule, but timesteps - 1 is
self.check_over_configs(time_step=timesteps - 1, num_train_timesteps=timesteps)
def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
self.check_over_configs(time_step=self.default_valid_timestep, beta_start=beta_start, beta_end=beta_end)
def test_schedules(self):
for schedule in ["linear", "scaled_linear", "squaredcos_cap_v2"]:
self.check_over_configs(time_step=self.default_valid_timestep, beta_schedule=schedule)
def test_prediction_type(self):
for prediction_type in ["epsilon", "v_prediction"]:
self.check_over_configs(time_step=self.default_valid_timestep, prediction_type=prediction_type)
def test_clip_sample(self):
for clip_sample in [True, False]:
self.check_over_configs(time_step=self.default_valid_timestep, clip_sample=clip_sample)
def test_thresholding(self):
self.check_over_configs(time_step=self.default_valid_timestep, thresholding=False)
for threshold in [0.5, 1.0, 2.0]:
for prediction_type in ["epsilon", "v_prediction"]:
self.check_over_configs(
time_step=self.default_valid_timestep,
thresholding=True,
prediction_type=prediction_type,
sample_max_value=threshold,
)
def test_time_indices(self):
# Get default timestep schedule.
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
scheduler_config = self.get_scheduler_config()
scheduler = self.scheduler_classes[0](**scheduler_config)
scheduler.set_timesteps(num_inference_steps)
timesteps = scheduler.timesteps
for t in timesteps:
self.check_over_forward(time_step=t)
def test_inference_steps(self):
# Hardcoded for now
for t, num_inference_steps in zip([99, 39, 19], [10, 25, 50]):
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
# Override test_add_noise_device because the hardcoded num_inference_steps of 100 doesn't work
# for LCMScheduler under default settings
def test_add_noise_device(self, num_inference_steps=10):
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(num_inference_steps)
sample = self.dummy_sample.to(torch_device)
scaled_sample = scheduler.scale_model_input(sample, 0.0)
self.assertEqual(sample.shape, scaled_sample.shape)
noise = torch.randn_like(scaled_sample).to(torch_device)
t = scheduler.timesteps[5][None]
noised = scheduler.add_noise(scaled_sample, noise, t)
self.assertEqual(noised.shape, scaled_sample.shape)
# Override test_from_save_pretrained because it hardcodes a timestep of 1
def test_from_save_pretrained(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
timestep = self.default_valid_timestep
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
sample = self.dummy_sample
residual = 0.1 * sample
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
scheduler.set_timesteps(num_inference_steps)
new_scheduler.set_timesteps(num_inference_steps)
kwargs["generator"] = torch.manual_seed(0)
output = scheduler.step(residual, timestep, sample, **kwargs).prev_sample
kwargs["generator"] = torch.manual_seed(0)
new_output = new_scheduler.step(residual, timestep, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
# Override test_step_shape because uses 0 and 1 as hardcoded timesteps
def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
sample = self.dummy_sample
residual = 0.1 * sample
scheduler.set_timesteps(num_inference_steps)
timestep_0 = scheduler.timesteps[-2]
timestep_1 = scheduler.timesteps[-1]
output_0 = scheduler.step(residual, timestep_0, sample, **kwargs).prev_sample
output_1 = scheduler.step(residual, timestep_1, sample, **kwargs).prev_sample
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
# Override test_set_scheduler_outputs_equivalence since it uses 0 as a hardcoded timestep
def test_scheduler_outputs_equivalence(self):
def set_nan_tensor_to_zero(t):
t[t != t] = 0
return t
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, Dict):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
),
)
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", 50)
timestep = self.default_valid_timestep
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
sample = self.dummy_sample
residual = 0.1 * sample
scheduler.set_timesteps(num_inference_steps)
kwargs["generator"] = torch.manual_seed(0)
outputs_dict = scheduler.step(residual, timestep, sample, **kwargs)
scheduler.set_timesteps(num_inference_steps)
kwargs["generator"] = torch.manual_seed(0)
outputs_tuple = scheduler.step(residual, timestep, sample, return_dict=False, **kwargs)
recursive_check(outputs_tuple, outputs_dict)
def full_loop(self, num_inference_steps=10, seed=0, **config):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
model = self.dummy_model()
sample = self.dummy_sample_deter
generator = torch.manual_seed(seed)
scheduler.set_timesteps(num_inference_steps)
for t in scheduler.timesteps:
residual = model(sample, t)
sample = scheduler.step(residual, t, sample, generator).prev_sample
return sample
def test_full_loop_onestep(self):
sample = self.full_loop(num_inference_steps=1)
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
# TODO: get expected sum and mean
assert abs(result_sum.item() - 18.7097) < 1e-2
assert abs(result_mean.item() - 0.0244) < 1e-3
def test_full_loop_multistep(self):
sample = self.full_loop(num_inference_steps=10)
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
# TODO: get expected sum and mean
assert abs(result_sum.item() - 280.5618) < 1e-2
assert abs(result_mean.item() - 0.3653) < 1e-3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment