scheduler.py 2.03 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
from ..scheduler import WanScheduler


class WanSchedulerFeatureCaching(WanScheduler):
    def __init__(self, args):
        super().__init__(args)
        self.cnt = 0
        self.num_steps = self.args.infer_steps * 2
        self.teacache_thresh = self.args.teacache_thresh
        self.accumulated_rel_l1_distance_even = 0
        self.accumulated_rel_l1_distance_odd = 0
        self.previous_e0_even = None
        self.previous_e0_odd = None
        self.previous_residual_even = None
        self.previous_residual_odd = None
        self.use_ret_steps = self.args.use_ret_steps

        if self.use_ret_steps:
            if self.args.target_width == 480 or self.args.target_height == 480:
                self.coefficients = [
                    2.57151496e05,
                    -3.54229917e04,
                    1.40286849e03,
                    -1.35890334e01,
                    1.32517977e-01,
                ]
            if self.args.target_width == 720 or self.args.target_height == 720:
                self.coefficients = [
                    8.10705460e03,
                    2.13393892e03,
                    -3.72934672e02,
                    1.66203073e01,
                    -4.17769401e-02,
                ]
            self.ret_steps = 5 * 2
            self.cutoff_steps = self.args.infer_steps * 2
        else:
            if self.args.target_width == 480 or self.args.target_height == 480:
                self.coefficients = [
                    -3.02331670e02,
                    2.23948934e02,
                    -5.25463970e01,
                    5.87348440e00,
                    -2.01973289e-01,
                ]
            if self.args.target_width == 720 or self.args.target_height == 720:
                self.coefficients = [
                    -114.36346466,
                    65.26524496,
                    -18.82220707,
                    4.91518089,
                    -0.23412683,
                ]
            self.ret_steps = 1 * 2
            self.cutoff_steps = self.args.infer_steps * 2 - 2