Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
170ebd28
Unverified
Commit
170ebd28
authored
Dec 07, 2022
by
Suraj Patil
Committed by
GitHub
Dec 07, 2022
Browse files
[UNet2DConditionModel] add an option to upcast attention to fp32 (#1590)
upcast attention
parent
dc87f526
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
50 additions
and
1 deletion
+50
-1
src/diffusers/models/attention.py
src/diffusers/models/attention.py
+26
-1
src/diffusers/models/unet_2d_blocks.py
src/diffusers/models/unet_2d_blocks.py
+10
-0
src/diffusers/models/unet_2d_condition.py
src/diffusers/models/unet_2d_condition.py
+4
-0
src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
...users/pipelines/versatile_diffusion/modeling_text_unet.py
+10
-0
No files found.
src/diffusers/models/attention.py
View file @
170ebd28
...
...
@@ -101,6 +101,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
num_embeds_ada_norm
:
Optional
[
int
]
=
None
,
use_linear_projection
:
bool
=
False
,
only_cross_attention
:
bool
=
False
,
upcast_attention
:
bool
=
False
,
):
super
().
__init__
()
self
.
use_linear_projection
=
use_linear_projection
...
...
@@ -159,6 +160,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
num_embeds_ada_norm
=
num_embeds_ada_norm
,
attention_bias
=
attention_bias
,
only_cross_attention
=
only_cross_attention
,
upcast_attention
=
upcast_attention
,
)
for
d
in
range
(
num_layers
)
]
...
...
@@ -403,6 +405,7 @@ class BasicTransformerBlock(nn.Module):
num_embeds_ada_norm
:
Optional
[
int
]
=
None
,
attention_bias
:
bool
=
False
,
only_cross_attention
:
bool
=
False
,
upcast_attention
:
bool
=
False
,
):
super
().
__init__
()
self
.
only_cross_attention
=
only_cross_attention
...
...
@@ -416,6 +419,7 @@ class BasicTransformerBlock(nn.Module):
dropout
=
dropout
,
bias
=
attention_bias
,
cross_attention_dim
=
cross_attention_dim
if
only_cross_attention
else
None
,
upcast_attention
=
upcast_attention
,
)
# is a self-attention
self
.
ff
=
FeedForward
(
dim
,
dropout
=
dropout
,
activation_fn
=
activation_fn
)
...
...
@@ -428,6 +432,7 @@ class BasicTransformerBlock(nn.Module):
dim_head
=
attention_head_dim
,
dropout
=
dropout
,
bias
=
attention_bias
,
upcast_attention
=
upcast_attention
,
)
# is self-attn if context is none
else
:
self
.
attn2
=
None
...
...
@@ -525,10 +530,12 @@ class CrossAttention(nn.Module):
dim_head
:
int
=
64
,
dropout
:
float
=
0.0
,
bias
=
False
,
upcast_attention
:
bool
=
False
,
):
super
().
__init__
()
inner_dim
=
dim_head
*
heads
cross_attention_dim
=
cross_attention_dim
if
cross_attention_dim
is
not
None
else
query_dim
self
.
upcast_attention
=
upcast_attention
self
.
scale
=
dim_head
**-
0.5
self
.
heads
=
heads
...
...
@@ -601,6 +608,10 @@ class CrossAttention(nn.Module):
return
hidden_states
def
_attention
(
self
,
query
,
key
,
value
):
if
self
.
upcast_attention
:
query
=
query
.
float
()
key
=
key
.
float
()
attention_scores
=
torch
.
baddbmm
(
torch
.
empty
(
query
.
shape
[
0
],
query
.
shape
[
1
],
key
.
shape
[
1
],
dtype
=
query
.
dtype
,
device
=
query
.
device
),
query
,
...
...
@@ -609,8 +620,11 @@ class CrossAttention(nn.Module):
alpha
=
self
.
scale
,
)
attention_probs
=
attention_scores
.
softmax
(
dim
=-
1
)
# compute attention output
# cast back to the original dtype
attention_probs
=
attention_probs
.
to
(
value
.
dtype
)
# compute attention output
hidden_states
=
torch
.
bmm
(
attention_probs
,
value
)
# reshape hidden_states
...
...
@@ -626,6 +640,14 @@ class CrossAttention(nn.Module):
for
i
in
range
(
hidden_states
.
shape
[
0
]
//
slice_size
):
start_idx
=
i
*
slice_size
end_idx
=
(
i
+
1
)
*
slice_size
query_slice
=
query
[
start_idx
:
end_idx
]
key_slice
=
key
[
start_idx
:
end_idx
]
if
self
.
upcast_attention
:
query_slice
=
query_slice
.
float
()
key_slice
=
key_slice
.
float
()
attn_slice
=
torch
.
baddbmm
(
torch
.
empty
(
slice_size
,
query
.
shape
[
1
],
key
.
shape
[
1
],
dtype
=
query
.
dtype
,
device
=
query
.
device
),
query
[
start_idx
:
end_idx
],
...
...
@@ -634,6 +656,9 @@ class CrossAttention(nn.Module):
alpha
=
self
.
scale
,
)
attn_slice
=
attn_slice
.
softmax
(
dim
=-
1
)
# cast back to the original dtype
attn_slice
=
attn_slice
.
to
(
value
.
dtype
)
attn_slice
=
torch
.
bmm
(
attn_slice
,
value
[
start_idx
:
end_idx
])
hidden_states
[
start_idx
:
end_idx
]
=
attn_slice
...
...
src/diffusers/models/unet_2d_blocks.py
View file @
170ebd28
...
...
@@ -35,6 +35,7 @@ def get_down_block(
dual_cross_attention
=
False
,
use_linear_projection
=
False
,
only_cross_attention
=
False
,
upcast_attention
=
False
,
):
down_block_type
=
down_block_type
[
7
:]
if
down_block_type
.
startswith
(
"UNetRes"
)
else
down_block_type
if
down_block_type
==
"DownBlock2D"
:
...
...
@@ -80,6 +81,7 @@ def get_down_block(
dual_cross_attention
=
dual_cross_attention
,
use_linear_projection
=
use_linear_projection
,
only_cross_attention
=
only_cross_attention
,
upcast_attention
=
upcast_attention
,
)
elif
down_block_type
==
"SkipDownBlock2D"
:
return
SkipDownBlock2D
(
...
...
@@ -146,6 +148,7 @@ def get_up_block(
dual_cross_attention
=
False
,
use_linear_projection
=
False
,
only_cross_attention
=
False
,
upcast_attention
=
False
,
):
up_block_type
=
up_block_type
[
7
:]
if
up_block_type
.
startswith
(
"UNetRes"
)
else
up_block_type
if
up_block_type
==
"UpBlock2D"
:
...
...
@@ -178,6 +181,7 @@ def get_up_block(
dual_cross_attention
=
dual_cross_attention
,
use_linear_projection
=
use_linear_projection
,
only_cross_attention
=
only_cross_attention
,
upcast_attention
=
upcast_attention
,
)
elif
up_block_type
==
"AttnUpBlock2D"
:
return
AttnUpBlock2D
(
...
...
@@ -335,6 +339,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
cross_attention_dim
=
1280
,
dual_cross_attention
=
False
,
use_linear_projection
=
False
,
upcast_attention
=
False
,
):
super
().
__init__
()
...
...
@@ -370,6 +375,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
cross_attention_dim
=
cross_attention_dim
,
norm_num_groups
=
resnet_groups
,
use_linear_projection
=
use_linear_projection
,
upcast_attention
=
upcast_attention
,
)
)
else
:
...
...
@@ -514,6 +520,7 @@ class CrossAttnDownBlock2D(nn.Module):
dual_cross_attention
=
False
,
use_linear_projection
=
False
,
only_cross_attention
=
False
,
upcast_attention
=
False
,
):
super
().
__init__
()
resnets
=
[]
...
...
@@ -549,6 +556,7 @@ class CrossAttnDownBlock2D(nn.Module):
norm_num_groups
=
resnet_groups
,
use_linear_projection
=
use_linear_projection
,
only_cross_attention
=
only_cross_attention
,
upcast_attention
=
upcast_attention
,
)
)
else
:
...
...
@@ -1096,6 +1104,7 @@ class CrossAttnUpBlock2D(nn.Module):
dual_cross_attention
=
False
,
use_linear_projection
=
False
,
only_cross_attention
=
False
,
upcast_attention
=
False
,
):
super
().
__init__
()
resnets
=
[]
...
...
@@ -1133,6 +1142,7 @@ class CrossAttnUpBlock2D(nn.Module):
norm_num_groups
=
resnet_groups
,
use_linear_projection
=
use_linear_projection
,
only_cross_attention
=
only_cross_attention
,
upcast_attention
=
upcast_attention
,
)
)
else
:
...
...
src/diffusers/models/unet_2d_condition.py
View file @
170ebd28
...
...
@@ -111,6 +111,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
dual_cross_attention
:
bool
=
False
,
use_linear_projection
:
bool
=
False
,
num_class_embeds
:
Optional
[
int
]
=
None
,
upcast_attention
:
bool
=
False
,
):
super
().
__init__
()
...
...
@@ -163,6 +164,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
dual_cross_attention
=
dual_cross_attention
,
use_linear_projection
=
use_linear_projection
,
only_cross_attention
=
only_cross_attention
[
i
],
upcast_attention
=
upcast_attention
,
)
self
.
down_blocks
.
append
(
down_block
)
...
...
@@ -179,6 +181,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
resnet_groups
=
norm_num_groups
,
dual_cross_attention
=
dual_cross_attention
,
use_linear_projection
=
use_linear_projection
,
upcast_attention
=
upcast_attention
,
)
# count how many layers upsample the images
...
...
@@ -219,6 +222,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
dual_cross_attention
=
dual_cross_attention
,
use_linear_projection
=
use_linear_projection
,
only_cross_attention
=
only_cross_attention
[
i
],
upcast_attention
=
upcast_attention
,
)
self
.
up_blocks
.
append
(
up_block
)
prev_output_channel
=
output_channel
...
...
src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
View file @
170ebd28
...
...
@@ -189,6 +189,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
dual_cross_attention
:
bool
=
False
,
use_linear_projection
:
bool
=
False
,
num_class_embeds
:
Optional
[
int
]
=
None
,
upcast_attention
:
bool
=
False
,
):
super
().
__init__
()
...
...
@@ -241,6 +242,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
dual_cross_attention
=
dual_cross_attention
,
use_linear_projection
=
use_linear_projection
,
only_cross_attention
=
only_cross_attention
[
i
],
upcast_attention
=
upcast_attention
,
)
self
.
down_blocks
.
append
(
down_block
)
...
...
@@ -257,6 +259,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
resnet_groups
=
norm_num_groups
,
dual_cross_attention
=
dual_cross_attention
,
use_linear_projection
=
use_linear_projection
,
upcast_attention
=
upcast_attention
,
)
# count how many layers upsample the images
...
...
@@ -297,6 +300,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
dual_cross_attention
=
dual_cross_attention
,
use_linear_projection
=
use_linear_projection
,
only_cross_attention
=
only_cross_attention
[
i
],
upcast_attention
=
upcast_attention
,
)
self
.
up_blocks
.
append
(
up_block
)
prev_output_channel
=
output_channel
...
...
@@ -716,6 +720,7 @@ class CrossAttnDownBlockFlat(nn.Module):
dual_cross_attention
=
False
,
use_linear_projection
=
False
,
only_cross_attention
=
False
,
upcast_attention
=
False
,
):
super
().
__init__
()
resnets
=
[]
...
...
@@ -751,6 +756,7 @@ class CrossAttnDownBlockFlat(nn.Module):
norm_num_groups
=
resnet_groups
,
use_linear_projection
=
use_linear_projection
,
only_cross_attention
=
only_cross_attention
,
upcast_attention
=
upcast_attention
,
)
)
else
:
...
...
@@ -912,6 +918,7 @@ class CrossAttnUpBlockFlat(nn.Module):
dual_cross_attention
=
False
,
use_linear_projection
=
False
,
only_cross_attention
=
False
,
upcast_attention
=
False
,
):
super
().
__init__
()
resnets
=
[]
...
...
@@ -949,6 +956,7 @@ class CrossAttnUpBlockFlat(nn.Module):
norm_num_groups
=
resnet_groups
,
use_linear_projection
=
use_linear_projection
,
only_cross_attention
=
only_cross_attention
,
upcast_attention
=
upcast_attention
,
)
)
else
:
...
...
@@ -1031,6 +1039,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
cross_attention_dim
=
1280
,
dual_cross_attention
=
False
,
use_linear_projection
=
False
,
upcast_attention
=
False
,
):
super
().
__init__
()
...
...
@@ -1066,6 +1075,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
cross_attention_dim
=
cross_attention_dim
,
norm_num_groups
=
resnet_groups
,
use_linear_projection
=
use_linear_projection
,
upcast_attention
=
upcast_attention
,
)
)
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment