Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
78e99a99
Commit
78e99a99
authored
Jun 24, 2022
by
Patrick von Platen
Browse files
adapt run.py
parent
fc67917a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
67 additions
and
43 deletions
+67
-43
run.py
run.py
+67
-43
No files found.
run.py
View file @
78e99a99
...
...
@@ -11,6 +11,8 @@ from models import ddpm as ddpm_model
from
models
import
layerspp
from
models
import
layers
from
models
import
normalization
from
models.ema
import
ExponentialMovingAverage
from
losses
import
get_optimizer
from
utils
import
restore_checkpoint
...
...
@@ -27,6 +29,7 @@ import datasets
import
torch
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
manual_seed
(
0
)
...
...
@@ -81,7 +84,6 @@ torch.manual_seed(0)
class
NewReverseDiffusionPredictor
:
def
__init__
(
self
,
sde
,
score_fn
,
probability_flow
=
False
):
super
().
__init__
()
self
.
sde
=
sde
...
...
@@ -112,7 +114,6 @@ class NewReverseDiffusionPredictor:
class
NewLangevinCorrector
:
def
__init__
(
self
,
sde
,
score_fn
,
snr
,
n_steps
):
super
().
__init__
()
self
.
sde
=
sde
...
...
@@ -146,28 +147,19 @@ class NewLangevinCorrector:
def
save_image
(
x
):
# image_processed = x.cpu().permute(0, 2, 3, 1)
# image_processed = (image_processed + 1.0) * 127.5
# image_processed = image_processed.numpy().astype(np.uint8)
image_processed
=
np
.
clip
(
x
.
permute
(
0
,
2
,
3
,
1
).
cpu
().
numpy
()
*
255
,
0
,
255
).
astype
(
np
.
uint8
)
image_pil
=
PIL
.
Image
.
fromarray
(
image_processed
[
0
])
# 6. save image
image_pil
.
save
(
"../images/hey.png"
)
#x = np.load("cifar10.npy")
#
#save_image(x)
# @title Load the score-based model
sde
=
'VESDE'
#@param ['VESDE', 'VPSDE', 'subVPSDE'] {"type": "string"}
if
sde
.
lower
()
==
'vesde'
:
from
configs.ve
import
cifar10_ncsnpp_continuous
as
configs
ckpt_filename
=
"exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
#
from configs.ve import ffhq_ncsnpp_continuous as configs
#
ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
#
from configs.ve import cifar10_ncsnpp_continuous as configs
#
ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
from
configs.ve
import
ffhq_ncsnpp_continuous
as
configs
ckpt_filename
=
"exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
config
=
configs
.
get_config
()
config
.
model
.
num_scales
=
1000
config
.
model
.
num_scales
=
2
sde
=
VESDE
(
sigma_min
=
config
.
model
.
sigma_min
,
sigma_max
=
config
.
model
.
sigma_max
,
N
=
config
.
model
.
num_scales
)
sampling_eps
=
1e-5
elif
sde
.
lower
()
==
'vpsde'
:
...
...
@@ -189,32 +181,53 @@ config.eval.batch_size = batch_size
random_seed
=
0
#@param {"type": "integer"}
score_model
=
mutils
.
create_model
(
config
)
#sigmas = mutils.get_sigmas(config)
#scaler = datasets.get_data_scaler(config)
#inverse_scaler = datasets.get_data_inverse_scaler(config)
#score_model = mutils.create_model(config)
#
#optimizer = get_optimizer(config, score_model.parameters())
#ema = ExponentialMovingAverage(score_model.parameters(),
# decay=config.model.ema_rate)
#state = dict(step=0, optimizer=optimizer,
# model=score_model, ema=ema)
#
#state = restore_checkpoint(ckpt_filename, state, config.device)
#ema.copy_to(score_model.parameters())
#score_model = mutils.create_model(config)
from
diffusers
import
NCSNpp
score_model
=
NCSNpp
(
config
).
to
(
config
.
device
)
score_model
=
torch
.
nn
.
DataParallel
(
score_model
)
loaded_state
=
torch
.
load
(
ckpt_filename
)
score_model
.
load_state_dict
(
loaded_state
[
"model"
],
strict
=
False
)
loaded_state
=
torch
.
load
(
"./ffhq_1024_ncsnpp_continuous_ema.pt"
)
del
loaded_state
[
"module.sigmas"
]
score_model
.
load_state_dict
(
loaded_state
,
strict
=
False
)
inverse_scaler
=
datasets
.
get_data_inverse_scaler
(
config
)
predictor
=
ReverseDiffusionPredictor
#@param ["EulerMaruyamaPredictor", "AncestralSamplingPredictor", "ReverseDiffusionPredictor", "None"] {"type": "raw"}
corrector
=
LangevinCorrector
#@param ["LangevinCorrector", "AnnealedLangevinDynamics", "None"] {"type": "raw"}
def
image_grid
(
x
):
size
=
config
.
data
.
image_size
channels
=
config
.
data
.
num_channels
img
=
x
.
reshape
(
-
1
,
size
,
size
,
channels
)
w
=
int
(
np
.
sqrt
(
img
.
shape
[
0
]))
img
=
img
.
reshape
((
w
,
w
,
size
,
size
,
channels
)).
transpose
((
0
,
2
,
1
,
3
,
4
)).
reshape
((
w
*
size
,
w
*
size
,
channels
))
return
img
#@title PC sampling
img_size
=
config
.
data
.
image_size
channels
=
config
.
data
.
num_channels
shape
=
(
batch_size
,
channels
,
img_size
,
img_size
)
probability_flow
=
False
snr
=
0.1
6
#@param {"type": "number"}
snr
=
0.1
5
#@param {"type": "number"}
n_steps
=
1
#@param {"type": "integer"}
#sampling_fn = sampling.get_pc_sampler(sde, shape, predictor, corrector,
# inverse_scaler, snr, n_steps=n_steps,
# probability_flow=probability_flow,
# continuous=config.training.continuous,
# eps=sampling_eps, device=config.device)
#
#x, n = sampling_fn(score_model)
#save_image(x)
def
shared_predictor_update_fn
(
x
,
t
,
sde
,
model
,
predictor
,
probability_flow
,
continuous
):
"""A wrapper that configures and returns the update function of predictors."""
score_fn
=
mutils
.
get_score_fn
(
sde
,
model
,
train
=
False
,
continuous
=
continuous
)
...
...
@@ -253,14 +266,14 @@ corrector_update_fn = functools.partial(shared_corrector_update_fn,
snr
=
snr
,
n_steps
=
n_steps
)
device
=
"cuda"
model
=
score_model
.
to
(
device
)
denoise
=
Fals
e
device
=
config
.
device
model
=
score_model
denoise
=
Tru
e
new_corrector
=
NewLangevinCorrector
(
sde
=
sde
,
score_fn
=
model
,
snr
=
snr
,
n_steps
=
n_steps
)
new_predictor
=
NewReverseDiffusionPredictor
(
sde
=
sde
,
score_fn
=
model
)
#
with
torch
.
no_grad
():
# Initial sample
x
=
sde
.
prior_sampling
(
shape
).
to
(
device
)
...
...
@@ -269,21 +282,32 @@ with torch.no_grad():
for
i
in
range
(
sde
.
N
):
t
=
timesteps
[
i
]
vec_t
=
torch
.
ones
(
shape
[
0
],
device
=
t
.
device
)
*
t
x
,
x_mean
=
corrector_update_fn
(
x
,
vec_t
,
model
=
model
)
x
,
x_mean
=
predictor_update_fn
(
x
,
vec_t
,
model
=
model
)
#
x, x_mean = new_corrector.update_fn(x, vec_t)
#
x, x_mean = new_predictor.update_fn(x, vec_t)
#
x, x_mean = corrector_update_fn(x, vec_t, model=model)
#
x, x_mean = predictor_update_fn(x, vec_t, model=model)
x
,
x_mean
=
new_corrector
.
update_fn
(
x
,
vec_t
)
x
,
x_mean
=
new_predictor
.
update_fn
(
x
,
vec_t
)
x
,
n
=
inverse_scaler
(
x_mean
if
denoise
else
x
),
sde
.
N
*
(
n_steps
+
1
)
save_image
(
x
)
# for 5
#assert (x.abs().sum() - 106114.90625).cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
#assert (x.abs().mean() - 34.5426139831543).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
#save_image(x)
# for 5 cifar10
x_sum
=
106071.9922
x_mean
=
34.52864456176758
# for 1000 cifar10
x_sum
=
461.9700
x_mean
=
0.1504
# for 2 for 1024
x_sum
=
3382810112.0
x_mean
=
1075.366455078125
def
check_x_sum_x_mean
(
x
,
x_sum
,
x_mean
):
assert
(
x
.
abs
().
sum
()
-
x_sum
).
abs
().
cpu
().
item
()
<
1e-2
,
f
"sum wrong
{
x
.
abs
().
sum
()
}
"
assert
(
x
.
abs
().
mean
()
-
x_mean
).
abs
().
cpu
().
item
()
<
1e-4
,
f
"mean wrong
{
x
.
abs
().
mean
()
}
"
# for 1000
assert
(
x
.
abs
().
sum
()
-
436.5811
).
abs
().
sum
().
cpu
().
item
()
<
1e-2
,
f
"sum wrong
{
x
.
abs
().
sum
()
}
"
assert
(
x
.
abs
().
mean
()
-
0.1421
).
abs
().
mean
().
cpu
().
item
()
<
1e-4
,
f
"mean wrong
{
x
.
abs
().
mean
()
}
"
check_x_sum_x_mean
(
x
,
x_sum
,
x_mean
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment