Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
94566e6d
Unverified
Commit
94566e6d
authored
Jul 04, 2022
by
Patrick von Platen
Committed by
GitHub
Jul 04, 2022
Browse files
update mid block (#70)
* update mid block * finish mid block
parent
4e267493
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
361 additions
and
147 deletions
+361
-147
src/diffusers/models/attention.py
src/diffusers/models/attention.py
+154
-114
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+7
-11
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+21
-5
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+2
-0
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+25
-10
src/diffusers/models/unet_new.py
src/diffusers/models/unet_new.py
+128
-0
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+24
-7
No files found.
src/diffusers/models/attention.py
View file @
94566e6d
import
math
import
math
from
inspect
import
isfunction
import
torch
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch
import
nn
...
@@ -43,18 +45,16 @@ class AttentionBlock(nn.Module):
...
@@ -43,18 +45,16 @@ class AttentionBlock(nn.Module):
self
,
self
,
channels
,
channels
,
num_heads
=
1
,
num_heads
=
1
,
num_head_channels
=
-
1
,
num_head_channels
=
None
,
num_groups
=
32
,
num_groups
=
32
,
use_checkpoint
=
False
,
encoder_channels
=
None
,
encoder_channels
=
None
,
use_new_attention_order
=
False
,
# TODO(Patrick) -> is never used, maybe delete?
overwrite_qkv
=
False
,
overwrite_qkv
=
False
,
overwrite_linear
=
False
,
overwrite_linear
=
False
,
rescale_output_factor
=
1.0
,
rescale_output_factor
=
1.0
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
channels
=
channels
if
num_head_channels
==
-
1
:
if
num_head_channels
is
None
:
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
else
:
else
:
assert
(
assert
(
...
@@ -62,7 +62,6 @@ class AttentionBlock(nn.Module):
...
@@ -62,7 +62,6 @@ class AttentionBlock(nn.Module):
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
num_heads
=
channels
//
num_head_channels
self
.
use_checkpoint
=
use_checkpoint
self
.
norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
1e-5
,
affine
=
True
)
self
.
norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
1e-5
,
affine
=
True
)
self
.
qkv
=
nn
.
Conv1d
(
channels
,
channels
*
3
,
1
)
self
.
qkv
=
nn
.
Conv1d
(
channels
,
channels
*
3
,
1
)
self
.
n_heads
=
self
.
num_heads
self
.
n_heads
=
self
.
num_heads
...
@@ -160,115 +159,135 @@ class AttentionBlock(nn.Module):
...
@@ -160,115 +159,135 @@ class AttentionBlock(nn.Module):
return
result
return
result
# unet_score_estimation.py
class
SpatialTransformer
(
nn
.
Module
):
# class AttnBlockpp(nn.Module):
"""
# """Channel-wise self-attention block. Modified from DDPM."""
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
#
standard transformer action. Finally, reshape to image
# def __init__(
"""
# self,
# channels,
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
depth
=
1
,
dropout
=
0.0
,
context_dim
=
None
):
# skip_rescale=False,
super
().
__init__
()
# init_scale=0.0,
self
.
in_channels
=
in_channels
# num_heads=1,
inner_dim
=
n_heads
*
d_head
# num_head_channels=-1,
self
.
norm
=
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
# use_checkpoint=False,
# encoder_channels=None,
self
.
proj_in
=
nn
.
Conv2d
(
in_channels
,
inner_dim
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
# use_new_attention_order=False, # TODO(Patrick) -> is never used, maybe delete?
# overwrite_qkv=False,
self
.
transformer_blocks
=
nn
.
ModuleList
(
# overwrite_from_grad_tts=False,
[
# ):
BasicTransformerBlock
(
inner_dim
,
n_heads
,
d_head
,
dropout
=
dropout
,
context_dim
=
context_dim
)
# super().__init__()
for
d
in
range
(
depth
)
# num_groups = min(channels // 4, 32)
]
# self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
)
# self.NIN_0 = NIN(channels, channels)
# self.NIN_1 = NIN(channels, channels)
self
.
proj_out
=
zero_module
(
nn
.
Conv2d
(
inner_dim
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
))
# self.NIN_2 = NIN(channels, channels)
# self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
def
forward
(
self
,
x
,
context
=
None
):
# self.skip_rescale = skip_rescale
# note: if no context is given, cross-attention defaults to self-attention
#
b
,
c
,
h
,
w
=
x
.
shape
# self.channels = channels
x_in
=
x
# if num_head_channels == -1:
x
=
self
.
norm
(
x
)
# self.num_heads = num_heads
x
=
self
.
proj_in
(
x
)
# else:
x
=
x
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
b
,
h
*
w
,
c
)
# assert (
for
block
in
self
.
transformer_blocks
:
# channels % num_head_channels == 0
x
=
block
(
x
,
context
=
context
)
# ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
x
=
x
.
reshape
(
b
,
h
,
w
,
c
).
permute
(
0
,
3
,
1
,
2
)
# self.num_heads = channels // num_head_channels
x
=
self
.
proj_out
(
x
)
#
return
x
+
x_in
# self.use_checkpoint = use_checkpoint
# self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
# self.qkv = nn.Conv1d(channels, channels * 3, 1)
class
BasicTransformerBlock
(
nn
.
Module
):
# self.n_heads = self.num_heads
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.0
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
):
#
super
().
__init__
()
# self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
self
.
attn1
=
CrossAttention
(
#
query_dim
=
dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
# self.is_weight_set = False
)
# is a self-attention
#
self
.
ff
=
FeedForward
(
dim
,
dropout
=
dropout
,
glu
=
gated_ff
)
# def set_weights(self):
self
.
attn2
=
CrossAttention
(
# self.qkv.weight.data = torch.concat([self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0)[:, :, None]
query_dim
=
dim
,
context_dim
=
context_dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
# self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0)
)
# is self-attn if context is none
#
self
.
norm1
=
nn
.
LayerNorm
(
dim
)
# self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None]
self
.
norm2
=
nn
.
LayerNorm
(
dim
)
# self.proj_out.bias.data = self.NIN_3.b.data
self
.
norm3
=
nn
.
LayerNorm
(
dim
)
#
self
.
checkpoint
=
checkpoint
# self.norm.weight.data = self.GroupNorm_0.weight.data
# self.norm.bias.data = self.GroupNorm_0.bias.data
def
forward
(
self
,
x
,
context
=
None
):
#
x
=
self
.
attn1
(
self
.
norm1
(
x
))
+
x
# def forward(self, x):
x
=
self
.
attn2
(
self
.
norm2
(
x
),
context
=
context
)
+
x
# if not self.is_weight_set:
x
=
self
.
ff
(
self
.
norm3
(
x
))
+
x
# self.set_weights()
return
x
# self.is_weight_set = True
#
# B, C, H, W = x.shape
class
CrossAttention
(
nn
.
Module
):
# h = self.GroupNorm_0(x)
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
):
# q = self.NIN_0(h)
super
().
__init__
()
# k = self.NIN_1(h)
inner_dim
=
dim_head
*
heads
# v = self.NIN_2(h)
context_dim
=
default
(
context_dim
,
query_dim
)
#
# w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5))
self
.
scale
=
dim_head
**-
0.5
# w = torch.reshape(w, (B, H, W, H * W))
self
.
heads
=
heads
# w = F.softmax(w, dim=-1)
# w = torch.reshape(w, (B, H, W, H, W))
self
.
to_q
=
nn
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
)
# h = torch.einsum("bhwij,bcij->bchw", w, v)
self
.
to_k
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
# h = self.NIN_3(h)
self
.
to_v
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
#
# if not self.skip_rescale:
self
.
to_out
=
nn
.
Sequential
(
nn
.
Linear
(
inner_dim
,
query_dim
),
nn
.
Dropout
(
dropout
))
# result = x + h
# else:
def
reshape_heads_to_batch_dim
(
self
,
tensor
):
# result = (x + h) / np.sqrt(2.0)
batch_size
,
seq_len
,
dim
=
tensor
.
shape
#
head_size
=
self
.
heads
# result = self.forward_2(x)
tensor
=
tensor
.
reshape
(
batch_size
,
seq_len
,
head_size
,
dim
//
head_size
)
#
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
batch_size
*
head_size
,
seq_len
,
dim
//
head_size
)
# return result
return
tensor
#
# def forward_2(self, x, encoder_out=None):
def
reshape_batch_dim_to_heads
(
self
,
tensor
):
# b, c, *spatial = x.shape
batch_size
,
seq_len
,
dim
=
tensor
.
shape
# hid_states = self.norm(x).view(b, c, -1)
head_size
=
self
.
heads
#
tensor
=
tensor
.
reshape
(
batch_size
//
head_size
,
head_size
,
seq_len
,
dim
)
# qkv = self.qkv(hid_states)
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
batch_size
//
head_size
,
seq_len
,
dim
*
head_size
)
# bs, width, length = qkv.shape
return
tensor
# assert width % (3 * self.n_heads) == 0
# ch = width // (3 * self.n_heads)
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
):
# q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
batch_size
,
sequence_length
,
dim
=
x
.
shape
#
# if encoder_out is not None:
h
=
self
.
heads
# encoder_kv = self.encoder_kv(encoder_out)
# assert encoder_kv.shape[1] == self.n_heads * ch * 2
q
=
self
.
to_q
(
x
)
# ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
context
=
default
(
context
,
x
)
# k = torch.cat([ek, k], dim=-1)
k
=
self
.
to_k
(
context
)
# v = torch.cat([ev, v], dim=-1)
v
=
self
.
to_v
(
context
)
#
# scale = 1 / math.sqrt(math.sqrt(ch))
q
=
self
.
reshape_heads_to_batch_dim
(
q
)
# weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
k
=
self
.
reshape_heads_to_batch_dim
(
k
)
# weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
v
=
self
.
reshape_heads_to_batch_dim
(
v
)
#
# a = torch.einsum("bts,bcs->bct", weight, v)
sim
=
torch
.
einsum
(
"b i d, b j d -> b i j"
,
q
,
k
)
*
self
.
scale
# h = a.reshape(bs, -1, length)
#
if
exists
(
mask
):
# h = self.proj_out(h)
mask
=
mask
.
reshape
(
batch_size
,
-
1
)
# h = h.reshape(b, c, *spatial)
max_neg_value
=
-
torch
.
finfo
(
sim
.
dtype
).
max
#
mask
=
mask
[:,
None
,
:].
repeat
(
h
,
1
,
1
)
# return (x + h) / np.sqrt(2.0)
sim
.
masked_fill_
(
~
mask
,
max_neg_value
)
# attention, what we cannot get enough of
attn
=
sim
.
softmax
(
dim
=-
1
)
out
=
torch
.
einsum
(
"b i j, b j d -> b i d"
,
attn
,
v
)
out
=
self
.
reshape_batch_dim_to_heads
(
out
)
return
self
.
to_out
(
out
)
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
=
None
,
mult
=
4
,
glu
=
False
,
dropout
=
0.0
):
super
().
__init__
()
inner_dim
=
int
(
dim
*
mult
)
dim_out
=
default
(
dim_out
,
dim
)
project_in
=
nn
.
Sequential
(
nn
.
Linear
(
dim
,
inner_dim
),
nn
.
GELU
())
if
not
glu
else
GEGLU
(
dim
,
inner_dim
)
self
.
net
=
nn
.
Sequential
(
project_in
,
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
inner_dim
,
dim_out
))
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
# TODO(Patrick) - this can and should be removed
# TODO(Patrick) - this can and should be removed
...
@@ -287,3 +306,24 @@ class NIN(nn.Module):
...
@@ -287,3 +306,24 @@ class NIN(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
W
=
nn
.
Parameter
(
torch
.
zeros
(
in_dim
,
num_units
),
requires_grad
=
True
)
self
.
W
=
nn
.
Parameter
(
torch
.
zeros
(
in_dim
,
num_units
),
requires_grad
=
True
)
self
.
b
=
nn
.
Parameter
(
torch
.
zeros
(
num_units
),
requires_grad
=
True
)
self
.
b
=
nn
.
Parameter
(
torch
.
zeros
(
num_units
),
requires_grad
=
True
)
def
exists
(
val
):
return
val
is
not
None
def
default
(
val
,
d
):
if
exists
(
val
):
return
val
return
d
()
if
isfunction
(
d
)
else
d
# feedforward
class
GEGLU
(
nn
.
Module
):
def
__init__
(
self
,
dim_in
,
dim_out
):
super
().
__init__
()
self
.
proj
=
nn
.
Linear
(
dim_in
,
dim_out
*
2
)
def
forward
(
self
,
x
):
x
,
gate
=
self
.
proj
(
x
).
chunk
(
2
,
dim
=-
1
)
return
x
*
F
.
gelu
(
gate
)
src/diffusers/models/unet.py
View file @
94566e6d
...
@@ -23,6 +23,7 @@ from ..modeling_utils import ModelMixin
...
@@ -23,6 +23,7 @@ from ..modeling_utils import ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample2D
,
ResnetBlock2D
,
Upsample2D
from
.resnet
import
Downsample2D
,
ResnetBlock2D
,
Upsample2D
from
.unet_new
import
UNetMidBlock2D
def
nonlinearity
(
x
):
def
nonlinearity
(
x
):
...
@@ -105,13 +106,8 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -105,13 +106,8 @@ class UNetModel(ModelMixin, ConfigMixin):
self
.
down
.
append
(
down
)
self
.
down
.
append
(
down
)
# middle
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
=
UNetMidBlock2D
(
self
.
mid
.
block_1
=
ResnetBlock2D
(
in_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
overwrite_qkv
=
True
,
overwrite_unet
=
True
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid
.
attn_1
=
AttentionBlock
(
block_in
,
overwrite_qkv
=
True
)
self
.
mid
.
block_2
=
ResnetBlock2D
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
# upsampling
# upsampling
...
@@ -171,10 +167,10 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -171,10 +167,10 @@ class UNetModel(ModelMixin, ConfigMixin):
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
# middle
h
=
hs
[
-
1
]
h
=
self
.
mid
(
hs
[
-
1
],
temb
)
h
=
self
.
mid
.
block_1
(
h
,
temb
)
#
h = self.mid.block_1(h, temb)
h
=
self
.
mid
.
attn_1
(
h
)
#
h = self.mid.attn_1(h)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
#
h = self.mid.block_2(h, temb)
# upsampling
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
...
...
src/diffusers/models/unet_glide.py
View file @
94566e6d
...
@@ -7,6 +7,7 @@ from ..modeling_utils import ModelMixin
...
@@ -7,6 +7,7 @@ from ..modeling_utils import ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample2D
,
ResnetBlock2D
,
Upsample2D
from
.resnet
import
Downsample2D
,
ResnetBlock2D
,
Upsample2D
from
.unet_new
import
UNetMidBlock2D
def
convert_module_to_f16
(
l
):
def
convert_module_to_f16
(
l
):
...
@@ -193,7 +194,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -193,7 +194,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
layers
.
append
(
layers
.
append
(
AttentionBlock
(
AttentionBlock
(
ch
,
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
num_head_channels
=
num_head_channels
,
encoder_channels
=
transformer_dim
,
encoder_channels
=
transformer_dim
,
...
@@ -226,6 +226,20 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -226,6 +226,20 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
ds
*=
2
ds
*=
2
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
self
.
mid
=
UNetMidBlock2D
(
in_channels
=
ch
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
resnet_eps
=
1e-5
,
resnet_act_fn
=
"silu"
,
resnet_time_scale_shift
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
attn_num_heads
=
num_heads
,
attn_num_head_channels
=
num_head_channels
,
attn_encoder_channels
=
transformer_dim
,
)
# TODO(Patrick) - delete after weight conversion
# init to be able to overwrite `self.mid`
self
.
middle_block
=
TimestepEmbedSequential
(
self
.
middle_block
=
TimestepEmbedSequential
(
ResnetBlock2D
(
ResnetBlock2D
(
in_channels
=
ch
,
in_channels
=
ch
,
...
@@ -238,7 +252,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -238,7 +252,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
),
),
AttentionBlock
(
AttentionBlock
(
ch
,
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
num_head_channels
=
num_head_channels
,
encoder_channels
=
transformer_dim
,
encoder_channels
=
transformer_dim
,
...
@@ -253,6 +266,10 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -253,6 +266,10 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
overwrite_for_glide
=
True
,
overwrite_for_glide
=
True
,
),
),
)
)
self
.
mid
.
resnet_1
=
self
.
middle_block
[
0
]
self
.
mid
.
attn
=
self
.
middle_block
[
1
]
self
.
mid
.
resnet_2
=
self
.
middle_block
[
2
]
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
self
.
output_blocks
=
nn
.
ModuleList
([])
self
.
output_blocks
=
nn
.
ModuleList
([])
...
@@ -276,7 +293,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -276,7 +293,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
layers
.
append
(
layers
.
append
(
AttentionBlock
(
AttentionBlock
(
ch
,
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads_upsample
,
num_heads
=
num_heads_upsample
,
num_head_channels
=
num_head_channels
,
num_head_channels
=
num_head_channels
,
encoder_channels
=
transformer_dim
,
encoder_channels
=
transformer_dim
,
...
@@ -343,7 +359,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -343,7 +359,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
for
module
in
self
.
input_blocks
:
for
module
in
self
.
input_blocks
:
h
=
module
(
h
,
emb
)
h
=
module
(
h
,
emb
)
hs
.
append
(
h
)
hs
.
append
(
h
)
h
=
self
.
mid
dle_block
(
h
,
emb
)
h
=
self
.
mid
(
h
,
emb
)
for
module
in
self
.
output_blocks
:
for
module
in
self
.
output_blocks
:
h
=
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
h
=
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
h
=
module
(
h
,
emb
)
h
=
module
(
h
,
emb
)
...
@@ -438,7 +454,7 @@ class GlideTextToImageUNetModel(GlideUNetModel):
...
@@ -438,7 +454,7 @@ class GlideTextToImageUNetModel(GlideUNetModel):
for
module
in
self
.
input_blocks
:
for
module
in
self
.
input_blocks
:
h
=
module
(
h
,
emb
,
transformer_out
)
h
=
module
(
h
,
emb
,
transformer_out
)
hs
.
append
(
h
)
hs
.
append
(
h
)
h
=
self
.
mid
dle_block
(
h
,
emb
,
transformer_out
)
h
=
self
.
mid
(
h
,
emb
,
transformer_out
)
for
module
in
self
.
output_blocks
:
for
module
in
self
.
output_blocks
:
other
=
hs
.
pop
()
other
=
hs
.
pop
()
h
=
torch
.
cat
([
h
,
other
],
dim
=
1
)
h
=
torch
.
cat
([
h
,
other
],
dim
=
1
)
...
...
src/diffusers/models/unet_grad_tts.py
View file @
94566e6d
...
@@ -133,6 +133,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
...
@@ -133,6 +133,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
overwrite_for_grad_tts
=
True
,
overwrite_for_grad_tts
=
True
,
)
)
# self.mid = UNetMidBlock2D
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
self
.
ups
.
append
(
self
.
ups
.
append
(
torch
.
nn
.
ModuleList
(
torch
.
nn
.
ModuleList
(
...
...
src/diffusers/models/unet_ldm.py
View file @
94566e6d
...
@@ -11,6 +11,7 @@ from ..modeling_utils import ModelMixin
...
@@ -11,6 +11,7 @@ from ..modeling_utils import ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample2D
,
ResnetBlock2D
,
Upsample2D
from
.resnet
import
Downsample2D
,
ResnetBlock2D
,
Upsample2D
from
.unet_new
import
UNetMidBlock2D
# from .resnet import ResBlock
# from .resnet import ResBlock
...
@@ -239,14 +240,12 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -239,14 +240,12 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
conv_resample
=
conv_resample
,
conv_resample
=
conv_resample
,
dims
=
dims
,
dims
=
dims
,
num_classes
=
num_classes
,
num_classes
=
num_classes
,
use_checkpoint
=
use_checkpoint
,
use_fp16
=
use_fp16
,
use_fp16
=
use_fp16
,
num_heads
=
num_heads
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
num_head_channels
=
num_head_channels
,
num_heads_upsample
=
num_heads_upsample
,
num_heads_upsample
=
num_heads_upsample
,
use_scale_shift_norm
=
use_scale_shift_norm
,
use_scale_shift_norm
=
use_scale_shift_norm
,
resblock_updown
=
resblock_updown
,
resblock_updown
=
resblock_updown
,
use_new_attention_order
=
use_new_attention_order
,
use_spatial_transformer
=
use_spatial_transformer
,
use_spatial_transformer
=
use_spatial_transformer
,
transformer_depth
=
transformer_depth
,
transformer_depth
=
transformer_depth
,
context_dim
=
context_dim
,
context_dim
=
context_dim
,
...
@@ -283,7 +282,6 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -283,7 +282,6 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
self
.
channel_mult
=
channel_mult
self
.
channel_mult
=
channel_mult
self
.
conv_resample
=
conv_resample
self
.
conv_resample
=
conv_resample
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
use_checkpoint
=
use_checkpoint
self
.
dtype_
=
torch
.
float16
if
use_fp16
else
torch
.
float32
self
.
dtype_
=
torch
.
float16
if
use_fp16
else
torch
.
float32
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
num_head_channels
=
num_head_channels
self
.
num_head_channels
=
num_head_channels
...
@@ -333,10 +331,8 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -333,10 +331,8 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
layers
.
append
(
layers
.
append
(
AttentionBlock
(
AttentionBlock
(
ch
,
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_heads
=
num_heads
,
num_head_channels
=
dim_head
,
num_head_channels
=
dim_head
,
use_new_attention_order
=
use_new_attention_order
,
)
)
if
not
use_spatial_transformer
if
not
use_spatial_transformer
else
SpatialTransformer
(
else
SpatialTransformer
(
...
@@ -366,6 +362,25 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -366,6 +362,25 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
if
legacy
:
if
legacy
:
# num_heads = 1
# num_heads = 1
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
if
dim_head
<
0
:
dim_head
=
None
self
.
mid
=
UNetMidBlock2D
(
in_channels
=
ch
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
resnet_eps
=
1e-5
,
resnet_act_fn
=
"silu"
,
resnet_time_scale_shift
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
attention_layer_type
=
"self"
if
not
use_spatial_transformer
else
"spatial"
,
attn_num_heads
=
num_heads
,
attn_num_head_channels
=
dim_head
,
attn_depth
=
transformer_depth
,
attn_encoder_channels
=
context_dim
,
)
# TODO(Patrick) - delete after weight conversion
# init to be able to overwrite `self.mid`
self
.
middle_block
=
TimestepEmbedSequential
(
self
.
middle_block
=
TimestepEmbedSequential
(
ResnetBlock2D
(
ResnetBlock2D
(
in_channels
=
ch
,
in_channels
=
ch
,
...
@@ -378,10 +393,8 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -378,10 +393,8 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
),
),
AttentionBlock
(
AttentionBlock
(
ch
,
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_heads
=
num_heads
,
num_head_channels
=
dim_head
,
num_head_channels
=
dim_head
,
use_new_attention_order
=
use_new_attention_order
,
)
)
if
not
use_spatial_transformer
if
not
use_spatial_transformer
else
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
),
else
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
),
...
@@ -395,6 +408,10 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -395,6 +408,10 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
overwrite_for_ldm
=
True
,
overwrite_for_ldm
=
True
,
),
),
)
)
self
.
mid
.
resnet_1
=
self
.
middle_block
[
0
]
self
.
mid
.
attn
=
self
.
middle_block
[
1
]
self
.
mid
.
resnet_2
=
self
.
middle_block
[
2
]
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
self
.
output_blocks
=
nn
.
ModuleList
([])
self
.
output_blocks
=
nn
.
ModuleList
([])
...
@@ -425,10 +442,8 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -425,10 +442,8 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
layers
.
append
(
layers
.
append
(
AttentionBlock
(
AttentionBlock
(
ch
,
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads_upsample
,
num_heads
=
num_heads_upsample
,
num_head_channels
=
dim_head
,
num_head_channels
=
dim_head
,
use_new_attention_order
=
use_new_attention_order
,
)
)
if
not
use_spatial_transformer
if
not
use_spatial_transformer
else
SpatialTransformer
(
else
SpatialTransformer
(
...
@@ -493,7 +508,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -493,7 +508,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
for
module
in
self
.
input_blocks
:
for
module
in
self
.
input_blocks
:
h
=
module
(
h
,
emb
,
context
)
h
=
module
(
h
,
emb
,
context
)
hs
.
append
(
h
)
hs
.
append
(
h
)
h
=
self
.
mid
dle_block
(
h
,
emb
,
context
)
h
=
self
.
mid
(
h
,
emb
,
context
)
for
module
in
self
.
output_blocks
:
for
module
in
self
.
output_blocks
:
h
=
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
h
=
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
h
=
module
(
h
,
emb
,
context
)
h
=
module
(
h
,
emb
,
context
)
...
...
src/diffusers/models/unet_new.py
0 → 100644
View file @
94566e6d
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
torch
import
nn
from
.attention
import
AttentionBlock
,
SpatialTransformer
from
.resnet
import
ResnetBlock2D
class
UNetMidBlock2D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
temb_channels
:
int
,
dropout
:
float
,
resnet_eps
:
float
=
1e-6
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_groups
:
int
=
32
,
attention_layer_type
:
str
=
"self"
,
attn_num_heads
=
1
,
attn_num_head_channels
=
None
,
attn_encoder_channels
=
None
,
attn_dim_head
=
None
,
attn_depth
=
None
,
output_scale_factor
=
1.0
,
overwrite_qkv
=
False
,
overwrite_unet
=
False
,
):
super
().
__init__
()
self
.
resnet_1
=
ResnetBlock2D
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
temb_channels
=
temb_channels
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
)
if
attention_layer_type
==
"self"
:
self
.
attn
=
AttentionBlock
(
in_channels
,
num_heads
=
attn_num_heads
,
num_head_channels
=
attn_num_head_channels
,
encoder_channels
=
attn_encoder_channels
,
overwrite_qkv
=
overwrite_qkv
,
rescale_output_factor
=
output_scale_factor
,
)
elif
attention_layer_type
==
"spatial"
:
self
.
attn
=
(
SpatialTransformer
(
in_channels
,
attn_num_heads
,
attn_num_head_channels
,
depth
=
attn_depth
,
context_dim
=
attn_encoder_channels
,
),
)
self
.
resnet_2
=
ResnetBlock2D
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
temb_channels
=
temb_channels
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
)
# TODO(Patrick) - delete all of the following code
self
.
is_overwritten
=
False
self
.
overwrite_unet
=
overwrite_unet
if
self
.
overwrite_unet
:
block_in
=
in_channels
self
.
temb_ch
=
temb_channels
self
.
block_1
=
ResnetBlock2D
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
eps
=
resnet_eps
,
)
self
.
attn_1
=
AttentionBlock
(
block_in
,
num_heads
=
attn_num_heads
,
num_head_channels
=
attn_num_head_channels
,
encoder_channels
=
attn_encoder_channels
,
overwrite_qkv
=
True
,
)
self
.
block_2
=
ResnetBlock2D
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
eps
=
resnet_eps
,
)
def
forward
(
self
,
hidden_states
,
temb
=
None
,
encoder_states
=
None
):
if
not
self
.
is_overwritten
and
self
.
overwrite_unet
:
self
.
resnet_1
=
self
.
block_1
self
.
attn
=
self
.
attn_1
self
.
resnet_2
=
self
.
block_2
self
.
is_overwritten
=
True
hidden_states
=
self
.
resnet_1
(
hidden_states
,
temb
)
if
encoder_states
is
None
:
hidden_states
=
self
.
attn
(
hidden_states
)
else
:
hidden_states
=
self
.
attn
(
hidden_states
,
encoder_states
)
hidden_states
=
self
.
resnet_2
(
hidden_states
,
temb
)
return
hidden_states
src/diffusers/models/unet_sde_score_estimation.py
View file @
94566e6d
...
@@ -27,6 +27,7 @@ from ..modeling_utils import ModelMixin
...
@@ -27,6 +27,7 @@ from ..modeling_utils import ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.resnet
import
Downsample2D
,
FirDownsample2D
,
FirUpsample2D
,
ResnetBlock2D
,
Upsample2D
from
.resnet
import
Downsample2D
,
FirDownsample2D
,
FirUpsample2D
,
ResnetBlock2D
,
Upsample2D
from
.unet_new
import
UNetMidBlock2D
class
Combine
(
nn
.
Module
):
class
Combine
(
nn
.
Module
):
...
@@ -214,6 +215,16 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -214,6 +215,16 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs_c
.
append
(
in_ch
)
hs_c
.
append
(
in_ch
)
# mid
self
.
mid
=
UNetMidBlock2D
(
in_channels
=
in_ch
,
temb_channels
=
4
*
nf
,
output_scale_factor
=
math
.
sqrt
(
2.0
),
resnet_act_fn
=
"silu"
,
resnet_groups
=
min
(
in_ch
//
4
,
32
),
dropout
=
dropout
,
)
in_ch
=
hs_c
[
-
1
]
in_ch
=
hs_c
[
-
1
]
modules
.
append
(
modules
.
append
(
ResnetBlock2D
(
ResnetBlock2D
(
...
@@ -238,6 +249,9 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -238,6 +249,9 @@ class NCSNpp(ModelMixin, ConfigMixin):
overwrite_for_score_vde
=
True
,
overwrite_for_score_vde
=
True
,
)
)
)
)
self
.
mid
.
resnet_1
=
modules
[
len
(
modules
)
-
3
]
self
.
mid
.
attn
=
modules
[
len
(
modules
)
-
2
]
self
.
mid
.
resnet_2
=
modules
[
len
(
modules
)
-
1
]
pyramid_ch
=
0
pyramid_ch
=
0
# Upsampling block
# Upsampling block
...
@@ -378,13 +392,16 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -378,13 +392,16 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs
.
append
(
h
)
hs
.
append
(
h
)
h
=
hs
[
-
1
]
# h = hs[-1]
h
=
modules
[
m_idx
](
h
,
temb
)
# h = modules[m_idx](h, temb)
m_idx
+=
1
# m_idx += 1
h
=
modules
[
m_idx
](
h
)
# h = modules[m_idx](h)
m_idx
+=
1
# m_idx += 1
h
=
modules
[
m_idx
](
h
,
temb
)
# h = modules[m_idx](h, temb)
m_idx
+=
1
# m_idx += 1
h
=
self
.
mid
(
h
,
temb
)
m_idx
+=
3
pyramid
=
None
pyramid
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment