Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
183056f2
"src/vscode:/vscode.git/clone" did not exist on "11f7d6f3cc07ed305c162d96bcdddb2ee6802832"
Commit
183056f2
authored
Jun 27, 2022
by
patil-suraj
Browse files
consolidate Upsample
parent
dc7c49e4
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
14 additions
and
93 deletions
+14
-93
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+1
-1
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+2
-15
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+4
-33
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+2
-10
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+4
-32
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+1
-2
No files found.
src/diffusers/models/resnet.py
View file @
183056f2
...
@@ -64,7 +64,7 @@ class Upsample(nn.Module):
...
@@ -64,7 +64,7 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
upsampling occurs in the inner-two dimensions.
"""
"""
def
__init__
(
self
,
channels
,
use_conv
,
use_conv_transpose
=
False
,
dims
=
2
,
out_channels
=
None
):
def
__init__
(
self
,
channels
,
use_conv
=
False
,
use_conv_transpose
=
False
,
dims
=
2
,
out_channels
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
out_channels
=
out_channels
or
channels
...
...
src/diffusers/models/unet.py
View file @
183056f2
...
@@ -31,6 +31,7 @@ from tqdm import tqdm
...
@@ -31,6 +31,7 @@ from tqdm import tqdm
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Upsample
def
nonlinearity
(
x
):
def
nonlinearity
(
x
):
...
@@ -42,20 +43,6 @@ def Normalize(in_channels):
...
@@ -42,20 +43,6 @@ def Normalize(in_channels):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
Upsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
with_conv
=
with_conv
if
self
.
with_conv
:
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
x
=
torch
.
nn
.
functional
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
if
self
.
with_conv
:
x
=
self
.
conv
(
x
)
return
x
class
Downsample
(
nn
.
Module
):
class
Downsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
super
().
__init__
()
...
@@ -259,7 +246,7 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -259,7 +246,7 @@ class UNetModel(ModelMixin, ConfigMixin):
up
.
block
=
block
up
.
block
=
block
up
.
attn
=
attn
up
.
attn
=
attn
if
i_level
!=
0
:
if
i_level
!=
0
:
up
.
upsample
=
Upsample
(
block_in
,
resamp_with_conv
)
up
.
upsample
=
Upsample
(
block_in
,
use_conv
=
resamp_with_conv
)
curr_res
=
curr_res
*
2
curr_res
=
curr_res
*
2
self
.
up
.
insert
(
0
,
up
)
# prepend to get consistent order
self
.
up
.
insert
(
0
,
up
)
# prepend to get consistent order
...
...
src/diffusers/models/unet_glide.py
View file @
183056f2
...
@@ -8,6 +8,7 @@ import torch.nn.functional as F
...
@@ -8,6 +8,7 @@ import torch.nn.functional as F
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Upsample
def
convert_module_to_f16
(
l
):
def
convert_module_to_f16
(
l
):
...
@@ -125,36 +126,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
...
@@ -125,36 +126,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return
x
return
x
class
Upsample
(
nn
.
Module
):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
if
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
dims
==
3
:
x
=
F
.
interpolate
(
x
,
(
x
.
shape
[
2
],
x
.
shape
[
3
]
*
2
,
x
.
shape
[
4
]
*
2
),
mode
=
"nearest"
)
else
:
x
=
F
.
interpolate
(
x
,
scale_factor
=
2
,
mode
=
"nearest"
)
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
return
x
class
Downsample
(
nn
.
Module
):
class
Downsample
(
nn
.
Module
):
"""
"""
A downsampling layer with an optional convolution.
A downsampling layer with an optional convolution.
...
@@ -231,8 +202,8 @@ class ResBlock(TimestepBlock):
...
@@ -231,8 +202,8 @@ class ResBlock(TimestepBlock):
self
.
updown
=
up
or
down
self
.
updown
=
up
or
down
if
up
:
if
up
:
self
.
h_upd
=
Upsample
(
channels
,
False
,
dims
)
self
.
h_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
self
.
x_upd
=
Upsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
elif
down
:
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
)
...
@@ -567,7 +538,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -567,7 +538,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
up
=
True
,
up
=
True
,
)
)
if
resblock_updown
if
resblock_updown
else
Upsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
else
Upsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
)
)
ds
//=
2
ds
//=
2
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
...
...
src/diffusers/models/unet_grad_tts.py
View file @
183056f2
...
@@ -10,6 +10,7 @@ except:
...
@@ -10,6 +10,7 @@ except:
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Upsample
class
Mish
(
torch
.
nn
.
Module
):
class
Mish
(
torch
.
nn
.
Module
):
...
@@ -17,15 +18,6 @@ class Mish(torch.nn.Module):
...
@@ -17,15 +18,6 @@ class Mish(torch.nn.Module):
return
x
*
torch
.
tanh
(
torch
.
nn
.
functional
.
softplus
(
x
))
return
x
*
torch
.
tanh
(
torch
.
nn
.
functional
.
softplus
(
x
))
class
Upsample
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
(
Upsample
,
self
).
__init__
()
self
.
conv
=
torch
.
nn
.
ConvTranspose2d
(
dim
,
dim
,
4
,
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
Downsample
(
torch
.
nn
.
Module
):
class
Downsample
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
):
def
__init__
(
self
,
dim
):
super
(
Downsample
,
self
).
__init__
()
super
(
Downsample
,
self
).
__init__
()
...
@@ -166,7 +158,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
...
@@ -166,7 +158,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
ResnetBlock
(
dim_out
*
2
,
dim_in
,
time_emb_dim
=
dim
),
ResnetBlock
(
dim_out
*
2
,
dim_in
,
time_emb_dim
=
dim
),
ResnetBlock
(
dim_in
,
dim_in
,
time_emb_dim
=
dim
),
ResnetBlock
(
dim_in
,
dim_in
,
time_emb_dim
=
dim
),
Residual
(
Rezero
(
LinearAttention
(
dim_in
))),
Residual
(
Rezero
(
LinearAttention
(
dim_in
))),
Upsample
(
dim_in
),
Upsample
(
dim_in
,
use_conv_transpose
=
True
),
]
]
)
)
)
)
...
...
src/diffusers/models/unet_ldm.py
View file @
183056f2
...
@@ -17,6 +17,7 @@ except:
...
@@ -17,6 +17,7 @@ except:
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Upsample
def
exists
(
val
):
def
exists
(
val
):
...
@@ -377,35 +378,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
...
@@ -377,35 +378,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return
x
return
x
class
Upsample
(
nn
.
Module
):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
if
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
padding
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
dims
==
3
:
x
=
F
.
interpolate
(
x
,
(
x
.
shape
[
2
],
x
.
shape
[
3
]
*
2
,
x
.
shape
[
4
]
*
2
),
mode
=
"nearest"
)
else
:
x
=
F
.
interpolate
(
x
,
scale_factor
=
2
,
mode
=
"nearest"
)
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
return
x
class
Downsample
(
nn
.
Module
):
class
Downsample
(
nn
.
Module
):
"""
"""
A downsampling layer with an optional convolution.
A downsampling layer with an optional convolution.
...
@@ -480,8 +452,8 @@ class ResBlock(TimestepBlock):
...
@@ -480,8 +452,8 @@ class ResBlock(TimestepBlock):
self
.
updown
=
up
or
down
self
.
updown
=
up
or
down
if
up
:
if
up
:
self
.
h_upd
=
Upsample
(
channels
,
False
,
dims
)
self
.
h_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
self
.
x_upd
=
Upsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
elif
down
:
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
)
...
@@ -948,7 +920,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -948,7 +920,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
up
=
True
,
up
=
True
,
)
)
if
resblock_updown
if
resblock_updown
else
Upsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
else
Upsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
)
)
ds
//=
2
ds
//=
2
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
...
...
tests/test_modeling_utils.py
View file @
183056f2
...
@@ -21,7 +21,7 @@ import unittest
...
@@ -21,7 +21,7 @@ import unittest
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
diffusers
import
(
from
diffusers
import
(
# GradTTSPipeline,
BDDMPipeline
,
BDDMPipeline
,
DDIMPipeline
,
DDIMPipeline
,
DDIMScheduler
,
DDIMScheduler
,
...
@@ -30,7 +30,6 @@ from diffusers import (
...
@@ -30,7 +30,6 @@ from diffusers import (
GlidePipeline
,
GlidePipeline
,
GlideSuperResUNetModel
,
GlideSuperResUNetModel
,
GlideTextToImageUNetModel
,
GlideTextToImageUNetModel
,
GradTTSPipeline
,
GradTTSScheduler
,
GradTTSScheduler
,
LatentDiffusionPipeline
,
LatentDiffusionPipeline
,
NCSNpp
,
NCSNpp
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment