"vscode:/vscode.git/clone" did not exist on "9088c6359299978390430821c23a2cfd0cb8ffeb"
Unverified Commit 6bd30ba7 authored by Miguel Farinha's avatar Miguel Farinha Committed by GitHub
Browse files

Allow image resolutions multiple of 8 instead of 64 in SVD pipeline (#6646)



allow resolutions not multiple of 64 in SVD
Co-authored-by: default avatarMiguel Farinha <mignha@CSL15958.local>
Co-authored-by: default avatarhlky <hlky@hlky.ac>
parent cef0e367
......@@ -1375,6 +1375,7 @@ class UpBlockSpatioTemporal(nn.Module):
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.Tensor] = None,
image_only_indicator: Optional[torch.Tensor] = None,
upsample_size: Optional[int] = None,
) -> torch.Tensor:
for resnet in self.resnets:
# pop res hidden states
......@@ -1415,7 +1416,7 @@ class UpBlockSpatioTemporal(nn.Module):
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
......@@ -1485,6 +1486,7 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
image_only_indicator: Optional[torch.Tensor] = None,
upsample_size: Optional[int] = None,
) -> torch.Tensor:
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
......@@ -1533,6 +1535,6 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
......@@ -382,6 +382,20 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is
returned, otherwise a `tuple` is returned where the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
default_overall_up_factor = 2**self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
......@@ -457,15 +471,23 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
# 5. up
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
upsample_size=upsample_size,
image_only_indicator=image_only_indicator,
)
else:
......@@ -473,6 +495,7 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
image_only_indicator=image_only_indicator,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment