"docs/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "c10d6854c08e722f68a3a4932b347df938eb98e1"
Unverified Commit 6bd30ba7 authored by Miguel Farinha's avatar Miguel Farinha Committed by GitHub
Browse files

Allow image resolutions multiple of 8 instead of 64 in SVD pipeline (#6646)



allow resolutions not multiple of 64 in SVD
Co-authored-by: default avatarMiguel Farinha <mignha@CSL15958.local>
Co-authored-by: default avatarhlky <hlky@hlky.ac>
parent cef0e367
...@@ -1375,6 +1375,7 @@ class UpBlockSpatioTemporal(nn.Module): ...@@ -1375,6 +1375,7 @@ class UpBlockSpatioTemporal(nn.Module):
res_hidden_states_tuple: Tuple[torch.Tensor, ...], res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None,
image_only_indicator: Optional[torch.Tensor] = None, image_only_indicator: Optional[torch.Tensor] = None,
upsample_size: Optional[int] = None,
) -> torch.Tensor: ) -> torch.Tensor:
for resnet in self.resnets: for resnet in self.resnets:
# pop res hidden states # pop res hidden states
...@@ -1415,7 +1416,7 @@ class UpBlockSpatioTemporal(nn.Module): ...@@ -1415,7 +1416,7 @@ class UpBlockSpatioTemporal(nn.Module):
if self.upsamplers is not None: if self.upsamplers is not None:
for upsampler in self.upsamplers: for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states) hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states return hidden_states
...@@ -1485,6 +1486,7 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module): ...@@ -1485,6 +1486,7 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
temb: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
image_only_indicator: Optional[torch.Tensor] = None, image_only_indicator: Optional[torch.Tensor] = None,
upsample_size: Optional[int] = None,
) -> torch.Tensor: ) -> torch.Tensor:
for resnet, attn in zip(self.resnets, self.attentions): for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states # pop res hidden states
...@@ -1533,6 +1535,6 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module): ...@@ -1533,6 +1535,6 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
if self.upsamplers is not None: if self.upsamplers is not None:
for upsampler in self.upsamplers: for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states) hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states return hidden_states
...@@ -382,6 +382,20 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL ...@@ -382,6 +382,20 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is
returned, otherwise a `tuple` is returned where the first element is the sample tensor. returned, otherwise a `tuple` is returned where the first element is the sample tensor.
""" """
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
default_overall_up_factor = 2**self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# 1. time # 1. time
timesteps = timestep timesteps = timestep
if not torch.is_tensor(timesteps): if not torch.is_tensor(timesteps):
...@@ -457,15 +471,23 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL ...@@ -457,15 +471,23 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
# 5. up # 5. up
for i, upsample_block in enumerate(self.up_blocks): for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :] res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block( sample = upsample_block(
hidden_states=sample, hidden_states=sample,
temb=emb, temb=emb,
res_hidden_states_tuple=res_samples, res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
upsample_size=upsample_size,
image_only_indicator=image_only_indicator, image_only_indicator=image_only_indicator,
) )
else: else:
...@@ -473,6 +495,7 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL ...@@ -473,6 +495,7 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
hidden_states=sample, hidden_states=sample,
temb=emb, temb=emb,
res_hidden_states_tuple=res_samples, res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
image_only_indicator=image_only_indicator, image_only_indicator=image_only_indicator,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment