Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
52b3ff5e
Commit
52b3ff5e
authored
Jun 28, 2022
by
Patrick von Platen
Browse files
unify ldm and glide attention
parent
fff981df
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
80 additions
and
246 deletions
+80
-246
src/diffusers/models/attention2d.py
src/diffusers/models/attention2d.py
+77
-90
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+1
-79
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+2
-77
No files found.
src/diffusers/models/attention2d.py
View file @
52b3ff5e
import
math
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
def
Normalize
(
in_channels
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
# unet_grad_tts.py
# unet_grad_tts.py
class
LinearAttention
(
torch
.
nn
.
Module
):
class
LinearAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
):
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
):
...
@@ -24,6 +35,7 @@ class LinearAttention(torch.nn.Module):
...
@@ -24,6 +35,7 @@ class LinearAttention(torch.nn.Module):
out
=
out
.
reshape
(
b
,
self
.
heads
,
self
.
dim_head
,
h
,
w
).
reshape
(
b
,
self
.
heads
*
self
.
dim_head
,
h
,
w
)
out
=
out
.
reshape
(
b
,
self
.
heads
,
self
.
dim_head
,
h
,
w
).
reshape
(
b
,
self
.
heads
*
self
.
dim_head
,
h
,
w
)
return
self
.
to_out
(
out
)
return
self
.
to_out
(
out
)
# unet.py
# unet.py
class
AttnBlock
(
nn
.
Module
):
class
AttnBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
):
def
__init__
(
self
,
in_channels
):
...
@@ -62,7 +74,8 @@ class AttnBlock(nn.Module):
...
@@ -62,7 +74,8 @@ class AttnBlock(nn.Module):
return
x
+
h_
return
x
+
h_
# unet_glide.py
# unet_glide.py & unet_ldm.py
class
AttentionBlock
(
nn
.
Module
):
class
AttentionBlock
(
nn
.
Module
):
"""
"""
An attention block that allows spatial positions to attend to each other.
An attention block that allows spatial positions to attend to each other.
...
@@ -78,6 +91,7 @@ class AttentionBlock(nn.Module):
...
@@ -78,6 +91,7 @@ class AttentionBlock(nn.Module):
num_head_channels
=-
1
,
num_head_channels
=-
1
,
use_checkpoint
=
False
,
use_checkpoint
=
False
,
encoder_channels
=
None
,
encoder_channels
=
None
,
use_new_attention_order
=
False
,
# TODO(Patrick) -> is never used, maybe delete?
):
):
super
().
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
channels
=
channels
...
@@ -108,6 +122,7 @@ class AttentionBlock(nn.Module):
...
@@ -108,6 +122,7 @@ class AttentionBlock(nn.Module):
h
=
self
.
proj_out
(
h
)
h
=
self
.
proj_out
(
h
)
return
x
+
h
.
reshape
(
b
,
c
,
*
spatial
)
return
x
+
h
.
reshape
(
b
,
c
,
*
spatial
)
class
QKVAttention
(
nn
.
Module
):
class
QKVAttention
(
nn
.
Module
):
"""
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
...
@@ -140,106 +155,78 @@ class QKVAttention(nn.Module):
...
@@ -140,106 +155,78 @@ class QKVAttention(nn.Module):
return
a
.
reshape
(
bs
,
-
1
,
length
)
return
a
.
reshape
(
bs
,
-
1
,
length
)
# unet_ldm.py
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
class
AttentionBlock
(
nn
.
Module
):
"""
"""
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
Create a 1D, 2D, or 3D convolution module.
to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
"""
if
dims
==
1
:
return
nn
.
Conv1d
(
*
args
,
**
kwargs
)
elif
dims
==
2
:
return
nn
.
Conv2d
(
*
args
,
**
kwargs
)
elif
dims
==
3
:
return
nn
.
Conv3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
def
__init__
(
self
,
channels
,
num_heads
=
1
,
num_head_channels
=-
1
,
use_checkpoint
=
False
,
use_new_attention_order
=
False
,
):
super
().
__init__
()
self
.
channels
=
channels
if
num_head_channels
==
-
1
:
self
.
num_heads
=
num_heads
else
:
assert
(
channels
%
num_head_channels
==
0
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
use_checkpoint
=
use_checkpoint
self
.
norm
=
normalization
(
channels
)
self
.
qkv
=
conv_nd
(
1
,
channels
,
channels
*
3
,
1
)
# split heads before split qkv
self
.
attention
=
QKVAttentionLegacy
(
self
.
num_heads
)
self
.
proj_out
=
zero_module
(
conv_nd
(
1
,
channels
,
channels
,
1
))
class
GroupNorm32
(
nn
.
GroupNorm
):
def
__init__
(
self
,
num_groups
,
num_channels
,
swish
,
eps
=
1e-5
):
super
().
__init__
(
num_groups
=
num_groups
,
num_channels
=
num_channels
,
eps
=
eps
)
self
.
swish
=
swish
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
b
,
c
,
*
spatial
=
x
.
sha
pe
y
=
super
().
forward
(
x
.
float
()).
to
(
x
.
dty
pe
)
x
=
x
.
reshape
(
b
,
c
,
-
1
)
if
self
.
swish
==
1.0
:
qkv
=
self
.
qkv
(
self
.
norm
(
x
)
)
y
=
F
.
silu
(
y
)
h
=
self
.
attention
(
qkv
)
elif
self
.
swish
:
h
=
self
.
proj_out
(
h
)
y
=
y
*
F
.
sigmoid
(
y
*
float
(
self
.
swis
h
)
)
return
(
x
+
h
).
reshape
(
b
,
c
,
*
spatial
)
return
y
class
QKVAttention
(
nn
.
Module
):
def
normalization
(
channels
,
swish
=
0.0
):
"""
"""
A module which performs QKV attention and splits in a different order.
Make a standard normalization layer, with an optional swish activation.
:param channels: number of input channels. :return: an nn.Module for normalization.
"""
"""
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
def
__init__
(
self
,
n_heads
):
super
().
__init__
()
self
.
n_heads
=
n_heads
def
forward
(
self
,
qkv
):
def
zero_module
(
module
):
"""
"""
Apply QKV attention. :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x
Zero out the parameters of a module and return it.
T] tensor after attention.
"""
"""
for
p
in
module
.
parameters
():
bs
,
width
,
length
=
qkv
.
shape
p
.
detach
().
zero_
()
assert
width
%
(
3
*
self
.
n_heads
)
==
0
return
module
ch
=
width
//
(
3
*
self
.
n_heads
)
q
,
k
,
v
=
qkv
.
chunk
(
3
,
dim
=
1
)
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
weight
=
torch
.
einsum
(
"bct,bcs->bts"
,
(
q
*
scale
).
view
(
bs
*
self
.
n_heads
,
ch
,
length
),
(
k
*
scale
).
view
(
bs
*
self
.
n_heads
,
ch
,
length
),
)
# More stable with f16 than dividing afterwards
weight
=
torch
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
a
=
torch
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
.
reshape
(
bs
*
self
.
n_heads
,
ch
,
length
))
return
a
.
reshape
(
bs
,
-
1
,
length
)
@
staticmethod
def
count_flops
(
model
,
_x
,
y
):
return
count_flops_attn
(
model
,
_x
,
y
)
# unet_score_estimation.py
# unet_score_estimation.py
class
AttnBlockpp
(
nn
.
Module
):
#
class AttnBlockpp(nn.Module):
"""Channel-wise self-attention block. Modified from DDPM."""
#
"""Channel-wise self-attention block. Modified from DDPM."""
#
def
__init__
(
self
,
channels
,
skip_rescale
=
False
,
init_scale
=
0.0
):
#
def __init__(self, channels, skip_rescale=False, init_scale=0.0):
super
().
__init__
()
#
super().__init__()
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
channels
//
4
,
32
),
num_channels
=
channels
,
eps
=
1e-6
)
#
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6)
self
.
NIN_0
=
NIN
(
channels
,
channels
)
#
self.NIN_0 = NIN(channels, channels)
self
.
NIN_1
=
NIN
(
channels
,
channels
)
#
self.NIN_1 = NIN(channels, channels)
self
.
NIN_2
=
NIN
(
channels
,
channels
)
#
self.NIN_2 = NIN(channels, channels)
self
.
NIN_3
=
NIN
(
channels
,
channels
,
init_scale
=
init_scale
)
#
self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
self
.
skip_rescale
=
skip_rescale
#
self.skip_rescale = skip_rescale
#
def
forward
(
self
,
x
):
#
def forward(self, x):
B
,
C
,
H
,
W
=
x
.
shape
#
B, C, H, W = x.shape
h
=
self
.
GroupNorm_0
(
x
)
#
h = self.GroupNorm_0(x)
q
=
self
.
NIN_0
(
h
)
#
q = self.NIN_0(h)
k
=
self
.
NIN_1
(
h
)
#
k = self.NIN_1(h)
v
=
self
.
NIN_2
(
h
)
#
v = self.NIN_2(h)
#
w
=
torch
.
einsum
(
"bchw,bcij->bhwij"
,
q
,
k
)
*
(
int
(
C
)
**
(
-
0.5
))
#
w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5))
w
=
torch
.
reshape
(
w
,
(
B
,
H
,
W
,
H
*
W
))
#
w = torch.reshape(w, (B, H, W, H * W))
w
=
F
.
softmax
(
w
,
dim
=-
1
)
#
w = F.softmax(w, dim=-1)
w
=
torch
.
reshape
(
w
,
(
B
,
H
,
W
,
H
,
W
))
#
w = torch.reshape(w, (B, H, W, H, W))
h
=
torch
.
einsum
(
"bhwij,bcij->bchw"
,
w
,
v
)
#
h = torch.einsum("bhwij,bcij->bchw", w, v)
h
=
self
.
NIN_3
(
h
)
#
h = self.NIN_3(h)
if
not
self
.
skip_rescale
:
#
if not self.skip_rescale:
return
x
+
h
#
return x + h
else
:
#
else:
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
#
return (x + h) / np.sqrt(2.0)
src/diffusers/models/unet_glide.py
View file @
52b3ff5e
import
math
from
abc
import
abstractmethod
from
abc
import
abstractmethod
import
torch
import
torch
...
@@ -7,6 +6,7 @@ import torch.nn.functional as F
...
@@ -7,6 +6,7 @@ import torch.nn.functional as F
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention2d
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
,
Upsample
...
@@ -226,84 +226,6 @@ class ResBlock(TimestepBlock):
...
@@ -226,84 +226,6 @@ class ResBlock(TimestepBlock):
return
self
.
skip_connection
(
x
)
+
h
return
self
.
skip_connection
(
x
)
+
h
class
AttentionBlock
(
nn
.
Module
):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def
__init__
(
self
,
channels
,
num_heads
=
1
,
num_head_channels
=-
1
,
use_checkpoint
=
False
,
encoder_channels
=
None
,
):
super
().
__init__
()
self
.
channels
=
channels
if
num_head_channels
==
-
1
:
self
.
num_heads
=
num_heads
else
:
assert
(
channels
%
num_head_channels
==
0
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
use_checkpoint
=
use_checkpoint
self
.
norm
=
normalization
(
channels
,
swish
=
0.0
)
self
.
qkv
=
conv_nd
(
1
,
channels
,
channels
*
3
,
1
)
self
.
attention
=
QKVAttention
(
self
.
num_heads
)
if
encoder_channels
is
not
None
:
self
.
encoder_kv
=
conv_nd
(
1
,
encoder_channels
,
channels
*
2
,
1
)
self
.
proj_out
=
zero_module
(
conv_nd
(
1
,
channels
,
channels
,
1
))
def
forward
(
self
,
x
,
encoder_out
=
None
):
b
,
c
,
*
spatial
=
x
.
shape
qkv
=
self
.
qkv
(
self
.
norm
(
x
).
view
(
b
,
c
,
-
1
))
if
encoder_out
is
not
None
:
encoder_out
=
self
.
encoder_kv
(
encoder_out
)
h
=
self
.
attention
(
qkv
,
encoder_out
)
else
:
h
=
self
.
attention
(
qkv
)
h
=
self
.
proj_out
(
h
)
return
x
+
h
.
reshape
(
b
,
c
,
*
spatial
)
class
QKVAttention
(
nn
.
Module
):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def
__init__
(
self
,
n_heads
):
super
().
__init__
()
self
.
n_heads
=
n_heads
def
forward
(
self
,
qkv
,
encoder_kv
=
None
):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x T] tensor after
attention.
"""
bs
,
width
,
length
=
qkv
.
shape
assert
width
%
(
3
*
self
.
n_heads
)
==
0
ch
=
width
//
(
3
*
self
.
n_heads
)
q
,
k
,
v
=
qkv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
3
,
length
).
split
(
ch
,
dim
=
1
)
if
encoder_kv
is
not
None
:
assert
encoder_kv
.
shape
[
1
]
==
self
.
n_heads
*
ch
*
2
ek
,
ev
=
encoder_kv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
2
,
-
1
).
split
(
ch
,
dim
=
1
)
k
=
torch
.
cat
([
ek
,
k
],
dim
=-
1
)
v
=
torch
.
cat
([
ev
,
v
],
dim
=-
1
)
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
weight
=
torch
.
einsum
(
"bct,bcs->bts"
,
q
*
scale
,
k
*
scale
)
# More stable with f16 than dividing afterwards
weight
=
torch
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
a
=
torch
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
)
return
a
.
reshape
(
bs
,
-
1
,
length
)
class
GlideUNetModel
(
ModelMixin
,
ConfigMixin
):
class
GlideUNetModel
(
ModelMixin
,
ConfigMixin
):
"""
"""
The full UNet model with attention and timestep embedding.
The full UNet model with attention and timestep embedding.
...
...
src/diffusers/models/unet_ldm.py
View file @
52b3ff5e
...
@@ -9,6 +9,7 @@ import torch.nn.functional as F
...
@@ -9,6 +9,7 @@ import torch.nn.functional as F
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention2d
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
,
Upsample
...
@@ -172,8 +173,6 @@ class CrossAttention(nn.Module):
...
@@ -172,8 +173,6 @@ class CrossAttention(nn.Module):
k
=
self
.
to_k
(
context
)
k
=
self
.
to_k
(
context
)
v
=
self
.
to_v
(
context
)
v
=
self
.
to_v
(
context
)
# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
q
=
self
.
reshape_heads_to_batch_dim
(
q
)
q
=
self
.
reshape_heads_to_batch_dim
(
q
)
k
=
self
.
reshape_heads_to_batch_dim
(
k
)
k
=
self
.
reshape_heads_to_batch_dim
(
k
)
v
=
self
.
reshape_heads_to_batch_dim
(
v
)
v
=
self
.
reshape_heads_to_batch_dim
(
v
)
...
@@ -181,12 +180,9 @@ class CrossAttention(nn.Module):
...
@@ -181,12 +180,9 @@ class CrossAttention(nn.Module):
sim
=
torch
.
einsum
(
"b i d, b j d -> b i j"
,
q
,
k
)
*
self
.
scale
sim
=
torch
.
einsum
(
"b i d, b j d -> b i j"
,
q
,
k
)
*
self
.
scale
if
exists
(
mask
):
if
exists
(
mask
):
# mask = rearrange(mask, "b ... -> b (...)")
mask
=
mask
.
reshape
(
batch_size
,
-
1
)
maks
=
mask
.
reshape
(
batch_size
,
-
1
)
max_neg_value
=
-
torch
.
finfo
(
sim
.
dtype
).
max
max_neg_value
=
-
torch
.
finfo
(
sim
.
dtype
).
max
# mask = repeat(mask, "b j -> (b h) () j", h=h)
mask
=
mask
[:,
None
,
:].
repeat
(
h
,
1
,
1
)
mask
=
mask
[:,
None
,
:].
repeat
(
h
,
1
,
1
)
# x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
sim
.
masked_fill_
(
~
mask
,
max_neg_value
)
sim
.
masked_fill_
(
~
mask
,
max_neg_value
)
# attention, what we cannot get enough of
# attention, what we cannot get enough of
...
@@ -194,7 +190,6 @@ class CrossAttention(nn.Module):
...
@@ -194,7 +190,6 @@ class CrossAttention(nn.Module):
out
=
torch
.
einsum
(
"b i j, b j d -> b i d"
,
attn
,
v
)
out
=
torch
.
einsum
(
"b i j, b j d -> b i d"
,
attn
,
v
)
out
=
self
.
reshape_batch_dim_to_heads
(
out
)
out
=
self
.
reshape_batch_dim_to_heads
(
out
)
# out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return
self
.
to_out
(
out
)
return
self
.
to_out
(
out
)
...
@@ -487,47 +482,6 @@ class ResBlock(TimestepBlock):
...
@@ -487,47 +482,6 @@ class ResBlock(TimestepBlock):
return
self
.
skip_connection
(
x
)
+
h
return
self
.
skip_connection
(
x
)
+
h
class
AttentionBlock
(
nn
.
Module
):
"""
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def
__init__
(
self
,
channels
,
num_heads
=
1
,
num_head_channels
=-
1
,
use_checkpoint
=
False
,
use_new_attention_order
=
False
,
):
super
().
__init__
()
self
.
channels
=
channels
if
num_head_channels
==
-
1
:
self
.
num_heads
=
num_heads
else
:
assert
(
channels
%
num_head_channels
==
0
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
use_checkpoint
=
use_checkpoint
self
.
norm
=
normalization
(
channels
)
self
.
qkv
=
conv_nd
(
1
,
channels
,
channels
*
3
,
1
)
# split heads before split qkv
self
.
attention
=
QKVAttentionLegacy
(
self
.
num_heads
)
self
.
proj_out
=
zero_module
(
conv_nd
(
1
,
channels
,
channels
,
1
))
def
forward
(
self
,
x
):
b
,
c
,
*
spatial
=
x
.
shape
x
=
x
.
reshape
(
b
,
c
,
-
1
)
qkv
=
self
.
qkv
(
self
.
norm
(
x
))
h
=
self
.
attention
(
qkv
)
h
=
self
.
proj_out
(
h
)
return
(
x
+
h
).
reshape
(
b
,
c
,
*
spatial
)
class
QKVAttention
(
nn
.
Module
):
class
QKVAttention
(
nn
.
Module
):
"""
"""
A module which performs QKV attention and splits in a different order.
A module which performs QKV attention and splits in a different order.
...
@@ -577,35 +531,6 @@ def count_flops_attn(model, _x, y):
...
@@ -577,35 +531,6 @@ def count_flops_attn(model, _x, y):
model
.
total_ops
+=
torch
.
DoubleTensor
([
matmul_ops
])
model
.
total_ops
+=
torch
.
DoubleTensor
([
matmul_ops
])
class
QKVAttentionLegacy
(
nn
.
Module
):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def
__init__
(
self
,
n_heads
):
super
().
__init__
()
self
.
n_heads
=
n_heads
def
forward
(
self
,
qkv
):
"""
Apply QKV attention. :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x
T] tensor after attention.
"""
bs
,
width
,
length
=
qkv
.
shape
assert
width
%
(
3
*
self
.
n_heads
)
==
0
ch
=
width
//
(
3
*
self
.
n_heads
)
q
,
k
,
v
=
qkv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
3
,
length
).
split
(
ch
,
dim
=
1
)
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
weight
=
torch
.
einsum
(
"bct,bcs->bts"
,
q
*
scale
,
k
*
scale
)
# More stable with f16 than dividing afterwards
weight
=
torch
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
a
=
torch
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
)
return
a
.
reshape
(
bs
,
-
1
,
length
)
@
staticmethod
def
count_flops
(
model
,
_x
,
y
):
return
count_flops_attn
(
model
,
_x
,
y
)
class
UNetLDMModel
(
ModelMixin
,
ConfigMixin
):
class
UNetLDMModel
(
ModelMixin
,
ConfigMixin
):
"""
"""
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment