import torch
class BaseScheduler:
def __init__(self, args): self.args = args self.step_index = 0 self.latents = None
def step_pre(self, step_index): self.step_index = step_index self.latents = self.latents.to(dtype=torch.bfloat16)
def clear(self): pass