Unverified Commit 249b36cc authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Flax: add shape argument to `set_timesteps` (#690)

* Flax: add shape argument to set_timesteps

* style
parent 500ca5a9
...@@ -156,7 +156,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -156,7 +156,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
return variance return variance
def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int) -> DDIMSchedulerState: def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDIMSchedulerState:
""" """
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
......
...@@ -133,7 +133,7 @@ class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -133,7 +133,7 @@ class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin):
self.variance_type = variance_type self.variance_type = variance_type
def set_timesteps(self, state: DDPMSchedulerState, num_inference_steps: int) -> DDPMSchedulerState: def set_timesteps(self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDPMSchedulerState:
""" """
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
......
...@@ -99,7 +99,9 @@ class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -99,7 +99,9 @@ class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin):
): ):
self.state = KarrasVeSchedulerState.create() self.state = KarrasVeSchedulerState.create()
def set_timesteps(self, state: KarrasVeSchedulerState, num_inference_steps: int) -> KarrasVeSchedulerState: def set_timesteps(
self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple
) -> KarrasVeSchedulerState:
""" """
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
......
...@@ -111,7 +111,9 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -111,7 +111,9 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
return integrated_coeff return integrated_coeff
def set_timesteps(self, state: LMSDiscreteSchedulerState, num_inference_steps: int) -> LMSDiscreteSchedulerState: def set_timesteps(
self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple
) -> LMSDiscreteSchedulerState:
""" """
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
......
...@@ -156,7 +156,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -156,7 +156,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
def create_state(self): def create_state(self):
return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps) return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
def set_timesteps(self, state: PNDMSchedulerState, shape: Tuple, num_inference_steps: int) -> PNDMSchedulerState: def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, shape: Tuple) -> PNDMSchedulerState:
""" """
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
......
...@@ -95,7 +95,7 @@ class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -95,7 +95,7 @@ class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
self.state = self.set_sigmas(state, num_train_timesteps, sigma_min, sigma_max, sampling_eps) self.state = self.set_sigmas(state, num_train_timesteps, sigma_min, sigma_max, sampling_eps)
def set_timesteps( def set_timesteps(
self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, sampling_eps: float = None self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple, sampling_eps: float = None
) -> ScoreSdeVeSchedulerState: ) -> ScoreSdeVeSchedulerState:
""" """
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment