Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
bc2d586d
Commit
bc2d586d
authored
Jun 25, 2022
by
Patrick von Platen
Browse files
remove more dependencies
parent
49a81f9f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
106 additions
and
167 deletions
+106
-167
run.py
run.py
+18
-128
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+88
-39
No files found.
run.py
View file @
bc2d586d
...
@@ -2,105 +2,14 @@
...
@@ -2,105 +2,14 @@
import
numpy
as
np
import
numpy
as
np
import
PIL
import
PIL
import
torch
import
torch
import
ml_collections
#from configs.ve import ffhq_ncsnpp_continuous as configs
#from configs.ve import ffhq_ncsnpp_continuous as configs
# from configs.ve import cifar10_ncsnpp_continuous as configs
# from configs.ve import cifar10_ncsnpp_continuous as configs
# ffhq_ncsnpp_continuous config
device
=
torch
.
device
(
'cuda:0'
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
'cpu'
)
def
get_config
():
config
=
ml_collections
.
ConfigDict
()
# training
config
.
training
=
training
=
ml_collections
.
ConfigDict
()
training
.
batch_size
=
8
training
.
n_iters
=
2400001
training
.
snapshot_freq
=
50000
training
.
log_freq
=
50
training
.
eval_freq
=
100
training
.
snapshot_freq_for_preemption
=
5000
training
.
snapshot_sampling
=
True
training
.
sde
=
'vesde'
training
.
continuous
=
True
training
.
likelihood_weighting
=
False
training
.
reduce_mean
=
True
# sampling
config
.
sampling
=
sampling
=
ml_collections
.
ConfigDict
()
sampling
.
method
=
'pc'
sampling
.
predictor
=
'reverse_diffusion'
sampling
.
corrector
=
'langevin'
sampling
.
probability_flow
=
False
sampling
.
snr
=
0.15
sampling
.
n_steps_each
=
1
sampling
.
noise_removal
=
True
# eval
config
.
eval
=
evaluate
=
ml_collections
.
ConfigDict
()
evaluate
.
batch_size
=
1024
evaluate
.
num_samples
=
50000
evaluate
.
begin_ckpt
=
1
evaluate
.
end_ckpt
=
96
# data
config
.
data
=
data
=
ml_collections
.
ConfigDict
()
data
.
dataset
=
'FFHQ'
data
.
image_size
=
1024
data
.
centered
=
False
data
.
random_flip
=
True
data
.
uniform_dequantization
=
False
data
.
num_channels
=
3
# Plug in your own path to the tfrecords file.
data
.
tfrecords_path
=
'/raid/song/ffhq-dataset/ffhq/ffhq-r10.tfrecords'
# model
config
.
model
=
model
=
ml_collections
.
ConfigDict
()
model
.
name
=
'ncsnpp'
model
.
scale_by_sigma
=
True
model
.
sigma_max
=
1348
model
.
num_scales
=
2000
model
.
ema_rate
=
0.9999
model
.
sigma_min
=
0.01
model
.
normalization
=
'GroupNorm'
model
.
nonlinearity
=
'swish'
model
.
nf
=
16
model
.
ch_mult
=
(
1
,
2
,
4
,
8
,
16
,
32
,
32
,
32
)
model
.
num_res_blocks
=
1
model
.
attn_resolutions
=
(
16
,)
model
.
dropout
=
0.
model
.
resamp_with_conv
=
True
model
.
conditional
=
True
model
.
fir
=
True
model
.
fir_kernel
=
[
1
,
3
,
3
,
1
]
model
.
skip_rescale
=
True
model
.
resblock_type
=
'biggan'
model
.
progressive
=
'output_skip'
model
.
progressive_input
=
'input_skip'
model
.
progressive_combine
=
'sum'
model
.
attention_type
=
'ddpm'
model
.
init_scale
=
0.
model
.
fourier_scale
=
16
model
.
conv_size
=
3
model
.
embedding_type
=
'fourier'
# optim
config
.
optim
=
optim
=
ml_collections
.
ConfigDict
()
optim
.
weight_decay
=
0
optim
.
optimizer
=
'Adam'
optim
.
lr
=
2e-4
optim
.
beta1
=
0.9
optim
.
amsgrad
=
False
optim
.
eps
=
1e-8
optim
.
warmup
=
5000
optim
.
grad_clip
=
1.
config
.
seed
=
42
config
.
device
=
torch
.
device
(
'cuda:0'
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
'cpu'
)
return
config
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
manual_seed
(
3
)
torch
.
manual_seed
(
0
)
class
NewReverseDiffusionPredictor
:
class
NewReverseDiffusionPredictor
:
...
@@ -182,46 +91,25 @@ def save_image(x):
...
@@ -182,46 +91,25 @@ def save_image(x):
# Note usually we need to restore ema etc...
# Note usually we need to restore ema etc...
# ema restored checkpoint used from below
# ema restored checkpoint used from below
N
=
2
sigma_min
=
0.01
config
=
get_config
()
sigma_max
=
1348
sigma_min
,
sigma_max
=
config
.
model
.
sigma_min
,
config
.
model
.
sigma_max
N
=
config
.
model
.
num_scales
sampling_eps
=
1e-5
sampling_eps
=
1e-5
batch_size
=
1
batch_size
=
1
#@param {"type":"integer"}
centered
=
False
config
.
training
.
batch_size
=
batch_size
config
.
eval
.
batch_size
=
batch_size
from
diffusers
import
NCSNpp
from
diffusers
import
NCSNpp
model
=
NCSNpp
(
config
).
to
(
config
.
device
)
model
=
torch
.
nn
.
DataParallel
(
model
)
loaded_state
=
torch
.
load
(
"../score_sde_pytorch/ffhq_1024_ncsnpp_continuous_ema.pt"
)
del
loaded_state
[
"module.sigmas"
]
model
.
load_state_dict
(
loaded_state
,
strict
=
False
)
def
get_data_inverse_scaler
(
config
):
"""Inverse data normalizer."""
if
config
.
data
.
centered
:
# Rescale [-1, 1] to [0, 1]
return
lambda
x
:
(
x
+
1.
)
/
2.
else
:
return
lambda
x
:
x
inverse_scaler
=
get_data_inverse_scaler
(
config
)
model
=
NCSNpp
.
from_pretrained
(
"/home/patrick/ffhq_ncsnpp"
).
to
(
device
)
model
=
torch
.
nn
.
DataParallel
(
model
)
img_size
=
config
.
data
.
image_size
img_size
=
model
.
module
.
config
.
image_size
channels
=
config
.
data
.
num_channels
channels
=
model
.
module
.
config
.
num_channels
shape
=
(
batch_size
,
channels
,
img_size
,
img_size
)
shape
=
(
batch_size
,
channels
,
img_size
,
img_size
)
probability_flow
=
False
probability_flow
=
False
snr
=
0.15
#@param {"type": "number"}
snr
=
0.15
n_steps
=
1
#@param {"type": "integer"}
n_steps
=
1
device
=
config
.
device
new_corrector
=
NewLangevinCorrector
(
score_fn
=
model
,
snr
=
snr
,
n_steps
=
n_steps
,
sigma_min
=
sigma_min
,
sigma_max
=
sigma_max
)
new_corrector
=
NewLangevinCorrector
(
score_fn
=
model
,
snr
=
snr
,
n_steps
=
n_steps
,
sigma_min
=
sigma_min
,
sigma_max
=
sigma_max
)
new_predictor
=
NewReverseDiffusionPredictor
(
score_fn
=
model
,
sigma_min
=
sigma_min
,
sigma_max
=
sigma_max
,
N
=
N
)
new_predictor
=
NewReverseDiffusionPredictor
(
score_fn
=
model
,
sigma_min
=
sigma_min
,
sigma_max
=
sigma_max
,
N
=
N
)
...
@@ -238,10 +126,12 @@ with torch.no_grad():
...
@@ -238,10 +126,12 @@ with torch.no_grad():
x
,
x_mean
=
new_corrector
.
update_fn
(
x
,
vec_t
)
x
,
x_mean
=
new_corrector
.
update_fn
(
x
,
vec_t
)
x
,
x_mean
=
new_predictor
.
update_fn
(
x
,
vec_t
)
x
,
x_mean
=
new_predictor
.
update_fn
(
x
,
vec_t
)
x
=
inverse_scaler
(
x_mean
)
x
=
x_mean
if
centered
:
x
=
(
x
+
1.
)
/
2.
save_image
(
x
)
#
save_image(x)
# for 5 cifar10
# for 5 cifar10
x_sum
=
106071.9922
x_sum
=
106071.9922
...
@@ -260,4 +150,4 @@ def check_x_sum_x_mean(x, x_sum, x_mean):
...
@@ -260,4 +150,4 @@ def check_x_sum_x_mean(x, x_sum, x_mean):
assert
(
x
.
abs
().
mean
()
-
x_mean
).
abs
().
cpu
().
item
()
<
1e-4
,
f
"mean wrong
{
x
.
abs
().
mean
()
}
"
assert
(
x
.
abs
().
mean
()
-
x_mean
).
abs
().
cpu
().
item
()
<
1e-4
,
f
"mean wrong
{
x
.
abs
().
mean
()
}
"
#
check_x_sum_x_mean(x, x_sum, x_mean)
check_x_sum_x_mean
(
x
,
x_sum
,
x_mean
)
src/diffusers/models/unet_sde_score_estimation.py
View file @
bc2d586d
...
@@ -15,6 +15,9 @@
...
@@ -15,6 +15,9 @@
# helpers functions
# helpers functions
from
..modeling_utils
import
ModelMixin
from
..configuration_utils
import
ConfigMixin
import
functools
import
functools
import
math
import
math
...
@@ -372,16 +375,16 @@ class NIN(nn.Module):
...
@@ -372,16 +375,16 @@ class NIN(nn.Module):
return
y
.
permute
(
0
,
3
,
1
,
2
)
return
y
.
permute
(
0
,
3
,
1
,
2
)
def
get_act
(
c
on
fig
):
def
get_act
(
n
on
linearity
):
"""Get activation functions from the config file."""
"""Get activation functions from the config file."""
if
config
.
model
.
nonlinearity
.
lower
()
==
"elu"
:
if
nonlinearity
.
lower
()
==
"elu"
:
return
nn
.
ELU
()
return
nn
.
ELU
()
elif
config
.
model
.
nonlinearity
.
lower
()
==
"relu"
:
elif
nonlinearity
.
lower
()
==
"relu"
:
return
nn
.
ReLU
()
return
nn
.
ReLU
()
elif
config
.
model
.
nonlinearity
.
lower
()
==
"lrelu"
:
elif
nonlinearity
.
lower
()
==
"lrelu"
:
return
nn
.
LeakyReLU
(
negative_slope
=
0.2
)
return
nn
.
LeakyReLU
(
negative_slope
=
0.2
)
elif
config
.
model
.
nonlinearity
.
lower
()
==
"swish"
:
elif
nonlinearity
.
lower
()
==
"swish"
:
return
nn
.
SiLU
()
return
nn
.
SiLU
()
else
:
else
:
raise
NotImplementedError
(
"activation function does not exist!"
)
raise
NotImplementedError
(
"activation function does not exist!"
)
...
@@ -710,46 +713,93 @@ class ResnetBlockBigGANpp(nn.Module):
...
@@ -710,46 +713,93 @@ class ResnetBlockBigGANpp(nn.Module):
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
class
NCSNpp
(
nn
.
Module
):
class
NCSNpp
(
ModelMixin
,
ConfigMixin
):
"""NCSN++ model"""
"""NCSN++ model"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
centered
=
False
,
image_size
=
1024
,
num_channels
=
3
,
attention_type
=
"ddpm"
,
attn_resolutions
=
(
16
,),
ch_mult
=
(
1
,
2
,
4
,
8
,
16
,
32
,
32
,
32
),
conditional
=
True
,
conv_size
=
3
,
dropout
=
0.0
,
embedding_type
=
"fourier"
,
fir
=
True
,
fir_kernel
=
(
1
,
3
,
3
,
1
),
fourier_scale
=
16
,
init_scale
=
0.0
,
nf
=
16
,
nonlinearity
=
"swish"
,
normalization
=
"GroupNorm"
,
num_res_blocks
=
1
,
progressive
=
"output_skip"
,
progressive_combine
=
"sum"
,
progressive_input
=
"input_skip"
,
resamp_with_conv
=
True
,
resblock_type
=
"biggan"
,
scale_by_sigma
=
True
,
skip_rescale
=
True
,
continuous
=
True
,
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
register_to_config
(
self
.
act
=
act
=
get_act
(
config
)
centered
=
centered
,
image_size
=
image_size
,
num_channels
=
num_channels
,
attention_type
=
attention_type
,
attn_resolutions
=
attn_resolutions
,
ch_mult
=
ch_mult
,
conditional
=
conditional
,
conv_size
=
conv_size
,
dropout
=
dropout
,
embedding_type
=
embedding_type
,
fir
=
fir
,
fir_kernel
=
fir_kernel
,
fourier_scale
=
fourier_scale
,
init_scale
=
init_scale
,
nf
=
nf
,
nonlinearity
=
nonlinearity
,
normalization
=
normalization
,
num_res_blocks
=
num_res_blocks
,
progressive
=
progressive
,
progressive_combine
=
progressive_combine
,
progressive_input
=
progressive_input
,
resamp_with_conv
=
resamp_with_conv
,
resblock_type
=
resblock_type
,
scale_by_sigma
=
scale_by_sigma
,
skip_rescale
=
skip_rescale
,
continuous
=
continuous
,
)
self
.
act
=
act
=
get_act
(
nonlinearity
)
# self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config)))
# self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config)))
self
.
nf
=
nf
=
config
.
model
.
nf
self
.
nf
=
nf
ch_mult
=
config
.
model
.
ch_mult
self
.
num_res_blocks
=
num_res_blocks
self
.
num_res_blocks
=
num_res_blocks
=
config
.
model
.
num_res_blocks
self
.
attn_resolutions
=
attn_resolutions
self
.
attn_resolutions
=
attn_resolutions
=
config
.
model
.
attn_resolutions
self
.
num_resolutions
=
len
(
ch_mult
)
dropout
=
config
.
model
.
dropout
self
.
all_resolutions
=
all_resolutions
=
[
image_size
//
(
2
**
i
)
for
i
in
range
(
self
.
num_resolutions
)]
resamp_with_conv
=
config
.
model
.
resamp_with_conv
self
.
num_resolutions
=
num_resolutions
=
len
(
ch_mult
)
self
.
conditional
=
conditional
self
.
all_resolutions
=
all_resolutions
=
[
config
.
data
.
image_size
//
(
2
**
i
)
for
i
in
range
(
num_resolutions
)]
self
.
skip_rescale
=
skip_rescale
self
.
resblock_type
=
resblock_type
self
.
conditional
=
conditional
=
config
.
model
.
conditional
# noise-conditional
self
.
progressive
=
progressive
fir
=
config
.
model
.
fir
self
.
progressive_input
=
progressive_input
fir_kernel
=
config
.
model
.
fir_kernel
self
.
embedding_type
=
embedding_type
self
.
skip_rescale
=
skip_rescale
=
config
.
model
.
skip_rescale
self
.
resblock_type
=
resblock_type
=
config
.
model
.
resblock_type
.
lower
()
self
.
progressive
=
progressive
=
config
.
model
.
progressive
.
lower
()
self
.
progressive_input
=
progressive_input
=
config
.
model
.
progressive_input
.
lower
()
self
.
embedding_type
=
embedding_type
=
config
.
model
.
embedding_type
.
lower
()
init_scale
=
config
.
model
.
init_scale
assert
progressive
in
[
"none"
,
"output_skip"
,
"residual"
]
assert
progressive
in
[
"none"
,
"output_skip"
,
"residual"
]
assert
progressive_input
in
[
"none"
,
"input_skip"
,
"residual"
]
assert
progressive_input
in
[
"none"
,
"input_skip"
,
"residual"
]
assert
embedding_type
in
[
"fourier"
,
"positional"
]
assert
embedding_type
in
[
"fourier"
,
"positional"
]
combine_method
=
config
.
model
.
progressive_combine
.
lower
()
combine_method
=
progressive_combine
.
lower
()
combiner
=
functools
.
partial
(
Combine
,
method
=
combine_method
)
combiner
=
functools
.
partial
(
Combine
,
method
=
combine_method
)
modules
=
[]
modules
=
[]
# timestep/noise_level embedding; only for continuous training
# timestep/noise_level embedding; only for continuous training
if
embedding_type
==
"fourier"
:
if
embedding_type
==
"fourier"
:
# Gaussian Fourier features embeddings.
# Gaussian Fourier features embeddings.
assert
config
.
training
.
continuous
,
"Fourier features are only used for continuous training."
modules
.
append
(
GaussianFourierProjection
(
embedding_size
=
nf
,
scale
=
fourier_scale
))
modules
.
append
(
GaussianFourierProjection
(
embedding_size
=
nf
,
scale
=
config
.
model
.
fourier_scale
))
embed_dim
=
2
*
nf
embed_dim
=
2
*
nf
elif
embedding_type
==
"positional"
:
elif
embedding_type
==
"positional"
:
...
@@ -809,7 +859,7 @@ class NCSNpp(nn.Module):
...
@@ -809,7 +859,7 @@ class NCSNpp(nn.Module):
# Downsampling block
# Downsampling block
channels
=
config
.
data
.
num_channels
channels
=
num_channels
if
progressive_input
!=
"none"
:
if
progressive_input
!=
"none"
:
input_pyramid_ch
=
channels
input_pyramid_ch
=
channels
...
@@ -817,7 +867,7 @@ class NCSNpp(nn.Module):
...
@@ -817,7 +867,7 @@ class NCSNpp(nn.Module):
hs_c
=
[
nf
]
hs_c
=
[
nf
]
in_ch
=
nf
in_ch
=
nf
for
i_level
in
range
(
num_resolutions
):
for
i_level
in
range
(
self
.
num_resolutions
):
# Residual blocks for this resolution
# Residual blocks for this resolution
for
i_block
in
range
(
num_res_blocks
):
for
i_block
in
range
(
num_res_blocks
):
out_ch
=
nf
*
ch_mult
[
i_level
]
out_ch
=
nf
*
ch_mult
[
i_level
]
...
@@ -828,7 +878,7 @@ class NCSNpp(nn.Module):
...
@@ -828,7 +878,7 @@ class NCSNpp(nn.Module):
modules
.
append
(
AttnBlock
(
channels
=
in_ch
))
modules
.
append
(
AttnBlock
(
channels
=
in_ch
))
hs_c
.
append
(
in_ch
)
hs_c
.
append
(
in_ch
)
if
i_level
!=
num_resolutions
-
1
:
if
i_level
!=
self
.
num_resolutions
-
1
:
if
resblock_type
==
"ddpm"
:
if
resblock_type
==
"ddpm"
:
modules
.
append
(
Downsample
(
in_ch
=
in_ch
))
modules
.
append
(
Downsample
(
in_ch
=
in_ch
))
else
:
else
:
...
@@ -852,7 +902,7 @@ class NCSNpp(nn.Module):
...
@@ -852,7 +902,7 @@ class NCSNpp(nn.Module):
pyramid_ch
=
0
pyramid_ch
=
0
# Upsampling block
# Upsampling block
for
i_level
in
reversed
(
range
(
num_resolutions
)):
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
num_res_blocks
+
1
):
for
i_block
in
range
(
num_res_blocks
+
1
):
out_ch
=
nf
*
ch_mult
[
i_level
]
out_ch
=
nf
*
ch_mult
[
i_level
]
modules
.
append
(
ResnetBlock
(
in_ch
=
in_ch
+
hs_c
.
pop
(),
out_ch
=
out_ch
))
modules
.
append
(
ResnetBlock
(
in_ch
=
in_ch
+
hs_c
.
pop
(),
out_ch
=
out_ch
))
...
@@ -862,7 +912,7 @@ class NCSNpp(nn.Module):
...
@@ -862,7 +912,7 @@ class NCSNpp(nn.Module):
modules
.
append
(
AttnBlock
(
channels
=
in_ch
))
modules
.
append
(
AttnBlock
(
channels
=
in_ch
))
if
progressive
!=
"none"
:
if
progressive
!=
"none"
:
if
i_level
==
num_resolutions
-
1
:
if
i_level
==
self
.
num_resolutions
-
1
:
if
progressive
==
"output_skip"
:
if
progressive
==
"output_skip"
:
modules
.
append
(
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
))
modules
.
append
(
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
))
modules
.
append
(
conv3x3
(
in_ch
,
channels
,
init_scale
=
init_scale
))
modules
.
append
(
conv3x3
(
in_ch
,
channels
,
init_scale
=
init_scale
))
...
@@ -899,7 +949,6 @@ class NCSNpp(nn.Module):
...
@@ -899,7 +949,6 @@ class NCSNpp(nn.Module):
self
.
all_modules
=
nn
.
ModuleList
(
modules
)
self
.
all_modules
=
nn
.
ModuleList
(
modules
)
def
forward
(
self
,
x
,
time_cond
):
def
forward
(
self
,
x
,
time_cond
):
# import ipdb; ipdb.set_trace()
# timestep/noise_level embedding; only for continuous training
# timestep/noise_level embedding; only for continuous training
modules
=
self
.
all_modules
modules
=
self
.
all_modules
m_idx
=
0
m_idx
=
0
...
@@ -926,7 +975,7 @@ class NCSNpp(nn.Module):
...
@@ -926,7 +975,7 @@ class NCSNpp(nn.Module):
else
:
else
:
temb
=
None
temb
=
None
if
not
self
.
config
.
data
.
centered
:
if
not
self
.
config
.
centered
:
# If input data is in [0, 1]
# If input data is in [0, 1]
x
=
2
*
x
-
1.0
x
=
2
*
x
-
1.0
...
@@ -1044,7 +1093,7 @@ class NCSNpp(nn.Module):
...
@@ -1044,7 +1093,7 @@ class NCSNpp(nn.Module):
m_idx
+=
1
m_idx
+=
1
assert
m_idx
==
len
(
modules
)
assert
m_idx
==
len
(
modules
)
if
self
.
config
.
model
.
scale_by_sigma
:
if
self
.
config
.
scale_by_sigma
:
used_sigmas
=
used_sigmas
.
reshape
((
x
.
shape
[
0
],
*
([
1
]
*
len
(
x
.
shape
[
1
:]))))
used_sigmas
=
used_sigmas
.
reshape
((
x
.
shape
[
0
],
*
([
1
]
*
len
(
x
.
shape
[
1
:]))))
h
=
h
/
used_sigmas
h
=
h
/
used_sigmas
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment