Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
5321f3e2
Unverified
Commit
5321f3e2
authored
Aug 22, 2022
by
Suraj Patil
Committed by
GitHub
Aug 22, 2022
Browse files
add add_noise method in LMSDiscreteScheduler, PNDMScheduler (#227)
add add_noise method in more schedulers
parent
3f1861ee
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
0 deletions
+18
-0
src/diffusers/schedulers/scheduling_lms_discrete.py
src/diffusers/schedulers/scheduling_lms_discrete.py
+9
-0
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+9
-0
No files found.
src/diffusers/schedulers/scheduling_lms_discrete.py
View file @
5321f3e2
...
@@ -130,5 +130,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -130,5 +130,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
noisy_samples
=
(
alpha_prod
**
0.5
)
*
original_samples
+
((
1
-
alpha_prod
)
**
0.5
)
*
noise
noisy_samples
=
(
alpha_prod
**
0.5
)
*
original_samples
+
((
1
-
alpha_prod
)
**
0.5
)
*
noise
return
noisy_samples
return
noisy_samples
def
add_noise
(
self
,
original_samples
,
noise
,
timesteps
):
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
self
.
match_shape
(
sqrt_alpha_prod
,
original_samples
)
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
self
.
match_shape
(
sqrt_one_minus_alpha_prod
,
original_samples
)
noisy_samples
=
sqrt_alpha_prod
*
original_samples
+
sqrt_one_minus_alpha_prod
*
noise
return
noisy_samples
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
config
.
num_train_timesteps
return
self
.
config
.
num_train_timesteps
src/diffusers/schedulers/scheduling_pndm.py
View file @
5321f3e2
...
@@ -250,5 +250,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -250,5 +250,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return
prev_sample
return
prev_sample
def
add_noise
(
self
,
original_samples
,
noise
,
timesteps
):
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
self
.
match_shape
(
sqrt_alpha_prod
,
original_samples
)
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
self
.
match_shape
(
sqrt_one_minus_alpha_prod
,
original_samples
)
noisy_samples
=
sqrt_alpha_prod
*
original_samples
+
sqrt_one_minus_alpha_prod
*
noise
return
noisy_samples
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
config
.
num_train_timesteps
return
self
.
config
.
num_train_timesteps
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment