Unverified Commit 2fd46405 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

consistency decoder (#5694)



* consistency decoder

* rename

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py

* uP

* Apply suggestions from code review

* uP

* uP

* uP

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 43346adc
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils.torch_utils import randn_tensor
from .scheduling_utils import SchedulerMixin
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
Choose from `cosine` or `exp`
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
if alpha_transform_type == "cosine":
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
@dataclass
class ConsistencyDecoderSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.FloatTensor
class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin):
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1024,
sigma_data: float = 0.5,
):
betas = betas_for_alpha_bar(num_train_timesteps)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
sigmas = torch.sqrt(1.0 / alphas_cumprod - 1)
sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
self.c_skip = sqrt_recip_alphas_cumprod * sigma_data**2 / (sigmas**2 + sigma_data**2)
self.c_out = sigmas * sigma_data / (sigmas**2 + sigma_data**2) ** 0.5
self.c_in = sqrt_recip_alphas_cumprod / (sigmas**2 + sigma_data**2) ** 0.5
def set_timesteps(
self,
num_inference_steps: Optional[int] = None,
device: Union[str, torch.device] = None,
):
if num_inference_steps != 2:
raise ValueError("Currently more than 2 inference steps are not supported.")
self.timesteps = torch.tensor([1008, 512], dtype=torch.long, device=device)
self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(device)
self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(device)
self.c_skip = self.c_skip.to(device)
self.c_out = self.c_out.to(device)
self.c_in = self.c_in.to(device)
@property
def init_noise_sigma(self):
return self.sqrt_one_minus_alphas_cumprod[self.timesteps[0]]
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.FloatTensor`:
A scaled input sample.
"""
return sample * self.c_in[timestep]
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[ConsistencyDecoderSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from the learned diffusion model.
timestep (`float`):
The current timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a
[`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] or `tuple`:
If return_dict is `True`,
[`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] is returned, otherwise
a tuple is returned where the first element is the sample tensor.
"""
x_0 = self.c_out[timestep] * model_output + self.c_skip[timestep] * sample
timestep_idx = torch.where(self.timesteps == timestep)[0]
if timestep_idx == len(self.timesteps) - 1:
prev_sample = x_0
else:
noise = randn_tensor(x_0.shape, generator=generator, dtype=x_0.dtype, device=x_0.device)
prev_sample = (
self.sqrt_alphas_cumprod[self.timesteps[timestep_idx + 1]].to(x_0.dtype) * x_0
+ self.sqrt_one_minus_alphas_cumprod[self.timesteps[timestep_idx + 1]].to(x_0.dtype) * noise
)
if not return_dict:
return (prev_sample,)
return ConsistencyDecoderSchedulerOutput(prev_sample=prev_sample)
......@@ -47,6 +47,21 @@ class AutoencoderTiny(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class ConsistencyDecoderVAE(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class ControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
......
......@@ -196,11 +196,15 @@ class UNetTesterMixin:
class ModelTesterMixin:
main_input_name = None # overwrite in model specific tester class
base_precision = 1e-3
forward_requires_fresh_args = False
def test_from_save_pretrained(self, expected_max_diff=5e-5):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
if self.forward_requires_fresh_args:
model = self.model_class(**self.init_dict)
else:
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model = self.model_class(**init_dict)
if hasattr(model, "set_default_attn_processor"):
model.set_default_attn_processor()
model.to(torch_device)
......@@ -214,11 +218,18 @@ class ModelTesterMixin:
new_model.to(torch_device)
with torch.no_grad():
image = model(**inputs_dict)
if self.forward_requires_fresh_args:
image = model(**self.inputs_dict(0))
else:
image = model(**inputs_dict)
if isinstance(image, dict):
image = image.to_tuple()[0]
new_image = new_model(**inputs_dict)
if self.forward_requires_fresh_args:
new_image = new_model(**self.inputs_dict(0))
else:
new_image = new_model(**inputs_dict)
if isinstance(new_image, dict):
new_image = new_image.to_tuple()[0]
......@@ -275,8 +286,11 @@ class ModelTesterMixin:
)
def test_set_xformers_attn_processor_for_determinism(self):
torch.use_deterministic_algorithms(False)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
if self.forward_requires_fresh_args:
model = self.model_class(**self.init_dict)
else:
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
if not hasattr(model, "set_attn_processor"):
......@@ -286,17 +300,26 @@ class ModelTesterMixin:
model.set_default_attn_processor()
assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values())
with torch.no_grad():
output = model(**inputs_dict)[0]
if self.forward_requires_fresh_args:
output = model(**self.inputs_dict(0))[0]
else:
output = model(**inputs_dict)[0]
model.enable_xformers_memory_efficient_attention()
assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
with torch.no_grad():
output_2 = model(**inputs_dict)[0]
if self.forward_requires_fresh_args:
output_2 = model(**self.inputs_dict(0))[0]
else:
output_2 = model(**inputs_dict)[0]
model.set_attn_processor(XFormersAttnProcessor())
assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
with torch.no_grad():
output_3 = model(**inputs_dict)[0]
if self.forward_requires_fresh_args:
output_3 = model(**self.inputs_dict(0))[0]
else:
output_3 = model(**inputs_dict)[0]
torch.use_deterministic_algorithms(True)
......@@ -307,8 +330,12 @@ class ModelTesterMixin:
@require_torch_gpu
def test_set_attn_processor_for_determinism(self):
torch.use_deterministic_algorithms(False)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
if self.forward_requires_fresh_args:
model = self.model_class(**self.init_dict)
else:
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
if not hasattr(model, "set_attn_processor"):
......@@ -317,22 +344,34 @@ class ModelTesterMixin:
assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values())
with torch.no_grad():
output_1 = model(**inputs_dict)[0]
if self.forward_requires_fresh_args:
output_1 = model(**self.inputs_dict(0))[0]
else:
output_1 = model(**inputs_dict)[0]
model.set_default_attn_processor()
assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values())
with torch.no_grad():
output_2 = model(**inputs_dict)[0]
if self.forward_requires_fresh_args:
output_2 = model(**self.inputs_dict(0))[0]
else:
output_2 = model(**inputs_dict)[0]
model.set_attn_processor(AttnProcessor2_0())
assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values())
with torch.no_grad():
output_4 = model(**inputs_dict)[0]
if self.forward_requires_fresh_args:
output_4 = model(**self.inputs_dict(0))[0]
else:
output_4 = model(**inputs_dict)[0]
model.set_attn_processor(AttnProcessor())
assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values())
with torch.no_grad():
output_5 = model(**inputs_dict)[0]
if self.forward_requires_fresh_args:
output_5 = model(**self.inputs_dict(0))[0]
else:
output_5 = model(**inputs_dict)[0]
torch.use_deterministic_algorithms(True)
......@@ -342,9 +381,12 @@ class ModelTesterMixin:
assert torch.allclose(output_2, output_5, atol=self.base_precision)
def test_from_save_pretrained_variant(self, expected_max_diff=5e-5):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
if self.forward_requires_fresh_args:
model = self.model_class(**self.init_dict)
else:
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model = self.model_class(**init_dict)
if hasattr(model, "set_default_attn_processor"):
model.set_default_attn_processor()
......@@ -367,11 +409,17 @@ class ModelTesterMixin:
new_model.to(torch_device)
with torch.no_grad():
image = model(**inputs_dict)
if self.forward_requires_fresh_args:
image = model(**self.inputs_dict(0))
else:
image = model(**inputs_dict)
if isinstance(image, dict):
image = image.to_tuple()[0]
new_image = new_model(**inputs_dict)
if self.forward_requires_fresh_args:
new_image = new_model(**self.inputs_dict(0))
else:
new_image = new_model(**inputs_dict)
if isinstance(new_image, dict):
new_image = new_image.to_tuple()[0]
......@@ -405,17 +453,26 @@ class ModelTesterMixin:
assert new_model.dtype == dtype
def test_determinism(self, expected_max_diff=1e-5):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
if self.forward_requires_fresh_args:
model = self.model_class(**self.init_dict)
else:
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
first = model(**inputs_dict)
if self.forward_requires_fresh_args:
first = model(**self.inputs_dict(0))
else:
first = model(**inputs_dict)
if isinstance(first, dict):
first = first.to_tuple()[0]
second = model(**inputs_dict)
if self.forward_requires_fresh_args:
second = model(**self.inputs_dict(0))
else:
second = model(**inputs_dict)
if isinstance(second, dict):
second = second.to_tuple()[0]
......@@ -548,15 +605,22 @@ class ModelTesterMixin:
),
)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
if self.forward_requires_fresh_args:
model = self.model_class(**self.init_dict)
else:
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs_dict = model(**inputs_dict)
outputs_tuple = model(**inputs_dict, return_dict=False)
if self.forward_requires_fresh_args:
outputs_dict = model(**self.inputs_dict(0))
outputs_tuple = model(**self.inputs_dict(0), return_dict=False)
else:
outputs_dict = model(**inputs_dict)
outputs_tuple = model(**inputs_dict, return_dict=False)
recursive_check(outputs_tuple, outputs_dict)
......
......@@ -16,11 +16,19 @@
import gc
import unittest
import numpy as np
import torch
from parameterized import parameterized
from diffusers import AsymmetricAutoencoderKL, AutoencoderKL, AutoencoderTiny
from diffusers import (
AsymmetricAutoencoderKL,
AutoencoderKL,
AutoencoderTiny,
ConsistencyDecoderVAE,
StableDiffusionPipeline,
)
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.loading_utils import load_image
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
......@@ -30,6 +38,7 @@ from diffusers.utils.testing_utils import (
torch_all_close,
torch_device,
)
from diffusers.utils.torch_utils import randn_tensor
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
......@@ -269,6 +278,79 @@ class AutoencoderTinyTests(ModelTesterMixin, unittest.TestCase):
pass
class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
model_class = ConsistencyDecoderVAE
main_input_name = "sample"
base_precision = 1e-2
forward_requires_fresh_args = True
def inputs_dict(self, seed=None):
generator = torch.Generator("cpu")
if seed is not None:
generator.manual_seed(0)
image = randn_tensor((4, 3, 32, 32), generator=generator, device=torch.device(torch_device))
return {"sample": image, "generator": generator}
@property
def input_shape(self):
return (3, 32, 32)
@property
def output_shape(self):
return (3, 32, 32)
@property
def init_dict(self):
return {
"encoder_args": {
"block_out_channels": [32, 64],
"in_channels": 3,
"out_channels": 4,
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
},
"decoder_args": {
"act_fn": "silu",
"add_attention": False,
"block_out_channels": [32, 64],
"down_block_types": [
"ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D",
],
"downsample_padding": 1,
"downsample_type": "conv",
"dropout": 0.0,
"in_channels": 7,
"layers_per_block": 1,
"norm_eps": 1e-05,
"norm_num_groups": 32,
"num_train_timesteps": 1024,
"out_channels": 6,
"resnet_time_scale_shift": "scale_shift",
"time_embedding_type": "learned",
"up_block_types": [
"ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D",
],
"upsample_type": "conv",
},
"scaling_factor": 1,
"block_out_channels": [32, 64],
"latent_channels": 4,
}
def prepare_init_args_and_inputs_for_common(self):
return self.init_dict, self.inputs_dict()
@unittest.skip
def test_training(self):
...
@unittest.skip
def test_ema_training(self):
...
@slow
class AutoencoderTinyIntegrationTests(unittest.TestCase):
def tearDown(self):
......@@ -721,3 +803,94 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
tolerance = 3e-3 if torch_device != "mps" else 1e-2
assert torch_all_close(output_slice, expected_output_slice, atol=tolerance)
@slow
class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_encode_decode(self):
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder") # TODO - update
vae.to(torch_device)
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/img2img/sketch-mountains-input.jpg"
).resize((256, 256))
image = torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[
None, :, :, :
].cuda()
latent = vae.encode(image).latent_dist.mean
sample = vae.decode(latent, generator=torch.Generator("cpu").manual_seed(0)).sample
actual_output = sample[0, :2, :2, :2].flatten().cpu()
expected_output = torch.tensor([-0.0141, -0.0014, 0.0115, 0.0086, 0.1051, 0.1053, 0.1031, 0.1024])
assert torch_all_close(actual_output, expected_output, atol=5e-3)
def test_sd(self):
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder") # TODO - update
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", vae=vae, safety_checker=None)
pipe.to(torch_device)
out = pipe(
"horse", num_inference_steps=2, output_type="pt", generator=torch.Generator("cpu").manual_seed(0)
).images[0]
actual_output = out[:2, :2, :2].flatten().cpu()
expected_output = torch.tensor([0.7686, 0.8228, 0.6489, 0.7455, 0.8661, 0.8797, 0.8241, 0.8759])
assert torch_all_close(actual_output, expected_output, atol=5e-3)
def test_encode_decode_f16(self):
vae = ConsistencyDecoderVAE.from_pretrained(
"openai/consistency-decoder", torch_dtype=torch.float16
) # TODO - update
vae.to(torch_device)
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/img2img/sketch-mountains-input.jpg"
).resize((256, 256))
image = (
torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[None, :, :, :]
.half()
.cuda()
)
latent = vae.encode(image).latent_dist.mean
sample = vae.decode(latent, generator=torch.Generator("cpu").manual_seed(0)).sample
actual_output = sample[0, :2, :2, :2].flatten().cpu()
expected_output = torch.tensor(
[-0.0111, -0.0125, -0.0017, -0.0007, 0.1257, 0.1465, 0.1450, 0.1471], dtype=torch.float16
)
assert torch_all_close(actual_output, expected_output, atol=5e-3)
def test_sd_f16(self):
vae = ConsistencyDecoderVAE.from_pretrained(
"openai/consistency-decoder", torch_dtype=torch.float16
) # TODO - update
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, vae=vae, safety_checker=None
)
pipe.to(torch_device)
out = pipe(
"horse", num_inference_steps=2, output_type="pt", generator=torch.Generator("cpu").manual_seed(0)
).images[0]
actual_output = out[:2, :2, :2].flatten().cpu()
expected_output = torch.tensor(
[0.0000, 0.0249, 0.0000, 0.0000, 0.1709, 0.2773, 0.0471, 0.1035], dtype=torch.float16
)
assert torch_all_close(actual_output, expected_output, atol=5e-3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment