Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
707341ae
Commit
707341ae
authored
Apr 08, 2023
by
William Berman
Committed by
Will Berman
Apr 09, 2023
Browse files
resnet skip time activation and output scale factor
parent
26b4319a
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
49 additions
and
1 deletion
+49
-1
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+5
-1
src/diffusers/models/unet_2d_blocks.py
src/diffusers/models/unet_2d_blocks.py
+27
-0
src/diffusers/models/unet_2d_condition.py
src/diffusers/models/unet_2d_condition.py
+7
-0
src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
...users/pipelines/versatile_diffusion/modeling_text_unet.py
+10
-0
No files found.
src/diffusers/models/resnet.py
View file @
707341ae
...
@@ -459,6 +459,7 @@ class ResnetBlock2D(nn.Module):
...
@@ -459,6 +459,7 @@ class ResnetBlock2D(nn.Module):
pre_norm
=
True
,
pre_norm
=
True
,
eps
=
1e-6
,
eps
=
1e-6
,
non_linearity
=
"swish"
,
non_linearity
=
"swish"
,
skip_time_act
=
False
,
time_embedding_norm
=
"default"
,
# default, scale_shift, ada_group
time_embedding_norm
=
"default"
,
# default, scale_shift, ada_group
kernel
=
None
,
kernel
=
None
,
output_scale_factor
=
1.0
,
output_scale_factor
=
1.0
,
...
@@ -479,6 +480,7 @@ class ResnetBlock2D(nn.Module):
...
@@ -479,6 +480,7 @@ class ResnetBlock2D(nn.Module):
self
.
down
=
down
self
.
down
=
down
self
.
output_scale_factor
=
output_scale_factor
self
.
output_scale_factor
=
output_scale_factor
self
.
time_embedding_norm
=
time_embedding_norm
self
.
time_embedding_norm
=
time_embedding_norm
self
.
skip_time_act
=
skip_time_act
if
groups_out
is
None
:
if
groups_out
is
None
:
groups_out
=
groups
groups_out
=
groups
...
@@ -570,7 +572,9 @@ class ResnetBlock2D(nn.Module):
...
@@ -570,7 +572,9 @@ class ResnetBlock2D(nn.Module):
hidden_states
=
self
.
conv1
(
hidden_states
)
hidden_states
=
self
.
conv1
(
hidden_states
)
if
self
.
time_emb_proj
is
not
None
:
if
self
.
time_emb_proj
is
not
None
:
temb
=
self
.
time_emb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
if
not
self
.
skip_time_act
:
temb
=
self
.
nonlinearity
(
temb
)
temb
=
self
.
time_emb_proj
(
temb
)[:,
:,
None
,
None
]
if
temb
is
not
None
and
self
.
time_embedding_norm
==
"default"
:
if
temb
is
not
None
and
self
.
time_embedding_norm
==
"default"
:
hidden_states
=
hidden_states
+
temb
hidden_states
=
hidden_states
+
temb
...
...
src/diffusers/models/unet_2d_blocks.py
View file @
707341ae
...
@@ -42,6 +42,8 @@ def get_down_block(
...
@@ -42,6 +42,8 @@ def get_down_block(
only_cross_attention
=
False
,
only_cross_attention
=
False
,
upcast_attention
=
False
,
upcast_attention
=
False
,
resnet_time_scale_shift
=
"default"
,
resnet_time_scale_shift
=
"default"
,
resnet_skip_time_act
=
False
,
resnet_out_scale_factor
=
1.0
,
):
):
down_block_type
=
down_block_type
[
7
:]
if
down_block_type
.
startswith
(
"UNetRes"
)
else
down_block_type
down_block_type
=
down_block_type
[
7
:]
if
down_block_type
.
startswith
(
"UNetRes"
)
else
down_block_type
if
down_block_type
==
"DownBlock2D"
:
if
down_block_type
==
"DownBlock2D"
:
...
@@ -68,6 +70,8 @@ def get_down_block(
...
@@ -68,6 +70,8 @@ def get_down_block(
resnet_act_fn
=
resnet_act_fn
,
resnet_act_fn
=
resnet_act_fn
,
resnet_groups
=
resnet_groups
,
resnet_groups
=
resnet_groups
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
skip_time_act
=
resnet_skip_time_act
,
output_scale_factor
=
resnet_out_scale_factor
,
)
)
elif
down_block_type
==
"AttnDownBlock2D"
:
elif
down_block_type
==
"AttnDownBlock2D"
:
return
AttnDownBlock2D
(
return
AttnDownBlock2D
(
...
@@ -119,6 +123,8 @@ def get_down_block(
...
@@ -119,6 +123,8 @@ def get_down_block(
cross_attention_dim
=
cross_attention_dim
,
cross_attention_dim
=
cross_attention_dim
,
attn_num_head_channels
=
attn_num_head_channels
,
attn_num_head_channels
=
attn_num_head_channels
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
skip_time_act
=
resnet_skip_time_act
,
output_scale_factor
=
resnet_out_scale_factor
,
)
)
elif
down_block_type
==
"SkipDownBlock2D"
:
elif
down_block_type
==
"SkipDownBlock2D"
:
return
SkipDownBlock2D
(
return
SkipDownBlock2D
(
...
@@ -214,6 +220,8 @@ def get_up_block(
...
@@ -214,6 +220,8 @@ def get_up_block(
only_cross_attention
=
False
,
only_cross_attention
=
False
,
upcast_attention
=
False
,
upcast_attention
=
False
,
resnet_time_scale_shift
=
"default"
,
resnet_time_scale_shift
=
"default"
,
resnet_skip_time_act
=
False
,
resnet_out_scale_factor
=
1.0
,
):
):
up_block_type
=
up_block_type
[
7
:]
if
up_block_type
.
startswith
(
"UNetRes"
)
else
up_block_type
up_block_type
=
up_block_type
[
7
:]
if
up_block_type
.
startswith
(
"UNetRes"
)
else
up_block_type
if
up_block_type
==
"UpBlock2D"
:
if
up_block_type
==
"UpBlock2D"
:
...
@@ -241,6 +249,8 @@ def get_up_block(
...
@@ -241,6 +249,8 @@ def get_up_block(
resnet_act_fn
=
resnet_act_fn
,
resnet_act_fn
=
resnet_act_fn
,
resnet_groups
=
resnet_groups
,
resnet_groups
=
resnet_groups
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
skip_time_act
=
resnet_skip_time_act
,
output_scale_factor
=
resnet_out_scale_factor
,
)
)
elif
up_block_type
==
"CrossAttnUpBlock2D"
:
elif
up_block_type
==
"CrossAttnUpBlock2D"
:
if
cross_attention_dim
is
None
:
if
cross_attention_dim
is
None
:
...
@@ -279,6 +289,8 @@ def get_up_block(
...
@@ -279,6 +289,8 @@ def get_up_block(
cross_attention_dim
=
cross_attention_dim
,
cross_attention_dim
=
cross_attention_dim
,
attn_num_head_channels
=
attn_num_head_channels
,
attn_num_head_channels
=
attn_num_head_channels
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
skip_time_act
=
resnet_skip_time_act
,
output_scale_factor
=
resnet_out_scale_factor
,
)
)
elif
up_block_type
==
"AttnUpBlock2D"
:
elif
up_block_type
==
"AttnUpBlock2D"
:
return
AttnUpBlock2D
(
return
AttnUpBlock2D
(
...
@@ -562,6 +574,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
...
@@ -562,6 +574,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
attn_num_head_channels
=
1
,
attn_num_head_channels
=
1
,
output_scale_factor
=
1.0
,
output_scale_factor
=
1.0
,
cross_attention_dim
=
1280
,
cross_attention_dim
=
1280
,
skip_time_act
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -585,6 +598,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
...
@@ -585,6 +598,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
non_linearity
=
resnet_act_fn
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
pre_norm
=
resnet_pre_norm
,
skip_time_act
=
skip_time_act
,
)
)
]
]
attentions
=
[]
attentions
=
[]
...
@@ -615,6 +629,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
...
@@ -615,6 +629,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
non_linearity
=
resnet_act_fn
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
pre_norm
=
resnet_pre_norm
,
skip_time_act
=
skip_time_act
,
)
)
)
)
...
@@ -1247,6 +1262,7 @@ class ResnetDownsampleBlock2D(nn.Module):
...
@@ -1247,6 +1262,7 @@ class ResnetDownsampleBlock2D(nn.Module):
resnet_pre_norm
:
bool
=
True
,
resnet_pre_norm
:
bool
=
True
,
output_scale_factor
=
1.0
,
output_scale_factor
=
1.0
,
add_downsample
=
True
,
add_downsample
=
True
,
skip_time_act
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
resnets
=
[]
resnets
=
[]
...
@@ -1265,6 +1281,7 @@ class ResnetDownsampleBlock2D(nn.Module):
...
@@ -1265,6 +1281,7 @@ class ResnetDownsampleBlock2D(nn.Module):
non_linearity
=
resnet_act_fn
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
pre_norm
=
resnet_pre_norm
,
skip_time_act
=
skip_time_act
,
)
)
)
)
...
@@ -1284,6 +1301,7 @@ class ResnetDownsampleBlock2D(nn.Module):
...
@@ -1284,6 +1301,7 @@ class ResnetDownsampleBlock2D(nn.Module):
non_linearity
=
resnet_act_fn
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
pre_norm
=
resnet_pre_norm
,
skip_time_act
=
skip_time_act
,
down
=
True
,
down
=
True
,
)
)
]
]
...
@@ -1337,6 +1355,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
...
@@ -1337,6 +1355,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
cross_attention_dim
=
1280
,
cross_attention_dim
=
1280
,
output_scale_factor
=
1.0
,
output_scale_factor
=
1.0
,
add_downsample
=
True
,
add_downsample
=
True
,
skip_time_act
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -1362,6 +1381,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
...
@@ -1362,6 +1381,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
non_linearity
=
resnet_act_fn
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
pre_norm
=
resnet_pre_norm
,
skip_time_act
=
skip_time_act
,
)
)
)
)
attentions
.
append
(
attentions
.
append
(
...
@@ -1394,6 +1414,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
...
@@ -1394,6 +1414,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
non_linearity
=
resnet_act_fn
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
pre_norm
=
resnet_pre_norm
,
skip_time_act
=
skip_time_act
,
down
=
True
,
down
=
True
,
)
)
]
]
...
@@ -2237,6 +2258,7 @@ class ResnetUpsampleBlock2D(nn.Module):
...
@@ -2237,6 +2258,7 @@ class ResnetUpsampleBlock2D(nn.Module):
resnet_pre_norm
:
bool
=
True
,
resnet_pre_norm
:
bool
=
True
,
output_scale_factor
=
1.0
,
output_scale_factor
=
1.0
,
add_upsample
=
True
,
add_upsample
=
True
,
skip_time_act
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
resnets
=
[]
resnets
=
[]
...
@@ -2257,6 +2279,7 @@ class ResnetUpsampleBlock2D(nn.Module):
...
@@ -2257,6 +2279,7 @@ class ResnetUpsampleBlock2D(nn.Module):
non_linearity
=
resnet_act_fn
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
pre_norm
=
resnet_pre_norm
,
skip_time_act
=
skip_time_act
,
)
)
)
)
...
@@ -2276,6 +2299,7 @@ class ResnetUpsampleBlock2D(nn.Module):
...
@@ -2276,6 +2299,7 @@ class ResnetUpsampleBlock2D(nn.Module):
non_linearity
=
resnet_act_fn
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
pre_norm
=
resnet_pre_norm
,
skip_time_act
=
skip_time_act
,
up
=
True
,
up
=
True
,
)
)
]
]
...
@@ -2329,6 +2353,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
...
@@ -2329,6 +2353,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
cross_attention_dim
=
1280
,
cross_attention_dim
=
1280
,
output_scale_factor
=
1.0
,
output_scale_factor
=
1.0
,
add_upsample
=
True
,
add_upsample
=
True
,
skip_time_act
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
resnets
=
[]
resnets
=
[]
...
@@ -2355,6 +2380,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
...
@@ -2355,6 +2380,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
non_linearity
=
resnet_act_fn
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
pre_norm
=
resnet_pre_norm
,
skip_time_act
=
skip_time_act
,
)
)
)
)
attentions
.
append
(
attentions
.
append
(
...
@@ -2387,6 +2413,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
...
@@ -2387,6 +2413,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
non_linearity
=
resnet_act_fn
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
pre_norm
=
resnet_pre_norm
,
skip_time_act
=
skip_time_act
,
up
=
True
,
up
=
True
,
)
)
]
]
...
...
src/diffusers/models/unet_2d_condition.py
View file @
707341ae
...
@@ -146,6 +146,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
...
@@ -146,6 +146,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
num_class_embeds
:
Optional
[
int
]
=
None
,
num_class_embeds
:
Optional
[
int
]
=
None
,
upcast_attention
:
bool
=
False
,
upcast_attention
:
bool
=
False
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_skip_time_act
:
bool
=
False
,
resnet_out_scale_factor
:
int
=
1.0
,
time_embedding_type
:
str
=
"positional"
,
time_embedding_type
:
str
=
"positional"
,
timestep_post_act
:
Optional
[
str
]
=
None
,
timestep_post_act
:
Optional
[
str
]
=
None
,
time_cond_proj_dim
:
Optional
[
int
]
=
None
,
time_cond_proj_dim
:
Optional
[
int
]
=
None
,
...
@@ -291,6 +293,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
...
@@ -291,6 +293,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
only_cross_attention
=
only_cross_attention
[
i
],
only_cross_attention
=
only_cross_attention
[
i
],
upcast_attention
=
upcast_attention
,
upcast_attention
=
upcast_attention
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
resnet_skip_time_act
=
resnet_skip_time_act
,
resnet_out_scale_factor
=
resnet_out_scale_factor
,
)
)
self
.
down_blocks
.
append
(
down_block
)
self
.
down_blocks
.
append
(
down_block
)
...
@@ -321,6 +325,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
...
@@ -321,6 +325,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
attn_num_head_channels
=
attention_head_dim
[
-
1
],
attn_num_head_channels
=
attention_head_dim
[
-
1
],
resnet_groups
=
norm_num_groups
,
resnet_groups
=
norm_num_groups
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
skip_time_act
=
resnet_skip_time_act
,
)
)
elif
mid_block_type
is
None
:
elif
mid_block_type
is
None
:
self
.
mid_block
=
None
self
.
mid_block
=
None
...
@@ -369,6 +374,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
...
@@ -369,6 +374,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
only_cross_attention
=
only_cross_attention
[
i
],
only_cross_attention
=
only_cross_attention
[
i
],
upcast_attention
=
upcast_attention
,
upcast_attention
=
upcast_attention
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
resnet_skip_time_act
=
resnet_skip_time_act
,
resnet_out_scale_factor
=
resnet_out_scale_factor
,
)
)
self
.
up_blocks
.
append
(
up_block
)
self
.
up_blocks
.
append
(
up_block
)
prev_output_channel
=
output_channel
prev_output_channel
=
output_channel
...
...
src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
View file @
707341ae
...
@@ -232,6 +232,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
...
@@ -232,6 +232,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
num_class_embeds
:
Optional
[
int
]
=
None
,
num_class_embeds
:
Optional
[
int
]
=
None
,
upcast_attention
:
bool
=
False
,
upcast_attention
:
bool
=
False
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_skip_time_act
:
bool
=
False
,
resnet_out_scale_factor
:
int
=
1.0
,
time_embedding_type
:
str
=
"positional"
,
time_embedding_type
:
str
=
"positional"
,
timestep_post_act
:
Optional
[
str
]
=
None
,
timestep_post_act
:
Optional
[
str
]
=
None
,
time_cond_proj_dim
:
Optional
[
int
]
=
None
,
time_cond_proj_dim
:
Optional
[
int
]
=
None
,
...
@@ -382,6 +384,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
...
@@ -382,6 +384,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
only_cross_attention
=
only_cross_attention
[
i
],
only_cross_attention
=
only_cross_attention
[
i
],
upcast_attention
=
upcast_attention
,
upcast_attention
=
upcast_attention
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
resnet_skip_time_act
=
resnet_skip_time_act
,
resnet_out_scale_factor
=
resnet_out_scale_factor
,
)
)
self
.
down_blocks
.
append
(
down_block
)
self
.
down_blocks
.
append
(
down_block
)
...
@@ -412,6 +416,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
...
@@ -412,6 +416,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
attn_num_head_channels
=
attention_head_dim
[
-
1
],
attn_num_head_channels
=
attention_head_dim
[
-
1
],
resnet_groups
=
norm_num_groups
,
resnet_groups
=
norm_num_groups
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
skip_time_act
=
resnet_skip_time_act
,
)
)
elif
mid_block_type
is
None
:
elif
mid_block_type
is
None
:
self
.
mid_block
=
None
self
.
mid_block
=
None
...
@@ -460,6 +465,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
...
@@ -460,6 +465,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
only_cross_attention
=
only_cross_attention
[
i
],
only_cross_attention
=
only_cross_attention
[
i
],
upcast_attention
=
upcast_attention
,
upcast_attention
=
upcast_attention
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
resnet_skip_time_act
=
resnet_skip_time_act
,
resnet_out_scale_factor
=
resnet_out_scale_factor
,
)
)
self
.
up_blocks
.
append
(
up_block
)
self
.
up_blocks
.
append
(
up_block
)
prev_output_channel
=
output_channel
prev_output_channel
=
output_channel
...
@@ -1434,6 +1441,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
...
@@ -1434,6 +1441,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
attn_num_head_channels
=
1
,
attn_num_head_channels
=
1
,
output_scale_factor
=
1.0
,
output_scale_factor
=
1.0
,
cross_attention_dim
=
1280
,
cross_attention_dim
=
1280
,
skip_time_act
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -1457,6 +1465,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
...
@@ -1457,6 +1465,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
non_linearity
=
resnet_act_fn
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
pre_norm
=
resnet_pre_norm
,
skip_time_act
=
skip_time_act
,
)
)
]
]
attentions
=
[]
attentions
=
[]
...
@@ -1487,6 +1496,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
...
@@ -1487,6 +1496,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
non_linearity
=
resnet_act_fn
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
pre_norm
=
resnet_pre_norm
,
skip_time_act
=
skip_time_act
,
)
)
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment