Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
394243ce
Commit
394243ce
authored
Jul 21, 2022
by
Patrick von Platen
Browse files
finish pndm sampler
parent
fe985746
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
78 additions
and
71 deletions
+78
-71
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
+2
-5
src/diffusers/pipelines/pndm/pipeline_pndm.py
src/diffusers/pipelines/pndm/pipeline_pndm.py
+12
-7
src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
...diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
+5
-11
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+2
-2
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+0
-1
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+39
-27
tests/test_scheduler.py
tests/test_scheduler.py
+18
-18
No files found.
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
View file @
394243ce
...
@@ -48,11 +48,8 @@ class DDPMPipeline(DiffusionPipeline):
...
@@ -48,11 +48,8 @@ class DDPMPipeline(DiffusionPipeline):
# 1. predict noise model_output
# 1. predict noise model_output
model_output
=
self
.
unet
(
image
,
t
)[
"sample"
]
model_output
=
self
.
unet
(
image
,
t
)[
"sample"
]
# 2. predict previous mean of image x_t-1
# 2. compute previous image: x_t -> t_t-1
pred_prev_image
=
self
.
scheduler
.
step
(
model_output
,
t
,
image
)[
"prev_sample"
]
image
=
self
.
scheduler
.
step
(
model_output
,
t
,
image
)[
"prev_sample"
]
# 3. set current image to prev_image: x_t -> x_t-1
image
=
pred_prev_image
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
...
...
src/diffusers/pipelines/pndm/pipeline_pndm.py
View file @
394243ce
...
@@ -44,15 +44,20 @@ class PNDMPipeline(DiffusionPipeline):
...
@@ -44,15 +44,20 @@ class PNDMPipeline(DiffusionPipeline):
image
=
image
.
to
(
torch_device
)
image
=
image
.
to
(
torch_device
)
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
for
i
,
t
in
enumerate
(
tqdm
(
self
.
scheduler
.
prk_
timesteps
)
)
:
for
t
in
tqdm
(
self
.
scheduler
.
timesteps
):
model_output
=
self
.
unet
(
image
,
t
)[
"sample"
]
model_output
=
self
.
unet
(
image
,
t
)[
"sample"
]
image
=
self
.
scheduler
.
step_prk
(
model_output
,
i
,
image
,
num_inference_steps
)[
"prev_sample"
]
image
=
self
.
scheduler
.
step
(
model_output
,
t
,
image
)[
"prev_sample"
]
for
i
,
t
in
enumerate
(
tqdm
(
self
.
scheduler
.
plms_timesteps
)):
# for i, t in enumerate(tqdm(self.scheduler.prk_timesteps)):
model_output
=
self
.
unet
(
image
,
t
)[
"sample"
]
# model_output = self.unet(image, t)["sample"]
#
image
=
self
.
scheduler
.
step_plms
(
model_output
,
i
,
image
,
num_inference_steps
)[
"prev_sample"
]
# image = self.scheduler.step_prk(model_output, t, image, i=i)["prev_sample"]
#
# for i, t in enumerate(tqdm(self.scheduler.plms_timesteps)):
# model_output = self.unet(image, t)["sample"]
#
# image = self.scheduler.step_plms(model_output, t, image, i=i)["prev_sample"]
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
...
...
src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
View file @
394243ce
...
@@ -28,21 +28,15 @@ class ScoreSdeVePipeline(DiffusionPipeline):
...
@@ -28,21 +28,15 @@ class ScoreSdeVePipeline(DiffusionPipeline):
for
i
,
t
in
tqdm
(
enumerate
(
self
.
scheduler
.
timesteps
)):
for
i
,
t
in
tqdm
(
enumerate
(
self
.
scheduler
.
timesteps
)):
sigma_t
=
self
.
scheduler
.
sigmas
[
i
]
*
torch
.
ones
(
shape
[
0
],
device
=
device
)
sigma_t
=
self
.
scheduler
.
sigmas
[
i
]
*
torch
.
ones
(
shape
[
0
],
device
=
device
)
# correction step
for
_
in
range
(
self
.
scheduler
.
correct_steps
):
for
_
in
range
(
self
.
scheduler
.
correct_steps
):
model_output
=
self
.
model
(
sample
,
sigma_t
)
model_output
=
self
.
model
(
sample
,
sigma_t
)[
"sample"
]
if
isinstance
(
model_output
,
dict
):
model_output
=
model_output
[
"sample"
]
sample
=
self
.
scheduler
.
step_correct
(
model_output
,
sample
)[
"prev_sample"
]
sample
=
self
.
scheduler
.
step_correct
(
model_output
,
sample
)[
"prev_sample"
]
with
torch
.
no_grad
():
# prediction step
model_output
=
model
(
sample
,
sigma_t
)
model_output
=
model
(
sample
,
sigma_t
)[
"sample"
]
if
isinstance
(
model_output
,
dict
):
model_output
=
model_output
[
"sample"
]
output
=
self
.
scheduler
.
step_pred
(
model_output
,
t
,
sample
)
output
=
self
.
scheduler
.
step_pred
(
model_output
,
t
,
sample
)
sample
,
sample_mean
=
output
[
"prev_sample"
],
output
[
"prev_sample_mean"
]
sample
,
sample_mean
=
output
[
"prev_sample"
],
output
[
"prev_sample_mean"
]
sample
=
sample
.
clamp
(
0
,
1
)
sample
=
sample
.
clamp
(
0
,
1
)
...
...
src/diffusers/schedulers/scheduling_ddim.py
View file @
394243ce
...
@@ -106,8 +106,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -106,8 +106,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
model_output
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
model_output
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
timestep
:
int
,
timestep
:
int
,
sample
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
sample
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
eta
,
eta
:
float
=
0.0
,
use_clipped_model_output
=
False
,
use_clipped_model_output
:
bool
=
False
,
generator
=
None
,
generator
=
None
,
):
):
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
394243ce
...
@@ -56,7 +56,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -56,7 +56,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
beta_end
=
0.02
,
beta_end
=
0.02
,
beta_schedule
=
"linear"
,
beta_schedule
=
"linear"
,
trained_betas
=
None
,
trained_betas
=
None
,
timestep_values
=
None
,
variance_type
=
"fixed_small"
,
variance_type
=
"fixed_small"
,
clip_sample
=
True
,
clip_sample
=
True
,
tensor_format
=
"pt"
,
tensor_format
=
"pt"
,
...
...
src/diffusers/schedulers/scheduling_pndm.py
View file @
394243ce
...
@@ -15,7 +15,6 @@
...
@@ -15,7 +15,6 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import
math
import
math
import
pdb
from
typing
import
Union
from
typing
import
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -79,78 +78,91 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -79,78 +78,91 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# running values
# running values
self
.
cur_model_output
=
0
self
.
cur_model_output
=
0
self
.
counter
=
0
self
.
cur_sample
=
None
self
.
cur_sample
=
None
self
.
ets
=
[]
self
.
ets
=
[]
# setable values
# setable values
self
.
num_inference_steps
=
None
self
.
num_inference_steps
=
None
self
.
timesteps
=
np
.
arange
(
0
,
num_train_timesteps
)[::
-
1
].
copy
()
self
.
_
timesteps
=
np
.
arange
(
0
,
num_train_timesteps
)[::
-
1
].
copy
()
self
.
prk_timesteps
=
None
self
.
prk_timesteps
=
None
self
.
plms_timesteps
=
None
self
.
plms_timesteps
=
None
self
.
timesteps
=
None
self
.
tensor_format
=
tensor_format
self
.
tensor_format
=
tensor_format
self
.
set_format
(
tensor_format
=
tensor_format
)
self
.
set_format
(
tensor_format
=
tensor_format
)
def
set_timesteps
(
self
,
num_inference_steps
):
def
set_timesteps
(
self
,
num_inference_steps
):
self
.
num_inference_steps
=
num_inference_steps
self
.
num_inference_steps
=
num_inference_steps
self
.
timesteps
=
list
(
self
.
_
timesteps
=
list
(
range
(
0
,
self
.
config
.
num_train_timesteps
,
self
.
config
.
num_train_timesteps
//
num_inference_steps
)
range
(
0
,
self
.
config
.
num_train_timesteps
,
self
.
config
.
num_train_timesteps
//
num_inference_steps
)
)
)
prk_time
_
steps
=
np
.
array
(
self
.
timesteps
[
-
self
.
pndm_order
:]).
repeat
(
2
)
+
np
.
tile
(
prk_timesteps
=
np
.
array
(
self
.
_
timesteps
[
-
self
.
pndm_order
:]).
repeat
(
2
)
+
np
.
tile
(
np
.
array
([
0
,
self
.
config
.
num_train_timesteps
//
num_inference_steps
//
2
]),
self
.
pndm_order
np
.
array
([
0
,
self
.
config
.
num_train_timesteps
//
num_inference_steps
//
2
]),
self
.
pndm_order
)
)
self
.
prk_timesteps
=
list
(
reversed
(
prk_time_steps
[:
-
1
].
repeat
(
2
)[
1
:
-
1
]))
self
.
prk_timesteps
=
list
(
reversed
(
prk_timesteps
[:
-
1
].
repeat
(
2
)[
1
:
-
1
]))
self
.
plms_timesteps
=
list
(
reversed
(
self
.
timesteps
[:
-
3
]))
self
.
plms_timesteps
=
list
(
reversed
(
self
.
_timesteps
[:
-
3
]))
self
.
timesteps
=
self
.
prk_timesteps
+
self
.
plms_timesteps
self
.
counter
=
0
self
.
set_format
(
tensor_format
=
self
.
tensor_format
)
self
.
set_format
(
tensor_format
=
self
.
tensor_format
)
def
step
(
self
,
model_output
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
timestep
:
int
,
sample
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
):
if
self
.
counter
<
len
(
self
.
prk_timesteps
):
return
self
.
step_prk
(
model_output
=
model_output
,
timestep
=
timestep
,
sample
=
sample
)
else
:
return
self
.
step_plms
(
model_output
=
model_output
,
timestep
=
timestep
,
sample
=
sample
)
def
step_prk
(
def
step_prk
(
self
,
self
,
model_output
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
model_output
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
timestep
:
int
,
timestep
:
int
,
sample
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
sample
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
num_inference_steps
,
):
):
"""
"""
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
solution to the differential equation.
solution to the differential equation.
"""
"""
t
=
timestep
diff_to_prev
=
0
if
self
.
counter
%
2
else
self
.
config
.
num_train_timesteps
//
self
.
num_inference_steps
//
2
prk_time_steps
=
self
.
prk_timesteps
prev_timestep
=
max
(
timestep
-
diff_to_prev
,
self
.
prk_timesteps
[
-
1
])
timestep
=
self
.
prk_timesteps
[
self
.
counter
//
4
*
4
]
t_orig
=
prk_time_steps
[
t
//
4
*
4
]
t_orig_prev
=
prk_time_steps
[
min
(
t
+
1
,
len
(
prk_time_steps
)
-
1
)]
if
t
%
4
==
0
:
if
self
.
counter
%
4
==
0
:
self
.
cur_model_output
+=
1
/
6
*
model_output
self
.
cur_model_output
+=
1
/
6
*
model_output
self
.
ets
.
append
(
model_output
)
self
.
ets
.
append
(
model_output
)
self
.
cur_sample
=
sample
self
.
cur_sample
=
sample
elif
(
t
-
1
)
%
4
==
0
:
elif
(
self
.
counter
-
1
)
%
4
==
0
:
self
.
cur_model_output
+=
1
/
3
*
model_output
self
.
cur_model_output
+=
1
/
3
*
model_output
elif
(
t
-
2
)
%
4
==
0
:
elif
(
self
.
counter
-
2
)
%
4
==
0
:
self
.
cur_model_output
+=
1
/
3
*
model_output
self
.
cur_model_output
+=
1
/
3
*
model_output
elif
(
t
-
3
)
%
4
==
0
:
elif
(
self
.
counter
-
3
)
%
4
==
0
:
model_output
=
self
.
cur_model_output
+
1
/
6
*
model_output
model_output
=
self
.
cur_model_output
+
1
/
6
*
model_output
self
.
cur_model_output
=
0
self
.
cur_model_output
=
0
# cur_sample should not be `None`
# cur_sample should not be `None`
cur_sample
=
self
.
cur_sample
if
self
.
cur_sample
is
not
None
else
sample
cur_sample
=
self
.
cur_sample
if
self
.
cur_sample
is
not
None
else
sample
return
{
"prev_sample"
:
self
.
get_prev_sample
(
cur_sample
,
t_orig
,
t_orig_prev
,
model_output
)}
prev_sample
=
self
.
_get_prev_sample
(
cur_sample
,
timestep
,
prev_timestep
,
model_output
)
self
.
counter
+=
1
return
{
"prev_sample"
:
prev_sample
}
def
step_plms
(
def
step_plms
(
self
,
self
,
model_output
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
model_output
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
timestep
:
int
,
timestep
:
int
,
sample
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
sample
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
num_inference_steps
,
):
):
"""
"""
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
times to approximate the solution.
times to approximate the solution.
"""
"""
t
=
timestep
if
len
(
self
.
ets
)
<
3
:
if
len
(
self
.
ets
)
<
3
:
raise
ValueError
(
raise
ValueError
(
f
"
{
self
.
__class__
}
can only be run AFTER scheduler has been run "
f
"
{
self
.
__class__
}
can only be run AFTER scheduler has been run "
...
@@ -159,17 +171,17 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -159,17 +171,17 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
"for more information."
"for more information."
)
)
timesteps
=
self
.
plms_timesteps
prev_timestep
=
max
(
timestep
-
self
.
config
.
num_train_timesteps
//
self
.
num_inference_steps
,
0
)
t_orig
=
timesteps
[
t
]
t_orig_prev
=
timesteps
[
min
(
t
+
1
,
len
(
timesteps
)
-
1
)]
self
.
ets
.
append
(
model_output
)
self
.
ets
.
append
(
model_output
)
model_output
=
(
1
/
24
)
*
(
55
*
self
.
ets
[
-
1
]
-
59
*
self
.
ets
[
-
2
]
+
37
*
self
.
ets
[
-
3
]
-
9
*
self
.
ets
[
-
4
])
model_output
=
(
1
/
24
)
*
(
55
*
self
.
ets
[
-
1
]
-
59
*
self
.
ets
[
-
2
]
+
37
*
self
.
ets
[
-
3
]
-
9
*
self
.
ets
[
-
4
])
return
{
"prev_sample"
:
self
.
get_prev_sample
(
sample
,
t_orig
,
t_orig_prev
,
model_output
)}
prev_sample
=
self
.
_get_prev_sample
(
sample
,
timestep
,
prev_timestep
,
model_output
)
self
.
counter
+=
1
return
{
"prev_sample"
:
prev_sample
}
def
get_prev_sample
(
self
,
sample
,
t
_orig
,
t_orig
_prev
,
model_output
):
def
_
get_prev_sample
(
self
,
sample
,
t
imestep
,
timestep
_prev
,
model_output
):
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
# this function computes x_(t−δ) using the formula of (9)
# this function computes x_(t−δ) using the formula of (9)
# Note that x_t needs to be added to both sides of the equation
# Note that x_t needs to be added to both sides of the equation
...
@@ -182,8 +194,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -182,8 +194,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# sample -> x_t
# sample -> x_t
# model_output -> e_θ(x_t, t)
# model_output -> e_θ(x_t, t)
# prev_sample -> x_(t−δ)
# prev_sample -> x_(t−δ)
alpha_prod_t
=
self
.
alphas_cumprod
[
t
_orig
+
1
]
alpha_prod_t
=
self
.
alphas_cumprod
[
t
imestep
+
1
]
alpha_prod_t_prev
=
self
.
alphas_cumprod
[
t
_orig
_prev
+
1
]
alpha_prod_t_prev
=
self
.
alphas_cumprod
[
t
imestep
_prev
+
1
]
beta_prod_t
=
1
-
alpha_prod_t
beta_prod_t
=
1
-
alpha_prod_t
beta_prod_t_prev
=
1
-
alpha_prod_t_prev
beta_prod_t_prev
=
1
-
alpha_prod_t_prev
...
...
tests/test_scheduler.py
View file @
394243ce
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
pdb
import
tempfile
import
tempfile
import
unittest
import
unittest
...
@@ -383,6 +382,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -383,6 +382,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
def
check_over_configs
(
self
,
time_step
=
0
,
**
config
):
def
check_over_configs
(
self
,
time_step
=
0
,
**
config
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
=
dict
(
self
.
forward_default_kwargs
)
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
sample
=
self
.
dummy_sample
sample
=
self
.
dummy_sample
residual
=
0.1
*
sample
residual
=
0.1
*
sample
dummy_past_residuals
=
[
residual
+
0.2
,
residual
+
0.15
,
residual
+
0.1
,
residual
+
0.05
]
dummy_past_residuals
=
[
residual
+
0.2
,
residual
+
0.15
,
residual
+
0.1
,
residual
+
0.05
]
...
@@ -390,14 +390,14 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -390,14 +390,14 @@ class PNDMSchedulerTest(SchedulerCommonTest):
for
scheduler_class
in
self
.
scheduler_classes
:
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_config
=
self
.
get_scheduler_config
(
**
config
)
scheduler_config
=
self
.
get_scheduler_config
(
**
config
)
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
.
set_timesteps
(
kwargs
[
"
num_inference_steps
"
]
)
scheduler
.
set_timesteps
(
num_inference_steps
)
# copy over dummy past residuals
# copy over dummy past residuals
scheduler
.
ets
=
dummy_past_residuals
[:]
scheduler
.
ets
=
dummy_past_residuals
[:]
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
new_scheduler
.
set_timesteps
(
kwargs
[
"
num_inference_steps
"
]
)
new_scheduler
.
set_timesteps
(
num_inference_steps
)
# copy over dummy past residuals
# copy over dummy past residuals
new_scheduler
.
ets
=
dummy_past_residuals
[:]
new_scheduler
.
ets
=
dummy_past_residuals
[:]
...
@@ -416,7 +416,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -416,7 +416,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
def
check_over_forward
(
self
,
time_step
=
0
,
**
forward_kwargs
):
def
check_over_forward
(
self
,
time_step
=
0
,
**
forward_kwargs
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
.
update
(
forward_kwargs
)
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
sample
=
self
.
dummy_sample
sample
=
self
.
dummy_sample
residual
=
0.1
*
sample
residual
=
0.1
*
sample
dummy_past_residuals
=
[
residual
+
0.2
,
residual
+
0.15
,
residual
+
0.1
,
residual
+
0.05
]
dummy_past_residuals
=
[
residual
+
0.2
,
residual
+
0.15
,
residual
+
0.1
,
residual
+
0.05
]
...
@@ -424,7 +424,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -424,7 +424,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
for
scheduler_class
in
self
.
scheduler_classes
:
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_config
=
self
.
get_scheduler_config
()
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
.
set_timesteps
(
kwargs
[
"
num_inference_steps
"
]
)
scheduler
.
set_timesteps
(
num_inference_steps
)
# copy over dummy past residuals
# copy over dummy past residuals
scheduler
.
ets
=
dummy_past_residuals
[:]
scheduler
.
ets
=
dummy_past_residuals
[:]
...
@@ -434,7 +434,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -434,7 +434,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
# copy over dummy past residuals
# copy over dummy past residuals
new_scheduler
.
ets
=
dummy_past_residuals
[:]
new_scheduler
.
ets
=
dummy_past_residuals
[:]
new_scheduler
.
set_timesteps
(
kwargs
[
"
num_inference_steps
"
]
)
new_scheduler
.
set_timesteps
(
num_inference_steps
)
output
=
scheduler
.
step_prk
(
residual
,
time_step
,
sample
,
**
kwargs
)[
"prev_sample"
]
output
=
scheduler
.
step_prk
(
residual
,
time_step
,
sample
,
**
kwargs
)[
"prev_sample"
]
new_output
=
new_scheduler
.
step_prk
(
residual
,
time_step
,
sample
,
**
kwargs
)[
"prev_sample"
]
new_output
=
new_scheduler
.
step_prk
(
residual
,
time_step
,
sample
,
**
kwargs
)[
"prev_sample"
]
...
@@ -474,12 +474,12 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -474,12 +474,12 @@ class PNDMSchedulerTest(SchedulerCommonTest):
elif
num_inference_steps
is
not
None
and
not
hasattr
(
scheduler
,
"set_timesteps"
):
elif
num_inference_steps
is
not
None
and
not
hasattr
(
scheduler
,
"set_timesteps"
):
kwargs
[
"num_inference_steps"
]
=
num_inference_steps
kwargs
[
"num_inference_steps"
]
=
num_inference_steps
output
=
scheduler
.
step_prk
(
residual
,
1
,
sample
,
num_inference_steps
,
**
kwargs
)[
"prev_sample"
]
output
=
scheduler
.
step_prk
(
residual
,
1
,
sample
,
**
kwargs
)[
"prev_sample"
]
output_pt
=
scheduler_pt
.
step_prk
(
residual_pt
,
1
,
sample_pt
,
num_inference_steps
,
**
kwargs
)[
"prev_sample"
]
output_pt
=
scheduler_pt
.
step_prk
(
residual_pt
,
1
,
sample_pt
,
**
kwargs
)[
"prev_sample"
]
assert
np
.
sum
(
np
.
abs
(
output
-
output_pt
.
numpy
()))
<
1e-4
,
"Scheduler outputs are not identical"
assert
np
.
sum
(
np
.
abs
(
output
-
output_pt
.
numpy
()))
<
1e-4
,
"Scheduler outputs are not identical"
output
=
scheduler
.
step_plms
(
residual
,
1
,
sample
,
num_inference_steps
,
**
kwargs
)[
"prev_sample"
]
output
=
scheduler
.
step_plms
(
residual
,
1
,
sample
,
**
kwargs
)[
"prev_sample"
]
output_pt
=
scheduler_pt
.
step_plms
(
residual_pt
,
1
,
sample_pt
,
num_inference_steps
,
**
kwargs
)[
"prev_sample"
]
output_pt
=
scheduler_pt
.
step_plms
(
residual_pt
,
1
,
sample_pt
,
**
kwargs
)[
"prev_sample"
]
assert
np
.
sum
(
np
.
abs
(
output
-
output_pt
.
numpy
()))
<
1e-4
,
"Scheduler outputs are not identical"
assert
np
.
sum
(
np
.
abs
(
output
-
output_pt
.
numpy
()))
<
1e-4
,
"Scheduler outputs are not identical"
...
@@ -503,14 +503,14 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -503,14 +503,14 @@ class PNDMSchedulerTest(SchedulerCommonTest):
elif
num_inference_steps
is
not
None
and
not
hasattr
(
scheduler
,
"set_timesteps"
):
elif
num_inference_steps
is
not
None
and
not
hasattr
(
scheduler
,
"set_timesteps"
):
kwargs
[
"num_inference_steps"
]
=
num_inference_steps
kwargs
[
"num_inference_steps"
]
=
num_inference_steps
output_0
=
scheduler
.
step_prk
(
residual
,
0
,
sample
,
num_inference_steps
,
**
kwargs
)[
"prev_sample"
]
output_0
=
scheduler
.
step_prk
(
residual
,
0
,
sample
,
**
kwargs
)[
"prev_sample"
]
output_1
=
scheduler
.
step_prk
(
residual
,
1
,
sample
,
num_inference_steps
,
**
kwargs
)[
"prev_sample"
]
output_1
=
scheduler
.
step_prk
(
residual
,
1
,
sample
,
**
kwargs
)[
"prev_sample"
]
self
.
assertEqual
(
output_0
.
shape
,
sample
.
shape
)
self
.
assertEqual
(
output_0
.
shape
,
sample
.
shape
)
self
.
assertEqual
(
output_0
.
shape
,
output_1
.
shape
)
self
.
assertEqual
(
output_0
.
shape
,
output_1
.
shape
)
output_0
=
scheduler
.
step_plms
(
residual
,
0
,
sample
,
num_inference_steps
,
**
kwargs
)[
"prev_sample"
]
output_0
=
scheduler
.
step_plms
(
residual
,
0
,
sample
,
**
kwargs
)[
"prev_sample"
]
output_1
=
scheduler
.
step_plms
(
residual
,
1
,
sample
,
num_inference_steps
,
**
kwargs
)[
"prev_sample"
]
output_1
=
scheduler
.
step_plms
(
residual
,
1
,
sample
,
**
kwargs
)[
"prev_sample"
]
self
.
assertEqual
(
output_0
.
shape
,
sample
.
shape
)
self
.
assertEqual
(
output_0
.
shape
,
sample
.
shape
)
self
.
assertEqual
(
output_0
.
shape
,
output_1
.
shape
)
self
.
assertEqual
(
output_0
.
shape
,
output_1
.
shape
)
...
@@ -541,7 +541,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -541,7 +541,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
scheduler_config
=
self
.
get_scheduler_config
()
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
.
step_plms
(
self
.
dummy_sample
,
1
,
self
.
dummy_sample
,
50
)[
"prev_sample"
]
scheduler
.
step_plms
(
self
.
dummy_sample
,
1
,
self
.
dummy_sample
)[
"prev_sample"
]
def
test_full_loop_no_noise
(
self
):
def
test_full_loop_no_noise
(
self
):
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_class
=
self
.
scheduler_classes
[
0
]
...
@@ -555,11 +555,11 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -555,11 +555,11 @@ class PNDMSchedulerTest(SchedulerCommonTest):
for
i
,
t
in
enumerate
(
scheduler
.
prk_timesteps
):
for
i
,
t
in
enumerate
(
scheduler
.
prk_timesteps
):
residual
=
model
(
sample
,
t
)
residual
=
model
(
sample
,
t
)
sample
=
scheduler
.
step_prk
(
residual
,
i
,
sample
,
num_inference_steps
)[
"prev_sample"
]
sample
=
scheduler
.
step_prk
(
residual
,
i
,
sample
)[
"prev_sample"
]
for
i
,
t
in
enumerate
(
scheduler
.
plms_timesteps
):
for
i
,
t
in
enumerate
(
scheduler
.
plms_timesteps
):
residual
=
model
(
sample
,
t
)
residual
=
model
(
sample
,
t
)
sample
=
scheduler
.
step_plms
(
residual
,
i
,
sample
,
num_inference_steps
)[
"prev_sample"
]
sample
=
scheduler
.
step_plms
(
residual
,
i
,
sample
)[
"prev_sample"
]
result_sum
=
torch
.
sum
(
torch
.
abs
(
sample
))
result_sum
=
torch
.
sum
(
torch
.
abs
(
sample
))
result_mean
=
torch
.
mean
(
torch
.
abs
(
sample
))
result_mean
=
torch
.
mean
(
torch
.
abs
(
sample
))
...
@@ -706,7 +706,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
...
@@ -706,7 +706,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
model_output
=
model
(
sample
,
sigma_t
)
model_output
=
model
(
sample
,
sigma_t
)
output
=
scheduler
.
step_pred
(
model_output
,
t
,
sample
,
**
kwargs
)
output
=
scheduler
.
step_pred
(
model_output
,
t
,
sample
,
**
kwargs
)
sample
,
sample_mean
=
output
[
"prev_sample"
],
output
[
"prev_sample_mean"
]
sample
,
_
=
output
[
"prev_sample"
],
output
[
"prev_sample_mean"
]
result_sum
=
torch
.
sum
(
torch
.
abs
(
sample
))
result_sum
=
torch
.
sum
(
torch
.
abs
(
sample
))
result_mean
=
torch
.
mean
(
torch
.
abs
(
sample
))
result_mean
=
torch
.
mean
(
torch
.
abs
(
sample
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment