Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
f1aade05
Commit
f1aade05
authored
Jul 01, 2022
by
Patrick von Platen
Browse files
up
parent
abedfb08
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
1946 additions
and
4 deletions
+1946
-4
@
@
+859
-0
G
G
+883
-0
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+202
-3
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+2
-1
No files found.
@
0 → 100644
View file @
f1aade05
This diff is collapsed.
Click to expand it.
G
0 → 100644
View file @
f1aade05
This diff is collapsed.
Click to expand it.
src/diffusers/models/resnet.py
View file @
f1aade05
...
...
@@ -160,6 +160,208 @@ class Downsample(nn.Module):
# RESNETS
# unet_score_estimation.py
class
ResnetBlockBigGANppNew
(
nn
.
Module
):
def
__init__
(
self
,
act
,
in_ch
,
out_ch
=
None
,
temb_dim
=
None
,
up
=
False
,
down
=
False
,
dropout
=
0.1
,
fir_kernel
=
(
1
,
3
,
3
,
1
),
skip_rescale
=
True
,
init_scale
=
0.0
,
overwrite
=
True
,
):
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
self
.
up
=
up
self
.
down
=
down
self
.
fir_kernel
=
fir_kernel
self
.
Conv_0
=
conv2d
(
in_ch
,
out_ch
,
kernel_size
=
3
,
padding
=
1
)
if
temb_dim
is
not
None
:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
self
.
Dense_0
.
weight
.
data
=
variance_scaling
()(
self
.
Dense_0
.
weight
.
shape
)
nn
.
init
.
zeros_
(
self
.
Dense_0
.
bias
)
self
.
GroupNorm_1
=
nn
.
GroupNorm
(
num_groups
=
min
(
out_ch
//
4
,
32
),
num_channels
=
out_ch
,
eps
=
1e-6
)
self
.
Dropout_0
=
nn
.
Dropout
(
dropout
)
self
.
Conv_1
=
conv2d
(
out_ch
,
out_ch
,
init_scale
=
init_scale
,
kernel_size
=
3
,
padding
=
1
)
if
in_ch
!=
out_ch
or
up
or
down
:
# 1x1 convolution with DDPM initialization.
self
.
Conv_2
=
conv2d
(
in_ch
,
out_ch
,
kernel_size
=
1
,
padding
=
0
)
self
.
skip_rescale
=
skip_rescale
self
.
act
=
act
self
.
in_ch
=
in_ch
self
.
out_ch
=
out_ch
self
.
is_overwritten
=
False
self
.
overwrite
=
overwrite
if
overwrite
:
self
.
output_scale_factor
=
np
.
sqrt
(
2.0
)
self
.
in_channels
=
in_channels
=
in_ch
self
.
out_channels
=
out_channels
=
out_ch
groups
=
min
(
in_ch
//
4
,
32
)
out_groups
=
min
(
out_ch
//
4
,
32
)
eps
=
1e-6
self
.
pre_norm
=
True
temb_channels
=
temb_dim
non_linearity
=
"silu"
self
.
time_embedding_norm
=
time_embedding_norm
=
"default"
if
self
.
pre_norm
:
self
.
norm1
=
Normalize
(
in_channels
,
num_groups
=
groups
,
eps
=
eps
)
else
:
self
.
norm1
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
time_embedding_norm
==
"default"
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
elif
time_embedding_norm
==
"scale_shift"
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
2
*
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
,
num_groups
=
out_groups
,
eps
=
eps
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
non_linearity
==
"swish"
:
self
.
nonlinearity
=
nonlinearity
elif
non_linearity
==
"mish"
:
self
.
nonlinearity
=
Mish
()
elif
non_linearity
==
"silu"
:
self
.
nonlinearity
=
nn
.
SiLU
()
if
up
:
self
.
h_upd
=
Upsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
)
self
.
x_upd
=
Upsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
)
elif
down
:
self
.
h_upd
=
Downsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
,
padding
=
1
,
name
=
"op"
)
self
.
x_upd
=
Downsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
,
padding
=
1
,
name
=
"op"
)
if
self
.
in_channels
!=
self
.
out_channels
or
self
.
up
or
self
.
down
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
set_weights
(
self
):
self
.
conv1
.
weight
.
data
=
self
.
Conv_0
.
weight
.
data
self
.
conv1
.
bias
.
data
=
self
.
Conv_0
.
bias
.
data
self
.
norm1
.
weight
.
data
=
self
.
GroupNorm_0
.
weight
.
data
self
.
norm1
.
bias
.
data
=
self
.
GroupNorm_0
.
bias
.
data
self
.
conv2
.
weight
.
data
=
self
.
Conv_1
.
weight
.
data
self
.
conv2
.
bias
.
data
=
self
.
Conv_1
.
bias
.
data
self
.
norm2
.
weight
.
data
=
self
.
GroupNorm_1
.
weight
.
data
self
.
norm2
.
bias
.
data
=
self
.
GroupNorm_1
.
bias
.
data
self
.
temb_proj
.
weight
.
data
=
self
.
Dense_0
.
weight
.
data
self
.
temb_proj
.
bias
.
data
=
self
.
Dense_0
.
bias
.
data
if
self
.
in_channels
!=
self
.
out_channels
or
self
.
up
or
self
.
down
:
self
.
nin_shortcut
.
weight
.
data
=
self
.
Conv_2
.
weight
.
data
self
.
nin_shortcut
.
bias
.
data
=
self
.
Conv_2
.
bias
.
data
def
forward
(
self
,
x
,
temb
=
None
):
if
self
.
overwrite
and
not
self
.
is_overwritten
:
self
.
set_weights
()
self
.
is_overwritten
=
True
orig_x
=
x
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
if
self
.
up
:
h
=
upsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
x
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
elif
self
.
down
:
h
=
downsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
h
=
self
.
Conv_0
(
h
)
# Add bias to each feature map conditioned on the time embedding
if
temb
is
not
None
:
h
+=
self
.
Dense_0
(
self
.
act
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
act
(
self
.
GroupNorm_1
(
h
))
h
=
self
.
Dropout_0
(
h
)
h
=
self
.
Conv_1
(
h
)
if
self
.
in_ch
!=
self
.
out_ch
or
self
.
up
or
self
.
down
:
x
=
self
.
Conv_2
(
x
)
if
not
self
.
skip_rescale
:
raise
ValueError
(
"Is this branch run?!"
)
# import ipdb; ipdb.set_trace()
result
=
x
+
h
else
:
result
=
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
result_2
=
self
.
forward_2
(
orig_x
,
temb
)
return
result_2
def
forward_2
(
self
,
x
,
temb
,
mask
=
1.0
):
h
=
x
h
=
h
*
mask
if
self
.
pre_norm
:
h
=
self
.
norm1
(
h
)
h
=
self
.
nonlinearity
(
h
)
# if self.up or self.down:
# x = self.x_upd(x)
# h = self.h_upd(h)
if
self
.
up
:
h
=
upsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
x
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
elif
self
.
down
:
h
=
downsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
h
=
self
.
conv1
(
h
)
if
not
self
.
pre_norm
:
h
=
self
.
norm1
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
h
*
mask
temb
=
self
.
temb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
if
self
.
time_embedding_norm
==
"scale_shift"
:
scale
,
shift
=
torch
.
chunk
(
temb
,
2
,
dim
=
1
)
h
=
self
.
norm2
(
h
)
h
=
h
+
h
*
scale
+
shift
h
=
self
.
nonlinearity
(
h
)
elif
self
.
time_embedding_norm
==
"default"
:
h
=
h
+
temb
h
=
h
*
mask
if
self
.
pre_norm
:
h
=
self
.
norm2
(
h
)
h
=
self
.
nonlinearity
(
h
)
else
:
raise
ValueError
(
"Nananan nanana - don't go here!"
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
if
not
self
.
pre_norm
:
h
=
self
.
norm2
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
h
*
mask
x
=
x
*
mask
# if self.in_channels != self.out_channels:
if
self
.
in_channels
!=
self
.
out_channels
or
self
.
up
or
self
.
down
:
x
=
self
.
nin_shortcut
(
x
)
result
=
x
+
h
return
result
/
self
.
output_scale_factor
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py
class
ResnetBlock
(
nn
.
Module
):
...
...
@@ -245,9 +447,6 @@ class ResnetBlock(nn.Module):
self
.
res_conv
=
torch
.
nn
.
Identity
()
elif
self
.
overwrite_for_ldm
:
dims
=
2
# eps = 1e-5
# non_linearity = "silu"
# overwrite_for_ldm
channels
=
in_channels
emb_channels
=
temb_channels
use_scale_shift_norm
=
False
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
f1aade05
...
...
@@ -27,7 +27,8 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.resnet
import
ResnetBlockBigGANpp
,
downsample_2d
,
upfirdn2d
,
upsample_2d
from
.resnet
import
downsample_2d
,
upfirdn2d
,
upsample_2d
from
.resnet
import
ResnetBlockBigGANppNew
as
ResnetBlockBigGANpp
def
_setup_kernel
(
k
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment