Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
7222a8ea
Commit
7222a8ea
authored
Dec 02, 2022
by
Patrick von Platen
Browse files
make style
parent
155d272c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
4 deletions
+2
-4
src/diffusers/schedulers/scheduling_lms_discrete_flax.py
src/diffusers/schedulers/scheduling_lms_discrete_flax.py
+2
-4
No files found.
src/diffusers/schedulers/scheduling_lms_discrete_flax.py
View file @
7222a8ea
...
...
@@ -102,9 +102,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
sigmas
=
((
1
-
self
.
alphas_cumprod
)
/
self
.
alphas_cumprod
)
**
0.5
,
)
def
scale_model_input
(
self
,
state
:
LMSDiscreteSchedulerState
,
sample
:
jnp
.
ndarray
,
timestep
:
int
)
->
jnp
.
ndarray
:
def
scale_model_input
(
self
,
state
:
LMSDiscreteSchedulerState
,
sample
:
jnp
.
ndarray
,
timestep
:
int
)
->
jnp
.
ndarray
:
"""
Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm.
...
...
@@ -119,7 +117,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
Returns:
`jnp.ndarray`: scaled input sample
"""
step_index
,
=
jnp
.
where
(
scheduler_state
.
timesteps
==
timestep
,
size
=
1
)
(
step_index
,
)
=
jnp
.
where
(
scheduler_state
.
timesteps
==
timestep
,
size
=
1
)
sigma
=
scheduler_state
.
sigmas
[
step_index
]
sample
=
sample
/
((
sigma
**
2
+
1
)
**
0.5
)
return
sample
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment