guiders.py 4 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import logging
import math
from abc import ABC, abstractmethod
from functools import partial
from typing import Dict, List, Optional, Tuple, Union

import torch
from einops import rearrange, repeat

from ...util import append_dims, default, instantiate_from_config


class Guider(ABC):

    @abstractmethod
    def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
        pass

    def prepare_inputs(self, x: torch.Tensor, s: float, c: Dict,
                       uc: Dict) -> Tuple[torch.Tensor, float, Dict]:
        pass


class VanillaCFG:
    """
    implements parallelized CFG
    """

    def __init__(self, scale, dyn_thresh_config=None):
        self.scale = scale
        scale_schedule = lambda scale, sigma: scale  # independent of step
        self.scale_schedule = partial(scale_schedule, scale)
        self.dyn_thresh = instantiate_from_config(
            default(
                dyn_thresh_config,
                {
                    'target':
                    'sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding'
                },
            ))

    def __call__(self, x, sigma, scale=None):
        x_u, x_c = x.chunk(2)
        scale_value = default(scale, self.scale_schedule(sigma))
        x_pred = self.dyn_thresh(x_u, x_c, scale_value)
        return x_pred

    def prepare_inputs(self, x, s, c, uc):
        c_out = dict()

        for k in c:
            if k in ['vector', 'crossattn', 'concat']:
                c_out[k] = torch.cat((uc[k], c[k]), 0)
            else:
                assert c[k] == uc[k]
                c_out[k] = c[k]
        return torch.cat([x] * 2), torch.cat([s] * 2), c_out


# class DynamicCFG(VanillaCFG):

#     def __init__(self, scale, exp, num_steps, dyn_thresh_config=None):
#         super().__init__(scale, dyn_thresh_config)
#         scale_schedule = (lambda scale, sigma, step_index: 1 + scale *
#                           (1 - math.cos(math.pi *
#                                         (step_index / num_steps)**exp)) / 2)
#         self.scale_schedule = partial(scale_schedule, scale)
#         self.dyn_thresh = instantiate_from_config(
#             default(
#                 dyn_thresh_config,
#                 {
#                     'target':
#                     'sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding'
#                 },
#             ))

#     def __call__(self, x, sigma, step_index, scale=None):
#         x_u, x_c = x.chunk(2)
#         scale_value = self.scale_schedule(sigma, step_index.item())
#         x_pred = self.dyn_thresh(x_u, x_c, scale_value)
#         return x_pred


class DynamicCFG(VanillaCFG):

    def __init__(self, scale, exp, num_steps, dyn_thresh_config=None):
        super().__init__(scale, dyn_thresh_config)

        self.scale = scale
        self.num_steps = num_steps
        self.exp = exp
        scale_schedule = (lambda scale, sigma, step_index: 1 + scale *
                          (1 - math.cos(math.pi *
                                        (step_index / num_steps)**exp)) / 2)

        #self.scale_schedule = partial(scale_schedule, scale)
        self.dyn_thresh = instantiate_from_config(
            default(
                dyn_thresh_config,
                {
                    'target':
                    'sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding'
                },
            ))

    def scale_schedule_dy(self, sigma, step_index):
        #        print(self.scale)
        return 1 + self.scale * (
            1 - math.cos(math.pi *
                         (step_index / self.num_steps)**self.exp)) / 2

    def __call__(self, x, sigma, step_index, scale=None):
        x_u, x_c = x.chunk(2)
        scale_value = self.scale_schedule_dy(sigma, step_index.item())
        x_pred = self.dyn_thresh(x_u, x_c, scale_value)
        return x_pred


class IdentityGuider:

    def __call__(self, x, sigma):
        return x

    def prepare_inputs(self, x, s, c, uc):
        c_out = dict()

        for k in c:
            c_out[k] = c[k]

        return x, s, c_out