Unverified Commit a1242044 authored by Akash Pannu's avatar Akash Pannu Committed by GitHub
Browse files

Flax: Trickle down `norm_num_groups` (#789)

* pass norm_num_groups param and add tests

* set resnet_groups for FlaxUNetMidBlock2D

* fixed docstrings

* fixed typo

* using is_flax_available util and created require_flax decorator
parent 66a5279a
...@@ -119,6 +119,8 @@ class FlaxResnetBlock2D(nn.Module): ...@@ -119,6 +119,8 @@ class FlaxResnetBlock2D(nn.Module):
Output channels Output channels
dropout (:obj:`float`, *optional*, defaults to 0.0): dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate Dropout rate
groups (:obj:`int`, *optional*, defaults to `32`):
The number of groups to use for group norm.
use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`): use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`):
Whether to use `nin_shortcut`. This activates a new layer inside ResNet block Whether to use `nin_shortcut`. This activates a new layer inside ResNet block
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
...@@ -128,13 +130,14 @@ class FlaxResnetBlock2D(nn.Module): ...@@ -128,13 +130,14 @@ class FlaxResnetBlock2D(nn.Module):
in_channels: int in_channels: int
out_channels: int = None out_channels: int = None
dropout: float = 0.0 dropout: float = 0.0
groups: int = 32
use_nin_shortcut: bool = None use_nin_shortcut: bool = None
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
def setup(self): def setup(self):
out_channels = self.in_channels if self.out_channels is None else self.out_channels out_channels = self.in_channels if self.out_channels is None else self.out_channels
self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-6) self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
self.conv1 = nn.Conv( self.conv1 = nn.Conv(
out_channels, out_channels,
kernel_size=(3, 3), kernel_size=(3, 3),
...@@ -143,7 +146,7 @@ class FlaxResnetBlock2D(nn.Module): ...@@ -143,7 +146,7 @@ class FlaxResnetBlock2D(nn.Module):
dtype=self.dtype, dtype=self.dtype,
) )
self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6) self.norm2 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
self.dropout_layer = nn.Dropout(self.dropout) self.dropout_layer = nn.Dropout(self.dropout)
self.conv2 = nn.Conv( self.conv2 = nn.Conv(
out_channels, out_channels,
...@@ -191,12 +194,15 @@ class FlaxAttentionBlock(nn.Module): ...@@ -191,12 +194,15 @@ class FlaxAttentionBlock(nn.Module):
Input channels Input channels
num_head_channels (:obj:`int`, *optional*, defaults to `None`): num_head_channels (:obj:`int`, *optional*, defaults to `None`):
Number of attention heads Number of attention heads
num_groups (:obj:`int`, *optional*, defaults to `32`):
The number of groups to use for group norm
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype` Parameters `dtype`
""" """
channels: int channels: int
num_head_channels: int = None num_head_channels: int = None
num_groups: int = 32
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
def setup(self): def setup(self):
...@@ -204,7 +210,7 @@ class FlaxAttentionBlock(nn.Module): ...@@ -204,7 +210,7 @@ class FlaxAttentionBlock(nn.Module):
dense = partial(nn.Dense, self.channels, dtype=self.dtype) dense = partial(nn.Dense, self.channels, dtype=self.dtype)
self.group_norm = nn.GroupNorm(num_groups=32, epsilon=1e-6) self.group_norm = nn.GroupNorm(num_groups=self.num_groups, epsilon=1e-6)
self.query, self.key, self.value = dense(), dense(), dense() self.query, self.key, self.value = dense(), dense(), dense()
self.proj_attn = dense() self.proj_attn = dense()
...@@ -264,6 +270,8 @@ class FlaxDownEncoderBlock2D(nn.Module): ...@@ -264,6 +270,8 @@ class FlaxDownEncoderBlock2D(nn.Module):
Dropout rate Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1): num_layers (:obj:`int`, *optional*, defaults to 1):
Number of Resnet layer block Number of Resnet layer block
resnet_groups (:obj:`int`, *optional*, defaults to `32`):
The number of groups to use for the Resnet block group norm
add_downsample (:obj:`bool`, *optional*, defaults to `True`): add_downsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add downsample layer Whether to add downsample layer
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
...@@ -273,6 +281,7 @@ class FlaxDownEncoderBlock2D(nn.Module): ...@@ -273,6 +281,7 @@ class FlaxDownEncoderBlock2D(nn.Module):
out_channels: int out_channels: int
dropout: float = 0.0 dropout: float = 0.0
num_layers: int = 1 num_layers: int = 1
resnet_groups: int = 32
add_downsample: bool = True add_downsample: bool = True
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
...@@ -285,6 +294,7 @@ class FlaxDownEncoderBlock2D(nn.Module): ...@@ -285,6 +294,7 @@ class FlaxDownEncoderBlock2D(nn.Module):
in_channels=in_channels, in_channels=in_channels,
out_channels=self.out_channels, out_channels=self.out_channels,
dropout=self.dropout, dropout=self.dropout,
groups=self.resnet_groups,
dtype=self.dtype, dtype=self.dtype,
) )
resnets.append(res_block) resnets.append(res_block)
...@@ -303,9 +313,9 @@ class FlaxDownEncoderBlock2D(nn.Module): ...@@ -303,9 +313,9 @@ class FlaxDownEncoderBlock2D(nn.Module):
return hidden_states return hidden_states
class FlaxUpEncoderBlock2D(nn.Module): class FlaxUpDecoderBlock2D(nn.Module):
r""" r"""
Flax Resnet blocks-based Encoder block for diffusion-based VAE. Flax Resnet blocks-based Decoder block for diffusion-based VAE.
Parameters: Parameters:
in_channels (:obj:`int`): in_channels (:obj:`int`):
...@@ -316,8 +326,10 @@ class FlaxUpEncoderBlock2D(nn.Module): ...@@ -316,8 +326,10 @@ class FlaxUpEncoderBlock2D(nn.Module):
Dropout rate Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1): num_layers (:obj:`int`, *optional*, defaults to 1):
Number of Resnet layer block Number of Resnet layer block
add_downsample (:obj:`bool`, *optional*, defaults to `True`): resnet_groups (:obj:`int`, *optional*, defaults to `32`):
Whether to add downsample layer The number of groups to use for the Resnet block group norm
add_upsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add upsample layer
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype` Parameters `dtype`
""" """
...@@ -325,6 +337,7 @@ class FlaxUpEncoderBlock2D(nn.Module): ...@@ -325,6 +337,7 @@ class FlaxUpEncoderBlock2D(nn.Module):
out_channels: int out_channels: int
dropout: float = 0.0 dropout: float = 0.0
num_layers: int = 1 num_layers: int = 1
resnet_groups: int = 32
add_upsample: bool = True add_upsample: bool = True
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
...@@ -336,6 +349,7 @@ class FlaxUpEncoderBlock2D(nn.Module): ...@@ -336,6 +349,7 @@ class FlaxUpEncoderBlock2D(nn.Module):
in_channels=in_channels, in_channels=in_channels,
out_channels=self.out_channels, out_channels=self.out_channels,
dropout=self.dropout, dropout=self.dropout,
groups=self.resnet_groups,
dtype=self.dtype, dtype=self.dtype,
) )
resnets.append(res_block) resnets.append(res_block)
...@@ -366,6 +380,8 @@ class FlaxUNetMidBlock2D(nn.Module): ...@@ -366,6 +380,8 @@ class FlaxUNetMidBlock2D(nn.Module):
Dropout rate Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1): num_layers (:obj:`int`, *optional*, defaults to 1):
Number of Resnet layer block Number of Resnet layer block
resnet_groups (:obj:`int`, *optional*, defaults to `32`):
The number of groups to use for the Resnet and Attention block group norm
attn_num_head_channels (:obj:`int`, *optional*, defaults to `1`): attn_num_head_channels (:obj:`int`, *optional*, defaults to `1`):
Number of attention heads for each attention block Number of attention heads for each attention block
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
...@@ -374,16 +390,20 @@ class FlaxUNetMidBlock2D(nn.Module): ...@@ -374,16 +390,20 @@ class FlaxUNetMidBlock2D(nn.Module):
in_channels: int in_channels: int
dropout: float = 0.0 dropout: float = 0.0
num_layers: int = 1 num_layers: int = 1
resnet_groups: int = 32
attn_num_head_channels: int = 1 attn_num_head_channels: int = 1
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
def setup(self): def setup(self):
resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32)
# there is always at least one resnet # there is always at least one resnet
resnets = [ resnets = [
FlaxResnetBlock2D( FlaxResnetBlock2D(
in_channels=self.in_channels, in_channels=self.in_channels,
out_channels=self.in_channels, out_channels=self.in_channels,
dropout=self.dropout, dropout=self.dropout,
groups=resnet_groups,
dtype=self.dtype, dtype=self.dtype,
) )
] ]
...@@ -392,7 +412,10 @@ class FlaxUNetMidBlock2D(nn.Module): ...@@ -392,7 +412,10 @@ class FlaxUNetMidBlock2D(nn.Module):
for _ in range(self.num_layers): for _ in range(self.num_layers):
attn_block = FlaxAttentionBlock( attn_block = FlaxAttentionBlock(
channels=self.in_channels, num_head_channels=self.attn_num_head_channels, dtype=self.dtype channels=self.in_channels,
num_head_channels=self.attn_num_head_channels,
num_groups=resnet_groups,
dtype=self.dtype,
) )
attentions.append(attn_block) attentions.append(attn_block)
...@@ -400,6 +423,7 @@ class FlaxUNetMidBlock2D(nn.Module): ...@@ -400,6 +423,7 @@ class FlaxUNetMidBlock2D(nn.Module):
in_channels=self.in_channels, in_channels=self.in_channels,
out_channels=self.in_channels, out_channels=self.in_channels,
dropout=self.dropout, dropout=self.dropout,
groups=resnet_groups,
dtype=self.dtype, dtype=self.dtype,
) )
resnets.append(res_block) resnets.append(res_block)
...@@ -441,7 +465,7 @@ class FlaxEncoder(nn.Module): ...@@ -441,7 +465,7 @@ class FlaxEncoder(nn.Module):
Tuple containing the number of output channels for each block Tuple containing the number of output channels for each block
layers_per_block (:obj:`int`, *optional*, defaults to `2`): layers_per_block (:obj:`int`, *optional*, defaults to `2`):
Number of Resnet layer for each block Number of Resnet layer for each block
norm_num_groups (:obj:`int`, *optional*, defaults to `2`): norm_num_groups (:obj:`int`, *optional*, defaults to `32`):
norm num group norm num group
act_fn (:obj:`str`, *optional*, defaults to `silu`): act_fn (:obj:`str`, *optional*, defaults to `silu`):
Activation function Activation function
...@@ -483,6 +507,7 @@ class FlaxEncoder(nn.Module): ...@@ -483,6 +507,7 @@ class FlaxEncoder(nn.Module):
in_channels=input_channel, in_channels=input_channel,
out_channels=output_channel, out_channels=output_channel,
num_layers=self.layers_per_block, num_layers=self.layers_per_block,
resnet_groups=self.norm_num_groups,
add_downsample=not is_final_block, add_downsample=not is_final_block,
dtype=self.dtype, dtype=self.dtype,
) )
...@@ -491,12 +516,15 @@ class FlaxEncoder(nn.Module): ...@@ -491,12 +516,15 @@ class FlaxEncoder(nn.Module):
# middle # middle
self.mid_block = FlaxUNetMidBlock2D( self.mid_block = FlaxUNetMidBlock2D(
in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype in_channels=block_out_channels[-1],
resnet_groups=self.norm_num_groups,
attn_num_head_channels=None,
dtype=self.dtype,
) )
# end # end
conv_out_channels = 2 * self.out_channels if self.double_z else self.out_channels conv_out_channels = 2 * self.out_channels if self.double_z else self.out_channels
self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6) self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6)
self.conv_out = nn.Conv( self.conv_out = nn.Conv(
conv_out_channels, conv_out_channels,
kernel_size=(3, 3), kernel_size=(3, 3),
...@@ -581,7 +609,10 @@ class FlaxDecoder(nn.Module): ...@@ -581,7 +609,10 @@ class FlaxDecoder(nn.Module):
# middle # middle
self.mid_block = FlaxUNetMidBlock2D( self.mid_block = FlaxUNetMidBlock2D(
in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype in_channels=block_out_channels[-1],
resnet_groups=self.norm_num_groups,
attn_num_head_channels=None,
dtype=self.dtype,
) )
# upsampling # upsampling
...@@ -594,10 +625,11 @@ class FlaxDecoder(nn.Module): ...@@ -594,10 +625,11 @@ class FlaxDecoder(nn.Module):
is_final_block = i == len(block_out_channels) - 1 is_final_block = i == len(block_out_channels) - 1
up_block = FlaxUpEncoderBlock2D( up_block = FlaxUpDecoderBlock2D(
in_channels=prev_output_channel, in_channels=prev_output_channel,
out_channels=output_channel, out_channels=output_channel,
num_layers=self.layers_per_block + 1, num_layers=self.layers_per_block + 1,
resnet_groups=self.norm_num_groups,
add_upsample=not is_final_block, add_upsample=not is_final_block,
dtype=self.dtype, dtype=self.dtype,
) )
...@@ -607,7 +639,7 @@ class FlaxDecoder(nn.Module): ...@@ -607,7 +639,7 @@ class FlaxDecoder(nn.Module):
self.up_blocks = up_blocks self.up_blocks = up_blocks
# end # end
self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6) self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6)
self.conv_out = nn.Conv( self.conv_out = nn.Conv(
self.out_channels, self.out_channels,
kernel_size=(3, 3), kernel_size=(3, 3),
......
...@@ -14,6 +14,8 @@ import PIL.ImageOps ...@@ -14,6 +14,8 @@ import PIL.ImageOps
import requests import requests
from packaging import version from packaging import version
from .import_utils import is_flax_available
global_rng = random.Random() global_rng = random.Random()
torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = "cuda" if torch.cuda.is_available() else "cpu"
...@@ -89,6 +91,13 @@ def slow(test_case): ...@@ -89,6 +91,13 @@ def slow(test_case):
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
def require_flax(test_case):
"""
Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
"""
return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)
def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
""" """
Args: Args:
......
from diffusers.utils import is_flax_available
from diffusers.utils.testing_utils import require_flax
if is_flax_available():
import jax
@require_flax
class FlaxModelTesterMixin:
def test_output(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
jax.lax.stop_gradient(variables)
output = model.apply(variables, inputs_dict["sample"])
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_forward_with_norm_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["norm_num_groups"] = 16
init_dict["block_out_channels"] = (16, 32)
model = self.model_class(**init_dict)
variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
jax.lax.stop_gradient(variables)
output = model.apply(variables, inputs_dict["sample"])
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
import unittest
from diffusers import FlaxAutoencoderKL
from diffusers.utils import is_flax_available
from diffusers.utils.testing_utils import require_flax
from .test_modeling_common_flax import FlaxModelTesterMixin
if is_flax_available():
import jax
@require_flax
class FlaxAutoencoderKLTests(FlaxModelTesterMixin, unittest.TestCase):
model_class = FlaxAutoencoderKL
@property
def dummy_input(self):
batch_size = 4
num_channels = 3
sizes = (32, 32)
prng_key = jax.random.PRNGKey(0)
image = jax.random.uniform(prng_key, ((batch_size, num_channels) + sizes))
return {"sample": image, "prng_key": prng_key}
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": [32, 64],
"in_channels": 3,
"out_channels": 3,
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
"latent_channels": 4,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment