Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
bd9c9fbf
Commit
bd9c9fbf
authored
Jun 22, 2022
by
Patrick von Platen
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
parents
f941fc99
e29fc446
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
25 deletions
+29
-25
src/diffusers/models/unet_rl.py
src/diffusers/models/unet_rl.py
+29
-25
No files found.
src/diffusers/models/unet_rl.py
View file @
bd9c9fbf
...
@@ -5,7 +5,6 @@ import math
...
@@ -5,7 +5,6 @@ import math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
try
:
try
:
import
einops
import
einops
from
einops.layers.torch
import
Rearrange
from
einops.layers.torch
import
Rearrange
...
@@ -13,7 +12,6 @@ except:
...
@@ -13,7 +12,6 @@ except:
print
(
"Einops is not installed"
)
print
(
"Einops is not installed"
)
pass
pass
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
...
@@ -106,15 +104,22 @@ class ResidualTemporalBlock(nn.Module):
...
@@ -106,15 +104,22 @@ class ResidualTemporalBlock(nn.Module):
class
TemporalUNet
(
ModelMixin
,
ConfigMixin
):
# (nn.Module):
class
TemporalUNet
(
ModelMixin
,
ConfigMixin
):
# (nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
horizon
,
training_horizon
,
transition_dim
,
transition_dim
,
cond_dim
,
cond_dim
,
dim
=
32
,
predict_epsilon
=
False
,
dim_mults
=
(
1
,
2
,
4
,
8
),
clip_denoised
=
True
,
dim
=
32
,
dim_mults
=
(
1
,
2
,
4
,
8
),
):
):
super
().
__init__
()
super
().
__init__
()
self
.
transition_dim
=
transition_dim
self
.
cond_dim
=
cond_dim
self
.
predict_epsilon
=
predict_epsilon
self
.
clip_denoised
=
clip_denoised
dims
=
[
transition_dim
,
*
map
(
lambda
m
:
dim
*
m
,
dim_mults
)]
dims
=
[
transition_dim
,
*
map
(
lambda
m
:
dim
*
m
,
dim_mults
)]
in_out
=
list
(
zip
(
dims
[:
-
1
],
dims
[
1
:]))
in_out
=
list
(
zip
(
dims
[:
-
1
],
dims
[
1
:]))
# print(f'[ models/temporal ] Channel dimensions: {in_out}')
# print(f'[ models/temporal ] Channel dimensions: {in_out}')
...
@@ -138,19 +143,19 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
...
@@ -138,19 +143,19 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
self
.
downs
.
append
(
self
.
downs
.
append
(
nn
.
ModuleList
(
nn
.
ModuleList
(
[
[
ResidualTemporalBlock
(
dim_in
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
ResidualTemporalBlock
(
dim_in
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
training_
horizon
),
ResidualTemporalBlock
(
dim_out
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
ResidualTemporalBlock
(
dim_out
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
training_
horizon
),
Downsample1d
(
dim_out
)
if
not
is_last
else
nn
.
Identity
(),
Downsample1d
(
dim_out
)
if
not
is_last
else
nn
.
Identity
(),
]
]
)
)
)
)
if
not
is_last
:
if
not
is_last
:
horizon
=
horizon
//
2
training_
horizon
=
training_
horizon
//
2
mid_dim
=
dims
[
-
1
]
mid_dim
=
dims
[
-
1
]
self
.
mid_block1
=
ResidualTemporalBlock
(
mid_dim
,
mid_dim
,
embed_dim
=
time_dim
,
horizon
=
horizon
)
self
.
mid_block1
=
ResidualTemporalBlock
(
mid_dim
,
mid_dim
,
embed_dim
=
time_dim
,
horizon
=
training_
horizon
)
self
.
mid_block2
=
ResidualTemporalBlock
(
mid_dim
,
mid_dim
,
embed_dim
=
time_dim
,
horizon
=
horizon
)
self
.
mid_block2
=
ResidualTemporalBlock
(
mid_dim
,
mid_dim
,
embed_dim
=
time_dim
,
horizon
=
training_
horizon
)
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
is_last
=
ind
>=
(
num_resolutions
-
1
)
is_last
=
ind
>=
(
num_resolutions
-
1
)
...
@@ -158,15 +163,15 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
...
@@ -158,15 +163,15 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
self
.
ups
.
append
(
self
.
ups
.
append
(
nn
.
ModuleList
(
nn
.
ModuleList
(
[
[
ResidualTemporalBlock
(
dim_out
*
2
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
ResidualTemporalBlock
(
dim_out
*
2
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
training_
horizon
),
ResidualTemporalBlock
(
dim_in
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
ResidualTemporalBlock
(
dim_in
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
training_
horizon
),
Upsample1d
(
dim_in
)
if
not
is_last
else
nn
.
Identity
(),
Upsample1d
(
dim_in
)
if
not
is_last
else
nn
.
Identity
(),
]
]
)
)
)
)
if
not
is_last
:
if
not
is_last
:
horizon
=
horizon
*
2
training_
horizon
=
training_
horizon
*
2
self
.
final_conv
=
nn
.
Sequential
(
self
.
final_conv
=
nn
.
Sequential
(
Conv1dBlock
(
dim
,
dim
,
kernel_size
=
5
),
Conv1dBlock
(
dim
,
dim
,
kernel_size
=
5
),
...
@@ -206,14 +211,14 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
...
@@ -206,14 +211,14 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
class
TemporalValue
(
nn
.
Module
):
class
TemporalValue
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
horizon
,
horizon
,
transition_dim
,
transition_dim
,
cond_dim
,
cond_dim
,
dim
=
32
,
dim
=
32
,
time_dim
=
None
,
time_dim
=
None
,
out_dim
=
1
,
out_dim
=
1
,
dim_mults
=
(
1
,
2
,
4
,
8
),
dim_mults
=
(
1
,
2
,
4
,
8
),
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -232,7 +237,6 @@ class TemporalValue(nn.Module):
...
@@ -232,7 +237,6 @@ class TemporalValue(nn.Module):
print
(
in_out
)
print
(
in_out
)
for
dim_in
,
dim_out
in
in_out
:
for
dim_in
,
dim_out
in
in_out
:
self
.
blocks
.
append
(
self
.
blocks
.
append
(
nn
.
ModuleList
(
nn
.
ModuleList
(
[
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment