Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
c1efda70
Unverified
Commit
c1efda70
authored
Aug 25, 2022
by
Patrick von Platen
Committed by
GitHub
Aug 25, 2022
Browse files
[Clean up] Clean unused code (#245)
* CleanResNet * refactor more * correct
parent
47893164
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
39 additions
and
247 deletions
+39
-247
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+1
-1
src/diffusers/models/attention.py
src/diffusers/models/attention.py
+6
-214
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+1
-1
src/diffusers/models/unet_blocks.py
src/diffusers/models/unet_blocks.py
+31
-31
No files found.
src/diffusers/modeling_utils.py
View file @
c1efda70
...
...
@@ -390,7 +390,7 @@ class ModelMixin(torch.nn.Module):
)
except
EntryNotFoundError
:
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
does not appear to have a file named
{
model_file
}
."
f
"
{
pretrained_model_name_or_path
}
does not appear to have a file named
{
WEIGHTS_NAME
}
."
)
except
HTTPError
as
err
:
raise
EnvironmentError
(
...
...
src/diffusers/models/attention.py
View file @
c1efda70
import
math
from
inspect
import
isfunction
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
class
AttentionBlock
New
(
nn
.
Module
):
class
AttentionBlock
(
nn
.
Module
):
"""
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
to the N-d case.
...
...
@@ -82,55 +81,6 @@ class AttentionBlockNew(nn.Module):
hidden_states
=
(
hidden_states
+
residual
)
/
self
.
rescale_output_factor
return
hidden_states
def
set_weight
(
self
,
attn_layer
):
self
.
group_norm
.
weight
.
data
=
attn_layer
.
norm
.
weight
.
data
self
.
group_norm
.
bias
.
data
=
attn_layer
.
norm
.
bias
.
data
if
hasattr
(
attn_layer
,
"q"
):
self
.
query
.
weight
.
data
=
attn_layer
.
q
.
weight
.
data
[:,
:,
0
,
0
]
self
.
key
.
weight
.
data
=
attn_layer
.
k
.
weight
.
data
[:,
:,
0
,
0
]
self
.
value
.
weight
.
data
=
attn_layer
.
v
.
weight
.
data
[:,
:,
0
,
0
]
self
.
query
.
bias
.
data
=
attn_layer
.
q
.
bias
.
data
self
.
key
.
bias
.
data
=
attn_layer
.
k
.
bias
.
data
self
.
value
.
bias
.
data
=
attn_layer
.
v
.
bias
.
data
self
.
proj_attn
.
weight
.
data
=
attn_layer
.
proj_out
.
weight
.
data
[:,
:,
0
,
0
]
self
.
proj_attn
.
bias
.
data
=
attn_layer
.
proj_out
.
bias
.
data
elif
hasattr
(
attn_layer
,
"NIN_0"
):
self
.
query
.
weight
.
data
=
attn_layer
.
NIN_0
.
W
.
data
.
T
self
.
key
.
weight
.
data
=
attn_layer
.
NIN_1
.
W
.
data
.
T
self
.
value
.
weight
.
data
=
attn_layer
.
NIN_2
.
W
.
data
.
T
self
.
query
.
bias
.
data
=
attn_layer
.
NIN_0
.
b
.
data
self
.
key
.
bias
.
data
=
attn_layer
.
NIN_1
.
b
.
data
self
.
value
.
bias
.
data
=
attn_layer
.
NIN_2
.
b
.
data
self
.
proj_attn
.
weight
.
data
=
attn_layer
.
NIN_3
.
W
.
data
.
T
self
.
proj_attn
.
bias
.
data
=
attn_layer
.
NIN_3
.
b
.
data
self
.
group_norm
.
weight
.
data
=
attn_layer
.
GroupNorm_0
.
weight
.
data
self
.
group_norm
.
bias
.
data
=
attn_layer
.
GroupNorm_0
.
bias
.
data
else
:
qkv_weight
=
attn_layer
.
qkv
.
weight
.
data
.
reshape
(
self
.
num_heads
,
3
*
self
.
channels
//
self
.
num_heads
,
self
.
channels
)
qkv_bias
=
attn_layer
.
qkv
.
bias
.
data
.
reshape
(
self
.
num_heads
,
3
*
self
.
channels
//
self
.
num_heads
)
q_w
,
k_w
,
v_w
=
qkv_weight
.
split
(
self
.
channels
//
self
.
num_heads
,
dim
=
1
)
q_b
,
k_b
,
v_b
=
qkv_bias
.
split
(
self
.
channels
//
self
.
num_heads
,
dim
=
1
)
self
.
query
.
weight
.
data
=
q_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
key
.
weight
.
data
=
k_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
value
.
weight
.
data
=
v_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
query
.
bias
.
data
=
q_b
.
reshape
(
-
1
)
self
.
key
.
bias
.
data
=
k_b
.
reshape
(
-
1
)
self
.
value
.
bias
.
data
=
v_b
.
reshape
(
-
1
)
self
.
proj_attn
.
weight
.
data
=
attn_layer
.
proj
.
weight
.
data
[:,
:,
0
]
self
.
proj_attn
.
bias
.
data
=
attn_layer
.
proj
.
bias
.
data
class
SpatialTransformer
(
nn
.
Module
):
"""
...
...
@@ -170,12 +120,6 @@ class SpatialTransformer(nn.Module):
x
=
self
.
proj_out
(
x
)
return
x
+
x_in
def
set_weight
(
self
,
layer
):
self
.
norm
=
layer
.
norm
self
.
proj_in
=
layer
.
proj_in
self
.
transformer_blocks
=
layer
.
transformer_blocks
self
.
proj_out
=
layer
.
proj_out
class
BasicTransformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.0
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
):
...
...
@@ -203,7 +147,7 @@ class CrossAttention(nn.Module):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
):
super
().
__init__
()
inner_dim
=
dim_head
*
heads
context_dim
=
default
(
context_dim
,
query_dim
)
context_dim
=
context_dim
if
context_dim
is
not
None
else
query_dim
self
.
scale
=
dim_head
**-
0.5
self
.
heads
=
heads
...
...
@@ -234,7 +178,7 @@ class CrossAttention(nn.Module):
h
=
self
.
heads
q
=
self
.
to_q
(
x
)
context
=
default
(
context
,
x
)
context
=
context
if
context
is
not
None
else
x
k
=
self
.
to_k
(
context
)
v
=
self
.
to_v
(
context
)
...
...
@@ -244,7 +188,7 @@ class CrossAttention(nn.Module):
sim
=
torch
.
einsum
(
"b i d, b j d -> b i j"
,
q
,
k
)
*
self
.
scale
if
exists
(
mask
)
:
if
mask
is
not
None
:
mask
=
mask
.
reshape
(
batch_size
,
-
1
)
max_neg_value
=
-
torch
.
finfo
(
sim
.
dtype
).
max
mask
=
mask
[:,
None
,
:].
repeat
(
h
,
1
,
1
)
...
...
@@ -262,8 +206,8 @@ class FeedForward(nn.Module):
def
__init__
(
self
,
dim
,
dim_out
=
None
,
mult
=
4
,
glu
=
False
,
dropout
=
0.0
):
super
().
__init__
()
inner_dim
=
int
(
dim
*
mult
)
dim_out
=
d
efault
(
dim_out
,
dim
)
project_in
=
nn
.
Sequential
(
nn
.
Linear
(
dim
,
inner_dim
),
nn
.
GELU
())
if
not
glu
else
GEGLU
(
dim
,
inner_dim
)
dim_out
=
d
im_out
if
dim_out
is
not
None
else
dim
project_in
=
GEGLU
(
dim
,
inner_dim
)
self
.
net
=
nn
.
Sequential
(
project_in
,
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
inner_dim
,
dim_out
))
...
...
@@ -280,155 +224,3 @@ class GEGLU(nn.Module):
def
forward
(
self
,
x
):
x
,
gate
=
self
.
proj
(
x
).
chunk
(
2
,
dim
=-
1
)
return
x
*
F
.
gelu
(
gate
)
# TODO(Patrick) - remove once all weights have been converted -> not needed anymore then
class
NIN
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
num_units
,
init_scale
=
0.1
):
super
().
__init__
()
self
.
W
=
nn
.
Parameter
(
torch
.
zeros
(
in_dim
,
num_units
),
requires_grad
=
True
)
self
.
b
=
nn
.
Parameter
(
torch
.
zeros
(
num_units
),
requires_grad
=
True
)
def
exists
(
val
):
return
val
is
not
None
def
default
(
val
,
d
):
if
exists
(
val
):
return
val
return
d
()
if
isfunction
(
d
)
else
d
# the main attention block that is used for all models
class
AttentionBlock
(
nn
.
Module
):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def
__init__
(
self
,
channels
,
num_heads
=
1
,
num_head_channels
=
None
,
num_groups
=
32
,
encoder_channels
=
None
,
overwrite_qkv
=
False
,
overwrite_linear
=
False
,
rescale_output_factor
=
1.0
,
eps
=
1e-5
,
):
super
().
__init__
()
self
.
channels
=
channels
if
num_head_channels
is
None
:
self
.
num_heads
=
num_heads
else
:
assert
(
channels
%
num_head_channels
==
0
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
eps
,
affine
=
True
)
self
.
qkv
=
nn
.
Conv1d
(
channels
,
channels
*
3
,
1
)
self
.
n_heads
=
self
.
num_heads
self
.
rescale_output_factor
=
rescale_output_factor
if
encoder_channels
is
not
None
:
self
.
encoder_kv
=
nn
.
Conv1d
(
encoder_channels
,
channels
*
2
,
1
)
self
.
proj
=
nn
.
Conv1d
(
channels
,
channels
,
1
)
self
.
overwrite_qkv
=
overwrite_qkv
self
.
overwrite_linear
=
overwrite_linear
if
overwrite_qkv
:
in_channels
=
channels
self
.
norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
1e-6
)
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
elif
self
.
overwrite_linear
:
num_groups
=
min
(
channels
//
4
,
32
)
self
.
norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
1e-6
)
self
.
NIN_0
=
NIN
(
channels
,
channels
)
self
.
NIN_1
=
NIN
(
channels
,
channels
)
self
.
NIN_2
=
NIN
(
channels
,
channels
)
self
.
NIN_3
=
NIN
(
channels
,
channels
)
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
num_groups
,
num_channels
=
channels
,
eps
=
1e-6
)
else
:
self
.
proj_out
=
nn
.
Conv1d
(
channels
,
channels
,
1
)
self
.
set_weights
(
self
)
self
.
is_overwritten
=
False
def
set_weights
(
self
,
module
):
if
self
.
overwrite_qkv
:
qkv_weight
=
torch
.
cat
([
module
.
q
.
weight
.
data
,
module
.
k
.
weight
.
data
,
module
.
v
.
weight
.
data
],
dim
=
0
)[
:,
:,
:,
0
]
qkv_bias
=
torch
.
cat
([
module
.
q
.
bias
.
data
,
module
.
k
.
bias
.
data
,
module
.
v
.
bias
.
data
],
dim
=
0
)
self
.
qkv
.
weight
.
data
=
qkv_weight
self
.
qkv
.
bias
.
data
=
qkv_bias
proj_out
=
nn
.
Conv1d
(
self
.
channels
,
self
.
channels
,
1
)
proj_out
.
weight
.
data
=
module
.
proj_out
.
weight
.
data
[:,
:,
:,
0
]
proj_out
.
bias
.
data
=
module
.
proj_out
.
bias
.
data
self
.
proj
=
proj_out
elif
self
.
overwrite_linear
:
self
.
qkv
.
weight
.
data
=
torch
.
concat
(
[
self
.
NIN_0
.
W
.
data
.
T
,
self
.
NIN_1
.
W
.
data
.
T
,
self
.
NIN_2
.
W
.
data
.
T
],
dim
=
0
)[:,
:,
None
]
self
.
qkv
.
bias
.
data
=
torch
.
concat
([
self
.
NIN_0
.
b
.
data
,
self
.
NIN_1
.
b
.
data
,
self
.
NIN_2
.
b
.
data
],
dim
=
0
)
self
.
proj
.
weight
.
data
=
self
.
NIN_3
.
W
.
data
.
T
[:,
:,
None
]
self
.
proj
.
bias
.
data
=
self
.
NIN_3
.
b
.
data
self
.
norm
.
weight
.
data
=
self
.
GroupNorm_0
.
weight
.
data
self
.
norm
.
bias
.
data
=
self
.
GroupNorm_0
.
bias
.
data
else
:
self
.
proj
.
weight
.
data
=
self
.
proj_out
.
weight
.
data
self
.
proj
.
bias
.
data
=
self
.
proj_out
.
bias
.
data
def
forward
(
self
,
x
,
encoder_out
=
None
):
if
not
self
.
is_overwritten
and
(
self
.
overwrite_qkv
or
self
.
overwrite_linear
):
self
.
set_weights
(
self
)
self
.
is_overwritten
=
True
b
,
c
,
*
spatial
=
x
.
shape
hid_states
=
self
.
norm
(
x
).
view
(
b
,
c
,
-
1
)
qkv
=
self
.
qkv
(
hid_states
)
bs
,
width
,
length
=
qkv
.
shape
assert
width
%
(
3
*
self
.
n_heads
)
==
0
ch
=
width
//
(
3
*
self
.
n_heads
)
q
,
k
,
v
=
qkv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
3
,
length
).
split
(
ch
,
dim
=
1
)
if
encoder_out
is
not
None
:
encoder_kv
=
self
.
encoder_kv
(
encoder_out
)
assert
encoder_kv
.
shape
[
1
]
==
self
.
n_heads
*
ch
*
2
ek
,
ev
=
encoder_kv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
2
,
-
1
).
split
(
ch
,
dim
=
1
)
k
=
torch
.
cat
([
ek
,
k
],
dim
=-
1
)
v
=
torch
.
cat
([
ev
,
v
],
dim
=-
1
)
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
weight
=
torch
.
einsum
(
"bct,bcs->bts"
,
q
*
scale
,
k
*
scale
)
# More stable with f16 than dividing afterwards
weight
=
torch
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
a
=
torch
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
)
h
=
a
.
reshape
(
bs
,
-
1
,
length
)
h
=
self
.
proj
(
h
)
h
=
h
.
reshape
(
b
,
c
,
*
spatial
)
result
=
x
+
h
result
=
result
/
self
.
rescale_output_factor
return
result
src/diffusers/models/resnet.py
View file @
c1efda70
...
...
@@ -248,7 +248,7 @@ class FirDownsample2D(nn.Module):
return
x
class
ResnetBlock
(
nn
.
Module
):
class
ResnetBlock
2D
(
nn
.
Module
):
def
__init__
(
self
,
*
,
...
...
src/diffusers/models/unet_blocks.py
View file @
c1efda70
...
...
@@ -17,8 +17,8 @@ import numpy as np
import
torch
from
torch
import
nn
from
.attention
import
AttentionBlock
New
,
SpatialTransformer
from
.resnet
import
Downsample2D
,
FirDownsample2D
,
FirUpsample2D
,
ResnetBlock
,
Upsample2D
from
.attention
import
AttentionBlock
,
SpatialTransformer
from
.resnet
import
Downsample2D
,
FirDownsample2D
,
FirUpsample2D
,
ResnetBlock
2D
,
Upsample2D
def
get_down_block
(
...
...
@@ -219,7 +219,7 @@ class UNetMidBlock2D(nn.Module):
# there is always at least one resnet
resnets
=
[
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
temb_channels
=
temb_channels
,
...
...
@@ -236,7 +236,7 @@ class UNetMidBlock2D(nn.Module):
for
_
in
range
(
num_layers
):
attentions
.
append
(
AttentionBlock
New
(
AttentionBlock
(
in_channels
,
num_head_channels
=
attn_num_head_channels
,
rescale_output_factor
=
output_scale_factor
,
...
...
@@ -245,7 +245,7 @@ class UNetMidBlock2D(nn.Module):
)
)
resnets
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
temb_channels
=
temb_channels
,
...
...
@@ -299,7 +299,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
# there is always at least one resnet
resnets
=
[
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
temb_channels
=
temb_channels
,
...
...
@@ -325,7 +325,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
)
)
resnets
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
temb_channels
=
temb_channels
,
...
...
@@ -379,7 +379,7 @@ class AttnDownBlock2D(nn.Module):
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
resnets
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
...
...
@@ -393,7 +393,7 @@ class AttnDownBlock2D(nn.Module):
)
)
attentions
.
append
(
AttentionBlock
New
(
AttentionBlock
(
out_channels
,
num_head_channels
=
attn_num_head_channels
,
rescale_output_factor
=
output_scale_factor
,
...
...
@@ -461,7 +461,7 @@ class CrossAttnDownBlock2D(nn.Module):
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
resnets
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
...
...
@@ -537,7 +537,7 @@ class DownBlock2D(nn.Module):
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
resnets
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
...
...
@@ -602,7 +602,7 @@ class DownEncoderBlock2D(nn.Module):
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
resnets
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
temb_channels
=
None
,
...
...
@@ -664,7 +664,7 @@ class AttnDownEncoderBlock2D(nn.Module):
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
resnets
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
temb_channels
=
None
,
...
...
@@ -678,7 +678,7 @@ class AttnDownEncoderBlock2D(nn.Module):
)
)
attentions
.
append
(
AttentionBlock
New
(
AttentionBlock
(
out_channels
,
num_head_channels
=
attn_num_head_channels
,
rescale_output_factor
=
output_scale_factor
,
...
...
@@ -740,7 +740,7 @@ class AttnSkipDownBlock2D(nn.Module):
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
self
.
resnets
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
...
...
@@ -755,7 +755,7 @@ class AttnSkipDownBlock2D(nn.Module):
)
)
self
.
attentions
.
append
(
AttentionBlock
New
(
AttentionBlock
(
out_channels
,
num_head_channels
=
attn_num_head_channels
,
rescale_output_factor
=
output_scale_factor
,
...
...
@@ -764,7 +764,7 @@ class AttnSkipDownBlock2D(nn.Module):
)
if
add_downsample
:
self
.
resnet_down
=
ResnetBlock
(
self
.
resnet_down
=
ResnetBlock
2D
(
in_channels
=
out_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
...
...
@@ -828,7 +828,7 @@ class SkipDownBlock2D(nn.Module):
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
self
.
resnets
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
...
...
@@ -844,7 +844,7 @@ class SkipDownBlock2D(nn.Module):
)
if
add_downsample
:
self
.
resnet_down
=
ResnetBlock
(
self
.
resnet_down
=
ResnetBlock
2D
(
in_channels
=
out_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
...
...
@@ -915,7 +915,7 @@ class AttnUpBlock2D(nn.Module):
resnet_in_channels
=
prev_output_channel
if
i
==
0
else
out_channels
resnets
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
resnet_in_channels
+
res_skip_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
...
...
@@ -929,7 +929,7 @@ class AttnUpBlock2D(nn.Module):
)
)
attentions
.
append
(
AttentionBlock
New
(
AttentionBlock
(
out_channels
,
num_head_channels
=
attn_num_head_channels
,
rescale_output_factor
=
output_scale_factor
,
...
...
@@ -995,7 +995,7 @@ class CrossAttnUpBlock2D(nn.Module):
resnet_in_channels
=
prev_output_channel
if
i
==
0
else
out_channels
resnets
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
resnet_in_channels
+
res_skip_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
...
...
@@ -1068,7 +1068,7 @@ class UpBlock2D(nn.Module):
resnet_in_channels
=
prev_output_channel
if
i
==
0
else
out_channels
resnets
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
resnet_in_channels
+
res_skip_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
...
...
@@ -1128,7 +1128,7 @@ class UpDecoderBlock2D(nn.Module):
input_channels
=
in_channels
if
i
==
0
else
out_channels
resnets
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
input_channels
,
out_channels
=
out_channels
,
temb_channels
=
None
,
...
...
@@ -1184,7 +1184,7 @@ class AttnUpDecoderBlock2D(nn.Module):
input_channels
=
in_channels
if
i
==
0
else
out_channels
resnets
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
input_channels
,
out_channels
=
out_channels
,
temb_channels
=
None
,
...
...
@@ -1198,7 +1198,7 @@ class AttnUpDecoderBlock2D(nn.Module):
)
)
attentions
.
append
(
AttentionBlock
New
(
AttentionBlock
(
out_channels
,
num_head_channels
=
attn_num_head_channels
,
rescale_output_factor
=
output_scale_factor
,
...
...
@@ -1257,7 +1257,7 @@ class AttnSkipUpBlock2D(nn.Module):
resnet_in_channels
=
prev_output_channel
if
i
==
0
else
out_channels
self
.
resnets
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
resnet_in_channels
+
res_skip_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
...
...
@@ -1273,7 +1273,7 @@ class AttnSkipUpBlock2D(nn.Module):
)
self
.
attentions
.
append
(
AttentionBlock
New
(
AttentionBlock
(
out_channels
,
num_head_channels
=
attn_num_head_channels
,
rescale_output_factor
=
output_scale_factor
,
...
...
@@ -1283,7 +1283,7 @@ class AttnSkipUpBlock2D(nn.Module):
self
.
upsampler
=
FirUpsample2D
(
in_channels
,
out_channels
=
out_channels
)
if
add_upsample
:
self
.
resnet_up
=
ResnetBlock
(
self
.
resnet_up
=
ResnetBlock
2D
(
in_channels
=
out_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
...
...
@@ -1363,7 +1363,7 @@ class SkipUpBlock2D(nn.Module):
resnet_in_channels
=
prev_output_channel
if
i
==
0
else
out_channels
self
.
resnets
.
append
(
ResnetBlock
(
ResnetBlock
2D
(
in_channels
=
resnet_in_channels
+
res_skip_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
...
...
@@ -1380,7 +1380,7 @@ class SkipUpBlock2D(nn.Module):
self
.
upsampler
=
FirUpsample2D
(
in_channels
,
out_channels
=
out_channels
)
if
add_upsample
:
self
.
resnet_up
=
ResnetBlock
(
self
.
resnet_up
=
ResnetBlock
2D
(
in_channels
=
out_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment