Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
bd9c9fbf
Commit
bd9c9fbf
authored
Jun 22, 2022
by
Patrick von Platen
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
parents
f941fc99
e29fc446
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
25 deletions
+29
-25
src/diffusers/models/unet_rl.py
src/diffusers/models/unet_rl.py
+29
-25
No files found.
src/diffusers/models/unet_rl.py
View file @
bd9c9fbf
...
@@ -5,7 +5,6 @@ import math
...
@@ -5,7 +5,6 @@ import math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
try
:
try
:
import
einops
import
einops
from
einops.layers.torch
import
Rearrange
from
einops.layers.torch
import
Rearrange
...
@@ -13,7 +12,6 @@ except:
...
@@ -13,7 +12,6 @@ except:
print
(
"Einops is not installed"
)
print
(
"Einops is not installed"
)
pass
pass
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
...
@@ -106,15 +104,22 @@ class ResidualTemporalBlock(nn.Module):
...
@@ -106,15 +104,22 @@ class ResidualTemporalBlock(nn.Module):
class
TemporalUNet
(
ModelMixin
,
ConfigMixin
):
# (nn.Module):
class
TemporalUNet
(
ModelMixin
,
ConfigMixin
):
# (nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
horizon
,
training_horizon
,
transition_dim
,
transition_dim
,
cond_dim
,
cond_dim
,
dim
=
32
,
predict_epsilon
=
False
,
dim_mults
=
(
1
,
2
,
4
,
8
),
clip_denoised
=
True
,
dim
=
32
,
dim_mults
=
(
1
,
2
,
4
,
8
),
):
):
super
().
__init__
()
super
().
__init__
()
self
.
transition_dim
=
transition_dim
self
.
cond_dim
=
cond_dim
self
.
predict_epsilon
=
predict_epsilon
self
.
clip_denoised
=
clip_denoised
dims
=
[
transition_dim
,
*
map
(
lambda
m
:
dim
*
m
,
dim_mults
)]
dims
=
[
transition_dim
,
*
map
(
lambda
m
:
dim
*
m
,
dim_mults
)]
in_out
=
list
(
zip
(
dims
[:
-
1
],
dims
[
1
:]))
in_out
=
list
(
zip
(
dims
[:
-
1
],
dims
[
1
:]))
# print(f'[ models/temporal ] Channel dimensions: {in_out}')
# print(f'[ models/temporal ] Channel dimensions: {in_out}')
...
@@ -138,19 +143,19 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
...
@@ -138,19 +143,19 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
self
.
downs
.
append
(
self
.
downs
.
append
(
nn
.
ModuleList
(
nn
.
ModuleList
(
[
[
ResidualTemporalBlock
(
dim_in
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
ResidualTemporalBlock
(
dim_in
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
training_
horizon
),
ResidualTemporalBlock
(
dim_out
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
ResidualTemporalBlock
(
dim_out
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
training_
horizon
),
Downsample1d
(
dim_out
)
if
not
is_last
else
nn
.
Identity
(),
Downsample1d
(
dim_out
)
if
not
is_last
else
nn
.
Identity
(),
]
]
)
)
)
)
if
not
is_last
:
if
not
is_last
:
horizon
=
horizon
//
2
training_
horizon
=
training_
horizon
//
2
mid_dim
=
dims
[
-
1
]
mid_dim
=
dims
[
-
1
]
self
.
mid_block1
=
ResidualTemporalBlock
(
mid_dim
,
mid_dim
,
embed_dim
=
time_dim
,
horizon
=
horizon
)
self
.
mid_block1
=
ResidualTemporalBlock
(
mid_dim
,
mid_dim
,
embed_dim
=
time_dim
,
horizon
=
training_
horizon
)
self
.
mid_block2
=
ResidualTemporalBlock
(
mid_dim
,
mid_dim
,
embed_dim
=
time_dim
,
horizon
=
horizon
)
self
.
mid_block2
=
ResidualTemporalBlock
(
mid_dim
,
mid_dim
,
embed_dim
=
time_dim
,
horizon
=
training_
horizon
)
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
is_last
=
ind
>=
(
num_resolutions
-
1
)
is_last
=
ind
>=
(
num_resolutions
-
1
)
...
@@ -158,15 +163,15 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
...
@@ -158,15 +163,15 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
self
.
ups
.
append
(
self
.
ups
.
append
(
nn
.
ModuleList
(
nn
.
ModuleList
(
[
[
ResidualTemporalBlock
(
dim_out
*
2
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
ResidualTemporalBlock
(
dim_out
*
2
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
training_
horizon
),
ResidualTemporalBlock
(
dim_in
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
ResidualTemporalBlock
(
dim_in
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
training_
horizon
),
Upsample1d
(
dim_in
)
if
not
is_last
else
nn
.
Identity
(),
Upsample1d
(
dim_in
)
if
not
is_last
else
nn
.
Identity
(),
]
]
)
)
)
)
if
not
is_last
:
if
not
is_last
:
horizon
=
horizon
*
2
training_
horizon
=
training_
horizon
*
2
self
.
final_conv
=
nn
.
Sequential
(
self
.
final_conv
=
nn
.
Sequential
(
Conv1dBlock
(
dim
,
dim
,
kernel_size
=
5
),
Conv1dBlock
(
dim
,
dim
,
kernel_size
=
5
),
...
@@ -206,14 +211,14 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
...
@@ -206,14 +211,14 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
class
TemporalValue
(
nn
.
Module
):
class
TemporalValue
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
horizon
,
horizon
,
transition_dim
,
transition_dim
,
cond_dim
,
cond_dim
,
dim
=
32
,
dim
=
32
,
time_dim
=
None
,
time_dim
=
None
,
out_dim
=
1
,
out_dim
=
1
,
dim_mults
=
(
1
,
2
,
4
,
8
),
dim_mults
=
(
1
,
2
,
4
,
8
),
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -232,7 +237,6 @@ class TemporalValue(nn.Module):
...
@@ -232,7 +237,6 @@ class TemporalValue(nn.Module):
print
(
in_out
)
print
(
in_out
)
for
dim_in
,
dim_out
in
in_out
:
for
dim_in
,
dim_out
in
in_out
:
self
.
blocks
.
append
(
self
.
blocks
.
append
(
nn
.
ModuleList
(
nn
.
ModuleList
(
[
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment