Commit a859b199 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

fix rl model tests

parent 85d991a1
...@@ -122,13 +122,13 @@ class ResidualTemporalBlock(nn.Module): ...@@ -122,13 +122,13 @@ class ResidualTemporalBlock(nn.Module):
class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
def __init__( def __init__(
self, self,
training_horizon, training_horizon=128,
transition_dim, transition_dim=14,
cond_dim, cond_dim=3,
predict_epsilon=False, predict_epsilon=False,
clip_denoised=True, clip_denoised=True,
dim=32, dim=32,
dim_mults=(1, 2, 4, 8), dim_mults=(1, 4, 8),
): ):
super().__init__() super().__init__()
...@@ -139,7 +139,6 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): ...@@ -139,7 +139,6 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:])) in_out = list(zip(dims[:-1], dims[1:]))
# print(f'[ models/temporal ] Channel dimensions: {in_out}')
time_dim = dim time_dim = dim
self.time_mlp = nn.Sequential( self.time_mlp = nn.Sequential(
...@@ -153,7 +152,6 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): ...@@ -153,7 +152,6 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
self.ups = nn.ModuleList([]) self.ups = nn.ModuleList([])
num_resolutions = len(in_out) num_resolutions = len(in_out)
print(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out): for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1) is_last = ind >= (num_resolutions - 1)
...@@ -195,7 +193,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): ...@@ -195,7 +193,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
nn.Conv1d(dim, transition_dim, 1), nn.Conv1d(dim, transition_dim, 1),
) )
def forward(self, x, time): def forward(self, x, timesteps):
""" """
x : [ batch x horizon x transition ] x : [ batch x horizon x transition ]
""" """
...@@ -203,7 +201,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): ...@@ -203,7 +201,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
# x = einops.rearrange(x, "b h t -> b t h") # x = einops.rearrange(x, "b h t -> b t h")
x = x.permute(0, 2, 1) x = x.permute(0, 2, 1)
t = self.time_mlp(time) t = self.time_mlp(timesteps)
h = [] h = []
for resnet, resnet2, downsample in self.downs: for resnet, resnet2, downsample in self.downs:
......
...@@ -190,7 +190,7 @@ class ModelTesterMixin: ...@@ -190,7 +190,7 @@ class ModelTesterMixin:
model.to(torch_device) model.to(torch_device)
model.train() model.train()
output = model(**inputs_dict) output = model(**inputs_dict)
noise = torch.randn((inputs_dict["x"].shape[0],) + self.get_output_shape).to(torch_device) noise = torch.randn((inputs_dict["x"].shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise) loss = torch.nn.functional.mse_loss(output, noise)
loss.backward() loss.backward()
...@@ -210,11 +210,11 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -210,11 +210,11 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
return {"x": noise, "timesteps": time_step} return {"x": noise, "timesteps": time_step}
@property @property
def get_input_shape(self): def input_shape(self):
return (3, 32, 32) return (3, 32, 32)
@property @property
def get_output_shape(self): def output_shape(self):
return (3, 32, 32) return (3, 32, 32)
def prepare_init_args_and_inputs_for_common(self): def prepare_init_args_and_inputs_for_common(self):
...@@ -276,11 +276,11 @@ class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase): ...@@ -276,11 +276,11 @@ class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase):
return {"x": noise, "timesteps": time_step, "low_res": low_res} return {"x": noise, "timesteps": time_step, "low_res": low_res}
@property @property
def get_input_shape(self): def input_shape(self):
return (3, 32, 32) return (3, 32, 32)
@property @property
def get_output_shape(self): def output_shape(self):
return (6, 32, 32) return (6, 32, 32)
def prepare_init_args_and_inputs_for_common(self): def prepare_init_args_and_inputs_for_common(self):
...@@ -367,11 +367,11 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -367,11 +367,11 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
return {"x": noise, "timesteps": time_step, "transformer_out": emb} return {"x": noise, "timesteps": time_step, "transformer_out": emb}
@property @property
def get_input_shape(self): def input_shape(self):
return (3, 32, 32) return (3, 32, 32)
@property @property
def get_output_shape(self): def output_shape(self):
return (6, 32, 32) return (6, 32, 32)
def prepare_init_args_and_inputs_for_common(self): def prepare_init_args_and_inputs_for_common(self):
...@@ -459,11 +459,11 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -459,11 +459,11 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
return {"x": noise, "timesteps": time_step} return {"x": noise, "timesteps": time_step}
@property @property
def get_input_shape(self): def input_shape(self):
return (4, 32, 32) return (4, 32, 32)
@property @property
def get_output_shape(self): def output_shape(self):
return (4, 32, 32) return (4, 32, 32)
def prepare_init_args_and_inputs_for_common(self): def prepare_init_args_and_inputs_for_common(self):
...@@ -552,11 +552,11 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -552,11 +552,11 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
return {"x": noise, "timesteps": time_step, "mu": condition, "mask": mask} return {"x": noise, "timesteps": time_step, "mu": condition, "mask": mask}
@property @property
def get_input_shape(self): def input_shape(self):
return (4, 32, 16) return (4, 32, 16)
@property @property
def get_output_shape(self): def output_shape(self):
return (4, 32, 16) return (4, 32, 16)
def prepare_init_args_and_inputs_for_common(self): def prepare_init_args_and_inputs_for_common(self):
...@@ -610,6 +610,38 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -610,6 +610,38 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase): class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase):
model_class = TemporalUNet model_class = TemporalUNet
@property
def dummy_input(self):
batch_size = 4
num_features = 14
seq_len = 16
noise = floats_tensor((batch_size, seq_len, num_features)).to(torch_device)
time_step = torch.tensor([10] * batch_size).to(torch_device)
return {"x": noise, "timesteps": time_step}
@property
def input_shape(self):
return (4, 16, 14)
@property
def output_shape(self):
return (4, 16, 14)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"training_horizon": 128,
"dim": 32,
"dim_mults": [1, 4, 8],
"predict_epsilon": False,
"clip_denoised": True,
"transition_dim": 14,
"cond_dim": 3,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_from_pretrained_hub(self): def test_from_pretrained_hub(self):
model, loading_info = TemporalUNet.from_pretrained( model, loading_info = TemporalUNet.from_pretrained(
"fusing/ddpm-unet-rl-hopper-hor128", output_loading_info=True "fusing/ddpm-unet-rl-hopper-hor128", output_loading_info=True
...@@ -640,8 +672,7 @@ class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -640,8 +672,7 @@ class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase):
output_slice = output[0, -3:, -3:].flatten() output_slice = output[0, -3:, -3:].flatten()
# fmt: off # fmt: off
expected_output_slice = torch.tensor([-0.2714, 0.1042, -0.0794, -0.2820, 0.0803, -0.0811, -0.2345, 0.0580, expected_output_slice = torch.tensor([-0.2714, 0.1042, -0.0794, -0.2820, 0.0803, -0.0811, -0.2345, 0.0580, -0.0584])
-0.0584])
# fmt: on # fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
...@@ -662,11 +693,11 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -662,11 +693,11 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
return {"x": noise, "timesteps": time_step} return {"x": noise, "timesteps": time_step}
@property @property
def get_input_shape(self): def input_shape(self):
return (3, 32, 32) return (3, 32, 32)
@property @property
def get_output_shape(self): def output_shape(self):
return (3, 32, 32) return (3, 32, 32)
def prepare_init_args_and_inputs_for_common(self): def prepare_init_args_and_inputs_for_common(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment