Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
66ee73ee
Commit
66ee73ee
authored
Jun 29, 2022
by
patil-suraj
Browse files
refactor up/down sample blocks in unet_rl
parent
597b7ae2
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
21 deletions
+3
-21
src/diffusers/models/unet_rl.py
src/diffusers/models/unet_rl.py
+3
-21
No files found.
src/diffusers/models/unet_rl.py
View file @
66ee73ee
...
@@ -6,7 +6,7 @@ import torch.nn as nn
...
@@ -6,7 +6,7 @@ import torch.nn as nn
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
ResidualTemporalBlock
from
.resnet
import
Downsample
,
ResidualTemporalBlock
,
Upsample
class
SinusoidalPosEmb
(
nn
.
Module
):
class
SinusoidalPosEmb
(
nn
.
Module
):
...
@@ -18,24 +18,6 @@ class SinusoidalPosEmb(nn.Module):
...
@@ -18,24 +18,6 @@ class SinusoidalPosEmb(nn.Module):
return
get_timestep_embedding
(
x
,
self
.
dim
)
return
get_timestep_embedding
(
x
,
self
.
dim
)
class
Downsample1d
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv1d
(
dim
,
dim
,
3
,
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
Upsample1d
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
conv
=
nn
.
ConvTranspose1d
(
dim
,
dim
,
4
,
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
RearrangeDim
(
nn
.
Module
):
class
RearrangeDim
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
...
@@ -114,7 +96,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
...
@@ -114,7 +96,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
[
[
ResidualTemporalBlock
(
dim_in
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
training_horizon
),
ResidualTemporalBlock
(
dim_in
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
training_horizon
),
ResidualTemporalBlock
(
dim_out
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
training_horizon
),
ResidualTemporalBlock
(
dim_out
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
training_horizon
),
Downsample
1d
(
dim_out
)
if
not
is_last
else
nn
.
Identity
(),
Downsample
(
dim_out
,
use_conv
=
True
,
dims
=
1
)
if
not
is_last
else
nn
.
Identity
(),
]
]
)
)
)
)
...
@@ -134,7 +116,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
...
@@ -134,7 +116,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
[
[
ResidualTemporalBlock
(
dim_out
*
2
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
training_horizon
),
ResidualTemporalBlock
(
dim_out
*
2
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
training_horizon
),
ResidualTemporalBlock
(
dim_in
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
training_horizon
),
ResidualTemporalBlock
(
dim_in
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
training_horizon
),
Upsample
1d
(
dim_in
)
if
not
is_last
else
nn
.
Identity
(),
Upsample
(
dim_in
,
use_conv_transpose
=
True
,
dims
=
1
)
if
not
is_last
else
nn
.
Identity
(),
]
]
)
)
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment