Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
2bbf8b67
Unverified
Commit
2bbf8b67
authored
Dec 01, 2022
by
Suraj Patil
Committed by
GitHub
Dec 01, 2022
Browse files
simplyfy AttentionBlock (#1492)
parent
5a5bf7ef
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
32 additions
and
46 deletions
+32
-46
src/diffusers/models/attention.py
src/diffusers/models/attention.py
+32
-46
No files found.
src/diffusers/models/attention.py
View file @
2bbf8b67
...
...
@@ -290,11 +290,19 @@ class AttentionBlock(nn.Module):
self
.
rescale_output_factor
=
rescale_output_factor
self
.
proj_attn
=
nn
.
Linear
(
channels
,
channels
,
1
)
def
transpose_for_scores
(
self
,
projection
:
torch
.
Tensor
)
->
torch
.
Tensor
:
new_projection_shape
=
projection
.
size
()[:
-
1
]
+
(
self
.
num_heads
,
-
1
)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
new_projection
=
projection
.
view
(
new_projection_shape
).
permute
(
0
,
2
,
1
,
3
)
return
new_projection
def
reshape_heads_to_batch_dim
(
self
,
tensor
):
batch_size
,
seq_len
,
dim
=
tensor
.
shape
head_size
=
self
.
num_heads
tensor
=
tensor
.
reshape
(
batch_size
,
seq_len
,
head_size
,
dim
//
head_size
)
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
batch_size
*
head_size
,
seq_len
,
dim
//
head_size
)
return
tensor
def
reshape_batch_dim_to_heads
(
self
,
tensor
):
batch_size
,
seq_len
,
dim
=
tensor
.
shape
head_size
=
self
.
num_heads
tensor
=
tensor
.
reshape
(
batch_size
//
head_size
,
head_size
,
seq_len
,
dim
)
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
batch_size
//
head_size
,
seq_len
,
dim
*
head_size
)
return
tensor
def
forward
(
self
,
hidden_states
):
residual
=
hidden_states
...
...
@@ -312,50 +320,28 @@ class AttentionBlock(nn.Module):
scale
=
1
/
math
.
sqrt
(
self
.
channels
/
self
.
num_heads
)
# get scores
if
self
.
num_heads
>
1
:
query_states
=
self
.
transpose_for_scores
(
query_proj
)
key_states
=
self
.
transpose_for_scores
(
key_proj
)
value_states
=
self
.
transpose_for_scores
(
value_proj
)
# TODO: is there a way to perform batched matmul (e.g. baddbmm) on 4D tensors?
# or reformulate this into a 3D problem?
# TODO: measure whether on MPS device it would be faster to do this matmul via einsum
# as some matmuls can be 1.94x slower than an equivalent einsum on MPS
# https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
attention_scores
=
torch
.
matmul
(
query_states
,
key_states
.
transpose
(
-
1
,
-
2
))
*
scale
else
:
query_states
,
key_states
,
value_states
=
query_proj
,
key_proj
,
value_proj
attention_scores
=
torch
.
baddbmm
(
torch
.
empty
(
query_states
.
shape
[
0
],
query_states
.
shape
[
1
],
key_states
.
shape
[
1
],
dtype
=
query_states
.
dtype
,
device
=
query_states
.
device
,
),
query_states
,
key_states
.
transpose
(
-
1
,
-
2
),
beta
=
0
,
alpha
=
scale
,
)
query_proj
=
self
.
reshape_heads_to_batch_dim
(
query_proj
)
key_proj
=
self
.
reshape_heads_to_batch_dim
(
key_proj
)
value_proj
=
self
.
reshape_heads_to_batch_dim
(
value_proj
)
attention_scores
=
torch
.
baddbmm
(
torch
.
empty
(
query_proj
.
shape
[
0
],
query_proj
.
shape
[
1
],
key_proj
.
shape
[
1
],
dtype
=
query_proj
.
dtype
,
device
=
query_proj
.
device
,
),
query_proj
,
key_proj
.
transpose
(
-
1
,
-
2
),
beta
=
0
,
alpha
=
scale
,
)
attention_probs
=
torch
.
softmax
(
attention_scores
.
float
(),
dim
=-
1
).
type
(
attention_scores
.
dtype
)
hidden_states
=
torch
.
bmm
(
attention_probs
,
value_proj
)
# compute attention output
if
self
.
num_heads
>
1
:
# TODO: is there a way to perform batched matmul (e.g. bmm) on 4D tensors?
# or reformulate this into a 3D problem?
# TODO: measure whether on MPS device it would be faster to do this matmul via einsum
# as some matmuls can be 1.94x slower than an equivalent einsum on MPS
# https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
hidden_states
=
torch
.
matmul
(
attention_probs
,
value_states
)
hidden_states
=
hidden_states
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
new_hidden_states_shape
=
hidden_states
.
size
()[:
-
2
]
+
(
self
.
channels
,)
hidden_states
=
hidden_states
.
view
(
new_hidden_states_shape
)
else
:
hidden_states
=
torch
.
bmm
(
attention_probs
,
value_states
)
# reshape hidden_states
hidden_states
=
self
.
reshape_batch_dim_to_heads
(
hidden_states
)
# compute next hidden_states
hidden_states
=
self
.
proj_attn
(
hidden_states
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment