Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
a859b199
Commit
a859b199
authored
Jun 28, 2022
by
Patrick von Platen
Browse files
fix rl model tests
parent
85d991a1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
52 additions
and
23 deletions
+52
-23
src/diffusers/models/unet_rl.py
src/diffusers/models/unet_rl.py
+6
-8
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+46
-15
No files found.
src/diffusers/models/unet_rl.py
View file @
a859b199
...
@@ -122,13 +122,13 @@ class ResidualTemporalBlock(nn.Module):
...
@@ -122,13 +122,13 @@ class ResidualTemporalBlock(nn.Module):
class
TemporalUNet
(
ModelMixin
,
ConfigMixin
):
# (nn.Module):
class
TemporalUNet
(
ModelMixin
,
ConfigMixin
):
# (nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
training_horizon
,
training_horizon
=
128
,
transition_dim
,
transition_dim
=
14
,
cond_dim
,
cond_dim
=
3
,
predict_epsilon
=
False
,
predict_epsilon
=
False
,
clip_denoised
=
True
,
clip_denoised
=
True
,
dim
=
32
,
dim
=
32
,
dim_mults
=
(
1
,
2
,
4
,
8
),
dim_mults
=
(
1
,
4
,
8
),
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -139,7 +139,6 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
...
@@ -139,7 +139,6 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
dims
=
[
transition_dim
,
*
map
(
lambda
m
:
dim
*
m
,
dim_mults
)]
dims
=
[
transition_dim
,
*
map
(
lambda
m
:
dim
*
m
,
dim_mults
)]
in_out
=
list
(
zip
(
dims
[:
-
1
],
dims
[
1
:]))
in_out
=
list
(
zip
(
dims
[:
-
1
],
dims
[
1
:]))
# print(f'[ models/temporal ] Channel dimensions: {in_out}')
time_dim
=
dim
time_dim
=
dim
self
.
time_mlp
=
nn
.
Sequential
(
self
.
time_mlp
=
nn
.
Sequential
(
...
@@ -153,7 +152,6 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
...
@@ -153,7 +152,6 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
self
.
ups
=
nn
.
ModuleList
([])
self
.
ups
=
nn
.
ModuleList
([])
num_resolutions
=
len
(
in_out
)
num_resolutions
=
len
(
in_out
)
print
(
in_out
)
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
in_out
):
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
in_out
):
is_last
=
ind
>=
(
num_resolutions
-
1
)
is_last
=
ind
>=
(
num_resolutions
-
1
)
...
@@ -195,7 +193,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
...
@@ -195,7 +193,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
nn
.
Conv1d
(
dim
,
transition_dim
,
1
),
nn
.
Conv1d
(
dim
,
transition_dim
,
1
),
)
)
def
forward
(
self
,
x
,
time
):
def
forward
(
self
,
x
,
time
steps
):
"""
"""
x : [ batch x horizon x transition ]
x : [ batch x horizon x transition ]
"""
"""
...
@@ -203,7 +201,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
...
@@ -203,7 +201,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
# x = einops.rearrange(x, "b h t -> b t h")
# x = einops.rearrange(x, "b h t -> b t h")
x
=
x
.
permute
(
0
,
2
,
1
)
x
=
x
.
permute
(
0
,
2
,
1
)
t
=
self
.
time_mlp
(
time
)
t
=
self
.
time_mlp
(
time
steps
)
h
=
[]
h
=
[]
for
resnet
,
resnet2
,
downsample
in
self
.
downs
:
for
resnet
,
resnet2
,
downsample
in
self
.
downs
:
...
...
tests/test_modeling_utils.py
View file @
a859b199
...
@@ -190,7 +190,7 @@ class ModelTesterMixin:
...
@@ -190,7 +190,7 @@ class ModelTesterMixin:
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
train
()
model
.
train
()
output
=
model
(
**
inputs_dict
)
output
=
model
(
**
inputs_dict
)
noise
=
torch
.
randn
((
inputs_dict
[
"x"
].
shape
[
0
],)
+
self
.
get_
output_shape
).
to
(
torch_device
)
noise
=
torch
.
randn
((
inputs_dict
[
"x"
].
shape
[
0
],)
+
self
.
output_shape
).
to
(
torch_device
)
loss
=
torch
.
nn
.
functional
.
mse_loss
(
output
,
noise
)
loss
=
torch
.
nn
.
functional
.
mse_loss
(
output
,
noise
)
loss
.
backward
()
loss
.
backward
()
...
@@ -210,11 +210,11 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -210,11 +210,11 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
return
{
"x"
:
noise
,
"timesteps"
:
time_step
}
return
{
"x"
:
noise
,
"timesteps"
:
time_step
}
@
property
@
property
def
get_
input_shape
(
self
):
def
input_shape
(
self
):
return
(
3
,
32
,
32
)
return
(
3
,
32
,
32
)
@
property
@
property
def
get_
output_shape
(
self
):
def
output_shape
(
self
):
return
(
3
,
32
,
32
)
return
(
3
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
def
prepare_init_args_and_inputs_for_common
(
self
):
...
@@ -276,11 +276,11 @@ class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase):
...
@@ -276,11 +276,11 @@ class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase):
return
{
"x"
:
noise
,
"timesteps"
:
time_step
,
"low_res"
:
low_res
}
return
{
"x"
:
noise
,
"timesteps"
:
time_step
,
"low_res"
:
low_res
}
@
property
@
property
def
get_
input_shape
(
self
):
def
input_shape
(
self
):
return
(
3
,
32
,
32
)
return
(
3
,
32
,
32
)
@
property
@
property
def
get_
output_shape
(
self
):
def
output_shape
(
self
):
return
(
6
,
32
,
32
)
return
(
6
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
def
prepare_init_args_and_inputs_for_common
(
self
):
...
@@ -367,11 +367,11 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -367,11 +367,11 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
return
{
"x"
:
noise
,
"timesteps"
:
time_step
,
"transformer_out"
:
emb
}
return
{
"x"
:
noise
,
"timesteps"
:
time_step
,
"transformer_out"
:
emb
}
@
property
@
property
def
get_
input_shape
(
self
):
def
input_shape
(
self
):
return
(
3
,
32
,
32
)
return
(
3
,
32
,
32
)
@
property
@
property
def
get_
output_shape
(
self
):
def
output_shape
(
self
):
return
(
6
,
32
,
32
)
return
(
6
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
def
prepare_init_args_and_inputs_for_common
(
self
):
...
@@ -459,11 +459,11 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -459,11 +459,11 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
return
{
"x"
:
noise
,
"timesteps"
:
time_step
}
return
{
"x"
:
noise
,
"timesteps"
:
time_step
}
@
property
@
property
def
get_
input_shape
(
self
):
def
input_shape
(
self
):
return
(
4
,
32
,
32
)
return
(
4
,
32
,
32
)
@
property
@
property
def
get_
output_shape
(
self
):
def
output_shape
(
self
):
return
(
4
,
32
,
32
)
return
(
4
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
def
prepare_init_args_and_inputs_for_common
(
self
):
...
@@ -552,11 +552,11 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -552,11 +552,11 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
return
{
"x"
:
noise
,
"timesteps"
:
time_step
,
"mu"
:
condition
,
"mask"
:
mask
}
return
{
"x"
:
noise
,
"timesteps"
:
time_step
,
"mu"
:
condition
,
"mask"
:
mask
}
@
property
@
property
def
get_
input_shape
(
self
):
def
input_shape
(
self
):
return
(
4
,
32
,
16
)
return
(
4
,
32
,
16
)
@
property
@
property
def
get_
output_shape
(
self
):
def
output_shape
(
self
):
return
(
4
,
32
,
16
)
return
(
4
,
32
,
16
)
def
prepare_init_args_and_inputs_for_common
(
self
):
def
prepare_init_args_and_inputs_for_common
(
self
):
...
@@ -610,6 +610,38 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -610,6 +610,38 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
class
TemporalUNetModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
TemporalUNetModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
TemporalUNet
model_class
=
TemporalUNet
@
property
def
dummy_input
(
self
):
batch_size
=
4
num_features
=
14
seq_len
=
16
noise
=
floats_tensor
((
batch_size
,
seq_len
,
num_features
)).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
batch_size
).
to
(
torch_device
)
return
{
"x"
:
noise
,
"timesteps"
:
time_step
}
@
property
def
input_shape
(
self
):
return
(
4
,
16
,
14
)
@
property
def
output_shape
(
self
):
return
(
4
,
16
,
14
)
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"training_horizon"
:
128
,
"dim"
:
32
,
"dim_mults"
:
[
1
,
4
,
8
],
"predict_epsilon"
:
False
,
"clip_denoised"
:
True
,
"transition_dim"
:
14
,
"cond_dim"
:
3
,
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_from_pretrained_hub
(
self
):
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
TemporalUNet
.
from_pretrained
(
model
,
loading_info
=
TemporalUNet
.
from_pretrained
(
"fusing/ddpm-unet-rl-hopper-hor128"
,
output_loading_info
=
True
"fusing/ddpm-unet-rl-hopper-hor128"
,
output_loading_info
=
True
...
@@ -640,8 +672,7 @@ class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -640,8 +672,7 @@ class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase):
output_slice
=
output
[
0
,
-
3
:,
-
3
:].
flatten
()
output_slice
=
output
[
0
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
0.2714
,
0.1042
,
-
0.0794
,
-
0.2820
,
0.0803
,
-
0.0811
,
-
0.2345
,
0.0580
,
expected_output_slice
=
torch
.
tensor
([
-
0.2714
,
0.1042
,
-
0.0794
,
-
0.2820
,
0.0803
,
-
0.0811
,
-
0.2345
,
0.0580
,
-
0.0584
])
-
0.0584
])
# fmt: on
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
...
@@ -662,11 +693,11 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -662,11 +693,11 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
return
{
"x"
:
noise
,
"timesteps"
:
time_step
}
return
{
"x"
:
noise
,
"timesteps"
:
time_step
}
@
property
@
property
def
get_
input_shape
(
self
):
def
input_shape
(
self
):
return
(
3
,
32
,
32
)
return
(
3
,
32
,
32
)
@
property
@
property
def
get_
output_shape
(
self
):
def
output_shape
(
self
):
return
(
3
,
32
,
32
)
return
(
3
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
def
prepare_init_args_and_inputs_for_common
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment