Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
394243ce
Commit
394243ce
authored
Jul 21, 2022
by
Patrick von Platen
Browse files
finish pndm sampler
parent
fe985746
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
78 additions
and
71 deletions
+78
-71
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
+2
-5
src/diffusers/pipelines/pndm/pipeline_pndm.py
src/diffusers/pipelines/pndm/pipeline_pndm.py
+12
-7
src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
...diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
+5
-11
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+2
-2
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+0
-1
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+39
-27
tests/test_scheduler.py
tests/test_scheduler.py
+18
-18
No files found.
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
View file @
394243ce
...
@@ -48,11 +48,8 @@ class DDPMPipeline(DiffusionPipeline):
...
@@ -48,11 +48,8 @@ class DDPMPipeline(DiffusionPipeline):
# 1. predict noise model_output
# 1. predict noise model_output
model_output
=
self
.
unet
(
image
,
t
)[
"sample"
]
model_output
=
self
.
unet
(
image
,
t
)[
"sample"
]
# 2. predict previous mean of image x_t-1
# 2. compute previous image: x_t -> t_t-1
pred_prev_image
=
self
.
scheduler
.
step
(
model_output
,
t
,
image
)[
"prev_sample"
]
image
=
self
.
scheduler
.
step
(
model_output
,
t
,
image
)[
"prev_sample"
]
# 3. set current image to prev_image: x_t -> x_t-1
image
=
pred_prev_image
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
...
...
src/diffusers/pipelines/pndm/pipeline_pndm.py
View file @
394243ce
...
@@ -44,15 +44,20 @@ class PNDMPipeline(DiffusionPipeline):
...
@@ -44,15 +44,20 @@ class PNDMPipeline(DiffusionPipeline):
image
=
image
.
to
(
torch_device
)
image
=
image
.
to
(
torch_device
)
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
for
i
,
t
in
enumerate
(
tqdm
(
self
.
scheduler
.
prk_
timesteps
)
)
:
for
t
in
tqdm
(
self
.
scheduler
.
timesteps
):
model_output
=
self
.
unet
(
image
,
t
)[
"sample"
]
model_output
=
self
.
unet
(
image
,
t
)[
"sample"
]
image
=
self
.
scheduler
.
step_prk
(
model_output
,
i
,
image
,
num_inference_steps
)[
"prev_sample"
]
image
=
self
.
scheduler
.
step
(
model_output
,
t
,
image
)[
"prev_sample"
]
for
i
,
t
in
enumerate
(
tqdm
(
self
.
scheduler
.
plms_timesteps
)):
# for i, t in enumerate(tqdm(self.scheduler.prk_timesteps)):
model_output
=
self
.
unet
(
image
,
t
)[
"sample"
]
# model_output = self.unet(image, t)["sample"]
#
image
=
self
.
scheduler
.
step_plms
(
model_output
,
i
,
image
,
num_inference_steps
)[
"prev_sample"
]
# image = self.scheduler.step_prk(model_output, t, image, i=i)["prev_sample"]
#
# for i, t in enumerate(tqdm(self.scheduler.plms_timesteps)):
# model_output = self.unet(image, t)["sample"]
#
# image = self.scheduler.step_plms(model_output, t, image, i=i)["prev_sample"]
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
...
...
src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
View file @
394243ce
...
@@ -28,21 +28,15 @@ class ScoreSdeVePipeline(DiffusionPipeline):
...
@@ -28,21 +28,15 @@ class ScoreSdeVePipeline(DiffusionPipeline):
for
i
,
t
in
tqdm
(
enumerate
(
self
.
scheduler
.
timesteps
)):
for
i
,
t
in
tqdm
(
enumerate
(
self
.
scheduler
.
timesteps
)):
sigma_t
=
self
.
scheduler
.
sigmas
[
i
]
*
torch
.
ones
(
shape
[
0
],
device
=
device
)
sigma_t
=
self
.
scheduler
.
sigmas
[
i
]
*
torch
.
ones
(
shape
[
0
],
device
=
device
)
# correction step
for
_
in
range
(
self
.
scheduler
.
correct_steps
):
for
_
in
range
(
self
.
scheduler
.
correct_steps
):
model_output
=
self
.
model
(
sample
,
sigma_t
)
model_output
=
self
.
model
(
sample
,
sigma_t
)[
"sample"
]
if
isinstance
(
model_output
,
dict
):
model_output
=
model_output
[
"sample"
]
sample
=
self
.
scheduler
.
step_correct
(
model_output
,
sample
)[
"prev_sample"
]
sample
=
self
.
scheduler
.
step_correct
(
model_output
,
sample
)[
"prev_sample"
]
with
torch
.
no_grad
():
# prediction step
model_output
=
model
(
sample
,
sigma_t
)
model_output
=
model
(
sample
,
sigma_t
)[
"sample"
]
if
isinstance
(
model_output
,
dict
):
model_output
=
model_output
[
"sample"
]
output
=
self
.
scheduler
.
step_pred
(
model_output
,
t
,
sample
)
output
=
self
.
scheduler
.
step_pred
(
model_output
,
t
,
sample
)
sample
,
sample_mean
=
output
[
"prev_sample"
],
output
[
"prev_sample_mean"
]
sample
,
sample_mean
=
output
[
"prev_sample"
],
output
[
"prev_sample_mean"
]
sample
=
sample
.
clamp
(
0
,
1
)
sample
=
sample
.
clamp
(
0
,
1
)
...
...
src/diffusers/schedulers/scheduling_ddim.py
View file @
394243ce
...
@@ -106,8 +106,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -106,8 +106,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
model_output
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
model_output
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
timestep
:
int
,
timestep
:
int
,
sample
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
sample
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
eta
,
eta
:
float
=
0.0
,
use_clipped_model_output
=
False
,
use_clipped_model_output
:
bool
=
False
,
generator
=
None
,
generator
=
None
,
):
):
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
394243ce
...
@@ -56,7 +56,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -56,7 +56,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
beta_end
=
0.02
,
beta_end
=
0.02
,
beta_schedule
=
"linear"
,
beta_schedule
=
"linear"
,
trained_betas
=
None
,
trained_betas
=
None
,
timestep_values
=
None
,
variance_type
=
"fixed_small"
,
variance_type
=
"fixed_small"
,
clip_sample
=
True
,
clip_sample
=
True
,
tensor_format
=
"pt"
,
tensor_format
=
"pt"
,
...
...
src/diffusers/schedulers/scheduling_pndm.py
View file @
394243ce
...
@@ -15,7 +15,6 @@
...
@@ -15,7 +15,6 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import
math
import
math
import
pdb
from
typing
import
Union
from
typing
import
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -79,78 +78,91 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -79,78 +78,91 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# running values
# running values
self
.
cur_model_output
=
0
self
.
cur_model_output
=
0
self
.
counter
=
0
self
.
cur_sample
=
None
self
.
cur_sample
=
None
self
.
ets
=
[]
self
.
ets
=
[]
# setable values
# setable values
self
.
num_inference_steps
=
None
self
.
num_inference_steps
=
None
self
.
timesteps
=
np
.
arange
(
0
,
num_train_timesteps
)[::
-
1
].
copy
()
self
.
_
timesteps
=
np
.
arange
(
0
,
num_train_timesteps
)[::
-
1
].
copy
()
self
.
prk_timesteps
=
None
self
.
prk_timesteps
=
None
self
.
plms_timesteps
=
None
self
.
plms_timesteps
=
None
self
.
timesteps
=
None
self
.
tensor_format
=
tensor_format
self
.
tensor_format
=
tensor_format
self
.
set_format
(
tensor_format
=
tensor_format
)
self
.
set_format
(
tensor_format
=
tensor_format
)
def
set_timesteps
(
self
,
num_inference_steps
):
def
set_timesteps
(
self
,
num_inference_steps
):
self
.
num_inference_steps
=
num_inference_steps
self
.
num_inference_steps
=
num_inference_steps
self
.
timesteps
=
list
(
self
.
_
timesteps
=
list
(
range
(
0
,
self
.
config
.
num_train_timesteps
,
self
.
config
.
num_train_timesteps
//
num_inference_steps
)
range
(
0
,
self
.
config
.
num_train_timesteps
,
self
.
config
.
num_train_timesteps
//
num_inference_steps
)
)
)
prk_time
_
steps
=
np
.
array
(
self
.
timesteps
[
-
self
.
pndm_order
:]).
repeat
(
2
)
+
np
.
tile
(
prk_timesteps
=
np
.
array
(
self
.
_
timesteps
[
-
self
.
pndm_order
:]).
repeat
(
2
)
+
np
.
tile
(
np
.
array
([
0
,
self
.
config
.
num_train_timesteps
//
num_inference_steps
//
2
]),
self
.
pndm_order
np
.
array
([
0
,
self
.
config
.
num_train_timesteps
//
num_inference_steps
//
2
]),
self
.
pndm_order
)
)
self
.
prk_timesteps
=
list
(
reversed
(
prk_time_steps
[:
-
1
].
repeat
(
2
)[
1
:
-
1
]))
self
.
prk_timesteps
=
list
(
reversed
(
prk_timesteps
[:
-
1
].
repeat
(
2
)[
1
:
-
1
]))
self
.
plms_timesteps
=
list
(
reversed
(
self
.
timesteps
[:
-
3
]))
self
.
plms_timesteps
=
list
(
reversed
(
self
.
_timesteps
[:
-
3
]))
self
.
timesteps
=
self
.
prk_timesteps
+
self
.
plms_timesteps
self
.
counter
=
0
self
.
set_format
(
tensor_format
=
self
.
tensor_format
)
self
.
set_format
(
tensor_format
=
self
.
tensor_format
)
def
step
(
self
,
model_output
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
timestep
:
int
,
sample
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
):
if
self
.
counter
<
len
(
self
.
prk_timesteps
):
return
self
.
step_prk
(
model_output
=
model_output
,
timestep
=
timestep
,
sample
=
sample
)
else
:
return
self
.
step_plms
(
model_output
=
model_output
,
timestep
=
timestep
,
sample
=
sample
)
def
step_prk
(
def
step_prk
(
self
,
self
,
model_output
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
model_output
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
timestep
:
int
,
timestep
:
int
,
sample
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
sample
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
num_inference_steps
,
):
):
"""
"""
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
solution to the differential equation.
solution to the differential equation.
"""
"""
t
=
timestep
diff_to_prev
=
0
if
self
.
counter
%
2
else
self
.
config
.
num_train_timesteps
//
self
.
num_inference_steps
//
2
prk_time_steps
=
self
.
prk_timesteps
prev_timestep
=
max
(
timestep
-
diff_to_prev
,
self
.
prk_timesteps
[
-
1
])
timestep
=
self
.
prk_timesteps
[
self
.
counter
//
4
*
4
]
t_orig
=
prk_time_steps
[
t
//
4
*
4
]
t_orig_prev
=
prk_time_steps
[
min
(
t
+
1
,
len
(
prk_time_steps
)
-
1
)]
if
t
%
4
==
0
:
if
self
.
counter
%
4
==
0
:
self
.
cur_model_output
+=
1
/
6
*
model_output
self
.
cur_model_output
+=
1
/
6
*
model_output
self
.
ets
.
append
(
model_output
)
self
.
ets
.
append
(
model_output
)
self
.
cur_sample
=
sample
self
.
cur_sample
=
sample
elif
(
t
-
1
)
%
4
==
0
:
elif
(
self
.
counter
-
1
)
%
4
==
0
:
self
.
cur_model_output
+=
1
/
3
*
model_output
self
.
cur_model_output
+=
1
/
3
*
model_output
elif
(
t
-
2
)
%
4
==
0
:
elif
(
self
.
counter
-
2
)
%
4
==
0
:
self
.
cur_model_output
+=
1
/
3
*
model_output
self
.
cur_model_output
+=
1
/
3
*
model_output
elif
(
t
-
3
)
%
4
==
0
:
elif
(
self
.
counter
-
3
)
%
4
==
0
:
model_output
=
self
.
cur_model_output
+
1
/
6
*
model_output
model_output
=
self
.
cur_model_output
+
1
/
6
*
model_output
self
.
cur_model_output
=
0
self
.
cur_model_output
=
0
# cur_sample should not be `None`
# cur_sample should not be `None`
cur_sample
=
self
.
cur_sample
if
self
.
cur_sample
is
not
None
else
sample
cur_sample
=
self
.
cur_sample
if
self
.
cur_sample
is
not
None
else
sample
return
{
"prev_sample"
:
self
.
get_prev_sample
(
cur_sample
,
t_orig
,
t_orig_prev
,
model_output
)}
prev_sample
=
self
.
_get_prev_sample
(
cur_sample
,
timestep
,
prev_timestep
,
model_output
)
self
.
counter
+=
1
return
{
"prev_sample"
:
prev_sample
}
def
step_plms
(
def
step_plms
(
self
,
self
,
model_output
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
model_output
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
timestep
:
int
,
timestep
:
int
,
sample
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
sample
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
num_inference_steps
,
):
):
"""
"""
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
times to approximate the solution.
times to approximate the solution.
"""
"""
t
=
timestep
if
len
(
self
.
ets
)
<
3
:
if
len
(
self
.
ets
)
<
3
:
raise
ValueError
(
raise
ValueError
(
f
"
{
self
.
__class__
}
can only be run AFTER scheduler has been run "
f
"
{
self
.
__class__
}
can only be run AFTER scheduler has been run "
...
@@ -159,17 +171,17 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -159,17 +171,17 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
"for more information."
"for more information."
)
)
timesteps
=
self
.
plms_timesteps
prev_timestep
=
max
(
timestep
-
self
.
config
.
num_train_timesteps
//
self
.
num_inference_steps
,
0
)
t_orig
=
timesteps
[
t
]
t_orig_prev
=
timesteps
[
min
(
t
+
1
,
len
(
timesteps
)
-
1
)]
self
.
ets
.
append
(
model_output
)
self
.
ets
.
append
(
model_output
)
model_output
=
(
1
/
24
)
*
(
55
*
self
.
ets
[
-
1
]
-
59
*
self
.
ets
[
-
2
]
+
37
*
self
.
ets
[
-
3
]
-
9
*
self
.
ets
[
-
4
])
model_output
=
(
1
/
24
)
*
(
55
*
self
.
ets
[
-
1
]
-
59
*
self
.
ets
[
-
2
]
+
37
*
self
.
ets
[
-
3
]
-
9
*
self
.
ets
[
-
4
])
return
{
"prev_sample"
:
self
.
get_prev_sample
(
sample
,
t_orig
,
t_orig_prev
,
model_output
)}
prev_sample
=
self
.
_get_prev_sample
(
sample
,
timestep
,
prev_timestep
,
model_output
)
self
.
counter
+=
1
return
{
"prev_sample"
:
prev_sample
}
def
get_prev_sample
(
self
,
sample
,
t
_orig
,
t_orig
_prev
,
model_output
):
def
_
get_prev_sample
(
self
,
sample
,
t
imestep
,
timestep
_prev
,
model_output
):
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
# this function computes x_(t−δ) using the formula of (9)
# this function computes x_(t−δ) using the formula of (9)
# Note that x_t needs to be added to both sides of the equation
# Note that x_t needs to be added to both sides of the equation
...
@@ -182,8 +194,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -182,8 +194,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# sample -> x_t
# sample -> x_t
# model_output -> e_θ(x_t, t)
# model_output -> e_θ(x_t, t)
# prev_sample -> x_(t−δ)
# prev_sample -> x_(t−δ)
alpha_prod_t
=
self
.
alphas_cumprod
[
t
_orig
+
1
]
alpha_prod_t
=
self
.
alphas_cumprod
[
t
imestep
+
1
]
alpha_prod_t_prev
=
self
.
alphas_cumprod
[
t
_orig
_prev
+
1
]
alpha_prod_t_prev
=
self
.
alphas_cumprod
[
t
imestep
_prev
+
1
]
beta_prod_t
=
1
-
alpha_prod_t
beta_prod_t
=
1
-
alpha_prod_t
beta_prod_t_prev
=
1
-
alpha_prod_t_prev
beta_prod_t_prev
=
1
-
alpha_prod_t_prev
...
...
tests/test_scheduler.py
View file @
394243ce
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
pdb
import
tempfile
import
tempfile
import
unittest
import
unittest
...
@@ -383,6 +382,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -383,6 +382,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
def
check_over_configs
(
self
,
time_step
=
0
,
**
config
):
def
check_over_configs
(
self
,
time_step
=
0
,
**
config
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
=
dict
(
self
.
forward_default_kwargs
)
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
sample
=
self
.
dummy_sample
sample
=
self
.
dummy_sample
residual
=
0.1
*
sample
residual
=
0.1
*
sample
dummy_past_residuals
=
[
residual
+
0.2
,
residual
+
0.15
,
residual
+
0.1
,
residual
+
0.05
]
dummy_past_residuals
=
[
residual
+
0.2
,
residual
+
0.15
,
residual
+
0.1
,
residual
+
0.05
]
...
@@ -390,14 +390,14 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -390,14 +390,14 @@ class PNDMSchedulerTest(SchedulerCommonTest):
for
scheduler_class
in
self
.
scheduler_classes
:
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_config
=
self
.
get_scheduler_config
(
**
config
)
scheduler_config
=
self
.
get_scheduler_config
(
**
config
)
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
.
set_timesteps
(
kwargs
[
"
num_inference_steps
"
]
)
scheduler
.
set_timesteps
(
num_inference_steps
)
# copy over dummy past residuals
# copy over dummy past residuals
scheduler
.
ets
=
dummy_past_residuals
[:]
scheduler
.
ets
=
dummy_past_residuals
[:]
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
new_scheduler
.
set_timesteps
(
kwargs
[
"
num_inference_steps
"
]
)
new_scheduler
.
set_timesteps
(
num_inference_steps
)
# copy over dummy past residuals
# copy over dummy past residuals
new_scheduler
.
ets
=
dummy_past_residuals
[:]
new_scheduler
.
ets
=
dummy_past_residuals
[:]
...
@@ -416,7 +416,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -416,7 +416,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
def
check_over_forward
(
self
,
time_step
=
0
,
**
forward_kwargs
):
def
check_over_forward
(
self
,
time_step
=
0
,
**
forward_kwargs
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
.
update
(
forward_kwargs
)
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
sample
=
self
.
dummy_sample
sample
=
self
.
dummy_sample
residual
=
0.1
*
sample
residual
=
0.1
*
sample
dummy_past_residuals
=
[
residual
+
0.2
,
residual
+
0.15
,
residual
+
0.1
,
residual
+
0.05
]
dummy_past_residuals
=
[
residual
+
0.2
,
residual
+
0.15
,
residual
+
0.1
,
residual
+
0.05
]
...
@@ -424,7 +424,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -424,7 +424,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
for
scheduler_class
in
self
.
scheduler_classes
:
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_config
=
self
.
get_scheduler_config
()
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
.
set_timesteps
(
kwargs
[
"
num_inference_steps
"
]
)
scheduler
.
set_timesteps
(
num_inference_steps
)
# copy over dummy past residuals
# copy over dummy past residuals
scheduler
.
ets
=
dummy_past_residuals
[:]
scheduler
.
ets
=
dummy_past_residuals
[:]
...
@@ -434,7 +434,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -434,7 +434,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
# copy over dummy past residuals
# copy over dummy past residuals
new_scheduler
.
ets
=
dummy_past_residuals
[:]
new_scheduler
.
ets
=
dummy_past_residuals
[:]
new_scheduler
.
set_timesteps
(
kwargs
[
"
num_inference_steps
"
]
)
new_scheduler
.
set_timesteps
(
num_inference_steps
)
output
=
scheduler
.
step_prk
(
residual
,
time_step
,
sample
,
**
kwargs
)[
"prev_sample"
]
output
=
scheduler
.
step_prk
(
residual
,
time_step
,
sample
,
**
kwargs
)[
"prev_sample"
]
new_output
=
new_scheduler
.
step_prk
(
residual
,
time_step
,
sample
,
**
kwargs
)[
"prev_sample"
]
new_output
=
new_scheduler
.
step_prk
(
residual
,
time_step
,
sample
,
**
kwargs
)[
"prev_sample"
]
...
@@ -474,12 +474,12 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -474,12 +474,12 @@ class PNDMSchedulerTest(SchedulerCommonTest):
elif
num_inference_steps
is
not
None
and
not
hasattr
(
scheduler
,
"set_timesteps"
):
elif
num_inference_steps
is
not
None
and
not
hasattr
(
scheduler
,
"set_timesteps"
):
kwargs
[
"num_inference_steps"
]
=
num_inference_steps
kwargs
[
"num_inference_steps"
]
=
num_inference_steps
output
=
scheduler
.
step_prk
(
residual
,
1
,
sample
,
num_inference_steps
,
**
kwargs
)[
"prev_sample"
]
output
=
scheduler
.
step_prk
(
residual
,
1
,
sample
,
**
kwargs
)[
"prev_sample"
]
output_pt
=
scheduler_pt
.
step_prk
(
residual_pt
,
1
,
sample_pt
,
num_inference_steps
,
**
kwargs
)[
"prev_sample"
]
output_pt
=
scheduler_pt
.
step_prk
(
residual_pt
,
1
,
sample_pt
,
**
kwargs
)[
"prev_sample"
]
assert
np
.
sum
(
np
.
abs
(
output
-
output_pt
.
numpy
()))
<
1e-4
,
"Scheduler outputs are not identical"
assert
np
.
sum
(
np
.
abs
(
output
-
output_pt
.
numpy
()))
<
1e-4
,
"Scheduler outputs are not identical"
output
=
scheduler
.
step_plms
(
residual
,
1
,
sample
,
num_inference_steps
,
**
kwargs
)[
"prev_sample"
]
output
=
scheduler
.
step_plms
(
residual
,
1
,
sample
,
**
kwargs
)[
"prev_sample"
]
output_pt
=
scheduler_pt
.
step_plms
(
residual_pt
,
1
,
sample_pt
,
num_inference_steps
,
**
kwargs
)[
"prev_sample"
]
output_pt
=
scheduler_pt
.
step_plms
(
residual_pt
,
1
,
sample_pt
,
**
kwargs
)[
"prev_sample"
]
assert
np
.
sum
(
np
.
abs
(
output
-
output_pt
.
numpy
()))
<
1e-4
,
"Scheduler outputs are not identical"
assert
np
.
sum
(
np
.
abs
(
output
-
output_pt
.
numpy
()))
<
1e-4
,
"Scheduler outputs are not identical"
...
@@ -503,14 +503,14 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -503,14 +503,14 @@ class PNDMSchedulerTest(SchedulerCommonTest):
elif
num_inference_steps
is
not
None
and
not
hasattr
(
scheduler
,
"set_timesteps"
):
elif
num_inference_steps
is
not
None
and
not
hasattr
(
scheduler
,
"set_timesteps"
):
kwargs
[
"num_inference_steps"
]
=
num_inference_steps
kwargs
[
"num_inference_steps"
]
=
num_inference_steps
output_0
=
scheduler
.
step_prk
(
residual
,
0
,
sample
,
num_inference_steps
,
**
kwargs
)[
"prev_sample"
]
output_0
=
scheduler
.
step_prk
(
residual
,
0
,
sample
,
**
kwargs
)[
"prev_sample"
]
output_1
=
scheduler
.
step_prk
(
residual
,
1
,
sample
,
num_inference_steps
,
**
kwargs
)[
"prev_sample"
]
output_1
=
scheduler
.
step_prk
(
residual
,
1
,
sample
,
**
kwargs
)[
"prev_sample"
]
self
.
assertEqual
(
output_0
.
shape
,
sample
.
shape
)
self
.
assertEqual
(
output_0
.
shape
,
sample
.
shape
)
self
.
assertEqual
(
output_0
.
shape
,
output_1
.
shape
)
self
.
assertEqual
(
output_0
.
shape
,
output_1
.
shape
)
output_0
=
scheduler
.
step_plms
(
residual
,
0
,
sample
,
num_inference_steps
,
**
kwargs
)[
"prev_sample"
]
output_0
=
scheduler
.
step_plms
(
residual
,
0
,
sample
,
**
kwargs
)[
"prev_sample"
]
output_1
=
scheduler
.
step_plms
(
residual
,
1
,
sample
,
num_inference_steps
,
**
kwargs
)[
"prev_sample"
]
output_1
=
scheduler
.
step_plms
(
residual
,
1
,
sample
,
**
kwargs
)[
"prev_sample"
]
self
.
assertEqual
(
output_0
.
shape
,
sample
.
shape
)
self
.
assertEqual
(
output_0
.
shape
,
sample
.
shape
)
self
.
assertEqual
(
output_0
.
shape
,
output_1
.
shape
)
self
.
assertEqual
(
output_0
.
shape
,
output_1
.
shape
)
...
@@ -541,7 +541,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -541,7 +541,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
scheduler_config
=
self
.
get_scheduler_config
()
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
.
step_plms
(
self
.
dummy_sample
,
1
,
self
.
dummy_sample
,
50
)[
"prev_sample"
]
scheduler
.
step_plms
(
self
.
dummy_sample
,
1
,
self
.
dummy_sample
)[
"prev_sample"
]
def
test_full_loop_no_noise
(
self
):
def
test_full_loop_no_noise
(
self
):
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_class
=
self
.
scheduler_classes
[
0
]
...
@@ -555,11 +555,11 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -555,11 +555,11 @@ class PNDMSchedulerTest(SchedulerCommonTest):
for
i
,
t
in
enumerate
(
scheduler
.
prk_timesteps
):
for
i
,
t
in
enumerate
(
scheduler
.
prk_timesteps
):
residual
=
model
(
sample
,
t
)
residual
=
model
(
sample
,
t
)
sample
=
scheduler
.
step_prk
(
residual
,
i
,
sample
,
num_inference_steps
)[
"prev_sample"
]
sample
=
scheduler
.
step_prk
(
residual
,
i
,
sample
)[
"prev_sample"
]
for
i
,
t
in
enumerate
(
scheduler
.
plms_timesteps
):
for
i
,
t
in
enumerate
(
scheduler
.
plms_timesteps
):
residual
=
model
(
sample
,
t
)
residual
=
model
(
sample
,
t
)
sample
=
scheduler
.
step_plms
(
residual
,
i
,
sample
,
num_inference_steps
)[
"prev_sample"
]
sample
=
scheduler
.
step_plms
(
residual
,
i
,
sample
)[
"prev_sample"
]
result_sum
=
torch
.
sum
(
torch
.
abs
(
sample
))
result_sum
=
torch
.
sum
(
torch
.
abs
(
sample
))
result_mean
=
torch
.
mean
(
torch
.
abs
(
sample
))
result_mean
=
torch
.
mean
(
torch
.
abs
(
sample
))
...
@@ -706,7 +706,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
...
@@ -706,7 +706,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
model_output
=
model
(
sample
,
sigma_t
)
model_output
=
model
(
sample
,
sigma_t
)
output
=
scheduler
.
step_pred
(
model_output
,
t
,
sample
,
**
kwargs
)
output
=
scheduler
.
step_pred
(
model_output
,
t
,
sample
,
**
kwargs
)
sample
,
sample_mean
=
output
[
"prev_sample"
],
output
[
"prev_sample_mean"
]
sample
,
_
=
output
[
"prev_sample"
],
output
[
"prev_sample_mean"
]
result_sum
=
torch
.
sum
(
torch
.
abs
(
sample
))
result_sum
=
torch
.
sum
(
torch
.
abs
(
sample
))
result_mean
=
torch
.
mean
(
torch
.
abs
(
sample
))
result_mean
=
torch
.
mean
(
torch
.
abs
(
sample
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment