Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
94566e6d
Unverified
Commit
94566e6d
authored
Jul 04, 2022
by
Patrick von Platen
Committed by
GitHub
Jul 04, 2022
Browse files
update mid block (#70)
* update mid block * finish mid block
parent
4e267493
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
361 additions
and
147 deletions
+361
-147
src/diffusers/models/attention.py
src/diffusers/models/attention.py
+154
-114
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+7
-11
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+21
-5
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+2
-0
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+25
-10
src/diffusers/models/unet_new.py
src/diffusers/models/unet_new.py
+128
-0
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+24
-7
No files found.
src/diffusers/models/attention.py
View file @
94566e6d
import
math
import
math
from
inspect
import
isfunction
import
torch
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch
import
nn
...
@@ -43,18 +45,16 @@ class AttentionBlock(nn.Module):
...
@@ -43,18 +45,16 @@ class AttentionBlock(nn.Module):
self
,
self
,
channels
,
channels
,
num_heads
=
1
,
num_heads
=
1
,
num_head_channels
=
-
1
,
num_head_channels
=
None
,
num_groups
=
32
,
num_groups
=
32
,
use_checkpoint
=
False
,
encoder_channels
=
None
,
encoder_channels
=
None
,
use_new_attention_order
=
False
,
# TODO(Patrick) -> is never used, maybe delete?
overwrite_qkv
=
False
,
overwrite_qkv
=
False
,
overwrite_linear
=
False
,
overwrite_linear
=
False
,
rescale_output_factor
=
1.0
,
rescale_output_factor
=
1.0
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
channels
=
channels
if
num_head_channels
==
-
1
:
if
num_head_channels
is
None
:
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
else
:
else
:
assert
(
assert
(
...
@@ -62,7 +62,6 @@ class AttentionBlock(nn.Module):
...
@@ -62,7 +62,6 @@ class AttentionBlock(nn.Module):
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
num_heads
=
channels
//
num_head_channels
self
.
use_checkpoint
=
use_checkpoint
self
.
norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
1e-5
,
affine
=
True
)
self
.
norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
1e-5
,
affine
=
True
)
self
.
qkv
=
nn
.
Conv1d
(
channels
,
channels
*
3
,
1
)
self
.
qkv
=
nn
.
Conv1d
(
channels
,
channels
*
3
,
1
)
self
.
n_heads
=
self
.
num_heads
self
.
n_heads
=
self
.
num_heads
...
@@ -160,115 +159,135 @@ class AttentionBlock(nn.Module):
...
@@ -160,115 +159,135 @@ class AttentionBlock(nn.Module):
return
result
return
result
# unet_score_estimation.py
class
SpatialTransformer
(
nn
.
Module
):
# class AttnBlockpp(nn.Module):
"""
# """Channel-wise self-attention block. Modified from DDPM."""
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
#
standard transformer action. Finally, reshape to image
# def __init__(
"""
# self,
# channels,
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
depth
=
1
,
dropout
=
0.0
,
context_dim
=
None
):
# skip_rescale=False,
super
().
__init__
()
# init_scale=0.0,
self
.
in_channels
=
in_channels
# num_heads=1,
inner_dim
=
n_heads
*
d_head
# num_head_channels=-1,
self
.
norm
=
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
# use_checkpoint=False,
# encoder_channels=None,
self
.
proj_in
=
nn
.
Conv2d
(
in_channels
,
inner_dim
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
# use_new_attention_order=False, # TODO(Patrick) -> is never used, maybe delete?
# overwrite_qkv=False,
self
.
transformer_blocks
=
nn
.
ModuleList
(
# overwrite_from_grad_tts=False,
[
# ):
BasicTransformerBlock
(
inner_dim
,
n_heads
,
d_head
,
dropout
=
dropout
,
context_dim
=
context_dim
)
# super().__init__()
for
d
in
range
(
depth
)
# num_groups = min(channels // 4, 32)
]
# self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
)
# self.NIN_0 = NIN(channels, channels)
# self.NIN_1 = NIN(channels, channels)
self
.
proj_out
=
zero_module
(
nn
.
Conv2d
(
inner_dim
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
))
# self.NIN_2 = NIN(channels, channels)
# self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
def
forward
(
self
,
x
,
context
=
None
):
# self.skip_rescale = skip_rescale
# note: if no context is given, cross-attention defaults to self-attention
#
b
,
c
,
h
,
w
=
x
.
shape
# self.channels = channels
x_in
=
x
# if num_head_channels == -1:
x
=
self
.
norm
(
x
)
# self.num_heads = num_heads
x
=
self
.
proj_in
(
x
)
# else:
x
=
x
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
b
,
h
*
w
,
c
)
# assert (
for
block
in
self
.
transformer_blocks
:
# channels % num_head_channels == 0
x
=
block
(
x
,
context
=
context
)
# ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
x
=
x
.
reshape
(
b
,
h
,
w
,
c
).
permute
(
0
,
3
,
1
,
2
)
# self.num_heads = channels // num_head_channels
x
=
self
.
proj_out
(
x
)
#
return
x
+
x_in
# self.use_checkpoint = use_checkpoint
# self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
# self.qkv = nn.Conv1d(channels, channels * 3, 1)
class
BasicTransformerBlock
(
nn
.
Module
):
# self.n_heads = self.num_heads
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.0
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
):
#
super
().
__init__
()
# self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
self
.
attn1
=
CrossAttention
(
#
query_dim
=
dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
# self.is_weight_set = False
)
# is a self-attention
#
self
.
ff
=
FeedForward
(
dim
,
dropout
=
dropout
,
glu
=
gated_ff
)
# def set_weights(self):
self
.
attn2
=
CrossAttention
(
# self.qkv.weight.data = torch.concat([self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0)[:, :, None]
query_dim
=
dim
,
context_dim
=
context_dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
# self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0)
)
# is self-attn if context is none
#
self
.
norm1
=
nn
.
LayerNorm
(
dim
)
# self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None]
self
.
norm2
=
nn
.
LayerNorm
(
dim
)
# self.proj_out.bias.data = self.NIN_3.b.data
self
.
norm3
=
nn
.
LayerNorm
(
dim
)
#
self
.
checkpoint
=
checkpoint
# self.norm.weight.data = self.GroupNorm_0.weight.data
# self.norm.bias.data = self.GroupNorm_0.bias.data
def
forward
(
self
,
x
,
context
=
None
):
#
x
=
self
.
attn1
(
self
.
norm1
(
x
))
+
x
# def forward(self, x):
x
=
self
.
attn2
(
self
.
norm2
(
x
),
context
=
context
)
+
x
# if not self.is_weight_set:
x
=
self
.
ff
(
self
.
norm3
(
x
))
+
x
# self.set_weights()
return
x
# self.is_weight_set = True
#
# B, C, H, W = x.shape
class
CrossAttention
(
nn
.
Module
):
# h = self.GroupNorm_0(x)
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
):
# q = self.NIN_0(h)
super
().
__init__
()
# k = self.NIN_1(h)
inner_dim
=
dim_head
*
heads
# v = self.NIN_2(h)
context_dim
=
default
(
context_dim
,
query_dim
)
#
# w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5))
self
.
scale
=
dim_head
**-
0.5
# w = torch.reshape(w, (B, H, W, H * W))
self
.
heads
=
heads
# w = F.softmax(w, dim=-1)
# w = torch.reshape(w, (B, H, W, H, W))
self
.
to_q
=
nn
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
)
# h = torch.einsum("bhwij,bcij->bchw", w, v)
self
.
to_k
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
# h = self.NIN_3(h)
self
.
to_v
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
#
# if not self.skip_rescale:
self
.
to_out
=
nn
.
Sequential
(
nn
.
Linear
(
inner_dim
,
query_dim
),
nn
.
Dropout
(
dropout
))
# result = x + h
# else:
def
reshape_heads_to_batch_dim
(
self
,
tensor
):
# result = (x + h) / np.sqrt(2.0)
batch_size
,
seq_len
,
dim
=
tensor
.
shape
#
head_size
=
self
.
heads
# result = self.forward_2(x)
tensor
=
tensor
.
reshape
(
batch_size
,
seq_len
,
head_size
,
dim
//
head_size
)
#
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
batch_size
*
head_size
,
seq_len
,
dim
//
head_size
)
# return result
return
tensor
#
# def forward_2(self, x, encoder_out=None):
def
reshape_batch_dim_to_heads
(
self
,
tensor
):
# b, c, *spatial = x.shape
batch_size
,
seq_len
,
dim
=
tensor
.
shape
# hid_states = self.norm(x).view(b, c, -1)
head_size
=
self
.
heads
#
tensor
=
tensor
.
reshape
(
batch_size
//
head_size
,
head_size
,
seq_len
,
dim
)
# qkv = self.qkv(hid_states)
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
batch_size
//
head_size
,
seq_len
,
dim
*
head_size
)
# bs, width, length = qkv.shape
return
tensor
# assert width % (3 * self.n_heads) == 0
# ch = width // (3 * self.n_heads)
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
):
# q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
batch_size
,
sequence_length
,
dim
=
x
.
shape
#
# if encoder_out is not None:
h
=
self
.
heads
# encoder_kv = self.encoder_kv(encoder_out)
# assert encoder_kv.shape[1] == self.n_heads * ch * 2
q
=
self
.
to_q
(
x
)
# ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
context
=
default
(
context
,
x
)
# k = torch.cat([ek, k], dim=-1)
k
=
self
.
to_k
(
context
)
# v = torch.cat([ev, v], dim=-1)
v
=
self
.
to_v
(
context
)
#
# scale = 1 / math.sqrt(math.sqrt(ch))
q
=
self
.
reshape_heads_to_batch_dim
(
q
)
# weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
k
=
self
.
reshape_heads_to_batch_dim
(
k
)
# weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
v
=
self
.
reshape_heads_to_batch_dim
(
v
)
#
# a = torch.einsum("bts,bcs->bct", weight, v)
sim
=
torch
.
einsum
(
"b i d, b j d -> b i j"
,
q
,
k
)
*
self
.
scale
# h = a.reshape(bs, -1, length)
#
if
exists
(
mask
):
# h = self.proj_out(h)
mask
=
mask
.
reshape
(
batch_size
,
-
1
)
# h = h.reshape(b, c, *spatial)
max_neg_value
=
-
torch
.
finfo
(
sim
.
dtype
).
max
#
mask
=
mask
[:,
None
,
:].
repeat
(
h
,
1
,
1
)
# return (x + h) / np.sqrt(2.0)
sim
.
masked_fill_
(
~
mask
,
max_neg_value
)
# attention, what we cannot get enough of
attn
=
sim
.
softmax
(
dim
=-
1
)
out
=
torch
.
einsum
(
"b i j, b j d -> b i d"
,
attn
,
v
)
out
=
self
.
reshape_batch_dim_to_heads
(
out
)
return
self
.
to_out
(
out
)
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
=
None
,
mult
=
4
,
glu
=
False
,
dropout
=
0.0
):
super
().
__init__
()
inner_dim
=
int
(
dim
*
mult
)
dim_out
=
default
(
dim_out
,
dim
)
project_in
=
nn
.
Sequential
(
nn
.
Linear
(
dim
,
inner_dim
),
nn
.
GELU
())
if
not
glu
else
GEGLU
(
dim
,
inner_dim
)
self
.
net
=
nn
.
Sequential
(
project_in
,
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
inner_dim
,
dim_out
))
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
# TODO(Patrick) - this can and should be removed
# TODO(Patrick) - this can and should be removed
...
@@ -287,3 +306,24 @@ class NIN(nn.Module):
...
@@ -287,3 +306,24 @@ class NIN(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
W
=
nn
.
Parameter
(
torch
.
zeros
(
in_dim
,
num_units
),
requires_grad
=
True
)
self
.
W
=
nn
.
Parameter
(
torch
.
zeros
(
in_dim
,
num_units
),
requires_grad
=
True
)
self
.
b
=
nn
.
Parameter
(
torch
.
zeros
(
num_units
),
requires_grad
=
True
)
self
.
b
=
nn
.
Parameter
(
torch
.
zeros
(
num_units
),
requires_grad
=
True
)
def
exists
(
val
):
return
val
is
not
None
def
default
(
val
,
d
):
if
exists
(
val
):
return
val
return
d
()
if
isfunction
(
d
)
else
d
# feedforward
class
GEGLU
(
nn
.
Module
):
def
__init__
(
self
,
dim_in
,
dim_out
):
super
().
__init__
()
self
.
proj
=
nn
.
Linear
(
dim_in
,
dim_out
*
2
)
def
forward
(
self
,
x
):
x
,
gate
=
self
.
proj
(
x
).
chunk
(
2
,
dim
=-
1
)
return
x
*
F
.
gelu
(
gate
)
src/diffusers/models/unet.py
View file @
94566e6d
...
@@ -23,6 +23,7 @@ from ..modeling_utils import ModelMixin
...
@@ -23,6 +23,7 @@ from ..modeling_utils import ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample2D
,
ResnetBlock2D
,
Upsample2D
from
.resnet
import
Downsample2D
,
ResnetBlock2D
,
Upsample2D
from
.unet_new
import
UNetMidBlock2D
def
nonlinearity
(
x
):
def
nonlinearity
(
x
):
...
@@ -105,13 +106,8 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -105,13 +106,8 @@ class UNetModel(ModelMixin, ConfigMixin):
self
.
down
.
append
(
down
)
self
.
down
.
append
(
down
)
# middle
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
=
UNetMidBlock2D
(
self
.
mid
.
block_1
=
ResnetBlock2D
(
in_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
overwrite_qkv
=
True
,
overwrite_unet
=
True
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid
.
attn_1
=
AttentionBlock
(
block_in
,
overwrite_qkv
=
True
)
self
.
mid
.
block_2
=
ResnetBlock2D
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
# upsampling
# upsampling
...
@@ -171,10 +167,10 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -171,10 +167,10 @@ class UNetModel(ModelMixin, ConfigMixin):
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
# middle
h
=
hs
[
-
1
]
h
=
self
.
mid
(
hs
[
-
1
],
temb
)
h
=
self
.
mid
.
block_1
(
h
,
temb
)
#
h = self.mid.block_1(h, temb)
h
=
self
.
mid
.
attn_1
(
h
)
#
h = self.mid.attn_1(h)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
#
h = self.mid.block_2(h, temb)
# upsampling
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
...
...
src/diffusers/models/unet_glide.py
View file @
94566e6d
...
@@ -7,6 +7,7 @@ from ..modeling_utils import ModelMixin
...
@@ -7,6 +7,7 @@ from ..modeling_utils import ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample2D
,
ResnetBlock2D
,
Upsample2D
from
.resnet
import
Downsample2D
,
ResnetBlock2D
,
Upsample2D
from
.unet_new
import
UNetMidBlock2D
def
convert_module_to_f16
(
l
):
def
convert_module_to_f16
(
l
):
...
@@ -193,7 +194,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -193,7 +194,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
layers
.
append
(
layers
.
append
(
AttentionBlock
(
AttentionBlock
(
ch
,
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
num_head_channels
=
num_head_channels
,
encoder_channels
=
transformer_dim
,
encoder_channels
=
transformer_dim
,
...
@@ -226,6 +226,20 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -226,6 +226,20 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
ds
*=
2
ds
*=
2
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
self
.
mid
=
UNetMidBlock2D
(
in_channels
=
ch
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
resnet_eps
=
1e-5
,
resnet_act_fn
=
"silu"
,
resnet_time_scale_shift
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
attn_num_heads
=
num_heads
,
attn_num_head_channels
=
num_head_channels
,
attn_encoder_channels
=
transformer_dim
,
)
# TODO(Patrick) - delete after weight conversion
# init to be able to overwrite `self.mid`
self
.
middle_block
=
TimestepEmbedSequential
(
self
.
middle_block
=
TimestepEmbedSequential
(
ResnetBlock2D
(
ResnetBlock2D
(
in_channels
=
ch
,
in_channels
=
ch
,
...
@@ -238,7 +252,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -238,7 +252,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
),
),
AttentionBlock
(
AttentionBlock
(
ch
,
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
num_head_channels
=
num_head_channels
,
encoder_channels
=
transformer_dim
,
encoder_channels
=
transformer_dim
,
...
@@ -253,6 +266,10 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -253,6 +266,10 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
overwrite_for_glide
=
True
,
overwrite_for_glide
=
True
,
),
),
)
)
self
.
mid
.
resnet_1
=
self
.
middle_block
[
0
]
self
.
mid
.
attn
=
self
.
middle_block
[
1
]
self
.
mid
.
resnet_2
=
self
.
middle_block
[
2
]
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
self
.
output_blocks
=
nn
.
ModuleList
([])
self
.
output_blocks
=
nn
.
ModuleList
([])
...
@@ -276,7 +293,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -276,7 +293,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
layers
.
append
(
layers
.
append
(
AttentionBlock
(
AttentionBlock
(
ch
,
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads_upsample
,
num_heads
=
num_heads_upsample
,
num_head_channels
=
num_head_channels
,
num_head_channels
=
num_head_channels
,
encoder_channels
=
transformer_dim
,
encoder_channels
=
transformer_dim
,
...
@@ -343,7 +359,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -343,7 +359,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
for
module
in
self
.
input_blocks
:
for
module
in
self
.
input_blocks
:
h
=
module
(
h
,
emb
)
h
=
module
(
h
,
emb
)
hs
.
append
(
h
)
hs
.
append
(
h
)
h
=
self
.
mid
dle_block
(
h
,
emb
)
h
=
self
.
mid
(
h
,
emb
)
for
module
in
self
.
output_blocks
:
for
module
in
self
.
output_blocks
:
h
=
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
h
=
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
h
=
module
(
h
,
emb
)
h
=
module
(
h
,
emb
)
...
@@ -438,7 +454,7 @@ class GlideTextToImageUNetModel(GlideUNetModel):
...
@@ -438,7 +454,7 @@ class GlideTextToImageUNetModel(GlideUNetModel):
for
module
in
self
.
input_blocks
:
for
module
in
self
.
input_blocks
:
h
=
module
(
h
,
emb
,
transformer_out
)
h
=
module
(
h
,
emb
,
transformer_out
)
hs
.
append
(
h
)
hs
.
append
(
h
)
h
=
self
.
mid
dle_block
(
h
,
emb
,
transformer_out
)
h
=
self
.
mid
(
h
,
emb
,
transformer_out
)
for
module
in
self
.
output_blocks
:
for
module
in
self
.
output_blocks
:
other
=
hs
.
pop
()
other
=
hs
.
pop
()
h
=
torch
.
cat
([
h
,
other
],
dim
=
1
)
h
=
torch
.
cat
([
h
,
other
],
dim
=
1
)
...
...
src/diffusers/models/unet_grad_tts.py
View file @
94566e6d
...
@@ -133,6 +133,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
...
@@ -133,6 +133,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
overwrite_for_grad_tts
=
True
,
overwrite_for_grad_tts
=
True
,
)
)
# self.mid = UNetMidBlock2D
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
self
.
ups
.
append
(
self
.
ups
.
append
(
torch
.
nn
.
ModuleList
(
torch
.
nn
.
ModuleList
(
...
...
src/diffusers/models/unet_ldm.py
View file @
94566e6d
...
@@ -11,6 +11,7 @@ from ..modeling_utils import ModelMixin
...
@@ -11,6 +11,7 @@ from ..modeling_utils import ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample2D
,
ResnetBlock2D
,
Upsample2D
from
.resnet
import
Downsample2D
,
ResnetBlock2D
,
Upsample2D
from
.unet_new
import
UNetMidBlock2D
# from .resnet import ResBlock
# from .resnet import ResBlock
...
@@ -239,14 +240,12 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -239,14 +240,12 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
conv_resample
=
conv_resample
,
conv_resample
=
conv_resample
,
dims
=
dims
,
dims
=
dims
,
num_classes
=
num_classes
,
num_classes
=
num_classes
,
use_checkpoint
=
use_checkpoint
,
use_fp16
=
use_fp16
,
use_fp16
=
use_fp16
,
num_heads
=
num_heads
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
num_head_channels
=
num_head_channels
,
num_heads_upsample
=
num_heads_upsample
,
num_heads_upsample
=
num_heads_upsample
,
use_scale_shift_norm
=
use_scale_shift_norm
,
use_scale_shift_norm
=
use_scale_shift_norm
,
resblock_updown
=
resblock_updown
,
resblock_updown
=
resblock_updown
,
use_new_attention_order
=
use_new_attention_order
,
use_spatial_transformer
=
use_spatial_transformer
,
use_spatial_transformer
=
use_spatial_transformer
,
transformer_depth
=
transformer_depth
,
transformer_depth
=
transformer_depth
,
context_dim
=
context_dim
,
context_dim
=
context_dim
,
...
@@ -283,7 +282,6 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -283,7 +282,6 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
self
.
channel_mult
=
channel_mult
self
.
channel_mult
=
channel_mult
self
.
conv_resample
=
conv_resample
self
.
conv_resample
=
conv_resample
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
use_checkpoint
=
use_checkpoint
self
.
dtype_
=
torch
.
float16
if
use_fp16
else
torch
.
float32
self
.
dtype_
=
torch
.
float16
if
use_fp16
else
torch
.
float32
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
num_head_channels
=
num_head_channels
self
.
num_head_channels
=
num_head_channels
...
@@ -333,10 +331,8 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -333,10 +331,8 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
layers
.
append
(
layers
.
append
(
AttentionBlock
(
AttentionBlock
(
ch
,
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_heads
=
num_heads
,
num_head_channels
=
dim_head
,
num_head_channels
=
dim_head
,
use_new_attention_order
=
use_new_attention_order
,
)
)
if
not
use_spatial_transformer
if
not
use_spatial_transformer
else
SpatialTransformer
(
else
SpatialTransformer
(
...
@@ -366,6 +362,25 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -366,6 +362,25 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
if
legacy
:
if
legacy
:
# num_heads = 1
# num_heads = 1
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
if
dim_head
<
0
:
dim_head
=
None
self
.
mid
=
UNetMidBlock2D
(
in_channels
=
ch
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
resnet_eps
=
1e-5
,
resnet_act_fn
=
"silu"
,
resnet_time_scale_shift
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
attention_layer_type
=
"self"
if
not
use_spatial_transformer
else
"spatial"
,
attn_num_heads
=
num_heads
,
attn_num_head_channels
=
dim_head
,
attn_depth
=
transformer_depth
,
attn_encoder_channels
=
context_dim
,
)
# TODO(Patrick) - delete after weight conversion
# init to be able to overwrite `self.mid`
self
.
middle_block
=
TimestepEmbedSequential
(
self
.
middle_block
=
TimestepEmbedSequential
(
ResnetBlock2D
(
ResnetBlock2D
(
in_channels
=
ch
,
in_channels
=
ch
,
...
@@ -378,10 +393,8 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -378,10 +393,8 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
),
),
AttentionBlock
(
AttentionBlock
(
ch
,
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_heads
=
num_heads
,
num_head_channels
=
dim_head
,
num_head_channels
=
dim_head
,
use_new_attention_order
=
use_new_attention_order
,
)
)
if
not
use_spatial_transformer
if
not
use_spatial_transformer
else
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
),
else
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
),
...
@@ -395,6 +408,10 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -395,6 +408,10 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
overwrite_for_ldm
=
True
,
overwrite_for_ldm
=
True
,
),
),
)
)
self
.
mid
.
resnet_1
=
self
.
middle_block
[
0
]
self
.
mid
.
attn
=
self
.
middle_block
[
1
]
self
.
mid
.
resnet_2
=
self
.
middle_block
[
2
]
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
self
.
output_blocks
=
nn
.
ModuleList
([])
self
.
output_blocks
=
nn
.
ModuleList
([])
...
@@ -425,10 +442,8 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -425,10 +442,8 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
layers
.
append
(
layers
.
append
(
AttentionBlock
(
AttentionBlock
(
ch
,
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads_upsample
,
num_heads
=
num_heads_upsample
,
num_head_channels
=
dim_head
,
num_head_channels
=
dim_head
,
use_new_attention_order
=
use_new_attention_order
,
)
)
if
not
use_spatial_transformer
if
not
use_spatial_transformer
else
SpatialTransformer
(
else
SpatialTransformer
(
...
@@ -493,7 +508,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -493,7 +508,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
for
module
in
self
.
input_blocks
:
for
module
in
self
.
input_blocks
:
h
=
module
(
h
,
emb
,
context
)
h
=
module
(
h
,
emb
,
context
)
hs
.
append
(
h
)
hs
.
append
(
h
)
h
=
self
.
mid
dle_block
(
h
,
emb
,
context
)
h
=
self
.
mid
(
h
,
emb
,
context
)
for
module
in
self
.
output_blocks
:
for
module
in
self
.
output_blocks
:
h
=
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
h
=
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
h
=
module
(
h
,
emb
,
context
)
h
=
module
(
h
,
emb
,
context
)
...
...
src/diffusers/models/unet_new.py
0 → 100644
View file @
94566e6d
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
torch
import
nn
from
.attention
import
AttentionBlock
,
SpatialTransformer
from
.resnet
import
ResnetBlock2D
class
UNetMidBlock2D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
temb_channels
:
int
,
dropout
:
float
,
resnet_eps
:
float
=
1e-6
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_groups
:
int
=
32
,
attention_layer_type
:
str
=
"self"
,
attn_num_heads
=
1
,
attn_num_head_channels
=
None
,
attn_encoder_channels
=
None
,
attn_dim_head
=
None
,
attn_depth
=
None
,
output_scale_factor
=
1.0
,
overwrite_qkv
=
False
,
overwrite_unet
=
False
,
):
super
().
__init__
()
self
.
resnet_1
=
ResnetBlock2D
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
temb_channels
=
temb_channels
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
)
if
attention_layer_type
==
"self"
:
self
.
attn
=
AttentionBlock
(
in_channels
,
num_heads
=
attn_num_heads
,
num_head_channels
=
attn_num_head_channels
,
encoder_channels
=
attn_encoder_channels
,
overwrite_qkv
=
overwrite_qkv
,
rescale_output_factor
=
output_scale_factor
,
)
elif
attention_layer_type
==
"spatial"
:
self
.
attn
=
(
SpatialTransformer
(
in_channels
,
attn_num_heads
,
attn_num_head_channels
,
depth
=
attn_depth
,
context_dim
=
attn_encoder_channels
,
),
)
self
.
resnet_2
=
ResnetBlock2D
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
temb_channels
=
temb_channels
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
)
# TODO(Patrick) - delete all of the following code
self
.
is_overwritten
=
False
self
.
overwrite_unet
=
overwrite_unet
if
self
.
overwrite_unet
:
block_in
=
in_channels
self
.
temb_ch
=
temb_channels
self
.
block_1
=
ResnetBlock2D
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
eps
=
resnet_eps
,
)
self
.
attn_1
=
AttentionBlock
(
block_in
,
num_heads
=
attn_num_heads
,
num_head_channels
=
attn_num_head_channels
,
encoder_channels
=
attn_encoder_channels
,
overwrite_qkv
=
True
,
)
self
.
block_2
=
ResnetBlock2D
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
eps
=
resnet_eps
,
)
def
forward
(
self
,
hidden_states
,
temb
=
None
,
encoder_states
=
None
):
if
not
self
.
is_overwritten
and
self
.
overwrite_unet
:
self
.
resnet_1
=
self
.
block_1
self
.
attn
=
self
.
attn_1
self
.
resnet_2
=
self
.
block_2
self
.
is_overwritten
=
True
hidden_states
=
self
.
resnet_1
(
hidden_states
,
temb
)
if
encoder_states
is
None
:
hidden_states
=
self
.
attn
(
hidden_states
)
else
:
hidden_states
=
self
.
attn
(
hidden_states
,
encoder_states
)
hidden_states
=
self
.
resnet_2
(
hidden_states
,
temb
)
return
hidden_states
src/diffusers/models/unet_sde_score_estimation.py
View file @
94566e6d
...
@@ -27,6 +27,7 @@ from ..modeling_utils import ModelMixin
...
@@ -27,6 +27,7 @@ from ..modeling_utils import ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.resnet
import
Downsample2D
,
FirDownsample2D
,
FirUpsample2D
,
ResnetBlock2D
,
Upsample2D
from
.resnet
import
Downsample2D
,
FirDownsample2D
,
FirUpsample2D
,
ResnetBlock2D
,
Upsample2D
from
.unet_new
import
UNetMidBlock2D
class
Combine
(
nn
.
Module
):
class
Combine
(
nn
.
Module
):
...
@@ -214,6 +215,16 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -214,6 +215,16 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs_c
.
append
(
in_ch
)
hs_c
.
append
(
in_ch
)
# mid
self
.
mid
=
UNetMidBlock2D
(
in_channels
=
in_ch
,
temb_channels
=
4
*
nf
,
output_scale_factor
=
math
.
sqrt
(
2.0
),
resnet_act_fn
=
"silu"
,
resnet_groups
=
min
(
in_ch
//
4
,
32
),
dropout
=
dropout
,
)
in_ch
=
hs_c
[
-
1
]
in_ch
=
hs_c
[
-
1
]
modules
.
append
(
modules
.
append
(
ResnetBlock2D
(
ResnetBlock2D
(
...
@@ -238,6 +249,9 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -238,6 +249,9 @@ class NCSNpp(ModelMixin, ConfigMixin):
overwrite_for_score_vde
=
True
,
overwrite_for_score_vde
=
True
,
)
)
)
)
self
.
mid
.
resnet_1
=
modules
[
len
(
modules
)
-
3
]
self
.
mid
.
attn
=
modules
[
len
(
modules
)
-
2
]
self
.
mid
.
resnet_2
=
modules
[
len
(
modules
)
-
1
]
pyramid_ch
=
0
pyramid_ch
=
0
# Upsampling block
# Upsampling block
...
@@ -378,13 +392,16 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -378,13 +392,16 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs
.
append
(
h
)
hs
.
append
(
h
)
h
=
hs
[
-
1
]
# h = hs[-1]
h
=
modules
[
m_idx
](
h
,
temb
)
# h = modules[m_idx](h, temb)
m_idx
+=
1
# m_idx += 1
h
=
modules
[
m_idx
](
h
)
# h = modules[m_idx](h)
m_idx
+=
1
# m_idx += 1
h
=
modules
[
m_idx
](
h
,
temb
)
# h = modules[m_idx](h, temb)
m_idx
+=
1
# m_idx += 1
h
=
self
.
mid
(
h
,
temb
)
m_idx
+=
3
pyramid
=
None
pyramid
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment