Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
707341ae
Commit
707341ae
authored
Apr 08, 2023
by
William Berman
Committed by
Will Berman
Apr 09, 2023
Browse files
resnet skip time activation and output scale factor
parent
26b4319a
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
49 additions
and
1 deletion
+49
-1
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+5
-1
src/diffusers/models/unet_2d_blocks.py
src/diffusers/models/unet_2d_blocks.py
+27
-0
src/diffusers/models/unet_2d_condition.py
src/diffusers/models/unet_2d_condition.py
+7
-0
src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
...users/pipelines/versatile_diffusion/modeling_text_unet.py
+10
-0
No files found.
src/diffusers/models/resnet.py
View file @
707341ae
...
...
@@ -459,6 +459,7 @@ class ResnetBlock2D(nn.Module):
pre_norm
=
True
,
eps
=
1e-6
,
non_linearity
=
"swish"
,
skip_time_act
=
False
,
time_embedding_norm
=
"default"
,
# default, scale_shift, ada_group
kernel
=
None
,
output_scale_factor
=
1.0
,
...
...
@@ -479,6 +480,7 @@ class ResnetBlock2D(nn.Module):
self
.
down
=
down
self
.
output_scale_factor
=
output_scale_factor
self
.
time_embedding_norm
=
time_embedding_norm
self
.
skip_time_act
=
skip_time_act
if
groups_out
is
None
:
groups_out
=
groups
...
...
@@ -570,7 +572,9 @@ class ResnetBlock2D(nn.Module):
hidden_states
=
self
.
conv1
(
hidden_states
)
if
self
.
time_emb_proj
is
not
None
:
temb
=
self
.
time_emb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
if
not
self
.
skip_time_act
:
temb
=
self
.
nonlinearity
(
temb
)
temb
=
self
.
time_emb_proj
(
temb
)[:,
:,
None
,
None
]
if
temb
is
not
None
and
self
.
time_embedding_norm
==
"default"
:
hidden_states
=
hidden_states
+
temb
...
...
src/diffusers/models/unet_2d_blocks.py
View file @
707341ae
...
...
@@ -42,6 +42,8 @@ def get_down_block(
only_cross_attention
=
False
,
upcast_attention
=
False
,
resnet_time_scale_shift
=
"default"
,
resnet_skip_time_act
=
False
,
resnet_out_scale_factor
=
1.0
,
):
down_block_type
=
down_block_type
[
7
:]
if
down_block_type
.
startswith
(
"UNetRes"
)
else
down_block_type
if
down_block_type
==
"DownBlock2D"
:
...
...
@@ -68,6 +70,8 @@ def get_down_block(
resnet_act_fn
=
resnet_act_fn
,
resnet_groups
=
resnet_groups
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
skip_time_act
=
resnet_skip_time_act
,
output_scale_factor
=
resnet_out_scale_factor
,
)
elif
down_block_type
==
"AttnDownBlock2D"
:
return
AttnDownBlock2D
(
...
...
@@ -119,6 +123,8 @@ def get_down_block(
cross_attention_dim
=
cross_attention_dim
,
attn_num_head_channels
=
attn_num_head_channels
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
skip_time_act
=
resnet_skip_time_act
,
output_scale_factor
=
resnet_out_scale_factor
,
)
elif
down_block_type
==
"SkipDownBlock2D"
:
return
SkipDownBlock2D
(
...
...
@@ -214,6 +220,8 @@ def get_up_block(
only_cross_attention
=
False
,
upcast_attention
=
False
,
resnet_time_scale_shift
=
"default"
,
resnet_skip_time_act
=
False
,
resnet_out_scale_factor
=
1.0
,
):
up_block_type
=
up_block_type
[
7
:]
if
up_block_type
.
startswith
(
"UNetRes"
)
else
up_block_type
if
up_block_type
==
"UpBlock2D"
:
...
...
@@ -241,6 +249,8 @@ def get_up_block(
resnet_act_fn
=
resnet_act_fn
,
resnet_groups
=
resnet_groups
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
skip_time_act
=
resnet_skip_time_act
,
output_scale_factor
=
resnet_out_scale_factor
,
)
elif
up_block_type
==
"CrossAttnUpBlock2D"
:
if
cross_attention_dim
is
None
:
...
...
@@ -279,6 +289,8 @@ def get_up_block(
cross_attention_dim
=
cross_attention_dim
,
attn_num_head_channels
=
attn_num_head_channels
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
skip_time_act
=
resnet_skip_time_act
,
output_scale_factor
=
resnet_out_scale_factor
,
)
elif
up_block_type
==
"AttnUpBlock2D"
:
return
AttnUpBlock2D
(
...
...
@@ -562,6 +574,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
attn_num_head_channels
=
1
,
output_scale_factor
=
1.0
,
cross_attention_dim
=
1280
,
skip_time_act
=
False
,
):
super
().
__init__
()
...
...
@@ -585,6 +598,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
skip_time_act
=
skip_time_act
,
)
]
attentions
=
[]
...
...
@@ -615,6 +629,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
skip_time_act
=
skip_time_act
,
)
)
...
...
@@ -1247,6 +1262,7 @@ class ResnetDownsampleBlock2D(nn.Module):
resnet_pre_norm
:
bool
=
True
,
output_scale_factor
=
1.0
,
add_downsample
=
True
,
skip_time_act
=
False
,
):
super
().
__init__
()
resnets
=
[]
...
...
@@ -1265,6 +1281,7 @@ class ResnetDownsampleBlock2D(nn.Module):
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
skip_time_act
=
skip_time_act
,
)
)
...
...
@@ -1284,6 +1301,7 @@ class ResnetDownsampleBlock2D(nn.Module):
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
skip_time_act
=
skip_time_act
,
down
=
True
,
)
]
...
...
@@ -1337,6 +1355,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
cross_attention_dim
=
1280
,
output_scale_factor
=
1.0
,
add_downsample
=
True
,
skip_time_act
=
False
,
):
super
().
__init__
()
...
...
@@ -1362,6 +1381,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
skip_time_act
=
skip_time_act
,
)
)
attentions
.
append
(
...
...
@@ -1394,6 +1414,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
skip_time_act
=
skip_time_act
,
down
=
True
,
)
]
...
...
@@ -2237,6 +2258,7 @@ class ResnetUpsampleBlock2D(nn.Module):
resnet_pre_norm
:
bool
=
True
,
output_scale_factor
=
1.0
,
add_upsample
=
True
,
skip_time_act
=
False
,
):
super
().
__init__
()
resnets
=
[]
...
...
@@ -2257,6 +2279,7 @@ class ResnetUpsampleBlock2D(nn.Module):
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
skip_time_act
=
skip_time_act
,
)
)
...
...
@@ -2276,6 +2299,7 @@ class ResnetUpsampleBlock2D(nn.Module):
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
skip_time_act
=
skip_time_act
,
up
=
True
,
)
]
...
...
@@ -2329,6 +2353,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
cross_attention_dim
=
1280
,
output_scale_factor
=
1.0
,
add_upsample
=
True
,
skip_time_act
=
False
,
):
super
().
__init__
()
resnets
=
[]
...
...
@@ -2355,6 +2380,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
skip_time_act
=
skip_time_act
,
)
)
attentions
.
append
(
...
...
@@ -2387,6 +2413,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
skip_time_act
=
skip_time_act
,
up
=
True
,
)
]
...
...
src/diffusers/models/unet_2d_condition.py
View file @
707341ae
...
...
@@ -146,6 +146,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
num_class_embeds
:
Optional
[
int
]
=
None
,
upcast_attention
:
bool
=
False
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_skip_time_act
:
bool
=
False
,
resnet_out_scale_factor
:
int
=
1.0
,
time_embedding_type
:
str
=
"positional"
,
timestep_post_act
:
Optional
[
str
]
=
None
,
time_cond_proj_dim
:
Optional
[
int
]
=
None
,
...
...
@@ -291,6 +293,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
only_cross_attention
=
only_cross_attention
[
i
],
upcast_attention
=
upcast_attention
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
resnet_skip_time_act
=
resnet_skip_time_act
,
resnet_out_scale_factor
=
resnet_out_scale_factor
,
)
self
.
down_blocks
.
append
(
down_block
)
...
...
@@ -321,6 +325,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
attn_num_head_channels
=
attention_head_dim
[
-
1
],
resnet_groups
=
norm_num_groups
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
skip_time_act
=
resnet_skip_time_act
,
)
elif
mid_block_type
is
None
:
self
.
mid_block
=
None
...
...
@@ -369,6 +374,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
only_cross_attention
=
only_cross_attention
[
i
],
upcast_attention
=
upcast_attention
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
resnet_skip_time_act
=
resnet_skip_time_act
,
resnet_out_scale_factor
=
resnet_out_scale_factor
,
)
self
.
up_blocks
.
append
(
up_block
)
prev_output_channel
=
output_channel
...
...
src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
View file @
707341ae
...
...
@@ -232,6 +232,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
num_class_embeds
:
Optional
[
int
]
=
None
,
upcast_attention
:
bool
=
False
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_skip_time_act
:
bool
=
False
,
resnet_out_scale_factor
:
int
=
1.0
,
time_embedding_type
:
str
=
"positional"
,
timestep_post_act
:
Optional
[
str
]
=
None
,
time_cond_proj_dim
:
Optional
[
int
]
=
None
,
...
...
@@ -382,6 +384,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
only_cross_attention
=
only_cross_attention
[
i
],
upcast_attention
=
upcast_attention
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
resnet_skip_time_act
=
resnet_skip_time_act
,
resnet_out_scale_factor
=
resnet_out_scale_factor
,
)
self
.
down_blocks
.
append
(
down_block
)
...
...
@@ -412,6 +416,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
attn_num_head_channels
=
attention_head_dim
[
-
1
],
resnet_groups
=
norm_num_groups
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
skip_time_act
=
resnet_skip_time_act
,
)
elif
mid_block_type
is
None
:
self
.
mid_block
=
None
...
...
@@ -460,6 +465,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
only_cross_attention
=
only_cross_attention
[
i
],
upcast_attention
=
upcast_attention
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
resnet_skip_time_act
=
resnet_skip_time_act
,
resnet_out_scale_factor
=
resnet_out_scale_factor
,
)
self
.
up_blocks
.
append
(
up_block
)
prev_output_channel
=
output_channel
...
...
@@ -1434,6 +1441,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
attn_num_head_channels
=
1
,
output_scale_factor
=
1.0
,
cross_attention_dim
=
1280
,
skip_time_act
=
False
,
):
super
().
__init__
()
...
...
@@ -1457,6 +1465,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
skip_time_act
=
skip_time_act
,
)
]
attentions
=
[]
...
...
@@ -1487,6 +1496,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
skip_time_act
=
skip_time_act
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment