scheduler.py 1.61 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
import torch
from ..scheduler import WanScheduler


5
class WanSchedulerTeaCaching(WanScheduler):
helloyongyang's avatar
helloyongyang committed
6
7
    def __init__(self, config):
        super().__init__(config)
helloyongyang's avatar
helloyongyang committed
8
        self.cnt = 0
helloyongyang's avatar
helloyongyang committed
9
10
        self.num_steps = self.config.infer_steps * 2
        self.teacache_thresh = self.config.teacache_thresh
helloyongyang's avatar
helloyongyang committed
11
12
13
14
15
16
        self.accumulated_rel_l1_distance_even = 0
        self.accumulated_rel_l1_distance_odd = 0
        self.previous_e0_even = None
        self.previous_e0_odd = None
        self.previous_residual_even = None
        self.previous_residual_odd = None
helloyongyang's avatar
helloyongyang committed
17
        self.use_ret_steps = self.config.use_ret_steps
gushiqiao's avatar
gushiqiao committed
18
19
20
21
22
23
24
25
        if self.use_ret_steps:
            self.coefficients = self.config.coefficients[0]
            self.ret_steps = 5 * 2
            self.cutoff_steps = self.config.infer_steps * 2
        else:
            self.coefficients = self.config.coefficients[1]
            self.ret_steps = 1 * 2
            self.cutoff_steps = self.config.infer_steps * 2 - 2
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

    def clear(self):
        if self.previous_e0_even is not None:
            self.previous_e0_even = self.previous_e0_even.cpu()
        if self.previous_e0_odd is not None:
            self.previous_e0_odd = self.previous_e0_odd.cpu()
        if self.previous_residual_even is not None:
            self.previous_residual_even = self.previous_residual_even.cpu()
        if self.previous_residual_odd is not None:
            self.previous_residual_odd = self.previous_residual_odd.cpu()
        self.previous_e0_even = None
        self.previous_e0_odd = None
        self.previous_residual_even = None
        self.previous_residual_odd = None
        torch.cuda.empty_cache()