Commit 2dc84d14 authored by comfyanonymous's avatar comfyanonymous
Browse files

Add a way to set the timestep multiplier in the flow sampling.

parent ff63893d
......@@ -190,11 +190,12 @@ class ModelSamplingDiscreteFlow(torch.nn.Module):
else:
sampling_settings = {}
self.set_parameters(shift=sampling_settings.get("shift", 1.0))
self.set_parameters(shift=sampling_settings.get("shift", 1.0), multiplier=sampling_settings.get("multiplier", 1000))
def set_parameters(self, shift=1.0, timesteps=1000):
def set_parameters(self, shift=1.0, timesteps=1000, multiplier=1000):
self.shift = shift
ts = self.sigma(torch.arange(1, timesteps + 1, 1))
self.multiplier = multiplier
ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps) * multiplier)
self.register_buffer('sigmas', ts)
@property
......@@ -206,10 +207,10 @@ class ModelSamplingDiscreteFlow(torch.nn.Module):
return self.sigmas[-1]
def timestep(self, sigma):
return sigma * 1000
return sigma * self.multiplier
def sigma(self, timestep):
return time_snr_shift(self.shift, timestep / 1000)
return time_snr_shift(self.shift, timestep / self.multiplier)
def percent_to_sigma(self, percent):
if percent <= 0.0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment