Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
d224c637
Unverified
Commit
d224c637
authored
Jul 03, 2022
by
Patrick von Platen
Committed by
GitHub
Jul 03, 2022
Browse files
Resnet => Resnet2D (#66)
parent
44705a64
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
41 additions
and
41 deletions
+41
-41
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+1
-1
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+5
-5
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+8
-8
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+7
-7
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+6
-6
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+7
-7
src/diffusers/models/vae.py
src/diffusers/models/vae.py
+7
-7
No files found.
src/diffusers/models/resnet.py
View file @
d224c637
...
@@ -176,7 +176,7 @@ class Downsample(nn.Module):
...
@@ -176,7 +176,7 @@ class Downsample(nn.Module):
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py, unet_score_vde.py
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py, unet_score_vde.py
# => All 2D-Resnets are included here now!
# => All 2D-Resnets are included here now!
class
ResnetBlock
(
nn
.
Module
):
class
ResnetBlock
2D
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
*
,
*
,
...
...
src/diffusers/models/unet.py
View file @
d224c637
...
@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
ResnetBlock
,
Upsample
from
.resnet
import
Downsample
,
ResnetBlock
2D
,
Upsample
def
nonlinearity
(
x
):
def
nonlinearity
(
x
):
...
@@ -89,7 +89,7 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -89,7 +89,7 @@ class UNetModel(ModelMixin, ConfigMixin):
block_out
=
ch
*
ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
for
i_block
in
range
(
self
.
num_res_blocks
):
block
.
append
(
block
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
)
)
...
@@ -106,11 +106,11 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -106,11 +106,11 @@ class UNetModel(ModelMixin, ConfigMixin):
# middle
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
self
.
mid
.
block_1
=
ResnetBlock
2D
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
self
.
mid
.
attn_1
=
AttentionBlock
(
block_in
,
overwrite_qkv
=
True
)
self
.
mid
.
attn_1
=
AttentionBlock
(
block_in
,
overwrite_qkv
=
True
)
self
.
mid
.
block_2
=
ResnetBlock
(
self
.
mid
.
block_2
=
ResnetBlock
2D
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
...
@@ -125,7 +125,7 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -125,7 +125,7 @@ class UNetModel(ModelMixin, ConfigMixin):
if
i_block
==
self
.
num_res_blocks
:
if
i_block
==
self
.
num_res_blocks
:
skip_in
=
ch
*
in_ch_mult
[
i_level
]
skip_in
=
ch
*
in_ch_mult
[
i_level
]
block
.
append
(
block
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
block_in
+
skip_in
,
in_channels
=
block_in
+
skip_in
,
out_channels
=
block_out
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
temb_channels
=
self
.
temb_ch
,
...
...
src/diffusers/models/unet_glide.py
View file @
d224c637
...
@@ -6,7 +6,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -6,7 +6,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
ResnetBlock
,
TimestepBlock
,
Upsample
from
.resnet
import
Downsample
,
ResnetBlock
2D
,
TimestepBlock
,
Upsample
def
convert_module_to_f16
(
l
):
def
convert_module_to_f16
(
l
):
...
@@ -88,7 +88,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
...
@@ -88,7 +88,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
def
forward
(
self
,
x
,
emb
,
encoder_out
=
None
):
def
forward
(
self
,
x
,
emb
,
encoder_out
=
None
):
for
layer
in
self
:
for
layer
in
self
:
if
isinstance
(
layer
,
TimestepBlock
)
or
isinstance
(
layer
,
ResnetBlock
):
if
isinstance
(
layer
,
TimestepBlock
)
or
isinstance
(
layer
,
ResnetBlock
2D
):
x
=
layer
(
x
,
emb
)
x
=
layer
(
x
,
emb
)
elif
isinstance
(
layer
,
AttentionBlock
):
elif
isinstance
(
layer
,
AttentionBlock
):
x
=
layer
(
x
,
encoder_out
)
x
=
layer
(
x
,
encoder_out
)
...
@@ -177,7 +177,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -177,7 +177,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
for
level
,
mult
in
enumerate
(
channel_mult
):
for
level
,
mult
in
enumerate
(
channel_mult
):
for
_
in
range
(
num_res_blocks
):
for
_
in
range
(
num_res_blocks
):
layers
=
[
layers
=
[
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
ch
,
in_channels
=
ch
,
out_channels
=
mult
*
model_channels
,
out_channels
=
mult
*
model_channels
,
dropout
=
dropout
,
dropout
=
dropout
,
...
@@ -206,7 +206,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -206,7 +206,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
out_ch
=
ch
out_ch
=
ch
self
.
input_blocks
.
append
(
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
TimestepEmbedSequential
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
ch
,
in_channels
=
ch
,
out_channels
=
out_ch
,
out_channels
=
out_ch
,
dropout
=
dropout
,
dropout
=
dropout
,
...
@@ -229,7 +229,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -229,7 +229,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
self
.
middle_block
=
TimestepEmbedSequential
(
self
.
middle_block
=
TimestepEmbedSequential
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
ch
,
in_channels
=
ch
,
dropout
=
dropout
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
temb_channels
=
time_embed_dim
,
...
@@ -245,7 +245,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -245,7 +245,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
num_head_channels
=
num_head_channels
,
num_head_channels
=
num_head_channels
,
encoder_channels
=
transformer_dim
,
encoder_channels
=
transformer_dim
,
),
),
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
ch
,
in_channels
=
ch
,
dropout
=
dropout
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
temb_channels
=
time_embed_dim
,
...
@@ -262,7 +262,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -262,7 +262,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
for
i
in
range
(
num_res_blocks
+
1
):
for
i
in
range
(
num_res_blocks
+
1
):
ich
=
input_block_chans
.
pop
()
ich
=
input_block_chans
.
pop
()
layers
=
[
layers
=
[
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
ch
+
ich
,
in_channels
=
ch
+
ich
,
out_channels
=
model_channels
*
mult
,
out_channels
=
model_channels
*
mult
,
dropout
=
dropout
,
dropout
=
dropout
,
...
@@ -287,7 +287,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -287,7 +287,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
if
level
and
i
==
num_res_blocks
:
if
level
and
i
==
num_res_blocks
:
out_ch
=
ch
out_ch
=
ch
layers
.
append
(
layers
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
ch
,
in_channels
=
ch
,
out_channels
=
out_ch
,
out_channels
=
out_ch
,
dropout
=
dropout
,
dropout
=
dropout
,
...
...
src/diffusers/models/unet_grad_tts.py
View file @
d224c637
...
@@ -4,7 +4,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -4,7 +4,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
LinearAttention
from
.attention
import
LinearAttention
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
ResnetBlock
,
Upsample
from
.resnet
import
Downsample
,
ResnetBlock
2D
,
Upsample
class
Mish
(
torch
.
nn
.
Module
):
class
Mish
(
torch
.
nn
.
Module
):
...
@@ -84,7 +84,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
...
@@ -84,7 +84,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self
.
downs
.
append
(
self
.
downs
.
append
(
torch
.
nn
.
ModuleList
(
torch
.
nn
.
ModuleList
(
[
[
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
dim_in
,
in_channels
=
dim_in
,
out_channels
=
dim_out
,
out_channels
=
dim_out
,
temb_channels
=
dim
,
temb_channels
=
dim
,
...
@@ -94,7 +94,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
...
@@ -94,7 +94,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
non_linearity
=
"mish"
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
,
overwrite_for_grad_tts
=
True
,
),
),
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
dim_out
,
in_channels
=
dim_out
,
out_channels
=
dim_out
,
out_channels
=
dim_out
,
temb_channels
=
dim
,
temb_channels
=
dim
,
...
@@ -111,7 +111,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
...
@@ -111,7 +111,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
)
)
mid_dim
=
dims
[
-
1
]
mid_dim
=
dims
[
-
1
]
self
.
mid_block1
=
ResnetBlock
(
self
.
mid_block1
=
ResnetBlock
2D
(
in_channels
=
mid_dim
,
in_channels
=
mid_dim
,
out_channels
=
mid_dim
,
out_channels
=
mid_dim
,
temb_channels
=
dim
,
temb_channels
=
dim
,
...
@@ -122,7 +122,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
...
@@ -122,7 +122,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
overwrite_for_grad_tts
=
True
,
overwrite_for_grad_tts
=
True
,
)
)
self
.
mid_attn
=
Residual
(
Rezero
(
LinearAttention
(
mid_dim
)))
self
.
mid_attn
=
Residual
(
Rezero
(
LinearAttention
(
mid_dim
)))
self
.
mid_block2
=
ResnetBlock
(
self
.
mid_block2
=
ResnetBlock
2D
(
in_channels
=
mid_dim
,
in_channels
=
mid_dim
,
out_channels
=
mid_dim
,
out_channels
=
mid_dim
,
temb_channels
=
dim
,
temb_channels
=
dim
,
...
@@ -137,7 +137,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
...
@@ -137,7 +137,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self
.
ups
.
append
(
self
.
ups
.
append
(
torch
.
nn
.
ModuleList
(
torch
.
nn
.
ModuleList
(
[
[
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
dim_out
*
2
,
in_channels
=
dim_out
*
2
,
out_channels
=
dim_in
,
out_channels
=
dim_in
,
temb_channels
=
dim
,
temb_channels
=
dim
,
...
@@ -147,7 +147,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
...
@@ -147,7 +147,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
non_linearity
=
"mish"
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
,
overwrite_for_grad_tts
=
True
,
),
),
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
dim_in
,
in_channels
=
dim_in
,
out_channels
=
dim_in
,
out_channels
=
dim_in
,
temb_channels
=
dim
,
temb_channels
=
dim
,
...
...
src/diffusers/models/unet_ldm.py
View file @
d224c637
...
@@ -10,7 +10,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -10,7 +10,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
ResnetBlock
,
TimestepBlock
,
Upsample
from
.resnet
import
Downsample
,
ResnetBlock
2D
,
TimestepBlock
,
Upsample
# from .resnet import ResBlock
# from .resnet import ResBlock
...
@@ -148,7 +148,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
...
@@ -148,7 +148,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
def
forward
(
self
,
x
,
emb
,
context
=
None
):
def
forward
(
self
,
x
,
emb
,
context
=
None
):
for
layer
in
self
:
for
layer
in
self
:
if
isinstance
(
layer
,
TimestepBlock
)
or
isinstance
(
layer
,
ResnetBlock
):
if
isinstance
(
layer
,
TimestepBlock
)
or
isinstance
(
layer
,
ResnetBlock
2D
):
x
=
layer
(
x
,
emb
)
x
=
layer
(
x
,
emb
)
elif
isinstance
(
layer
,
SpatialTransformer
):
elif
isinstance
(
layer
,
SpatialTransformer
):
x
=
layer
(
x
,
context
)
x
=
layer
(
x
,
context
)
...
@@ -310,7 +310,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -310,7 +310,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
for
level
,
mult
in
enumerate
(
channel_mult
):
for
level
,
mult
in
enumerate
(
channel_mult
):
for
_
in
range
(
num_res_blocks
):
for
_
in
range
(
num_res_blocks
):
layers
=
[
layers
=
[
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
ch
,
in_channels
=
ch
,
out_channels
=
mult
*
model_channels
,
out_channels
=
mult
*
model_channels
,
dropout
=
dropout
,
dropout
=
dropout
,
...
@@ -367,7 +367,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -367,7 +367,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
# num_heads = 1
# num_heads = 1
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
self
.
middle_block
=
TimestepEmbedSequential
(
self
.
middle_block
=
TimestepEmbedSequential
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
ch
,
in_channels
=
ch
,
out_channels
=
None
,
out_channels
=
None
,
dropout
=
dropout
,
dropout
=
dropout
,
...
@@ -385,7 +385,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -385,7 +385,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
)
)
if
not
use_spatial_transformer
if
not
use_spatial_transformer
else
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
),
else
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
),
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
ch
,
in_channels
=
ch
,
out_channels
=
None
,
out_channels
=
None
,
dropout
=
dropout
,
dropout
=
dropout
,
...
@@ -402,7 +402,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -402,7 +402,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
for
i
in
range
(
num_res_blocks
+
1
):
for
i
in
range
(
num_res_blocks
+
1
):
ich
=
input_block_chans
.
pop
()
ich
=
input_block_chans
.
pop
()
layers
=
[
layers
=
[
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
ch
+
ich
,
in_channels
=
ch
+
ich
,
out_channels
=
model_channels
*
mult
,
out_channels
=
model_channels
*
mult
,
dropout
=
dropout
,
dropout
=
dropout
,
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
d224c637
...
@@ -27,7 +27,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -27,7 +27,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.resnet
import
Downsample
,
ResnetBlock
,
Upsample
,
downsample_2d
,
upfirdn2d
,
upsample_2d
from
.resnet
import
Downsample
,
ResnetBlock
2D
,
Upsample
,
downsample_2d
,
upfirdn2d
,
upsample_2d
def
_setup_kernel
(
k
):
def
_setup_kernel
(
k
):
...
@@ -345,7 +345,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -345,7 +345,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
for
i_block
in
range
(
num_res_blocks
):
for
i_block
in
range
(
num_res_blocks
):
out_ch
=
nf
*
ch_mult
[
i_level
]
out_ch
=
nf
*
ch_mult
[
i_level
]
modules
.
append
(
modules
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
in_ch
,
in_channels
=
in_ch
,
out_channels
=
out_ch
,
out_channels
=
out_ch
,
temb_channels
=
4
*
nf
,
temb_channels
=
4
*
nf
,
...
@@ -364,7 +364,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -364,7 +364,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
if
i_level
!=
self
.
num_resolutions
-
1
:
if
i_level
!=
self
.
num_resolutions
-
1
:
modules
.
append
(
modules
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
in_ch
,
in_channels
=
in_ch
,
temb_channels
=
4
*
nf
,
temb_channels
=
4
*
nf
,
output_scale_factor
=
np
.
sqrt
(
2.0
),
output_scale_factor
=
np
.
sqrt
(
2.0
),
...
@@ -391,7 +391,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -391,7 +391,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
in_ch
=
hs_c
[
-
1
]
in_ch
=
hs_c
[
-
1
]
modules
.
append
(
modules
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
in_ch
,
in_channels
=
in_ch
,
temb_channels
=
4
*
nf
,
temb_channels
=
4
*
nf
,
output_scale_factor
=
np
.
sqrt
(
2.0
),
output_scale_factor
=
np
.
sqrt
(
2.0
),
...
@@ -403,7 +403,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -403,7 +403,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
)
)
modules
.
append
(
AttnBlock
(
channels
=
in_ch
))
modules
.
append
(
AttnBlock
(
channels
=
in_ch
))
modules
.
append
(
modules
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
in_ch
,
in_channels
=
in_ch
,
temb_channels
=
4
*
nf
,
temb_channels
=
4
*
nf
,
output_scale_factor
=
np
.
sqrt
(
2.0
),
output_scale_factor
=
np
.
sqrt
(
2.0
),
...
@@ -421,7 +421,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -421,7 +421,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
out_ch
=
nf
*
ch_mult
[
i_level
]
out_ch
=
nf
*
ch_mult
[
i_level
]
in_ch
=
in_ch
+
hs_c
.
pop
()
in_ch
=
in_ch
+
hs_c
.
pop
()
modules
.
append
(
modules
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
in_ch
,
in_channels
=
in_ch
,
out_channels
=
out_ch
,
out_channels
=
out_ch
,
temb_channels
=
4
*
nf
,
temb_channels
=
4
*
nf
,
...
@@ -464,7 +464,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -464,7 +464,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
if
i_level
!=
0
:
if
i_level
!=
0
:
modules
.
append
(
modules
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
in_ch
,
in_channels
=
in_ch
,
temb_channels
=
4
*
nf
,
temb_channels
=
4
*
nf
,
output_scale_factor
=
np
.
sqrt
(
2.0
),
output_scale_factor
=
np
.
sqrt
(
2.0
),
...
...
src/diffusers/models/vae.py
View file @
d224c637
...
@@ -5,7 +5,7 @@ import torch.nn as nn
...
@@ -5,7 +5,7 @@ import torch.nn as nn
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.resnet
import
Downsample
,
ResnetBlock
,
Upsample
from
.resnet
import
Downsample
,
ResnetBlock
2D
,
Upsample
def
nonlinearity
(
x
):
def
nonlinearity
(
x
):
...
@@ -54,7 +54,7 @@ class Encoder(nn.Module):
...
@@ -54,7 +54,7 @@ class Encoder(nn.Module):
block_out
=
ch
*
ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
for
i_block
in
range
(
self
.
num_res_blocks
):
block
.
append
(
block
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
)
)
...
@@ -71,11 +71,11 @@ class Encoder(nn.Module):
...
@@ -71,11 +71,11 @@ class Encoder(nn.Module):
# middle
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
self
.
mid
.
block_1
=
ResnetBlock
2D
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
self
.
mid
.
attn_1
=
AttentionBlock
(
block_in
,
overwrite_qkv
=
True
)
self
.
mid
.
attn_1
=
AttentionBlock
(
block_in
,
overwrite_qkv
=
True
)
self
.
mid
.
block_2
=
ResnetBlock
(
self
.
mid
.
block_2
=
ResnetBlock
2D
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
...
@@ -152,11 +152,11 @@ class Decoder(nn.Module):
...
@@ -152,11 +152,11 @@ class Decoder(nn.Module):
# middle
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
self
.
mid
.
block_1
=
ResnetBlock
2D
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
self
.
mid
.
attn_1
=
AttentionBlock
(
block_in
,
overwrite_qkv
=
True
)
self
.
mid
.
attn_1
=
AttentionBlock
(
block_in
,
overwrite_qkv
=
True
)
self
.
mid
.
block_2
=
ResnetBlock
(
self
.
mid
.
block_2
=
ResnetBlock
2D
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
...
@@ -168,7 +168,7 @@ class Decoder(nn.Module):
...
@@ -168,7 +168,7 @@ class Decoder(nn.Module):
block_out
=
ch
*
ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
block
.
append
(
block
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment