Unverified Commit 3fc8ef72 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Replace `dropout_prob` by `dropout` in `vae` (#595)

replace `dropout_prob` by `dropout` in `vae`
parent 86856993
...@@ -89,7 +89,7 @@ class FlaxDownsample2D(nn.Module): ...@@ -89,7 +89,7 @@ class FlaxDownsample2D(nn.Module):
class FlaxResnetBlock2D(nn.Module): class FlaxResnetBlock2D(nn.Module):
in_channels: int in_channels: int
out_channels: int = None out_channels: int = None
dropout_prob: float = 0.0 dropout: float = 0.0
use_nin_shortcut: bool = None use_nin_shortcut: bool = None
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
...@@ -106,7 +106,7 @@ class FlaxResnetBlock2D(nn.Module): ...@@ -106,7 +106,7 @@ class FlaxResnetBlock2D(nn.Module):
) )
self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6) self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
self.dropout = nn.Dropout(self.dropout_prob) self.dropout_layer = nn.Dropout(self.dropout)
self.conv2 = nn.Conv( self.conv2 = nn.Conv(
out_channels, out_channels,
kernel_size=(3, 3), kernel_size=(3, 3),
...@@ -135,7 +135,7 @@ class FlaxResnetBlock2D(nn.Module): ...@@ -135,7 +135,7 @@ class FlaxResnetBlock2D(nn.Module):
hidden_states = self.norm2(hidden_states) hidden_states = self.norm2(hidden_states)
hidden_states = nn.swish(hidden_states) hidden_states = nn.swish(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic) hidden_states = self.dropout_layer(hidden_states, deterministic)
hidden_states = self.conv2(hidden_states) hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None: if self.conv_shortcut is not None:
...@@ -217,7 +217,7 @@ class FlaxDownEncoderBlock2D(nn.Module): ...@@ -217,7 +217,7 @@ class FlaxDownEncoderBlock2D(nn.Module):
res_block = FlaxResnetBlock2D( res_block = FlaxResnetBlock2D(
in_channels=in_channels, in_channels=in_channels,
out_channels=self.out_channels, out_channels=self.out_channels,
dropout_prob=self.dropout, dropout=self.dropout,
dtype=self.dtype, dtype=self.dtype,
) )
resnets.append(res_block) resnets.append(res_block)
...@@ -251,7 +251,7 @@ class FlaxUpEncoderBlock2D(nn.Module): ...@@ -251,7 +251,7 @@ class FlaxUpEncoderBlock2D(nn.Module):
res_block = FlaxResnetBlock2D( res_block = FlaxResnetBlock2D(
in_channels=in_channels, in_channels=in_channels,
out_channels=self.out_channels, out_channels=self.out_channels,
dropout_prob=self.dropout, dropout=self.dropout,
dtype=self.dtype, dtype=self.dtype,
) )
resnets.append(res_block) resnets.append(res_block)
...@@ -284,7 +284,7 @@ class FlaxUNetMidBlock2D(nn.Module): ...@@ -284,7 +284,7 @@ class FlaxUNetMidBlock2D(nn.Module):
FlaxResnetBlock2D( FlaxResnetBlock2D(
in_channels=self.in_channels, in_channels=self.in_channels,
out_channels=self.in_channels, out_channels=self.in_channels,
dropout_prob=self.dropout, dropout=self.dropout,
dtype=self.dtype, dtype=self.dtype,
) )
] ]
...@@ -300,7 +300,7 @@ class FlaxUNetMidBlock2D(nn.Module): ...@@ -300,7 +300,7 @@ class FlaxUNetMidBlock2D(nn.Module):
res_block = FlaxResnetBlock2D( res_block = FlaxResnetBlock2D(
in_channels=self.in_channels, in_channels=self.in_channels,
out_channels=self.in_channels, out_channels=self.in_channels,
dropout_prob=self.dropout, dropout=self.dropout,
dtype=self.dtype, dtype=self.dtype,
) )
resnets.append(res_block) resnets.append(res_block)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment