Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
7b4e049e
"docs/vscode:/vscode.git/clone" did not exist on "c44fba889965638f447d20f5730745c7963494d7"
Unverified
Commit
7b4e049e
authored
Jun 22, 2022
by
Nathan Lambert
Browse files
adding properties, formatting
parent
62c2c547
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
25 deletions
+29
-25
src/diffusers/models/unet_rl.py
src/diffusers/models/unet_rl.py
+29
-25
No files found.
src/diffusers/models/unet_rl.py
View file @
7b4e049e
...
@@ -5,7 +5,6 @@ import math
...
@@ -5,7 +5,6 @@ import math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
try
:
try
:
import
einops
import
einops
from
einops.layers.torch
import
Rearrange
from
einops.layers.torch
import
Rearrange
...
@@ -13,7 +12,6 @@ except:
...
@@ -13,7 +12,6 @@ except:
print
(
"Einops is not installed"
)
print
(
"Einops is not installed"
)
pass
pass
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
...
@@ -107,14 +105,21 @@ class ResidualTemporalBlock(nn.Module):
...
@@ -107,14 +105,21 @@ class ResidualTemporalBlock(nn.Module):
class
TemporalUNet
(
ModelMixin
,
ConfigMixin
):
# (nn.Module):
class
TemporalUNet
(
ModelMixin
,
ConfigMixin
):
# (nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
horizon
,
training_
horizon
,
transition_dim
,
transition_dim
,
cond_dim
,
cond_dim
,
predict_epsilon
=
False
,
clip_denoised
=
True
,
dim
=
32
,
dim
=
32
,
dim_mults
=
(
1
,
2
,
4
,
8
),
dim_mults
=
(
1
,
2
,
4
,
8
),
):
):
super
().
__init__
()
super
().
__init__
()
self
.
transition_dim
=
transition_dim
self
.
cond_dim
=
cond_dim
self
.
predict_epsilon
=
predict_epsilon
self
.
clip_denoised
=
clip_denoised
dims
=
[
transition_dim
,
*
map
(
lambda
m
:
dim
*
m
,
dim_mults
)]
dims
=
[
transition_dim
,
*
map
(
lambda
m
:
dim
*
m
,
dim_mults
)]
in_out
=
list
(
zip
(
dims
[:
-
1
],
dims
[
1
:]))
in_out
=
list
(
zip
(
dims
[:
-
1
],
dims
[
1
:]))
# print(f'[ models/temporal ] Channel dimensions: {in_out}')
# print(f'[ models/temporal ] Channel dimensions: {in_out}')
...
@@ -138,19 +143,19 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
...
@@ -138,19 +143,19 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
self
.
downs
.
append
(
self
.
downs
.
append
(
nn
.
ModuleList
(
nn
.
ModuleList
(
[
[
ResidualTemporalBlock
(
dim_in
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
ResidualTemporalBlock
(
dim_in
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
training_
horizon
),
ResidualTemporalBlock
(
dim_out
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
ResidualTemporalBlock
(
dim_out
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
training_
horizon
),
Downsample1d
(
dim_out
)
if
not
is_last
else
nn
.
Identity
(),
Downsample1d
(
dim_out
)
if
not
is_last
else
nn
.
Identity
(),
]
]
)
)
)
)
if
not
is_last
:
if
not
is_last
:
horizon
=
horizon
//
2
training_
horizon
=
training_
horizon
//
2
mid_dim
=
dims
[
-
1
]
mid_dim
=
dims
[
-
1
]
self
.
mid_block1
=
ResidualTemporalBlock
(
mid_dim
,
mid_dim
,
embed_dim
=
time_dim
,
horizon
=
horizon
)
self
.
mid_block1
=
ResidualTemporalBlock
(
mid_dim
,
mid_dim
,
embed_dim
=
time_dim
,
horizon
=
training_
horizon
)
self
.
mid_block2
=
ResidualTemporalBlock
(
mid_dim
,
mid_dim
,
embed_dim
=
time_dim
,
horizon
=
horizon
)
self
.
mid_block2
=
ResidualTemporalBlock
(
mid_dim
,
mid_dim
,
embed_dim
=
time_dim
,
horizon
=
training_
horizon
)
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
is_last
=
ind
>=
(
num_resolutions
-
1
)
is_last
=
ind
>=
(
num_resolutions
-
1
)
...
@@ -158,15 +163,15 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
...
@@ -158,15 +163,15 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
self
.
ups
.
append
(
self
.
ups
.
append
(
nn
.
ModuleList
(
nn
.
ModuleList
(
[
[
ResidualTemporalBlock
(
dim_out
*
2
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
ResidualTemporalBlock
(
dim_out
*
2
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
training_
horizon
),
ResidualTemporalBlock
(
dim_in
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
ResidualTemporalBlock
(
dim_in
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
training_
horizon
),
Upsample1d
(
dim_in
)
if
not
is_last
else
nn
.
Identity
(),
Upsample1d
(
dim_in
)
if
not
is_last
else
nn
.
Identity
(),
]
]
)
)
)
)
if
not
is_last
:
if
not
is_last
:
horizon
=
horizon
*
2
training_
horizon
=
training_
horizon
*
2
self
.
final_conv
=
nn
.
Sequential
(
self
.
final_conv
=
nn
.
Sequential
(
Conv1dBlock
(
dim
,
dim
,
kernel_size
=
5
),
Conv1dBlock
(
dim
,
dim
,
kernel_size
=
5
),
...
@@ -232,7 +237,6 @@ class TemporalValue(nn.Module):
...
@@ -232,7 +237,6 @@ class TemporalValue(nn.Module):
print
(
in_out
)
print
(
in_out
)
for
dim_in
,
dim_out
in
in_out
:
for
dim_in
,
dim_out
in
in_out
:
self
.
blocks
.
append
(
self
.
blocks
.
append
(
nn
.
ModuleList
(
nn
.
ModuleList
(
[
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment