cond_fn.py 2.54 KB
Newer Older
1
from typing import overload, Optional
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
2
3
4
5
6
import torch
from torch.nn import functional as F


class Guidance:
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29

    def __init__(
        self,
        scale: float,
        t_start: int,
        t_stop: int,
        space: str,
        repeat: int
    ) -> "Guidance":
        """
        Initialize latent image guidance.
        
        Args:
            scale (float): Gradient scale (denoted as `s` in our paper). The larger the gradient scale, 
                the closer the final result will be to the output of the first stage model.
            t_start (int), t_stop (int): The timestep to start or stop guidance. Note that the sampling 
                process starts from t=1000 to t=0, the `t_start` should be larger than `t_stop`.
            space (str): The data space for computing loss function (rgb or latent).
            repeat (int): Repeat gradient descent for `repeat` times.

        Our latent image guidance is based on [GDP](https://github.com/Fayeben/GenerativeDiffusionPrior).
        Thanks for their work!
        """
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
30
31
32
33
34
35
36
        self.scale = scale
        self.t_start = t_start
        self.t_stop = t_stop
        self.target = None
        self.space = space
        self.repeat = repeat
    
37
    def load_target(self, target: torch.Tensor) -> torch.Tensor:
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
38
39
        self.target = target

40
    def __call__(self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int) -> Optional[torch.Tensor]:
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
41
42
43
44
45
46
47
48
49
50
        if self.t_stop < t and t < self.t_start:
            # print("sampling with classifier guidance")
            # avoid propagating gradient out of this scope
            pred_x0 = pred_x0.detach().clone()
            target_x0 = target_x0.detach().clone()
            return self.scale * self._forward(target_x0, pred_x0)
        else:
            return None
    
    @overload
51
52
    def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor) -> torch.Tensor:
        ...
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
53
54
55
56


class MSEGuidance(Guidance):
    
57
58
59
60
61
62
63
64
    def __init__(
        self,
        scale: float,
        t_start: int,
        t_stop: int,
        space: str,
        repeat: int
    ) -> "MSEGuidance":
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
65
        super().__init__(
66
            scale, t_start, t_stop, space, repeat
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
67
68
69
        )
    
    @torch.enable_grad()
70
    def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor) -> torch.Tensor:
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
71
72
73
        # inputs: [-1, 1], nchw, rgb
        pred_x0.requires_grad_(True)
        
74
75
        # This is what we actually use.
        loss = (pred_x0 - target_x0).pow(2).mean((1, 2, 3)).sum()
0x3f3f3f3fun's avatar
0x3f3f3f3fun committed
76
77
78
        
        print(f"loss = {loss.item()}")
        return -torch.autograd.grad(loss, pred_x0)[0]