Unverified Commit c01ec2d1 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[FlaxAutoencoderKL] rename weights to align with PT (#584)

* rename weights to align with PT

* DiagonalGaussianDistribution => FlaxDiagonalGaussianDistribution

* fix name
parent 0902449e
......@@ -34,15 +34,15 @@ class FlaxAutoencoderKLOutput(BaseOutput):
Output of AutoencoderKL encoding method.
Args:
latent_dist (`DiagonalGaussianDistribution`):
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
latent_dist (`FlaxDiagonalGaussianDistribution`):
Encoded outputs of `Encoder` represented as the mean and logvar of `FlaxDiagonalGaussianDistribution`.
`FlaxDiagonalGaussianDistribution` allows for sampling latents from the distribution.
"""
latent_dist: "DiagonalGaussianDistribution"
latent_dist: "FlaxDiagonalGaussianDistribution"
class Upsample2D(nn.Module):
class FlaxUpsample2D(nn.Module):
in_channels: int
dtype: jnp.dtype = jnp.float32
......@@ -66,7 +66,7 @@ class Upsample2D(nn.Module):
return hidden_states
class Downsample2D(nn.Module):
class FlaxDownsample2D(nn.Module):
in_channels: int
dtype: jnp.dtype = jnp.float32
......@@ -86,7 +86,7 @@ class Downsample2D(nn.Module):
return hidden_states
class ResnetBlock2D(nn.Module):
class FlaxResnetBlock2D(nn.Module):
in_channels: int
out_channels: int = None
dropout_prob: float = 0.0
......@@ -144,7 +144,7 @@ class ResnetBlock2D(nn.Module):
return hidden_states + residual
class AttentionBlock(nn.Module):
class FlaxAttentionBlock(nn.Module):
channels: int
num_head_channels: int = None
dtype: jnp.dtype = jnp.float32
......@@ -201,7 +201,7 @@ class AttentionBlock(nn.Module):
return hidden_states
class DownEncoderBlock2D(nn.Module):
class FlaxDownEncoderBlock2D(nn.Module):
in_channels: int
out_channels: int
dropout: float = 0.0
......@@ -214,7 +214,7 @@ class DownEncoderBlock2D(nn.Module):
for i in range(self.num_layers):
in_channels = self.in_channels if i == 0 else self.out_channels
res_block = ResnetBlock2D(
res_block = FlaxResnetBlock2D(
in_channels=in_channels,
out_channels=self.out_channels,
dropout_prob=self.dropout,
......@@ -224,19 +224,19 @@ class DownEncoderBlock2D(nn.Module):
self.resnets = resnets
if self.add_downsample:
self.downsample = Downsample2D(self.out_channels, dtype=self.dtype)
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
def __call__(self, hidden_states, deterministic=True):
for resnet in self.resnets:
hidden_states = resnet(hidden_states, deterministic=deterministic)
if self.add_downsample:
hidden_states = self.downsample(hidden_states)
hidden_states = self.downsamplers_0(hidden_states)
return hidden_states
class UpEncoderBlock2D(nn.Module):
class FlaxUpEncoderBlock2D(nn.Module):
in_channels: int
out_channels: int
dropout: float = 0.0
......@@ -248,7 +248,7 @@ class UpEncoderBlock2D(nn.Module):
resnets = []
for i in range(self.num_layers):
in_channels = self.in_channels if i == 0 else self.out_channels
res_block = ResnetBlock2D(
res_block = FlaxResnetBlock2D(
in_channels=in_channels,
out_channels=self.out_channels,
dropout_prob=self.dropout,
......@@ -259,19 +259,19 @@ class UpEncoderBlock2D(nn.Module):
self.resnets = resnets
if self.add_upsample:
self.upsample = Upsample2D(self.out_channels, dtype=self.dtype)
self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
def __call__(self, hidden_states, deterministic=True):
for resnet in self.resnets:
hidden_states = resnet(hidden_states, deterministic=deterministic)
if self.add_upsample:
hidden_states = self.upsample(hidden_states)
hidden_states = self.upsamplers_0(hidden_states)
return hidden_states
class UNetMidBlock2D(nn.Module):
class FlaxUNetMidBlock2D(nn.Module):
in_channels: int
dropout: float = 0.0
num_layers: int = 1
......@@ -281,7 +281,7 @@ class UNetMidBlock2D(nn.Module):
def setup(self):
# there is always at least one resnet
resnets = [
ResnetBlock2D(
FlaxResnetBlock2D(
in_channels=self.in_channels,
out_channels=self.in_channels,
dropout_prob=self.dropout,
......@@ -292,12 +292,12 @@ class UNetMidBlock2D(nn.Module):
attentions = []
for _ in range(self.num_layers):
attn_block = AttentionBlock(
attn_block = FlaxAttentionBlock(
channels=self.in_channels, num_head_channels=self.attn_num_head_channels, dtype=self.dtype
)
attentions.append(attn_block)
res_block = ResnetBlock2D(
res_block = FlaxResnetBlock2D(
in_channels=self.in_channels,
out_channels=self.in_channels,
dropout_prob=self.dropout,
......@@ -317,7 +317,7 @@ class UNetMidBlock2D(nn.Module):
return hidden_states
class Encoder(nn.Module):
class FlaxEncoder(nn.Module):
in_channels: int = 3
out_channels: int = 3
down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
......@@ -347,7 +347,7 @@ class Encoder(nn.Module):
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = DownEncoderBlock2D(
down_block = FlaxDownEncoderBlock2D(
in_channels=input_channel,
out_channels=output_channel,
num_layers=self.layers_per_block,
......@@ -358,7 +358,7 @@ class Encoder(nn.Module):
self.down_blocks = down_blocks
# middle
self.mid_block = UNetMidBlock2D(
self.mid_block = FlaxUNetMidBlock2D(
in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype
)
......@@ -392,7 +392,7 @@ class Encoder(nn.Module):
return sample
class Decoder(nn.Module):
class FlaxDecoder(nn.Module):
dtype: jnp.dtype = jnp.float32
in_channels: int = 3
out_channels: int = 3
......@@ -415,7 +415,7 @@ class Decoder(nn.Module):
)
# middle
self.mid_block = UNetMidBlock2D(
self.mid_block = FlaxUNetMidBlock2D(
in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype
)
......@@ -429,7 +429,7 @@ class Decoder(nn.Module):
is_final_block = i == len(block_out_channels) - 1
up_block = UpEncoderBlock2D(
up_block = FlaxUpEncoderBlock2D(
in_channels=prev_output_channel,
out_channels=output_channel,
num_layers=self.layers_per_block + 1,
......@@ -469,7 +469,7 @@ class Decoder(nn.Module):
return sample
class DiagonalGaussianDistribution(object):
class FlaxDiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
# Last axis to account for channels-last
self.mean, self.logvar = jnp.split(parameters, 2, axis=-1)
......@@ -521,7 +521,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
dtype: jnp.dtype = jnp.float32
def setup(self):
self.encoder = Encoder(
self.encoder = FlaxEncoder(
in_channels=self.config.in_channels,
out_channels=self.config.latent_channels,
down_block_types=self.config.down_block_types,
......@@ -532,7 +532,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
double_z=True,
dtype=self.dtype,
)
self.decoder = Decoder(
self.decoder = FlaxDecoder(
in_channels=self.config.latent_channels,
out_channels=self.config.out_channels,
up_block_types=self.config.up_block_types,
......@@ -572,7 +572,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
hidden_states = self.encoder(sample, deterministic=deterministic)
moments = self.quant_conv(hidden_states)
posterior = DiagonalGaussianDistribution(moments)
posterior = FlaxDiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment