Commit 7dc71897 authored by patil-suraj's avatar patil-suraj
Browse files

add UnetModelTests

parent 800b2770
......@@ -82,44 +82,25 @@ class ConfigTester(unittest.TestCase):
assert config == new_config
class ModelTesterMixin(unittest.TestCase):
@property
def dummy_input(self):
batch_size = 4
num_channels = 3
sizes = (32, 32)
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
return (noise, time_step)
class ModelTesterMixin:
def test_from_pretrained_save_pretrained(self):
model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
new_model = UNetModel.from_pretrained(tmpdirname)
new_model.to(torch_device)
dummy_input = self.dummy_input
image = model(*dummy_input)
new_image = new_model(*dummy_input)
with torch.no_grad():
image = model(**inputs_dict)
new_image = new_model(**inputs_dict)
assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"
def test_from_pretrained_hub(self):
model = UNetModel.from_pretrained("fusing/ddpm_dummy")
model.to(torch_device)
image = model(*self.dummy_input)
assert image is not None, "Make sure output is not None"
def test_save_load(self):
pass
def test_determinism(self):
pass
......@@ -137,6 +118,41 @@ class ModelTesterMixin(unittest.TestCase):
pass
class UnetModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNetModel
@property
def dummy_input(self):
batch_size = 4
num_channels = 3
sizes = (32, 32)
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
return {"x": noise, "t": time_step}
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"ch": 32,
"ch_mult": (1, 2),
"num_res_blocks": 2,
"attn_resolutions": (16,),
"resolution": 32,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_from_pretrained_hub(self):
model = UNetModel.from_pretrained("fusing/ddpm_dummy")
model.to(torch_device)
image = model(**self.dummy_input)
assert image is not None, "Make sure output is not None"
class PipelineTesterMixin(unittest.TestCase):
def test_from_pretrained_save_pretrained(self):
# 1. Load models
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment