"docs/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "6518fa9b2c74e84d7eb1fc6e3eb51e43213f0c05"
Unverified Commit 85cbe589 authored by Junyu Chen's avatar Junyu Chen Committed by GitHub
Browse files

Minor modification to support DC-AE-turbo (#12169)

* minor modification to support dc-ae-turbo

* minor
parent 4d9b8229
...@@ -299,6 +299,7 @@ class Decoder(nn.Module): ...@@ -299,6 +299,7 @@ class Decoder(nn.Module):
act_fn: Union[str, Tuple[str]] = "silu", act_fn: Union[str, Tuple[str]] = "silu",
upsample_block_type: str = "pixel_shuffle", upsample_block_type: str = "pixel_shuffle",
in_shortcut: bool = True, in_shortcut: bool = True,
conv_act_fn: str = "relu",
): ):
super().__init__() super().__init__()
...@@ -349,7 +350,7 @@ class Decoder(nn.Module): ...@@ -349,7 +350,7 @@ class Decoder(nn.Module):
channels = block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1] channels = block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1]
self.norm_out = RMSNorm(channels, 1e-5, elementwise_affine=True, bias=True) self.norm_out = RMSNorm(channels, 1e-5, elementwise_affine=True, bias=True)
self.conv_act = nn.ReLU() self.conv_act = get_activation(conv_act_fn)
self.conv_out = None self.conv_out = None
if layers_per_block[0] > 0: if layers_per_block[0] > 0:
...@@ -414,6 +415,12 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -414,6 +415,12 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
The normalization type(s) to use in the decoder. The normalization type(s) to use in the decoder.
decoder_act_fns (`Union[str, Tuple[str]]`, defaults to `"silu"`): decoder_act_fns (`Union[str, Tuple[str]]`, defaults to `"silu"`):
The activation function(s) to use in the decoder. The activation function(s) to use in the decoder.
encoder_out_shortcut (`bool`, defaults to `True`):
Whether to use shortcut at the end of the encoder.
decoder_in_shortcut (`bool`, defaults to `True`):
Whether to use shortcut at the beginning of the decoder.
decoder_conv_act_fn (`str`, defaults to `"relu"`):
The activation function to use at the end of the decoder.
scaling_factor (`float`, defaults to `1.0`): scaling_factor (`float`, defaults to `1.0`):
The multiplicative inverse of the root mean square of the latent features. This is used to scale the latent The multiplicative inverse of the root mean square of the latent features. This is used to scale the latent
space to have unit variance when training the diffusion model. The latents are scaled with the formula `z = space to have unit variance when training the diffusion model. The latents are scaled with the formula `z =
...@@ -441,6 +448,9 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -441,6 +448,9 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
downsample_block_type: str = "pixel_unshuffle", downsample_block_type: str = "pixel_unshuffle",
decoder_norm_types: Union[str, Tuple[str]] = "rms_norm", decoder_norm_types: Union[str, Tuple[str]] = "rms_norm",
decoder_act_fns: Union[str, Tuple[str]] = "silu", decoder_act_fns: Union[str, Tuple[str]] = "silu",
encoder_out_shortcut: bool = True,
decoder_in_shortcut: bool = True,
decoder_conv_act_fn: str = "relu",
scaling_factor: float = 1.0, scaling_factor: float = 1.0,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -454,6 +464,7 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -454,6 +464,7 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
layers_per_block=encoder_layers_per_block, layers_per_block=encoder_layers_per_block,
qkv_multiscales=encoder_qkv_multiscales, qkv_multiscales=encoder_qkv_multiscales,
downsample_block_type=downsample_block_type, downsample_block_type=downsample_block_type,
out_shortcut=encoder_out_shortcut,
) )
self.decoder = Decoder( self.decoder = Decoder(
in_channels=in_channels, in_channels=in_channels,
...@@ -466,6 +477,8 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -466,6 +477,8 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
norm_type=decoder_norm_types, norm_type=decoder_norm_types,
act_fn=decoder_act_fns, act_fn=decoder_act_fns,
upsample_block_type=upsample_block_type, upsample_block_type=upsample_block_type,
in_shortcut=decoder_in_shortcut,
conv_act_fn=decoder_conv_act_fn,
) )
self.spatial_compression_ratio = 2 ** (len(encoder_block_out_channels) - 1) self.spatial_compression_ratio = 2 ** (len(encoder_block_out_channels) - 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment