"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "01e355516bbeecf1588a24c75321895029ea1123"
Unverified Commit 35e92096 authored by Grigory Sizov's avatar Grigory Sizov Committed by GitHub
Browse files

Fix formula for noise levels in Karras scheduler and tests (#627)

fix formula for noise levels in karras scheduler and tests
parent d0aa899f
...@@ -110,7 +110,7 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -110,7 +110,7 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
self.schedule = [ self.schedule = [
( (
self.config.sigma_max self.config.sigma_max**2
* (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1)) * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
) )
for i in self.timesteps for i in self.timesteps
......
...@@ -113,7 +113,7 @@ class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -113,7 +113,7 @@ class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin):
timesteps = jnp.arange(0, num_inference_steps)[::-1].copy() timesteps = jnp.arange(0, num_inference_steps)[::-1].copy()
schedule = [ schedule = [
( (
self.config.sigma_max self.config.sigma_max**2
* (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1)) * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
) )
for i in timesteps for i in timesteps
......
...@@ -1104,7 +1104,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -1104,7 +1104,7 @@ class PipelineTesterMixin(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 256, 256, 3) assert image.shape == (1, 256, 256, 3)
expected_slice = np.array([0.26815, 0.1581, 0.2658, 0.23248, 0.1550, 0.2539, 0.1131, 0.1024, 0.0837]) expected_slice = np.array([0.578, 0.5811, 0.5924, 0.5809, 0.587, 0.5886, 0.5861, 0.5802, 0.586])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow @slow
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment