"tests/git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "5739930f8ed41567883260ff0ba8dc4f1669175f"
Unverified Commit ea8ae8c6 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Complete set_attn_processor for prior and vae (#3796)



* relax tolerance slightly

* Add more tests

* upload readme

* upload readme

* Apply suggestions from code review

* Improve API Autoencoder KL

* finalize

* finalize tests

* finalize tests

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* up

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 958d9ec7
...@@ -12,13 +12,14 @@ ...@@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, apply_forward_hook from ..utils import BaseOutput, apply_forward_hook
from .attention_processor import AttentionProcessor, AttnProcessor
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
...@@ -156,6 +157,69 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -156,6 +157,69 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
""" """
self.use_slicing = False self.use_slicing = False
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "set_processor"):
processors[f"{name}.processor"] = module.processor
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Parameters:
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `Attention` layers.
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
self.set_attn_processor(AttnProcessor())
@apply_forward_hook @apply_forward_hook
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
......
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Union from typing import Dict, Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -8,6 +8,7 @@ from torch import nn ...@@ -8,6 +8,7 @@ from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput from ..utils import BaseOutput
from .attention import BasicTransformerBlock from .attention import BasicTransformerBlock
from .attention_processor import AttentionProcessor, AttnProcessor
from .embeddings import TimestepEmbedding, Timesteps from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
...@@ -104,6 +105,69 @@ class PriorTransformer(ModelMixin, ConfigMixin): ...@@ -104,6 +105,69 @@ class PriorTransformer(ModelMixin, ConfigMixin):
self.clip_mean = nn.Parameter(torch.zeros(1, embedding_dim)) self.clip_mean = nn.Parameter(torch.zeros(1, embedding_dim))
self.clip_std = nn.Parameter(torch.zeros(1, embedding_dim)) self.clip_std = nn.Parameter(torch.zeros(1, embedding_dim))
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "set_processor"):
processors[f"{name}.processor"] = module.processor
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Parameters:
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `Attention` layers.
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
self.set_attn_processor(AttnProcessor())
def forward( def forward(
self, self,
hidden_states, hidden_states,
......
...@@ -26,9 +26,10 @@ import torch ...@@ -26,9 +26,10 @@ import torch
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from diffusers.models import UNet2DConditionModel from diffusers.models import UNet2DConditionModel
from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0, XFormersAttnProcessor
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
from diffusers.utils import logging, torch_device from diffusers.utils import logging, torch_device
from diffusers.utils.testing_utils import CaptureLogger, require_torch_2, run_test_in_subprocess from diffusers.utils.testing_utils import CaptureLogger, require_torch_2, require_torch_gpu, run_test_in_subprocess
# Will be run via run_test_in_subprocess # Will be run via run_test_in_subprocess
...@@ -150,7 +151,43 @@ class ModelUtilsTest(unittest.TestCase): ...@@ -150,7 +151,43 @@ class ModelUtilsTest(unittest.TestCase):
assert model.config.in_channels == 9 assert model.config.in_channels == 9
class UNetTesterMixin:
def test_forward_signature(self):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["sample", "timestep"]
self.assertListEqual(arg_names[:2], expected_arg_names)
def test_forward_with_norm_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["norm_num_groups"] = 16
init_dict["block_out_channels"] = (16, 32)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
class ModelTesterMixin: class ModelTesterMixin:
main_input_name = None # overwrite in model specific tester class
base_precision = 1e-3
def test_from_save_pretrained(self): def test_from_save_pretrained(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
...@@ -170,12 +207,12 @@ class ModelTesterMixin: ...@@ -170,12 +207,12 @@ class ModelTesterMixin:
with torch.no_grad(): with torch.no_grad():
image = model(**inputs_dict) image = model(**inputs_dict)
if isinstance(image, dict): if isinstance(image, dict):
image = image.sample image = image.to_tuple()[0]
new_image = new_model(**inputs_dict) new_image = new_model(**inputs_dict)
if isinstance(new_image, dict): if isinstance(new_image, dict):
new_image = new_image.sample new_image = new_image.to_tuple()[0]
max_diff = (image - new_image).abs().sum().item() max_diff = (image - new_image).abs().sum().item()
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes") self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
...@@ -223,12 +260,62 @@ class ModelTesterMixin: ...@@ -223,12 +260,62 @@ class ModelTesterMixin:
assert str(error.exception) == f"'{type(model).__name__}' object has no attribute 'does_not_exist'" assert str(error.exception) == f"'{type(model).__name__}' object has no attribute 'does_not_exist'"
@require_torch_gpu
def test_set_attn_processor_for_determinism(self):
torch.use_deterministic_algorithms(False)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
if not hasattr(model, "set_attn_processor"):
# If not has `set_attn_processor`, skip test
return
assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values())
with torch.no_grad():
output_1 = model(**inputs_dict)[0]
model.set_default_attn_processor()
assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values())
with torch.no_grad():
output_2 = model(**inputs_dict)[0]
model.enable_xformers_memory_efficient_attention()
assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
with torch.no_grad():
output_3 = model(**inputs_dict)[0]
model.set_attn_processor(AttnProcessor2_0())
assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values())
with torch.no_grad():
output_4 = model(**inputs_dict)[0]
model.set_attn_processor(AttnProcessor())
assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values())
with torch.no_grad():
output_5 = model(**inputs_dict)[0]
model.set_attn_processor(XFormersAttnProcessor())
assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
with torch.no_grad():
output_6 = model(**inputs_dict)[0]
torch.use_deterministic_algorithms(True)
# make sure that outputs match
assert torch.allclose(output_2, output_1, atol=self.base_precision)
assert torch.allclose(output_2, output_3, atol=self.base_precision)
assert torch.allclose(output_2, output_4, atol=self.base_precision)
assert torch.allclose(output_2, output_5, atol=self.base_precision)
assert torch.allclose(output_2, output_6, atol=self.base_precision)
def test_from_save_pretrained_variant(self): def test_from_save_pretrained_variant(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
if hasattr(model, "set_default_attn_processor"): if hasattr(model, "set_default_attn_processor"):
model.set_default_attn_processor() model.set_default_attn_processor()
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -250,12 +337,12 @@ class ModelTesterMixin: ...@@ -250,12 +337,12 @@ class ModelTesterMixin:
with torch.no_grad(): with torch.no_grad():
image = model(**inputs_dict) image = model(**inputs_dict)
if isinstance(image, dict): if isinstance(image, dict):
image = image.sample image = image.to_tuple()[0]
new_image = new_model(**inputs_dict) new_image = new_model(**inputs_dict)
if isinstance(new_image, dict): if isinstance(new_image, dict):
new_image = new_image.sample new_image = new_image.to_tuple()[0]
max_diff = (image - new_image).abs().sum().item() max_diff = (image - new_image).abs().sum().item()
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes") self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
...@@ -293,11 +380,11 @@ class ModelTesterMixin: ...@@ -293,11 +380,11 @@ class ModelTesterMixin:
with torch.no_grad(): with torch.no_grad():
first = model(**inputs_dict) first = model(**inputs_dict)
if isinstance(first, dict): if isinstance(first, dict):
first = first.sample first = first.to_tuple()[0]
second = model(**inputs_dict) second = model(**inputs_dict)
if isinstance(second, dict): if isinstance(second, dict):
second = second.sample second = second.to_tuple()[0]
out_1 = first.cpu().numpy() out_1 = first.cpu().numpy()
out_2 = second.cpu().numpy() out_2 = second.cpu().numpy()
...@@ -316,43 +403,15 @@ class ModelTesterMixin: ...@@ -316,43 +403,15 @@ class ModelTesterMixin:
output = model(**inputs_dict) output = model(**inputs_dict)
if isinstance(output, dict): if isinstance(output, dict):
output = output.sample output = output.to_tuple()[0]
self.assertIsNotNone(output) self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_forward_with_norm_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["norm_num_groups"] = 16 # input & output have to have the same shape
init_dict["block_out_channels"] = (16, 32) input_tensor = inputs_dict[self.main_input_name]
expected_shape = input_tensor.shape
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_forward_signature(self):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["sample", "timestep"]
self.assertListEqual(arg_names[:2], expected_arg_names)
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
...@@ -378,12 +437,12 @@ class ModelTesterMixin: ...@@ -378,12 +437,12 @@ class ModelTesterMixin:
output_1 = model(**inputs_dict) output_1 = model(**inputs_dict)
if isinstance(output_1, dict): if isinstance(output_1, dict):
output_1 = output_1.sample output_1 = output_1.to_tuple()[0]
output_2 = new_model(**inputs_dict) output_2 = new_model(**inputs_dict)
if isinstance(output_2, dict): if isinstance(output_2, dict):
output_2 = output_2.sample output_2 = output_2.to_tuple()[0]
self.assertEqual(output_1.shape, output_2.shape) self.assertEqual(output_1.shape, output_2.shape)
...@@ -397,9 +456,10 @@ class ModelTesterMixin: ...@@ -397,9 +456,10 @@ class ModelTesterMixin:
output = model(**inputs_dict) output = model(**inputs_dict)
if isinstance(output, dict): if isinstance(output, dict):
output = output.sample output = output.to_tuple()[0]
noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device) input_tensor = inputs_dict[self.main_input_name]
noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise) loss = torch.nn.functional.mse_loss(output, noise)
loss.backward() loss.backward()
...@@ -415,9 +475,10 @@ class ModelTesterMixin: ...@@ -415,9 +475,10 @@ class ModelTesterMixin:
output = model(**inputs_dict) output = model(**inputs_dict)
if isinstance(output, dict): if isinstance(output, dict):
output = output.sample output = output.to_tuple()[0]
noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device) input_tensor = inputs_dict[self.main_input_name]
noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise) loss = torch.nn.functional.mse_loss(output, noise)
loss.backward() loss.backward()
ema_model.step(model.parameters()) ema_model.step(model.parameters())
......
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import inspect
import unittest
import torch
from parameterized import parameterized
from diffusers import PriorTransformer
from diffusers.utils import floats_tensor, slow, torch_all_close, torch_device
from diffusers.utils.testing_utils import enable_full_determinism
from .test_modeling_common import ModelTesterMixin
enable_full_determinism()
class PriorTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = PriorTransformer
main_input_name = "hidden_states"
@property
def dummy_input(self):
batch_size = 4
embedding_dim = 8
num_embeddings = 7
hidden_states = floats_tensor((batch_size, embedding_dim)).to(torch_device)
proj_embedding = floats_tensor((batch_size, embedding_dim)).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, num_embeddings, embedding_dim)).to(torch_device)
return {
"hidden_states": hidden_states,
"timestep": 2,
"proj_embedding": proj_embedding,
"encoder_hidden_states": encoder_hidden_states,
}
def get_dummy_seed_input(self, seed=0):
torch.manual_seed(seed)
batch_size = 4
embedding_dim = 8
num_embeddings = 7
hidden_states = torch.randn((batch_size, embedding_dim)).to(torch_device)
proj_embedding = torch.randn((batch_size, embedding_dim)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, num_embeddings, embedding_dim)).to(torch_device)
return {
"hidden_states": hidden_states,
"timestep": 2,
"proj_embedding": proj_embedding,
"encoder_hidden_states": encoder_hidden_states,
}
@property
def input_shape(self):
return (4, 8)
@property
def output_shape(self):
return (4, 8)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"num_attention_heads": 2,
"attention_head_dim": 4,
"num_layers": 2,
"embedding_dim": 8,
"num_embeddings": 7,
"additional_embeddings": 4,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_from_pretrained_hub(self):
model, loading_info = PriorTransformer.from_pretrained(
"hf-internal-testing/prior-dummy", output_loading_info=True
)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
model.to(torch_device)
hidden_states = model(**self.dummy_input)[0]
assert hidden_states is not None, "Make sure output is not None"
def test_forward_signature(self):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["hidden_states", "timestep"]
self.assertListEqual(arg_names[:2], expected_arg_names)
def test_output_pretrained(self):
model = PriorTransformer.from_pretrained("hf-internal-testing/prior-dummy")
model = model.to(torch_device)
if hasattr(model, "set_default_attn_processor"):
model.set_default_attn_processor()
input = self.get_dummy_seed_input()
with torch.no_grad():
output = model(**input)[0]
output_slice = output[0, :5].flatten().cpu()
print(output_slice)
# Since the VAE Gaussian prior's generator is seeded on the appropriate device,
# the expected output slices are not the same for CPU and GPU.
expected_output_slice = torch.tensor([-1.3436, -0.2870, 0.7538, 0.4368, -0.0239])
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
@slow
class PriorTransformerIntegrationTests(unittest.TestCase):
def get_dummy_seed_input(self, batch_size=1, embedding_dim=768, num_embeddings=77, seed=0):
torch.manual_seed(seed)
batch_size = batch_size
embedding_dim = embedding_dim
num_embeddings = num_embeddings
hidden_states = torch.randn((batch_size, embedding_dim)).to(torch_device)
proj_embedding = torch.randn((batch_size, embedding_dim)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, num_embeddings, embedding_dim)).to(torch_device)
return {
"hidden_states": hidden_states,
"timestep": 2,
"proj_embedding": proj_embedding,
"encoder_hidden_states": encoder_hidden_states,
}
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
@parameterized.expand(
[
# fmt: off
[13, [-0.5861, 0.1283, -0.0931, 0.0882, 0.4476, 0.1329, -0.0498, 0.0640]],
[37, [-0.4913, 0.0110, -0.0483, 0.0541, 0.4954, -0.0170, 0.0354, 0.1651]],
# fmt: on
]
)
def test_kandinsky_prior(self, seed, expected_slice):
model = PriorTransformer.from_pretrained("kandinsky-community/kandinsky-2-1-prior", subfolder="prior")
model.to(torch_device)
input = self.get_dummy_seed_input(seed=seed)
with torch.no_grad():
sample = model(**input)[0]
assert list(sample.shape) == [1, 768]
output_slice = sample[0, :8].flatten().cpu()
print(output_slice)
expected_output_slice = torch.tensor(expected_slice)
assert torch_all_close(output_slice, expected_output_slice, atol=1e-3)
...@@ -20,11 +20,12 @@ import torch ...@@ -20,11 +20,12 @@ import torch
from diffusers import UNet1DModel from diffusers import UNet1DModel
from diffusers.utils import floats_tensor, slow, torch_device from diffusers.utils import floats_tensor, slow, torch_device
from .test_modeling_common import ModelTesterMixin from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
class UNet1DModelTests(ModelTesterMixin, unittest.TestCase): class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet1DModel model_class = UNet1DModel
main_input_name = "sample"
@property @property
def dummy_input(self): def dummy_input(self):
...@@ -153,8 +154,9 @@ class UNet1DModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -153,8 +154,9 @@ class UNet1DModelTests(ModelTesterMixin, unittest.TestCase):
assert (output_max - 0.0607).abs() < 4e-4 assert (output_max - 0.0607).abs() < 4e-4
class UNetRLModelTests(ModelTesterMixin, unittest.TestCase): class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet1DModel model_class = UNet1DModel
main_input_name = "sample"
@property @property
def dummy_input(self): def dummy_input(self):
......
...@@ -23,7 +23,7 @@ from diffusers import UNet2DModel ...@@ -23,7 +23,7 @@ from diffusers import UNet2DModel
from diffusers.utils import floats_tensor, logging, slow, torch_all_close, torch_device from diffusers.utils import floats_tensor, logging, slow, torch_all_close, torch_device
from diffusers.utils.testing_utils import enable_full_determinism from diffusers.utils.testing_utils import enable_full_determinism
from .test_modeling_common import ModelTesterMixin from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -31,8 +31,9 @@ logger = logging.get_logger(__name__) ...@@ -31,8 +31,9 @@ logger = logging.get_logger(__name__)
enable_full_determinism() enable_full_determinism()
class Unet2DModelTests(ModelTesterMixin, unittest.TestCase): class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet2DModel model_class = UNet2DModel
main_input_name = "sample"
@property @property
def dummy_input(self): def dummy_input(self):
...@@ -68,8 +69,9 @@ class Unet2DModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -68,8 +69,9 @@ class Unet2DModelTests(ModelTesterMixin, unittest.TestCase):
return init_dict, inputs_dict return init_dict, inputs_dict
class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet2DModel model_class = UNet2DModel
main_input_name = "sample"
@property @property
def dummy_input(self): def dummy_input(self):
...@@ -182,8 +184,9 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -182,8 +184,9 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3)) self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3))
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet2DModel model_class = UNet2DModel
main_input_name = "sample"
@property @property
def dummy_input(self, sizes=(32, 32)): def dummy_input(self, sizes=(32, 32)):
......
...@@ -36,7 +36,7 @@ from diffusers.utils import ( ...@@ -36,7 +36,7 @@ from diffusers.utils import (
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import enable_full_determinism from diffusers.utils.testing_utils import enable_full_determinism
from .test_modeling_common import ModelTesterMixin from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -120,8 +120,9 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True): ...@@ -120,8 +120,9 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True):
return custom_diffusion_attn_procs return custom_diffusion_attn_procs
class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet2DConditionModel model_class = UNet2DConditionModel
main_input_name = "sample"
@property @property
def dummy_input(self): def dummy_input(self):
......
...@@ -31,7 +31,7 @@ from diffusers.utils import ( ...@@ -31,7 +31,7 @@ from diffusers.utils import (
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import enable_full_determinism from diffusers.utils.testing_utils import enable_full_determinism
from .test_modeling_common import ModelTesterMixin from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
enable_full_determinism() enable_full_determinism()
...@@ -73,8 +73,9 @@ def create_lora_layers(model, mock_weights: bool = True): ...@@ -73,8 +73,9 @@ def create_lora_layers(model, mock_weights: bool = True):
@skip_mps @skip_mps
class UNet3DConditionModelTests(ModelTesterMixin, unittest.TestCase): class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet3DConditionModel model_class = UNet3DConditionModel
main_input_name = "sample"
@property @property
def dummy_input(self): def dummy_input(self):
......
...@@ -24,14 +24,16 @@ from diffusers.utils import floats_tensor, load_hf_numpy, require_torch_gpu, slo ...@@ -24,14 +24,16 @@ from diffusers.utils import floats_tensor, load_hf_numpy, require_torch_gpu, slo
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import enable_full_determinism from diffusers.utils.testing_utils import enable_full_determinism
from .test_modeling_common import ModelTesterMixin from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
enable_full_determinism() enable_full_determinism()
class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = AutoencoderKL model_class = AutoencoderKL
main_input_name = "sample"
base_precision = 1e-2
@property @property
def dummy_input(self): def dummy_input(self):
......
...@@ -21,14 +21,15 @@ from diffusers import VQModel ...@@ -21,14 +21,15 @@ from diffusers import VQModel
from diffusers.utils import floats_tensor, torch_device from diffusers.utils import floats_tensor, torch_device
from diffusers.utils.testing_utils import enable_full_determinism from diffusers.utils.testing_utils import enable_full_determinism
from .test_modeling_common import ModelTesterMixin from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
enable_full_determinism() enable_full_determinism()
class VQModelTests(ModelTesterMixin, unittest.TestCase): class VQModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = VQModel model_class = VQModel
main_input_name = "sample"
@property @property
def dummy_input(self, sizes=(32, 32)): def dummy_input(self, sizes=(32, 32)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment