scheduler.py 3.83 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
import torch
from ..scheduler import WanScheduler


5
class WanSchedulerTeaCaching(WanScheduler):
helloyongyang's avatar
helloyongyang committed
6
7
8
9
10
11
12
13
14
15
16
17
18
    def __init__(self, args):
        super().__init__(args)
        self.cnt = 0
        self.num_steps = self.args.infer_steps * 2
        self.teacache_thresh = self.args.teacache_thresh
        self.accumulated_rel_l1_distance_even = 0
        self.accumulated_rel_l1_distance_odd = 0
        self.previous_e0_even = None
        self.previous_e0_odd = None
        self.previous_residual_even = None
        self.previous_residual_odd = None
        self.use_ret_steps = self.args.use_ret_steps

Dongz's avatar
Dongz committed
19
        if self.args.task == "i2v":
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
            if self.use_ret_steps:
                if self.args.target_width == 480 or self.args.target_height == 480:
                    self.coefficients = [
                        2.57151496e05,
                        -3.54229917e04,
                        1.40286849e03,
                        -1.35890334e01,
                        1.32517977e-01,
                    ]
                if self.args.target_width == 720 or self.args.target_height == 720:
                    self.coefficients = [
                        8.10705460e03,
                        2.13393892e03,
                        -3.72934672e02,
                        1.66203073e01,
                        -4.17769401e-02,
                    ]
                self.ret_steps = 5 * 2
                self.cutoff_steps = self.args.infer_steps * 2
            else:
                if self.args.target_width == 480 or self.args.target_height == 480:
                    self.coefficients = [
                        -3.02331670e02,
                        2.23948934e02,
                        -5.25463970e01,
                        5.87348440e00,
                        -2.01973289e-01,
                    ]
                if self.args.target_width == 720 or self.args.target_height == 720:
                    self.coefficients = [
                        -114.36346466,
                        65.26524496,
                        -18.82220707,
                        4.91518089,
                        -0.23412683,
                    ]
                self.ret_steps = 1 * 2
                self.cutoff_steps = self.args.infer_steps * 2 - 2

Dongz's avatar
Dongz committed
59
        elif self.args.task == "t2v":
60
            if self.use_ret_steps:
Dongz's avatar
Dongz committed
61
62
63
64
                if "1.3B" in self.args.model_path:
                    self.coefficients = [-5.21862437e04, 9.23041404e03, -5.28275948e02, 1.36987616e01, -4.99875664e-02]
                if "14B" in self.args.model_path:
                    self.coefficients = [-3.03318725e05, 4.90537029e04, -2.65530556e03, 5.87365115e01, -3.15583525e-01]
65
66
67
                self.ret_steps = 5 * 2
                self.cutoff_steps = self.args.infer_steps * 2
            else:
Dongz's avatar
Dongz committed
68
69
70
71
                if "1.3B" in self.args.model_path:
                    self.coefficients = [2.39676752e03, -1.31110545e03, 2.01331979e02, -8.29855975e00, 1.37887774e-01]
                if "14B" in self.args.model_path:
                    self.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
72
                self.ret_steps = 1 * 2
Dongz's avatar
Dongz committed
73
                self.cutoff_steps = self.args.infer_steps * 2 - 2
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

    def clear(self):
        if self.previous_e0_even is not None:
            self.previous_e0_even = self.previous_e0_even.cpu()
        if self.previous_e0_odd is not None:
            self.previous_e0_odd = self.previous_e0_odd.cpu()
        if self.previous_residual_even is not None:
            self.previous_residual_even = self.previous_residual_even.cpu()
        if self.previous_residual_odd is not None:
            self.previous_residual_odd = self.previous_residual_odd.cpu()
        self.previous_e0_even = None
        self.previous_e0_odd = None
        self.previous_residual_even = None
        self.previous_residual_odd = None
        torch.cuda.empty_cache()