Unverified Commit c01ec2d1 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[FlaxAutoencoderKL] rename weights to align with PT (#584)

* rename weights to align with PT

* DiagonalGaussianDistribution => FlaxDiagonalGaussianDistribution

* fix name
parent 0902449e
...@@ -34,15 +34,15 @@ class FlaxAutoencoderKLOutput(BaseOutput): ...@@ -34,15 +34,15 @@ class FlaxAutoencoderKLOutput(BaseOutput):
Output of AutoencoderKL encoding method. Output of AutoencoderKL encoding method.
Args: Args:
latent_dist (`DiagonalGaussianDistribution`): latent_dist (`FlaxDiagonalGaussianDistribution`):
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`. Encoded outputs of `Encoder` represented as the mean and logvar of `FlaxDiagonalGaussianDistribution`.
`DiagonalGaussianDistribution` allows for sampling latents from the distribution. `FlaxDiagonalGaussianDistribution` allows for sampling latents from the distribution.
""" """
latent_dist: "DiagonalGaussianDistribution" latent_dist: "FlaxDiagonalGaussianDistribution"
class Upsample2D(nn.Module): class FlaxUpsample2D(nn.Module):
in_channels: int in_channels: int
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
...@@ -66,7 +66,7 @@ class Upsample2D(nn.Module): ...@@ -66,7 +66,7 @@ class Upsample2D(nn.Module):
return hidden_states return hidden_states
class Downsample2D(nn.Module): class FlaxDownsample2D(nn.Module):
in_channels: int in_channels: int
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
...@@ -86,7 +86,7 @@ class Downsample2D(nn.Module): ...@@ -86,7 +86,7 @@ class Downsample2D(nn.Module):
return hidden_states return hidden_states
class ResnetBlock2D(nn.Module): class FlaxResnetBlock2D(nn.Module):
in_channels: int in_channels: int
out_channels: int = None out_channels: int = None
dropout_prob: float = 0.0 dropout_prob: float = 0.0
...@@ -144,7 +144,7 @@ class ResnetBlock2D(nn.Module): ...@@ -144,7 +144,7 @@ class ResnetBlock2D(nn.Module):
return hidden_states + residual return hidden_states + residual
class AttentionBlock(nn.Module): class FlaxAttentionBlock(nn.Module):
channels: int channels: int
num_head_channels: int = None num_head_channels: int = None
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
...@@ -201,7 +201,7 @@ class AttentionBlock(nn.Module): ...@@ -201,7 +201,7 @@ class AttentionBlock(nn.Module):
return hidden_states return hidden_states
class DownEncoderBlock2D(nn.Module): class FlaxDownEncoderBlock2D(nn.Module):
in_channels: int in_channels: int
out_channels: int out_channels: int
dropout: float = 0.0 dropout: float = 0.0
...@@ -214,7 +214,7 @@ class DownEncoderBlock2D(nn.Module): ...@@ -214,7 +214,7 @@ class DownEncoderBlock2D(nn.Module):
for i in range(self.num_layers): for i in range(self.num_layers):
in_channels = self.in_channels if i == 0 else self.out_channels in_channels = self.in_channels if i == 0 else self.out_channels
res_block = ResnetBlock2D( res_block = FlaxResnetBlock2D(
in_channels=in_channels, in_channels=in_channels,
out_channels=self.out_channels, out_channels=self.out_channels,
dropout_prob=self.dropout, dropout_prob=self.dropout,
...@@ -224,19 +224,19 @@ class DownEncoderBlock2D(nn.Module): ...@@ -224,19 +224,19 @@ class DownEncoderBlock2D(nn.Module):
self.resnets = resnets self.resnets = resnets
if self.add_downsample: if self.add_downsample:
self.downsample = Downsample2D(self.out_channels, dtype=self.dtype) self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
def __call__(self, hidden_states, deterministic=True): def __call__(self, hidden_states, deterministic=True):
for resnet in self.resnets: for resnet in self.resnets:
hidden_states = resnet(hidden_states, deterministic=deterministic) hidden_states = resnet(hidden_states, deterministic=deterministic)
if self.add_downsample: if self.add_downsample:
hidden_states = self.downsample(hidden_states) hidden_states = self.downsamplers_0(hidden_states)
return hidden_states return hidden_states
class UpEncoderBlock2D(nn.Module): class FlaxUpEncoderBlock2D(nn.Module):
in_channels: int in_channels: int
out_channels: int out_channels: int
dropout: float = 0.0 dropout: float = 0.0
...@@ -248,7 +248,7 @@ class UpEncoderBlock2D(nn.Module): ...@@ -248,7 +248,7 @@ class UpEncoderBlock2D(nn.Module):
resnets = [] resnets = []
for i in range(self.num_layers): for i in range(self.num_layers):
in_channels = self.in_channels if i == 0 else self.out_channels in_channels = self.in_channels if i == 0 else self.out_channels
res_block = ResnetBlock2D( res_block = FlaxResnetBlock2D(
in_channels=in_channels, in_channels=in_channels,
out_channels=self.out_channels, out_channels=self.out_channels,
dropout_prob=self.dropout, dropout_prob=self.dropout,
...@@ -259,19 +259,19 @@ class UpEncoderBlock2D(nn.Module): ...@@ -259,19 +259,19 @@ class UpEncoderBlock2D(nn.Module):
self.resnets = resnets self.resnets = resnets
if self.add_upsample: if self.add_upsample:
self.upsample = Upsample2D(self.out_channels, dtype=self.dtype) self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
def __call__(self, hidden_states, deterministic=True): def __call__(self, hidden_states, deterministic=True):
for resnet in self.resnets: for resnet in self.resnets:
hidden_states = resnet(hidden_states, deterministic=deterministic) hidden_states = resnet(hidden_states, deterministic=deterministic)
if self.add_upsample: if self.add_upsample:
hidden_states = self.upsample(hidden_states) hidden_states = self.upsamplers_0(hidden_states)
return hidden_states return hidden_states
class UNetMidBlock2D(nn.Module): class FlaxUNetMidBlock2D(nn.Module):
in_channels: int in_channels: int
dropout: float = 0.0 dropout: float = 0.0
num_layers: int = 1 num_layers: int = 1
...@@ -281,7 +281,7 @@ class UNetMidBlock2D(nn.Module): ...@@ -281,7 +281,7 @@ class UNetMidBlock2D(nn.Module):
def setup(self): def setup(self):
# there is always at least one resnet # there is always at least one resnet
resnets = [ resnets = [
ResnetBlock2D( FlaxResnetBlock2D(
in_channels=self.in_channels, in_channels=self.in_channels,
out_channels=self.in_channels, out_channels=self.in_channels,
dropout_prob=self.dropout, dropout_prob=self.dropout,
...@@ -292,12 +292,12 @@ class UNetMidBlock2D(nn.Module): ...@@ -292,12 +292,12 @@ class UNetMidBlock2D(nn.Module):
attentions = [] attentions = []
for _ in range(self.num_layers): for _ in range(self.num_layers):
attn_block = AttentionBlock( attn_block = FlaxAttentionBlock(
channels=self.in_channels, num_head_channels=self.attn_num_head_channels, dtype=self.dtype channels=self.in_channels, num_head_channels=self.attn_num_head_channels, dtype=self.dtype
) )
attentions.append(attn_block) attentions.append(attn_block)
res_block = ResnetBlock2D( res_block = FlaxResnetBlock2D(
in_channels=self.in_channels, in_channels=self.in_channels,
out_channels=self.in_channels, out_channels=self.in_channels,
dropout_prob=self.dropout, dropout_prob=self.dropout,
...@@ -317,7 +317,7 @@ class UNetMidBlock2D(nn.Module): ...@@ -317,7 +317,7 @@ class UNetMidBlock2D(nn.Module):
return hidden_states return hidden_states
class Encoder(nn.Module): class FlaxEncoder(nn.Module):
in_channels: int = 3 in_channels: int = 3
out_channels: int = 3 out_channels: int = 3
down_block_types: Tuple[str] = ("DownEncoderBlock2D",) down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
...@@ -347,7 +347,7 @@ class Encoder(nn.Module): ...@@ -347,7 +347,7 @@ class Encoder(nn.Module):
output_channel = block_out_channels[i] output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1 is_final_block = i == len(block_out_channels) - 1
down_block = DownEncoderBlock2D( down_block = FlaxDownEncoderBlock2D(
in_channels=input_channel, in_channels=input_channel,
out_channels=output_channel, out_channels=output_channel,
num_layers=self.layers_per_block, num_layers=self.layers_per_block,
...@@ -358,7 +358,7 @@ class Encoder(nn.Module): ...@@ -358,7 +358,7 @@ class Encoder(nn.Module):
self.down_blocks = down_blocks self.down_blocks = down_blocks
# middle # middle
self.mid_block = UNetMidBlock2D( self.mid_block = FlaxUNetMidBlock2D(
in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype
) )
...@@ -392,7 +392,7 @@ class Encoder(nn.Module): ...@@ -392,7 +392,7 @@ class Encoder(nn.Module):
return sample return sample
class Decoder(nn.Module): class FlaxDecoder(nn.Module):
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
in_channels: int = 3 in_channels: int = 3
out_channels: int = 3 out_channels: int = 3
...@@ -415,7 +415,7 @@ class Decoder(nn.Module): ...@@ -415,7 +415,7 @@ class Decoder(nn.Module):
) )
# middle # middle
self.mid_block = UNetMidBlock2D( self.mid_block = FlaxUNetMidBlock2D(
in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype
) )
...@@ -429,7 +429,7 @@ class Decoder(nn.Module): ...@@ -429,7 +429,7 @@ class Decoder(nn.Module):
is_final_block = i == len(block_out_channels) - 1 is_final_block = i == len(block_out_channels) - 1
up_block = UpEncoderBlock2D( up_block = FlaxUpEncoderBlock2D(
in_channels=prev_output_channel, in_channels=prev_output_channel,
out_channels=output_channel, out_channels=output_channel,
num_layers=self.layers_per_block + 1, num_layers=self.layers_per_block + 1,
...@@ -469,7 +469,7 @@ class Decoder(nn.Module): ...@@ -469,7 +469,7 @@ class Decoder(nn.Module):
return sample return sample
class DiagonalGaussianDistribution(object): class FlaxDiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False): def __init__(self, parameters, deterministic=False):
# Last axis to account for channels-last # Last axis to account for channels-last
self.mean, self.logvar = jnp.split(parameters, 2, axis=-1) self.mean, self.logvar = jnp.split(parameters, 2, axis=-1)
...@@ -521,7 +521,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -521,7 +521,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
def setup(self): def setup(self):
self.encoder = Encoder( self.encoder = FlaxEncoder(
in_channels=self.config.in_channels, in_channels=self.config.in_channels,
out_channels=self.config.latent_channels, out_channels=self.config.latent_channels,
down_block_types=self.config.down_block_types, down_block_types=self.config.down_block_types,
...@@ -532,7 +532,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -532,7 +532,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
double_z=True, double_z=True,
dtype=self.dtype, dtype=self.dtype,
) )
self.decoder = Decoder( self.decoder = FlaxDecoder(
in_channels=self.config.latent_channels, in_channels=self.config.latent_channels,
out_channels=self.config.out_channels, out_channels=self.config.out_channels,
up_block_types=self.config.up_block_types, up_block_types=self.config.up_block_types,
...@@ -572,7 +572,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -572,7 +572,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
hidden_states = self.encoder(sample, deterministic=deterministic) hidden_states = self.encoder(sample, deterministic=deterministic)
moments = self.quant_conv(hidden_states) moments = self.quant_conv(hidden_states)
posterior = DiagonalGaussianDistribution(moments) posterior = FlaxDiagonalGaussianDistribution(moments)
if not return_dict: if not return_dict:
return (posterior,) return (posterior,)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment