Unverified Commit 1c60e094 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Tests] reduce block sizes of UNet and VAE tests (#7560)

* reduce block sizes for unet1d.

* reduce blocks for unet_2d.

* reduce block size for unet_motion

* increase channels.

* correctly increase channels.

* reduce number of layers in unet2dconditionmodel tests.

* reduce block sizes for unet2dconditionmodel tests

* reduce block sizes for unet3dconditionmodel.

* fix: test_feed_forward_chunking

* fix: test_forward_with_norm_groups

* skip spatiotemporal tests on MPS.

* reduce block size in AutoencoderKL.

* reduce block sizes for vqmodel.

* further reduce block size.

* make style.

* Empty-Commit

* reduce sizes for ConsistencyDecoderVAETests

* further reduction.

* further block reductions in AutoencoderKL and AssymetricAutoencoderKL.

* massively reduce the block size in unet2dcontionmodel.

* reduce sizes for unet3d

* fix tests in unet3d.

* reduce blocks further in motion unet.

* fix: output shape

* add attention_head_dim to the test configuration.

* remove unexpected keyword arg

* up a bit.

* groups.

* up again

* fix
parent 71f49a5d
......@@ -53,8 +53,8 @@ enable_full_determinism()
def get_autoencoder_kl_config(block_out_channels=None, norm_num_groups=None):
block_out_channels = block_out_channels or [32, 64]
norm_num_groups = norm_num_groups or 32
block_out_channels = block_out_channels or [2, 4]
norm_num_groups = norm_num_groups or 2
init_dict = {
"block_out_channels": block_out_channels,
"in_channels": 3,
......@@ -68,8 +68,8 @@ def get_autoencoder_kl_config(block_out_channels=None, norm_num_groups=None):
def get_asym_autoencoder_kl_config(block_out_channels=None, norm_num_groups=None):
block_out_channels = block_out_channels or [32, 64]
norm_num_groups = norm_num_groups or 32
block_out_channels = block_out_channels or [2, 4]
norm_num_groups = norm_num_groups or 2
init_dict = {
"in_channels": 3,
"out_channels": 3,
......@@ -102,8 +102,8 @@ def get_autoencoder_tiny_config(block_out_channels=None):
def get_consistency_vae_config(block_out_channels=None, norm_num_groups=None):
block_out_channels = block_out_channels or [32, 64]
norm_num_groups = norm_num_groups or 32
block_out_channels = block_out_channels or [2, 4]
norm_num_groups = norm_num_groups or 2
return {
"encoder_block_out_channels": block_out_channels,
"encoder_in_channels": 3,
......
......@@ -54,7 +54,8 @@ class VQModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": [32, 64],
"block_out_channels": [8, 16],
"norm_num_groups": 8,
"in_channels": 3,
"out_channels": 3,
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
......
......@@ -77,7 +77,7 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": (32, 64, 128, 256),
"block_out_channels": (8, 8, 16, 16),
"in_channels": 14,
"out_channels": 14,
"time_embedding_type": "positional",
......
......@@ -63,7 +63,8 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": (32, 64),
"block_out_channels": (4, 8),
"norm_num_groups": 2,
"down_block_types": ("DownBlock2D", "AttnDownBlock2D"),
"up_block_types": ("AttnUpBlock2D", "UpBlock2D"),
"attention_head_dim": 3,
......@@ -78,9 +79,8 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
def test_mid_block_attn_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["norm_num_groups"] = 16
init_dict["add_attention"] = True
init_dict["attn_norm_num_groups"] = 8
init_dict["attn_norm_num_groups"] = 4
model = self.model_class(**init_dict)
model.to(torch_device)
......
......@@ -247,33 +247,34 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
def dummy_input(self):
batch_size = 4
num_channels = 4
sizes = (32, 32)
sizes = (16, 16)
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
@property
def input_shape(self):
return (4, 32, 32)
return (4, 16, 16)
@property
def output_shape(self):
return (4, 32, 32)
return (4, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": (32, 64),
"block_out_channels": (4, 8),
"norm_num_groups": 4,
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
"up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"),
"cross_attention_dim": 32,
"attention_head_dim": 8,
"cross_attention_dim": 8,
"attention_head_dim": 2,
"out_channels": 4,
"in_channels": 4,
"layers_per_block": 2,
"sample_size": 32,
"layers_per_block": 1,
"sample_size": 16,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
......@@ -337,6 +338,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
def test_model_with_attention_head_dim_tuple(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
......@@ -375,7 +377,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
def test_model_with_cross_attention_dim_tuple(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["cross_attention_dim"] = (32, 32)
init_dict["cross_attention_dim"] = (8, 8)
model = self.model_class(**init_dict)
model.to(torch_device)
......@@ -443,6 +445,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
def test_model_attention_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
......@@ -467,6 +470,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
def test_model_sliceable_head_dim(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
......@@ -485,6 +489,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
def test_gradient_checkpointing_is_applied(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
model_class_copy = copy.copy(self.model_class)
......@@ -561,6 +566,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
......@@ -571,7 +577,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
model.set_attn_processor(processor)
model(**inputs_dict, cross_attention_kwargs={"number": 123}).sample
assert processor.counter == 12
assert processor.counter == 8
assert processor.is_run
assert processor.number == 123
......@@ -587,7 +593,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
def test_model_xattn_mask(self, mask_dtype):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16)})
model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16), "block_out_channels": (16, 32)})
model.to(torch_device)
model.eval()
......@@ -649,6 +655,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
......@@ -675,6 +682,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
torch.manual_seed(0)
......@@ -714,6 +722,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
torch.manual_seed(0)
......@@ -739,6 +748,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
......@@ -770,6 +780,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
def test_ip_adapter(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
......@@ -842,6 +853,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
def test_ip_adapter_plus(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
......
......@@ -41,36 +41,37 @@ class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
batch_size = 4
num_channels = 4
num_frames = 4
sizes = (32, 32)
sizes = (16, 16)
noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
@property
def input_shape(self):
return (4, 4, 32, 32)
return (4, 4, 16, 16)
@property
def output_shape(self):
return (4, 4, 32, 32)
return (4, 4, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": (32, 64),
"block_out_channels": (4, 8),
"norm_num_groups": 4,
"down_block_types": (
"CrossAttnDownBlock3D",
"DownBlock3D",
),
"up_block_types": ("UpBlock3D", "CrossAttnUpBlock3D"),
"cross_attention_dim": 32,
"attention_head_dim": 8,
"cross_attention_dim": 8,
"attention_head_dim": 2,
"out_channels": 4,
"in_channels": 4,
"layers_per_block": 1,
"sample_size": 32,
"sample_size": 16,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
......@@ -93,7 +94,7 @@ class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
# Overriding to set `norm_num_groups` needs to be different for this model.
def test_forward_with_norm_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (32, 64)
init_dict["norm_num_groups"] = 32
model = self.model_class(**init_dict)
......@@ -140,6 +141,7 @@ class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
def test_model_attention_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = 8
model = self.model_class(**init_dict)
......@@ -163,6 +165,7 @@ class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
def test_feed_forward_chunking(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (32, 64)
init_dict["norm_num_groups"] = 32
model = self.model_class(**init_dict)
......
......@@ -46,34 +46,35 @@ class UNetMotionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase)
def dummy_input(self):
batch_size = 4
num_channels = 4
num_frames = 8
sizes = (32, 32)
num_frames = 4
sizes = (16, 16)
noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 16)).to(torch_device)
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
@property
def input_shape(self):
return (4, 8, 32, 32)
return (4, 4, 16, 16)
@property
def output_shape(self):
return (4, 8, 32, 32)
return (4, 4, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": (32, 64),
"block_out_channels": (16, 32),
"norm_num_groups": 16,
"down_block_types": ("CrossAttnDownBlockMotion", "DownBlockMotion"),
"up_block_types": ("UpBlockMotion", "CrossAttnUpBlockMotion"),
"cross_attention_dim": 32,
"num_attention_heads": 4,
"cross_attention_dim": 16,
"num_attention_heads": 2,
"out_channels": 4,
"in_channels": 4,
"layers_per_block": 1,
"sample_size": 32,
"sample_size": 16,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
......@@ -194,6 +195,7 @@ class UNetMotionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase)
def test_feed_forward_chunking(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (32, 64)
init_dict["norm_num_groups"] = 32
model = self.model_class(**init_dict)
......
......@@ -24,6 +24,7 @@ from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
skip_mps,
torch_all_close,
torch_device,
)
......@@ -36,6 +37,7 @@ logger = logging.get_logger(__name__)
enable_full_determinism()
@skip_mps
class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNetSpatioTemporalConditionModel
main_input_name = "sample"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment