Unverified Commit 1c60e094 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Tests] reduce block sizes of UNet and VAE tests (#7560)

* reduce block sizes for unet1d.

* reduce blocks for unet_2d.

* reduce block size for unet_motion

* increase channels.

* correctly increase channels.

* reduce number of layers in unet2dconditionmodel tests.

* reduce block sizes for unet2dconditionmodel tests

* reduce block sizes for unet3dconditionmodel.

* fix: test_feed_forward_chunking

* fix: test_forward_with_norm_groups

* skip spatiotemporal tests on MPS.

* reduce block size in AutoencoderKL.

* reduce block sizes for vqmodel.

* further reduce block size.

* make style.

* Empty-Commit

* reduce sizes for ConsistencyDecoderVAETests

* further reduction.

* further block reductions in AutoencoderKL and AssymetricAutoencoderKL.

* massively reduce the block size in unet2dcontionmodel.

* reduce sizes for unet3d

* fix tests in unet3d.

* reduce blocks further in motion unet.

* fix: output shape

* add attention_head_dim to the test configuration.

* remove unexpected keyword arg

* up a bit.

* groups.

* up again

* fix
parent 71f49a5d
...@@ -53,8 +53,8 @@ enable_full_determinism() ...@@ -53,8 +53,8 @@ enable_full_determinism()
def get_autoencoder_kl_config(block_out_channels=None, norm_num_groups=None): def get_autoencoder_kl_config(block_out_channels=None, norm_num_groups=None):
block_out_channels = block_out_channels or [32, 64] block_out_channels = block_out_channels or [2, 4]
norm_num_groups = norm_num_groups or 32 norm_num_groups = norm_num_groups or 2
init_dict = { init_dict = {
"block_out_channels": block_out_channels, "block_out_channels": block_out_channels,
"in_channels": 3, "in_channels": 3,
...@@ -68,8 +68,8 @@ def get_autoencoder_kl_config(block_out_channels=None, norm_num_groups=None): ...@@ -68,8 +68,8 @@ def get_autoencoder_kl_config(block_out_channels=None, norm_num_groups=None):
def get_asym_autoencoder_kl_config(block_out_channels=None, norm_num_groups=None): def get_asym_autoencoder_kl_config(block_out_channels=None, norm_num_groups=None):
block_out_channels = block_out_channels or [32, 64] block_out_channels = block_out_channels or [2, 4]
norm_num_groups = norm_num_groups or 32 norm_num_groups = norm_num_groups or 2
init_dict = { init_dict = {
"in_channels": 3, "in_channels": 3,
"out_channels": 3, "out_channels": 3,
...@@ -102,8 +102,8 @@ def get_autoencoder_tiny_config(block_out_channels=None): ...@@ -102,8 +102,8 @@ def get_autoencoder_tiny_config(block_out_channels=None):
def get_consistency_vae_config(block_out_channels=None, norm_num_groups=None): def get_consistency_vae_config(block_out_channels=None, norm_num_groups=None):
block_out_channels = block_out_channels or [32, 64] block_out_channels = block_out_channels or [2, 4]
norm_num_groups = norm_num_groups or 32 norm_num_groups = norm_num_groups or 2
return { return {
"encoder_block_out_channels": block_out_channels, "encoder_block_out_channels": block_out_channels,
"encoder_in_channels": 3, "encoder_in_channels": 3,
......
...@@ -54,7 +54,8 @@ class VQModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): ...@@ -54,7 +54,8 @@ class VQModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
def prepare_init_args_and_inputs_for_common(self): def prepare_init_args_and_inputs_for_common(self):
init_dict = { init_dict = {
"block_out_channels": [32, 64], "block_out_channels": [8, 16],
"norm_num_groups": 8,
"in_channels": 3, "in_channels": 3,
"out_channels": 3, "out_channels": 3,
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
......
...@@ -77,7 +77,7 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): ...@@ -77,7 +77,7 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
def prepare_init_args_and_inputs_for_common(self): def prepare_init_args_and_inputs_for_common(self):
init_dict = { init_dict = {
"block_out_channels": (32, 64, 128, 256), "block_out_channels": (8, 8, 16, 16),
"in_channels": 14, "in_channels": 14,
"out_channels": 14, "out_channels": 14,
"time_embedding_type": "positional", "time_embedding_type": "positional",
......
...@@ -63,7 +63,8 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): ...@@ -63,7 +63,8 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
def prepare_init_args_and_inputs_for_common(self): def prepare_init_args_and_inputs_for_common(self):
init_dict = { init_dict = {
"block_out_channels": (32, 64), "block_out_channels": (4, 8),
"norm_num_groups": 2,
"down_block_types": ("DownBlock2D", "AttnDownBlock2D"), "down_block_types": ("DownBlock2D", "AttnDownBlock2D"),
"up_block_types": ("AttnUpBlock2D", "UpBlock2D"), "up_block_types": ("AttnUpBlock2D", "UpBlock2D"),
"attention_head_dim": 3, "attention_head_dim": 3,
...@@ -78,9 +79,8 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): ...@@ -78,9 +79,8 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
def test_mid_block_attn_groups(self): def test_mid_block_attn_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["norm_num_groups"] = 16
init_dict["add_attention"] = True init_dict["add_attention"] = True
init_dict["attn_norm_num_groups"] = 8 init_dict["attn_norm_num_groups"] = 4
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
model.to(torch_device) model.to(torch_device)
......
...@@ -247,33 +247,34 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -247,33 +247,34 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
def dummy_input(self): def dummy_input(self):
batch_size = 4 batch_size = 4
num_channels = 4 num_channels = 4
sizes = (32, 32) sizes = (16, 16)
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device) time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device) encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
@property @property
def input_shape(self): def input_shape(self):
return (4, 32, 32) return (4, 16, 16)
@property @property
def output_shape(self): def output_shape(self):
return (4, 32, 32) return (4, 16, 16)
def prepare_init_args_and_inputs_for_common(self): def prepare_init_args_and_inputs_for_common(self):
init_dict = { init_dict = {
"block_out_channels": (32, 64), "block_out_channels": (4, 8),
"norm_num_groups": 4,
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"), "down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
"up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"), "up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"),
"cross_attention_dim": 32, "cross_attention_dim": 8,
"attention_head_dim": 8, "attention_head_dim": 2,
"out_channels": 4, "out_channels": 4,
"in_channels": 4, "in_channels": 4,
"layers_per_block": 2, "layers_per_block": 1,
"sample_size": 32, "sample_size": 16,
} }
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
...@@ -337,6 +338,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -337,6 +338,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
def test_model_with_attention_head_dim_tuple(self): def test_model_with_attention_head_dim_tuple(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16) init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
...@@ -375,7 +377,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -375,7 +377,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
def test_model_with_cross_attention_dim_tuple(self): def test_model_with_cross_attention_dim_tuple(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["cross_attention_dim"] = (32, 32) init_dict["cross_attention_dim"] = (8, 8)
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
model.to(torch_device) model.to(torch_device)
...@@ -443,6 +445,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -443,6 +445,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
def test_model_attention_slicing(self): def test_model_attention_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16) init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
...@@ -467,6 +470,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -467,6 +470,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
def test_model_sliceable_head_dim(self): def test_model_sliceable_head_dim(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16) init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
...@@ -485,6 +489,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -485,6 +489,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
def test_gradient_checkpointing_is_applied(self): def test_gradient_checkpointing_is_applied(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16) init_dict["attention_head_dim"] = (8, 16)
model_class_copy = copy.copy(self.model_class) model_class_copy = copy.copy(self.model_class)
...@@ -561,6 +566,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -561,6 +566,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
# enable deterministic behavior for gradient checkpointing # enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16) init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
...@@ -571,7 +577,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -571,7 +577,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
model.set_attn_processor(processor) model.set_attn_processor(processor)
model(**inputs_dict, cross_attention_kwargs={"number": 123}).sample model(**inputs_dict, cross_attention_kwargs={"number": 123}).sample
assert processor.counter == 12 assert processor.counter == 8
assert processor.is_run assert processor.is_run
assert processor.number == 123 assert processor.number == 123
...@@ -587,7 +593,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -587,7 +593,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
def test_model_xattn_mask(self, mask_dtype): def test_model_xattn_mask(self, mask_dtype):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16)}) model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16), "block_out_channels": (16, 32)})
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -649,6 +655,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -649,6 +655,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
# enable deterministic behavior for gradient checkpointing # enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16) init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
...@@ -675,6 +682,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -675,6 +682,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
# enable deterministic behavior for gradient checkpointing # enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16) init_dict["attention_head_dim"] = (8, 16)
torch.manual_seed(0) torch.manual_seed(0)
...@@ -714,6 +722,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -714,6 +722,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
# enable deterministic behavior for gradient checkpointing # enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16) init_dict["attention_head_dim"] = (8, 16)
torch.manual_seed(0) torch.manual_seed(0)
...@@ -739,6 +748,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -739,6 +748,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
# enable deterministic behavior for gradient checkpointing # enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16) init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
...@@ -770,6 +780,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -770,6 +780,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
def test_ip_adapter(self): def test_ip_adapter(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16) init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
...@@ -842,6 +853,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -842,6 +853,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
def test_ip_adapter_plus(self): def test_ip_adapter_plus(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16) init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
......
...@@ -41,36 +41,37 @@ class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -41,36 +41,37 @@ class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
batch_size = 4 batch_size = 4
num_channels = 4 num_channels = 4
num_frames = 4 num_frames = 4
sizes = (32, 32) sizes = (16, 16)
noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device) time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device) encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
@property @property
def input_shape(self): def input_shape(self):
return (4, 4, 32, 32) return (4, 4, 16, 16)
@property @property
def output_shape(self): def output_shape(self):
return (4, 4, 32, 32) return (4, 4, 16, 16)
def prepare_init_args_and_inputs_for_common(self): def prepare_init_args_and_inputs_for_common(self):
init_dict = { init_dict = {
"block_out_channels": (32, 64), "block_out_channels": (4, 8),
"norm_num_groups": 4,
"down_block_types": ( "down_block_types": (
"CrossAttnDownBlock3D", "CrossAttnDownBlock3D",
"DownBlock3D", "DownBlock3D",
), ),
"up_block_types": ("UpBlock3D", "CrossAttnUpBlock3D"), "up_block_types": ("UpBlock3D", "CrossAttnUpBlock3D"),
"cross_attention_dim": 32, "cross_attention_dim": 8,
"attention_head_dim": 8, "attention_head_dim": 2,
"out_channels": 4, "out_channels": 4,
"in_channels": 4, "in_channels": 4,
"layers_per_block": 1, "layers_per_block": 1,
"sample_size": 32, "sample_size": 16,
} }
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
...@@ -93,7 +94,7 @@ class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -93,7 +94,7 @@ class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
# Overriding to set `norm_num_groups` needs to be different for this model. # Overriding to set `norm_num_groups` needs to be different for this model.
def test_forward_with_norm_groups(self): def test_forward_with_norm_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (32, 64)
init_dict["norm_num_groups"] = 32 init_dict["norm_num_groups"] = 32
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
...@@ -140,6 +141,7 @@ class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -140,6 +141,7 @@ class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
def test_model_attention_slicing(self): def test_model_attention_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = 8 init_dict["attention_head_dim"] = 8
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
...@@ -163,6 +165,7 @@ class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -163,6 +165,7 @@ class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
def test_feed_forward_chunking(self): def test_feed_forward_chunking(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (32, 64)
init_dict["norm_num_groups"] = 32 init_dict["norm_num_groups"] = 32
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
......
...@@ -46,34 +46,35 @@ class UNetMotionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase) ...@@ -46,34 +46,35 @@ class UNetMotionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase)
def dummy_input(self): def dummy_input(self):
batch_size = 4 batch_size = 4
num_channels = 4 num_channels = 4
num_frames = 8 num_frames = 4
sizes = (32, 32) sizes = (16, 16)
noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device) time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device) encoder_hidden_states = floats_tensor((batch_size, 4, 16)).to(torch_device)
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
@property @property
def input_shape(self): def input_shape(self):
return (4, 8, 32, 32) return (4, 4, 16, 16)
@property @property
def output_shape(self): def output_shape(self):
return (4, 8, 32, 32) return (4, 4, 16, 16)
def prepare_init_args_and_inputs_for_common(self): def prepare_init_args_and_inputs_for_common(self):
init_dict = { init_dict = {
"block_out_channels": (32, 64), "block_out_channels": (16, 32),
"norm_num_groups": 16,
"down_block_types": ("CrossAttnDownBlockMotion", "DownBlockMotion"), "down_block_types": ("CrossAttnDownBlockMotion", "DownBlockMotion"),
"up_block_types": ("UpBlockMotion", "CrossAttnUpBlockMotion"), "up_block_types": ("UpBlockMotion", "CrossAttnUpBlockMotion"),
"cross_attention_dim": 32, "cross_attention_dim": 16,
"num_attention_heads": 4, "num_attention_heads": 2,
"out_channels": 4, "out_channels": 4,
"in_channels": 4, "in_channels": 4,
"layers_per_block": 1, "layers_per_block": 1,
"sample_size": 32, "sample_size": 16,
} }
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
...@@ -194,6 +195,7 @@ class UNetMotionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase) ...@@ -194,6 +195,7 @@ class UNetMotionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase)
def test_feed_forward_chunking(self): def test_feed_forward_chunking(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (32, 64)
init_dict["norm_num_groups"] = 32 init_dict["norm_num_groups"] = 32
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
......
...@@ -24,6 +24,7 @@ from diffusers.utils.import_utils import is_xformers_available ...@@ -24,6 +24,7 @@ from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
enable_full_determinism, enable_full_determinism,
floats_tensor, floats_tensor,
skip_mps,
torch_all_close, torch_all_close,
torch_device, torch_device,
) )
...@@ -36,6 +37,7 @@ logger = logging.get_logger(__name__) ...@@ -36,6 +37,7 @@ logger = logging.get_logger(__name__)
enable_full_determinism() enable_full_determinism()
@skip_mps
class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNetSpatioTemporalConditionModel model_class = UNetSpatioTemporalConditionModel
main_input_name = "sample" main_input_name = "sample"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment