scheduler.py 352 Bytes
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
import torch


Dongz's avatar
Dongz committed
4
class BaseScheduler:
helloyongyang's avatar
helloyongyang committed
5
6
    def __init__(self, config):
        self.config = config
helloyongyang's avatar
helloyongyang committed
7
8
        self.step_index = 0
        self.latents = None
9
        self.flag_df = False
Dongz's avatar
Dongz committed
10

helloyongyang's avatar
helloyongyang committed
11
12
13
    def step_pre(self, step_index):
        self.step_index = step_index
        self.latents = self.latents.to(dtype=torch.bfloat16)
14
15
16

    def clear(self):
        pass