Unverified Commit b33b64f5 authored by Martin Müller's avatar Martin Müller Committed by GitHub
Browse files

Make mid block optional for flax UNet (#7083)

* make mid block optional for flax UNet

* make style
parent 9d974407
......@@ -75,6 +75,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
The tuple of downsample blocks to use.
up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`):
The tuple of upsample blocks to use.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`. If `None`, the mid block layer is skipped.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
......@@ -107,6 +109,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
"DownBlock2D",
)
up_block_types: Tuple[str, ...] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn"
only_cross_attention: Union[bool, Tuple[bool]] = False
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280)
layers_per_block: int = 2
......@@ -252,6 +255,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
self.down_blocks = down_blocks
# mid
if self.config.mid_block_type == "UNetMidBlock2DCrossAttn":
self.mid_block = FlaxUNetMidBlock2DCrossAttn(
in_channels=block_out_channels[-1],
dropout=self.dropout,
......@@ -262,6 +266,10 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
split_head_dim=self.split_head_dim,
dtype=self.dtype,
)
elif self.config.mid_block_type is None:
self.mid_block = None
else:
raise ValueError(f"Unexpected mid_block_type {self.config.mid_block_type}")
# up
up_blocks = []
......@@ -412,6 +420,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
down_block_res_samples = new_down_block_res_samples
# 4. mid
if self.mid_block is not None:
sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
if mid_block_additional_residual is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment