Commit 5b4e3127 authored by comfyanonymous's avatar comfyanonymous
Browse files

Use inplace operations for less OOM issues.

parent 3fd87cbd
......@@ -96,6 +96,7 @@ class ResnetBlock(nn.Module):
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.swish = torch.nn.SiLU(inplace=True)
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(in_channels,
out_channels,
......@@ -106,7 +107,7 @@ class ResnetBlock(nn.Module):
self.temb_proj = torch.nn.Linear(temb_channels,
out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.dropout = torch.nn.Dropout(dropout, inplace=True)
self.conv2 = torch.nn.Conv2d(out_channels,
out_channels,
kernel_size=3,
......@@ -129,14 +130,14 @@ class ResnetBlock(nn.Module):
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.swish(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
h = h + self.temb_proj(self.swish(temb))[:,:,None,None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.swish(h)
h = self.dropout(h)
h = self.conv2(h)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment