Unverified Commit 2fd46405 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

consistency decoder (#5694)



* consistency decoder

* rename

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py

* uP

* Apply suggestions from code review

* uP

* uP

* uP

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 43346adc
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils.torch_utils import randn_tensor
from .scheduling_utils import SchedulerMixin
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
Choose from `cosine` or `exp`
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
if alpha_transform_type == "cosine":
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
@dataclass
class ConsistencyDecoderSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.FloatTensor
class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin):
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1024,
sigma_data: float = 0.5,
):
betas = betas_for_alpha_bar(num_train_timesteps)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
sigmas = torch.sqrt(1.0 / alphas_cumprod - 1)
sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
self.c_skip = sqrt_recip_alphas_cumprod * sigma_data**2 / (sigmas**2 + sigma_data**2)
self.c_out = sigmas * sigma_data / (sigmas**2 + sigma_data**2) ** 0.5
self.c_in = sqrt_recip_alphas_cumprod / (sigmas**2 + sigma_data**2) ** 0.5
def set_timesteps(
self,
num_inference_steps: Optional[int] = None,
device: Union[str, torch.device] = None,
):
if num_inference_steps != 2:
raise ValueError("Currently more than 2 inference steps are not supported.")
self.timesteps = torch.tensor([1008, 512], dtype=torch.long, device=device)
self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(device)
self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(device)
self.c_skip = self.c_skip.to(device)
self.c_out = self.c_out.to(device)
self.c_in = self.c_in.to(device)
@property
def init_noise_sigma(self):
return self.sqrt_one_minus_alphas_cumprod[self.timesteps[0]]
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.FloatTensor`:
A scaled input sample.
"""
return sample * self.c_in[timestep]
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[ConsistencyDecoderSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from the learned diffusion model.
timestep (`float`):
The current timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a
[`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] or `tuple`:
If return_dict is `True`,
[`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] is returned, otherwise
a tuple is returned where the first element is the sample tensor.
"""
x_0 = self.c_out[timestep] * model_output + self.c_skip[timestep] * sample
timestep_idx = torch.where(self.timesteps == timestep)[0]
if timestep_idx == len(self.timesteps) - 1:
prev_sample = x_0
else:
noise = randn_tensor(x_0.shape, generator=generator, dtype=x_0.dtype, device=x_0.device)
prev_sample = (
self.sqrt_alphas_cumprod[self.timesteps[timestep_idx + 1]].to(x_0.dtype) * x_0
+ self.sqrt_one_minus_alphas_cumprod[self.timesteps[timestep_idx + 1]].to(x_0.dtype) * noise
)
if not return_dict:
return (prev_sample,)
return ConsistencyDecoderSchedulerOutput(prev_sample=prev_sample)
...@@ -47,6 +47,21 @@ class AutoencoderTiny(metaclass=DummyObject): ...@@ -47,6 +47,21 @@ class AutoencoderTiny(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class ConsistencyDecoderVAE(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class ControlNetModel(metaclass=DummyObject): class ControlNetModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -196,11 +196,15 @@ class UNetTesterMixin: ...@@ -196,11 +196,15 @@ class UNetTesterMixin:
class ModelTesterMixin: class ModelTesterMixin:
main_input_name = None # overwrite in model specific tester class main_input_name = None # overwrite in model specific tester class
base_precision = 1e-3 base_precision = 1e-3
forward_requires_fresh_args = False
def test_from_save_pretrained(self, expected_max_diff=5e-5): def test_from_save_pretrained(self, expected_max_diff=5e-5):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() if self.forward_requires_fresh_args:
model = self.model_class(**self.init_dict)
else:
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model = self.model_class(**init_dict)
if hasattr(model, "set_default_attn_processor"): if hasattr(model, "set_default_attn_processor"):
model.set_default_attn_processor() model.set_default_attn_processor()
model.to(torch_device) model.to(torch_device)
...@@ -214,11 +218,18 @@ class ModelTesterMixin: ...@@ -214,11 +218,18 @@ class ModelTesterMixin:
new_model.to(torch_device) new_model.to(torch_device)
with torch.no_grad(): with torch.no_grad():
image = model(**inputs_dict) if self.forward_requires_fresh_args:
image = model(**self.inputs_dict(0))
else:
image = model(**inputs_dict)
if isinstance(image, dict): if isinstance(image, dict):
image = image.to_tuple()[0] image = image.to_tuple()[0]
new_image = new_model(**inputs_dict) if self.forward_requires_fresh_args:
new_image = new_model(**self.inputs_dict(0))
else:
new_image = new_model(**inputs_dict)
if isinstance(new_image, dict): if isinstance(new_image, dict):
new_image = new_image.to_tuple()[0] new_image = new_image.to_tuple()[0]
...@@ -275,8 +286,11 @@ class ModelTesterMixin: ...@@ -275,8 +286,11 @@ class ModelTesterMixin:
) )
def test_set_xformers_attn_processor_for_determinism(self): def test_set_xformers_attn_processor_for_determinism(self):
torch.use_deterministic_algorithms(False) torch.use_deterministic_algorithms(False)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() if self.forward_requires_fresh_args:
model = self.model_class(**init_dict) model = self.model_class(**self.init_dict)
else:
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device) model.to(torch_device)
if not hasattr(model, "set_attn_processor"): if not hasattr(model, "set_attn_processor"):
...@@ -286,17 +300,26 @@ class ModelTesterMixin: ...@@ -286,17 +300,26 @@ class ModelTesterMixin:
model.set_default_attn_processor() model.set_default_attn_processor()
assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values()) assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values())
with torch.no_grad(): with torch.no_grad():
output = model(**inputs_dict)[0] if self.forward_requires_fresh_args:
output = model(**self.inputs_dict(0))[0]
else:
output = model(**inputs_dict)[0]
model.enable_xformers_memory_efficient_attention() model.enable_xformers_memory_efficient_attention()
assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values()) assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
with torch.no_grad(): with torch.no_grad():
output_2 = model(**inputs_dict)[0] if self.forward_requires_fresh_args:
output_2 = model(**self.inputs_dict(0))[0]
else:
output_2 = model(**inputs_dict)[0]
model.set_attn_processor(XFormersAttnProcessor()) model.set_attn_processor(XFormersAttnProcessor())
assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values()) assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
with torch.no_grad(): with torch.no_grad():
output_3 = model(**inputs_dict)[0] if self.forward_requires_fresh_args:
output_3 = model(**self.inputs_dict(0))[0]
else:
output_3 = model(**inputs_dict)[0]
torch.use_deterministic_algorithms(True) torch.use_deterministic_algorithms(True)
...@@ -307,8 +330,12 @@ class ModelTesterMixin: ...@@ -307,8 +330,12 @@ class ModelTesterMixin:
@require_torch_gpu @require_torch_gpu
def test_set_attn_processor_for_determinism(self): def test_set_attn_processor_for_determinism(self):
torch.use_deterministic_algorithms(False) torch.use_deterministic_algorithms(False)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() if self.forward_requires_fresh_args:
model = self.model_class(**init_dict) model = self.model_class(**self.init_dict)
else:
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device) model.to(torch_device)
if not hasattr(model, "set_attn_processor"): if not hasattr(model, "set_attn_processor"):
...@@ -317,22 +344,34 @@ class ModelTesterMixin: ...@@ -317,22 +344,34 @@ class ModelTesterMixin:
assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values()) assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values())
with torch.no_grad(): with torch.no_grad():
output_1 = model(**inputs_dict)[0] if self.forward_requires_fresh_args:
output_1 = model(**self.inputs_dict(0))[0]
else:
output_1 = model(**inputs_dict)[0]
model.set_default_attn_processor() model.set_default_attn_processor()
assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values()) assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values())
with torch.no_grad(): with torch.no_grad():
output_2 = model(**inputs_dict)[0] if self.forward_requires_fresh_args:
output_2 = model(**self.inputs_dict(0))[0]
else:
output_2 = model(**inputs_dict)[0]
model.set_attn_processor(AttnProcessor2_0()) model.set_attn_processor(AttnProcessor2_0())
assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values()) assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values())
with torch.no_grad(): with torch.no_grad():
output_4 = model(**inputs_dict)[0] if self.forward_requires_fresh_args:
output_4 = model(**self.inputs_dict(0))[0]
else:
output_4 = model(**inputs_dict)[0]
model.set_attn_processor(AttnProcessor()) model.set_attn_processor(AttnProcessor())
assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values()) assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values())
with torch.no_grad(): with torch.no_grad():
output_5 = model(**inputs_dict)[0] if self.forward_requires_fresh_args:
output_5 = model(**self.inputs_dict(0))[0]
else:
output_5 = model(**inputs_dict)[0]
torch.use_deterministic_algorithms(True) torch.use_deterministic_algorithms(True)
...@@ -342,9 +381,12 @@ class ModelTesterMixin: ...@@ -342,9 +381,12 @@ class ModelTesterMixin:
assert torch.allclose(output_2, output_5, atol=self.base_precision) assert torch.allclose(output_2, output_5, atol=self.base_precision)
def test_from_save_pretrained_variant(self, expected_max_diff=5e-5): def test_from_save_pretrained_variant(self, expected_max_diff=5e-5):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() if self.forward_requires_fresh_args:
model = self.model_class(**self.init_dict)
else:
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model = self.model_class(**init_dict)
if hasattr(model, "set_default_attn_processor"): if hasattr(model, "set_default_attn_processor"):
model.set_default_attn_processor() model.set_default_attn_processor()
...@@ -367,11 +409,17 @@ class ModelTesterMixin: ...@@ -367,11 +409,17 @@ class ModelTesterMixin:
new_model.to(torch_device) new_model.to(torch_device)
with torch.no_grad(): with torch.no_grad():
image = model(**inputs_dict) if self.forward_requires_fresh_args:
image = model(**self.inputs_dict(0))
else:
image = model(**inputs_dict)
if isinstance(image, dict): if isinstance(image, dict):
image = image.to_tuple()[0] image = image.to_tuple()[0]
new_image = new_model(**inputs_dict) if self.forward_requires_fresh_args:
new_image = new_model(**self.inputs_dict(0))
else:
new_image = new_model(**inputs_dict)
if isinstance(new_image, dict): if isinstance(new_image, dict):
new_image = new_image.to_tuple()[0] new_image = new_image.to_tuple()[0]
...@@ -405,17 +453,26 @@ class ModelTesterMixin: ...@@ -405,17 +453,26 @@ class ModelTesterMixin:
assert new_model.dtype == dtype assert new_model.dtype == dtype
def test_determinism(self, expected_max_diff=1e-5): def test_determinism(self, expected_max_diff=1e-5):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() if self.forward_requires_fresh_args:
model = self.model_class(**init_dict) model = self.model_class(**self.init_dict)
else:
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
first = model(**inputs_dict) if self.forward_requires_fresh_args:
first = model(**self.inputs_dict(0))
else:
first = model(**inputs_dict)
if isinstance(first, dict): if isinstance(first, dict):
first = first.to_tuple()[0] first = first.to_tuple()[0]
second = model(**inputs_dict) if self.forward_requires_fresh_args:
second = model(**self.inputs_dict(0))
else:
second = model(**inputs_dict)
if isinstance(second, dict): if isinstance(second, dict):
second = second.to_tuple()[0] second = second.to_tuple()[0]
...@@ -548,15 +605,22 @@ class ModelTesterMixin: ...@@ -548,15 +605,22 @@ class ModelTesterMixin:
), ),
) )
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() if self.forward_requires_fresh_args:
model = self.model_class(**self.init_dict)
else:
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model = self.model_class(**init_dict)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
outputs_dict = model(**inputs_dict) if self.forward_requires_fresh_args:
outputs_tuple = model(**inputs_dict, return_dict=False) outputs_dict = model(**self.inputs_dict(0))
outputs_tuple = model(**self.inputs_dict(0), return_dict=False)
else:
outputs_dict = model(**inputs_dict)
outputs_tuple = model(**inputs_dict, return_dict=False)
recursive_check(outputs_tuple, outputs_dict) recursive_check(outputs_tuple, outputs_dict)
......
...@@ -16,11 +16,19 @@ ...@@ -16,11 +16,19 @@
import gc import gc
import unittest import unittest
import numpy as np
import torch import torch
from parameterized import parameterized from parameterized import parameterized
from diffusers import AsymmetricAutoencoderKL, AutoencoderKL, AutoencoderTiny from diffusers import (
AsymmetricAutoencoderKL,
AutoencoderKL,
AutoencoderTiny,
ConsistencyDecoderVAE,
StableDiffusionPipeline,
)
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.loading_utils import load_image
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
enable_full_determinism, enable_full_determinism,
floats_tensor, floats_tensor,
...@@ -30,6 +38,7 @@ from diffusers.utils.testing_utils import ( ...@@ -30,6 +38,7 @@ from diffusers.utils.testing_utils import (
torch_all_close, torch_all_close,
torch_device, torch_device,
) )
from diffusers.utils.torch_utils import randn_tensor
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
...@@ -269,6 +278,79 @@ class AutoencoderTinyTests(ModelTesterMixin, unittest.TestCase): ...@@ -269,6 +278,79 @@ class AutoencoderTinyTests(ModelTesterMixin, unittest.TestCase):
pass pass
class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
model_class = ConsistencyDecoderVAE
main_input_name = "sample"
base_precision = 1e-2
forward_requires_fresh_args = True
def inputs_dict(self, seed=None):
generator = torch.Generator("cpu")
if seed is not None:
generator.manual_seed(0)
image = randn_tensor((4, 3, 32, 32), generator=generator, device=torch.device(torch_device))
return {"sample": image, "generator": generator}
@property
def input_shape(self):
return (3, 32, 32)
@property
def output_shape(self):
return (3, 32, 32)
@property
def init_dict(self):
return {
"encoder_args": {
"block_out_channels": [32, 64],
"in_channels": 3,
"out_channels": 4,
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
},
"decoder_args": {
"act_fn": "silu",
"add_attention": False,
"block_out_channels": [32, 64],
"down_block_types": [
"ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D",
],
"downsample_padding": 1,
"downsample_type": "conv",
"dropout": 0.0,
"in_channels": 7,
"layers_per_block": 1,
"norm_eps": 1e-05,
"norm_num_groups": 32,
"num_train_timesteps": 1024,
"out_channels": 6,
"resnet_time_scale_shift": "scale_shift",
"time_embedding_type": "learned",
"up_block_types": [
"ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D",
],
"upsample_type": "conv",
},
"scaling_factor": 1,
"block_out_channels": [32, 64],
"latent_channels": 4,
}
def prepare_init_args_and_inputs_for_common(self):
return self.init_dict, self.inputs_dict()
@unittest.skip
def test_training(self):
...
@unittest.skip
def test_ema_training(self):
...
@slow @slow
class AutoencoderTinyIntegrationTests(unittest.TestCase): class AutoencoderTinyIntegrationTests(unittest.TestCase):
def tearDown(self): def tearDown(self):
...@@ -721,3 +803,94 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase): ...@@ -721,3 +803,94 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
tolerance = 3e-3 if torch_device != "mps" else 1e-2 tolerance = 3e-3 if torch_device != "mps" else 1e-2
assert torch_all_close(output_slice, expected_output_slice, atol=tolerance) assert torch_all_close(output_slice, expected_output_slice, atol=tolerance)
@slow
class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_encode_decode(self):
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder") # TODO - update
vae.to(torch_device)
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/img2img/sketch-mountains-input.jpg"
).resize((256, 256))
image = torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[
None, :, :, :
].cuda()
latent = vae.encode(image).latent_dist.mean
sample = vae.decode(latent, generator=torch.Generator("cpu").manual_seed(0)).sample
actual_output = sample[0, :2, :2, :2].flatten().cpu()
expected_output = torch.tensor([-0.0141, -0.0014, 0.0115, 0.0086, 0.1051, 0.1053, 0.1031, 0.1024])
assert torch_all_close(actual_output, expected_output, atol=5e-3)
def test_sd(self):
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder") # TODO - update
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", vae=vae, safety_checker=None)
pipe.to(torch_device)
out = pipe(
"horse", num_inference_steps=2, output_type="pt", generator=torch.Generator("cpu").manual_seed(0)
).images[0]
actual_output = out[:2, :2, :2].flatten().cpu()
expected_output = torch.tensor([0.7686, 0.8228, 0.6489, 0.7455, 0.8661, 0.8797, 0.8241, 0.8759])
assert torch_all_close(actual_output, expected_output, atol=5e-3)
def test_encode_decode_f16(self):
vae = ConsistencyDecoderVAE.from_pretrained(
"openai/consistency-decoder", torch_dtype=torch.float16
) # TODO - update
vae.to(torch_device)
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/img2img/sketch-mountains-input.jpg"
).resize((256, 256))
image = (
torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[None, :, :, :]
.half()
.cuda()
)
latent = vae.encode(image).latent_dist.mean
sample = vae.decode(latent, generator=torch.Generator("cpu").manual_seed(0)).sample
actual_output = sample[0, :2, :2, :2].flatten().cpu()
expected_output = torch.tensor(
[-0.0111, -0.0125, -0.0017, -0.0007, 0.1257, 0.1465, 0.1450, 0.1471], dtype=torch.float16
)
assert torch_all_close(actual_output, expected_output, atol=5e-3)
def test_sd_f16(self):
vae = ConsistencyDecoderVAE.from_pretrained(
"openai/consistency-decoder", torch_dtype=torch.float16
) # TODO - update
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, vae=vae, safety_checker=None
)
pipe.to(torch_device)
out = pipe(
"horse", num_inference_steps=2, output_type="pt", generator=torch.Generator("cpu").manual_seed(0)
).images[0]
actual_output = out[:2, :2, :2].flatten().cpu()
expected_output = torch.tensor(
[0.0000, 0.0249, 0.0000, 0.0000, 0.1709, 0.2773, 0.0471, 0.1035], dtype=torch.float16
)
assert torch_all_close(actual_output, expected_output, atol=5e-3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment