Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
d5c527a4
"src/vscode:/vscode.git/clone" did not exist on "d5dd8df3b4e978e3f0549fa35e245a8785116af7"
Commit
d5c527a4
authored
Jun 26, 2022
by
Patrick von Platen
Browse files
clean up
parent
135acd83
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
53 deletions
+2
-53
src/diffusers/pipelines/pipeline_score_sde_ve.py
src/diffusers/pipelines/pipeline_score_sde_ve.py
+2
-53
No files found.
src/diffusers/pipelines/pipeline_score_sde_ve.py
View file @
d5c527a4
#!/usr/bin/env python3
#!/usr/bin/env python3
import
numpy
as
np
import
torch
import
torch
import
PIL
from
diffusers
import
DiffusionPipeline
from
diffusers
import
DiffusionPipeline
# TODO(Patrick, Anton, Suraj) - rename `x` to better variable names
# TODO(Patrick, Anton, Suraj) - rename `x` to better variable names
class
ScoreSdeVePipeline
(
DiffusionPipeline
):
class
ScoreSdeVePipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
model
,
scheduler
):
def
__init__
(
self
,
model
,
scheduler
):
super
().
__init__
()
super
().
__init__
()
...
@@ -23,7 +18,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
...
@@ -23,7 +18,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
model
=
self
.
model
.
to
(
device
)
model
=
self
.
model
.
to
(
device
)
centered
=
False
# TODO(Patrick) move to scheduler config
n_steps
=
1
n_steps
=
1
x
=
torch
.
randn
(
*
shape
)
*
self
.
scheduler
.
config
.
sigma_max
x
=
torch
.
randn
(
*
shape
)
*
self
.
scheduler
.
config
.
sigma_max
...
@@ -45,50 +40,4 @@ class ScoreSdeVePipeline(DiffusionPipeline):
...
@@ -45,50 +40,4 @@ class ScoreSdeVePipeline(DiffusionPipeline):
x
,
x_mean
=
self
.
scheduler
.
step_pred
(
result
,
x
,
t
)
x
,
x_mean
=
self
.
scheduler
.
step_pred
(
result
,
x
,
t
)
x
=
x_mean
return
x_mean
if
centered
:
x
=
(
x
+
1.0
)
/
2.0
return
x
# from configs.ve import ffhq_ncsnpp_continuous as configs
# from configs.ve import cifar10_ncsnpp_continuous as configs
# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
# ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
# Note usually we need to restore ema etc...
# ema restored checkpoint used from below
# pipeline = ScoreSdeVePipeline.from_pretrained("/home/patrick/ffhq_ncsnpp")
# x = pipeline(num_inference_steps=2)
# for 5 cifar10
# x_sum = 106071.9922
# x_mean = 34.52864456176758
# for 1000 cifar10
# x_sum = 461.9700
# x_mean = 0.1504
# for N=2 for 1024
# x_sum = 3382810112.0
# x_mean = 1075.366455078125
#
#
# def check_x_sum_x_mean(x, x_sum, x_mean):
# assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
# assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
#
#
# check_x_sum_x_mean(x, x_sum, x_mean)
#
#
# def save_image(x):
# image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8)
# image_pil = PIL.Image.fromarray(image_processed[0])
# image_pil.save("../images/hey.png")
#
#
# save_image(x)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment