Commit 6af19588 authored by Zhuguanyu Wu's avatar Zhuguanyu Wu Committed by GitHub
Browse files

Use CM scheduler for distill_models as default (#127)

parent d594d5ac
......@@ -12,6 +12,10 @@ class WanStepDistillScheduler(WanScheduler):
self.infer_steps = len(self.denoising_step_list)
self.sample_shift = self.config.sample_shift
self.num_train_timesteps = 1000
self.sigma_max = 1.0
self.sigma_min = 0.0
def prepare(self, image_encoder_output):
self.generator = torch.Generator(device=self.device)
self.generator.manual_seed(self.config.seed)
......@@ -23,43 +27,30 @@ class WanStepDistillScheduler(WanScheduler):
elif self.config.task in ["i2v"]:
self.seq_len = self.config.lat_h * self.config.lat_w // (self.config.patch_size[1] * self.config.patch_size[2]) * self.config.target_shape[1]
alphas = np.linspace(1, 1 / self.num_train_timesteps, self.num_train_timesteps)[::-1].copy()
sigmas = 1.0 - alphas
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
self.sigmas = sigmas
self.timesteps = sigmas * self.num_train_timesteps
self.model_outputs = [None] * self.solver_order
self.timestep_list = [None] * self.solver_order
self.last_sample = None
self.sigmas = self.sigmas.to("cpu")
self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item()
self.set_denoising_timesteps(device=self.device)
def set_denoising_timesteps(self, device: Union[str, torch.device] = None):
self.timesteps = torch.tensor(self.denoising_step_list, device=device, dtype=torch.int64)
self.sigmas = torch.cat([self.timesteps / self.num_train_timesteps, torch.tensor([0.0], device=device)])
self.sigmas = self.sigmas.to("cpu")
self.infer_steps = len(self.timesteps)
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min)
self.sigmas = torch.linspace(sigma_start, self.sigma_min, self.num_train_timesteps + 1)[:-1]
self.sigmas = self.sample_shift * self.sigmas / (1 + (self.sample_shift - 1) * self.sigmas)
self.timesteps = self.sigmas * self.num_train_timesteps
self.model_outputs = [
None,
] * self.solver_order
self.lower_order_nums = 0
self.last_sample = None
self._begin_index = None
self.denoising_step_index = [self.num_train_timesteps - x for x in self.denoising_step_list]
self.timesteps = self.timesteps[self.denoising_step_index].to(device)
self.sigmas = self.sigmas[self.denoising_step_index].to("cpu")
def reset(self):
self.model_outputs = [None] * self.solver_order
self.timestep_list = [None] * self.solver_order
self.last_sample = None
self.noise_pred = None
self.this_order = None
self.lower_order_nums = 0
self.prepare_latents(self.config.target_shape, dtype=torch.float32)
def add_noise(self, original_samples, noise, sigma):
sample = (1 - sigma) * original_samples + sigma * noise
return sample.type_as(noise)
def step_post(self):
flow_pred = self.noise_pred.to(torch.float32)
sigma = self.sigmas[self.step_index].item()
noisy_image_or_video = self.latents.to(torch.float32) - sigma * flow_pred
if self.step_index < self.infer_steps - 1:
sigma = self.sigmas[self.step_index + 1].item()
noisy_image_or_video = self.add_noise(noisy_image_or_video, torch.randn_like(noisy_image_or_video), self.sigmas[self.step_index + 1].item())
self.latents = noisy_image_or_video.to(self.latents.dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment