ddpm_loss.py 22 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
# Copyright (c) OpenMMLab. All rights reserved.
from abc import abstractmethod
from copy import deepcopy
from functools import partial

import mmcv
import torch
import torch.distributed as dist
import torch.nn as nn
from mmcv.utils import digit_version

from mmgen.models.builder import MODULES
from .pixelwise_loss import (DiscretizedGaussianLogLikelihoodLoss,
                             GaussianKLDLoss, _reduction_modes, mse_loss)
from .utils import reduce_loss


class DDPMLoss(nn.Module):
    """Base module for DDPM losses. We support loss weight rescale and log
    collection for DDPM models in this module.

    We support two kinds of loss rescale methods, which can be
    controlled by ``rescale_mode`` and ``rescale_cfg``:
    1. ``rescale_mode == 'constant'``: ``constant_rescale`` would be called,
        and ``rescale_cfg`` should be passed as ``dict(scale=SCALE)``,
        e.g., ``dict(scale=1.2)``. Then, all loss terms would be rescaled by
        multiply with ``SCALE``
    2. ``rescale_mode == timestep_weight``: ``timestep_weight_rescale`` would
        be called, and ``weight`` or ``sampler`` who contains attribute of
        weight must be passed. Then, loss at timestep `t` would be multiplied
        with `weight[t]`. We also support users further apply a constant
        rescale factor to all loss terms, e.g.
        ``rescale_cfg=dict(scale=SCALE)``. The overall rescale function for
        loss at timestep ``t`` can be formulated as
        `loss[t] := weight[t] * loss[t] * SCALE`. To be noted that, ``weight``
        or ``sampler.weight`` would be inplace modified in the outer code.
        e.g.,

        .. code-blocks:: python
            :linenos:

            # 1. define weight
            weight = torch.randn(10, )

            # 2. define loss function
            loss_fn = DDPMLoss(rescale_mode='timestep_weight', weight=weight)

            # 3 update weight
            # wrong usage: `weight` in `loss_fn` is not accessible from now
            # because you assign a new tensor to variable `weight`
            # weight = torch.randn(10, )

            # correct usage: update `weight` inplace
            weight[2] = 2

    If ``rescale_mode`` is not passed, ``rescale_cfg`` would be ignored, and
    all loss terms would not be rescaled.

    For loss log collection, we support users to pass a list of (or single)
    config by ``log_cfgs`` argument to define how they want to collect loss
    terms and show them in the log. Each log collection returns a dict which
    key and value are the name and value of collected loss terms. And the dict
    will be merged into  ``log_vars`` after the loss used for parameter
    optimization is calculated. The log updating process for the class which
    uses ddpm_loss can be referred to the following pseudo-code:

    .. code-block:: python
        :linenos:

        # 1. loss dict for parameter optimization
        losses_dict = {}

        # 2. calculate losses
        for loss_fn in self.ddpm_loss:
            losses_dict[loss_fn.loss_name()] = loss_fn(outputs_dict)

        # 3. init log_vars
        log_vars = OrderedDict()

        # 4. update log_vars with loss terms used for parameter optimization
        for loss_name, loss_value in losses_dict.items():
            log_vars[loss_name] = loss_value.mean()

        # 5. sum all loss terms used for backward
        loss = sum(_value for _key, _value in log_vars.items()
                   if 'loss' in _key)

        # 6. update log_var with log collection functions
        for loss_fn in self.ddpm_loss:
            if hasattr(loss_fn, 'log_vars'):
                log_vars.update(loss_fn.log_vars)

    Each log configs must contain ``type`` keyword, and may contain ``prefix``
    and ``reduction`` keywords.

    ``type``: Use to get the corresponding collection function. Functions would
        be named as ``f'{type}_log_collect'``. In `DDPMLoss`, we only support
        ``type=='quartile'``, but users may define their log collection
        functions and use them in this way.
    ``prefix``: This keyword is set for avoiding the name of displayed loss
        terms being too long. The name of each loss term will set as
        ``'{prefix}_{log_coll_fn_spec_name}'``, where
        ``{log_coll_fn_spec_name}`` is name specific to the log collection
        function. If passed, it must start with ``'loss_'``. If not passed,
        ``'loss_'`` would be used.
    ``reduction``: Control the reduction method of the collected loss terms.

    We implement ``quartile_log_collection`` in this module. In detail, we
    divide total timesteps into four parts and collect the loss in the
    corresponding timestep intervals.

    To use those collection methods, users may pass ``log_cfgs`` as the
    following example:

    .. code-block:: python
        :linenos:

        log_cfgs = [
            dict(type='quartile', reduction=REUCTION, prefix_name=PREFIX),
            ...
        ]

    Args:
        rescale_mode (str, optional): Mode of the loss rescale method.
            Defaults to None.
        rescale_cfg (dict, optional): Config of the loss rescale method.
        log_cfgs (list[dict] | dict | optional): Configs to collect logs.
            Defaults to None.
        sampler (object): Weight sampler. Defaults to None.
        weight (torch.Tensor, optional): Weight used for rescale losses.
            Defaults to None.
        reduction (str, optional): Same as built-in losses of PyTorch.
            Defaults to 'mean'.
        loss_name (str, optional): Name of the loss item. Defaults to None.
    """

    def __init__(self,
                 rescale_mode=None,
                 rescale_cfg=None,
                 log_cfgs=None,
                 weight=None,
                 sampler=None,
                 reduction='mean',
                 loss_name=None):
        super().__init__()

        if reduction not in _reduction_modes:
            raise ValueError(f'Unsupported reduction mode: {reduction}. '
                             f'Supported ones are: {_reduction_modes}')
        self.reduction = reduction
        self._loss_name = loss_name

        self.log_fn_list = []

        log_cfgs_ = deepcopy(log_cfgs)
        if log_cfgs_ is not None:
            if not isinstance(log_cfgs_, list):
                log_cfgs_ = [log_cfgs_]
            assert mmcv.is_list_of(log_cfgs_, dict)
            for log_cfg_ in log_cfgs_:
                log_type = log_cfg_.pop('type')
                log_collect_fn = f'{log_type}_log_collect'
                assert hasattr(self, log_collect_fn)
                log_collect_fn = getattr(self, log_collect_fn)

                log_cfg_.setdefault('prefix_name', 'loss')
                assert log_cfg_['prefix_name'].startswith('loss')
                log_cfg_.setdefault('reduction', reduction)

                self.log_fn_list.append(partial(log_collect_fn, **log_cfg_))
        self.log_vars = dict()

        # handle rescale mode
        if not rescale_mode:
            self.rescale_fn = lambda loss, t: loss
        else:
            rescale_fn_name = f'{rescale_mode}_rescale'
            assert hasattr(self, rescale_fn_name)
            if rescale_mode == 'timestep_weight':
                if sampler is not None and hasattr(sampler, 'weight'):
                    weight = sampler.weight
                else:
                    assert weight is not None and isinstance(
                        weight, torch.Tensor), (
                            '\'weight\' or a \'sampler\' contains weight '
                            'attribute is must be \'torch.Tensor\' for '
                            '\'timestep_weight\' rescale_mode.')

                mmcv.print_log(
                    'Apply \'timestep_weight\' rescale_mode for '
                    f'{self._loss_name}. Please make sure the passed weight '
                    'can be updated by external functions.', 'mmgen')

                rescale_cfg = dict(weight=weight)
            self.rescale_fn = partial(
                getattr(self, rescale_fn_name), **rescale_cfg)

    @staticmethod
    def constant_rescale(loss, timesteps, scale):
        """Rescale losses at all timesteps with a constant factor.

        Args:
            loss (torch.Tensor): Losses to rescale.
            timesteps (torch.Tensor): Timesteps of each loss items.
            scale (int): Rescale factor.

        Returns:
            torch.Tensor: Rescaled losses.
        """

        return loss * scale

    @staticmethod
    def timestep_weight_rescale(loss, timesteps, weight, scale=1):
        """Rescale losses corresponding to timestep.

        Args:
            loss (torch.Tensor): Losses to rescale.
            timesteps (torch.Tensor): Timesteps of each loss items.
            weight (torch.Tensor): Weight corresponding to each timestep.
            scale (int): Rescale factor.

        Returns:
            torch.Tensor: Rescaled losses.
        """

        return loss * weight[timesteps] * scale

    @torch.no_grad()
    def collect_log(self, loss, timesteps):
        """Collect logs.

        Args:
            loss (torch.Tensor): Losses to collect.
            timesteps (torch.Tensor): Timesteps of each loss items.
        """
        if not self.log_fn_list:
            return

        if dist.is_initialized():
            ws = dist.get_world_size()
            placeholder_l = [torch.zeros_like(loss) for _ in range(ws)]
            placeholder_t = [torch.zeros_like(timesteps) for _ in range(ws)]
            dist.all_gather(placeholder_l, loss)
            dist.all_gather(placeholder_t, timesteps)
            loss = torch.cat(placeholder_l, dim=0)
            timesteps = torch.cat(placeholder_t, dim=0)
        log_vars = dict()

        if (dist.is_initialized()
                and dist.get_rank() == 0) or not dist.is_initialized():
            for log_fn in self.log_fn_list:
                log_vars.update(log_fn(loss, timesteps))
        self.log_vars = log_vars

    @torch.no_grad()
    def quartile_log_collect(self,
                             loss,
                             timesteps,
                             total_timesteps,
                             prefix_name,
                             reduction='mean'):
        """Collect loss logs by quartile timesteps.

        Args:
            loss (torch.Tensor): Loss value of each input. Each loss tensor
                should be shape as [bz, ]
            timesteps (torch.Tensor): Timesteps corresponding to each loss.
                Each loss tensor should be shape as [bz, ].
            total_timesteps (int): Total timesteps of diffusion process.
            prefix_name (str): Prefix want to show in logs.
            reduction (str, optional): Specifies the reduction to apply to the
                output losses. Defaults to `mean`.

        Returns:
            dict: Collected log variables.
        """
        if digit_version(torch.__version__) <= digit_version('1.6.0'):
            # use true_divide in older torch version
            quartile = torch.true_divide(timesteps, total_timesteps) * 4
        else:
            quartile = (timesteps / total_timesteps * 4)
        quartile = quartile.type(torch.LongTensor)

        log_vars = dict()

        for idx in range(4):
            if not (quartile == idx).any():
                loss_quartile = torch.zeros((1, ))
            else:
                loss_quartile = reduce_loss(loss[quartile == idx], reduction)
            log_vars[f'{prefix_name}_quartile_{idx}'] = loss_quartile.item()

        return log_vars

    def forward(self, *args, **kwargs):
        """Forward function.

        If ``self.data_info`` is not ``None``, a dictionary containing all of
        the data and necessary modules should be passed into this function.
        If this dictionary is given as a non-keyword argument, it should be
        offered as the first argument. If you are using keyword argument,
        please name it as `outputs_dict`.

        If ``self.data_info`` is ``None``, the input argument or key-word
        argument will be directly passed to loss function, ``mse_loss``.
        """
        if len(args) == 1:
            assert isinstance(args[0], dict), (
                'You should offer a dictionary containing network outputs '
                'for building up computational graph of this loss module.')
            output_dict = args[0]
        elif 'output_dict' in kwargs:
            assert len(args) == 0, (
                'If the outputs dict is given in keyworded arguments, no'
                ' further non-keyworded arguments should be offered.')
            output_dict = kwargs.pop('outputs_dict')
        else:
            raise NotImplementedError(
                'Cannot parsing your arguments passed to this loss module.'
                ' Please check the usage of this module')

        # check keys in output_dict
        assert 'timesteps' in output_dict, (
            '\'timesteps\' is must for DDPM-based losses, but found'
            f'{output_dict.keys()} in \'output_dict\'')

        timesteps = output_dict['timesteps']
        loss = self._forward_loss(output_dict)

        # update log_vars of this class
        self.collect_log(loss, timesteps=timesteps)

        loss_rescaled = self.rescale_fn(loss, timesteps)
        return reduce_loss(loss_rescaled, self.reduction)

    @abstractmethod
    def _forward_loss(self, output_dict):
        """Forward function for loss calculation. This method should be
        implemented by each subclasses.

        Args:
            outputs_dict (dict): Outputs of the model used to calculate losses.

        Returns:
            torch.Tensor: Calculated loss.
        """

        raise NotImplementedError(
            '\'self._forward_loss\' must be implemented.')

    def loss_name(self):
        """Loss Name.

        This function must be implemented and will return the name of this
        loss function. This name will be used to combine different loss items
        by simple sum operation. In addition, if you want this loss item to be
        included into the backward graph, `loss_` must be the prefix of the
        name.

        Returns:
            str: The name of this loss item.
        """
        return self._loss_name


@MODULES.register_module()
class DDPMVLBLoss(DDPMLoss):
    """Variational lower-bound loss for DDPM-based models.
    In this loss, we calculate VLB of different timesteps with different
    method. In detail, ``DiscretizedGaussianLogLikelihoodLoss`` is used at
    timesteps = 0 and ``GaussianKLDLoss`` at other timesteps.
    To control the data flow for loss calculation, users should define
    ``data_info`` and ``data_info_t_0`` for ``GaussianKLDLoss`` and
    ``DiscretizedGaussianLogLikelihoodLoss`` respectively. If not passed
    ``_default_data_info`` and ``_default_data_info_t_0`` would be used.
    To be noted that, we only penalize 'variance' in this loss term, and
    tensors in output dict corresponding to 'mean' would be detached.

    Additionally, we support another log collection function called
    ``name_log_collection``. In this collection method, we would directly
    collect loss terms calculated by different methods.
    To use this collection methods, users may passed ``log_cfgs`` as the
    following example:

    .. code-block:: python
        :linenos:

        log_cfgs = [
            dict(type='name', reduction=REUCTION, prefix_name=PREFIX),
            ...
        ]

    Args:
        rescale_mode (str, optional): Mode of the loss rescale method.
            Defaults to None.
        rescale_cfg (dict, optional): Config of the loss rescale method.
        sampler (object): Weight sampler. Defaults to None.
        weight (torch.Tensor, optional): Weight used for rescale losses.
            Defaults to None.
        data_info (dict, optional): Dictionary contains the mapping between
            loss input args and data dictionary for ``timesteps != 0``.
            Defaults to None.
        data_info_t_0 (dict, optional): Dictionary contains the mapping between
            loss input args and data dictionary for ``timesteps == 0``.
            Defaults to None.
        log_cfgs (list[dict] | dict | optional): Configs to collect logs.
            Defaults to None.
        reduction (str, optional): Same as built-in losses of PyTorch.
            Defaults to 'mean'.
        loss_name (str, optional): Name of the loss item. Defaults to
            'loss_ddpm_vlb'.
    """
    _default_data_info = dict(
        mean_pred='mean_pred',
        mean_target='mean_target',
        logvar_pred='logvar_pred',
        logvar_target='logvar_target')
    _default_data_info_t_0 = dict(
        x='real_imgs', mean='mean_pred', logvar='logvar_pred')

    def __init__(self,
                 rescale_mode=None,
                 rescale_cfg=None,
                 sampler=None,
                 weight=None,
                 data_info=None,
                 data_info_t_0=None,
                 log_cfgs=None,
                 reduction='mean',
                 loss_name='loss_ddpm_vlb'):
        super().__init__(rescale_mode, rescale_cfg, log_cfgs, weight, sampler,
                         reduction, loss_name)

        self.data_info = self._default_data_info \
            if data_info is None else data_info
        self.data_info_t_0 = self._default_data_info_t_0 \
            if data_info_t_0 is None else data_info_t_0

        self.loss_list = [
            DiscretizedGaussianLogLikelihoodLoss(
                reduction='flatmean',
                data_info=self.data_info_t_0,
                base='2',
                loss_weight=-1,
                only_update_var=True),
            GaussianKLDLoss(
                reduction='flatmean',
                data_info=self.data_info,
                base='2',
                only_update_var=True)
        ]
        self.loss_select_fn_list = [lambda t: t == 0, lambda t: t != 0]

    @torch.no_grad()
    def name_log_collect(self, loss, timesteps, prefix_name, reduction='mean'):
        """Collect loss logs by name (GaissianKLD and
        DiscGaussianLogLikelihood).

        Args:
            loss (torch.Tensor): Loss value of each input. Each loss tensor
                should be in the shape of [bz, ].
            timesteps (torch.Tensor): Timesteps corresponding to each losses.
                Each loss tensor should be in the shape of [bz, ].
            prefix_name (str): Prefix want to show in logs.
            reduction (str, optional): Specifies the reduction to apply to the
                output losses. Defaults to `mean`.

        Returns:
            dict: Collected log variables.
        """
        log_vars = dict()
        for select_fn, loss_fn in zip(self.loss_select_fn_list,
                                      self.loss_list):
            mask = select_fn(timesteps)
            if not mask.any():
                loss_reduced = torch.zeros((1, ))
            else:
                loss_reduced = reduce_loss(loss[mask], reduction)
            # remove original prefix in loss names
            loss_term_name = loss_fn.loss_name().replace('loss_', '')
            log_vars[f'{prefix_name}_{loss_term_name}'] = loss_reduced.item()

        return log_vars

    def _forward_loss(self, outputs_dict):
        """Forward function for loss calculation.
        Args:
            outputs_dict (dict): Outputs of the model used to calculate losses.

        Returns:
            torch.Tensor: Calculated loss.
        """
        # use `zeros` instead of `zeros_like` to avoid get int tensor
        timesteps = outputs_dict['timesteps']
        loss = torch.zeros_like(timesteps).float()
        # loss = torch.zeros(*timesteps.shape).to(timesteps.device)
        for select_fn, loss_fn in zip(self.loss_select_fn_list,
                                      self.loss_list):
            mask = select_fn(timesteps)
            outputs_dict_ = {}
            for k, v in outputs_dict.items():
                if v is None or not isinstance(v, (torch.Tensor, list)):
                    outputs_dict_[k] = v
                elif isinstance(v, list):
                    outputs_dict_[k] = [
                        v[idx] for idx, m in enumerate(mask) if m
                    ]
                else:
                    outputs_dict_[k] = v[mask]
            loss[mask] = loss_fn(outputs_dict_)
        return loss


@MODULES.register_module()
class DDPMMSELoss(DDPMLoss):
    """Mean square loss for DDPM-based models.

    Args:
        rescale_mode (str, optional): Mode of the loss rescale method.
            Defaults to None.
        rescale_cfg (dict, optional): Config of the loss rescale method.
        sampler (object): Weight sampler. Defaults to None.
        weight (torch.Tensor, optional): Weight used for rescale losses.
            Defaults to None.
        data_info (dict, optional): Dictionary contains the mapping between
            loss input args and data dictionary for ``timesteps != 0``.
            Defaults to None.
        log_cfgs (list[dict] | dict | optional): Configs to collect logs.
            Defaults to None.
        reduction (str, optional): Same as built-in losses of PyTorch.
            Defaults to 'mean'.
        loss_name (str, optional): Name of the loss item. Defaults to
            'loss_ddpm_vlb'.
    """
    _default_data_info = dict(pred='eps_t_pred', target='noise')

    def __init__(self,
                 rescale_mode=None,
                 rescale_cfg=None,
                 sampler=None,
                 weight=None,
                 log_cfgs=None,
                 reduction='mean',
                 data_info=None,
                 loss_name='loss_ddpm_mse'):
        super().__init__(rescale_mode, rescale_cfg, log_cfgs, weight, sampler,
                         reduction, loss_name)

        self.data_info = self._default_data_info \
            if data_info is None else data_info

        self.loss_fn = partial(mse_loss, reduction='flatmean')

    def _forward_loss(self, outputs_dict):
        """Forward function for loss calculation.
        Args:
            outputs_dict (dict): Outputs of the model used to calculate losses.

        Returns:
            torch.Tensor: Calculated loss.
        """
        loss_input_dict = {
            k: outputs_dict[v]
            for k, v in self.data_info.items()
        }
        loss = self.loss_fn(**loss_input_dict)
        return loss