Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
fff981df
Commit
fff981df
authored
Jun 28, 2022
by
Patrick von Platen
Browse files
all attentions collected
parent
a42b900d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
245 additions
and
0 deletions
+245
-0
src/diffusers/models/attention2d.py
src/diffusers/models/attention2d.py
+245
-0
No files found.
src/diffusers/models/attention2d.py
View file @
fff981df
# unet_grad_tts.py
class
LinearAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
):
super
(
LinearAttention
,
self
).
__init__
()
self
.
heads
=
heads
self
.
dim_head
=
dim_head
hidden_dim
=
dim_head
*
heads
self
.
to_qkv
=
torch
.
nn
.
Conv2d
(
dim
,
hidden_dim
*
3
,
1
,
bias
=
False
)
self
.
to_out
=
torch
.
nn
.
Conv2d
(
hidden_dim
,
dim
,
1
)
def
forward
(
self
,
x
):
b
,
c
,
h
,
w
=
x
.
shape
qkv
=
self
.
to_qkv
(
x
)
# q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
q
,
k
,
v
=
(
qkv
.
reshape
(
b
,
3
,
self
.
heads
,
self
.
dim_head
,
h
,
w
)
.
permute
(
1
,
0
,
2
,
3
,
4
,
5
)
.
reshape
(
3
,
b
,
self
.
heads
,
self
.
dim_head
,
-
1
)
)
k
=
k
.
softmax
(
dim
=-
1
)
context
=
torch
.
einsum
(
"bhdn,bhen->bhde"
,
k
,
v
)
out
=
torch
.
einsum
(
"bhde,bhdn->bhen"
,
context
,
q
)
# out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
out
=
out
.
reshape
(
b
,
self
.
heads
,
self
.
dim_head
,
h
,
w
).
reshape
(
b
,
self
.
heads
*
self
.
dim_head
,
h
,
w
)
return
self
.
to_out
(
out
)
# unet.py
class
AttnBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
norm
=
Normalize
(
in_channels
)
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
):
h_
=
x
h_
=
self
.
norm
(
h_
)
q
=
self
.
q
(
h_
)
k
=
self
.
k
(
h_
)
v
=
self
.
v
(
h_
)
# compute attention
b
,
c
,
h
,
w
=
q
.
shape
q
=
q
.
reshape
(
b
,
c
,
h
*
w
)
q
=
q
.
permute
(
0
,
2
,
1
)
# b,hw,c
k
=
k
.
reshape
(
b
,
c
,
h
*
w
)
# b,c,hw
w_
=
torch
.
bmm
(
q
,
k
)
# b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_
=
w_
*
(
int
(
c
)
**
(
-
0.5
))
w_
=
torch
.
nn
.
functional
.
softmax
(
w_
,
dim
=
2
)
# attend to values
v
=
v
.
reshape
(
b
,
c
,
h
*
w
)
w_
=
w_
.
permute
(
0
,
2
,
1
)
# b,hw,hw (first hw of k, second of q)
h_
=
torch
.
bmm
(
v
,
w_
)
# b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_
=
h_
.
reshape
(
b
,
c
,
h
,
w
)
h_
=
self
.
proj_out
(
h_
)
return
x
+
h_
# unet_glide.py
class
AttentionBlock
(
nn
.
Module
):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def
__init__
(
self
,
channels
,
num_heads
=
1
,
num_head_channels
=-
1
,
use_checkpoint
=
False
,
encoder_channels
=
None
,
):
super
().
__init__
()
self
.
channels
=
channels
if
num_head_channels
==
-
1
:
self
.
num_heads
=
num_heads
else
:
assert
(
channels
%
num_head_channels
==
0
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
use_checkpoint
=
use_checkpoint
self
.
norm
=
normalization
(
channels
,
swish
=
0.0
)
self
.
qkv
=
conv_nd
(
1
,
channels
,
channels
*
3
,
1
)
self
.
attention
=
QKVAttention
(
self
.
num_heads
)
if
encoder_channels
is
not
None
:
self
.
encoder_kv
=
conv_nd
(
1
,
encoder_channels
,
channels
*
2
,
1
)
self
.
proj_out
=
zero_module
(
conv_nd
(
1
,
channels
,
channels
,
1
))
def
forward
(
self
,
x
,
encoder_out
=
None
):
b
,
c
,
*
spatial
=
x
.
shape
qkv
=
self
.
qkv
(
self
.
norm
(
x
).
view
(
b
,
c
,
-
1
))
if
encoder_out
is
not
None
:
encoder_out
=
self
.
encoder_kv
(
encoder_out
)
h
=
self
.
attention
(
qkv
,
encoder_out
)
else
:
h
=
self
.
attention
(
qkv
)
h
=
self
.
proj_out
(
h
)
return
x
+
h
.
reshape
(
b
,
c
,
*
spatial
)
class
QKVAttention
(
nn
.
Module
):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def
__init__
(
self
,
n_heads
):
super
().
__init__
()
self
.
n_heads
=
n_heads
def
forward
(
self
,
qkv
,
encoder_kv
=
None
):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x T] tensor after
attention.
"""
bs
,
width
,
length
=
qkv
.
shape
assert
width
%
(
3
*
self
.
n_heads
)
==
0
ch
=
width
//
(
3
*
self
.
n_heads
)
q
,
k
,
v
=
qkv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
3
,
length
).
split
(
ch
,
dim
=
1
)
if
encoder_kv
is
not
None
:
assert
encoder_kv
.
shape
[
1
]
==
self
.
n_heads
*
ch
*
2
ek
,
ev
=
encoder_kv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
2
,
-
1
).
split
(
ch
,
dim
=
1
)
k
=
torch
.
cat
([
ek
,
k
],
dim
=-
1
)
v
=
torch
.
cat
([
ev
,
v
],
dim
=-
1
)
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
weight
=
torch
.
einsum
(
"bct,bcs->bts"
,
q
*
scale
,
k
*
scale
)
# More stable with f16 than dividing afterwards
weight
=
torch
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
a
=
torch
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
)
return
a
.
reshape
(
bs
,
-
1
,
length
)
# unet_ldm.py
class
AttentionBlock
(
nn
.
Module
):
"""
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def
__init__
(
self
,
channels
,
num_heads
=
1
,
num_head_channels
=-
1
,
use_checkpoint
=
False
,
use_new_attention_order
=
False
,
):
super
().
__init__
()
self
.
channels
=
channels
if
num_head_channels
==
-
1
:
self
.
num_heads
=
num_heads
else
:
assert
(
channels
%
num_head_channels
==
0
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
use_checkpoint
=
use_checkpoint
self
.
norm
=
normalization
(
channels
)
self
.
qkv
=
conv_nd
(
1
,
channels
,
channels
*
3
,
1
)
# split heads before split qkv
self
.
attention
=
QKVAttentionLegacy
(
self
.
num_heads
)
self
.
proj_out
=
zero_module
(
conv_nd
(
1
,
channels
,
channels
,
1
))
def
forward
(
self
,
x
):
b
,
c
,
*
spatial
=
x
.
shape
x
=
x
.
reshape
(
b
,
c
,
-
1
)
qkv
=
self
.
qkv
(
self
.
norm
(
x
))
h
=
self
.
attention
(
qkv
)
h
=
self
.
proj_out
(
h
)
return
(
x
+
h
).
reshape
(
b
,
c
,
*
spatial
)
class
QKVAttention
(
nn
.
Module
):
"""
A module which performs QKV attention and splits in a different order.
"""
def
__init__
(
self
,
n_heads
):
super
().
__init__
()
self
.
n_heads
=
n_heads
def
forward
(
self
,
qkv
):
"""
Apply QKV attention. :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x
T] tensor after attention.
"""
bs
,
width
,
length
=
qkv
.
shape
assert
width
%
(
3
*
self
.
n_heads
)
==
0
ch
=
width
//
(
3
*
self
.
n_heads
)
q
,
k
,
v
=
qkv
.
chunk
(
3
,
dim
=
1
)
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
weight
=
torch
.
einsum
(
"bct,bcs->bts"
,
(
q
*
scale
).
view
(
bs
*
self
.
n_heads
,
ch
,
length
),
(
k
*
scale
).
view
(
bs
*
self
.
n_heads
,
ch
,
length
),
)
# More stable with f16 than dividing afterwards
weight
=
torch
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
a
=
torch
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
.
reshape
(
bs
*
self
.
n_heads
,
ch
,
length
))
return
a
.
reshape
(
bs
,
-
1
,
length
)
@
staticmethod
def
count_flops
(
model
,
_x
,
y
):
return
count_flops_attn
(
model
,
_x
,
y
)
# unet_score_estimation.py
class
AttnBlockpp
(
nn
.
Module
):
"""Channel-wise self-attention block. Modified from DDPM."""
def
__init__
(
self
,
channels
,
skip_rescale
=
False
,
init_scale
=
0.0
):
super
().
__init__
()
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
channels
//
4
,
32
),
num_channels
=
channels
,
eps
=
1e-6
)
self
.
NIN_0
=
NIN
(
channels
,
channels
)
self
.
NIN_1
=
NIN
(
channels
,
channels
)
self
.
NIN_2
=
NIN
(
channels
,
channels
)
self
.
NIN_3
=
NIN
(
channels
,
channels
,
init_scale
=
init_scale
)
self
.
skip_rescale
=
skip_rescale
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
h
=
self
.
GroupNorm_0
(
x
)
q
=
self
.
NIN_0
(
h
)
k
=
self
.
NIN_1
(
h
)
v
=
self
.
NIN_2
(
h
)
w
=
torch
.
einsum
(
"bchw,bcij->bhwij"
,
q
,
k
)
*
(
int
(
C
)
**
(
-
0.5
))
w
=
torch
.
reshape
(
w
,
(
B
,
H
,
W
,
H
*
W
))
w
=
F
.
softmax
(
w
,
dim
=-
1
)
w
=
torch
.
reshape
(
w
,
(
B
,
H
,
W
,
H
,
W
))
h
=
torch
.
einsum
(
"bhwij,bcij->bchw"
,
w
,
v
)
h
=
self
.
NIN_3
(
h
)
if
not
self
.
skip_rescale
:
return
x
+
h
else
:
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment