"...resnet50_tensorflow.git" did not exist on "5d39608b5bcbeb6c0c4560d975c48b8560b1b31d"
Commit 5b4e3127 authored by comfyanonymous's avatar comfyanonymous
Browse files

Use inplace operations for less OOM issues.

parent 3fd87cbd
...@@ -96,6 +96,7 @@ class ResnetBlock(nn.Module): ...@@ -96,6 +96,7 @@ class ResnetBlock(nn.Module):
self.out_channels = out_channels self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut self.use_conv_shortcut = conv_shortcut
self.swish = torch.nn.SiLU(inplace=True)
self.norm1 = Normalize(in_channels) self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(in_channels, self.conv1 = torch.nn.Conv2d(in_channels,
out_channels, out_channels,
...@@ -106,7 +107,7 @@ class ResnetBlock(nn.Module): ...@@ -106,7 +107,7 @@ class ResnetBlock(nn.Module):
self.temb_proj = torch.nn.Linear(temb_channels, self.temb_proj = torch.nn.Linear(temb_channels,
out_channels) out_channels)
self.norm2 = Normalize(out_channels) self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout) self.dropout = torch.nn.Dropout(dropout, inplace=True)
self.conv2 = torch.nn.Conv2d(out_channels, self.conv2 = torch.nn.Conv2d(out_channels,
out_channels, out_channels,
kernel_size=3, kernel_size=3,
...@@ -129,14 +130,14 @@ class ResnetBlock(nn.Module): ...@@ -129,14 +130,14 @@ class ResnetBlock(nn.Module):
def forward(self, x, temb): def forward(self, x, temb):
h = x h = x
h = self.norm1(h) h = self.norm1(h)
h = nonlinearity(h) h = self.swish(h)
h = self.conv1(h) h = self.conv1(h)
if temb is not None: if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] h = h + self.temb_proj(self.swish(temb))[:,:,None,None]
h = self.norm2(h) h = self.norm2(h)
h = nonlinearity(h) h = self.swish(h)
h = self.dropout(h) h = self.dropout(h)
h = self.conv2(h) h = self.conv2(h)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment