LadderEncoder.py 4.56 KB
Newer Older
mashun1's avatar
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import torch.nn.utils.spectral_norm as spectral_norm


def get_nonspade_norm_layer(norm_type="instance"):
    # helper function to get # output channels of the previous layer
    def get_out_channel(layer):
        if hasattr(layer, "out_channels"):
            return getattr(layer, "out_channels")
        return layer.weight.size(0)

    # this function will be returned
    def add_norm_layer(layer):
        nonlocal norm_type
        subnorm_type = norm_type
        if norm_type.startswith("spectral"):
            layer = spectral_norm(layer)
            subnorm_type = norm_type[len("spectral") :]

        if subnorm_type == "none" or len(subnorm_type) == 0:
            return layer

        # remove bias in the previous layer, which is meaningless
        # since it has no effect after normalization
        if getattr(layer, "bias", None) is not None:
            delattr(layer, "bias")
            layer.register_parameter("bias", None)

        if subnorm_type == "batch":
            norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
        elif subnorm_type == "sync_batch":
            norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
        elif subnorm_type == "instance":
            norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
        else:
            raise ValueError("normalization layer %s is not recognized" % subnorm_type)

        return nn.Sequential(layer, norm_layer)

    return add_norm_layer


class LadderEncoder(nn.Module):
    """ Same architecture as the image discriminator """

    def __init__(self, need_feat=False, use_mask=False, label_nc=0, z_dim=512, norm_type="spectralinstance"):
        super().__init__()
        self.need_feat = need_feat
        ldmk_img_nc = 3

        nif = 3 + label_nc + 2 * ldmk_img_nc

        kw = 3
        pw = int(np.ceil((kw - 1.0) / 2))
        nef = 64
        norm_layer = get_nonspade_norm_layer(norm_type)
        self.layer1 = norm_layer(nn.Conv2d(nif, nef, kw, stride=2, padding=pw))
        self.layer2 = norm_layer(nn.Conv2d(nef * 1, nef * 2, kw, stride=2, padding=pw))
        self.layer3 = norm_layer(nn.Conv2d(nef * 2, nef * 4, kw, stride=2, padding=pw))
        self.layer4 = norm_layer(nn.Conv2d(nef * 4, nef * 8, kw, stride=2, padding=pw))
        self.layer5 = norm_layer(nn.Conv2d(nef * 8, nef * 8, kw, stride=2, padding=pw))
        self.layer6 = norm_layer(nn.Conv2d(nef * 8, nef * 8, kw, stride=2, padding=pw))

        if need_feat:
            self.up_layer2 = norm_layer(
                nn.Conv2d(nef * 2, nef * 2, kw, stride=1, padding=pw)
            )
            self.up_layer3 = nn.Sequential(
                norm_layer(nn.Conv2d(nef * 4, nef * 2, kw, stride=1, padding=pw)),
                nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
            )
            self.up_layer4 = nn.Sequential(
                norm_layer(nn.Conv2d(nef * 8, nef * 2, kw, stride=1, padding=pw)),
                nn.Upsample(scale_factor=4, mode="bilinear", align_corners=True),
            )
            self.up_layer5 = nn.Sequential(
                norm_layer(nn.Conv2d(nef * 8, nef * 2, kw, stride=1, padding=pw)),
                nn.Upsample(scale_factor=8, mode="bilinear", align_corners=True),
            )
            self.up_layer6 = nn.Sequential(
                norm_layer(nn.Conv2d(nef * 8, nef * 2, kw, stride=1, padding=pw)),
                nn.Upsample(scale_factor=16, mode="bilinear", align_corners=True),
            )

        self.actvn = nn.LeakyReLU(0.2, False)
        self.so = s0 = 4
        self.fc = nn.Linear(nef * 8 * s0 * s0, z_dim)

    def forward(self, x):
        features = None
        if x.size(2) != 256 or x.size(3) != 256:
            x = F.interpolate(x, size=(256, 256), mode="bilinear")

        x = self.layer1(x)
        x = self.layer2(self.actvn(x))
        if self.need_feat:
            features = self.up_layer2(x)
        x = self.layer3(self.actvn(x))
        if self.need_feat:
            features = self.up_layer3(x) + features
        x = self.layer4(self.actvn(x))
        if self.need_feat:
            features = self.up_layer4(x) + features
        x = self.layer5(self.actvn(x))
        if self.need_feat:
            features = self.up_layer5(x) + features
        x = self.layer6(self.actvn(x))
        if self.need_feat:
            features = self.up_layer6(x) + features

        x = self.actvn(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = x / (x.norm(dim=-1, p=2, keepdim=True) + 1e-5)

        return x, features