Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
e5d9baf0
You need to sign in or sign up before continuing.
Unverified
Commit
e5d9baf0
authored
Jun 29, 2022
by
Patrick von Platen
Committed by
GitHub
Jun 29, 2022
Browse files
Merge pull request #38 from huggingface/one_attentino_module
Unify attention modules
parents
e47c97a4
c482d7bd
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
299 additions
and
94 deletions
+299
-94
src/diffusers/models/attention.py
src/diffusers/models/attention.py
+289
-0
src/diffusers/models/embeddings.py
src/diffusers/models/embeddings.py
+0
-16
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+1
-15
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+1
-1
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+1
-27
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+1
-1
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+3
-33
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+3
-1
No files found.
src/diffusers/models/attention
2d
.py
→
src/diffusers/models/attention.py
View file @
e5d9baf0
import
math
import
math
import
torch
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch
import
nn
# unet_grad_tts.py
# unet_grad_tts.py
# TODO(Patrick) - weird linear attention layer. Check with: https://github.com/huawei-noah/Speech-Backbones/issues/15
class
LinearAttention
(
torch
.
nn
.
Module
):
class
LinearAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
):
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
):
super
(
LinearAttention
,
self
).
__init__
()
super
(
LinearAttention
,
self
).
__init__
()
...
@@ -18,7 +18,6 @@ class LinearAttention(torch.nn.Module):
...
@@ -18,7 +18,6 @@ class LinearAttention(torch.nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
b
,
c
,
h
,
w
=
x
.
shape
b
,
c
,
h
,
w
=
x
.
shape
qkv
=
self
.
to_qkv
(
x
)
qkv
=
self
.
to_qkv
(
x
)
# q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
q
,
k
,
v
=
(
q
,
k
,
v
=
(
qkv
.
reshape
(
b
,
3
,
self
.
heads
,
self
.
dim_head
,
h
,
w
)
qkv
.
reshape
(
b
,
3
,
self
.
heads
,
self
.
dim_head
,
h
,
w
)
.
permute
(
1
,
0
,
2
,
3
,
4
,
5
)
.
permute
(
1
,
0
,
2
,
3
,
4
,
5
)
...
@@ -27,12 +26,11 @@ class LinearAttention(torch.nn.Module):
...
@@ -27,12 +26,11 @@ class LinearAttention(torch.nn.Module):
k
=
k
.
softmax
(
dim
=-
1
)
k
=
k
.
softmax
(
dim
=-
1
)
context
=
torch
.
einsum
(
"bhdn,bhen->bhde"
,
k
,
v
)
context
=
torch
.
einsum
(
"bhdn,bhen->bhde"
,
k
,
v
)
out
=
torch
.
einsum
(
"bhde,bhdn->bhen"
,
context
,
q
)
out
=
torch
.
einsum
(
"bhde,bhdn->bhen"
,
context
,
q
)
# out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
out
=
out
.
reshape
(
b
,
self
.
heads
,
self
.
dim_head
,
h
,
w
).
reshape
(
b
,
self
.
heads
*
self
.
dim_head
,
h
,
w
)
out
=
out
.
reshape
(
b
,
self
.
heads
,
self
.
dim_head
,
h
,
w
).
reshape
(
b
,
self
.
heads
*
self
.
dim_head
,
h
,
w
)
return
self
.
to_out
(
out
)
return
self
.
to_out
(
out
)
#
unet_glide.py & unet_ldm.py
#
the main attention block that is used for all models
class
AttentionBlock
(
nn
.
Module
):
class
AttentionBlock
(
nn
.
Module
):
"""
"""
An attention block that allows spatial positions to attend to each other.
An attention block that allows spatial positions to attend to each other.
...
@@ -46,10 +44,13 @@ class AttentionBlock(nn.Module):
...
@@ -46,10 +44,13 @@ class AttentionBlock(nn.Module):
channels
,
channels
,
num_heads
=
1
,
num_heads
=
1
,
num_head_channels
=-
1
,
num_head_channels
=-
1
,
num_groups
=
32
,
use_checkpoint
=
False
,
use_checkpoint
=
False
,
encoder_channels
=
None
,
encoder_channels
=
None
,
use_new_attention_order
=
False
,
# TODO(Patrick) -> is never used, maybe delete?
use_new_attention_order
=
False
,
# TODO(Patrick) -> is never used, maybe delete?
overwrite_qkv
=
False
,
overwrite_qkv
=
False
,
overwrite_linear
=
False
,
rescale_output_factor
=
1.0
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
channels
=
channels
...
@@ -62,41 +63,67 @@ class AttentionBlock(nn.Module):
...
@@ -62,41 +63,67 @@ class AttentionBlock(nn.Module):
self
.
num_heads
=
channels
//
num_head_channels
self
.
num_heads
=
channels
//
num_head_channels
self
.
use_checkpoint
=
use_checkpoint
self
.
use_checkpoint
=
use_checkpoint
self
.
norm
=
n
ormalization
(
channels
,
swish
=
0.0
)
self
.
norm
=
n
n
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
1e-5
,
affine
=
True
)
self
.
qkv
=
conv_nd
(
1
,
channels
,
channels
*
3
,
1
)
self
.
qkv
=
nn
.
Conv1d
(
channels
,
channels
*
3
,
1
)
self
.
n_heads
=
self
.
num_heads
self
.
n_heads
=
self
.
num_heads
self
.
rescale_output_factor
=
rescale_output_factor
if
encoder_channels
is
not
None
:
if
encoder_channels
is
not
None
:
self
.
encoder_kv
=
conv_nd
(
1
,
encoder_channels
,
channels
*
2
,
1
)
self
.
encoder_kv
=
nn
.
Conv1d
(
encoder_channels
,
channels
*
2
,
1
)
self
.
proj_out
=
zero_module
(
conv_nd
(
1
,
channels
,
channels
,
1
))
self
.
proj_out
=
zero_module
(
nn
.
Conv1d
(
channels
,
channels
,
1
))
self
.
overwrite_qkv
=
overwrite_qkv
self
.
overwrite_qkv
=
overwrite_qkv
if
overwrite_qkv
:
if
overwrite_qkv
:
in_channels
=
channels
in_channels
=
channels
self
.
norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
1e-6
)
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
overwrite_linear
=
overwrite_linear
if
self
.
overwrite_linear
:
num_groups
=
min
(
channels
//
4
,
32
)
self
.
norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
1e-6
)
self
.
NIN_0
=
NIN
(
channels
,
channels
)
self
.
NIN_1
=
NIN
(
channels
,
channels
)
self
.
NIN_2
=
NIN
(
channels
,
channels
)
self
.
NIN_3
=
NIN
(
channels
,
channels
)
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
num_groups
,
num_channels
=
channels
,
eps
=
1e-6
)
self
.
is_overwritten
=
False
self
.
is_overwritten
=
False
def
set_weights
(
self
,
module
):
def
set_weights
(
self
,
module
):
if
self
.
overwrite_qkv
:
if
self
.
overwrite_qkv
:
qkv_weight
=
torch
.
cat
([
module
.
q
.
weight
.
data
,
module
.
k
.
weight
.
data
,
module
.
v
.
weight
.
data
],
dim
=
0
)[:,
:,
:,
0
]
qkv_weight
=
torch
.
cat
([
module
.
q
.
weight
.
data
,
module
.
k
.
weight
.
data
,
module
.
v
.
weight
.
data
],
dim
=
0
)[
:,
:,
:,
0
]
qkv_bias
=
torch
.
cat
([
module
.
q
.
bias
.
data
,
module
.
k
.
bias
.
data
,
module
.
v
.
bias
.
data
],
dim
=
0
)
qkv_bias
=
torch
.
cat
([
module
.
q
.
bias
.
data
,
module
.
k
.
bias
.
data
,
module
.
v
.
bias
.
data
],
dim
=
0
)
self
.
qkv
.
weight
.
data
=
qkv_weight
self
.
qkv
.
weight
.
data
=
qkv_weight
self
.
qkv
.
bias
.
data
=
qkv_bias
self
.
qkv
.
bias
.
data
=
qkv_bias
proj_out
=
zero_module
(
conv_nd
(
1
,
self
.
channels
,
self
.
channels
,
1
))
proj_out
=
zero_module
(
nn
.
Conv1d
(
self
.
channels
,
self
.
channels
,
1
))
proj_out
.
weight
.
data
=
module
.
proj_out
.
weight
.
data
[:,
:,
:,
0
]
proj_out
.
weight
.
data
=
module
.
proj_out
.
weight
.
data
[:,
:,
:,
0
]
proj_out
.
bias
.
data
=
module
.
proj_out
.
bias
.
data
proj_out
.
bias
.
data
=
module
.
proj_out
.
bias
.
data
self
.
proj_out
=
proj_out
self
.
proj_out
=
proj_out
elif
self
.
overwrite_linear
:
self
.
qkv
.
weight
.
data
=
torch
.
concat
(
[
self
.
NIN_0
.
W
.
data
.
T
,
self
.
NIN_1
.
W
.
data
.
T
,
self
.
NIN_2
.
W
.
data
.
T
],
dim
=
0
)[:,
:,
None
]
self
.
qkv
.
bias
.
data
=
torch
.
concat
([
self
.
NIN_0
.
b
.
data
,
self
.
NIN_1
.
b
.
data
,
self
.
NIN_2
.
b
.
data
],
dim
=
0
)
self
.
proj_out
.
weight
.
data
=
self
.
NIN_3
.
W
.
data
.
T
[:,
:,
None
]
self
.
proj_out
.
bias
.
data
=
self
.
NIN_3
.
b
.
data
self
.
norm
.
weight
.
data
=
self
.
GroupNorm_0
.
weight
.
data
self
.
norm
.
bias
.
data
=
self
.
GroupNorm_0
.
bias
.
data
def
forward
(
self
,
x
,
encoder_out
=
None
):
def
forward
(
self
,
x
,
encoder_out
=
None
):
if
self
.
overwrite_qkv
and
not
self
.
is_overwritten
:
if
(
self
.
overwrite_qkv
or
self
.
overwrite_linear
)
and
not
self
.
is_overwritten
:
self
.
set_weights
(
self
)
self
.
set_weights
(
self
)
self
.
is_overwritten
=
True
self
.
is_overwritten
=
True
...
@@ -124,69 +151,74 @@ class AttentionBlock(nn.Module):
...
@@ -124,69 +151,74 @@ class AttentionBlock(nn.Module):
h
=
a
.
reshape
(
bs
,
-
1
,
length
)
h
=
a
.
reshape
(
bs
,
-
1
,
length
)
h
=
self
.
proj_out
(
h
)
h
=
self
.
proj_out
(
h
)
h
=
h
.
reshape
(
b
,
c
,
*
spatial
)
return
x
+
h
.
reshape
(
b
,
c
,
*
spatial
)
result
=
x
+
h
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if
dims
==
1
:
return
nn
.
Conv1d
(
*
args
,
**
kwargs
)
elif
dims
==
2
:
return
nn
.
Conv2d
(
*
args
,
**
kwargs
)
elif
dims
==
3
:
return
nn
.
Conv3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
class
GroupNorm32
(
nn
.
GroupNorm
):
def
__init__
(
self
,
num_groups
,
num_channels
,
swish
,
eps
=
1e-5
,
affine
=
True
):
super
().
__init__
(
num_groups
=
num_groups
,
num_channels
=
num_channels
,
eps
=
eps
,
affine
=
affine
)
self
.
swish
=
swish
def
forward
(
self
,
x
):
y
=
super
().
forward
(
x
.
float
()).
to
(
x
.
dtype
)
if
self
.
swish
==
1.0
:
y
=
F
.
silu
(
y
)
elif
self
.
swish
:
y
=
y
*
F
.
sigmoid
(
y
*
float
(
self
.
swish
))
return
y
def
normalization
(
channels
,
swish
=
0.0
,
eps
=
1e-5
):
"""
Make a standard normalization layer, with an optional swish activation.
:param channels: number of input channels. :return: an nn.Module for normalization.
"""
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
,
eps
=
eps
,
affine
=
True
)
result
=
result
/
self
.
rescale_output_factor
def
zero_module
(
module
):
return
result
"""
Zero out the parameters of a module and return it.
"""
for
p
in
module
.
parameters
():
p
.
detach
().
zero_
()
return
module
# unet_score_estimation.py
# unet_score_estimation.py
# class AttnBlockpp(nn.Module):
# class AttnBlockpp(nn.Module):
# """Channel-wise self-attention block. Modified from DDPM."""
# """Channel-wise self-attention block. Modified from DDPM."""
#
#
# def __init__(self, channels, skip_rescale=False, init_scale=0.0):
# def __init__(
# self,
# channels,
# skip_rescale=False,
# init_scale=0.0,
# num_heads=1,
# num_head_channels=-1,
# use_checkpoint=False,
# encoder_channels=None,
# use_new_attention_order=False, # TODO(Patrick) -> is never used, maybe delete?
# overwrite_qkv=False,
# overwrite_from_grad_tts=False,
# ):
# super().__init__()
# super().__init__()
# self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6)
# num_groups = min(channels // 4, 32)
# self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
# self.NIN_0 = NIN(channels, channels)
# self.NIN_0 = NIN(channels, channels)
# self.NIN_1 = NIN(channels, channels)
# self.NIN_1 = NIN(channels, channels)
# self.NIN_2 = NIN(channels, channels)
# self.NIN_2 = NIN(channels, channels)
# self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
# self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
# self.skip_rescale = skip_rescale
# self.skip_rescale = skip_rescale
#
#
# self.channels = channels
# if num_head_channels == -1:
# self.num_heads = num_heads
# else:
# assert (
# channels % num_head_channels == 0
# ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
# self.num_heads = channels // num_head_channels
#
# self.use_checkpoint = use_checkpoint
# self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
# self.qkv = nn.Conv1d(channels, channels * 3, 1)
# self.n_heads = self.num_heads
#
# self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
#
# self.is_weight_set = False
#
# def set_weights(self):
# self.qkv.weight.data = torch.concat([self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0)[:, :, None]
# self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0)
#
# self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None]
# self.proj_out.bias.data = self.NIN_3.b.data
#
# self.norm.weight.data = self.GroupNorm_0.weight.data
# self.norm.bias.data = self.GroupNorm_0.bias.data
#
# def forward(self, x):
# def forward(self, x):
# if not self.is_weight_set:
# self.set_weights()
# self.is_weight_set = True
#
# B, C, H, W = x.shape
# B, C, H, W = x.shape
# h = self.GroupNorm_0(x)
# h = self.GroupNorm_0(x)
# q = self.NIN_0(h)
# q = self.NIN_0(h)
...
@@ -199,7 +231,59 @@ def zero_module(module):
...
@@ -199,7 +231,59 @@ def zero_module(module):
# w = torch.reshape(w, (B, H, W, H, W))
# w = torch.reshape(w, (B, H, W, H, W))
# h = torch.einsum("bhwij,bcij->bchw", w, v)
# h = torch.einsum("bhwij,bcij->bchw", w, v)
# h = self.NIN_3(h)
# h = self.NIN_3(h)
#
# if not self.skip_rescale:
# if not self.skip_rescale:
# re
turn
x + h
# re
sult =
x + h
# else:
# else:
# return (x + h) / np.sqrt(2.0)
# result = (x + h) / np.sqrt(2.0)
#
# result = self.forward_2(x)
#
# return result
#
# def forward_2(self, x, encoder_out=None):
# b, c, *spatial = x.shape
# hid_states = self.norm(x).view(b, c, -1)
#
# qkv = self.qkv(hid_states)
# bs, width, length = qkv.shape
# assert width % (3 * self.n_heads) == 0
# ch = width // (3 * self.n_heads)
# q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
#
# if encoder_out is not None:
# encoder_kv = self.encoder_kv(encoder_out)
# assert encoder_kv.shape[1] == self.n_heads * ch * 2
# ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
# k = torch.cat([ek, k], dim=-1)
# v = torch.cat([ev, v], dim=-1)
#
# scale = 1 / math.sqrt(math.sqrt(ch))
# weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
# weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
#
# a = torch.einsum("bts,bcs->bct", weight, v)
# h = a.reshape(bs, -1, length)
#
# h = self.proj_out(h)
# h = h.reshape(b, c, *spatial)
#
# return (x + h) / np.sqrt(2.0)
# TODO(Patrick) - this can and should be removed
def
zero_module
(
module
):
"""
Zero out the parameters of a module and return it.
"""
for
p
in
module
.
parameters
():
p
.
detach
().
zero_
()
return
module
# TODO(Patrick) - remove once all weights have been converted -> not needed anymore then
class
NIN
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
num_units
,
init_scale
=
0.1
):
super
().
__init__
()
self
.
W
=
nn
.
Parameter
(
torch
.
zeros
(
in_dim
,
num_units
),
requires_grad
=
True
)
self
.
b
=
nn
.
Parameter
(
torch
.
zeros
(
num_units
),
requires_grad
=
True
)
src/diffusers/models/embeddings.py
View file @
e5d9baf0
...
@@ -65,19 +65,3 @@ class GaussianFourierProjection(nn.Module):
...
@@ -65,19 +65,3 @@ class GaussianFourierProjection(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x_proj
=
x
[:,
None
]
*
self
.
W
[
None
,
:]
*
2
*
np
.
pi
x_proj
=
x
[:,
None
]
*
self
.
W
[
None
,
:]
*
2
*
np
.
pi
return
torch
.
cat
([
torch
.
sin
(
x_proj
),
torch
.
cos
(
x_proj
)],
dim
=-
1
)
return
torch
.
cat
([
torch
.
sin
(
x_proj
),
torch
.
cos
(
x_proj
)],
dim
=-
1
)
# unet_rl.py - TODO(need test)
class
SinusoidalPosEmb
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
dim
=
dim
def
forward
(
self
,
x
):
device
=
x
.
device
half_dim
=
self
.
dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
device
=
device
)
*
-
emb
)
emb
=
x
[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
return
emb
src/diffusers/models/unet.py
View file @
e5d9baf0
...
@@ -15,24 +15,14 @@
...
@@ -15,24 +15,14 @@
# helpers functions
# helpers functions
import
copy
import
math
from
pathlib
import
Path
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
torch.cuda.amp
import
GradScaler
,
autocast
from
torch.optim
import
Adam
from
torch.utils
import
data
from
PIL
import
Image
from
tqdm
import
tqdm
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
,
Upsample
from
.attention2d
import
AttentionBlock
def
nonlinearity
(
x
):
def
nonlinearity
(
x
):
...
@@ -219,11 +209,7 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -219,11 +209,7 @@ class UNetModel(ModelMixin, ConfigMixin):
for
i_block
in
range
(
self
.
num_res_blocks
):
for
i_block
in
range
(
self
.
num_res_blocks
):
h
=
self
.
down
[
i_level
].
block
[
i_block
](
hs
[
-
1
],
temb
)
h
=
self
.
down
[
i_level
].
block
[
i_block
](
hs
[
-
1
],
temb
)
if
len
(
self
.
down
[
i_level
].
attn
)
>
0
:
if
len
(
self
.
down
[
i_level
].
attn
)
>
0
:
# self.down[i_level].attn_2[i_block].set_weights(self.down[i_level].attn[i_block])
# h = self.down[i_level].attn_2[i_block](h)
h
=
self
.
down
[
i_level
].
attn
[
i_block
](
h
)
h
=
self
.
down
[
i_level
].
attn
[
i_block
](
h
)
# print("Result", (h - h_2).abs().sum())
hs
.
append
(
h
)
hs
.
append
(
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
if
i_level
!=
self
.
num_resolutions
-
1
:
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
...
...
src/diffusers/models/unet_glide.py
View file @
e5d9baf0
...
@@ -6,7 +6,7 @@ import torch.nn.functional as F
...
@@ -6,7 +6,7 @@ import torch.nn.functional as F
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
2d
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
,
Upsample
...
...
src/diffusers/models/unet_grad_tts.py
View file @
e5d9baf0
import
torch
import
torch
from
numpy
import
pad
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
LinearAttention
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
,
Upsample
...
@@ -54,32 +54,6 @@ class ResnetBlock(torch.nn.Module):
...
@@ -54,32 +54,6 @@ class ResnetBlock(torch.nn.Module):
return
output
return
output
class
LinearAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
):
super
(
LinearAttention
,
self
).
__init__
()
self
.
heads
=
heads
self
.
dim_head
=
dim_head
hidden_dim
=
dim_head
*
heads
self
.
to_qkv
=
torch
.
nn
.
Conv2d
(
dim
,
hidden_dim
*
3
,
1
,
bias
=
False
)
self
.
to_out
=
torch
.
nn
.
Conv2d
(
hidden_dim
,
dim
,
1
)
def
forward
(
self
,
x
):
b
,
c
,
h
,
w
=
x
.
shape
qkv
=
self
.
to_qkv
(
x
)
# q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
q
,
k
,
v
=
(
qkv
.
reshape
(
b
,
3
,
self
.
heads
,
self
.
dim_head
,
h
,
w
)
.
permute
(
1
,
0
,
2
,
3
,
4
,
5
)
.
reshape
(
3
,
b
,
self
.
heads
,
self
.
dim_head
,
-
1
)
)
k
=
k
.
softmax
(
dim
=-
1
)
context
=
torch
.
einsum
(
"bhdn,bhen->bhde"
,
k
,
v
)
out
=
torch
.
einsum
(
"bhde,bhdn->bhen"
,
context
,
q
)
# out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
out
=
out
.
reshape
(
b
,
self
.
heads
,
self
.
dim_head
,
h
,
w
).
reshape
(
b
,
self
.
heads
*
self
.
dim_head
,
h
,
w
)
return
self
.
to_out
(
out
)
class
Residual
(
torch
.
nn
.
Module
):
class
Residual
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
fn
):
def
__init__
(
self
,
fn
):
super
(
Residual
,
self
).
__init__
()
super
(
Residual
,
self
).
__init__
()
...
...
src/diffusers/models/unet_ldm.py
View file @
e5d9baf0
...
@@ -9,7 +9,7 @@ import torch.nn.functional as F
...
@@ -9,7 +9,7 @@ import torch.nn.functional as F
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
2d
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
,
Upsample
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
e5d9baf0
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
# helpers functions
# helpers functions
import
functools
import
functools
import
math
import
string
import
string
import
numpy
as
np
import
numpy
as
np
...
@@ -25,6 +26,7 @@ import torch.nn.functional as F
...
@@ -25,6 +26,7 @@ import torch.nn.functional as F
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
...
@@ -414,37 +416,6 @@ class Combine(nn.Module):
...
@@ -414,37 +416,6 @@ class Combine(nn.Module):
raise
ValueError
(
f
"Method
{
self
.
method
}
not recognized."
)
raise
ValueError
(
f
"Method
{
self
.
method
}
not recognized."
)
class
AttnBlockpp
(
nn
.
Module
):
"""Channel-wise self-attention block. Modified from DDPM."""
def
__init__
(
self
,
channels
,
skip_rescale
=
False
,
init_scale
=
0.0
):
super
().
__init__
()
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
channels
//
4
,
32
),
num_channels
=
channels
,
eps
=
1e-6
)
self
.
NIN_0
=
NIN
(
channels
,
channels
)
self
.
NIN_1
=
NIN
(
channels
,
channels
)
self
.
NIN_2
=
NIN
(
channels
,
channels
)
self
.
NIN_3
=
NIN
(
channels
,
channels
,
init_scale
=
init_scale
)
self
.
skip_rescale
=
skip_rescale
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
h
=
self
.
GroupNorm_0
(
x
)
q
=
self
.
NIN_0
(
h
)
k
=
self
.
NIN_1
(
h
)
v
=
self
.
NIN_2
(
h
)
w
=
torch
.
einsum
(
"bchw,bcij->bhwij"
,
q
,
k
)
*
(
int
(
C
)
**
(
-
0.5
))
w
=
torch
.
reshape
(
w
,
(
B
,
H
,
W
,
H
*
W
))
w
=
F
.
softmax
(
w
,
dim
=-
1
)
w
=
torch
.
reshape
(
w
,
(
B
,
H
,
W
,
H
,
W
))
h
=
torch
.
einsum
(
"bhwij,bcij->bchw"
,
w
,
v
)
h
=
self
.
NIN_3
(
h
)
if
not
self
.
skip_rescale
:
return
x
+
h
else
:
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
class
Upsample
(
nn
.
Module
):
class
Upsample
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
=
None
,
out_ch
=
None
,
with_conv
=
False
,
fir
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
def
__init__
(
self
,
in_ch
=
None
,
out_ch
=
None
,
with_conv
=
False
,
fir
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
super
().
__init__
()
super
().
__init__
()
...
@@ -756,8 +727,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -756,8 +727,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
modules
[
-
1
].
weight
.
data
=
default_init
()(
modules
[
-
1
].
weight
.
shape
)
modules
[
-
1
].
weight
.
data
=
default_init
()(
modules
[
-
1
].
weight
.
shape
)
nn
.
init
.
zeros_
(
modules
[
-
1
].
bias
)
nn
.
init
.
zeros_
(
modules
[
-
1
].
bias
)
AttnBlock
=
functools
.
partial
(
AttnBlockpp
,
init_scale
=
init_scale
,
skip_rescale
=
skip_rescale
)
AttnBlock
=
functools
.
partial
(
AttentionBlock
,
overwrite_linear
=
True
,
rescale_output_factor
=
math
.
sqrt
(
2.0
))
Up_sample
=
functools
.
partial
(
Upsample
,
with_conv
=
resamp_with_conv
,
fir
=
fir
,
fir_kernel
=
fir_kernel
)
Up_sample
=
functools
.
partial
(
Upsample
,
with_conv
=
resamp_with_conv
,
fir
=
fir
,
fir_kernel
=
fir_kernel
)
if
progressive
==
"output_skip"
:
if
progressive
==
"output_skip"
:
...
...
tests/test_modeling_utils.py
View file @
e5d9baf0
...
@@ -859,7 +859,9 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -859,7 +859,9 @@ class PipelineTesterMixin(unittest.TestCase):
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
assert
image
.
shape
==
(
1
,
3
,
32
,
32
)
assert
image
.
shape
==
(
1
,
3
,
32
,
32
)
expected_slice
=
torch
.
tensor
([
-
0.5712
,
-
0.6215
,
-
0.5953
,
-
0.5438
,
-
0.4775
,
-
0.4539
,
-
0.5172
,
-
0.4872
,
-
0.5105
])
expected_slice
=
torch
.
tensor
(
[
-
0.5712
,
-
0.6215
,
-
0.5953
,
-
0.5438
,
-
0.4775
,
-
0.4539
,
-
0.5172
,
-
0.4872
,
-
0.5105
]
)
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
@
slow
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment