scheduler.py 288 Bytes
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
6
7
8
9
10
11
12
import torch


class BaseScheduler():
    def __init__(self, args):
        self.args = args
        self.step_index = 0
        self.latents = None
    
    def step_pre(self, step_index):
        self.step_index = step_index
        self.latents = self.latents.to(dtype=torch.bfloat16)