Unverified Commit ec1c7a79 authored by hlky's avatar hlky Committed by GitHub
Browse files

Add `set_shift` to FlowMatchEulerDiscreteScheduler (#10269)

parent 9c68c945
...@@ -99,10 +99,19 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -99,10 +99,19 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self._step_index = None self._step_index = None
self._begin_index = None self._begin_index = None
self._shift = shift
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.sigma_min = self.sigmas[-1].item() self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item() self.sigma_max = self.sigmas[0].item()
@property
def shift(self):
"""
The value used for shifting.
"""
return self._shift
@property @property
def step_index(self): def step_index(self):
""" """
...@@ -128,6 +137,9 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -128,6 +137,9 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
self._begin_index = begin_index self._begin_index = begin_index
def set_shift(self, shift: float):
self._shift = shift
def scale_noise( def scale_noise(
self, self,
sample: torch.FloatTensor, sample: torch.FloatTensor,
...@@ -236,7 +248,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -236,7 +248,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
if self.config.use_dynamic_shifting: if self.config.use_dynamic_shifting:
sigmas = self.time_shift(mu, 1.0, sigmas) sigmas = self.time_shift(mu, 1.0, sigmas)
else: else:
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
if self.config.shift_terminal: if self.config.shift_terminal:
sigmas = self.stretch_shift_to_terminal(sigmas) sigmas = self.stretch_shift_to_terminal(sigmas)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment