"vscode:/vscode.git/clone" did not exist on "76f9b5228953273b22fe26f9157896070a710147"
Unverified Commit 21f023ec authored by Aritra Roy Gosthipaty's avatar Aritra Roy Gosthipaty Committed by GitHub
Browse files

[Tests] reduce the model size in the ddpm fast test (#7797)



* chore: reducing unet size for faster tests

* review suggestions

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 31d9f9ea
......@@ -30,9 +30,10 @@ class DDPMPipelineFastTests(unittest.TestCase):
def dummy_uncond_unet(self):
torch.manual_seed(0)
model = UNet2DModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
block_out_channels=(4, 8),
layers_per_block=1,
norm_num_groups=4,
sample_size=8,
in_channels=3,
out_channels=3,
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
......@@ -58,10 +59,8 @@ class DDPMPipelineFastTests(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array(
[9.956e-01, 5.785e-01, 4.675e-01, 9.930e-01, 0.0, 1.000, 1.199e-03, 2.648e-04, 5.101e-04]
)
assert image.shape == (1, 8, 8, 3)
expected_slice = np.array([0.0, 0.9996672, 0.00329116, 1.0, 0.9995991, 1.0, 0.0060907, 0.00115037, 0.0])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
......@@ -83,7 +82,7 @@ class DDPMPipelineFastTests(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1]
image_eps_slice = image_eps[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
assert image.shape == (1, 8, 8, 3)
tolerance = 1e-2 if torch_device != "mps" else 3e-2
assert np.abs(image_slice.flatten() - image_eps_slice.flatten()).max() < tolerance
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment