Unverified Commit 36f43ea7 authored by Will Rice's avatar Will Rice Committed by GitHub
Browse files

Add upsample_size to AttnUpBlock2D, AttnDownBlock2D (#3275)

The argument `upsample_size` needs to be added to these modules to allow compatibility with other blocks that require this argument.
parent 27522b58
...@@ -734,7 +734,7 @@ class AttnDownBlock2D(nn.Module): ...@@ -734,7 +734,7 @@ class AttnDownBlock2D(nn.Module):
else: else:
self.downsamplers = None self.downsamplers = None
def forward(self, hidden_states, temb=None): def forward(self, hidden_states, temb=None, upsample_size=None):
output_states = () output_states = ()
for resnet, attn in zip(self.resnets, self.attentions): for resnet, attn in zip(self.resnets, self.attentions):
...@@ -1720,7 +1720,7 @@ class AttnUpBlock2D(nn.Module): ...@@ -1720,7 +1720,7 @@ class AttnUpBlock2D(nn.Module):
else: else:
self.upsamplers = None self.upsamplers = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None): def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
for resnet, attn in zip(self.resnets, self.attentions): for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states # pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states = res_hidden_states_tuple[-1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment