Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
878af0e1
Unverified
Commit
878af0e1
authored
Sep 04, 2022
by
Partho
Committed by
GitHub
Sep 04, 2022
Browse files
[Type Hint] DDPM schedulers (#349)
parent
dea5ec50
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
11 deletions
+17
-11
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+17
-11
No files found.
src/diffusers/schedulers/scheduling_ddpm.py
View file @
878af0e1
...
...
@@ -15,7 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import
math
from
typing
import
Union
from
typing
import
Optional
,
Union
import
numpy
as
np
import
torch
...
...
@@ -51,14 +51,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
@
register_to_config
def
__init__
(
self
,
num_train_timesteps
=
1000
,
beta_start
=
0.0001
,
beta_end
=
0.02
,
beta_schedule
=
"linear"
,
trained_betas
=
None
,
variance_type
=
"fixed_small"
,
clip_sample
=
True
,
tensor_format
=
"pt"
,
num_train_timesteps
:
int
=
1000
,
beta_start
:
float
=
0.0001
,
beta_end
:
float
=
0.02
,
beta_schedule
:
str
=
"linear"
,
trained_betas
:
Optional
[
np
.
ndarray
]
=
None
,
variance_type
:
str
=
"fixed_small"
,
clip_sample
:
bool
=
True
,
tensor_format
:
str
=
"pt"
,
):
if
trained_betas
is
not
None
:
...
...
@@ -87,7 +87,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self
.
variance_type
=
variance_type
def
set_timesteps
(
self
,
num_inference_steps
):
def
set_timesteps
(
self
,
num_inference_steps
:
int
):
num_inference_steps
=
min
(
self
.
config
.
num_train_timesteps
,
num_inference_steps
)
self
.
num_inference_steps
=
num_inference_steps
self
.
timesteps
=
np
.
arange
(
...
...
@@ -179,7 +179,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return
{
"prev_sample"
:
pred_prev_sample
}
def
add_noise
(
self
,
original_samples
,
noise
,
timesteps
):
def
add_noise
(
self
,
original_samples
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
noise
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
timesteps
:
Union
[
torch
.
IntTensor
,
np
.
ndarray
],
)
->
Union
[
torch
.
FloatTensor
,
np
.
ndarray
]:
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
self
.
match_shape
(
sqrt_alpha_prod
,
original_samples
)
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment