"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "c82a7f9c49613117221fb844c4d04e1f628cbced"
Unverified Commit 21682bab authored by Disty0's avatar Disty0 Committed by GitHub
Browse files

Custom sampler support for Stable Cascade Decoder (#9132)

Custom sampler support Stable Cascade Decoder
parent 214990e5
...@@ -281,6 +281,16 @@ class StableCascadeDecoderPipeline(DiffusionPipeline): ...@@ -281,6 +281,16 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
def num_timesteps(self): def num_timesteps(self):
return self._num_timesteps return self._num_timesteps
def get_timestep_ratio_conditioning(self, t, alphas_cumprod):
s = torch.tensor([0.008])
clamp_range = [0, 1]
min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2
var = alphas_cumprod[t]
var = var.clamp(*clamp_range)
s, min_var = s.to(var.device), min_var.to(var.device)
ratio = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
return ratio
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -434,10 +444,30 @@ class StableCascadeDecoderPipeline(DiffusionPipeline): ...@@ -434,10 +444,30 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
batch_size, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler batch_size, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler
) )
if isinstance(self.scheduler, DDPMWuerstchenScheduler):
timesteps = timesteps[:-1]
else:
if hasattr(self.scheduler.config, "clip_sample") and self.scheduler.config.clip_sample:
self.scheduler.config.clip_sample = False # disample sample clipping
logger.warning(" set `clip_sample` to be False")
# 6. Run denoising loop # 6. Run denoising loop
self._num_timesteps = len(timesteps[:-1]) if hasattr(self.scheduler, "betas"):
for i, t in enumerate(self.progress_bar(timesteps[:-1])): alphas = 1.0 - self.scheduler.betas
timestep_ratio = t.expand(latents.size(0)).to(dtype) alphas_cumprod = torch.cumprod(alphas, dim=0)
else:
alphas_cumprod = []
self._num_timesteps = len(timesteps)
for i, t in enumerate(self.progress_bar(timesteps)):
if not isinstance(self.scheduler, DDPMWuerstchenScheduler):
if len(alphas_cumprod) > 0:
timestep_ratio = self.get_timestep_ratio_conditioning(t.long().cpu(), alphas_cumprod)
timestep_ratio = timestep_ratio.expand(latents.size(0)).to(dtype).to(device)
else:
timestep_ratio = t.float().div(self.scheduler.timesteps[-1]).expand(latents.size(0)).to(dtype)
else:
timestep_ratio = t.expand(latents.size(0)).to(dtype)
# 7. Denoise latents # 7. Denoise latents
predicted_latents = self.decoder( predicted_latents = self.decoder(
...@@ -454,6 +484,8 @@ class StableCascadeDecoderPipeline(DiffusionPipeline): ...@@ -454,6 +484,8 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
predicted_latents = torch.lerp(predicted_latents_uncond, predicted_latents_text, self.guidance_scale) predicted_latents = torch.lerp(predicted_latents_uncond, predicted_latents_text, self.guidance_scale)
# 9. Renoise latents to next timestep # 9. Renoise latents to next timestep
if not isinstance(self.scheduler, DDPMWuerstchenScheduler):
timestep_ratio = t
latents = self.scheduler.step( latents = self.scheduler.step(
model_output=predicted_latents, model_output=predicted_latents,
timestep=timestep_ratio, timestep=timestep_ratio,
......
...@@ -353,7 +353,7 @@ class StableCascadePriorPipeline(DiffusionPipeline): ...@@ -353,7 +353,7 @@ class StableCascadePriorPipeline(DiffusionPipeline):
return self._num_timesteps return self._num_timesteps
def get_timestep_ratio_conditioning(self, t, alphas_cumprod): def get_timestep_ratio_conditioning(self, t, alphas_cumprod):
s = torch.tensor([0.003]) s = torch.tensor([0.008])
clamp_range = [0, 1] clamp_range = [0, 1]
min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2 min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2
var = alphas_cumprod[t] var = alphas_cumprod[t]
...@@ -557,7 +557,7 @@ class StableCascadePriorPipeline(DiffusionPipeline): ...@@ -557,7 +557,7 @@ class StableCascadePriorPipeline(DiffusionPipeline):
if isinstance(self.scheduler, DDPMWuerstchenScheduler): if isinstance(self.scheduler, DDPMWuerstchenScheduler):
timesteps = timesteps[:-1] timesteps = timesteps[:-1]
else: else:
if self.scheduler.config.clip_sample: if hasattr(self.scheduler.config, "clip_sample") and self.scheduler.config.clip_sample:
self.scheduler.config.clip_sample = False # disample sample clipping self.scheduler.config.clip_sample = False # disample sample clipping
logger.warning(" set `clip_sample` to be False") logger.warning(" set `clip_sample` to be False")
# 6. Run denoising loop # 6. Run denoising loop
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment