Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
abd922bd
Unverified
Commit
abd922bd
authored
Feb 29, 2024
by
Aryan
Committed by
GitHub
Feb 28, 2024
Browse files
[docs] unet type hints (#7134)
update
parent
fa633ed6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
8 deletions
+10
-8
src/diffusers/models/unets/unet_2d_condition.py
src/diffusers/models/unets/unet_2d_condition.py
+10
-8
No files found.
src/diffusers/models/unets/unet_2d_condition.py
View file @
abd922bd
...
...
@@ -204,7 +204,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
upcast_attention
:
bool
=
False
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_skip_time_act
:
bool
=
False
,
resnet_out_scale_factor
:
in
t
=
1.0
,
resnet_out_scale_factor
:
floa
t
=
1.0
,
time_embedding_type
:
str
=
"positional"
,
time_embedding_dim
:
Optional
[
int
]
=
None
,
time_embedding_act_fn
:
Optional
[
str
]
=
None
,
...
...
@@ -217,7 +217,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
class_embeddings_concat
:
bool
=
False
,
mid_block_only_cross_attention
:
Optional
[
bool
]
=
None
,
cross_attention_norm
:
Optional
[
str
]
=
None
,
addition_embed_type_num_heads
=
64
,
addition_embed_type_num_heads
:
int
=
64
,
):
super
().
__init__
()
...
...
@@ -485,9 +485,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
up_block_types
:
Tuple
[
str
],
only_cross_attention
:
Union
[
bool
,
Tuple
[
bool
]],
block_out_channels
:
Tuple
[
int
],
layers_per_block
:
[
int
,
Tuple
[
int
]],
layers_per_block
:
Union
[
int
,
Tuple
[
int
]],
cross_attention_dim
:
Union
[
int
,
Tuple
[
int
]],
transformer_layers_per_block
:
Union
[
int
,
Tuple
[
int
],
Tuple
[
Tuple
]],
transformer_layers_per_block
:
Union
[
int
,
Tuple
[
int
],
Tuple
[
Tuple
[
int
]
]],
reverse_transformer_layers_per_block
:
bool
,
attention_head_dim
:
int
,
num_attention_heads
:
Optional
[
Union
[
int
,
Tuple
[
int
]]],
...
...
@@ -762,7 +762,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
self
.
set_attn_processor
(
processor
)
def
set_attention_slice
(
self
,
slice_size
):
def
set_attention_slice
(
self
,
slice_size
:
Union
[
str
,
int
,
List
[
int
]]
=
"auto"
):
r
"""
Enable sliced attention computation.
...
...
@@ -831,7 +831,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
if
hasattr
(
module
,
"gradient_checkpointing"
):
module
.
gradient_checkpointing
=
value
def
enable_freeu
(
self
,
s1
,
s2
,
b1
,
b2
):
def
enable_freeu
(
self
,
s1
:
float
,
s2
:
float
,
b1
:
float
,
b2
:
float
):
r
"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
The suffixes after the scaling factors represent the stage blocks where they are being applied.
...
...
@@ -953,7 +953,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
return
class_emb
def
get_aug_embed
(
self
,
emb
:
torch
.
Tensor
,
encoder_hidden_states
:
torch
.
Tensor
,
added_cond_kwargs
:
Dict
self
,
emb
:
torch
.
Tensor
,
encoder_hidden_states
:
torch
.
Tensor
,
added_cond_kwargs
:
Dict
[
str
,
Any
]
)
->
Optional
[
torch
.
Tensor
]:
aug_emb
=
None
if
self
.
config
.
addition_embed_type
==
"text"
:
...
...
@@ -1004,7 +1004,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
aug_emb
=
self
.
add_embedding
(
image_embs
,
hint
)
return
aug_emb
def
process_encoder_hidden_states
(
self
,
encoder_hidden_states
:
torch
.
Tensor
,
added_cond_kwargs
)
->
torch
.
Tensor
:
def
process_encoder_hidden_states
(
self
,
encoder_hidden_states
:
torch
.
Tensor
,
added_cond_kwargs
:
Dict
[
str
,
Any
]
)
->
torch
.
Tensor
:
if
self
.
encoder_hid_proj
is
not
None
and
self
.
config
.
encoder_hid_dim_type
==
"text_proj"
:
encoder_hidden_states
=
self
.
encoder_hid_proj
(
encoder_hidden_states
)
elif
self
.
encoder_hid_proj
is
not
None
and
self
.
config
.
encoder_hid_dim_type
==
"text_image_proj"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment