Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
249b36cc
Unverified
Commit
249b36cc
authored
Oct 03, 2022
by
Pedro Cuenca
Committed by
GitHub
Oct 03, 2022
Browse files
Flax: add shape argument to `set_timesteps` (#690)
* Flax: add shape argument to set_timesteps * style
parent
500ca5a9
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
10 additions
and
6 deletions
+10
-6
src/diffusers/schedulers/scheduling_ddim_flax.py
src/diffusers/schedulers/scheduling_ddim_flax.py
+1
-1
src/diffusers/schedulers/scheduling_ddpm_flax.py
src/diffusers/schedulers/scheduling_ddpm_flax.py
+1
-1
src/diffusers/schedulers/scheduling_karras_ve_flax.py
src/diffusers/schedulers/scheduling_karras_ve_flax.py
+3
-1
src/diffusers/schedulers/scheduling_lms_discrete_flax.py
src/diffusers/schedulers/scheduling_lms_discrete_flax.py
+3
-1
src/diffusers/schedulers/scheduling_pndm_flax.py
src/diffusers/schedulers/scheduling_pndm_flax.py
+1
-1
src/diffusers/schedulers/scheduling_sde_ve_flax.py
src/diffusers/schedulers/scheduling_sde_ve_flax.py
+1
-1
No files found.
src/diffusers/schedulers/scheduling_ddim_flax.py
View file @
249b36cc
...
...
@@ -156,7 +156,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
return
variance
def
set_timesteps
(
self
,
state
:
DDIMSchedulerState
,
num_inference_steps
:
int
)
->
DDIMSchedulerState
:
def
set_timesteps
(
self
,
state
:
DDIMSchedulerState
,
num_inference_steps
:
int
,
shape
:
Tuple
)
->
DDIMSchedulerState
:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
...
...
src/diffusers/schedulers/scheduling_ddpm_flax.py
View file @
249b36cc
...
...
@@ -133,7 +133,7 @@ class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin):
self
.
variance_type
=
variance_type
def
set_timesteps
(
self
,
state
:
DDPMSchedulerState
,
num_inference_steps
:
int
)
->
DDPMSchedulerState
:
def
set_timesteps
(
self
,
state
:
DDPMSchedulerState
,
num_inference_steps
:
int
,
shape
:
Tuple
)
->
DDPMSchedulerState
:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
...
...
src/diffusers/schedulers/scheduling_karras_ve_flax.py
View file @
249b36cc
...
...
@@ -99,7 +99,9 @@ class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin):
):
self
.
state
=
KarrasVeSchedulerState
.
create
()
def
set_timesteps
(
self
,
state
:
KarrasVeSchedulerState
,
num_inference_steps
:
int
)
->
KarrasVeSchedulerState
:
def
set_timesteps
(
self
,
state
:
KarrasVeSchedulerState
,
num_inference_steps
:
int
,
shape
:
Tuple
)
->
KarrasVeSchedulerState
:
"""
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
...
...
src/diffusers/schedulers/scheduling_lms_discrete_flax.py
View file @
249b36cc
...
...
@@ -111,7 +111,9 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
return
integrated_coeff
def
set_timesteps
(
self
,
state
:
LMSDiscreteSchedulerState
,
num_inference_steps
:
int
)
->
LMSDiscreteSchedulerState
:
def
set_timesteps
(
self
,
state
:
LMSDiscreteSchedulerState
,
num_inference_steps
:
int
,
shape
:
Tuple
)
->
LMSDiscreteSchedulerState
:
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
...
...
src/diffusers/schedulers/scheduling_pndm_flax.py
View file @
249b36cc
...
...
@@ -156,7 +156,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
def
create_state
(
self
):
return
PNDMSchedulerState
.
create
(
num_train_timesteps
=
self
.
config
.
num_train_timesteps
)
def
set_timesteps
(
self
,
state
:
PNDMSchedulerState
,
shape
:
Tuple
,
num_inference_steps
:
int
)
->
PNDMSchedulerState
:
def
set_timesteps
(
self
,
state
:
PNDMSchedulerState
,
num_inference_steps
:
int
,
shape
:
Tuple
)
->
PNDMSchedulerState
:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
...
...
src/diffusers/schedulers/scheduling_sde_ve_flax.py
View file @
249b36cc
...
...
@@ -95,7 +95,7 @@ class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
self
.
state
=
self
.
set_sigmas
(
state
,
num_train_timesteps
,
sigma_min
,
sigma_max
,
sampling_eps
)
def
set_timesteps
(
self
,
state
:
ScoreSdeVeSchedulerState
,
num_inference_steps
:
int
,
sampling_eps
:
float
=
None
self
,
state
:
ScoreSdeVeSchedulerState
,
num_inference_steps
:
int
,
shape
:
Tuple
,
sampling_eps
:
float
=
None
)
->
ScoreSdeVeSchedulerState
:
"""
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment