Unverified Commit bf406ea8 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Correct consist dec (#5722)

* uP

* Update src/diffusers/models/consistency_decoder_vae.py

* uP

* uP
parent 2fd46405
...@@ -138,6 +138,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin): ...@@ -138,6 +138,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
def decode( def decode(
self, self,
z: torch.FloatTensor, z: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
image: Optional[torch.FloatTensor] = None, image: Optional[torch.FloatTensor] = None,
mask: Optional[torch.FloatTensor] = None, mask: Optional[torch.FloatTensor] = None,
return_dict: bool = True, return_dict: bool = True,
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
...@@ -307,7 +307,9 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): ...@@ -307,7 +307,9 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
return AutoencoderTinyOutput(latents=output) return AutoencoderTinyOutput(latents=output)
@apply_forward_hook @apply_forward_hook
def decode(self, x: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]: def decode(
self, x: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
if self.use_slicing and x.shape[0] > 1: if self.use_slicing and x.shape[0] > 1:
output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)] output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)]
output = torch.cat(output) output = torch.cat(output)
......
...@@ -68,11 +68,76 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ...@@ -68,11 +68,76 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
""" """
@register_to_config @register_to_config
def __init__(self, encoder_args, decoder_args, scaling_factor, block_out_channels, latent_channels): def __init__(
self,
scaling_factor=0.18215,
latent_channels=4,
encoder_act_fn="silu",
encoder_block_out_channels=(128, 256, 512, 512),
encoder_double_z=True,
encoder_down_block_types=(
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
),
encoder_in_channels=3,
encoder_layers_per_block=2,
encoder_norm_num_groups=32,
encoder_out_channels=4,
decoder_add_attention=False,
decoder_block_out_channels=(320, 640, 1024, 1024),
decoder_down_block_types=(
"ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D",
),
decoder_downsample_padding=1,
decoder_in_channels=7,
decoder_layers_per_block=3,
decoder_norm_eps=1e-05,
decoder_norm_num_groups=32,
decoder_num_train_timesteps=1024,
decoder_out_channels=6,
decoder_resnet_time_scale_shift="scale_shift",
decoder_time_embedding_type="learned",
decoder_up_block_types=(
"ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D",
),
):
super().__init__() super().__init__()
self.encoder = Encoder(**encoder_args) self.encoder = Encoder(
self.decoder_unet = UNet2DModel(**decoder_args) act_fn=encoder_act_fn,
block_out_channels=encoder_block_out_channels,
double_z=encoder_double_z,
down_block_types=encoder_down_block_types,
in_channels=encoder_in_channels,
layers_per_block=encoder_layers_per_block,
norm_num_groups=encoder_norm_num_groups,
out_channels=encoder_out_channels,
)
self.decoder_unet = UNet2DModel(
add_attention=decoder_add_attention,
block_out_channels=decoder_block_out_channels,
down_block_types=decoder_down_block_types,
downsample_padding=decoder_downsample_padding,
in_channels=decoder_in_channels,
layers_per_block=decoder_layers_per_block,
norm_eps=decoder_norm_eps,
norm_num_groups=decoder_norm_num_groups,
num_train_timesteps=decoder_num_train_timesteps,
out_channels=decoder_out_channels,
resnet_time_scale_shift=decoder_resnet_time_scale_shift,
time_embedding_type=decoder_time_embedding_type,
up_block_types=decoder_up_block_types,
)
self.decoder_scheduler = ConsistencyDecoderScheduler() self.decoder_scheduler = ConsistencyDecoderScheduler()
self.register_to_config(block_out_channels=encoder_block_out_channels)
self.register_buffer( self.register_buffer(
"means", "means",
torch.tensor([0.38862467, 0.02253063, 0.07381133, -0.0171294])[None, :, None, None], torch.tensor([0.38862467, 0.02253063, 0.07381133, -0.0171294])[None, :, None, None],
......
...@@ -303,39 +303,30 @@ class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase): ...@@ -303,39 +303,30 @@ class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
@property @property
def init_dict(self): def init_dict(self):
return { return {
"encoder_args": { "encoder_block_out_channels": [32, 64],
"block_out_channels": [32, 64], "encoder_in_channels": 3,
"in_channels": 3, "encoder_out_channels": 4,
"out_channels": 4, "encoder_down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], "decoder_add_attention": False,
}, "decoder_block_out_channels": [32, 64],
"decoder_args": { "decoder_down_block_types": [
"act_fn": "silu",
"add_attention": False,
"block_out_channels": [32, 64],
"down_block_types": [
"ResnetDownsampleBlock2D", "ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D", "ResnetDownsampleBlock2D",
], ],
"downsample_padding": 1, "decoder_downsample_padding": 1,
"downsample_type": "conv", "decoder_in_channels": 7,
"dropout": 0.0, "decoder_layers_per_block": 1,
"in_channels": 7, "decoder_norm_eps": 1e-05,
"layers_per_block": 1, "decoder_norm_num_groups": 32,
"norm_eps": 1e-05, "decoder_num_train_timesteps": 1024,
"norm_num_groups": 32, "decoder_out_channels": 6,
"num_train_timesteps": 1024, "decoder_resnet_time_scale_shift": "scale_shift",
"out_channels": 6, "decoder_time_embedding_type": "learned",
"resnet_time_scale_shift": "scale_shift", "decoder_up_block_types": [
"time_embedding_type": "learned",
"up_block_types": [
"ResnetUpsampleBlock2D", "ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D", "ResnetUpsampleBlock2D",
], ],
"upsample_type": "conv",
},
"scaling_factor": 1, "scaling_factor": 1,
"block_out_channels": [32, 64],
"latent_channels": 4, "latent_channels": 4,
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment