Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
bc2d586d
Commit
bc2d586d
authored
Jun 25, 2022
by
Patrick von Platen
Browse files
remove more dependencies
parent
49a81f9f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
106 additions
and
167 deletions
+106
-167
run.py
run.py
+18
-128
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+88
-39
No files found.
run.py
View file @
bc2d586d
...
...
@@ -2,105 +2,14 @@
import
numpy
as
np
import
PIL
import
torch
import
ml_collections
#from configs.ve import ffhq_ncsnpp_continuous as configs
# from configs.ve import cifar10_ncsnpp_continuous as configs
# ffhq_ncsnpp_continuous config
def
get_config
():
config
=
ml_collections
.
ConfigDict
()
# training
config
.
training
=
training
=
ml_collections
.
ConfigDict
()
training
.
batch_size
=
8
training
.
n_iters
=
2400001
training
.
snapshot_freq
=
50000
training
.
log_freq
=
50
training
.
eval_freq
=
100
training
.
snapshot_freq_for_preemption
=
5000
training
.
snapshot_sampling
=
True
training
.
sde
=
'vesde'
training
.
continuous
=
True
training
.
likelihood_weighting
=
False
training
.
reduce_mean
=
True
# sampling
config
.
sampling
=
sampling
=
ml_collections
.
ConfigDict
()
sampling
.
method
=
'pc'
sampling
.
predictor
=
'reverse_diffusion'
sampling
.
corrector
=
'langevin'
sampling
.
probability_flow
=
False
sampling
.
snr
=
0.15
sampling
.
n_steps_each
=
1
sampling
.
noise_removal
=
True
# eval
config
.
eval
=
evaluate
=
ml_collections
.
ConfigDict
()
evaluate
.
batch_size
=
1024
evaluate
.
num_samples
=
50000
evaluate
.
begin_ckpt
=
1
evaluate
.
end_ckpt
=
96
# data
config
.
data
=
data
=
ml_collections
.
ConfigDict
()
data
.
dataset
=
'FFHQ'
data
.
image_size
=
1024
data
.
centered
=
False
data
.
random_flip
=
True
data
.
uniform_dequantization
=
False
data
.
num_channels
=
3
# Plug in your own path to the tfrecords file.
data
.
tfrecords_path
=
'/raid/song/ffhq-dataset/ffhq/ffhq-r10.tfrecords'
# model
config
.
model
=
model
=
ml_collections
.
ConfigDict
()
model
.
name
=
'ncsnpp'
model
.
scale_by_sigma
=
True
model
.
sigma_max
=
1348
model
.
num_scales
=
2000
model
.
ema_rate
=
0.9999
model
.
sigma_min
=
0.01
model
.
normalization
=
'GroupNorm'
model
.
nonlinearity
=
'swish'
model
.
nf
=
16
model
.
ch_mult
=
(
1
,
2
,
4
,
8
,
16
,
32
,
32
,
32
)
model
.
num_res_blocks
=
1
model
.
attn_resolutions
=
(
16
,)
model
.
dropout
=
0.
model
.
resamp_with_conv
=
True
model
.
conditional
=
True
model
.
fir
=
True
model
.
fir_kernel
=
[
1
,
3
,
3
,
1
]
model
.
skip_rescale
=
True
model
.
resblock_type
=
'biggan'
model
.
progressive
=
'output_skip'
model
.
progressive_input
=
'input_skip'
model
.
progressive_combine
=
'sum'
model
.
attention_type
=
'ddpm'
model
.
init_scale
=
0.
model
.
fourier_scale
=
16
model
.
conv_size
=
3
model
.
embedding_type
=
'fourier'
# optim
config
.
optim
=
optim
=
ml_collections
.
ConfigDict
()
optim
.
weight_decay
=
0
optim
.
optimizer
=
'Adam'
optim
.
lr
=
2e-4
optim
.
beta1
=
0.9
optim
.
amsgrad
=
False
optim
.
eps
=
1e-8
optim
.
warmup
=
5000
optim
.
grad_clip
=
1.
config
.
seed
=
42
config
.
device
=
torch
.
device
(
'cuda:0'
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
'cpu'
)
return
config
device
=
torch
.
device
(
'cuda:0'
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
'cpu'
)
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
manual_seed
(
3
)
torch
.
manual_seed
(
0
)
class
NewReverseDiffusionPredictor
:
...
...
@@ -182,46 +91,25 @@ def save_image(x):
# Note usually we need to restore ema etc...
# ema restored checkpoint used from below
config
=
get_config
()
sigma_min
,
sigma_max
=
config
.
model
.
sigma_min
,
config
.
model
.
sigma_max
N
=
config
.
model
.
num_scales
N
=
2
sigma_min
=
0.01
sigma_max
=
1348
sampling_eps
=
1e-5
batch_size
=
1
#@param {"type":"integer"}
config
.
training
.
batch_size
=
batch_size
config
.
eval
.
batch_size
=
batch_size
batch_size
=
1
centered
=
False
from
diffusers
import
NCSNpp
model
=
NCSNpp
(
config
).
to
(
config
.
device
)
model
=
torch
.
nn
.
DataParallel
(
model
)
loaded_state
=
torch
.
load
(
"../score_sde_pytorch/ffhq_1024_ncsnpp_continuous_ema.pt"
)
del
loaded_state
[
"module.sigmas"
]
model
.
load_state_dict
(
loaded_state
,
strict
=
False
)
def
get_data_inverse_scaler
(
config
):
"""Inverse data normalizer."""
if
config
.
data
.
centered
:
# Rescale [-1, 1] to [0, 1]
return
lambda
x
:
(
x
+
1.
)
/
2.
else
:
return
lambda
x
:
x
inverse_scaler
=
get_data_inverse_scaler
(
config
)
model
=
NCSNpp
.
from_pretrained
(
"/home/patrick/ffhq_ncsnpp"
).
to
(
device
)
model
=
torch
.
nn
.
DataParallel
(
model
)
img_size
=
config
.
data
.
image_size
channels
=
config
.
data
.
num_channels
img_size
=
model
.
module
.
config
.
image_size
channels
=
model
.
module
.
config
.
num_channels
shape
=
(
batch_size
,
channels
,
img_size
,
img_size
)
probability_flow
=
False
snr
=
0.15
#@param {"type": "number"}
n_steps
=
1
#@param {"type": "integer"}
snr
=
0.15
n_steps
=
1
device
=
config
.
device
new_corrector
=
NewLangevinCorrector
(
score_fn
=
model
,
snr
=
snr
,
n_steps
=
n_steps
,
sigma_min
=
sigma_min
,
sigma_max
=
sigma_max
)
new_predictor
=
NewReverseDiffusionPredictor
(
score_fn
=
model
,
sigma_min
=
sigma_min
,
sigma_max
=
sigma_max
,
N
=
N
)
...
...
@@ -238,10 +126,12 @@ with torch.no_grad():
x
,
x_mean
=
new_corrector
.
update_fn
(
x
,
vec_t
)
x
,
x_mean
=
new_predictor
.
update_fn
(
x
,
vec_t
)
x
=
inverse_scaler
(
x_mean
)
x
=
x_mean
if
centered
:
x
=
(
x
+
1.
)
/
2.
save_image
(
x
)
#
save_image(x)
# for 5 cifar10
x_sum
=
106071.9922
...
...
@@ -260,4 +150,4 @@ def check_x_sum_x_mean(x, x_sum, x_mean):
assert
(
x
.
abs
().
mean
()
-
x_mean
).
abs
().
cpu
().
item
()
<
1e-4
,
f
"mean wrong
{
x
.
abs
().
mean
()
}
"
#
check_x_sum_x_mean(x, x_sum, x_mean)
check_x_sum_x_mean
(
x
,
x_sum
,
x_mean
)
src/diffusers/models/unet_sde_score_estimation.py
View file @
bc2d586d
...
...
@@ -15,6 +15,9 @@
# helpers functions
from
..modeling_utils
import
ModelMixin
from
..configuration_utils
import
ConfigMixin
import
functools
import
math
...
...
@@ -372,16 +375,16 @@ class NIN(nn.Module):
return
y
.
permute
(
0
,
3
,
1
,
2
)
def
get_act
(
c
on
fig
):
def
get_act
(
n
on
linearity
):
"""Get activation functions from the config file."""
if
config
.
model
.
nonlinearity
.
lower
()
==
"elu"
:
if
nonlinearity
.
lower
()
==
"elu"
:
return
nn
.
ELU
()
elif
config
.
model
.
nonlinearity
.
lower
()
==
"relu"
:
elif
nonlinearity
.
lower
()
==
"relu"
:
return
nn
.
ReLU
()
elif
config
.
model
.
nonlinearity
.
lower
()
==
"lrelu"
:
elif
nonlinearity
.
lower
()
==
"lrelu"
:
return
nn
.
LeakyReLU
(
negative_slope
=
0.2
)
elif
config
.
model
.
nonlinearity
.
lower
()
==
"swish"
:
elif
nonlinearity
.
lower
()
==
"swish"
:
return
nn
.
SiLU
()
else
:
raise
NotImplementedError
(
"activation function does not exist!"
)
...
...
@@ -710,46 +713,93 @@ class ResnetBlockBigGANpp(nn.Module):
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
class
NCSNpp
(
nn
.
Module
):
class
NCSNpp
(
ModelMixin
,
ConfigMixin
):
"""NCSN++ model"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
centered
=
False
,
image_size
=
1024
,
num_channels
=
3
,
attention_type
=
"ddpm"
,
attn_resolutions
=
(
16
,),
ch_mult
=
(
1
,
2
,
4
,
8
,
16
,
32
,
32
,
32
),
conditional
=
True
,
conv_size
=
3
,
dropout
=
0.0
,
embedding_type
=
"fourier"
,
fir
=
True
,
fir_kernel
=
(
1
,
3
,
3
,
1
),
fourier_scale
=
16
,
init_scale
=
0.0
,
nf
=
16
,
nonlinearity
=
"swish"
,
normalization
=
"GroupNorm"
,
num_res_blocks
=
1
,
progressive
=
"output_skip"
,
progressive_combine
=
"sum"
,
progressive_input
=
"input_skip"
,
resamp_with_conv
=
True
,
resblock_type
=
"biggan"
,
scale_by_sigma
=
True
,
skip_rescale
=
True
,
continuous
=
True
,
):
super
().
__init__
()
self
.
config
=
config
self
.
act
=
act
=
get_act
(
config
)
self
.
register_to_config
(
centered
=
centered
,
image_size
=
image_size
,
num_channels
=
num_channels
,
attention_type
=
attention_type
,
attn_resolutions
=
attn_resolutions
,
ch_mult
=
ch_mult
,
conditional
=
conditional
,
conv_size
=
conv_size
,
dropout
=
dropout
,
embedding_type
=
embedding_type
,
fir
=
fir
,
fir_kernel
=
fir_kernel
,
fourier_scale
=
fourier_scale
,
init_scale
=
init_scale
,
nf
=
nf
,
nonlinearity
=
nonlinearity
,
normalization
=
normalization
,
num_res_blocks
=
num_res_blocks
,
progressive
=
progressive
,
progressive_combine
=
progressive_combine
,
progressive_input
=
progressive_input
,
resamp_with_conv
=
resamp_with_conv
,
resblock_type
=
resblock_type
,
scale_by_sigma
=
scale_by_sigma
,
skip_rescale
=
skip_rescale
,
continuous
=
continuous
,
)
self
.
act
=
act
=
get_act
(
nonlinearity
)
# self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config)))
self
.
nf
=
nf
=
config
.
model
.
nf
ch_mult
=
config
.
model
.
ch_mult
self
.
num_res_blocks
=
num_res_blocks
=
config
.
model
.
num_res_blocks
self
.
attn_resolutions
=
attn_resolutions
=
config
.
model
.
attn_resolutions
dropout
=
config
.
model
.
dropout
resamp_with_conv
=
config
.
model
.
resamp_with_conv
self
.
num_resolutions
=
num_resolutions
=
len
(
ch_mult
)
self
.
all_resolutions
=
all_resolutions
=
[
config
.
data
.
image_size
//
(
2
**
i
)
for
i
in
range
(
num_resolutions
)]
self
.
conditional
=
conditional
=
config
.
model
.
conditional
# noise-conditional
fir
=
config
.
model
.
fir
fir_kernel
=
config
.
model
.
fir_kernel
self
.
skip_rescale
=
skip_rescale
=
config
.
model
.
skip_rescale
self
.
resblock_type
=
resblock_type
=
config
.
model
.
resblock_type
.
lower
()
self
.
progressive
=
progressive
=
config
.
model
.
progressive
.
lower
()
self
.
progressive_input
=
progressive_input
=
config
.
model
.
progressive_input
.
lower
()
self
.
embedding_type
=
embedding_type
=
config
.
model
.
embedding_type
.
lower
()
init_scale
=
config
.
model
.
init_scale
self
.
nf
=
nf
self
.
num_res_blocks
=
num_res_blocks
self
.
attn_resolutions
=
attn_resolutions
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
all_resolutions
=
all_resolutions
=
[
image_size
//
(
2
**
i
)
for
i
in
range
(
self
.
num_resolutions
)]
self
.
conditional
=
conditional
self
.
skip_rescale
=
skip_rescale
self
.
resblock_type
=
resblock_type
self
.
progressive
=
progressive
self
.
progressive_input
=
progressive_input
self
.
embedding_type
=
embedding_type
assert
progressive
in
[
"none"
,
"output_skip"
,
"residual"
]
assert
progressive_input
in
[
"none"
,
"input_skip"
,
"residual"
]
assert
embedding_type
in
[
"fourier"
,
"positional"
]
combine_method
=
config
.
model
.
progressive_combine
.
lower
()
combine_method
=
progressive_combine
.
lower
()
combiner
=
functools
.
partial
(
Combine
,
method
=
combine_method
)
modules
=
[]
# timestep/noise_level embedding; only for continuous training
if
embedding_type
==
"fourier"
:
# Gaussian Fourier features embeddings.
assert
config
.
training
.
continuous
,
"Fourier features are only used for continuous training."
modules
.
append
(
GaussianFourierProjection
(
embedding_size
=
nf
,
scale
=
config
.
model
.
fourier_scale
))
modules
.
append
(
GaussianFourierProjection
(
embedding_size
=
nf
,
scale
=
fourier_scale
))
embed_dim
=
2
*
nf
elif
embedding_type
==
"positional"
:
...
...
@@ -809,7 +859,7 @@ class NCSNpp(nn.Module):
# Downsampling block
channels
=
config
.
data
.
num_channels
channels
=
num_channels
if
progressive_input
!=
"none"
:
input_pyramid_ch
=
channels
...
...
@@ -817,7 +867,7 @@ class NCSNpp(nn.Module):
hs_c
=
[
nf
]
in_ch
=
nf
for
i_level
in
range
(
num_resolutions
):
for
i_level
in
range
(
self
.
num_resolutions
):
# Residual blocks for this resolution
for
i_block
in
range
(
num_res_blocks
):
out_ch
=
nf
*
ch_mult
[
i_level
]
...
...
@@ -828,7 +878,7 @@ class NCSNpp(nn.Module):
modules
.
append
(
AttnBlock
(
channels
=
in_ch
))
hs_c
.
append
(
in_ch
)
if
i_level
!=
num_resolutions
-
1
:
if
i_level
!=
self
.
num_resolutions
-
1
:
if
resblock_type
==
"ddpm"
:
modules
.
append
(
Downsample
(
in_ch
=
in_ch
))
else
:
...
...
@@ -852,7 +902,7 @@ class NCSNpp(nn.Module):
pyramid_ch
=
0
# Upsampling block
for
i_level
in
reversed
(
range
(
num_resolutions
)):
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
num_res_blocks
+
1
):
out_ch
=
nf
*
ch_mult
[
i_level
]
modules
.
append
(
ResnetBlock
(
in_ch
=
in_ch
+
hs_c
.
pop
(),
out_ch
=
out_ch
))
...
...
@@ -862,7 +912,7 @@ class NCSNpp(nn.Module):
modules
.
append
(
AttnBlock
(
channels
=
in_ch
))
if
progressive
!=
"none"
:
if
i_level
==
num_resolutions
-
1
:
if
i_level
==
self
.
num_resolutions
-
1
:
if
progressive
==
"output_skip"
:
modules
.
append
(
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
))
modules
.
append
(
conv3x3
(
in_ch
,
channels
,
init_scale
=
init_scale
))
...
...
@@ -899,7 +949,6 @@ class NCSNpp(nn.Module):
self
.
all_modules
=
nn
.
ModuleList
(
modules
)
def
forward
(
self
,
x
,
time_cond
):
# import ipdb; ipdb.set_trace()
# timestep/noise_level embedding; only for continuous training
modules
=
self
.
all_modules
m_idx
=
0
...
...
@@ -926,7 +975,7 @@ class NCSNpp(nn.Module):
else
:
temb
=
None
if
not
self
.
config
.
data
.
centered
:
if
not
self
.
config
.
centered
:
# If input data is in [0, 1]
x
=
2
*
x
-
1.0
...
...
@@ -1044,7 +1093,7 @@ class NCSNpp(nn.Module):
m_idx
+=
1
assert
m_idx
==
len
(
modules
)
if
self
.
config
.
model
.
scale_by_sigma
:
if
self
.
config
.
scale_by_sigma
:
used_sigmas
=
used_sigmas
.
reshape
((
x
.
shape
[
0
],
*
([
1
]
*
len
(
x
.
shape
[
1
:]))))
h
=
h
/
used_sigmas
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment