Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
2bbf8b67
Unverified
Commit
2bbf8b67
authored
Dec 01, 2022
by
Suraj Patil
Committed by
GitHub
Dec 01, 2022
Browse files
simplyfy AttentionBlock (#1492)
parent
5a5bf7ef
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
32 additions
and
46 deletions
+32
-46
src/diffusers/models/attention.py
src/diffusers/models/attention.py
+32
-46
No files found.
src/diffusers/models/attention.py
View file @
2bbf8b67
...
...
@@ -290,11 +290,19 @@ class AttentionBlock(nn.Module):
self
.
rescale_output_factor
=
rescale_output_factor
self
.
proj_attn
=
nn
.
Linear
(
channels
,
channels
,
1
)
def
transpose_for_scores
(
self
,
projection
:
torch
.
Tensor
)
->
torch
.
Tensor
:
new_projection_shape
=
projection
.
size
()[:
-
1
]
+
(
self
.
num_heads
,
-
1
)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
new_projection
=
projection
.
view
(
new_projection_shape
).
permute
(
0
,
2
,
1
,
3
)
return
new_projection
def
reshape_heads_to_batch_dim
(
self
,
tensor
):
batch_size
,
seq_len
,
dim
=
tensor
.
shape
head_size
=
self
.
num_heads
tensor
=
tensor
.
reshape
(
batch_size
,
seq_len
,
head_size
,
dim
//
head_size
)
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
batch_size
*
head_size
,
seq_len
,
dim
//
head_size
)
return
tensor
def
reshape_batch_dim_to_heads
(
self
,
tensor
):
batch_size
,
seq_len
,
dim
=
tensor
.
shape
head_size
=
self
.
num_heads
tensor
=
tensor
.
reshape
(
batch_size
//
head_size
,
head_size
,
seq_len
,
dim
)
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
batch_size
//
head_size
,
seq_len
,
dim
*
head_size
)
return
tensor
def
forward
(
self
,
hidden_states
):
residual
=
hidden_states
...
...
@@ -312,50 +320,28 @@ class AttentionBlock(nn.Module):
scale
=
1
/
math
.
sqrt
(
self
.
channels
/
self
.
num_heads
)
# get scores
if
self
.
num_heads
>
1
:
query_states
=
self
.
transpose_for_scores
(
query_proj
)
key_states
=
self
.
transpose_for_scores
(
key_proj
)
value_states
=
self
.
transpose_for_scores
(
value_proj
)
# TODO: is there a way to perform batched matmul (e.g. baddbmm) on 4D tensors?
# or reformulate this into a 3D problem?
# TODO: measure whether on MPS device it would be faster to do this matmul via einsum
# as some matmuls can be 1.94x slower than an equivalent einsum on MPS
# https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
attention_scores
=
torch
.
matmul
(
query_states
,
key_states
.
transpose
(
-
1
,
-
2
))
*
scale
else
:
query_states
,
key_states
,
value_states
=
query_proj
,
key_proj
,
value_proj
attention_scores
=
torch
.
baddbmm
(
torch
.
empty
(
query_states
.
shape
[
0
],
query_states
.
shape
[
1
],
key_states
.
shape
[
1
],
dtype
=
query_states
.
dtype
,
device
=
query_states
.
device
,
),
query_states
,
key_states
.
transpose
(
-
1
,
-
2
),
beta
=
0
,
alpha
=
scale
,
)
query_proj
=
self
.
reshape_heads_to_batch_dim
(
query_proj
)
key_proj
=
self
.
reshape_heads_to_batch_dim
(
key_proj
)
value_proj
=
self
.
reshape_heads_to_batch_dim
(
value_proj
)
attention_scores
=
torch
.
baddbmm
(
torch
.
empty
(
query_proj
.
shape
[
0
],
query_proj
.
shape
[
1
],
key_proj
.
shape
[
1
],
dtype
=
query_proj
.
dtype
,
device
=
query_proj
.
device
,
),
query_proj
,
key_proj
.
transpose
(
-
1
,
-
2
),
beta
=
0
,
alpha
=
scale
,
)
attention_probs
=
torch
.
softmax
(
attention_scores
.
float
(),
dim
=-
1
).
type
(
attention_scores
.
dtype
)
hidden_states
=
torch
.
bmm
(
attention_probs
,
value_proj
)
# compute attention output
if
self
.
num_heads
>
1
:
# TODO: is there a way to perform batched matmul (e.g. bmm) on 4D tensors?
# or reformulate this into a 3D problem?
# TODO: measure whether on MPS device it would be faster to do this matmul via einsum
# as some matmuls can be 1.94x slower than an equivalent einsum on MPS
# https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
hidden_states
=
torch
.
matmul
(
attention_probs
,
value_states
)
hidden_states
=
hidden_states
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
new_hidden_states_shape
=
hidden_states
.
size
()[:
-
2
]
+
(
self
.
channels
,)
hidden_states
=
hidden_states
.
view
(
new_hidden_states_shape
)
else
:
hidden_states
=
torch
.
bmm
(
attention_probs
,
value_states
)
# reshape hidden_states
hidden_states
=
self
.
reshape_batch_dim_to_heads
(
hidden_states
)
# compute next hidden_states
hidden_states
=
self
.
proj_attn
(
hidden_states
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment