Commit 7b9b946c authored by patil-suraj's avatar patil-suraj
Browse files

add tests for downsample block

parent b9de7172
...@@ -22,7 +22,7 @@ import numpy as np ...@@ -22,7 +22,7 @@ import numpy as np
import torch import torch
from diffusers.models.embeddings import get_timestep_embedding from diffusers.models.embeddings import get_timestep_embedding
from diffusers.models.resnet import Upsample from diffusers.models.resnet import Downsample, Upsample
from diffusers.testing_utils import floats_tensor, slow, torch_device from diffusers.testing_utils import floats_tensor, slow, torch_device
...@@ -164,3 +164,58 @@ class UpsampleBlockTests(unittest.TestCase): ...@@ -164,3 +164,58 @@ class UpsampleBlockTests(unittest.TestCase):
output_slice = upsampled[0, -1, -3:, -3:] output_slice = upsampled[0, -1, -3:, -3:]
expected_slice = torch.tensor([-0.3028, -0.1582, 0.0071, 0.0350, -0.4799, -0.1139, 0.1056, -0.1153, -0.1046]) expected_slice = torch.tensor([-0.3028, -0.1582, 0.0071, 0.0350, -0.4799, -0.1139, 0.1056, -0.1153, -0.1046])
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
class DownsampleBlockTests(unittest.TestCase):
def test_downsample_default(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64)
downsample = Downsample(channels=32, use_conv=False)
with torch.no_grad():
downsampled = downsample(sample)
assert downsampled.shape == (1, 32, 32, 32)
output_slice = downsampled[0, -1, -3:, -3:]
expected_slice = torch.tensor([-0.0513, -0.3889, 0.0640, 0.0836, -0.5460, -0.0341, -0.0169, -0.6967, 0.1179])
max_diff = (output_slice.flatten() - expected_slice).abs().sum().item()
assert max_diff <= 1e-3
# assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-1)
def test_downsample_with_conv(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64)
downsample = Downsample(channels=32, use_conv=True)
with torch.no_grad():
downsampled = downsample(sample)
assert downsampled.shape == (1, 32, 32, 32)
output_slice = downsampled[0, -1, -3:, -3:]
expected_slice = torch.tensor(
[0.9267, 0.5878, 0.3337, 1.2321, -0.1191, -0.3984, -0.7532, -0.0715, -0.3913],
)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_downsample_with_conv_pad1(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64)
downsample = Downsample(channels=32, use_conv=True, padding=1)
with torch.no_grad():
downsampled = downsample(sample)
assert downsampled.shape == (1, 32, 32, 32)
output_slice = downsampled[0, -1, -3:, -3:]
expected_slice = torch.tensor([0.9267, 0.5878, 0.3337, 1.2321, -0.1191, -0.3984, -0.7532, -0.0715, -0.3913])
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_downsample_with_conv_out_dim(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64)
downsample = Downsample(channels=32, use_conv=True, out_channels=16)
with torch.no_grad():
downsampled = downsample(sample)
assert downsampled.shape == (1, 16, 32, 32)
output_slice = downsampled[0, -1, -3:, -3:]
expected_slice = torch.tensor([-0.6586, 0.5985, 0.0721, 0.1256, -0.1492, 0.4436, -0.2544, 0.5021, 1.1522])
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment