"docs/git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "b0dd0c882173f74517160ece7f79104e37bc6b2c"
Unverified Commit dc7cd893 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Add resnet_time_scale_shift to VD layers (#1757)

parent 88907588
...@@ -33,6 +33,7 @@ def get_down_block( ...@@ -33,6 +33,7 @@ def get_down_block(
use_linear_projection=False, use_linear_projection=False,
only_cross_attention=False, only_cross_attention=False,
upcast_attention=False, upcast_attention=False,
resnet_time_scale_shift="default",
): ):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlockFlat": if down_block_type == "DownBlockFlat":
...@@ -46,6 +47,7 @@ def get_down_block( ...@@ -46,6 +47,7 @@ def get_down_block(
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups, resnet_groups=resnet_groups,
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
) )
elif down_block_type == "CrossAttnDownBlockFlat": elif down_block_type == "CrossAttnDownBlockFlat":
if cross_attention_dim is None: if cross_attention_dim is None:
...@@ -65,6 +67,7 @@ def get_down_block( ...@@ -65,6 +67,7 @@ def get_down_block(
dual_cross_attention=dual_cross_attention, dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention, only_cross_attention=only_cross_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
) )
raise ValueError(f"{down_block_type} is not supported.") raise ValueError(f"{down_block_type} is not supported.")
...@@ -86,6 +89,7 @@ def get_up_block( ...@@ -86,6 +89,7 @@ def get_up_block(
use_linear_projection=False, use_linear_projection=False,
only_cross_attention=False, only_cross_attention=False,
upcast_attention=False, upcast_attention=False,
resnet_time_scale_shift="default",
): ):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlockFlat": if up_block_type == "UpBlockFlat":
...@@ -99,6 +103,7 @@ def get_up_block( ...@@ -99,6 +103,7 @@ def get_up_block(
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups, resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
) )
elif up_block_type == "CrossAttnUpBlockFlat": elif up_block_type == "CrossAttnUpBlockFlat":
if cross_attention_dim is None: if cross_attention_dim is None:
...@@ -118,6 +123,7 @@ def get_up_block( ...@@ -118,6 +123,7 @@ def get_up_block(
dual_cross_attention=dual_cross_attention, dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention, only_cross_attention=only_cross_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
) )
raise ValueError(f"{up_block_type} is not supported.") raise ValueError(f"{up_block_type} is not supported.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment