Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
d1fb3093
Commit
d1fb3093
authored
Jun 27, 2022
by
patil-suraj
Browse files
consolidate downsample
parent
7b9b946c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
20 additions
and
95 deletions
+20
-95
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+2
-20
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+6
-32
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+3
-11
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+9
-32
No files found.
src/diffusers/models/unet.py
View file @
d1fb3093
...
...
@@ -31,7 +31,7 @@ from tqdm import tqdm
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Upsample
from
.resnet
import
Downsample
,
Upsample
def
nonlinearity
(
x
):
...
...
@@ -43,24 +43,6 @@ def Normalize(in_channels):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
Downsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
with_conv
=
with_conv
if
self
.
with_conv
:
# no asymmetric padding in torch conv, must do it ourselves
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
2
,
padding
=
0
)
def
forward
(
self
,
x
):
if
self
.
with_conv
:
pad
=
(
0
,
1
,
0
,
1
)
x
=
torch
.
nn
.
functional
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
x
=
self
.
conv
(
x
)
else
:
x
=
torch
.
nn
.
functional
.
avg_pool2d
(
x
,
kernel_size
=
2
,
stride
=
2
)
return
x
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
,
temb_channels
=
512
):
super
().
__init__
()
...
...
@@ -207,7 +189,7 @@ class UNetModel(ModelMixin, ConfigMixin):
down
.
block
=
block
down
.
attn
=
attn
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
,
resamp_with_conv
)
down
.
downsample
=
Downsample
(
block_in
,
use_conv
=
resamp_with_conv
,
padding
=
0
)
curr_res
=
curr_res
//
2
self
.
down
.
append
(
down
)
...
...
src/diffusers/models/unet_glide.py
View file @
d1fb3093
...
...
@@ -8,7 +8,7 @@ import torch.nn.functional as F
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Upsample
from
.resnet
import
Downsample
,
Upsample
def
convert_module_to_f16
(
l
):
...
...
@@ -126,34 +126,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return
x
class
Downsample
(
nn
.
Module
):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
stride
=
2
if
dims
!=
3
else
(
1
,
2
,
2
)
if
use_conv
:
self
.
op
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
1
)
else
:
assert
self
.
channels
==
self
.
out_channels
self
.
op
=
avg_pool_nd
(
dims
,
kernel_size
=
stride
,
stride
=
stride
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
return
self
.
op
(
x
)
class
ResBlock
(
TimestepBlock
):
"""
A residual block that can optionally change the number of channels.
...
...
@@ -205,8 +177,8 @@ class ResBlock(TimestepBlock):
self
.
h_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
self
.
x_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
h_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
self
.
x_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
else
:
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
...
...
@@ -463,7 +435,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
down
=
True
,
)
if
resblock_updown
else
Downsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
else
Downsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
padding
=
1
,
name
=
"op"
)
)
)
ch
=
out_ch
...
...
src/diffusers/models/unet_grad_tts.py
View file @
d1fb3093
import
torch
from
numpy
import
pad
try
:
...
...
@@ -10,7 +11,7 @@ except:
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Upsample
from
.resnet
import
Downsample
,
Upsample
class
Mish
(
torch
.
nn
.
Module
):
...
...
@@ -18,15 +19,6 @@ class Mish(torch.nn.Module):
return
x
*
torch
.
tanh
(
torch
.
nn
.
functional
.
softplus
(
x
))
class
Downsample
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
(
Downsample
,
self
).
__init__
()
self
.
conv
=
torch
.
nn
.
Conv2d
(
dim
,
dim
,
3
,
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
Rezero
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
fn
):
super
(
Rezero
,
self
).
__init__
()
...
...
@@ -141,7 +133,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
ResnetBlock
(
dim_in
,
dim_out
,
time_emb_dim
=
dim
),
ResnetBlock
(
dim_out
,
dim_out
,
time_emb_dim
=
dim
),
Residual
(
Rezero
(
LinearAttention
(
dim_out
))),
Downsample
(
dim_out
)
if
not
is_last
else
torch
.
nn
.
Identity
(),
Downsample
(
dim_out
,
use_conv
=
True
,
padding
=
1
)
if
not
is_last
else
torch
.
nn
.
Identity
(),
]
)
)
...
...
src/diffusers/models/unet_ldm.py
View file @
d1fb3093
...
...
@@ -17,7 +17,7 @@ except:
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Upsample
from
.resnet
import
Downsample
,
Upsample
def
exists
(
val
):
...
...
@@ -380,33 +380,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return
x
class
Downsample
(
nn
.
Module
):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
stride
=
2
if
dims
!=
3
else
(
1
,
2
,
2
)
if
use_conv
:
self
.
op
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
padding
)
else
:
assert
self
.
channels
==
self
.
out_channels
self
.
op
=
avg_pool_nd
(
dims
,
kernel_size
=
stride
,
stride
=
stride
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
return
self
.
op
(
x
)
class
ResBlock
(
TimestepBlock
):
"""
A residual block that can optionally change the number of channels.
...
...
@@ -457,8 +430,8 @@ class ResBlock(TimestepBlock):
self
.
h_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
self
.
x_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
h_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
self
.
x_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
else
:
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
...
...
@@ -825,7 +798,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
down
=
True
,
)
if
resblock_updown
else
Downsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
else
Downsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
padding
=
1
,
name
=
"op"
)
)
)
ch
=
out_ch
...
...
@@ -1098,7 +1073,9 @@ class EncoderUNetModel(nn.Module):
down
=
True
,
)
if
resblock_updown
else
Downsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
else
Downsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
padding
=
1
,
name
=
"op"
)
)
)
ch
=
out_ch
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment