Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
3a5c8705
Unverified
Commit
3a5c8705
authored
Jun 27, 2022
by
Nathan Lambert
Browse files
add RL test, remove conds from RL model input
parent
a2b72faf
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
42 additions
and
1 deletion
+42
-1
src/diffusers/models/unet_rl.py
src/diffusers/models/unet_rl.py
+1
-1
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+41
-0
No files found.
src/diffusers/models/unet_rl.py
View file @
3a5c8705
...
...
@@ -195,7 +195,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
nn
.
Conv1d
(
dim
,
transition_dim
,
1
),
)
def
forward
(
self
,
x
,
cond
,
time
):
def
forward
(
self
,
x
,
time
):
"""
x : [ batch x horizon x transition ]
"""
...
...
tests/test_modeling_utils.py
View file @
3a5c8705
...
...
@@ -40,6 +40,7 @@ from diffusers import (
ScoreSdeVeScheduler
,
ScoreSdeVpPipeline
,
ScoreSdeVpScheduler
,
TemporalUNet
,
UNetGradTTSModel
,
UNetLDMModel
,
UNetModel
,
...
...
@@ -606,6 +607,46 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
TemporalUNetModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
TemporalUNet
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
TemporalUNet
.
from_pretrained
(
"fusing/ddpm-unet-rl-hopper-hor128"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
model
=
TemporalUNet
.
from_pretrained
(
"fusing/ddpm-unet-rl-hopper-hor128"
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
num_features
=
model
.
transition_dim
seq_len
=
16
noise
=
torch
.
randn
((
1
,
seq_len
,
num_features
))
time_step
=
torch
.
full
((
num_features
,),
0
)
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output_slice
=
output
[
0
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
0.2714
,
0.1042
,
-
0.0794
,
-
0.2820
,
0.0803
,
-
0.0811
,
-
0.2345
,
0.0580
,
-
0.0584
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
NCSNppModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
NCSNpp
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment