Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
bf406ea8
Unverified
Commit
bf406ea8
authored
Nov 09, 2023
by
Patrick von Platen
Committed by
GitHub
Nov 09, 2023
Browse files
Correct consist dec (#5722)
* uP * Update src/diffusers/models/consistency_decoder_vae.py * uP * uP
parent
2fd46405
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
96 additions
and
37 deletions
+96
-37
src/diffusers/models/autoencoder_asym_kl.py
src/diffusers/models/autoencoder_asym_kl.py
+1
-0
src/diffusers/models/autoencoder_tiny.py
src/diffusers/models/autoencoder_tiny.py
+4
-2
src/diffusers/models/consistency_decoder_vae.py
src/diffusers/models/consistency_decoder_vae.py
+68
-3
tests/models/test_models_vae.py
tests/models/test_models_vae.py
+23
-32
No files found.
src/diffusers/models/autoencoder_asym_kl.py
View file @
bf406ea8
...
@@ -138,6 +138,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
...
@@ -138,6 +138,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
def
decode
(
def
decode
(
self
,
self
,
z
:
torch
.
FloatTensor
,
z
:
torch
.
FloatTensor
,
generator
:
Optional
[
torch
.
Generator
]
=
None
,
image
:
Optional
[
torch
.
FloatTensor
]
=
None
,
image
:
Optional
[
torch
.
FloatTensor
]
=
None
,
mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
return_dict
:
bool
=
True
,
return_dict
:
bool
=
True
,
...
...
src/diffusers/models/autoencoder_tiny.py
View file @
bf406ea8
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Tuple
,
Union
from
typing
import
Optional
,
Tuple
,
Union
import
torch
import
torch
...
@@ -307,7 +307,9 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
...
@@ -307,7 +307,9 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
return
AutoencoderTinyOutput
(
latents
=
output
)
return
AutoencoderTinyOutput
(
latents
=
output
)
@
apply_forward_hook
@
apply_forward_hook
def
decode
(
self
,
x
:
torch
.
FloatTensor
,
return_dict
:
bool
=
True
)
->
Union
[
DecoderOutput
,
Tuple
[
torch
.
FloatTensor
]]:
def
decode
(
self
,
x
:
torch
.
FloatTensor
,
generator
:
Optional
[
torch
.
Generator
]
=
None
,
return_dict
:
bool
=
True
)
->
Union
[
DecoderOutput
,
Tuple
[
torch
.
FloatTensor
]]:
if
self
.
use_slicing
and
x
.
shape
[
0
]
>
1
:
if
self
.
use_slicing
and
x
.
shape
[
0
]
>
1
:
output
=
[
self
.
_tiled_decode
(
x_slice
)
if
self
.
use_tiling
else
self
.
decoder
(
x
)
for
x_slice
in
x
.
split
(
1
)]
output
=
[
self
.
_tiled_decode
(
x_slice
)
if
self
.
use_tiling
else
self
.
decoder
(
x
)
for
x_slice
in
x
.
split
(
1
)]
output
=
torch
.
cat
(
output
)
output
=
torch
.
cat
(
output
)
...
...
src/diffusers/models/consistency_decoder_vae.py
View file @
bf406ea8
...
@@ -68,11 +68,76 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
...
@@ -68,11 +68,76 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
"""
"""
@
register_to_config
@
register_to_config
def
__init__
(
self
,
encoder_args
,
decoder_args
,
scaling_factor
,
block_out_channels
,
latent_channels
):
def
__init__
(
self
,
scaling_factor
=
0.18215
,
latent_channels
=
4
,
encoder_act_fn
=
"silu"
,
encoder_block_out_channels
=
(
128
,
256
,
512
,
512
),
encoder_double_z
=
True
,
encoder_down_block_types
=
(
"DownEncoderBlock2D"
,
"DownEncoderBlock2D"
,
"DownEncoderBlock2D"
,
"DownEncoderBlock2D"
,
),
encoder_in_channels
=
3
,
encoder_layers_per_block
=
2
,
encoder_norm_num_groups
=
32
,
encoder_out_channels
=
4
,
decoder_add_attention
=
False
,
decoder_block_out_channels
=
(
320
,
640
,
1024
,
1024
),
decoder_down_block_types
=
(
"ResnetDownsampleBlock2D"
,
"ResnetDownsampleBlock2D"
,
"ResnetDownsampleBlock2D"
,
"ResnetDownsampleBlock2D"
,
),
decoder_downsample_padding
=
1
,
decoder_in_channels
=
7
,
decoder_layers_per_block
=
3
,
decoder_norm_eps
=
1e-05
,
decoder_norm_num_groups
=
32
,
decoder_num_train_timesteps
=
1024
,
decoder_out_channels
=
6
,
decoder_resnet_time_scale_shift
=
"scale_shift"
,
decoder_time_embedding_type
=
"learned"
,
decoder_up_block_types
=
(
"ResnetUpsampleBlock2D"
,
"ResnetUpsampleBlock2D"
,
"ResnetUpsampleBlock2D"
,
"ResnetUpsampleBlock2D"
,
),
):
super
().
__init__
()
super
().
__init__
()
self
.
encoder
=
Encoder
(
**
encoder_args
)
self
.
encoder
=
Encoder
(
self
.
decoder_unet
=
UNet2DModel
(
**
decoder_args
)
act_fn
=
encoder_act_fn
,
block_out_channels
=
encoder_block_out_channels
,
double_z
=
encoder_double_z
,
down_block_types
=
encoder_down_block_types
,
in_channels
=
encoder_in_channels
,
layers_per_block
=
encoder_layers_per_block
,
norm_num_groups
=
encoder_norm_num_groups
,
out_channels
=
encoder_out_channels
,
)
self
.
decoder_unet
=
UNet2DModel
(
add_attention
=
decoder_add_attention
,
block_out_channels
=
decoder_block_out_channels
,
down_block_types
=
decoder_down_block_types
,
downsample_padding
=
decoder_downsample_padding
,
in_channels
=
decoder_in_channels
,
layers_per_block
=
decoder_layers_per_block
,
norm_eps
=
decoder_norm_eps
,
norm_num_groups
=
decoder_norm_num_groups
,
num_train_timesteps
=
decoder_num_train_timesteps
,
out_channels
=
decoder_out_channels
,
resnet_time_scale_shift
=
decoder_resnet_time_scale_shift
,
time_embedding_type
=
decoder_time_embedding_type
,
up_block_types
=
decoder_up_block_types
,
)
self
.
decoder_scheduler
=
ConsistencyDecoderScheduler
()
self
.
decoder_scheduler
=
ConsistencyDecoderScheduler
()
self
.
register_to_config
(
block_out_channels
=
encoder_block_out_channels
)
self
.
register_buffer
(
self
.
register_buffer
(
"means"
,
"means"
,
torch
.
tensor
([
0.38862467
,
0.02253063
,
0.07381133
,
-
0.0171294
])[
None
,
:,
None
,
None
],
torch
.
tensor
([
0.38862467
,
0.02253063
,
0.07381133
,
-
0.0171294
])[
None
,
:,
None
,
None
],
...
...
tests/models/test_models_vae.py
View file @
bf406ea8
...
@@ -303,39 +303,30 @@ class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
...
@@ -303,39 +303,30 @@ class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
@
property
@
property
def
init_dict
(
self
):
def
init_dict
(
self
):
return
{
return
{
"encoder_args"
:
{
"encoder_block_out_channels"
:
[
32
,
64
],
"block_out_channels"
:
[
32
,
64
],
"encoder_in_channels"
:
3
,
"in_channels"
:
3
,
"encoder_out_channels"
:
4
,
"out_channels"
:
4
,
"encoder_down_block_types"
:
[
"DownEncoderBlock2D"
,
"DownEncoderBlock2D"
],
"down_block_types"
:
[
"DownEncoderBlock2D"
,
"DownEncoderBlock2D"
],
"decoder_add_attention"
:
False
,
},
"decoder_block_out_channels"
:
[
32
,
64
],
"decoder_args"
:
{
"decoder_down_block_types"
:
[
"act_fn"
:
"silu"
,
"ResnetDownsampleBlock2D"
,
"add_attention"
:
False
,
"ResnetDownsampleBlock2D"
,
"block_out_channels"
:
[
32
,
64
],
],
"down_block_types"
:
[
"decoder_downsample_padding"
:
1
,
"ResnetDownsampleBlock2D"
,
"decoder_in_channels"
:
7
,
"ResnetDownsampleBlock2D"
,
"decoder_layers_per_block"
:
1
,
],
"decoder_norm_eps"
:
1e-05
,
"downsample_padding"
:
1
,
"decoder_norm_num_groups"
:
32
,
"downsample_type"
:
"conv"
,
"decoder_num_train_timesteps"
:
1024
,
"dropout"
:
0.0
,
"decoder_out_channels"
:
6
,
"in_channels"
:
7
,
"decoder_resnet_time_scale_shift"
:
"scale_shift"
,
"layers_per_block"
:
1
,
"decoder_time_embedding_type"
:
"learned"
,
"norm_eps"
:
1e-05
,
"decoder_up_block_types"
:
[
"norm_num_groups"
:
32
,
"ResnetUpsampleBlock2D"
,
"num_train_timesteps"
:
1024
,
"ResnetUpsampleBlock2D"
,
"out_channels"
:
6
,
],
"resnet_time_scale_shift"
:
"scale_shift"
,
"time_embedding_type"
:
"learned"
,
"up_block_types"
:
[
"ResnetUpsampleBlock2D"
,
"ResnetUpsampleBlock2D"
,
],
"upsample_type"
:
"conv"
,
},
"scaling_factor"
:
1
,
"scaling_factor"
:
1
,
"block_out_channels"
:
[
32
,
64
],
"latent_channels"
:
4
,
"latent_channels"
:
4
,
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment