util.py 7.13 KB
Newer Older
1
import importlib
2
from inspect import isfunction
3
4

import numpy as np
5
import torch
6
from PIL import Image, ImageDraw, ImageFont
7
from torch import optim
8
9
10
11
12
13
14
15
16
17


def log_txt_as_img(wh, xc, size=10):
    # wh a tuple of (width, height)
    # xc a list of captions to plot
    b = len(xc)
    txts = list()
    for bi in range(b):
        txt = Image.new("RGB", wh, color="white")
        draw = ImageDraw.Draw(txt)
18
        font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
19
        nc = int(40 * (wh[0] / 256))
20
        lines = "\n".join(xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc))
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

        try:
            draw.text((0, 0), lines, fill="black", font=font)
        except UnicodeEncodeError:
            print("Cant encode string for logging. Skipping.")

        txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
        txts.append(txt)
    txts = np.stack(txts)
    txts = torch.tensor(txts)
    return txts


def ismap(x):
    if not isinstance(x, torch.Tensor):
        return False
    return (len(x.shape) == 4) and (x.shape[1] > 3)


def isimage(x):
41
    if not isinstance(x, torch.Tensor):
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
        return False
    return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)


def exists(x):
    return x is not None


def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


def mean_flat(tensor):
    """
    https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
    Take the mean over all non-batch dimensions.
    """
    return tensor.mean(dim=list(range(1, len(tensor.shape))))


def count_params(model, verbose=False):
    total_params = sum(p.numel() for p in model.parameters())
    if verbose:
Fazzie's avatar
Fazzie committed
67
        print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
68
69
70
71
72
    return total_params


def instantiate_from_config(config):
    if not "target" in config:
73
        if config == "__is_first_stage__":
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
            return None
        elif config == "__is_unconditional__":
            return None
        raise KeyError("Expected key `target` to instantiate.")
    return get_obj_from_str(config["target"])(**config.get("params", dict()))


def get_obj_from_str(string, reload=False):
    module, cls = string.rsplit(".", 1)
    if reload:
        module_imp = importlib.import_module(module)
        importlib.reload(module_imp)
    return getattr(importlib.import_module(module, package=None), cls)


Fazzie's avatar
Fazzie committed
89
90
class AdamWwithEMAandWings(optim.Optimizer):
    # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
91
92
93
94
95
96
97
98
99
100
101
102
    def __init__(
        self,
        params,
        lr=1.0e-3,
        betas=(0.9, 0.999),
        eps=1.0e-8,  # TODO: check hyperparameters before using
        weight_decay=1.0e-2,
        amsgrad=False,
        ema_decay=0.9999,  # ema decay to match previous code
        ema_power=1.0,
        param_names=(),
    ):
Fazzie's avatar
Fazzie committed
103
104
105
106
107
108
109
110
111
112
113
114
115
        """AdamW that saves EMA versions of the parameters."""
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        if not 0.0 <= weight_decay:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        if not 0.0 <= ema_decay <= 1.0:
            raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
116
117
118
119
120
121
122
123
124
125
        defaults = dict(
            lr=lr,
            betas=betas,
            eps=eps,
            weight_decay=weight_decay,
            amsgrad=amsgrad,
            ema_decay=ema_decay,
            ema_power=ema_power,
            param_names=param_names,
        )
Fazzie's avatar
Fazzie committed
126
127
128
129
130
        super().__init__(params, defaults)

    def __setstate__(self, state):
        super().__setstate__(state)
        for group in self.param_groups:
131
            group.setdefault("amsgrad", False)
Fazzie's avatar
Fazzie committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.
        Args:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            params_with_grad = []
            grads = []
            exp_avgs = []
            exp_avg_sqs = []
            ema_params_with_grad = []
            max_exp_avg_sqs = []
            state_steps = []
153
154
155
156
            amsgrad = group["amsgrad"]
            beta1, beta2 = group["betas"]
            ema_decay = group["ema_decay"]
            ema_power = group["ema_power"]
Fazzie's avatar
Fazzie committed
157

158
            for p in group["params"]:
Fazzie's avatar
Fazzie committed
159
160
161
162
                if p.grad is None:
                    continue
                params_with_grad.append(p)
                if p.grad.is_sparse:
163
                    raise RuntimeError("AdamW does not support sparse gradients")
Fazzie's avatar
Fazzie committed
164
165
166
167
168
169
                grads.append(p.grad)

                state = self.state[p]

                # State initialization
                if len(state) == 0:
170
                    state["step"] = 0
Fazzie's avatar
Fazzie committed
171
                    # Exponential moving average of gradient values
172
                    state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
Fazzie's avatar
Fazzie committed
173
                    # Exponential moving average of squared gradient values
174
                    state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
Fazzie's avatar
Fazzie committed
175
176
                    if amsgrad:
                        # Maintains max of all exp. moving avg. of sq. grad. values
177
                        state["max_exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
Fazzie's avatar
Fazzie committed
178
                    # Exponential moving average of parameter values
179
                    state["param_exp_avg"] = p.detach().float().clone()
Fazzie's avatar
Fazzie committed
180

181
182
183
                exp_avgs.append(state["exp_avg"])
                exp_avg_sqs.append(state["exp_avg_sq"])
                ema_params_with_grad.append(state["param_exp_avg"])
Fazzie's avatar
Fazzie committed
184
185

                if amsgrad:
186
                    max_exp_avg_sqs.append(state["max_exp_avg_sq"])
Fazzie's avatar
Fazzie committed
187
188

                # update the steps for each param group update
189
                state["step"] += 1
Fazzie's avatar
Fazzie committed
190
                # record the step after step update
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
                state_steps.append(state["step"])

            optim._functional.adamw(
                params_with_grad,
                grads,
                exp_avgs,
                exp_avg_sqs,
                max_exp_avg_sqs,
                state_steps,
                amsgrad=amsgrad,
                beta1=beta1,
                beta2=beta2,
                lr=group["lr"],
                weight_decay=group["weight_decay"],
                eps=group["eps"],
                maximize=False,
            )

            cur_ema_decay = min(ema_decay, 1 - state["step"] ** -ema_power)
Fazzie's avatar
Fazzie committed
210
211
212
            for param, ema_param in zip(params_with_grad, ema_params_with_grad):
                ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)

213
        return loss