Unverified Commit 8b0be935 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Flax documentation (#589)



* documenting `attention_flax.py` file

* documenting `embeddings_flax.py`

* documenting `unet_blocks_flax.py`

* Add new objs to doc page

* document `vae_flax.py`

* Apply suggestions from code review

* modify `unet_2d_condition_flax.py`

* make style

* Apply suggestions from code review

* make style

* Apply suggestions from code review

* fix indent

* fix typo

* fix indent unet

* Update src/diffusers/models/vae_flax.py

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avatarMishig Davaadorj <dmishig@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent df80ccf7
...@@ -45,3 +45,21 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module ...@@ -45,3 +45,21 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
## AutoencoderKL ## AutoencoderKL
[[autodoc]] AutoencoderKL [[autodoc]] AutoencoderKL
## FlaxModelMixin
[[autodoc]] FlaxModelMixin
## FlaxUNet2DConditionOutput
[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionOutput
## FlaxUNet2DConditionModel
[[autodoc]] FlaxUNet2DConditionModel
## FlaxDecoderOutput
[[autodoc]] models.vae_flax.FlaxDecoderOutput
## FlaxAutoencoderKLOutput
[[autodoc]] models.vae_flax.FlaxAutoencoderKLOutput
## FlaxAutoencoderKL
[[autodoc]] FlaxAutoencoderKL
...@@ -17,6 +17,22 @@ import jax.numpy as jnp ...@@ -17,6 +17,22 @@ import jax.numpy as jnp
class FlaxAttentionBlock(nn.Module): class FlaxAttentionBlock(nn.Module):
r"""
A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
Parameters:
query_dim (:obj:`int`):
Input hidden states dimension
heads (:obj:`int`, *optional*, defaults to 8):
Number of heads
dim_head (:obj:`int`, *optional*, defaults to 64):
Hidden states dimension inside each head
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
query_dim: int query_dim: int
heads: int = 8 heads: int = 8
dim_head: int = 64 dim_head: int = 64
...@@ -74,6 +90,23 @@ class FlaxAttentionBlock(nn.Module): ...@@ -74,6 +90,23 @@ class FlaxAttentionBlock(nn.Module):
class FlaxBasicTransformerBlock(nn.Module): class FlaxBasicTransformerBlock(nn.Module):
r"""
A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in:
https://arxiv.org/abs/1706.03762
Parameters:
dim (:obj:`int`):
Inner hidden states dimension
n_heads (:obj:`int`):
Number of heads
d_head (:obj:`int`):
Hidden states dimension inside each head
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
dim: int dim: int
n_heads: int n_heads: int
d_head: int d_head: int
...@@ -110,6 +143,25 @@ class FlaxBasicTransformerBlock(nn.Module): ...@@ -110,6 +143,25 @@ class FlaxBasicTransformerBlock(nn.Module):
class FlaxSpatialTransformer(nn.Module): class FlaxSpatialTransformer(nn.Module):
r"""
A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:
https://arxiv.org/pdf/1506.02025.pdf
Parameters:
in_channels (:obj:`int`):
Input number of channels
n_heads (:obj:`int`):
Number of heads
d_head (:obj:`int`):
Hidden states dimension inside each head
depth (:obj:`int`, *optional*, defaults to 1):
Number of transformers block
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int in_channels: int
n_heads: int n_heads: int
d_head: int d_head: int
...@@ -162,6 +214,18 @@ class FlaxSpatialTransformer(nn.Module): ...@@ -162,6 +214,18 @@ class FlaxSpatialTransformer(nn.Module):
class FlaxGluFeedForward(nn.Module): class FlaxGluFeedForward(nn.Module):
r"""
Flax module that encapsulates two Linear layers separated by a gated linear unit activation from:
https://arxiv.org/abs/2002.05202
Parameters:
dim (:obj:`int`):
Inner hidden states dimension
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
dim: int dim: int
dropout: float = 0.0 dropout: float = 0.0
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
...@@ -179,6 +243,18 @@ class FlaxGluFeedForward(nn.Module): ...@@ -179,6 +243,18 @@ class FlaxGluFeedForward(nn.Module):
class FlaxGEGLU(nn.Module): class FlaxGEGLU(nn.Module):
r"""
Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
https://arxiv.org/abs/2002.05202.
Parameters:
dim (:obj:`int`):
Input hidden states dimension
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
dim: int dim: int
dropout: float = 0.0 dropout: float = 0.0
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
......
...@@ -37,6 +37,15 @@ def get_sinusoidal_embeddings(timesteps, embedding_dim, freq_shift: float = 1): ...@@ -37,6 +37,15 @@ def get_sinusoidal_embeddings(timesteps, embedding_dim, freq_shift: float = 1):
class FlaxTimestepEmbedding(nn.Module): class FlaxTimestepEmbedding(nn.Module):
r"""
Time step Embedding Module. Learns embeddings for input time steps.
Args:
time_embed_dim (`int`, *optional*, defaults to `32`):
Time step embedding dimension
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
time_embed_dim: int = 32 time_embed_dim: int = 32
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
...@@ -49,6 +58,13 @@ class FlaxTimestepEmbedding(nn.Module): ...@@ -49,6 +58,13 @@ class FlaxTimestepEmbedding(nn.Module):
class FlaxTimesteps(nn.Module): class FlaxTimesteps(nn.Module):
r"""
Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239
Args:
dim (`int`, *optional*, defaults to `32`):
Time step embedding dimension
"""
dim: int = 32 dim: int = 32
freq_shift: float = 1 freq_shift: float = 1
......
...@@ -39,10 +39,23 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -39,10 +39,23 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for the generic methods the library This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the models (such as downloading or saving, etc.) implements for all the models (such as downloading or saving, etc.)
Also, this model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
Parameters: Parameters:
sample_size (`int`, *optional*): The size of the input sample. sample_size (`int`, *optional*):
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. The size of the input sample.
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. in_channels (`int`, *optional*, defaults to 4):
The number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4):
The number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use. The corresponding class names will be: "FlaxCrossAttnDownBlock2D", The tuple of downsample blocks to use. The corresponding class names will be: "FlaxCrossAttnDownBlock2D",
"FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D" "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D"
...@@ -51,10 +64,14 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -51,10 +64,14 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
"FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D" "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D"
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block. The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. layers_per_block (`int`, *optional*, defaults to 2):
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. The number of layers per block.
cross_attention_dim (`int`, *optional*, defaults to 768): The dimension of the cross attention features. attention_head_dim (`int`, *optional*, defaults to 8):
dropout (`float`, *optional*, defaults to 0): Dropout probability for down, up and bottleneck blocks. The dimension of the attention heads.
cross_attention_dim (`int`, *optional*, defaults to 768):
The dimension of the cross attention features.
dropout (`float`, *optional*, defaults to 0):
Dropout probability for down, up and bottleneck blocks.
""" """
sample_size: int = 32 sample_size: int = 32
......
...@@ -19,6 +19,26 @@ from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D ...@@ -19,6 +19,26 @@ from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
class FlaxCrossAttnDownBlock2D(nn.Module): class FlaxCrossAttnDownBlock2D(nn.Module):
r"""
Cross Attention 2D Downsizing block - original architecture from Unet transformers:
https://arxiv.org/abs/2103.06104
Parameters:
in_channels (:obj:`int`):
Input channels
out_channels (:obj:`int`):
Output channels
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of attention blocks layers
attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
Number of attention heads of each spatial transformer block
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add downsampling layer before each final output
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int in_channels: int
out_channels: int out_channels: int
dropout: float = 0.0 dropout: float = 0.0
...@@ -73,6 +93,23 @@ class FlaxCrossAttnDownBlock2D(nn.Module): ...@@ -73,6 +93,23 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
class FlaxDownBlock2D(nn.Module): class FlaxDownBlock2D(nn.Module):
r"""
Flax 2D downsizing block
Parameters:
in_channels (:obj:`int`):
Input channels
out_channels (:obj:`int`):
Output channels
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of attention blocks layers
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add downsampling layer before each final output
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int in_channels: int
out_channels: int out_channels: int
dropout: float = 0.0 dropout: float = 0.0
...@@ -113,6 +150,26 @@ class FlaxDownBlock2D(nn.Module): ...@@ -113,6 +150,26 @@ class FlaxDownBlock2D(nn.Module):
class FlaxCrossAttnUpBlock2D(nn.Module): class FlaxCrossAttnUpBlock2D(nn.Module):
r"""
Cross Attention 2D Upsampling block - original architecture from Unet transformers:
https://arxiv.org/abs/2103.06104
Parameters:
in_channels (:obj:`int`):
Input channels
out_channels (:obj:`int`):
Output channels
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of attention blocks layers
attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
Number of attention heads of each spatial transformer block
add_upsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add upsampling layer before each final output
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int in_channels: int
out_channels: int out_channels: int
prev_output_channel: int prev_output_channel: int
...@@ -170,6 +227,25 @@ class FlaxCrossAttnUpBlock2D(nn.Module): ...@@ -170,6 +227,25 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
class FlaxUpBlock2D(nn.Module): class FlaxUpBlock2D(nn.Module):
r"""
Flax 2D upsampling block
Parameters:
in_channels (:obj:`int`):
Input channels
out_channels (:obj:`int`):
Output channels
prev_output_channel (:obj:`int`):
Output channels from the previous block
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of attention blocks layers
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add downsampling layer before each final output
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int in_channels: int
out_channels: int out_channels: int
prev_output_channel: int prev_output_channel: int
...@@ -214,6 +290,21 @@ class FlaxUpBlock2D(nn.Module): ...@@ -214,6 +290,21 @@ class FlaxUpBlock2D(nn.Module):
class FlaxUNetMidBlock2DCrossAttn(nn.Module): class FlaxUNetMidBlock2DCrossAttn(nn.Module):
r"""
Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104
Parameters:
in_channels (:obj:`int`):
Input channels
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of attention blocks layers
attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
Number of attention heads of each spatial transformer block
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int in_channels: int
dropout: float = 0.0 dropout: float = 0.0
num_layers: int = 1 num_layers: int = 1
......
...@@ -23,6 +23,8 @@ class FlaxDecoderOutput(BaseOutput): ...@@ -23,6 +23,8 @@ class FlaxDecoderOutput(BaseOutput):
Args: Args:
sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
Decoded output sample of the model. Output of the last layer of the model. Decoded output sample of the model. Output of the last layer of the model.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
""" """
sample: jnp.ndarray sample: jnp.ndarray
...@@ -43,6 +45,16 @@ class FlaxAutoencoderKLOutput(BaseOutput): ...@@ -43,6 +45,16 @@ class FlaxAutoencoderKLOutput(BaseOutput):
class FlaxUpsample2D(nn.Module): class FlaxUpsample2D(nn.Module):
"""
Flax implementation of 2D Upsample layer
Args:
in_channels (`int`):
Input channels
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int in_channels: int
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
...@@ -67,6 +79,16 @@ class FlaxUpsample2D(nn.Module): ...@@ -67,6 +79,16 @@ class FlaxUpsample2D(nn.Module):
class FlaxDownsample2D(nn.Module): class FlaxDownsample2D(nn.Module):
"""
Flax implementation of 2D Downsample layer
Args:
in_channels (`int`):
Input channels
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int in_channels: int
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
...@@ -87,6 +109,22 @@ class FlaxDownsample2D(nn.Module): ...@@ -87,6 +109,22 @@ class FlaxDownsample2D(nn.Module):
class FlaxResnetBlock2D(nn.Module): class FlaxResnetBlock2D(nn.Module):
"""
Flax implementation of 2D Resnet Block.
Args:
in_channels (`int`):
Input channels
out_channels (`int`):
Output channels
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`):
Whether to use `nin_shortcut`. This activates a new layer inside ResNet block
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int in_channels: int
out_channels: int = None out_channels: int = None
dropout: float = 0.0 dropout: float = 0.0
...@@ -145,6 +183,18 @@ class FlaxResnetBlock2D(nn.Module): ...@@ -145,6 +183,18 @@ class FlaxResnetBlock2D(nn.Module):
class FlaxAttentionBlock(nn.Module): class FlaxAttentionBlock(nn.Module):
r"""
Flax Convolutional based multi-head attention block for diffusion-based VAE.
Parameters:
channels (:obj:`int`):
Input channels
num_head_channels (:obj:`int`, *optional*, defaults to `None`):
Number of attention heads
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
channels: int channels: int
num_head_channels: int = None num_head_channels: int = None
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
...@@ -202,6 +252,23 @@ class FlaxAttentionBlock(nn.Module): ...@@ -202,6 +252,23 @@ class FlaxAttentionBlock(nn.Module):
class FlaxDownEncoderBlock2D(nn.Module): class FlaxDownEncoderBlock2D(nn.Module):
r"""
Flax Resnet blocks-based Encoder block for diffusion-based VAE.
Parameters:
in_channels (:obj:`int`):
Input channels
out_channels (:obj:`int`):
Output channels
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of Resnet layer block
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add downsample layer
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int in_channels: int
out_channels: int out_channels: int
dropout: float = 0.0 dropout: float = 0.0
...@@ -237,6 +304,23 @@ class FlaxDownEncoderBlock2D(nn.Module): ...@@ -237,6 +304,23 @@ class FlaxDownEncoderBlock2D(nn.Module):
class FlaxUpEncoderBlock2D(nn.Module): class FlaxUpEncoderBlock2D(nn.Module):
r"""
Flax Resnet blocks-based Encoder block for diffusion-based VAE.
Parameters:
in_channels (:obj:`int`):
Input channels
out_channels (:obj:`int`):
Output channels
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of Resnet layer block
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add downsample layer
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int in_channels: int
out_channels: int out_channels: int
dropout: float = 0.0 dropout: float = 0.0
...@@ -272,6 +356,21 @@ class FlaxUpEncoderBlock2D(nn.Module): ...@@ -272,6 +356,21 @@ class FlaxUpEncoderBlock2D(nn.Module):
class FlaxUNetMidBlock2D(nn.Module): class FlaxUNetMidBlock2D(nn.Module):
r"""
Flax Unet Mid-Block module.
Parameters:
in_channels (:obj:`int`):
Input channels
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of Resnet layer block
attn_num_head_channels (:obj:`int`, *optional*, defaults to `1`):
Number of attention heads for each attention block
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int in_channels: int
dropout: float = 0.0 dropout: float = 0.0
num_layers: int = 1 num_layers: int = 1
...@@ -318,6 +417,39 @@ class FlaxUNetMidBlock2D(nn.Module): ...@@ -318,6 +417,39 @@ class FlaxUNetMidBlock2D(nn.Module):
class FlaxEncoder(nn.Module): class FlaxEncoder(nn.Module):
r"""
Flax Implementation of VAE Encoder.
This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
Parameters:
in_channels (:obj:`int`, *optional*, defaults to 3):
Input channels
out_channels (:obj:`int`, *optional*, defaults to 3):
Output channels
down_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`):
DownEncoder block type
block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`):
Tuple containing the number of output channels for each block
layers_per_block (:obj:`int`, *optional*, defaults to `2`):
Number of Resnet layer for each block
norm_num_groups (:obj:`int`, *optional*, defaults to `2`):
norm num group
act_fn (:obj:`str`, *optional*, defaults to `silu`):
Activation function
double_z (:obj:`bool`, *optional*, defaults to `False`):
Whether to double the last output channels
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int = 3 in_channels: int = 3
out_channels: int = 3 out_channels: int = 3
down_block_types: Tuple[str] = ("DownEncoderBlock2D",) down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
...@@ -393,7 +525,39 @@ class FlaxEncoder(nn.Module): ...@@ -393,7 +525,39 @@ class FlaxEncoder(nn.Module):
class FlaxDecoder(nn.Module): class FlaxDecoder(nn.Module):
dtype: jnp.dtype = jnp.float32 r"""
Flax Implementation of VAE Decoder.
This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
Parameters:
in_channels (:obj:`int`, *optional*, defaults to 3):
Input channels
out_channels (:obj:`int`, *optional*, defaults to 3):
Output channels
up_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`):
UpDecoder block type
block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`):
Tuple containing the number of output channels for each block
layers_per_block (:obj:`int`, *optional*, defaults to `2`):
Number of Resnet layer for each block
norm_num_groups (:obj:`int`, *optional*, defaults to `32`):
norm num group
act_fn (:obj:`str`, *optional*, defaults to `silu`):
Activation function
double_z (:obj:`bool`, *optional*, defaults to `False`):
Whether to double the last output channels
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
parameters `dtype`
"""
in_channels: int = 3 in_channels: int = 3
out_channels: int = 3 out_channels: int = 3
up_block_types: Tuple[str] = ("UpDecoderBlock2D",) up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
...@@ -401,6 +565,7 @@ class FlaxDecoder(nn.Module): ...@@ -401,6 +565,7 @@ class FlaxDecoder(nn.Module):
layers_per_block: int = 2 layers_per_block: int = 2
norm_num_groups: int = 32 norm_num_groups: int = 32
act_fn: str = "silu" act_fn: str = "silu"
dtype: jnp.dtype = jnp.float32
def setup(self): def setup(self):
block_out_channels = self.block_out_channels block_out_channels = self.block_out_channels
...@@ -508,6 +673,44 @@ class FlaxDiagonalGaussianDistribution(object): ...@@ -508,6 +673,44 @@ class FlaxDiagonalGaussianDistribution(object):
@flax_register_to_config @flax_register_to_config
class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
r"""
Flax Implementation of Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational
Bayes by Diederik P. Kingma and Max Welling.
This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
Parameters:
in_channels (:obj:`int`, *optional*, defaults to 3):
Input channels
out_channels (:obj:`int`, *optional*, defaults to 3):
Output channels
down_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`):
DownEncoder block type
up_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`):
UpDecoder block type
block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`):
Tuple containing the number of output channels for each block
layers_per_block (:obj:`int`, *optional*, defaults to `2`):
Number of Resnet layer for each block
act_fn (:obj:`str`, *optional*, defaults to `silu`):
Activation function
latent_channels (:obj:`int`, *optional*, defaults to `4`):
Latent space channels
norm_num_groups (:obj:`int`, *optional*, defaults to `32`):
Norm num group
sample_size (:obj:`int`, *optional*, defaults to `32`):
Sample input size
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
parameters `dtype`
"""
in_channels: int = 3 in_channels: int = 3
out_channels: int = 3 out_channels: int = 3
down_block_types: Tuple[str] = ("DownEncoderBlock2D",) down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment