scheduler.py 424 Bytes
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import torch
gushiqiao's avatar
gushiqiao committed
2
from lightx2v.utils.envs import *
helloyongyang's avatar
helloyongyang committed
3
4


Dongz's avatar
Dongz committed
5
class BaseScheduler:
helloyongyang's avatar
helloyongyang committed
6
7
    def __init__(self, config):
        self.config = config
helloyongyang's avatar
helloyongyang committed
8
9
        self.step_index = 0
        self.latents = None
10
        self.flag_df = False
Dongz's avatar
Dongz committed
11

helloyongyang's avatar
helloyongyang committed
12
13
    def step_pre(self, step_index):
        self.step_index = step_index
gushiqiao's avatar
gushiqiao committed
14
15
        if GET_DTYPE() == "BF16":
            self.latents = self.latents.to(dtype=torch.bfloat16)
16
17
18

    def clear(self):
        pass