"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "0a94ce70dec5b90613b460fd42cc1a0c4d1df101"
Unverified Commit b33b64f5 authored by Martin Müller's avatar Martin Müller Committed by GitHub
Browse files

Make mid block optional for flax UNet (#7083)

* make mid block optional for flax UNet

* make style
parent 9d974407
...@@ -75,6 +75,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -75,6 +75,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
The tuple of downsample blocks to use. The tuple of downsample blocks to use.
up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`): up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`):
The tuple of upsample blocks to use. The tuple of upsample blocks to use.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`. If `None`, the mid block layer is skipped.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block. The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2): layers_per_block (`int`, *optional*, defaults to 2):
...@@ -107,6 +109,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -107,6 +109,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
"DownBlock2D", "DownBlock2D",
) )
up_block_types: Tuple[str, ...] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D") up_block_types: Tuple[str, ...] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn"
only_cross_attention: Union[bool, Tuple[bool]] = False only_cross_attention: Union[bool, Tuple[bool]] = False
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280) block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280)
layers_per_block: int = 2 layers_per_block: int = 2
...@@ -252,16 +255,21 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -252,16 +255,21 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
self.down_blocks = down_blocks self.down_blocks = down_blocks
# mid # mid
self.mid_block = FlaxUNetMidBlock2DCrossAttn( if self.config.mid_block_type == "UNetMidBlock2DCrossAttn":
in_channels=block_out_channels[-1], self.mid_block = FlaxUNetMidBlock2DCrossAttn(
dropout=self.dropout, in_channels=block_out_channels[-1],
num_attention_heads=num_attention_heads[-1], dropout=self.dropout,
transformer_layers_per_block=transformer_layers_per_block[-1], num_attention_heads=num_attention_heads[-1],
use_linear_projection=self.use_linear_projection, transformer_layers_per_block=transformer_layers_per_block[-1],
use_memory_efficient_attention=self.use_memory_efficient_attention, use_linear_projection=self.use_linear_projection,
split_head_dim=self.split_head_dim, use_memory_efficient_attention=self.use_memory_efficient_attention,
dtype=self.dtype, split_head_dim=self.split_head_dim,
) dtype=self.dtype,
)
elif self.config.mid_block_type is None:
self.mid_block = None
else:
raise ValueError(f"Unexpected mid_block_type {self.config.mid_block_type}")
# up # up
up_blocks = [] up_blocks = []
...@@ -412,7 +420,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -412,7 +420,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
down_block_res_samples = new_down_block_res_samples down_block_res_samples = new_down_block_res_samples
# 4. mid # 4. mid
sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train) if self.mid_block is not None:
sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
if mid_block_additional_residual is not None: if mid_block_additional_residual is not None:
sample += mid_block_additional_residual sample += mid_block_additional_residual
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment