Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
52b3ff5e
Commit
52b3ff5e
authored
Jun 28, 2022
by
Patrick von Platen
Browse files
unify ldm and glide attention
parent
fff981df
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
80 additions
and
246 deletions
+80
-246
src/diffusers/models/attention2d.py
src/diffusers/models/attention2d.py
+77
-90
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+1
-79
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+2
-77
No files found.
src/diffusers/models/attention2d.py
View file @
52b3ff5e
import
math
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
def
Normalize
(
in_channels
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
# unet_grad_tts.py
# unet_grad_tts.py
class
LinearAttention
(
torch
.
nn
.
Module
):
class
LinearAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
):
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
):
...
@@ -24,6 +35,7 @@ class LinearAttention(torch.nn.Module):
...
@@ -24,6 +35,7 @@ class LinearAttention(torch.nn.Module):
out
=
out
.
reshape
(
b
,
self
.
heads
,
self
.
dim_head
,
h
,
w
).
reshape
(
b
,
self
.
heads
*
self
.
dim_head
,
h
,
w
)
out
=
out
.
reshape
(
b
,
self
.
heads
,
self
.
dim_head
,
h
,
w
).
reshape
(
b
,
self
.
heads
*
self
.
dim_head
,
h
,
w
)
return
self
.
to_out
(
out
)
return
self
.
to_out
(
out
)
# unet.py
# unet.py
class
AttnBlock
(
nn
.
Module
):
class
AttnBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
):
def
__init__
(
self
,
in_channels
):
...
@@ -62,7 +74,8 @@ class AttnBlock(nn.Module):
...
@@ -62,7 +74,8 @@ class AttnBlock(nn.Module):
return
x
+
h_
return
x
+
h_
# unet_glide.py
# unet_glide.py & unet_ldm.py
class
AttentionBlock
(
nn
.
Module
):
class
AttentionBlock
(
nn
.
Module
):
"""
"""
An attention block that allows spatial positions to attend to each other.
An attention block that allows spatial positions to attend to each other.
...
@@ -78,6 +91,7 @@ class AttentionBlock(nn.Module):
...
@@ -78,6 +91,7 @@ class AttentionBlock(nn.Module):
num_head_channels
=-
1
,
num_head_channels
=-
1
,
use_checkpoint
=
False
,
use_checkpoint
=
False
,
encoder_channels
=
None
,
encoder_channels
=
None
,
use_new_attention_order
=
False
,
# TODO(Patrick) -> is never used, maybe delete?
):
):
super
().
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
channels
=
channels
...
@@ -108,6 +122,7 @@ class AttentionBlock(nn.Module):
...
@@ -108,6 +122,7 @@ class AttentionBlock(nn.Module):
h
=
self
.
proj_out
(
h
)
h
=
self
.
proj_out
(
h
)
return
x
+
h
.
reshape
(
b
,
c
,
*
spatial
)
return
x
+
h
.
reshape
(
b
,
c
,
*
spatial
)
class
QKVAttention
(
nn
.
Module
):
class
QKVAttention
(
nn
.
Module
):
"""
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
...
@@ -140,106 +155,78 @@ class QKVAttention(nn.Module):
...
@@ -140,106 +155,78 @@ class QKVAttention(nn.Module):
return
a
.
reshape
(
bs
,
-
1
,
length
)
return
a
.
reshape
(
bs
,
-
1
,
length
)
# unet_ldm.py
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
class
AttentionBlock
(
nn
.
Module
):
"""
"""
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
Create a 1D, 2D, or 3D convolution module.
to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
"""
if
dims
==
1
:
return
nn
.
Conv1d
(
*
args
,
**
kwargs
)
elif
dims
==
2
:
return
nn
.
Conv2d
(
*
args
,
**
kwargs
)
elif
dims
==
3
:
return
nn
.
Conv3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
def
__init__
(
self
,
channels
,
num_heads
=
1
,
num_head_channels
=-
1
,
use_checkpoint
=
False
,
use_new_attention_order
=
False
,
):
super
().
__init__
()
self
.
channels
=
channels
if
num_head_channels
==
-
1
:
self
.
num_heads
=
num_heads
else
:
assert
(
channels
%
num_head_channels
==
0
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
use_checkpoint
=
use_checkpoint
self
.
norm
=
normalization
(
channels
)
self
.
qkv
=
conv_nd
(
1
,
channels
,
channels
*
3
,
1
)
# split heads before split qkv
self
.
attention
=
QKVAttentionLegacy
(
self
.
num_heads
)
self
.
proj_out
=
zero_module
(
conv_nd
(
1
,
channels
,
channels
,
1
))
class
GroupNorm32
(
nn
.
GroupNorm
):
def
__init__
(
self
,
num_groups
,
num_channels
,
swish
,
eps
=
1e-5
):
super
().
__init__
(
num_groups
=
num_groups
,
num_channels
=
num_channels
,
eps
=
eps
)
self
.
swish
=
swish
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
b
,
c
,
*
spatial
=
x
.
sha
pe
y
=
super
().
forward
(
x
.
float
()).
to
(
x
.
dty
pe
)
x
=
x
.
reshape
(
b
,
c
,
-
1
)
if
self
.
swish
==
1.0
:
qkv
=
self
.
qkv
(
self
.
norm
(
x
)
)
y
=
F
.
silu
(
y
)
h
=
self
.
attention
(
qkv
)
elif
self
.
swish
:
h
=
self
.
proj_out
(
h
)
y
=
y
*
F
.
sigmoid
(
y
*
float
(
self
.
swis
h
)
)
return
(
x
+
h
).
reshape
(
b
,
c
,
*
spatial
)
return
y
class
QKVAttention
(
nn
.
Module
):
def
normalization
(
channels
,
swish
=
0.0
):
"""
"""
A module which performs QKV attention and splits in a different order.
Make a standard normalization layer, with an optional swish activation.
:param channels: number of input channels. :return: an nn.Module for normalization.
"""
"""
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
def
__init__
(
self
,
n_heads
):
super
().
__init__
()
self
.
n_heads
=
n_heads
def
forward
(
self
,
qkv
):
def
zero_module
(
module
):
"""
"""
Apply QKV attention. :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x
Zero out the parameters of a module and return it.
T] tensor after attention.
"""
"""
for
p
in
module
.
parameters
():
bs
,
width
,
length
=
qkv
.
shape
p
.
detach
().
zero_
()
assert
width
%
(
3
*
self
.
n_heads
)
==
0
return
module
ch
=
width
//
(
3
*
self
.
n_heads
)
q
,
k
,
v
=
qkv
.
chunk
(
3
,
dim
=
1
)
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
weight
=
torch
.
einsum
(
"bct,bcs->bts"
,
(
q
*
scale
).
view
(
bs
*
self
.
n_heads
,
ch
,
length
),
(
k
*
scale
).
view
(
bs
*
self
.
n_heads
,
ch
,
length
),
)
# More stable with f16 than dividing afterwards
weight
=
torch
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
a
=
torch
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
.
reshape
(
bs
*
self
.
n_heads
,
ch
,
length
))
return
a
.
reshape
(
bs
,
-
1
,
length
)
@
staticmethod
def
count_flops
(
model
,
_x
,
y
):
return
count_flops_attn
(
model
,
_x
,
y
)
# unet_score_estimation.py
# unet_score_estimation.py
class
AttnBlockpp
(
nn
.
Module
):
#
class AttnBlockpp(nn.Module):
"""Channel-wise self-attention block. Modified from DDPM."""
#
"""Channel-wise self-attention block. Modified from DDPM."""
#
def
__init__
(
self
,
channels
,
skip_rescale
=
False
,
init_scale
=
0.0
):
#
def __init__(self, channels, skip_rescale=False, init_scale=0.0):
super
().
__init__
()
#
super().__init__()
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
channels
//
4
,
32
),
num_channels
=
channels
,
eps
=
1e-6
)
#
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6)
self
.
NIN_0
=
NIN
(
channels
,
channels
)
#
self.NIN_0 = NIN(channels, channels)
self
.
NIN_1
=
NIN
(
channels
,
channels
)
#
self.NIN_1 = NIN(channels, channels)
self
.
NIN_2
=
NIN
(
channels
,
channels
)
#
self.NIN_2 = NIN(channels, channels)
self
.
NIN_3
=
NIN
(
channels
,
channels
,
init_scale
=
init_scale
)
#
self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
self
.
skip_rescale
=
skip_rescale
#
self.skip_rescale = skip_rescale
#
def
forward
(
self
,
x
):
#
def forward(self, x):
B
,
C
,
H
,
W
=
x
.
shape
#
B, C, H, W = x.shape
h
=
self
.
GroupNorm_0
(
x
)
#
h = self.GroupNorm_0(x)
q
=
self
.
NIN_0
(
h
)
#
q = self.NIN_0(h)
k
=
self
.
NIN_1
(
h
)
#
k = self.NIN_1(h)
v
=
self
.
NIN_2
(
h
)
#
v = self.NIN_2(h)
#
w
=
torch
.
einsum
(
"bchw,bcij->bhwij"
,
q
,
k
)
*
(
int
(
C
)
**
(
-
0.5
))
#
w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5))
w
=
torch
.
reshape
(
w
,
(
B
,
H
,
W
,
H
*
W
))
#
w = torch.reshape(w, (B, H, W, H * W))
w
=
F
.
softmax
(
w
,
dim
=-
1
)
#
w = F.softmax(w, dim=-1)
w
=
torch
.
reshape
(
w
,
(
B
,
H
,
W
,
H
,
W
))
#
w = torch.reshape(w, (B, H, W, H, W))
h
=
torch
.
einsum
(
"bhwij,bcij->bchw"
,
w
,
v
)
#
h = torch.einsum("bhwij,bcij->bchw", w, v)
h
=
self
.
NIN_3
(
h
)
#
h = self.NIN_3(h)
if
not
self
.
skip_rescale
:
#
if not self.skip_rescale:
return
x
+
h
#
return x + h
else
:
#
else:
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
#
return (x + h) / np.sqrt(2.0)
src/diffusers/models/unet_glide.py
View file @
52b3ff5e
import
math
from
abc
import
abstractmethod
from
abc
import
abstractmethod
import
torch
import
torch
...
@@ -7,6 +6,7 @@ import torch.nn.functional as F
...
@@ -7,6 +6,7 @@ import torch.nn.functional as F
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention2d
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
,
Upsample
...
@@ -226,84 +226,6 @@ class ResBlock(TimestepBlock):
...
@@ -226,84 +226,6 @@ class ResBlock(TimestepBlock):
return
self
.
skip_connection
(
x
)
+
h
return
self
.
skip_connection
(
x
)
+
h
class
AttentionBlock
(
nn
.
Module
):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def
__init__
(
self
,
channels
,
num_heads
=
1
,
num_head_channels
=-
1
,
use_checkpoint
=
False
,
encoder_channels
=
None
,
):
super
().
__init__
()
self
.
channels
=
channels
if
num_head_channels
==
-
1
:
self
.
num_heads
=
num_heads
else
:
assert
(
channels
%
num_head_channels
==
0
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
use_checkpoint
=
use_checkpoint
self
.
norm
=
normalization
(
channels
,
swish
=
0.0
)
self
.
qkv
=
conv_nd
(
1
,
channels
,
channels
*
3
,
1
)
self
.
attention
=
QKVAttention
(
self
.
num_heads
)
if
encoder_channels
is
not
None
:
self
.
encoder_kv
=
conv_nd
(
1
,
encoder_channels
,
channels
*
2
,
1
)
self
.
proj_out
=
zero_module
(
conv_nd
(
1
,
channels
,
channels
,
1
))
def
forward
(
self
,
x
,
encoder_out
=
None
):
b
,
c
,
*
spatial
=
x
.
shape
qkv
=
self
.
qkv
(
self
.
norm
(
x
).
view
(
b
,
c
,
-
1
))
if
encoder_out
is
not
None
:
encoder_out
=
self
.
encoder_kv
(
encoder_out
)
h
=
self
.
attention
(
qkv
,
encoder_out
)
else
:
h
=
self
.
attention
(
qkv
)
h
=
self
.
proj_out
(
h
)
return
x
+
h
.
reshape
(
b
,
c
,
*
spatial
)
class
QKVAttention
(
nn
.
Module
):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def
__init__
(
self
,
n_heads
):
super
().
__init__
()
self
.
n_heads
=
n_heads
def
forward
(
self
,
qkv
,
encoder_kv
=
None
):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x T] tensor after
attention.
"""
bs
,
width
,
length
=
qkv
.
shape
assert
width
%
(
3
*
self
.
n_heads
)
==
0
ch
=
width
//
(
3
*
self
.
n_heads
)
q
,
k
,
v
=
qkv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
3
,
length
).
split
(
ch
,
dim
=
1
)
if
encoder_kv
is
not
None
:
assert
encoder_kv
.
shape
[
1
]
==
self
.
n_heads
*
ch
*
2
ek
,
ev
=
encoder_kv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
2
,
-
1
).
split
(
ch
,
dim
=
1
)
k
=
torch
.
cat
([
ek
,
k
],
dim
=-
1
)
v
=
torch
.
cat
([
ev
,
v
],
dim
=-
1
)
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
weight
=
torch
.
einsum
(
"bct,bcs->bts"
,
q
*
scale
,
k
*
scale
)
# More stable with f16 than dividing afterwards
weight
=
torch
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
a
=
torch
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
)
return
a
.
reshape
(
bs
,
-
1
,
length
)
class
GlideUNetModel
(
ModelMixin
,
ConfigMixin
):
class
GlideUNetModel
(
ModelMixin
,
ConfigMixin
):
"""
"""
The full UNet model with attention and timestep embedding.
The full UNet model with attention and timestep embedding.
...
...
src/diffusers/models/unet_ldm.py
View file @
52b3ff5e
...
@@ -9,6 +9,7 @@ import torch.nn.functional as F
...
@@ -9,6 +9,7 @@ import torch.nn.functional as F
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention2d
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
,
Upsample
...
@@ -172,8 +173,6 @@ class CrossAttention(nn.Module):
...
@@ -172,8 +173,6 @@ class CrossAttention(nn.Module):
k
=
self
.
to_k
(
context
)
k
=
self
.
to_k
(
context
)
v
=
self
.
to_v
(
context
)
v
=
self
.
to_v
(
context
)
# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
q
=
self
.
reshape_heads_to_batch_dim
(
q
)
q
=
self
.
reshape_heads_to_batch_dim
(
q
)
k
=
self
.
reshape_heads_to_batch_dim
(
k
)
k
=
self
.
reshape_heads_to_batch_dim
(
k
)
v
=
self
.
reshape_heads_to_batch_dim
(
v
)
v
=
self
.
reshape_heads_to_batch_dim
(
v
)
...
@@ -181,12 +180,9 @@ class CrossAttention(nn.Module):
...
@@ -181,12 +180,9 @@ class CrossAttention(nn.Module):
sim
=
torch
.
einsum
(
"b i d, b j d -> b i j"
,
q
,
k
)
*
self
.
scale
sim
=
torch
.
einsum
(
"b i d, b j d -> b i j"
,
q
,
k
)
*
self
.
scale
if
exists
(
mask
):
if
exists
(
mask
):
# mask = rearrange(mask, "b ... -> b (...)")
mask
=
mask
.
reshape
(
batch_size
,
-
1
)
maks
=
mask
.
reshape
(
batch_size
,
-
1
)
max_neg_value
=
-
torch
.
finfo
(
sim
.
dtype
).
max
max_neg_value
=
-
torch
.
finfo
(
sim
.
dtype
).
max
# mask = repeat(mask, "b j -> (b h) () j", h=h)
mask
=
mask
[:,
None
,
:].
repeat
(
h
,
1
,
1
)
mask
=
mask
[:,
None
,
:].
repeat
(
h
,
1
,
1
)
# x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
sim
.
masked_fill_
(
~
mask
,
max_neg_value
)
sim
.
masked_fill_
(
~
mask
,
max_neg_value
)
# attention, what we cannot get enough of
# attention, what we cannot get enough of
...
@@ -194,7 +190,6 @@ class CrossAttention(nn.Module):
...
@@ -194,7 +190,6 @@ class CrossAttention(nn.Module):
out
=
torch
.
einsum
(
"b i j, b j d -> b i d"
,
attn
,
v
)
out
=
torch
.
einsum
(
"b i j, b j d -> b i d"
,
attn
,
v
)
out
=
self
.
reshape_batch_dim_to_heads
(
out
)
out
=
self
.
reshape_batch_dim_to_heads
(
out
)
# out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return
self
.
to_out
(
out
)
return
self
.
to_out
(
out
)
...
@@ -487,47 +482,6 @@ class ResBlock(TimestepBlock):
...
@@ -487,47 +482,6 @@ class ResBlock(TimestepBlock):
return
self
.
skip_connection
(
x
)
+
h
return
self
.
skip_connection
(
x
)
+
h
class
AttentionBlock
(
nn
.
Module
):
"""
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def
__init__
(
self
,
channels
,
num_heads
=
1
,
num_head_channels
=-
1
,
use_checkpoint
=
False
,
use_new_attention_order
=
False
,
):
super
().
__init__
()
self
.
channels
=
channels
if
num_head_channels
==
-
1
:
self
.
num_heads
=
num_heads
else
:
assert
(
channels
%
num_head_channels
==
0
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
use_checkpoint
=
use_checkpoint
self
.
norm
=
normalization
(
channels
)
self
.
qkv
=
conv_nd
(
1
,
channels
,
channels
*
3
,
1
)
# split heads before split qkv
self
.
attention
=
QKVAttentionLegacy
(
self
.
num_heads
)
self
.
proj_out
=
zero_module
(
conv_nd
(
1
,
channels
,
channels
,
1
))
def
forward
(
self
,
x
):
b
,
c
,
*
spatial
=
x
.
shape
x
=
x
.
reshape
(
b
,
c
,
-
1
)
qkv
=
self
.
qkv
(
self
.
norm
(
x
))
h
=
self
.
attention
(
qkv
)
h
=
self
.
proj_out
(
h
)
return
(
x
+
h
).
reshape
(
b
,
c
,
*
spatial
)
class
QKVAttention
(
nn
.
Module
):
class
QKVAttention
(
nn
.
Module
):
"""
"""
A module which performs QKV attention and splits in a different order.
A module which performs QKV attention and splits in a different order.
...
@@ -577,35 +531,6 @@ def count_flops_attn(model, _x, y):
...
@@ -577,35 +531,6 @@ def count_flops_attn(model, _x, y):
model
.
total_ops
+=
torch
.
DoubleTensor
([
matmul_ops
])
model
.
total_ops
+=
torch
.
DoubleTensor
([
matmul_ops
])
class
QKVAttentionLegacy
(
nn
.
Module
):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def
__init__
(
self
,
n_heads
):
super
().
__init__
()
self
.
n_heads
=
n_heads
def
forward
(
self
,
qkv
):
"""
Apply QKV attention. :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x
T] tensor after attention.
"""
bs
,
width
,
length
=
qkv
.
shape
assert
width
%
(
3
*
self
.
n_heads
)
==
0
ch
=
width
//
(
3
*
self
.
n_heads
)
q
,
k
,
v
=
qkv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
3
,
length
).
split
(
ch
,
dim
=
1
)
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
weight
=
torch
.
einsum
(
"bct,bcs->bts"
,
q
*
scale
,
k
*
scale
)
# More stable with f16 than dividing afterwards
weight
=
torch
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
a
=
torch
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
)
return
a
.
reshape
(
bs
,
-
1
,
length
)
@
staticmethod
def
count_flops
(
model
,
_x
,
y
):
return
count_flops_attn
(
model
,
_x
,
y
)
class
UNetLDMModel
(
ModelMixin
,
ConfigMixin
):
class
UNetLDMModel
(
ModelMixin
,
ConfigMixin
):
"""
"""
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment