scheduler.py 568 Bytes
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import torch
PengGao's avatar
PengGao committed
2

gushiqiao's avatar
gushiqiao committed
3
from lightx2v.utils.envs import *
helloyongyang's avatar
helloyongyang committed
4
5


Dongz's avatar
Dongz committed
6
class BaseScheduler:
helloyongyang's avatar
helloyongyang committed
7
8
    def __init__(self, config):
        self.config = config
helloyongyang's avatar
helloyongyang committed
9
        self.latents = None
Rongjin Yang's avatar
Rongjin Yang committed
10
        self.step_index = 0
11
12
        self.infer_steps = config.infer_steps
        self.caching_records = [True] * config.infer_steps
13
        self.flag_df = False
14
        self.transformer_infer = None
Dongz's avatar
Dongz committed
15

helloyongyang's avatar
helloyongyang committed
16
17
    def step_pre(self, step_index):
        self.step_index = step_index
gushiqiao's avatar
gushiqiao committed
18
19
        if GET_DTYPE() == "BF16":
            self.latents = self.latents.to(dtype=torch.bfloat16)
20
21
22

    def clear(self):
        pass