Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
78e99a99
Commit
78e99a99
authored
Jun 24, 2022
by
Patrick von Platen
Browse files
adapt run.py
parent
fc67917a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
67 additions
and
43 deletions
+67
-43
run.py
run.py
+67
-43
No files found.
run.py
View file @
78e99a99
...
@@ -11,6 +11,8 @@ from models import ddpm as ddpm_model
...
@@ -11,6 +11,8 @@ from models import ddpm as ddpm_model
from
models
import
layerspp
from
models
import
layerspp
from
models
import
layers
from
models
import
layers
from
models
import
normalization
from
models
import
normalization
from
models.ema
import
ExponentialMovingAverage
from
losses
import
get_optimizer
from
utils
import
restore_checkpoint
from
utils
import
restore_checkpoint
...
@@ -27,6 +29,7 @@ import datasets
...
@@ -27,6 +29,7 @@ import datasets
import
torch
import
torch
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
...
@@ -81,7 +84,6 @@ torch.manual_seed(0)
...
@@ -81,7 +84,6 @@ torch.manual_seed(0)
class
NewReverseDiffusionPredictor
:
class
NewReverseDiffusionPredictor
:
def
__init__
(
self
,
sde
,
score_fn
,
probability_flow
=
False
):
def
__init__
(
self
,
sde
,
score_fn
,
probability_flow
=
False
):
super
().
__init__
()
super
().
__init__
()
self
.
sde
=
sde
self
.
sde
=
sde
...
@@ -112,7 +114,6 @@ class NewReverseDiffusionPredictor:
...
@@ -112,7 +114,6 @@ class NewReverseDiffusionPredictor:
class
NewLangevinCorrector
:
class
NewLangevinCorrector
:
def
__init__
(
self
,
sde
,
score_fn
,
snr
,
n_steps
):
def
__init__
(
self
,
sde
,
score_fn
,
snr
,
n_steps
):
super
().
__init__
()
super
().
__init__
()
self
.
sde
=
sde
self
.
sde
=
sde
...
@@ -146,28 +147,19 @@ class NewLangevinCorrector:
...
@@ -146,28 +147,19 @@ class NewLangevinCorrector:
def
save_image
(
x
):
def
save_image
(
x
):
# image_processed = x.cpu().permute(0, 2, 3, 1)
# image_processed = (image_processed + 1.0) * 127.5
# image_processed = image_processed.numpy().astype(np.uint8)
image_processed
=
np
.
clip
(
x
.
permute
(
0
,
2
,
3
,
1
).
cpu
().
numpy
()
*
255
,
0
,
255
).
astype
(
np
.
uint8
)
image_processed
=
np
.
clip
(
x
.
permute
(
0
,
2
,
3
,
1
).
cpu
().
numpy
()
*
255
,
0
,
255
).
astype
(
np
.
uint8
)
image_pil
=
PIL
.
Image
.
fromarray
(
image_processed
[
0
])
image_pil
=
PIL
.
Image
.
fromarray
(
image_processed
[
0
])
# 6. save image
image_pil
.
save
(
"../images/hey.png"
)
image_pil
.
save
(
"../images/hey.png"
)
#x = np.load("cifar10.npy")
#
#save_image(x)
# @title Load the score-based model
sde
=
'VESDE'
#@param ['VESDE', 'VPSDE', 'subVPSDE'] {"type": "string"}
sde
=
'VESDE'
#@param ['VESDE', 'VPSDE', 'subVPSDE'] {"type": "string"}
if
sde
.
lower
()
==
'vesde'
:
if
sde
.
lower
()
==
'vesde'
:
from
configs.ve
import
cifar10_ncsnpp_continuous
as
configs
#
from configs.ve import cifar10_ncsnpp_continuous as configs
ckpt_filename
=
"exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
#
ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
#
from configs.ve import ffhq_ncsnpp_continuous as configs
from
configs.ve
import
ffhq_ncsnpp_continuous
as
configs
#
ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
ckpt_filename
=
"exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
config
=
configs
.
get_config
()
config
=
configs
.
get_config
()
config
.
model
.
num_scales
=
1000
config
.
model
.
num_scales
=
2
sde
=
VESDE
(
sigma_min
=
config
.
model
.
sigma_min
,
sigma_max
=
config
.
model
.
sigma_max
,
N
=
config
.
model
.
num_scales
)
sde
=
VESDE
(
sigma_min
=
config
.
model
.
sigma_min
,
sigma_max
=
config
.
model
.
sigma_max
,
N
=
config
.
model
.
num_scales
)
sampling_eps
=
1e-5
sampling_eps
=
1e-5
elif
sde
.
lower
()
==
'vpsde'
:
elif
sde
.
lower
()
==
'vpsde'
:
...
@@ -189,32 +181,53 @@ config.eval.batch_size = batch_size
...
@@ -189,32 +181,53 @@ config.eval.batch_size = batch_size
random_seed
=
0
#@param {"type": "integer"}
random_seed
=
0
#@param {"type": "integer"}
score_model
=
mutils
.
create_model
(
config
)
#sigmas = mutils.get_sigmas(config)
#scaler = datasets.get_data_scaler(config)
#inverse_scaler = datasets.get_data_inverse_scaler(config)
#score_model = mutils.create_model(config)
#
#optimizer = get_optimizer(config, score_model.parameters())
#ema = ExponentialMovingAverage(score_model.parameters(),
# decay=config.model.ema_rate)
#state = dict(step=0, optimizer=optimizer,
# model=score_model, ema=ema)
#
#state = restore_checkpoint(ckpt_filename, state, config.device)
#ema.copy_to(score_model.parameters())
#score_model = mutils.create_model(config)
from
diffusers
import
NCSNpp
score_model
=
NCSNpp
(
config
).
to
(
config
.
device
)
score_model
=
torch
.
nn
.
DataParallel
(
score_model
)
loaded_state
=
torch
.
load
(
ckpt_filename
)
loaded_state
=
torch
.
load
(
"./ffhq_1024_ncsnpp_continuous_ema.pt"
)
score_model
.
load_state_dict
(
loaded_state
[
"model"
],
strict
=
False
)
del
loaded_state
[
"module.sigmas"
]
score_model
.
load_state_dict
(
loaded_state
,
strict
=
False
)
inverse_scaler
=
datasets
.
get_data_inverse_scaler
(
config
)
inverse_scaler
=
datasets
.
get_data_inverse_scaler
(
config
)
predictor
=
ReverseDiffusionPredictor
#@param ["EulerMaruyamaPredictor", "AncestralSamplingPredictor", "ReverseDiffusionPredictor", "None"] {"type": "raw"}
predictor
=
ReverseDiffusionPredictor
#@param ["EulerMaruyamaPredictor", "AncestralSamplingPredictor", "ReverseDiffusionPredictor", "None"] {"type": "raw"}
corrector
=
LangevinCorrector
#@param ["LangevinCorrector", "AnnealedLangevinDynamics", "None"] {"type": "raw"}
corrector
=
LangevinCorrector
#@param ["LangevinCorrector", "AnnealedLangevinDynamics", "None"] {"type": "raw"}
def
image_grid
(
x
):
size
=
config
.
data
.
image_size
channels
=
config
.
data
.
num_channels
img
=
x
.
reshape
(
-
1
,
size
,
size
,
channels
)
w
=
int
(
np
.
sqrt
(
img
.
shape
[
0
]))
img
=
img
.
reshape
((
w
,
w
,
size
,
size
,
channels
)).
transpose
((
0
,
2
,
1
,
3
,
4
)).
reshape
((
w
*
size
,
w
*
size
,
channels
))
return
img
#@title PC sampling
#@title PC sampling
img_size
=
config
.
data
.
image_size
img_size
=
config
.
data
.
image_size
channels
=
config
.
data
.
num_channels
channels
=
config
.
data
.
num_channels
shape
=
(
batch_size
,
channels
,
img_size
,
img_size
)
shape
=
(
batch_size
,
channels
,
img_size
,
img_size
)
probability_flow
=
False
probability_flow
=
False
snr
=
0.1
6
#@param {"type": "number"}
snr
=
0.1
5
#@param {"type": "number"}
n_steps
=
1
#@param {"type": "integer"}
n_steps
=
1
#@param {"type": "integer"}
#sampling_fn = sampling.get_pc_sampler(sde, shape, predictor, corrector,
# inverse_scaler, snr, n_steps=n_steps,
# probability_flow=probability_flow,
# continuous=config.training.continuous,
# eps=sampling_eps, device=config.device)
#
#x, n = sampling_fn(score_model)
#save_image(x)
def
shared_predictor_update_fn
(
x
,
t
,
sde
,
model
,
predictor
,
probability_flow
,
continuous
):
def
shared_predictor_update_fn
(
x
,
t
,
sde
,
model
,
predictor
,
probability_flow
,
continuous
):
"""A wrapper that configures and returns the update function of predictors."""
"""A wrapper that configures and returns the update function of predictors."""
score_fn
=
mutils
.
get_score_fn
(
sde
,
model
,
train
=
False
,
continuous
=
continuous
)
score_fn
=
mutils
.
get_score_fn
(
sde
,
model
,
train
=
False
,
continuous
=
continuous
)
...
@@ -253,14 +266,14 @@ corrector_update_fn = functools.partial(shared_corrector_update_fn,
...
@@ -253,14 +266,14 @@ corrector_update_fn = functools.partial(shared_corrector_update_fn,
snr
=
snr
,
snr
=
snr
,
n_steps
=
n_steps
)
n_steps
=
n_steps
)
device
=
"cuda"
device
=
config
.
device
model
=
score_model
.
to
(
device
)
model
=
score_model
denoise
=
Fals
e
denoise
=
Tru
e
new_corrector
=
NewLangevinCorrector
(
sde
=
sde
,
score_fn
=
model
,
snr
=
snr
,
n_steps
=
n_steps
)
new_corrector
=
NewLangevinCorrector
(
sde
=
sde
,
score_fn
=
model
,
snr
=
snr
,
n_steps
=
n_steps
)
new_predictor
=
NewReverseDiffusionPredictor
(
sde
=
sde
,
score_fn
=
model
)
new_predictor
=
NewReverseDiffusionPredictor
(
sde
=
sde
,
score_fn
=
model
)
#
with
torch
.
no_grad
():
with
torch
.
no_grad
():
# Initial sample
# Initial sample
x
=
sde
.
prior_sampling
(
shape
).
to
(
device
)
x
=
sde
.
prior_sampling
(
shape
).
to
(
device
)
...
@@ -269,21 +282,32 @@ with torch.no_grad():
...
@@ -269,21 +282,32 @@ with torch.no_grad():
for
i
in
range
(
sde
.
N
):
for
i
in
range
(
sde
.
N
):
t
=
timesteps
[
i
]
t
=
timesteps
[
i
]
vec_t
=
torch
.
ones
(
shape
[
0
],
device
=
t
.
device
)
*
t
vec_t
=
torch
.
ones
(
shape
[
0
],
device
=
t
.
device
)
*
t
x
,
x_mean
=
corrector_update_fn
(
x
,
vec_t
,
model
=
model
)
#
x, x_mean = corrector_update_fn(x, vec_t, model=model)
x
,
x_mean
=
predictor_update_fn
(
x
,
vec_t
,
model
=
model
)
#
x, x_mean = predictor_update_fn(x, vec_t, model=model)
#
x, x_mean = new_corrector.update_fn(x, vec_t)
x
,
x_mean
=
new_corrector
.
update_fn
(
x
,
vec_t
)
#
x, x_mean = new_predictor.update_fn(x, vec_t)
x
,
x_mean
=
new_predictor
.
update_fn
(
x
,
vec_t
)
x
,
n
=
inverse_scaler
(
x_mean
if
denoise
else
x
),
sde
.
N
*
(
n_steps
+
1
)
x
,
n
=
inverse_scaler
(
x_mean
if
denoise
else
x
),
sde
.
N
*
(
n_steps
+
1
)
save_image
(
x
)
# for 5
#save_image(x)
#assert (x.abs().sum() - 106114.90625).cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
#assert (x.abs().mean() - 34.5426139831543).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
# for 5 cifar10
x_sum
=
106071.9922
x_mean
=
34.52864456176758
# for 1000 cifar10
x_sum
=
461.9700
x_mean
=
0.1504
# for 2 for 1024
x_sum
=
3382810112.0
x_mean
=
1075.366455078125
def
check_x_sum_x_mean
(
x
,
x_sum
,
x_mean
):
assert
(
x
.
abs
().
sum
()
-
x_sum
).
abs
().
cpu
().
item
()
<
1e-2
,
f
"sum wrong
{
x
.
abs
().
sum
()
}
"
assert
(
x
.
abs
().
mean
()
-
x_mean
).
abs
().
cpu
().
item
()
<
1e-4
,
f
"mean wrong
{
x
.
abs
().
mean
()
}
"
# for 1000
assert
(
x
.
abs
().
sum
()
-
436.5811
).
abs
().
sum
().
cpu
().
item
()
<
1e-2
,
f
"sum wrong
{
x
.
abs
().
sum
()
}
"
assert
(
x
.
abs
().
mean
()
-
0.1421
).
abs
().
mean
().
cpu
().
item
()
<
1e-4
,
f
"mean wrong
{
x
.
abs
().
mean
()
}
"
check_x_sum_x_mean
(
x
,
x_sum
,
x_mean
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment