Commit d86b6917 authored by wangshankun's avatar wangshankun
Browse files

update audio scheduler

parent 3dc1fafb
...@@ -30,10 +30,6 @@ class EulerSchedulerTimestepFix(BaseScheduler): ...@@ -30,10 +30,6 @@ class EulerSchedulerTimestepFix(BaseScheduler):
self.init_noise_sigma = 1.0 self.init_noise_sigma = 1.0
self.config = config self.config = config
self.latents = None self.latents = None
self.caching_records = [True] * config.infer_steps
self.flag_df = False
self.transformer_infer = None
self.solver_order = 2
self.device = torch.device("cuda") self.device = torch.device("cuda")
self.infer_steps = self.config.infer_steps self.infer_steps = self.config.infer_steps
self.target_video_length = self.config.target_video_length self.target_video_length = self.config.target_video_length
...@@ -41,47 +37,12 @@ class EulerSchedulerTimestepFix(BaseScheduler): ...@@ -41,47 +37,12 @@ class EulerSchedulerTimestepFix(BaseScheduler):
self.shift = 1 self.shift = 1
self.num_train_timesteps = 1000 self.num_train_timesteps = 1000
self.step_index = None self.step_index = None
self.noise_pred = None
self._step_index = None
self._begin_index = None
def step_pre(self, step_index): def step_pre(self, step_index):
self.step_index = step_index self.step_index = step_index
if GET_DTYPE() == "BF16": if GET_DTYPE() == "BF16":
self.latents = self.latents.to(dtype=torch.bfloat16) self.latents = self.latents.to(dtype=torch.bfloat16)
def set_timesteps(
self,
infer_steps: Union[int, None] = None,
device: Union[str, torch.device] = None,
sigmas: Optional[List[float]] = None,
mu: Optional[Union[float, None]] = None,
shift: Optional[Union[float, None]] = None,
):
sigmas = np.linspace(self.sigma_max, self.sigma_min, infer_steps + 1).copy()[:-1]
if shift is None:
shift = self.shift
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
sigma_last = 0
timesteps = sigmas * self.num_train_timesteps
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas)
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
assert len(self.timesteps) == self.infer_steps
self.model_outputs = [
None,
] * self.solver_order
self.lower_order_nums = 0
self.last_sample = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu")
def prepare(self, image_encoder_output=None): def prepare(self, image_encoder_output=None):
self.prepare_latents(self.config.target_shape, dtype=torch.float32) self.prepare_latents(self.config.target_shape, dtype=torch.float32)
...@@ -90,24 +51,15 @@ class EulerSchedulerTimestepFix(BaseScheduler): ...@@ -90,24 +51,15 @@ class EulerSchedulerTimestepFix(BaseScheduler):
elif self.config.task in ["i2v"]: elif self.config.task in ["i2v"]:
self.seq_len = ((self.config.target_video_length - 1) // self.config.vae_stride[0] + 1) * self.config.lat_h * self.config.lat_w // (self.config.patch_size[1] * self.config.patch_size[2]) self.seq_len = ((self.config.target_video_length - 1) // self.config.vae_stride[0] + 1) * self.config.lat_h * self.config.lat_w // (self.config.patch_size[1] * self.config.patch_size[2])
alphas = np.linspace(1, 1 / self.num_train_timesteps, self.num_train_timesteps)[::-1].copy() timesteps = np.linspace(self.num_train_timesteps, 0, self.infer_steps + 1, dtype=np.float32)
sigmas = 1.0 - alphas
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
self.sigmas = sigmas self.timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=self.device)
self.timesteps = sigmas * self.num_train_timesteps self.timesteps_ori = self.timesteps.clone()
self.model_outputs = [None] * self.solver_order self.sigmas = self.timesteps_ori / self.num_train_timesteps
self.timestep_list = [None] * self.solver_order self.sigmas = self.sample_shift * self.sigmas / (1 + (self.sample_shift - 1) * self.sigmas)
self.last_sample = None
self.sigmas = self.sigmas.to("cpu") self.timesteps = self.sigmas * self.num_train_timesteps
self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item()
self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift)
def prepare_latents(self, target_shape, dtype=torch.float32): def prepare_latents(self, target_shape, dtype=torch.float32):
self.generator = torch.Generator(device=self.device).manual_seed(self.config.seed) self.generator = torch.Generator(device=self.device).manual_seed(self.config.seed)
...@@ -128,21 +80,25 @@ class EulerSchedulerTimestepFix(BaseScheduler): ...@@ -128,21 +80,25 @@ class EulerSchedulerTimestepFix(BaseScheduler):
model_output = self.noise_pred.to(torch.float32) model_output = self.noise_pred.to(torch.float32)
sample = self.latents.to(torch.float32) sample = self.latents.to(torch.float32)
sample = sample.to(torch.float32)
sigma = unsqueeze_to_ndim(self.sigmas[self.step_index], sample.ndim).to(sample.device, sample.dtype) sigma = unsqueeze_to_ndim(self.sigmas[self.step_index], sample.ndim).to(sample.device, sample.dtype)
sigma_next = unsqueeze_to_ndim(self.sigmas[self.step_index + 1], sample.ndim).to(sample.device, sample.dtype) sigma_next = unsqueeze_to_ndim(self.sigmas[self.step_index + 1], sample.ndim).to(sample.device, sample.dtype)
# x0 = sample - model_output * sigma
x_t_next = sample + (sigma_next - sigma) * model_output x_t_next = sample + (sigma_next - sigma) * model_output
self.latents = x_t_next self.latents = x_t_next
def reset(self): def reset(self):
self.model_outputs = [None]
self.timestep_list = [None]
self.last_sample = None
self.noise_pred = None
self.this_order = None
self.lower_order_nums = 0
self.prepare_latents(self.config.target_shape, dtype=torch.float32) self.prepare_latents(self.config.target_shape, dtype=torch.float32)
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
class ConsistencyModelScheduler(EulerSchedulerTimestepFix):
def step_post(self):
logger.info(f"Step index: {self.step_index}, self.timestep: {self.timesteps[self.step_index]}")
model_output = self.noise_pred.to(torch.float32)
sample = self.latents.to(torch.float32)
sigma = unsqueeze_to_ndim(self.sigmas[self.step_index], sample.ndim).to(sample.device, sample.dtype)
sigma_next = unsqueeze_to_ndim(self.sigmas[self.step_index + 1], sample.ndim).to(sample.device, sample.dtype)
x0 = sample - model_output * sigma
x_t_next = x0 * (1 - sigma_next) + sigma_next * torch.randn_like(x0)
self.latents = x_t_next
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment