Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
a859b199
Commit
a859b199
authored
Jun 28, 2022
by
Patrick von Platen
Browse files
fix rl model tests
parent
85d991a1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
52 additions
and
23 deletions
+52
-23
src/diffusers/models/unet_rl.py
src/diffusers/models/unet_rl.py
+6
-8
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+46
-15
No files found.
src/diffusers/models/unet_rl.py
View file @
a859b199
...
@@ -122,13 +122,13 @@ class ResidualTemporalBlock(nn.Module):
...
@@ -122,13 +122,13 @@ class ResidualTemporalBlock(nn.Module):
class
TemporalUNet
(
ModelMixin
,
ConfigMixin
):
# (nn.Module):
class
TemporalUNet
(
ModelMixin
,
ConfigMixin
):
# (nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
training_horizon
,
training_horizon
=
128
,
transition_dim
,
transition_dim
=
14
,
cond_dim
,
cond_dim
=
3
,
predict_epsilon
=
False
,
predict_epsilon
=
False
,
clip_denoised
=
True
,
clip_denoised
=
True
,
dim
=
32
,
dim
=
32
,
dim_mults
=
(
1
,
2
,
4
,
8
),
dim_mults
=
(
1
,
4
,
8
),
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -139,7 +139,6 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
...
@@ -139,7 +139,6 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
dims
=
[
transition_dim
,
*
map
(
lambda
m
:
dim
*
m
,
dim_mults
)]
dims
=
[
transition_dim
,
*
map
(
lambda
m
:
dim
*
m
,
dim_mults
)]
in_out
=
list
(
zip
(
dims
[:
-
1
],
dims
[
1
:]))
in_out
=
list
(
zip
(
dims
[:
-
1
],
dims
[
1
:]))
# print(f'[ models/temporal ] Channel dimensions: {in_out}')
time_dim
=
dim
time_dim
=
dim
self
.
time_mlp
=
nn
.
Sequential
(
self
.
time_mlp
=
nn
.
Sequential
(
...
@@ -153,7 +152,6 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
...
@@ -153,7 +152,6 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
self
.
ups
=
nn
.
ModuleList
([])
self
.
ups
=
nn
.
ModuleList
([])
num_resolutions
=
len
(
in_out
)
num_resolutions
=
len
(
in_out
)
print
(
in_out
)
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
in_out
):
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
in_out
):
is_last
=
ind
>=
(
num_resolutions
-
1
)
is_last
=
ind
>=
(
num_resolutions
-
1
)
...
@@ -195,7 +193,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
...
@@ -195,7 +193,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
nn
.
Conv1d
(
dim
,
transition_dim
,
1
),
nn
.
Conv1d
(
dim
,
transition_dim
,
1
),
)
)
def
forward
(
self
,
x
,
time
):
def
forward
(
self
,
x
,
time
steps
):
"""
"""
x : [ batch x horizon x transition ]
x : [ batch x horizon x transition ]
"""
"""
...
@@ -203,7 +201,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
...
@@ -203,7 +201,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
# x = einops.rearrange(x, "b h t -> b t h")
# x = einops.rearrange(x, "b h t -> b t h")
x
=
x
.
permute
(
0
,
2
,
1
)
x
=
x
.
permute
(
0
,
2
,
1
)
t
=
self
.
time_mlp
(
time
)
t
=
self
.
time_mlp
(
time
steps
)
h
=
[]
h
=
[]
for
resnet
,
resnet2
,
downsample
in
self
.
downs
:
for
resnet
,
resnet2
,
downsample
in
self
.
downs
:
...
...
tests/test_modeling_utils.py
View file @
a859b199
...
@@ -190,7 +190,7 @@ class ModelTesterMixin:
...
@@ -190,7 +190,7 @@ class ModelTesterMixin:
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
train
()
model
.
train
()
output
=
model
(
**
inputs_dict
)
output
=
model
(
**
inputs_dict
)
noise
=
torch
.
randn
((
inputs_dict
[
"x"
].
shape
[
0
],)
+
self
.
get_
output_shape
).
to
(
torch_device
)
noise
=
torch
.
randn
((
inputs_dict
[
"x"
].
shape
[
0
],)
+
self
.
output_shape
).
to
(
torch_device
)
loss
=
torch
.
nn
.
functional
.
mse_loss
(
output
,
noise
)
loss
=
torch
.
nn
.
functional
.
mse_loss
(
output
,
noise
)
loss
.
backward
()
loss
.
backward
()
...
@@ -210,11 +210,11 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -210,11 +210,11 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
return
{
"x"
:
noise
,
"timesteps"
:
time_step
}
return
{
"x"
:
noise
,
"timesteps"
:
time_step
}
@
property
@
property
def
get_
input_shape
(
self
):
def
input_shape
(
self
):
return
(
3
,
32
,
32
)
return
(
3
,
32
,
32
)
@
property
@
property
def
get_
output_shape
(
self
):
def
output_shape
(
self
):
return
(
3
,
32
,
32
)
return
(
3
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
def
prepare_init_args_and_inputs_for_common
(
self
):
...
@@ -276,11 +276,11 @@ class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase):
...
@@ -276,11 +276,11 @@ class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase):
return
{
"x"
:
noise
,
"timesteps"
:
time_step
,
"low_res"
:
low_res
}
return
{
"x"
:
noise
,
"timesteps"
:
time_step
,
"low_res"
:
low_res
}
@
property
@
property
def
get_
input_shape
(
self
):
def
input_shape
(
self
):
return
(
3
,
32
,
32
)
return
(
3
,
32
,
32
)
@
property
@
property
def
get_
output_shape
(
self
):
def
output_shape
(
self
):
return
(
6
,
32
,
32
)
return
(
6
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
def
prepare_init_args_and_inputs_for_common
(
self
):
...
@@ -367,11 +367,11 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -367,11 +367,11 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
return
{
"x"
:
noise
,
"timesteps"
:
time_step
,
"transformer_out"
:
emb
}
return
{
"x"
:
noise
,
"timesteps"
:
time_step
,
"transformer_out"
:
emb
}
@
property
@
property
def
get_
input_shape
(
self
):
def
input_shape
(
self
):
return
(
3
,
32
,
32
)
return
(
3
,
32
,
32
)
@
property
@
property
def
get_
output_shape
(
self
):
def
output_shape
(
self
):
return
(
6
,
32
,
32
)
return
(
6
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
def
prepare_init_args_and_inputs_for_common
(
self
):
...
@@ -459,11 +459,11 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -459,11 +459,11 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
return
{
"x"
:
noise
,
"timesteps"
:
time_step
}
return
{
"x"
:
noise
,
"timesteps"
:
time_step
}
@
property
@
property
def
get_
input_shape
(
self
):
def
input_shape
(
self
):
return
(
4
,
32
,
32
)
return
(
4
,
32
,
32
)
@
property
@
property
def
get_
output_shape
(
self
):
def
output_shape
(
self
):
return
(
4
,
32
,
32
)
return
(
4
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
def
prepare_init_args_and_inputs_for_common
(
self
):
...
@@ -552,11 +552,11 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -552,11 +552,11 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
return
{
"x"
:
noise
,
"timesteps"
:
time_step
,
"mu"
:
condition
,
"mask"
:
mask
}
return
{
"x"
:
noise
,
"timesteps"
:
time_step
,
"mu"
:
condition
,
"mask"
:
mask
}
@
property
@
property
def
get_
input_shape
(
self
):
def
input_shape
(
self
):
return
(
4
,
32
,
16
)
return
(
4
,
32
,
16
)
@
property
@
property
def
get_
output_shape
(
self
):
def
output_shape
(
self
):
return
(
4
,
32
,
16
)
return
(
4
,
32
,
16
)
def
prepare_init_args_and_inputs_for_common
(
self
):
def
prepare_init_args_and_inputs_for_common
(
self
):
...
@@ -610,6 +610,38 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -610,6 +610,38 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
class
TemporalUNetModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
TemporalUNetModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
TemporalUNet
model_class
=
TemporalUNet
@
property
def
dummy_input
(
self
):
batch_size
=
4
num_features
=
14
seq_len
=
16
noise
=
floats_tensor
((
batch_size
,
seq_len
,
num_features
)).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
batch_size
).
to
(
torch_device
)
return
{
"x"
:
noise
,
"timesteps"
:
time_step
}
@
property
def
input_shape
(
self
):
return
(
4
,
16
,
14
)
@
property
def
output_shape
(
self
):
return
(
4
,
16
,
14
)
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"training_horizon"
:
128
,
"dim"
:
32
,
"dim_mults"
:
[
1
,
4
,
8
],
"predict_epsilon"
:
False
,
"clip_denoised"
:
True
,
"transition_dim"
:
14
,
"cond_dim"
:
3
,
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_from_pretrained_hub
(
self
):
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
TemporalUNet
.
from_pretrained
(
model
,
loading_info
=
TemporalUNet
.
from_pretrained
(
"fusing/ddpm-unet-rl-hopper-hor128"
,
output_loading_info
=
True
"fusing/ddpm-unet-rl-hopper-hor128"
,
output_loading_info
=
True
...
@@ -640,8 +672,7 @@ class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -640,8 +672,7 @@ class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase):
output_slice
=
output
[
0
,
-
3
:,
-
3
:].
flatten
()
output_slice
=
output
[
0
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
0.2714
,
0.1042
,
-
0.0794
,
-
0.2820
,
0.0803
,
-
0.0811
,
-
0.2345
,
0.0580
,
expected_output_slice
=
torch
.
tensor
([
-
0.2714
,
0.1042
,
-
0.0794
,
-
0.2820
,
0.0803
,
-
0.0811
,
-
0.2345
,
0.0580
,
-
0.0584
])
-
0.0584
])
# fmt: on
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
...
@@ -662,11 +693,11 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -662,11 +693,11 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
return
{
"x"
:
noise
,
"timesteps"
:
time_step
}
return
{
"x"
:
noise
,
"timesteps"
:
time_step
}
@
property
@
property
def
get_
input_shape
(
self
):
def
input_shape
(
self
):
return
(
3
,
32
,
32
)
return
(
3
,
32
,
32
)
@
property
@
property
def
get_
output_shape
(
self
):
def
output_shape
(
self
):
return
(
3
,
32
,
32
)
return
(
3
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
def
prepare_init_args_and_inputs_for_common
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment