Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
49a81f9f
Commit
49a81f9f
authored
Jun 24, 2022
by
Patrick von Platen
Browse files
port first 1024 model
parent
78e99a99
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
150 additions
and
200 deletions
+150
-200
run.py
run.py
+150
-200
No files found.
run.py
View file @
49a81f9f
#!/usr/bin/env python3
#!/usr/bin/env python3
import
numpy
as
np
import
numpy
as
np
import
PIL
import
PIL
import
functools
import
models
from
models
import
utils
as
mutils
from
models
import
ncsnv2
from
models
import
ncsnpp
from
models
import
ddpm
as
ddpm_model
from
models
import
layerspp
from
models
import
layers
from
models
import
normalization
from
models.ema
import
ExponentialMovingAverage
from
losses
import
get_optimizer
from
utils
import
restore_checkpoint
import
sampling
from
sde_lib
import
VESDE
,
VPSDE
,
subVPSDE
from
sampling
import
(
NoneCorrector
,
ReverseDiffusionPredictor
,
LangevinCorrector
,
EulerMaruyamaPredictor
,
AncestralSamplingPredictor
,
NonePredictor
,
AnnealedLangevinDynamics
)
import
datasets
import
torch
import
torch
import
ml_collections
#from configs.ve import ffhq_ncsnpp_continuous as configs
# from configs.ve import cifar10_ncsnpp_continuous as configs
# ffhq_ncsnpp_continuous config
def
get_config
():
config
=
ml_collections
.
ConfigDict
()
# training
config
.
training
=
training
=
ml_collections
.
ConfigDict
()
training
.
batch_size
=
8
training
.
n_iters
=
2400001
training
.
snapshot_freq
=
50000
training
.
log_freq
=
50
training
.
eval_freq
=
100
training
.
snapshot_freq_for_preemption
=
5000
training
.
snapshot_sampling
=
True
training
.
sde
=
'vesde'
training
.
continuous
=
True
training
.
likelihood_weighting
=
False
training
.
reduce_mean
=
True
# sampling
config
.
sampling
=
sampling
=
ml_collections
.
ConfigDict
()
sampling
.
method
=
'pc'
sampling
.
predictor
=
'reverse_diffusion'
sampling
.
corrector
=
'langevin'
sampling
.
probability_flow
=
False
sampling
.
snr
=
0.15
sampling
.
n_steps_each
=
1
sampling
.
noise_removal
=
True
# eval
config
.
eval
=
evaluate
=
ml_collections
.
ConfigDict
()
evaluate
.
batch_size
=
1024
evaluate
.
num_samples
=
50000
evaluate
.
begin_ckpt
=
1
evaluate
.
end_ckpt
=
96
# data
config
.
data
=
data
=
ml_collections
.
ConfigDict
()
data
.
dataset
=
'FFHQ'
data
.
image_size
=
1024
data
.
centered
=
False
data
.
random_flip
=
True
data
.
uniform_dequantization
=
False
data
.
num_channels
=
3
# Plug in your own path to the tfrecords file.
data
.
tfrecords_path
=
'/raid/song/ffhq-dataset/ffhq/ffhq-r10.tfrecords'
# model
config
.
model
=
model
=
ml_collections
.
ConfigDict
()
model
.
name
=
'ncsnpp'
model
.
scale_by_sigma
=
True
model
.
sigma_max
=
1348
model
.
num_scales
=
2000
model
.
ema_rate
=
0.9999
model
.
sigma_min
=
0.01
model
.
normalization
=
'GroupNorm'
model
.
nonlinearity
=
'swish'
model
.
nf
=
16
model
.
ch_mult
=
(
1
,
2
,
4
,
8
,
16
,
32
,
32
,
32
)
model
.
num_res_blocks
=
1
model
.
attn_resolutions
=
(
16
,)
model
.
dropout
=
0.
model
.
resamp_with_conv
=
True
model
.
conditional
=
True
model
.
fir
=
True
model
.
fir_kernel
=
[
1
,
3
,
3
,
1
]
model
.
skip_rescale
=
True
model
.
resblock_type
=
'biggan'
model
.
progressive
=
'output_skip'
model
.
progressive_input
=
'input_skip'
model
.
progressive_combine
=
'sum'
model
.
attention_type
=
'ddpm'
model
.
init_scale
=
0.
model
.
fourier_scale
=
16
model
.
conv_size
=
3
model
.
embedding_type
=
'fourier'
# optim
config
.
optim
=
optim
=
ml_collections
.
ConfigDict
()
optim
.
weight_decay
=
0
optim
.
optimizer
=
'Adam'
optim
.
lr
=
2e-4
optim
.
beta1
=
0.9
optim
.
amsgrad
=
False
optim
.
eps
=
1e-8
optim
.
warmup
=
5000
optim
.
grad_clip
=
1.
config
.
seed
=
42
config
.
device
=
torch
.
device
(
'cuda:0'
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
'cpu'
)
return
config
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
3
)
#class NewVESDE(SDE):
# def __init__(self, sigma_min=0.01, sigma_max=50, N=1000):
# """Construct a Variance Exploding SDE.
#
# Args:
# sigma_min: smallest sigma.
# sigma_max: largest sigma.
# N: number of discretization steps
# """
# super().__init__(N)
# self.sigma_min = sigma_min
# self.sigma_max = sigma_max
# self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N))
# self.N = N
#
# @property
# def T(self):
# return 1
#
# def sde(self, x, t):
# sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
# drift = torch.zeros_like(x)
# diffusion = sigma * torch.sqrt(torch.tensor(2 * (np.log(self.sigma_max) - np.log(self.sigma_min)),
# device=t.device))
# return drift, diffusion
#
# def marginal_prob(self, x, t):
# std = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
# mean = x
# return mean, std
#
# def prior_sampling(self, shape):
# return torch.randn(*shape) * self.sigma_max
#
# def prior_logp(self, z):
# shape = z.shape
# N = np.prod(shape[1:])
# return -N / 2. * np.log(2 * np.pi * self.sigma_max ** 2) - torch.sum(z ** 2, dim=(1, 2, 3)) / (2 * self.sigma_max ** 2)
#
# def discretize(self, x, t):
# """SMLD(NCSN) discretization."""
# timestep = (t * (self.N - 1) / self.T).long()
# sigma = self.discrete_sigmas.to(t.device)[timestep]
# adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t),
# self.discrete_sigmas[timestep - 1].to(t.device))
# f = torch.zeros_like(x)
# G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2)
# return f, G
class
NewReverseDiffusionPredictor
:
class
NewReverseDiffusionPredictor
:
def
__init__
(
self
,
sde
,
score_fn
,
probability_flow
=
False
):
def
__init__
(
self
,
score_fn
,
probability_flow
=
False
,
sigma_min
=
0.0
,
sigma_max
=
0.0
,
N
=
0
):
super
().
__init__
()
super
().
__init__
()
self
.
sde
=
sde
self
.
sigma_min
=
sigma_min
self
.
sigma_max
=
sigma_max
self
.
N
=
N
self
.
discrete_sigmas
=
torch
.
exp
(
torch
.
linspace
(
np
.
log
(
self
.
sigma_min
),
np
.
log
(
self
.
sigma_max
),
N
))
self
.
probability_flow
=
probability_flow
self
.
probability_flow
=
probability_flow
self
.
score_fn
=
score_fn
self
.
score_fn
=
score_fn
def
discretize
(
self
,
x
,
t
):
def
discretize
(
self
,
x
,
t
):
timestep
=
(
t
*
(
self
.
sde
.
N
-
1
)
/
self
.
sde
.
T
).
long
()
timestep
=
(
t
*
(
self
.
N
-
1
)).
long
()
sigma
=
self
.
sde
.
discrete_sigmas
.
to
(
t
.
device
)[
timestep
]
sigma
=
self
.
discrete_sigmas
.
to
(
t
.
device
)[
timestep
]
adjacent_sigma
=
torch
.
where
(
timestep
==
0
,
torch
.
zeros_like
(
t
),
adjacent_sigma
=
torch
.
where
(
timestep
==
0
,
torch
.
zeros_like
(
t
),
self
.
sde
.
discrete_sigmas
[
timestep
-
1
].
to
(
t
.
device
))
self
.
discrete_sigmas
[
timestep
-
1
].
to
(
t
.
device
))
f
=
torch
.
zeros_like
(
x
)
f
=
torch
.
zeros_like
(
x
)
G
=
torch
.
sqrt
(
sigma
**
2
-
adjacent_sigma
**
2
)
G
=
torch
.
sqrt
(
sigma
**
2
-
adjacent_sigma
**
2
)
labels
=
self
.
s
de
.
marginal_prob
(
torch
.
zeros_like
(
x
),
t
)[
1
]
labels
=
self
.
s
igma_min
*
(
self
.
sigma_max
/
self
.
sigma_min
)
**
t
result
=
self
.
score_fn
(
x
,
labels
)
result
=
self
.
score_fn
(
x
,
labels
)
rev_f
=
f
-
G
[:,
None
,
None
,
None
]
**
2
*
result
*
(
0.5
if
self
.
probability_flow
else
1.
)
rev_f
=
f
-
G
[:,
None
,
None
,
None
]
**
2
*
result
*
(
0.5
if
self
.
probability_flow
else
1.
)
...
@@ -114,26 +138,27 @@ class NewReverseDiffusionPredictor:
...
@@ -114,26 +138,27 @@ class NewReverseDiffusionPredictor:
class
NewLangevinCorrector
:
class
NewLangevinCorrector
:
def
__init__
(
self
,
sde
,
score_fn
,
snr
,
n_steps
):
def
__init__
(
self
,
score_fn
,
snr
,
n_steps
,
sigma_min
=
0.0
,
sigma_max
=
0.0
):
super
().
__init__
()
super
().
__init__
()
self
.
sde
=
sde
self
.
score_fn
=
score_fn
self
.
score_fn
=
score_fn
self
.
snr
=
snr
self
.
snr
=
snr
self
.
n_steps
=
n_steps
self
.
n_steps
=
n_steps
self
.
sigma_min
=
sigma_min
self
.
sigma_max
=
sigma_max
def
update_fn
(
self
,
x
,
t
):
def
update_fn
(
self
,
x
,
t
):
sde
=
self
.
sde
score_fn
=
self
.
score_fn
score_fn
=
self
.
score_fn
n_steps
=
self
.
n_steps
n_steps
=
self
.
n_steps
target_snr
=
self
.
snr
target_snr
=
self
.
snr
if
isinstance
(
sde
,
VPSDE
)
or
isinstance
(
sde
,
subVPSDE
):
#
if isinstance(sde, VPSDE) or isinstance(sde, subVPSDE):
timestep
=
(
t
*
(
sde
.
N
-
1
)
/
sde
.
T
).
long
()
#
timestep = (t * (sde.N - 1) / sde.T).long()
alpha
=
sde
.
alphas
.
to
(
t
.
device
)[
timestep
]
#
alpha = sde.alphas.to(t.device)[timestep]
else
:
#
else:
alpha
=
torch
.
ones_like
(
t
)
alpha
=
torch
.
ones_like
(
t
)
for
i
in
range
(
n_steps
):
for
i
in
range
(
n_steps
):
labels
=
s
de
.
marginal_prob
(
torch
.
zeros_like
(
x
),
t
)[
1
]
labels
=
s
elf
.
sigma_min
*
(
self
.
sigma_max
/
self
.
sigma_min
)
**
t
grad
=
score_fn
(
x
,
labels
)
grad
=
score_fn
(
x
,
labels
)
noise
=
torch
.
randn_like
(
x
)
noise
=
torch
.
randn_like
(
x
)
grad_norm
=
torch
.
norm
(
grad
.
reshape
(
grad
.
shape
[
0
],
-
1
),
dim
=-
1
).
mean
()
grad_norm
=
torch
.
norm
(
grad
.
reshape
(
grad
.
shape
[
0
],
-
1
),
dim
=-
1
).
mean
()
...
@@ -152,64 +177,42 @@ def save_image(x):
...
@@ -152,64 +177,42 @@ def save_image(x):
image_pil
.
save
(
"../images/hey.png"
)
image_pil
.
save
(
"../images/hey.png"
)
sde
=
'VESDE'
#@param ['VESDE', 'VPSDE', 'subVPSDE'] {"type": "string"}
if
sde
.
lower
()
==
'vesde'
:
# from configs.ve import cifar10_ncsnpp_continuous as configs
# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
from
configs.ve
import
ffhq_ncsnpp_continuous
as
configs
#ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
ckpt_filename
=
"exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
# Note usually we need to restore ema etc...
config
=
configs
.
get_config
()
# ema restored checkpoint used from below
config
.
model
.
num_scales
=
2
sde
=
VESDE
(
sigma_min
=
config
.
model
.
sigma_min
,
sigma_max
=
config
.
model
.
sigma_max
,
N
=
config
.
model
.
num_scales
)
sampling_eps
=
1e-5
elif
sde
.
lower
()
==
'vpsde'
:
from
configs.vp
import
cifar10_ddpmpp_continuous
as
configs
ckpt_filename
=
"exp/vp/cifar10_ddpmpp_continuous/checkpoint_8.pth"
config
=
configs
.
get_config
()
sde
=
VPSDE
(
beta_min
=
config
.
model
.
beta_min
,
beta_max
=
config
.
model
.
beta_max
,
N
=
config
.
model
.
num_scales
)
sampling_eps
=
1e-3
elif
sde
.
lower
()
==
'subvpsde'
:
from
configs.subvp
import
cifar10_ddpmpp_continuous
as
configs
ckpt_filename
=
"exp/subvp/cifar10_ddpmpp_continuous/checkpoint_26.pth"
config
=
configs
.
get_config
()
sde
=
subVPSDE
(
beta_min
=
config
.
model
.
beta_min
,
beta_max
=
config
.
model
.
beta_max
,
N
=
config
.
model
.
num_scales
)
sampling_eps
=
1e-3
batch_size
=
1
#@param {"type":"integer"}
config
.
training
.
batch_size
=
batch_size
config
.
eval
.
batch_size
=
batch_size
random_seed
=
0
#@param {"type": "integer"}
config
=
get_config
()
#sigmas = mutils.get_sigmas(config)
sigma_min
,
sigma_max
=
config
.
model
.
sigma_min
,
config
.
model
.
sigma_max
#scaler = datasets.get_data_scaler(config)
N
=
config
.
model
.
num_scales
#inverse_scaler = datasets.get_data_inverse_scaler(config)
#score_model = mutils.create_model(config)
#
#optimizer = get_optimizer(config, score_model.parameters())
#ema = ExponentialMovingAverage(score_model.parameters(),
# decay=config.model.ema_rate)
#state = dict(step=0, optimizer=optimizer,
# model=score_model, ema=ema)
#
#state = restore_checkpoint(ckpt_filename, state, config.device)
#ema.copy_to(score_model.parameters())
#score_model = mutils.create_model(config)
sampling_eps
=
1e-5
batch_size
=
1
#@param {"type":"integer"}
config
.
training
.
batch_size
=
batch_size
config
.
eval
.
batch_size
=
batch_size
from
diffusers
import
NCSNpp
from
diffusers
import
NCSNpp
score_
model
=
NCSNpp
(
config
).
to
(
config
.
device
)
model
=
NCSNpp
(
config
).
to
(
config
.
device
)
score_
model
=
torch
.
nn
.
DataParallel
(
score_
model
)
model
=
torch
.
nn
.
DataParallel
(
model
)
loaded_state
=
torch
.
load
(
"./ffhq_1024_ncsnpp_continuous_ema.pt"
)
loaded_state
=
torch
.
load
(
".
./score_sde_pytorch
/ffhq_1024_ncsnpp_continuous_ema.pt"
)
del
loaded_state
[
"module.sigmas"
]
del
loaded_state
[
"module.sigmas"
]
score_
model
.
load_state_dict
(
loaded_state
,
strict
=
False
)
model
.
load_state_dict
(
loaded_state
,
strict
=
False
)
inverse_scaler
=
datasets
.
get_data_inverse_scaler
(
config
)
def
get_data_inverse_scaler
(
config
):
predictor
=
ReverseDiffusionPredictor
#@param ["EulerMaruyamaPredictor", "AncestralSamplingPredictor", "ReverseDiffusionPredictor", "None"] {"type": "raw"}
"""Inverse data normalizer."""
corrector
=
LangevinCorrector
#@param ["LangevinCorrector", "AnnealedLangevinDynamics", "None"] {"type": "raw"}
if
config
.
data
.
centered
:
# Rescale [-1, 1] to [0, 1]
return
lambda
x
:
(
x
+
1.
)
/
2.
else
:
return
lambda
x
:
x
inverse_scaler
=
get_data_inverse_scaler
(
config
)
#@title PC sampling
img_size
=
config
.
data
.
image_size
img_size
=
config
.
data
.
image_size
channels
=
config
.
data
.
num_channels
channels
=
config
.
data
.
num_channels
shape
=
(
batch_size
,
channels
,
img_size
,
img_size
)
shape
=
(
batch_size
,
channels
,
img_size
,
img_size
)
...
@@ -218,80 +221,27 @@ snr = 0.15 #@param {"type": "number"}
...
@@ -218,80 +221,27 @@ snr = 0.15 #@param {"type": "number"}
n_steps
=
1
#@param {"type": "integer"}
n_steps
=
1
#@param {"type": "integer"}
#sampling_fn = sampling.get_pc_sampler(sde, shape, predictor, corrector,
# inverse_scaler, snr, n_steps=n_steps,
# probability_flow=probability_flow,
# continuous=config.training.continuous,
# eps=sampling_eps, device=config.device)
#
#x, n = sampling_fn(score_model)
#save_image(x)
def
shared_predictor_update_fn
(
x
,
t
,
sde
,
model
,
predictor
,
probability_flow
,
continuous
):
"""A wrapper that configures and returns the update function of predictors."""
score_fn
=
mutils
.
get_score_fn
(
sde
,
model
,
train
=
False
,
continuous
=
continuous
)
if
predictor
is
None
:
# Corrector-only sampler
predictor_obj
=
NonePredictor
(
sde
,
score_fn
,
probability_flow
)
else
:
predictor_obj
=
predictor
(
sde
,
score_fn
,
probability_flow
)
return
predictor_obj
.
update_fn
(
x
,
t
)
def
shared_corrector_update_fn
(
x
,
t
,
sde
,
model
,
corrector
,
continuous
,
snr
,
n_steps
):
"""A wrapper tha configures and returns the update function of correctors."""
score_fn
=
mutils
.
get_score_fn
(
sde
,
model
,
train
=
False
,
continuous
=
continuous
)
if
corrector
is
None
:
# Predictor-only sampler
corrector_obj
=
NoneCorrector
(
sde
,
score_fn
,
snr
,
n_steps
)
else
:
corrector_obj
=
corrector
(
sde
,
score_fn
,
snr
,
n_steps
)
return
corrector_obj
.
update_fn
(
x
,
t
)
continuous
=
config
.
training
.
continuous
predictor_update_fn
=
functools
.
partial
(
shared_predictor_update_fn
,
sde
=
sde
,
predictor
=
predictor
,
probability_flow
=
probability_flow
,
continuous
=
continuous
)
corrector_update_fn
=
functools
.
partial
(
shared_corrector_update_fn
,
sde
=
sde
,
corrector
=
corrector
,
continuous
=
continuous
,
snr
=
snr
,
n_steps
=
n_steps
)
device
=
config
.
device
device
=
config
.
device
model
=
score_model
denoise
=
True
new_corrector
=
NewLangevinCorrector
(
sde
=
sde
,
score_fn
=
model
,
snr
=
snr
,
n_steps
=
n_steps
)
new_corrector
=
NewLangevinCorrector
(
score_fn
=
model
,
snr
=
snr
,
n_steps
=
n_steps
,
sigma_min
=
sigma_min
,
sigma_max
=
sigma_max
)
new_predictor
=
NewReverseDiffusionPredictor
(
sde
=
sde
,
score_fn
=
model
)
new_predictor
=
NewReverseDiffusionPredictor
(
score_fn
=
model
,
sigma_min
=
sigma_min
,
sigma_max
=
sigma_max
,
N
=
N
)
#
with
torch
.
no_grad
():
with
torch
.
no_grad
():
# Initial sample
# Initial sample
x
=
sde
.
prior_sampling
(
shape
).
to
(
device
)
x
=
torch
.
randn
(
*
shape
)
*
sigma_max
timesteps
=
torch
.
linspace
(
sde
.
T
,
sampling_eps
,
sde
.
N
,
device
=
device
)
x
=
x
.
to
(
device
)
timesteps
=
torch
.
linspace
(
1
,
sampling_eps
,
N
,
device
=
device
)
for
i
in
range
(
sde
.
N
):
for
i
in
range
(
N
):
t
=
timesteps
[
i
]
t
=
timesteps
[
i
]
vec_t
=
torch
.
ones
(
shape
[
0
],
device
=
t
.
device
)
*
t
vec_t
=
torch
.
ones
(
shape
[
0
],
device
=
t
.
device
)
*
t
# x, x_mean = corrector_update_fn(x, vec_t, model=model)
# x, x_mean = predictor_update_fn(x, vec_t, model=model)
x
,
x_mean
=
new_corrector
.
update_fn
(
x
,
vec_t
)
x
,
x_mean
=
new_corrector
.
update_fn
(
x
,
vec_t
)
x
,
x_mean
=
new_predictor
.
update_fn
(
x
,
vec_t
)
x
,
x_mean
=
new_predictor
.
update_fn
(
x
,
vec_t
)
x
,
n
=
inverse_scaler
(
x_mean
if
denoise
else
x
),
sde
.
N
*
(
n_steps
+
1
)
x
=
inverse_scaler
(
x_mean
)
#
save_image(x)
save_image
(
x
)
# for 5 cifar10
# for 5 cifar10
x_sum
=
106071.9922
x_sum
=
106071.9922
...
@@ -310,4 +260,4 @@ def check_x_sum_x_mean(x, x_sum, x_mean):
...
@@ -310,4 +260,4 @@ def check_x_sum_x_mean(x, x_sum, x_mean):
assert
(
x
.
abs
().
mean
()
-
x_mean
).
abs
().
cpu
().
item
()
<
1e-4
,
f
"mean wrong
{
x
.
abs
().
mean
()
}
"
assert
(
x
.
abs
().
mean
()
-
x_mean
).
abs
().
cpu
().
item
()
<
1e-4
,
f
"mean wrong
{
x
.
abs
().
mean
()
}
"
check_x_sum_x_mean
(
x
,
x_sum
,
x_mean
)
#
check_x_sum_x_mean(x, x_sum, x_mean)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment