Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
fc67917a
Commit
fc67917a
authored
Jun 24, 2022
by
Patrick von Platen
Browse files
up
parent
7ca832ca
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
9 deletions
+10
-9
run.py
run.py
+10
-9
No files found.
run.py
View file @
fc67917a
...
@@ -269,20 +269,21 @@ with torch.no_grad():
...
@@ -269,20 +269,21 @@ with torch.no_grad():
for
i
in
range
(
sde
.
N
):
for
i
in
range
(
sde
.
N
):
t
=
timesteps
[
i
]
t
=
timesteps
[
i
]
vec_t
=
torch
.
ones
(
shape
[
0
],
device
=
t
.
device
)
*
t
vec_t
=
torch
.
ones
(
shape
[
0
],
device
=
t
.
device
)
*
t
#
x, x_mean = corrector_update_fn(x, vec_t, model=model)
x
,
x_mean
=
corrector_update_fn
(
x
,
vec_t
,
model
=
model
)
#
x, x_mean = predictor_update_fn(x, vec_t, model=model)
x
,
x_mean
=
predictor_update_fn
(
x
,
vec_t
,
model
=
model
)
x
,
x_mean
=
new_corrector
.
update_fn
(
x
,
vec_t
)
#
x, x_mean = new_corrector.update_fn(x, vec_t)
x
,
x_mean
=
new_predictor
.
update_fn
(
x
,
vec_t
)
#
x, x_mean = new_predictor.update_fn(x, vec_t)
x
,
n
=
inverse_scaler
(
x_mean
if
denoise
else
x
),
sde
.
N
*
(
n_steps
+
1
)
x
,
n
=
inverse_scaler
(
x_mean
if
denoise
else
x
),
sde
.
N
*
(
n_steps
+
1
)
save_image
(
x
)
# for 5
# for 5
#assert x.abs().sum()
.cpu().item()
- 106114.90625 < 1e-2, "sum wrong"
#assert
(
x.abs().sum() - 106114.90625
).cpu().item()
< 1e-2,
f
"sum wrong
{x.abs().sum()}
"
#assert x.abs().mean()
.cpu().item()
- 34.5426139831543 < 1e-4, "mean wrong"
#assert
(
x.abs().mean() - 34.5426139831543
).abs().cpu().item()
< 1e-4,
f
"mean wrong
{x.abs().mean()}
"
# for 1000
# for 1000
assert
x
.
abs
().
sum
().
cpu
().
item
()
-
436.5811
<
1e-2
,
"sum wrong"
assert
(
x
.
abs
().
sum
()
-
436.5811
).
abs
().
sum
()
.
cpu
().
item
()
<
1e-2
,
f
"sum wrong
{
x
.
abs
().
sum
()
}
"
assert
x
.
abs
().
mean
().
cpu
().
item
()
-
0.1421
<
1e-4
,
"mean wrong"
assert
(
x
.
abs
().
mean
()
-
0.1421
)
.
abs
().
mean
().
cpu
().
item
()
<
1e-4
,
f
"mean wrong
{
x
.
abs
().
mean
()
}
"
save_image
(
x
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment