scheduler.py 714 Bytes
Newer Older
gushiqiao's avatar
gushiqiao committed
1
from lightx2v.utils.envs import *
helloyongyang's avatar
helloyongyang committed
2
3


Dongz's avatar
Dongz committed
4
class BaseScheduler:
helloyongyang's avatar
helloyongyang committed
5
6
    def __init__(self, config):
        self.config = config
helloyongyang's avatar
helloyongyang committed
7
        self.latents = None
Rongjin Yang's avatar
Rongjin Yang committed
8
        self.step_index = 0
9
10
        self.infer_steps = config["infer_steps"]
        self.caching_records = [True] * config["infer_steps"]
11
        self.flag_df = False
12
        self.transformer_infer = None
helloyongyang's avatar
helloyongyang committed
13
        self.infer_condition = True  # cfg status
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
14
        self.keep_latents_dtype_in_scheduler = False
Dongz's avatar
Dongz committed
15

helloyongyang's avatar
helloyongyang committed
16
17
    def step_pre(self, step_index):
        self.step_index = step_index
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
18
        if GET_DTYPE() == GET_SENSITIVE_DTYPE() and not self.keep_latents_dtype_in_scheduler:
19
            self.latents = self.latents.to(GET_DTYPE())
20
21
22

    def clear(self):
        pass