Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
a859b199
Commit
a859b199
authored
Jun 28, 2022
by
Patrick von Platen
Browse files
fix rl model tests
parent
85d991a1
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
52 additions
and
23 deletions
+52
-23
src/diffusers/models/unet_rl.py
src/diffusers/models/unet_rl.py
+6
-8
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+46
-15
No files found.
src/diffusers/models/unet_rl.py
View file @
a859b199
...
...
@@ -122,13 +122,13 @@ class ResidualTemporalBlock(nn.Module):
class
TemporalUNet
(
ModelMixin
,
ConfigMixin
):
# (nn.Module):
def
__init__
(
self
,
training_horizon
,
transition_dim
,
cond_dim
,
training_horizon
=
128
,
transition_dim
=
14
,
cond_dim
=
3
,
predict_epsilon
=
False
,
clip_denoised
=
True
,
dim
=
32
,
dim_mults
=
(
1
,
2
,
4
,
8
),
dim_mults
=
(
1
,
4
,
8
),
):
super
().
__init__
()
...
...
@@ -139,7 +139,6 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
dims
=
[
transition_dim
,
*
map
(
lambda
m
:
dim
*
m
,
dim_mults
)]
in_out
=
list
(
zip
(
dims
[:
-
1
],
dims
[
1
:]))
# print(f'[ models/temporal ] Channel dimensions: {in_out}')
time_dim
=
dim
self
.
time_mlp
=
nn
.
Sequential
(
...
...
@@ -153,7 +152,6 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
self
.
ups
=
nn
.
ModuleList
([])
num_resolutions
=
len
(
in_out
)
print
(
in_out
)
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
in_out
):
is_last
=
ind
>=
(
num_resolutions
-
1
)
...
...
@@ -195,7 +193,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
nn
.
Conv1d
(
dim
,
transition_dim
,
1
),
)
def
forward
(
self
,
x
,
time
):
def
forward
(
self
,
x
,
time
steps
):
"""
x : [ batch x horizon x transition ]
"""
...
...
@@ -203,7 +201,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
# x = einops.rearrange(x, "b h t -> b t h")
x
=
x
.
permute
(
0
,
2
,
1
)
t
=
self
.
time_mlp
(
time
)
t
=
self
.
time_mlp
(
time
steps
)
h
=
[]
for
resnet
,
resnet2
,
downsample
in
self
.
downs
:
...
...
tests/test_modeling_utils.py
View file @
a859b199
...
...
@@ -190,7 +190,7 @@ class ModelTesterMixin:
model
.
to
(
torch_device
)
model
.
train
()
output
=
model
(
**
inputs_dict
)
noise
=
torch
.
randn
((
inputs_dict
[
"x"
].
shape
[
0
],)
+
self
.
get_
output_shape
).
to
(
torch_device
)
noise
=
torch
.
randn
((
inputs_dict
[
"x"
].
shape
[
0
],)
+
self
.
output_shape
).
to
(
torch_device
)
loss
=
torch
.
nn
.
functional
.
mse_loss
(
output
,
noise
)
loss
.
backward
()
...
...
@@ -210,11 +210,11 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
return
{
"x"
:
noise
,
"timesteps"
:
time_step
}
@
property
def
get_
input_shape
(
self
):
def
input_shape
(
self
):
return
(
3
,
32
,
32
)
@
property
def
get_
output_shape
(
self
):
def
output_shape
(
self
):
return
(
3
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
...
...
@@ -276,11 +276,11 @@ class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase):
return
{
"x"
:
noise
,
"timesteps"
:
time_step
,
"low_res"
:
low_res
}
@
property
def
get_
input_shape
(
self
):
def
input_shape
(
self
):
return
(
3
,
32
,
32
)
@
property
def
get_
output_shape
(
self
):
def
output_shape
(
self
):
return
(
6
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
...
...
@@ -367,11 +367,11 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
return
{
"x"
:
noise
,
"timesteps"
:
time_step
,
"transformer_out"
:
emb
}
@
property
def
get_
input_shape
(
self
):
def
input_shape
(
self
):
return
(
3
,
32
,
32
)
@
property
def
get_
output_shape
(
self
):
def
output_shape
(
self
):
return
(
6
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
...
...
@@ -459,11 +459,11 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
return
{
"x"
:
noise
,
"timesteps"
:
time_step
}
@
property
def
get_
input_shape
(
self
):
def
input_shape
(
self
):
return
(
4
,
32
,
32
)
@
property
def
get_
output_shape
(
self
):
def
output_shape
(
self
):
return
(
4
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
...
...
@@ -552,11 +552,11 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
return
{
"x"
:
noise
,
"timesteps"
:
time_step
,
"mu"
:
condition
,
"mask"
:
mask
}
@
property
def
get_
input_shape
(
self
):
def
input_shape
(
self
):
return
(
4
,
32
,
16
)
@
property
def
get_
output_shape
(
self
):
def
output_shape
(
self
):
return
(
4
,
32
,
16
)
def
prepare_init_args_and_inputs_for_common
(
self
):
...
...
@@ -610,6 +610,38 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
class
TemporalUNetModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
TemporalUNet
@
property
def
dummy_input
(
self
):
batch_size
=
4
num_features
=
14
seq_len
=
16
noise
=
floats_tensor
((
batch_size
,
seq_len
,
num_features
)).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
batch_size
).
to
(
torch_device
)
return
{
"x"
:
noise
,
"timesteps"
:
time_step
}
@
property
def
input_shape
(
self
):
return
(
4
,
16
,
14
)
@
property
def
output_shape
(
self
):
return
(
4
,
16
,
14
)
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"training_horizon"
:
128
,
"dim"
:
32
,
"dim_mults"
:
[
1
,
4
,
8
],
"predict_epsilon"
:
False
,
"clip_denoised"
:
True
,
"transition_dim"
:
14
,
"cond_dim"
:
3
,
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
TemporalUNet
.
from_pretrained
(
"fusing/ddpm-unet-rl-hopper-hor128"
,
output_loading_info
=
True
...
...
@@ -640,8 +672,7 @@ class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase):
output_slice
=
output
[
0
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
0.2714
,
0.1042
,
-
0.0794
,
-
0.2820
,
0.0803
,
-
0.0811
,
-
0.2345
,
0.0580
,
-
0.0584
])
expected_output_slice
=
torch
.
tensor
([
-
0.2714
,
0.1042
,
-
0.0794
,
-
0.2820
,
0.0803
,
-
0.0811
,
-
0.2345
,
0.0580
,
-
0.0584
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
...
...
@@ -662,11 +693,11 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
return
{
"x"
:
noise
,
"timesteps"
:
time_step
}
@
property
def
get_
input_shape
(
self
):
def
input_shape
(
self
):
return
(
3
,
32
,
32
)
@
property
def
get_
output_shape
(
self
):
def
output_shape
(
self
):
return
(
3
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment