scheduler.py 1.29 KB
Newer Older
1
from .utils import cache_init, cal_type
helloyongyang's avatar
helloyongyang committed
2
from ..scheduler import HunyuanScheduler
3
import torch
helloyongyang's avatar
helloyongyang committed
4
5


6
class HunyuanSchedulerTeaCaching(HunyuanScheduler):
helloyongyang's avatar
helloyongyang committed
7
8
    def __init__(self, config):
        super().__init__(config)
9
        self.cnt = 0
TorynCurtis's avatar
TorynCurtis committed
10
11
        self.num_steps = self.config.infer_steps
        self.teacache_thresh = self.config.teacache_thresh
12
13
14
15
        self.accumulated_rel_l1_distance = 0
        self.previous_modulated_input = None
        self.previous_residual = None
        self.coefficients = [7.33226126e02, -4.01131952e02, 6.75869174e01, -3.14987800e00, 9.61237896e-02]
Dongz's avatar
Dongz committed
16

17
18
19
20
21
    def clear(self):
        if self.previous_residual is not None:
            self.previous_residual = self.previous_residual.cpu()
        if self.previous_modulated_input is not None:
            self.previous_modulated_input = self.previous_modulated_input.cpu()
Dongz's avatar
Dongz committed
22

23
24
25
        self.previous_modulated_input = None
        self.previous_residual = None
        torch.cuda.empty_cache()
helloyongyang's avatar
helloyongyang committed
26
27


28
class HunyuanSchedulerTaylorCaching(HunyuanScheduler):
helloyongyang's avatar
helloyongyang committed
29
30
    def __init__(self, config):
        super().__init__(config)
helloyongyang's avatar
helloyongyang committed
31
32
33
34
        self.cache_dic, self.current = cache_init(self.infer_steps)

    def step_pre(self, step_index):
        super().step_pre(step_index)
Dongz's avatar
Dongz committed
35
        self.current["step"] = step_index
helloyongyang's avatar
helloyongyang committed
36
        cal_type(self.cache_dic, self.current)