Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
c1efda70
Unverified
Commit
c1efda70
authored
Aug 25, 2022
by
Patrick von Platen
Committed by
GitHub
Aug 25, 2022
Browse files
[Clean up] Clean unused code (#245)
* CleanResNet * refactor more * correct
parent
47893164
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
39 additions
and
247 deletions
+39
-247
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+1
-1
src/diffusers/models/attention.py
src/diffusers/models/attention.py
+6
-214
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+1
-1
src/diffusers/models/unet_blocks.py
src/diffusers/models/unet_blocks.py
+31
-31
No files found.
src/diffusers/modeling_utils.py
View file @
c1efda70
...
@@ -390,7 +390,7 @@ class ModelMixin(torch.nn.Module):
...
@@ -390,7 +390,7 @@ class ModelMixin(torch.nn.Module):
)
)
except
EntryNotFoundError
:
except
EntryNotFoundError
:
raise
EnvironmentError
(
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
does not appear to have a file named
{
model_file
}
."
f
"
{
pretrained_model_name_or_path
}
does not appear to have a file named
{
WEIGHTS_NAME
}
."
)
)
except
HTTPError
as
err
:
except
HTTPError
as
err
:
raise
EnvironmentError
(
raise
EnvironmentError
(
...
...
src/diffusers/models/attention.py
View file @
c1efda70
import
math
import
math
from
inspect
import
isfunction
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch
import
nn
class
AttentionBlock
New
(
nn
.
Module
):
class
AttentionBlock
(
nn
.
Module
):
"""
"""
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
to the N-d case.
to the N-d case.
...
@@ -82,55 +81,6 @@ class AttentionBlockNew(nn.Module):
...
@@ -82,55 +81,6 @@ class AttentionBlockNew(nn.Module):
hidden_states
=
(
hidden_states
+
residual
)
/
self
.
rescale_output_factor
hidden_states
=
(
hidden_states
+
residual
)
/
self
.
rescale_output_factor
return
hidden_states
return
hidden_states
def
set_weight
(
self
,
attn_layer
):
self
.
group_norm
.
weight
.
data
=
attn_layer
.
norm
.
weight
.
data
self
.
group_norm
.
bias
.
data
=
attn_layer
.
norm
.
bias
.
data
if
hasattr
(
attn_layer
,
"q"
):
self
.
query
.
weight
.
data
=
attn_layer
.
q
.
weight
.
data
[:,
:,
0
,
0
]
self
.
key
.
weight
.
data
=
attn_layer
.
k
.
weight
.
data
[:,
:,
0
,
0
]
self
.
value
.
weight
.
data
=
attn_layer
.
v
.
weight
.
data
[:,
:,
0
,
0
]
self
.
query
.
bias
.
data
=
attn_layer
.
q
.
bias
.
data
self
.
key
.
bias
.
data
=
attn_layer
.
k
.
bias
.
data
self
.
value
.
bias
.
data
=
attn_layer
.
v
.
bias
.
data
self
.
proj_attn
.
weight
.
data
=
attn_layer
.
proj_out
.
weight
.
data
[:,
:,
0
,
0
]
self
.
proj_attn
.
bias
.
data
=
attn_layer
.
proj_out
.
bias
.
data
elif
hasattr
(
attn_layer
,
"NIN_0"
):
self
.
query
.
weight
.
data
=
attn_layer
.
NIN_0
.
W
.
data
.
T
self
.
key
.
weight
.
data
=
attn_layer
.
NIN_1
.
W
.
data
.
T
self
.
value
.
weight
.
data
=
attn_layer
.
NIN_2
.
W
.
data
.
T
self
.
query
.
bias
.
data
=
attn_layer
.
NIN_0
.
b
.
data
self
.
key
.
bias
.
data
=
attn_layer
.
NIN_1
.
b
.
data
self
.
value
.
bias
.
data
=
attn_layer
.
NIN_2
.
b
.
data
self
.
proj_attn
.
weight
.
data
=
attn_layer
.
NIN_3
.
W
.
data
.
T
self
.
proj_attn
.
bias
.
data
=
attn_layer
.
NIN_3
.
b
.
data
self
.
group_norm
.
weight
.
data
=
attn_layer
.
GroupNorm_0
.
weight
.
data
self
.
group_norm
.
bias
.
data
=
attn_layer
.
GroupNorm_0
.
bias
.
data
else
:
qkv_weight
=
attn_layer
.
qkv
.
weight
.
data
.
reshape
(
self
.
num_heads
,
3
*
self
.
channels
//
self
.
num_heads
,
self
.
channels
)
qkv_bias
=
attn_layer
.
qkv
.
bias
.
data
.
reshape
(
self
.
num_heads
,
3
*
self
.
channels
//
self
.
num_heads
)
q_w
,
k_w
,
v_w
=
qkv_weight
.
split
(
self
.
channels
//
self
.
num_heads
,
dim
=
1
)
q_b
,
k_b
,
v_b
=
qkv_bias
.
split
(
self
.
channels
//
self
.
num_heads
,
dim
=
1
)
self
.
query
.
weight
.
data
=
q_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
key
.
weight
.
data
=
k_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
value
.
weight
.
data
=
v_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
query
.
bias
.
data
=
q_b
.
reshape
(
-
1
)
self
.
key
.
bias
.
data
=
k_b
.
reshape
(
-
1
)
self
.
value
.
bias
.
data
=
v_b
.
reshape
(
-
1
)
self
.
proj_attn
.
weight
.
data
=
attn_layer
.
proj
.
weight
.
data
[:,
:,
0
]
self
.
proj_attn
.
bias
.
data
=
attn_layer
.
proj
.
bias
.
data
class
SpatialTransformer
(
nn
.
Module
):
class
SpatialTransformer
(
nn
.
Module
):
"""
"""
...
@@ -170,12 +120,6 @@ class SpatialTransformer(nn.Module):
...
@@ -170,12 +120,6 @@ class SpatialTransformer(nn.Module):
x
=
self
.
proj_out
(
x
)
x
=
self
.
proj_out
(
x
)
return
x
+
x_in
return
x
+
x_in
def
set_weight
(
self
,
layer
):
self
.
norm
=
layer
.
norm
self
.
proj_in
=
layer
.
proj_in
self
.
transformer_blocks
=
layer
.
transformer_blocks
self
.
proj_out
=
layer
.
proj_out
class
BasicTransformerBlock
(
nn
.
Module
):
class
BasicTransformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.0
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
):
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.0
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
):
...
@@ -203,7 +147,7 @@ class CrossAttention(nn.Module):
...
@@ -203,7 +147,7 @@ class CrossAttention(nn.Module):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
):
super
().
__init__
()
super
().
__init__
()
inner_dim
=
dim_head
*
heads
inner_dim
=
dim_head
*
heads
context_dim
=
default
(
context_dim
,
query_dim
)
context_dim
=
context_dim
if
context_dim
is
not
None
else
query_dim
self
.
scale
=
dim_head
**-
0.5
self
.
scale
=
dim_head
**-
0.5
self
.
heads
=
heads
self
.
heads
=
heads
...
@@ -234,7 +178,7 @@ class CrossAttention(nn.Module):
...
@@ -234,7 +178,7 @@ class CrossAttention(nn.Module):
h
=
self
.
heads
h
=
self
.
heads
q
=
self
.
to_q
(
x
)
q
=
self
.
to_q
(
x
)
context
=
default
(
context
,
x
)
context
=
context
if
context
is
not
None
else
x
k
=
self
.
to_k
(
context
)
k
=
self
.
to_k
(
context
)
v
=
self
.
to_v
(
context
)
v
=
self
.
to_v
(
context
)
...
@@ -244,7 +188,7 @@ class CrossAttention(nn.Module):
...
@@ -244,7 +188,7 @@ class CrossAttention(nn.Module):
sim
=
torch
.
einsum
(
"b i d, b j d -> b i j"
,
q
,
k
)
*
self
.
scale
sim
=
torch
.
einsum
(
"b i d, b j d -> b i j"
,
q
,
k
)
*
self
.
scale
if
exists
(
mask
)
:
if
mask
is
not
None
:
mask
=
mask
.
reshape
(
batch_size
,
-
1
)
mask
=
mask
.
reshape
(
batch_size
,
-
1
)
max_neg_value
=
-
torch
.
finfo
(
sim
.
dtype
).
max
max_neg_value
=
-
torch
.
finfo
(
sim
.
dtype
).
max
mask
=
mask
[:,
None
,
:].
repeat
(
h
,
1
,
1
)
mask
=
mask
[:,
None
,
:].
repeat
(
h
,
1
,
1
)
...
@@ -262,8 +206,8 @@ class FeedForward(nn.Module):
...
@@ -262,8 +206,8 @@ class FeedForward(nn.Module):
def
__init__
(
self
,
dim
,
dim_out
=
None
,
mult
=
4
,
glu
=
False
,
dropout
=
0.0
):
def
__init__
(
self
,
dim
,
dim_out
=
None
,
mult
=
4
,
glu
=
False
,
dropout
=
0.0
):
super
().
__init__
()
super
().
__init__
()
inner_dim
=
int
(
dim
*
mult
)
inner_dim
=
int
(
dim
*
mult
)
dim_out
=
d
efault
(
dim_out
,
dim
)
dim_out
=
d
im_out
if
dim_out
is
not
None
else
dim
project_in
=
nn
.
Sequential
(
nn
.
Linear
(
dim
,
inner_dim
),
nn
.
GELU
())
if
not
glu
else
GEGLU
(
dim
,
inner_dim
)
project_in
=
GEGLU
(
dim
,
inner_dim
)
self
.
net
=
nn
.
Sequential
(
project_in
,
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
inner_dim
,
dim_out
))
self
.
net
=
nn
.
Sequential
(
project_in
,
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
inner_dim
,
dim_out
))
...
@@ -280,155 +224,3 @@ class GEGLU(nn.Module):
...
@@ -280,155 +224,3 @@ class GEGLU(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
,
gate
=
self
.
proj
(
x
).
chunk
(
2
,
dim
=-
1
)
x
,
gate
=
self
.
proj
(
x
).
chunk
(
2
,
dim
=-
1
)
return
x
*
F
.
gelu
(
gate
)
return
x
*
F
.
gelu
(
gate
)
# TODO(Patrick) - remove once all weights have been converted -> not needed anymore then
class
NIN
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
num_units
,
init_scale
=
0.1
):
super
().
__init__
()
self
.
W
=
nn
.
Parameter
(
torch
.
zeros
(
in_dim
,
num_units
),
requires_grad
=
True
)
self
.
b
=
nn
.
Parameter
(
torch
.
zeros
(
num_units
),
requires_grad
=
True
)
def
exists
(
val
):
return
val
is
not
None
def
default
(
val
,
d
):
if
exists
(
val
):
return
val
return
d
()
if
isfunction
(
d
)
else
d
# the main attention block that is used for all models
class
AttentionBlock
(
nn
.
Module
):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def
__init__
(
self
,
channels
,
num_heads
=
1
,
num_head_channels
=
None
,
num_groups
=
32
,
encoder_channels
=
None
,
overwrite_qkv
=
False
,
overwrite_linear
=
False
,
rescale_output_factor
=
1.0
,
eps
=
1e-5
,
):
super
().
__init__
()
self
.
channels
=
channels
if
num_head_channels
is
None
:
self
.
num_heads
=
num_heads
else
:
assert
(
channels
%
num_head_channels
==
0
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
eps
,
affine
=
True
)
self
.
qkv
=
nn
.
Conv1d
(
channels
,
channels
*
3
,
1
)
self
.
n_heads
=
self
.
num_heads
self
.
rescale_output_factor
=
rescale_output_factor
if
encoder_channels
is
not
None
:
self
.
encoder_kv
=
nn
.
Conv1d
(
encoder_channels
,
channels
*
2
,
1
)
self
.
proj
=
nn
.
Conv1d
(
channels
,
channels
,
1
)
self
.
overwrite_qkv
=
overwrite_qkv
self
.
overwrite_linear
=
overwrite_linear
if
overwrite_qkv
:
in_channels
=
channels
self
.
norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
1e-6
)
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
elif
self
.
overwrite_linear
:
num_groups
=
min
(
channels
//
4
,
32
)
self
.
norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
1e-6
)
self
.
NIN_0
=
NIN
(
channels
,
channels
)
self
.
NIN_1
=
NIN
(
channels
,
channels
)
self
.
NIN_2
=
NIN
(
channels
,
channels
)
self
.
NIN_3
=
NIN
(
channels
,
channels
)
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
num_groups
,
num_channels
=
channels
,
eps
=
1e-6
)
else
:
self
.
proj_out
=
nn
.
Conv1d
(
channels
,
channels
,
1
)
self
.
set_weights
(
self
)
self
.
is_overwritten
=
False
def
set_weights
(
self
,
module
):
if
self
.
overwrite_qkv
:
qkv_weight
=
torch
.
cat
([
module
.
q
.
weight
.
data
,
module
.
k
.
weight
.
data
,
module
.
v
.
weight
.
data
],
dim
=
0
)[
:,
:,
:,
0
]
qkv_bias
=
torch
.
cat
([
module
.
q
.
bias
.
data
,
module
.
k
.
bias
.
data
,
module
.
v
.
bias
.
data
],
dim
=
0
)
self
.
qkv
.
weight
.
data
=
qkv_weight
self
.
qkv
.
bias
.
data
=
qkv_bias
proj_out
=
nn
.
Conv1d
(
self
.
channels
,
self
.
channels
,
1
)
proj_out
.
weight
.
data
=
module
.
proj_out
.
weight
.
data
[:,
:,
:,
0
]
proj_out
.
bias
.
data
=
module
.
proj_out
.
bias
.
data
self
.
proj
=
proj_out
elif
self
.
overwrite_linear
:
self
.
qkv
.
weight
.
data
=
torch
.
concat
(
[
self
.
NIN_0
.
W
.
data
.
T
,
self
.
NIN_1
.
W
.
data
.
T
,
self
.
NIN_2
.
W
.
data
.
T
],
dim
=
0
)[:,
:,
None
]
self
.
qkv
.
bias
.
data
=
torch
.
concat
([
self
.
NIN_0
.
b
.
data
,
self
.
NIN_1
.
b
.
data
,
self
.
NIN_2
.
b
.
data
],
dim
=
0
)
self
.
proj
.
weight
.
data
=
self
.
NIN_3
.
W
.
data
.
T
[:,
:,
None
]
self
.
proj
.
bias
.
data
=
self
.
NIN_3
.
b
.
data
self
.
norm
.
weight
.
data
=
self
.
GroupNorm_0
.
weight
.
data
self
.
norm
.
bias
.
data
=
self
.
GroupNorm_0
.
bias
.
data
else
:
self
.
proj
.
weight
.
data
=
self
.
proj_out
.
weight
.
data
self
.
proj
.
bias
.
data
=
self
.
proj_out
.
bias
.
data
def
forward
(
self
,
x
,
encoder_out
=
None
):
if
not
self
.
is_overwritten
and
(
self
.
overwrite_qkv
or
self
.
overwrite_linear
):
self
.
set_weights
(
self
)
self
.
is_overwritten
=
True
b
,
c
,
*
spatial
=
x
.
shape
hid_states
=
self
.
norm
(
x
).
view
(
b
,
c
,
-
1
)
qkv
=
self
.
qkv
(
hid_states
)
bs
,
width
,
length
=
qkv
.
shape
assert
width
%
(
3
*
self
.
n_heads
)
==
0
ch
=
width
//
(
3
*
self
.
n_heads
)
q
,
k
,
v
=
qkv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
3
,
length
).
split
(
ch
,
dim
=
1
)
if
encoder_out
is
not
None
:
encoder_kv
=
self
.
encoder_kv
(
encoder_out
)
assert
encoder_kv
.
shape
[
1
]
==
self
.
n_heads
*
ch
*
2
ek
,
ev
=
encoder_kv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
2
,
-
1
).
split
(
ch
,
dim
=
1
)
k
=
torch
.
cat
([
ek
,
k
],
dim
=-
1
)
v
=
torch
.
cat
([
ev
,
v
],
dim
=-
1
)
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
weight
=
torch
.
einsum
(
"bct,bcs->bts"
,
q
*
scale
,
k
*
scale
)
# More stable with f16 than dividing afterwards
weight
=
torch
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
a
=
torch
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
)
h
=
a
.
reshape
(
bs
,
-
1
,
length
)
h
=
self
.
proj
(
h
)
h
=
h
.
reshape
(
b
,
c
,
*
spatial
)
result
=
x
+
h
result
=
result
/
self
.
rescale_output_factor
return
result
src/diffusers/models/resnet.py
View file @
c1efda70
...
@@ -248,7 +248,7 @@ class FirDownsample2D(nn.Module):
...
@@ -248,7 +248,7 @@ class FirDownsample2D(nn.Module):
return
x
return
x
class
ResnetBlock
(
nn
.
Module
):
class
ResnetBlock
2D
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
*
,
*
,
...
...
src/diffusers/models/unet_blocks.py
View file @
c1efda70
...
@@ -17,8 +17,8 @@ import numpy as np
...
@@ -17,8 +17,8 @@ import numpy as np
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
.attention
import
AttentionBlock
New
,
SpatialTransformer
from
.attention
import
AttentionBlock
,
SpatialTransformer
from
.resnet
import
Downsample2D
,
FirDownsample2D
,
FirUpsample2D
,
ResnetBlock
,
Upsample2D
from
.resnet
import
Downsample2D
,
FirDownsample2D
,
FirUpsample2D
,
ResnetBlock
2D
,
Upsample2D
def
get_down_block
(
def
get_down_block
(
...
@@ -219,7 +219,7 @@ class UNetMidBlock2D(nn.Module):
...
@@ -219,7 +219,7 @@ class UNetMidBlock2D(nn.Module):
# there is always at least one resnet
# there is always at least one resnet
resnets
=
[
resnets
=
[
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
in_channels
,
in_channels
=
in_channels
,
out_channels
=
in_channels
,
out_channels
=
in_channels
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
...
@@ -236,7 +236,7 @@ class UNetMidBlock2D(nn.Module):
...
@@ -236,7 +236,7 @@ class UNetMidBlock2D(nn.Module):
for
_
in
range
(
num_layers
):
for
_
in
range
(
num_layers
):
attentions
.
append
(
attentions
.
append
(
AttentionBlock
New
(
AttentionBlock
(
in_channels
,
in_channels
,
num_head_channels
=
attn_num_head_channels
,
num_head_channels
=
attn_num_head_channels
,
rescale_output_factor
=
output_scale_factor
,
rescale_output_factor
=
output_scale_factor
,
...
@@ -245,7 +245,7 @@ class UNetMidBlock2D(nn.Module):
...
@@ -245,7 +245,7 @@ class UNetMidBlock2D(nn.Module):
)
)
)
)
resnets
.
append
(
resnets
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
in_channels
,
in_channels
=
in_channels
,
out_channels
=
in_channels
,
out_channels
=
in_channels
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
...
@@ -299,7 +299,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
...
@@ -299,7 +299,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
# there is always at least one resnet
# there is always at least one resnet
resnets
=
[
resnets
=
[
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
in_channels
,
in_channels
=
in_channels
,
out_channels
=
in_channels
,
out_channels
=
in_channels
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
...
@@ -325,7 +325,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
...
@@ -325,7 +325,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
)
)
)
)
resnets
.
append
(
resnets
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
in_channels
,
in_channels
=
in_channels
,
out_channels
=
in_channels
,
out_channels
=
in_channels
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
...
@@ -379,7 +379,7 @@ class AttnDownBlock2D(nn.Module):
...
@@ -379,7 +379,7 @@ class AttnDownBlock2D(nn.Module):
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
in_channels
=
in_channels
if
i
==
0
else
out_channels
resnets
.
append
(
resnets
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
in_channels
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
...
@@ -393,7 +393,7 @@ class AttnDownBlock2D(nn.Module):
...
@@ -393,7 +393,7 @@ class AttnDownBlock2D(nn.Module):
)
)
)
)
attentions
.
append
(
attentions
.
append
(
AttentionBlock
New
(
AttentionBlock
(
out_channels
,
out_channels
,
num_head_channels
=
attn_num_head_channels
,
num_head_channels
=
attn_num_head_channels
,
rescale_output_factor
=
output_scale_factor
,
rescale_output_factor
=
output_scale_factor
,
...
@@ -461,7 +461,7 @@ class CrossAttnDownBlock2D(nn.Module):
...
@@ -461,7 +461,7 @@ class CrossAttnDownBlock2D(nn.Module):
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
in_channels
=
in_channels
if
i
==
0
else
out_channels
resnets
.
append
(
resnets
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
in_channels
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
...
@@ -537,7 +537,7 @@ class DownBlock2D(nn.Module):
...
@@ -537,7 +537,7 @@ class DownBlock2D(nn.Module):
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
in_channels
=
in_channels
if
i
==
0
else
out_channels
resnets
.
append
(
resnets
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
in_channels
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
...
@@ -602,7 +602,7 @@ class DownEncoderBlock2D(nn.Module):
...
@@ -602,7 +602,7 @@ class DownEncoderBlock2D(nn.Module):
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
in_channels
=
in_channels
if
i
==
0
else
out_channels
resnets
.
append
(
resnets
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
in_channels
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
out_channels
=
out_channels
,
temb_channels
=
None
,
temb_channels
=
None
,
...
@@ -664,7 +664,7 @@ class AttnDownEncoderBlock2D(nn.Module):
...
@@ -664,7 +664,7 @@ class AttnDownEncoderBlock2D(nn.Module):
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
in_channels
=
in_channels
if
i
==
0
else
out_channels
resnets
.
append
(
resnets
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
in_channels
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
out_channels
=
out_channels
,
temb_channels
=
None
,
temb_channels
=
None
,
...
@@ -678,7 +678,7 @@ class AttnDownEncoderBlock2D(nn.Module):
...
@@ -678,7 +678,7 @@ class AttnDownEncoderBlock2D(nn.Module):
)
)
)
)
attentions
.
append
(
attentions
.
append
(
AttentionBlock
New
(
AttentionBlock
(
out_channels
,
out_channels
,
num_head_channels
=
attn_num_head_channels
,
num_head_channels
=
attn_num_head_channels
,
rescale_output_factor
=
output_scale_factor
,
rescale_output_factor
=
output_scale_factor
,
...
@@ -740,7 +740,7 @@ class AttnSkipDownBlock2D(nn.Module):
...
@@ -740,7 +740,7 @@ class AttnSkipDownBlock2D(nn.Module):
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
in_channels
=
in_channels
if
i
==
0
else
out_channels
self
.
resnets
.
append
(
self
.
resnets
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
in_channels
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
...
@@ -755,7 +755,7 @@ class AttnSkipDownBlock2D(nn.Module):
...
@@ -755,7 +755,7 @@ class AttnSkipDownBlock2D(nn.Module):
)
)
)
)
self
.
attentions
.
append
(
self
.
attentions
.
append
(
AttentionBlock
New
(
AttentionBlock
(
out_channels
,
out_channels
,
num_head_channels
=
attn_num_head_channels
,
num_head_channels
=
attn_num_head_channels
,
rescale_output_factor
=
output_scale_factor
,
rescale_output_factor
=
output_scale_factor
,
...
@@ -764,7 +764,7 @@ class AttnSkipDownBlock2D(nn.Module):
...
@@ -764,7 +764,7 @@ class AttnSkipDownBlock2D(nn.Module):
)
)
if
add_downsample
:
if
add_downsample
:
self
.
resnet_down
=
ResnetBlock
(
self
.
resnet_down
=
ResnetBlock
2D
(
in_channels
=
out_channels
,
in_channels
=
out_channels
,
out_channels
=
out_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
...
@@ -828,7 +828,7 @@ class SkipDownBlock2D(nn.Module):
...
@@ -828,7 +828,7 @@ class SkipDownBlock2D(nn.Module):
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
in_channels
=
in_channels
if
i
==
0
else
out_channels
self
.
resnets
.
append
(
self
.
resnets
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
in_channels
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
...
@@ -844,7 +844,7 @@ class SkipDownBlock2D(nn.Module):
...
@@ -844,7 +844,7 @@ class SkipDownBlock2D(nn.Module):
)
)
if
add_downsample
:
if
add_downsample
:
self
.
resnet_down
=
ResnetBlock
(
self
.
resnet_down
=
ResnetBlock
2D
(
in_channels
=
out_channels
,
in_channels
=
out_channels
,
out_channels
=
out_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
...
@@ -915,7 +915,7 @@ class AttnUpBlock2D(nn.Module):
...
@@ -915,7 +915,7 @@ class AttnUpBlock2D(nn.Module):
resnet_in_channels
=
prev_output_channel
if
i
==
0
else
out_channels
resnet_in_channels
=
prev_output_channel
if
i
==
0
else
out_channels
resnets
.
append
(
resnets
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
resnet_in_channels
+
res_skip_channels
,
in_channels
=
resnet_in_channels
+
res_skip_channels
,
out_channels
=
out_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
...
@@ -929,7 +929,7 @@ class AttnUpBlock2D(nn.Module):
...
@@ -929,7 +929,7 @@ class AttnUpBlock2D(nn.Module):
)
)
)
)
attentions
.
append
(
attentions
.
append
(
AttentionBlock
New
(
AttentionBlock
(
out_channels
,
out_channels
,
num_head_channels
=
attn_num_head_channels
,
num_head_channels
=
attn_num_head_channels
,
rescale_output_factor
=
output_scale_factor
,
rescale_output_factor
=
output_scale_factor
,
...
@@ -995,7 +995,7 @@ class CrossAttnUpBlock2D(nn.Module):
...
@@ -995,7 +995,7 @@ class CrossAttnUpBlock2D(nn.Module):
resnet_in_channels
=
prev_output_channel
if
i
==
0
else
out_channels
resnet_in_channels
=
prev_output_channel
if
i
==
0
else
out_channels
resnets
.
append
(
resnets
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
resnet_in_channels
+
res_skip_channels
,
in_channels
=
resnet_in_channels
+
res_skip_channels
,
out_channels
=
out_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
...
@@ -1068,7 +1068,7 @@ class UpBlock2D(nn.Module):
...
@@ -1068,7 +1068,7 @@ class UpBlock2D(nn.Module):
resnet_in_channels
=
prev_output_channel
if
i
==
0
else
out_channels
resnet_in_channels
=
prev_output_channel
if
i
==
0
else
out_channels
resnets
.
append
(
resnets
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
resnet_in_channels
+
res_skip_channels
,
in_channels
=
resnet_in_channels
+
res_skip_channels
,
out_channels
=
out_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
...
@@ -1128,7 +1128,7 @@ class UpDecoderBlock2D(nn.Module):
...
@@ -1128,7 +1128,7 @@ class UpDecoderBlock2D(nn.Module):
input_channels
=
in_channels
if
i
==
0
else
out_channels
input_channels
=
in_channels
if
i
==
0
else
out_channels
resnets
.
append
(
resnets
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
input_channels
,
in_channels
=
input_channels
,
out_channels
=
out_channels
,
out_channels
=
out_channels
,
temb_channels
=
None
,
temb_channels
=
None
,
...
@@ -1184,7 +1184,7 @@ class AttnUpDecoderBlock2D(nn.Module):
...
@@ -1184,7 +1184,7 @@ class AttnUpDecoderBlock2D(nn.Module):
input_channels
=
in_channels
if
i
==
0
else
out_channels
input_channels
=
in_channels
if
i
==
0
else
out_channels
resnets
.
append
(
resnets
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
input_channels
,
in_channels
=
input_channels
,
out_channels
=
out_channels
,
out_channels
=
out_channels
,
temb_channels
=
None
,
temb_channels
=
None
,
...
@@ -1198,7 +1198,7 @@ class AttnUpDecoderBlock2D(nn.Module):
...
@@ -1198,7 +1198,7 @@ class AttnUpDecoderBlock2D(nn.Module):
)
)
)
)
attentions
.
append
(
attentions
.
append
(
AttentionBlock
New
(
AttentionBlock
(
out_channels
,
out_channels
,
num_head_channels
=
attn_num_head_channels
,
num_head_channels
=
attn_num_head_channels
,
rescale_output_factor
=
output_scale_factor
,
rescale_output_factor
=
output_scale_factor
,
...
@@ -1257,7 +1257,7 @@ class AttnSkipUpBlock2D(nn.Module):
...
@@ -1257,7 +1257,7 @@ class AttnSkipUpBlock2D(nn.Module):
resnet_in_channels
=
prev_output_channel
if
i
==
0
else
out_channels
resnet_in_channels
=
prev_output_channel
if
i
==
0
else
out_channels
self
.
resnets
.
append
(
self
.
resnets
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
resnet_in_channels
+
res_skip_channels
,
in_channels
=
resnet_in_channels
+
res_skip_channels
,
out_channels
=
out_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
...
@@ -1273,7 +1273,7 @@ class AttnSkipUpBlock2D(nn.Module):
...
@@ -1273,7 +1273,7 @@ class AttnSkipUpBlock2D(nn.Module):
)
)
self
.
attentions
.
append
(
self
.
attentions
.
append
(
AttentionBlock
New
(
AttentionBlock
(
out_channels
,
out_channels
,
num_head_channels
=
attn_num_head_channels
,
num_head_channels
=
attn_num_head_channels
,
rescale_output_factor
=
output_scale_factor
,
rescale_output_factor
=
output_scale_factor
,
...
@@ -1283,7 +1283,7 @@ class AttnSkipUpBlock2D(nn.Module):
...
@@ -1283,7 +1283,7 @@ class AttnSkipUpBlock2D(nn.Module):
self
.
upsampler
=
FirUpsample2D
(
in_channels
,
out_channels
=
out_channels
)
self
.
upsampler
=
FirUpsample2D
(
in_channels
,
out_channels
=
out_channels
)
if
add_upsample
:
if
add_upsample
:
self
.
resnet_up
=
ResnetBlock
(
self
.
resnet_up
=
ResnetBlock
2D
(
in_channels
=
out_channels
,
in_channels
=
out_channels
,
out_channels
=
out_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
...
@@ -1363,7 +1363,7 @@ class SkipUpBlock2D(nn.Module):
...
@@ -1363,7 +1363,7 @@ class SkipUpBlock2D(nn.Module):
resnet_in_channels
=
prev_output_channel
if
i
==
0
else
out_channels
resnet_in_channels
=
prev_output_channel
if
i
==
0
else
out_channels
self
.
resnets
.
append
(
self
.
resnets
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
resnet_in_channels
+
res_skip_channels
,
in_channels
=
resnet_in_channels
+
res_skip_channels
,
out_channels
=
out_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
...
@@ -1380,7 +1380,7 @@ class SkipUpBlock2D(nn.Module):
...
@@ -1380,7 +1380,7 @@ class SkipUpBlock2D(nn.Module):
self
.
upsampler
=
FirUpsample2D
(
in_channels
,
out_channels
=
out_channels
)
self
.
upsampler
=
FirUpsample2D
(
in_channels
,
out_channels
=
out_channels
)
if
add_upsample
:
if
add_upsample
:
self
.
resnet_up
=
ResnetBlock
(
self
.
resnet_up
=
ResnetBlock
2D
(
in_channels
=
out_channels
,
in_channels
=
out_channels
,
out_channels
=
out_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment