scheduler.py 3.88 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
import torch
from ..scheduler import WanScheduler


5
class WanSchedulerTeaCaching(WanScheduler):
helloyongyang's avatar
helloyongyang committed
6
7
    def __init__(self, config):
        super().__init__(config)
helloyongyang's avatar
helloyongyang committed
8
        self.cnt = 0
helloyongyang's avatar
helloyongyang committed
9
10
        self.num_steps = self.config.infer_steps * 2
        self.teacache_thresh = self.config.teacache_thresh
helloyongyang's avatar
helloyongyang committed
11
12
13
14
15
16
        self.accumulated_rel_l1_distance_even = 0
        self.accumulated_rel_l1_distance_odd = 0
        self.previous_e0_even = None
        self.previous_e0_odd = None
        self.previous_residual_even = None
        self.previous_residual_odd = None
helloyongyang's avatar
helloyongyang committed
17
        self.use_ret_steps = self.config.use_ret_steps
helloyongyang's avatar
helloyongyang committed
18

helloyongyang's avatar
helloyongyang committed
19
        if self.config.task == "i2v":
20
            if self.use_ret_steps:
helloyongyang's avatar
helloyongyang committed
21
                if self.config.target_width == 480 or self.config.target_height == 480:
22
23
24
25
26
27
28
                    self.coefficients = [
                        2.57151496e05,
                        -3.54229917e04,
                        1.40286849e03,
                        -1.35890334e01,
                        1.32517977e-01,
                    ]
helloyongyang's avatar
helloyongyang committed
29
                if self.config.target_width == 720 or self.config.target_height == 720:
30
31
32
33
34
35
36
37
                    self.coefficients = [
                        8.10705460e03,
                        2.13393892e03,
                        -3.72934672e02,
                        1.66203073e01,
                        -4.17769401e-02,
                    ]
                self.ret_steps = 5 * 2
helloyongyang's avatar
helloyongyang committed
38
                self.cutoff_steps = self.config.infer_steps * 2
39
            else:
helloyongyang's avatar
helloyongyang committed
40
                if self.config.target_width == 480 or self.config.target_height == 480:
41
42
43
44
45
46
47
                    self.coefficients = [
                        -3.02331670e02,
                        2.23948934e02,
                        -5.25463970e01,
                        5.87348440e00,
                        -2.01973289e-01,
                    ]
helloyongyang's avatar
helloyongyang committed
48
                if self.config.target_width == 720 or self.config.target_height == 720:
49
50
51
52
53
54
55
56
                    self.coefficients = [
                        -114.36346466,
                        65.26524496,
                        -18.82220707,
                        4.91518089,
                        -0.23412683,
                    ]
                self.ret_steps = 1 * 2
helloyongyang's avatar
helloyongyang committed
57
                self.cutoff_steps = self.config.infer_steps * 2 - 2
58

helloyongyang's avatar
helloyongyang committed
59
        elif self.config.task == "t2v":
60
            if self.use_ret_steps:
helloyongyang's avatar
helloyongyang committed
61
                if "1.3B" in self.config.model_path:
Dongz's avatar
Dongz committed
62
                    self.coefficients = [-5.21862437e04, 9.23041404e03, -5.28275948e02, 1.36987616e01, -4.99875664e-02]
helloyongyang's avatar
helloyongyang committed
63
                if "14B" in self.config.model_path:
Dongz's avatar
Dongz committed
64
                    self.coefficients = [-3.03318725e05, 4.90537029e04, -2.65530556e03, 5.87365115e01, -3.15583525e-01]
65
                self.ret_steps = 5 * 2
helloyongyang's avatar
helloyongyang committed
66
                self.cutoff_steps = self.config.infer_steps * 2
67
            else:
helloyongyang's avatar
helloyongyang committed
68
                if "1.3B" in self.config.model_path:
Dongz's avatar
Dongz committed
69
                    self.coefficients = [2.39676752e03, -1.31110545e03, 2.01331979e02, -8.29855975e00, 1.37887774e-01]
helloyongyang's avatar
helloyongyang committed
70
                if "14B" in self.config.model_path:
Dongz's avatar
Dongz committed
71
                    self.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
72
                self.ret_steps = 1 * 2
helloyongyang's avatar
helloyongyang committed
73
                self.cutoff_steps = self.config.infer_steps * 2 - 2
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

    def clear(self):
        if self.previous_e0_even is not None:
            self.previous_e0_even = self.previous_e0_even.cpu()
        if self.previous_e0_odd is not None:
            self.previous_e0_odd = self.previous_e0_odd.cpu()
        if self.previous_residual_even is not None:
            self.previous_residual_even = self.previous_residual_even.cpu()
        if self.previous_residual_odd is not None:
            self.previous_residual_odd = self.previous_residual_odd.cpu()
        self.previous_e0_even = None
        self.previous_e0_odd = None
        self.previous_residual_even = None
        self.previous_residual_odd = None
        torch.cuda.empty_cache()