"vscode:/vscode.git/clone" did not exist on "7facedda38da928843e9ed0de1810d45ce1b9224"
Unverified Commit 36f43ea7 authored by Will Rice's avatar Will Rice Committed by GitHub
Browse files

Add upsample_size to AttnUpBlock2D, AttnDownBlock2D (#3275)

The argument `upsample_size` needs to be added to these modules to allow compatibility with other blocks that require this argument.
parent 27522b58
......@@ -734,7 +734,7 @@ class AttnDownBlock2D(nn.Module):
else:
self.downsamplers = None
def forward(self, hidden_states, temb=None):
def forward(self, hidden_states, temb=None, upsample_size=None):
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
......@@ -1720,7 +1720,7 @@ class AttnUpBlock2D(nn.Module):
else:
self.upsamplers = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment