Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
abedfb08
Unverified
Commit
abedfb08
authored
Jul 01, 2022
by
Patrick von Platen
Committed by
GitHub
Jul 01, 2022
Browse files
Merge pull request #57 from huggingface/big_clean_up
[Clean up] Clean up unused code
parents
810c0e4f
61ea57c5
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
105 additions
and
536 deletions
+105
-536
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+0
-42
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+0
-13
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+105
-481
No files found.
src/diffusers/models/unet.py
View file @
abedfb08
...
@@ -34,48 +34,6 @@ def Normalize(in_channels):
...
@@ -34,48 +34,6 @@ def Normalize(in_channels):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
# class ResnetBlock(nn.Module):
# def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
# super().__init__()
# self.in_channels = in_channels
# out_channels = in_channels if out_channels is None else out_channels
# self.out_channels = out_channels
# self.use_conv_shortcut = conv_shortcut
#
# self.norm1 = Normalize(in_channels)
# self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
# self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
# self.norm2 = Normalize(out_channels)
# self.dropout = torch.nn.Dropout(dropout)
# self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
# if self.in_channels != self.out_channels:
# if self.use_conv_shortcut:
# self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
# else:
# self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
#
# def forward(self, x, temb):
# h = x
# h = self.norm1(h)
# h = nonlinearity(h)
# h = self.conv1(h)
#
# h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
#
# h = self.norm2(h)
# h = nonlinearity(h)
# h = self.dropout(h)
# h = self.conv2(h)
#
# if self.in_channels != self.out_channels:
# if self.use_conv_shortcut:
# x = self.conv_shortcut(x)
# else:
# x = self.nin_shortcut(x)
#
# return x + h
class
UNetModel
(
ModelMixin
,
ConfigMixin
):
class
UNetModel
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
def
__init__
(
self
,
self
,
...
...
src/diffusers/models/unet_glide.py
View file @
abedfb08
...
@@ -29,19 +29,6 @@ def convert_module_to_f32(l):
...
@@ -29,19 +29,6 @@ def convert_module_to_f32(l):
l
.
bias
.
data
=
l
.
bias
.
data
.
float
()
l
.
bias
.
data
=
l
.
bias
.
data
.
float
()
def
avg_pool_nd
(
dims
,
*
args
,
**
kwargs
):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if
dims
==
1
:
return
nn
.
AvgPool1d
(
*
args
,
**
kwargs
)
elif
dims
==
2
:
return
nn
.
AvgPool2d
(
*
args
,
**
kwargs
)
elif
dims
==
3
:
return
nn
.
AvgPool3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
"""
"""
Create a 1D, 2D, or 3D convolution module.
Create a 1D, 2D, or 3D convolution module.
...
...
src/diffusers/models/unet_ldm.py
View file @
abedfb08
...
@@ -78,182 +78,6 @@ def Normalize(in_channels):
...
@@ -78,182 +78,6 @@ def Normalize(in_channels):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
# class LinearAttention(nn.Module):
# def __init__(self, dim, heads=4, dim_head=32):
# super().__init__()
# self.heads = heads
# hidden_dim = dim_head * heads
# self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
# self.to_out = nn.Conv2d(hidden_dim, dim, 1)
#
# def forward(self, x):
# b, c, h, w = x.shape
# qkv = self.to_qkv(x)
# q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
# import ipdb; ipdb.set_trace()
# k = k.softmax(dim=-1)
# context = torch.einsum("bhdn,bhen->bhde", k, v)
# out = torch.einsum("bhde,bhdn->bhen", context, q)
# out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
# return self.to_out(out)
#
# class SpatialSelfAttention(nn.Module):
# def __init__(self, in_channels):
# super().__init__()
# self.in_channels = in_channels
#
# self.norm = Normalize(in_channels)
# self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
# self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
# self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
# self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
#
# def forward(self, x):
# h_ = x
# h_ = self.norm(h_)
# q = self.q(h_)
# k = self.k(h_)
# v = self.v(h_)
#
# compute attention
# b, c, h, w = q.shape
# q = rearrange(q, "b c h w -> b (h w) c")
# k = rearrange(k, "b c h w -> b c (h w)")
# w_ = torch.einsum("bij,bjk->bik", q, k)
#
# w_ = w_ * (int(c) ** (-0.5))
# w_ = torch.nn.functional.softmax(w_, dim=2)
#
# attend to values
# v = rearrange(v, "b c h w -> b c (h w)")
# w_ = rearrange(w_, "b i j -> b j i")
# h_ = torch.einsum("bij,bjk->bik", v, w_)
# h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
# h_ = self.proj_out(h_)
#
# return x + h_
#
class
CrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
):
super
().
__init__
()
inner_dim
=
dim_head
*
heads
context_dim
=
default
(
context_dim
,
query_dim
)
self
.
scale
=
dim_head
**-
0.5
self
.
heads
=
heads
self
.
to_q
=
nn
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
)
self
.
to_k
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_v
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_out
=
nn
.
Sequential
(
nn
.
Linear
(
inner_dim
,
query_dim
),
nn
.
Dropout
(
dropout
))
def
reshape_heads_to_batch_dim
(
self
,
tensor
):
batch_size
,
seq_len
,
dim
=
tensor
.
shape
head_size
=
self
.
heads
tensor
=
tensor
.
reshape
(
batch_size
,
seq_len
,
head_size
,
dim
//
head_size
)
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
batch_size
*
head_size
,
seq_len
,
dim
//
head_size
)
return
tensor
def
reshape_batch_dim_to_heads
(
self
,
tensor
):
batch_size
,
seq_len
,
dim
=
tensor
.
shape
head_size
=
self
.
heads
tensor
=
tensor
.
reshape
(
batch_size
//
head_size
,
head_size
,
seq_len
,
dim
)
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
batch_size
//
head_size
,
seq_len
,
dim
*
head_size
)
return
tensor
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
):
batch_size
,
sequence_length
,
dim
=
x
.
shape
h
=
self
.
heads
q
=
self
.
to_q
(
x
)
context
=
default
(
context
,
x
)
k
=
self
.
to_k
(
context
)
v
=
self
.
to_v
(
context
)
q
=
self
.
reshape_heads_to_batch_dim
(
q
)
k
=
self
.
reshape_heads_to_batch_dim
(
k
)
v
=
self
.
reshape_heads_to_batch_dim
(
v
)
sim
=
torch
.
einsum
(
"b i d, b j d -> b i j"
,
q
,
k
)
*
self
.
scale
if
exists
(
mask
):
mask
=
mask
.
reshape
(
batch_size
,
-
1
)
max_neg_value
=
-
torch
.
finfo
(
sim
.
dtype
).
max
mask
=
mask
[:,
None
,
:].
repeat
(
h
,
1
,
1
)
sim
.
masked_fill_
(
~
mask
,
max_neg_value
)
# attention, what we cannot get enough of
attn
=
sim
.
softmax
(
dim
=-
1
)
out
=
torch
.
einsum
(
"b i j, b j d -> b i d"
,
attn
,
v
)
out
=
self
.
reshape_batch_dim_to_heads
(
out
)
return
self
.
to_out
(
out
)
class
BasicTransformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.0
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
):
super
().
__init__
()
self
.
attn1
=
CrossAttention
(
query_dim
=
dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
)
# is a self-attention
self
.
ff
=
FeedForward
(
dim
,
dropout
=
dropout
,
glu
=
gated_ff
)
self
.
attn2
=
CrossAttention
(
query_dim
=
dim
,
context_dim
=
context_dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
)
# is self-attn if context is none
self
.
norm1
=
nn
.
LayerNorm
(
dim
)
self
.
norm2
=
nn
.
LayerNorm
(
dim
)
self
.
norm3
=
nn
.
LayerNorm
(
dim
)
self
.
checkpoint
=
checkpoint
def
forward
(
self
,
x
,
context
=
None
):
x
=
self
.
attn1
(
self
.
norm1
(
x
))
+
x
x
=
self
.
attn2
(
self
.
norm2
(
x
),
context
=
context
)
+
x
x
=
self
.
ff
(
self
.
norm3
(
x
))
+
x
return
x
class
SpatialTransformer
(
nn
.
Module
):
"""
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
standard transformer action. Finally, reshape to image
"""
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
depth
=
1
,
dropout
=
0.0
,
context_dim
=
None
):
super
().
__init__
()
self
.
in_channels
=
in_channels
inner_dim
=
n_heads
*
d_head
self
.
norm
=
Normalize
(
in_channels
)
self
.
proj_in
=
nn
.
Conv2d
(
in_channels
,
inner_dim
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
transformer_blocks
=
nn
.
ModuleList
(
[
BasicTransformerBlock
(
inner_dim
,
n_heads
,
d_head
,
dropout
=
dropout
,
context_dim
=
context_dim
)
for
d
in
range
(
depth
)
]
)
self
.
proj_out
=
zero_module
(
nn
.
Conv2d
(
inner_dim
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
))
def
forward
(
self
,
x
,
context
=
None
):
# note: if no context is given, cross-attention defaults to self-attention
b
,
c
,
h
,
w
=
x
.
shape
x_in
=
x
x
=
self
.
norm
(
x
)
x
=
self
.
proj_in
(
x
)
x
=
x
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
b
,
h
*
w
,
c
)
for
block
in
self
.
transformer_blocks
:
x
=
block
(
x
,
context
=
context
)
x
=
x
.
reshape
(
b
,
h
,
w
,
c
).
permute
(
0
,
3
,
1
,
2
)
x
=
self
.
proj_out
(
x
)
return
x
+
x_in
def
convert_module_to_f16
(
l
):
def
convert_module_to_f16
(
l
):
"""
"""
Convert primitive modules to float16.
Convert primitive modules to float16.
...
@@ -274,19 +98,6 @@ def convert_module_to_f32(l):
...
@@ -274,19 +98,6 @@ def convert_module_to_f32(l):
l
.
bias
.
data
=
l
.
bias
.
data
.
float
()
l
.
bias
.
data
=
l
.
bias
.
data
.
float
()
def
avg_pool_nd
(
dims
,
*
args
,
**
kwargs
):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if
dims
==
1
:
return
nn
.
AvgPool1d
(
*
args
,
**
kwargs
)
elif
dims
==
2
:
return
nn
.
AvgPool2d
(
*
args
,
**
kwargs
)
elif
dims
==
3
:
return
nn
.
AvgPool3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
"""
"""
Create a 1D, 2D, or 3D convolution module.
Create a 1D, 2D, or 3D convolution module.
...
@@ -330,36 +141,6 @@ def normalization(channels, swish=0.0):
...
@@ -330,36 +141,6 @@ def normalization(channels, swish=0.0):
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
class
AttentionPool2d
(
nn
.
Module
):
"""
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
"""
def
__init__
(
self
,
spacial_dim
:
int
,
embed_dim
:
int
,
num_heads_channels
:
int
,
output_dim
:
int
=
None
,
):
super
().
__init__
()
self
.
positional_embedding
=
nn
.
Parameter
(
torch
.
randn
(
embed_dim
,
spacial_dim
**
2
+
1
)
/
embed_dim
**
0.5
)
self
.
qkv_proj
=
conv_nd
(
1
,
embed_dim
,
3
*
embed_dim
,
1
)
self
.
c_proj
=
conv_nd
(
1
,
embed_dim
,
output_dim
or
embed_dim
,
1
)
self
.
num_heads
=
embed_dim
//
num_heads_channels
self
.
attention
=
QKVAttention
(
self
.
num_heads
)
def
forward
(
self
,
x
):
b
,
c
,
*
_spatial
=
x
.
shape
x
=
x
.
reshape
(
b
,
c
,
-
1
)
# NC(HW)
x
=
torch
.
cat
([
x
.
mean
(
dim
=-
1
,
keepdim
=
True
),
x
],
dim
=-
1
)
# NC(HW+1)
x
=
x
+
self
.
positional_embedding
[
None
,
:,
:].
to
(
x
.
dtype
)
# NC(HW+1)
x
=
self
.
qkv_proj
(
x
)
x
=
self
.
attention
(
x
)
x
=
self
.
c_proj
(
x
)
return
x
[:,
:,
0
]
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
"""
"""
A sequential module that passes timestep embeddings to the children that support it as an extra input.
A sequential module that passes timestep embeddings to the children that support it as an extra input.
...
@@ -376,39 +157,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
...
@@ -376,39 +157,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return
x
return
x
class
QKVAttention
(
nn
.
Module
):
"""
A module which performs QKV attention and splits in a different order.
"""
def
__init__
(
self
,
n_heads
):
super
().
__init__
()
self
.
n_heads
=
n_heads
def
forward
(
self
,
qkv
):
"""
Apply QKV attention. :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x
T] tensor after attention.
"""
bs
,
width
,
length
=
qkv
.
shape
assert
width
%
(
3
*
self
.
n_heads
)
==
0
ch
=
width
//
(
3
*
self
.
n_heads
)
q
,
k
,
v
=
qkv
.
chunk
(
3
,
dim
=
1
)
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
weight
=
torch
.
einsum
(
"bct,bcs->bts"
,
(
q
*
scale
).
view
(
bs
*
self
.
n_heads
,
ch
,
length
),
(
k
*
scale
).
view
(
bs
*
self
.
n_heads
,
ch
,
length
),
)
# More stable with f16 than dividing afterwards
weight
=
torch
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
a
=
torch
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
.
reshape
(
bs
*
self
.
n_heads
,
ch
,
length
))
return
a
.
reshape
(
bs
,
-
1
,
length
)
@
staticmethod
def
count_flops
(
model
,
_x
,
y
):
return
count_flops_attn
(
model
,
_x
,
y
)
def
count_flops_attn
(
model
,
_x
,
y
):
def
count_flops_attn
(
model
,
_x
,
y
):
"""
"""
A counter for the `thop` package to count the operations in an attention operation. Meant to be used like:
A counter for the `thop` package to count the operations in an attention operation. Meant to be used like:
...
@@ -602,21 +350,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -602,21 +350,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
out_ch
=
ch
out_ch
=
ch
self
.
input_blocks
.
append
(
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
TimestepEmbedSequential
(
# ResBlock(
Downsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
padding
=
1
,
name
=
"op"
)
# ch,
# time_embed_dim,
# dropout,
# out_channels=out_ch,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# down=True,
# )
None
if
resblock_updown
else
Downsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
padding
=
1
,
name
=
"op"
)
)
)
)
)
ch
=
out_ch
ch
=
out_ch
...
@@ -703,21 +437,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -703,21 +437,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
)
)
if
level
and
i
==
num_res_blocks
:
if
level
and
i
==
num_res_blocks
:
out_ch
=
ch
out_ch
=
ch
layers
.
append
(
layers
.
append
(
Upsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
))
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# out_channels=out_ch,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# up=True,
# )
None
if
resblock_updown
else
Upsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
)
ds
//=
2
ds
//=
2
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
...
@@ -784,215 +504,119 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -784,215 +504,119 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
return
self
.
out
(
h
)
return
self
.
out
(
h
)
class
EncoderUNetModel
(
nn
.
Module
):
class
SpatialTransformer
(
nn
.
Module
):
"""
"""
The half UNet model with attention and timestep embedding. For usage, see UNet.
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
standard transformer action. Finally, reshape to image
"""
"""
def
__init__
(
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
depth
=
1
,
dropout
=
0.0
,
context_dim
=
None
):
self
,
image_size
,
in_channels
,
model_channels
,
out_channels
,
num_res_blocks
,
attention_resolutions
,
dropout
=
0
,
channel_mult
=
(
1
,
2
,
4
,
8
),
conv_resample
=
True
,
dims
=
2
,
use_checkpoint
=
False
,
use_fp16
=
False
,
num_heads
=
1
,
num_head_channels
=-
1
,
num_heads_upsample
=-
1
,
use_scale_shift_norm
=
False
,
resblock_updown
=
False
,
use_new_attention_order
=
False
,
pool
=
"adaptive"
,
*
args
,
**
kwargs
,
):
super
().
__init__
()
super
().
__init__
()
if
num_heads_upsample
==
-
1
:
num_heads_upsample
=
num_heads
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
model_channels
=
model_channels
inner_dim
=
n_heads
*
d_head
self
.
out_channels
=
out_channels
self
.
norm
=
Normalize
(
in_channels
)
self
.
num_res_blocks
=
num_res_blocks
self
.
attention_resolutions
=
attention_resolutions
self
.
dropout
=
dropout
self
.
channel_mult
=
channel_mult
self
.
conv_resample
=
conv_resample
self
.
use_checkpoint
=
use_checkpoint
self
.
dtype
=
torch
.
float16
if
use_fp16
else
torch
.
float32
self
.
num_heads
=
num_heads
self
.
num_head_channels
=
num_head_channels
self
.
num_heads_upsample
=
num_heads_upsample
time_embed_dim
=
model_channels
*
4
self
.
proj_in
=
nn
.
Conv2d
(
in_channels
,
inner_dim
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
time_embed
=
nn
.
Sequential
(
linear
(
model_channels
,
time_embed_dim
),
nn
.
SiLU
(),
linear
(
time_embed_dim
,
time_embed_dim
),
)
self
.
input_blocks
=
nn
.
ModuleList
(
self
.
transformer_blocks
=
nn
.
ModuleList
(
[
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
model_channels
,
3
,
padding
=
1
))]
[
)
BasicTransformerBlock
(
inner_dim
,
n_heads
,
d_head
,
dropout
=
dropout
,
context_dim
=
context_dim
)
self
.
_feature_size
=
model_channels
for
d
in
range
(
depth
)
input_block_chans
=
[
model_channels
]
ch
=
model_channels
ds
=
1
for
level
,
mult
in
enumerate
(
channel_mult
):
for
_
in
range
(
num_res_blocks
):
layers
=
[
ResnetBlock
(
in_channels
=
ch
,
out_channels
=
model_channels
*
mult
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
overwrite_for_ldm
=
True
,
),
]
]
ch
=
mult
*
model_channels
if
ds
in
attention_resolutions
:
layers
.
append
(
AttentionBlock
(
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
use_new_attention_order
=
use_new_attention_order
,
)
)
)
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
_feature_size
+=
ch
input_block_chans
.
append
(
ch
)
if
level
!=
len
(
channel_mult
)
-
1
:
out_ch
=
ch
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# out_channels=out_ch,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# down=True,
# )
None
if
resblock_updown
else
Downsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
padding
=
1
,
name
=
"op"
)
)
)
ch
=
out_ch
input_block_chans
.
append
(
ch
)
ds
*=
2
self
.
_feature_size
+=
ch
self
.
middle_block
=
TimestepEmbedSequential
(
self
.
proj_out
=
zero_module
(
nn
.
Conv2d
(
inner_dim
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
))
ResnetBlock
(
in_channels
=
ch
,
out_channels
=
None
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
overwrite_for_ldm
=
True
,
),
AttentionBlock
(
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
use_new_attention_order
=
use_new_attention_order
,
),
ResnetBlock
(
in_channels
=
ch
,
out_channels
=
None
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
overwrite_for_ldm
=
True
,
),
)
self
.
_feature_size
+=
ch
self
.
pool
=
pool
if
pool
==
"adaptive"
:
self
.
out
=
nn
.
Sequential
(
normalization
(
ch
),
nn
.
SiLU
(),
nn
.
AdaptiveAvgPool2d
((
1
,
1
)),
zero_module
(
conv_nd
(
dims
,
ch
,
out_channels
,
1
)),
nn
.
Flatten
(),
)
elif
pool
==
"attention"
:
assert
num_head_channels
!=
-
1
self
.
out
=
nn
.
Sequential
(
normalization
(
ch
),
nn
.
SiLU
(),
AttentionPool2d
((
image_size
//
ds
),
ch
,
num_head_channels
,
out_channels
),
)
elif
pool
==
"spatial"
:
self
.
out
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
_feature_size
,
2048
),
nn
.
ReLU
(),
nn
.
Linear
(
2048
,
self
.
out_channels
),
)
elif
pool
==
"spatial_v2"
:
self
.
out
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
_feature_size
,
2048
),
normalization
(
2048
),
nn
.
SiLU
(),
nn
.
Linear
(
2048
,
self
.
out_channels
),
)
else
:
raise
NotImplementedError
(
f
"Unexpected
{
pool
}
pooling"
)
def
convert_to_fp16
(
self
):
def
forward
(
self
,
x
,
context
=
None
):
"""
# note: if no context is given, cross-attention defaults to self-attention
Convert the torso of the model to float16.
b
,
c
,
h
,
w
=
x
.
shape
"""
x_in
=
x
self
.
input_blocks
.
apply
(
convert_module_to_f16
)
x
=
self
.
norm
(
x
)
self
.
middle_block
.
apply
(
convert_module_to_f16
)
x
=
self
.
proj_in
(
x
)
x
=
x
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
b
,
h
*
w
,
c
)
for
block
in
self
.
transformer_blocks
:
x
=
block
(
x
,
context
=
context
)
x
=
x
.
reshape
(
b
,
h
,
w
,
c
).
permute
(
0
,
3
,
1
,
2
)
x
=
self
.
proj_out
(
x
)
return
x
+
x_in
def
convert_to_fp32
(
self
):
"""
Convert the torso of the model to float32.
"""
self
.
input_blocks
.
apply
(
convert_module_to_f32
)
self
.
middle_block
.
apply
(
convert_module_to_f32
)
def
forward
(
self
,
x
,
timesteps
):
class
BasicTransformerBlock
(
nn
.
Module
):
"""
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.0
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
):
Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch
super
().
__init__
()
of timesteps. :return: an [N x K] Tensor of outputs.
self
.
attn1
=
CrossAttention
(
"""
query_dim
=
dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
emb
=
self
.
time_embed
(
)
# is a self-attention
get_timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
)
self
.
ff
=
FeedForward
(
dim
,
dropout
=
dropout
,
glu
=
gated_ff
)
)
self
.
attn2
=
CrossAttention
(
query_dim
=
dim
,
context_dim
=
context_dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
)
# is self-attn if context is none
self
.
norm1
=
nn
.
LayerNorm
(
dim
)
self
.
norm2
=
nn
.
LayerNorm
(
dim
)
self
.
norm3
=
nn
.
LayerNorm
(
dim
)
self
.
checkpoint
=
checkpoint
results
=
[]
def
forward
(
self
,
x
,
context
=
None
):
h
=
x
.
type
(
self
.
dtype
)
x
=
self
.
attn1
(
self
.
norm1
(
x
))
+
x
for
module
in
self
.
input_blocks
:
x
=
self
.
attn2
(
self
.
norm2
(
x
),
context
=
context
)
+
x
h
=
module
(
h
,
emb
)
x
=
self
.
ff
(
self
.
norm3
(
x
))
+
x
if
self
.
pool
.
startswith
(
"spatial"
):
return
x
results
.
append
(
h
.
type
(
x
.
dtype
).
mean
(
dim
=
(
2
,
3
)))
h
=
self
.
middle_block
(
h
,
emb
)
if
self
.
pool
.
startswith
(
"spatial"
):
class
CrossAttention
(
nn
.
Module
):
results
.
append
(
h
.
type
(
x
.
dtype
).
mean
(
dim
=
(
2
,
3
)))
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
):
h
=
torch
.
cat
(
results
,
axis
=-
1
)
super
().
__init__
()
return
self
.
out
(
h
)
inner_dim
=
dim_head
*
heads
else
:
context_dim
=
default
(
context_dim
,
query_dim
)
h
=
h
.
type
(
x
.
dtype
)
return
self
.
out
(
h
)
self
.
scale
=
dim_head
**-
0.5
self
.
heads
=
heads
self
.
to_q
=
nn
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
)
self
.
to_k
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_v
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_out
=
nn
.
Sequential
(
nn
.
Linear
(
inner_dim
,
query_dim
),
nn
.
Dropout
(
dropout
))
def
reshape_heads_to_batch_dim
(
self
,
tensor
):
batch_size
,
seq_len
,
dim
=
tensor
.
shape
head_size
=
self
.
heads
tensor
=
tensor
.
reshape
(
batch_size
,
seq_len
,
head_size
,
dim
//
head_size
)
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
batch_size
*
head_size
,
seq_len
,
dim
//
head_size
)
return
tensor
def
reshape_batch_dim_to_heads
(
self
,
tensor
):
batch_size
,
seq_len
,
dim
=
tensor
.
shape
head_size
=
self
.
heads
tensor
=
tensor
.
reshape
(
batch_size
//
head_size
,
head_size
,
seq_len
,
dim
)
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
batch_size
//
head_size
,
seq_len
,
dim
*
head_size
)
return
tensor
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
):
batch_size
,
sequence_length
,
dim
=
x
.
shape
h
=
self
.
heads
q
=
self
.
to_q
(
x
)
context
=
default
(
context
,
x
)
k
=
self
.
to_k
(
context
)
v
=
self
.
to_v
(
context
)
q
=
self
.
reshape_heads_to_batch_dim
(
q
)
k
=
self
.
reshape_heads_to_batch_dim
(
k
)
v
=
self
.
reshape_heads_to_batch_dim
(
v
)
sim
=
torch
.
einsum
(
"b i d, b j d -> b i j"
,
q
,
k
)
*
self
.
scale
if
exists
(
mask
):
mask
=
mask
.
reshape
(
batch_size
,
-
1
)
max_neg_value
=
-
torch
.
finfo
(
sim
.
dtype
).
max
mask
=
mask
[:,
None
,
:].
repeat
(
h
,
1
,
1
)
sim
.
masked_fill_
(
~
mask
,
max_neg_value
)
# attention, what we cannot get enough of
attn
=
sim
.
softmax
(
dim
=-
1
)
out
=
torch
.
einsum
(
"b i j, b j d -> b i d"
,
attn
,
v
)
out
=
self
.
reshape_batch_dim_to_heads
(
out
)
return
self
.
to_out
(
out
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment