Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
bc2d586d
Commit
bc2d586d
authored
Jun 25, 2022
by
Patrick von Platen
Browse files
remove more dependencies
parent
49a81f9f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
106 additions
and
167 deletions
+106
-167
run.py
run.py
+18
-128
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+88
-39
No files found.
run.py
View file @
bc2d586d
...
@@ -2,105 +2,14 @@
...
@@ -2,105 +2,14 @@
import
numpy
as
np
import
numpy
as
np
import
PIL
import
PIL
import
torch
import
torch
import
ml_collections
#from configs.ve import ffhq_ncsnpp_continuous as configs
#from configs.ve import ffhq_ncsnpp_continuous as configs
# from configs.ve import cifar10_ncsnpp_continuous as configs
# from configs.ve import cifar10_ncsnpp_continuous as configs
# ffhq_ncsnpp_continuous config
device
=
torch
.
device
(
'cuda:0'
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
'cpu'
)
def
get_config
():
config
=
ml_collections
.
ConfigDict
()
# training
config
.
training
=
training
=
ml_collections
.
ConfigDict
()
training
.
batch_size
=
8
training
.
n_iters
=
2400001
training
.
snapshot_freq
=
50000
training
.
log_freq
=
50
training
.
eval_freq
=
100
training
.
snapshot_freq_for_preemption
=
5000
training
.
snapshot_sampling
=
True
training
.
sde
=
'vesde'
training
.
continuous
=
True
training
.
likelihood_weighting
=
False
training
.
reduce_mean
=
True
# sampling
config
.
sampling
=
sampling
=
ml_collections
.
ConfigDict
()
sampling
.
method
=
'pc'
sampling
.
predictor
=
'reverse_diffusion'
sampling
.
corrector
=
'langevin'
sampling
.
probability_flow
=
False
sampling
.
snr
=
0.15
sampling
.
n_steps_each
=
1
sampling
.
noise_removal
=
True
# eval
config
.
eval
=
evaluate
=
ml_collections
.
ConfigDict
()
evaluate
.
batch_size
=
1024
evaluate
.
num_samples
=
50000
evaluate
.
begin_ckpt
=
1
evaluate
.
end_ckpt
=
96
# data
config
.
data
=
data
=
ml_collections
.
ConfigDict
()
data
.
dataset
=
'FFHQ'
data
.
image_size
=
1024
data
.
centered
=
False
data
.
random_flip
=
True
data
.
uniform_dequantization
=
False
data
.
num_channels
=
3
# Plug in your own path to the tfrecords file.
data
.
tfrecords_path
=
'/raid/song/ffhq-dataset/ffhq/ffhq-r10.tfrecords'
# model
config
.
model
=
model
=
ml_collections
.
ConfigDict
()
model
.
name
=
'ncsnpp'
model
.
scale_by_sigma
=
True
model
.
sigma_max
=
1348
model
.
num_scales
=
2000
model
.
ema_rate
=
0.9999
model
.
sigma_min
=
0.01
model
.
normalization
=
'GroupNorm'
model
.
nonlinearity
=
'swish'
model
.
nf
=
16
model
.
ch_mult
=
(
1
,
2
,
4
,
8
,
16
,
32
,
32
,
32
)
model
.
num_res_blocks
=
1
model
.
attn_resolutions
=
(
16
,)
model
.
dropout
=
0.
model
.
resamp_with_conv
=
True
model
.
conditional
=
True
model
.
fir
=
True
model
.
fir_kernel
=
[
1
,
3
,
3
,
1
]
model
.
skip_rescale
=
True
model
.
resblock_type
=
'biggan'
model
.
progressive
=
'output_skip'
model
.
progressive_input
=
'input_skip'
model
.
progressive_combine
=
'sum'
model
.
attention_type
=
'ddpm'
model
.
init_scale
=
0.
model
.
fourier_scale
=
16
model
.
conv_size
=
3
model
.
embedding_type
=
'fourier'
# optim
config
.
optim
=
optim
=
ml_collections
.
ConfigDict
()
optim
.
weight_decay
=
0
optim
.
optimizer
=
'Adam'
optim
.
lr
=
2e-4
optim
.
beta1
=
0.9
optim
.
amsgrad
=
False
optim
.
eps
=
1e-8
optim
.
warmup
=
5000
optim
.
grad_clip
=
1.
config
.
seed
=
42
config
.
device
=
torch
.
device
(
'cuda:0'
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
'cpu'
)
return
config
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
manual_seed
(
3
)
torch
.
manual_seed
(
0
)
class
NewReverseDiffusionPredictor
:
class
NewReverseDiffusionPredictor
:
...
@@ -182,46 +91,25 @@ def save_image(x):
...
@@ -182,46 +91,25 @@ def save_image(x):
# Note usually we need to restore ema etc...
# Note usually we need to restore ema etc...
# ema restored checkpoint used from below
# ema restored checkpoint used from below
N
=
2
sigma_min
=
0.01
config
=
get_config
()
sigma_max
=
1348
sigma_min
,
sigma_max
=
config
.
model
.
sigma_min
,
config
.
model
.
sigma_max
N
=
config
.
model
.
num_scales
sampling_eps
=
1e-5
sampling_eps
=
1e-5
batch_size
=
1
batch_size
=
1
#@param {"type":"integer"}
centered
=
False
config
.
training
.
batch_size
=
batch_size
config
.
eval
.
batch_size
=
batch_size
from
diffusers
import
NCSNpp
from
diffusers
import
NCSNpp
model
=
NCSNpp
(
config
).
to
(
config
.
device
)
model
=
torch
.
nn
.
DataParallel
(
model
)
loaded_state
=
torch
.
load
(
"../score_sde_pytorch/ffhq_1024_ncsnpp_continuous_ema.pt"
)
del
loaded_state
[
"module.sigmas"
]
model
.
load_state_dict
(
loaded_state
,
strict
=
False
)
def
get_data_inverse_scaler
(
config
):
"""Inverse data normalizer."""
if
config
.
data
.
centered
:
# Rescale [-1, 1] to [0, 1]
return
lambda
x
:
(
x
+
1.
)
/
2.
else
:
return
lambda
x
:
x
inverse_scaler
=
get_data_inverse_scaler
(
config
)
model
=
NCSNpp
.
from_pretrained
(
"/home/patrick/ffhq_ncsnpp"
).
to
(
device
)
model
=
torch
.
nn
.
DataParallel
(
model
)
img_size
=
config
.
data
.
image_size
img_size
=
model
.
module
.
config
.
image_size
channels
=
config
.
data
.
num_channels
channels
=
model
.
module
.
config
.
num_channels
shape
=
(
batch_size
,
channels
,
img_size
,
img_size
)
shape
=
(
batch_size
,
channels
,
img_size
,
img_size
)
probability_flow
=
False
probability_flow
=
False
snr
=
0.15
#@param {"type": "number"}
snr
=
0.15
n_steps
=
1
#@param {"type": "integer"}
n_steps
=
1
device
=
config
.
device
new_corrector
=
NewLangevinCorrector
(
score_fn
=
model
,
snr
=
snr
,
n_steps
=
n_steps
,
sigma_min
=
sigma_min
,
sigma_max
=
sigma_max
)
new_corrector
=
NewLangevinCorrector
(
score_fn
=
model
,
snr
=
snr
,
n_steps
=
n_steps
,
sigma_min
=
sigma_min
,
sigma_max
=
sigma_max
)
new_predictor
=
NewReverseDiffusionPredictor
(
score_fn
=
model
,
sigma_min
=
sigma_min
,
sigma_max
=
sigma_max
,
N
=
N
)
new_predictor
=
NewReverseDiffusionPredictor
(
score_fn
=
model
,
sigma_min
=
sigma_min
,
sigma_max
=
sigma_max
,
N
=
N
)
...
@@ -238,10 +126,12 @@ with torch.no_grad():
...
@@ -238,10 +126,12 @@ with torch.no_grad():
x
,
x_mean
=
new_corrector
.
update_fn
(
x
,
vec_t
)
x
,
x_mean
=
new_corrector
.
update_fn
(
x
,
vec_t
)
x
,
x_mean
=
new_predictor
.
update_fn
(
x
,
vec_t
)
x
,
x_mean
=
new_predictor
.
update_fn
(
x
,
vec_t
)
x
=
inverse_scaler
(
x_mean
)
x
=
x_mean
if
centered
:
x
=
(
x
+
1.
)
/
2.
save_image
(
x
)
#
save_image(x)
# for 5 cifar10
# for 5 cifar10
x_sum
=
106071.9922
x_sum
=
106071.9922
...
@@ -260,4 +150,4 @@ def check_x_sum_x_mean(x, x_sum, x_mean):
...
@@ -260,4 +150,4 @@ def check_x_sum_x_mean(x, x_sum, x_mean):
assert
(
x
.
abs
().
mean
()
-
x_mean
).
abs
().
cpu
().
item
()
<
1e-4
,
f
"mean wrong
{
x
.
abs
().
mean
()
}
"
assert
(
x
.
abs
().
mean
()
-
x_mean
).
abs
().
cpu
().
item
()
<
1e-4
,
f
"mean wrong
{
x
.
abs
().
mean
()
}
"
#
check_x_sum_x_mean(x, x_sum, x_mean)
check_x_sum_x_mean
(
x
,
x_sum
,
x_mean
)
src/diffusers/models/unet_sde_score_estimation.py
View file @
bc2d586d
...
@@ -15,6 +15,9 @@
...
@@ -15,6 +15,9 @@
# helpers functions
# helpers functions
from
..modeling_utils
import
ModelMixin
from
..configuration_utils
import
ConfigMixin
import
functools
import
functools
import
math
import
math
...
@@ -372,16 +375,16 @@ class NIN(nn.Module):
...
@@ -372,16 +375,16 @@ class NIN(nn.Module):
return
y
.
permute
(
0
,
3
,
1
,
2
)
return
y
.
permute
(
0
,
3
,
1
,
2
)
def
get_act
(
c
on
fig
):
def
get_act
(
n
on
linearity
):
"""Get activation functions from the config file."""
"""Get activation functions from the config file."""
if
config
.
model
.
nonlinearity
.
lower
()
==
"elu"
:
if
nonlinearity
.
lower
()
==
"elu"
:
return
nn
.
ELU
()
return
nn
.
ELU
()
elif
config
.
model
.
nonlinearity
.
lower
()
==
"relu"
:
elif
nonlinearity
.
lower
()
==
"relu"
:
return
nn
.
ReLU
()
return
nn
.
ReLU
()
elif
config
.
model
.
nonlinearity
.
lower
()
==
"lrelu"
:
elif
nonlinearity
.
lower
()
==
"lrelu"
:
return
nn
.
LeakyReLU
(
negative_slope
=
0.2
)
return
nn
.
LeakyReLU
(
negative_slope
=
0.2
)
elif
config
.
model
.
nonlinearity
.
lower
()
==
"swish"
:
elif
nonlinearity
.
lower
()
==
"swish"
:
return
nn
.
SiLU
()
return
nn
.
SiLU
()
else
:
else
:
raise
NotImplementedError
(
"activation function does not exist!"
)
raise
NotImplementedError
(
"activation function does not exist!"
)
...
@@ -710,46 +713,93 @@ class ResnetBlockBigGANpp(nn.Module):
...
@@ -710,46 +713,93 @@ class ResnetBlockBigGANpp(nn.Module):
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
class
NCSNpp
(
nn
.
Module
):
class
NCSNpp
(
ModelMixin
,
ConfigMixin
):
"""NCSN++ model"""
"""NCSN++ model"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
centered
=
False
,
image_size
=
1024
,
num_channels
=
3
,
attention_type
=
"ddpm"
,
attn_resolutions
=
(
16
,),
ch_mult
=
(
1
,
2
,
4
,
8
,
16
,
32
,
32
,
32
),
conditional
=
True
,
conv_size
=
3
,
dropout
=
0.0
,
embedding_type
=
"fourier"
,
fir
=
True
,
fir_kernel
=
(
1
,
3
,
3
,
1
),
fourier_scale
=
16
,
init_scale
=
0.0
,
nf
=
16
,
nonlinearity
=
"swish"
,
normalization
=
"GroupNorm"
,
num_res_blocks
=
1
,
progressive
=
"output_skip"
,
progressive_combine
=
"sum"
,
progressive_input
=
"input_skip"
,
resamp_with_conv
=
True
,
resblock_type
=
"biggan"
,
scale_by_sigma
=
True
,
skip_rescale
=
True
,
continuous
=
True
,
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
register_to_config
(
self
.
act
=
act
=
get_act
(
config
)
centered
=
centered
,
image_size
=
image_size
,
num_channels
=
num_channels
,
attention_type
=
attention_type
,
attn_resolutions
=
attn_resolutions
,
ch_mult
=
ch_mult
,
conditional
=
conditional
,
conv_size
=
conv_size
,
dropout
=
dropout
,
embedding_type
=
embedding_type
,
fir
=
fir
,
fir_kernel
=
fir_kernel
,
fourier_scale
=
fourier_scale
,
init_scale
=
init_scale
,
nf
=
nf
,
nonlinearity
=
nonlinearity
,
normalization
=
normalization
,
num_res_blocks
=
num_res_blocks
,
progressive
=
progressive
,
progressive_combine
=
progressive_combine
,
progressive_input
=
progressive_input
,
resamp_with_conv
=
resamp_with_conv
,
resblock_type
=
resblock_type
,
scale_by_sigma
=
scale_by_sigma
,
skip_rescale
=
skip_rescale
,
continuous
=
continuous
,
)
self
.
act
=
act
=
get_act
(
nonlinearity
)
# self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config)))
# self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config)))
self
.
nf
=
nf
=
config
.
model
.
nf
self
.
nf
=
nf
ch_mult
=
config
.
model
.
ch_mult
self
.
num_res_blocks
=
num_res_blocks
self
.
num_res_blocks
=
num_res_blocks
=
config
.
model
.
num_res_blocks
self
.
attn_resolutions
=
attn_resolutions
self
.
attn_resolutions
=
attn_resolutions
=
config
.
model
.
attn_resolutions
self
.
num_resolutions
=
len
(
ch_mult
)
dropout
=
config
.
model
.
dropout
self
.
all_resolutions
=
all_resolutions
=
[
image_size
//
(
2
**
i
)
for
i
in
range
(
self
.
num_resolutions
)]
resamp_with_conv
=
config
.
model
.
resamp_with_conv
self
.
num_resolutions
=
num_resolutions
=
len
(
ch_mult
)
self
.
conditional
=
conditional
self
.
all_resolutions
=
all_resolutions
=
[
config
.
data
.
image_size
//
(
2
**
i
)
for
i
in
range
(
num_resolutions
)]
self
.
skip_rescale
=
skip_rescale
self
.
resblock_type
=
resblock_type
self
.
conditional
=
conditional
=
config
.
model
.
conditional
# noise-conditional
self
.
progressive
=
progressive
fir
=
config
.
model
.
fir
self
.
progressive_input
=
progressive_input
fir_kernel
=
config
.
model
.
fir_kernel
self
.
embedding_type
=
embedding_type
self
.
skip_rescale
=
skip_rescale
=
config
.
model
.
skip_rescale
self
.
resblock_type
=
resblock_type
=
config
.
model
.
resblock_type
.
lower
()
self
.
progressive
=
progressive
=
config
.
model
.
progressive
.
lower
()
self
.
progressive_input
=
progressive_input
=
config
.
model
.
progressive_input
.
lower
()
self
.
embedding_type
=
embedding_type
=
config
.
model
.
embedding_type
.
lower
()
init_scale
=
config
.
model
.
init_scale
assert
progressive
in
[
"none"
,
"output_skip"
,
"residual"
]
assert
progressive
in
[
"none"
,
"output_skip"
,
"residual"
]
assert
progressive_input
in
[
"none"
,
"input_skip"
,
"residual"
]
assert
progressive_input
in
[
"none"
,
"input_skip"
,
"residual"
]
assert
embedding_type
in
[
"fourier"
,
"positional"
]
assert
embedding_type
in
[
"fourier"
,
"positional"
]
combine_method
=
config
.
model
.
progressive_combine
.
lower
()
combine_method
=
progressive_combine
.
lower
()
combiner
=
functools
.
partial
(
Combine
,
method
=
combine_method
)
combiner
=
functools
.
partial
(
Combine
,
method
=
combine_method
)
modules
=
[]
modules
=
[]
# timestep/noise_level embedding; only for continuous training
# timestep/noise_level embedding; only for continuous training
if
embedding_type
==
"fourier"
:
if
embedding_type
==
"fourier"
:
# Gaussian Fourier features embeddings.
# Gaussian Fourier features embeddings.
assert
config
.
training
.
continuous
,
"Fourier features are only used for continuous training."
modules
.
append
(
GaussianFourierProjection
(
embedding_size
=
nf
,
scale
=
fourier_scale
))
modules
.
append
(
GaussianFourierProjection
(
embedding_size
=
nf
,
scale
=
config
.
model
.
fourier_scale
))
embed_dim
=
2
*
nf
embed_dim
=
2
*
nf
elif
embedding_type
==
"positional"
:
elif
embedding_type
==
"positional"
:
...
@@ -809,7 +859,7 @@ class NCSNpp(nn.Module):
...
@@ -809,7 +859,7 @@ class NCSNpp(nn.Module):
# Downsampling block
# Downsampling block
channels
=
config
.
data
.
num_channels
channels
=
num_channels
if
progressive_input
!=
"none"
:
if
progressive_input
!=
"none"
:
input_pyramid_ch
=
channels
input_pyramid_ch
=
channels
...
@@ -817,7 +867,7 @@ class NCSNpp(nn.Module):
...
@@ -817,7 +867,7 @@ class NCSNpp(nn.Module):
hs_c
=
[
nf
]
hs_c
=
[
nf
]
in_ch
=
nf
in_ch
=
nf
for
i_level
in
range
(
num_resolutions
):
for
i_level
in
range
(
self
.
num_resolutions
):
# Residual blocks for this resolution
# Residual blocks for this resolution
for
i_block
in
range
(
num_res_blocks
):
for
i_block
in
range
(
num_res_blocks
):
out_ch
=
nf
*
ch_mult
[
i_level
]
out_ch
=
nf
*
ch_mult
[
i_level
]
...
@@ -828,7 +878,7 @@ class NCSNpp(nn.Module):
...
@@ -828,7 +878,7 @@ class NCSNpp(nn.Module):
modules
.
append
(
AttnBlock
(
channels
=
in_ch
))
modules
.
append
(
AttnBlock
(
channels
=
in_ch
))
hs_c
.
append
(
in_ch
)
hs_c
.
append
(
in_ch
)
if
i_level
!=
num_resolutions
-
1
:
if
i_level
!=
self
.
num_resolutions
-
1
:
if
resblock_type
==
"ddpm"
:
if
resblock_type
==
"ddpm"
:
modules
.
append
(
Downsample
(
in_ch
=
in_ch
))
modules
.
append
(
Downsample
(
in_ch
=
in_ch
))
else
:
else
:
...
@@ -852,7 +902,7 @@ class NCSNpp(nn.Module):
...
@@ -852,7 +902,7 @@ class NCSNpp(nn.Module):
pyramid_ch
=
0
pyramid_ch
=
0
# Upsampling block
# Upsampling block
for
i_level
in
reversed
(
range
(
num_resolutions
)):
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
num_res_blocks
+
1
):
for
i_block
in
range
(
num_res_blocks
+
1
):
out_ch
=
nf
*
ch_mult
[
i_level
]
out_ch
=
nf
*
ch_mult
[
i_level
]
modules
.
append
(
ResnetBlock
(
in_ch
=
in_ch
+
hs_c
.
pop
(),
out_ch
=
out_ch
))
modules
.
append
(
ResnetBlock
(
in_ch
=
in_ch
+
hs_c
.
pop
(),
out_ch
=
out_ch
))
...
@@ -862,7 +912,7 @@ class NCSNpp(nn.Module):
...
@@ -862,7 +912,7 @@ class NCSNpp(nn.Module):
modules
.
append
(
AttnBlock
(
channels
=
in_ch
))
modules
.
append
(
AttnBlock
(
channels
=
in_ch
))
if
progressive
!=
"none"
:
if
progressive
!=
"none"
:
if
i_level
==
num_resolutions
-
1
:
if
i_level
==
self
.
num_resolutions
-
1
:
if
progressive
==
"output_skip"
:
if
progressive
==
"output_skip"
:
modules
.
append
(
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
))
modules
.
append
(
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
))
modules
.
append
(
conv3x3
(
in_ch
,
channels
,
init_scale
=
init_scale
))
modules
.
append
(
conv3x3
(
in_ch
,
channels
,
init_scale
=
init_scale
))
...
@@ -899,7 +949,6 @@ class NCSNpp(nn.Module):
...
@@ -899,7 +949,6 @@ class NCSNpp(nn.Module):
self
.
all_modules
=
nn
.
ModuleList
(
modules
)
self
.
all_modules
=
nn
.
ModuleList
(
modules
)
def
forward
(
self
,
x
,
time_cond
):
def
forward
(
self
,
x
,
time_cond
):
# import ipdb; ipdb.set_trace()
# timestep/noise_level embedding; only for continuous training
# timestep/noise_level embedding; only for continuous training
modules
=
self
.
all_modules
modules
=
self
.
all_modules
m_idx
=
0
m_idx
=
0
...
@@ -926,7 +975,7 @@ class NCSNpp(nn.Module):
...
@@ -926,7 +975,7 @@ class NCSNpp(nn.Module):
else
:
else
:
temb
=
None
temb
=
None
if
not
self
.
config
.
data
.
centered
:
if
not
self
.
config
.
centered
:
# If input data is in [0, 1]
# If input data is in [0, 1]
x
=
2
*
x
-
1.0
x
=
2
*
x
-
1.0
...
@@ -1044,7 +1093,7 @@ class NCSNpp(nn.Module):
...
@@ -1044,7 +1093,7 @@ class NCSNpp(nn.Module):
m_idx
+=
1
m_idx
+=
1
assert
m_idx
==
len
(
modules
)
assert
m_idx
==
len
(
modules
)
if
self
.
config
.
model
.
scale_by_sigma
:
if
self
.
config
.
scale_by_sigma
:
used_sigmas
=
used_sigmas
.
reshape
((
x
.
shape
[
0
],
*
([
1
]
*
len
(
x
.
shape
[
1
:]))))
used_sigmas
=
used_sigmas
.
reshape
((
x
.
shape
[
0
],
*
([
1
]
*
len
(
x
.
shape
[
1
:]))))
h
=
h
/
used_sigmas
h
=
h
/
used_sigmas
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment