Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
b1fe1706
Unverified
Commit
b1fe1706
authored
Sep 03, 2022
by
Sid Sahai
Committed by
GitHub
Sep 03, 2022
Browse files
[Type Hint] Unet Models (#330)
* add void check * remove void, add types for params
parent
9b704f76
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
36 deletions
+41
-36
src/diffusers/models/unet_2d.py
src/diffusers/models/unet_2d.py
+18
-18
src/diffusers/models/unet_2d_condition.py
src/diffusers/models/unet_2d_condition.py
+23
-18
No files found.
src/diffusers/models/unet_2d.py
View file @
b1fe1706
from
typing
import
Dict
,
Union
from
typing
import
Dict
,
Optional
,
Tuple
,
Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -13,23 +13,23 @@ class UNet2DModel(ModelMixin, ConfigMixin):
...
@@ -13,23 +13,23 @@ class UNet2DModel(ModelMixin, ConfigMixin):
@
register_to_config
@
register_to_config
def
__init__
(
def
__init__
(
self
,
self
,
sample_size
=
None
,
sample_size
:
Optional
[
int
]
=
None
,
in_channels
=
3
,
in_channels
:
int
=
3
,
out_channels
=
3
,
out_channels
:
int
=
3
,
center_input_sample
=
False
,
center_input_sample
:
bool
=
False
,
time_embedding_type
=
"positional"
,
time_embedding_type
:
str
=
"positional"
,
freq_shift
=
0
,
freq_shift
:
int
=
0
,
flip_sin_to_cos
=
True
,
flip_sin_to_cos
:
bool
=
True
,
down_block_types
=
(
"DownBlock2D"
,
"AttnDownBlock2D"
,
"AttnDownBlock2D"
,
"AttnDownBlock2D"
),
down_block_types
:
Tuple
[
str
]
=
(
"DownBlock2D"
,
"AttnDownBlock2D"
,
"AttnDownBlock2D"
,
"AttnDownBlock2D"
),
up_block_types
=
(
"AttnUpBlock2D"
,
"AttnUpBlock2D"
,
"AttnUpBlock2D"
,
"UpBlock2D"
),
up_block_types
:
Tuple
[
str
]
=
(
"AttnUpBlock2D"
,
"AttnUpBlock2D"
,
"AttnUpBlock2D"
,
"UpBlock2D"
),
block_out_channels
=
(
224
,
448
,
672
,
896
),
block_out_channels
:
Tuple
[
int
]
=
(
224
,
448
,
672
,
896
),
layers_per_block
=
2
,
layers_per_block
:
int
=
2
,
mid_block_scale_factor
=
1
,
mid_block_scale_factor
:
float
=
1
,
downsample_padding
=
1
,
downsample_padding
:
int
=
1
,
act_fn
=
"silu"
,
act_fn
:
str
=
"silu"
,
attention_head_dim
=
8
,
attention_head_dim
:
int
=
8
,
norm_num_groups
=
32
,
norm_num_groups
:
int
=
32
,
norm_eps
=
1e-5
,
norm_eps
:
float
=
1e-5
,
):
):
super
().
__init__
()
super
().
__init__
()
...
...
src/diffusers/models/unet_2d_condition.py
View file @
b1fe1706
from
typing
import
Dict
,
Union
from
typing
import
Dict
,
Optional
,
Tuple
,
Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -13,23 +13,28 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
...
@@ -13,23 +13,28 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
@
register_to_config
@
register_to_config
def
__init__
(
def
__init__
(
self
,
self
,
sample_size
=
None
,
sample_size
:
Optional
[
int
]
=
None
,
in_channels
=
4
,
in_channels
:
int
=
4
,
out_channels
=
4
,
out_channels
:
int
=
4
,
center_input_sample
=
False
,
center_input_sample
:
bool
=
False
,
flip_sin_to_cos
=
True
,
flip_sin_to_cos
:
bool
=
True
,
freq_shift
=
0
,
freq_shift
:
int
=
0
,
down_block_types
=
(
"CrossAttnDownBlock2D"
,
"CrossAttnDownBlock2D"
,
"CrossAttnDownBlock2D"
,
"DownBlock2D"
),
down_block_types
:
Tuple
[
str
]
=
(
up_block_types
=
(
"UpBlock2D"
,
"CrossAttnUpBlock2D"
,
"CrossAttnUpBlock2D"
,
"CrossAttnUpBlock2D"
),
"CrossAttnDownBlock2D"
,
block_out_channels
=
(
320
,
640
,
1280
,
1280
),
"CrossAttnDownBlock2D"
,
layers_per_block
=
2
,
"CrossAttnDownBlock2D"
,
downsample_padding
=
1
,
"DownBlock2D"
,
mid_block_scale_factor
=
1
,
),
act_fn
=
"silu"
,
up_block_types
:
Tuple
[
str
]
=
(
"UpBlock2D"
,
"CrossAttnUpBlock2D"
,
"CrossAttnUpBlock2D"
,
"CrossAttnUpBlock2D"
),
norm_num_groups
=
32
,
block_out_channels
:
Tuple
[
int
]
=
(
320
,
640
,
1280
,
1280
),
norm_eps
=
1e-5
,
layers_per_block
:
int
=
2
,
cross_attention_dim
=
1280
,
downsample_padding
:
int
=
1
,
attention_head_dim
=
8
,
mid_block_scale_factor
:
float
=
1
,
act_fn
:
str
=
"silu"
,
norm_num_groups
:
int
=
32
,
norm_eps
:
float
=
1e-5
,
cross_attention_dim
:
int
=
1280
,
attention_head_dim
:
int
=
8
,
):
):
super
().
__init__
()
super
().
__init__
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment