Unverified Commit 3fc8ef72 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Replace `dropout_prob` by `dropout` in `vae` (#595)

replace `dropout_prob` by `dropout` in `vae`
parent 86856993
......@@ -89,7 +89,7 @@ class FlaxDownsample2D(nn.Module):
class FlaxResnetBlock2D(nn.Module):
in_channels: int
out_channels: int = None
dropout_prob: float = 0.0
dropout: float = 0.0
use_nin_shortcut: bool = None
dtype: jnp.dtype = jnp.float32
......@@ -106,7 +106,7 @@ class FlaxResnetBlock2D(nn.Module):
)
self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
self.dropout = nn.Dropout(self.dropout_prob)
self.dropout_layer = nn.Dropout(self.dropout)
self.conv2 = nn.Conv(
out_channels,
kernel_size=(3, 3),
......@@ -135,7 +135,7 @@ class FlaxResnetBlock2D(nn.Module):
hidden_states = self.norm2(hidden_states)
hidden_states = nn.swish(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic)
hidden_states = self.dropout_layer(hidden_states, deterministic)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
......@@ -217,7 +217,7 @@ class FlaxDownEncoderBlock2D(nn.Module):
res_block = FlaxResnetBlock2D(
in_channels=in_channels,
out_channels=self.out_channels,
dropout_prob=self.dropout,
dropout=self.dropout,
dtype=self.dtype,
)
resnets.append(res_block)
......@@ -251,7 +251,7 @@ class FlaxUpEncoderBlock2D(nn.Module):
res_block = FlaxResnetBlock2D(
in_channels=in_channels,
out_channels=self.out_channels,
dropout_prob=self.dropout,
dropout=self.dropout,
dtype=self.dtype,
)
resnets.append(res_block)
......@@ -284,7 +284,7 @@ class FlaxUNetMidBlock2D(nn.Module):
FlaxResnetBlock2D(
in_channels=self.in_channels,
out_channels=self.in_channels,
dropout_prob=self.dropout,
dropout=self.dropout,
dtype=self.dtype,
)
]
......@@ -300,7 +300,7 @@ class FlaxUNetMidBlock2D(nn.Module):
res_block = FlaxResnetBlock2D(
in_channels=self.in_channels,
out_channels=self.in_channels,
dropout_prob=self.dropout,
dropout=self.dropout,
dtype=self.dtype,
)
resnets.append(res_block)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment