assertall(s%patch_size==0forsinlatents_size),f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), but got {latents_size}."
rope_sizes=[s//patch_sizeforsinlatents_size]
elifisinstance(patch_size,list):
assertall(s%patch_size[idx]==0foridx,sinenumerate(latents_size)),f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), but got {latents_size}."
# If attn head dim is not defined, we default it to the number of heads
ifattention_head_dimisNone:
logger.warn(f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}.")
raiseValueError(f"{down_block_type} does not exist.")
defget_up_block3d(
up_block_type:str,
num_layers:int,
in_channels:int,
out_channels:int,
prev_output_channel:int,
temb_channels:int,
add_upsample:bool,
upsample_scale_factor:Tuple,
resnet_eps:float,
resnet_act_fn:str,
resolution_idx:Optional[int]=None,
transformer_layers_per_block:int=1,
num_attention_heads:Optional[int]=None,
resnet_groups:Optional[int]=None,
cross_attention_dim:Optional[int]=None,
dual_cross_attention:bool=False,
use_linear_projection:bool=False,
only_cross_attention:bool=False,
upcast_attention:bool=False,
resnet_time_scale_shift:str="default",
attention_type:str="default",
resnet_skip_time_act:bool=False,
resnet_out_scale_factor:float=1.0,
cross_attention_norm:Optional[str]=None,
attention_head_dim:Optional[int]=None,
upsample_type:Optional[str]=None,
dropout:float=0.0,
)->nn.Module:
# If attn head dim is not defined, we default it to the number of heads
ifattention_head_dimisNone:
logger.warn(f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}.")