dense_motion.py 2.78 KB
Newer Older
mashun1's avatar
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from torch import nn
import torch
import functools
from modules.util import (
    Hourglass,
    make_coordinate_grid,
    LayerNorm2d,
)


class DenseMotionNetworkReg(nn.Module):
    def __init__(
        self,
        block_expansion,
        num_blocks,
        max_features,
        Lwarp=False,
        AdaINc=0,
        dec_lease=0,
        label_nc=0,
        ldmkimg=False,
        occlusion=False,
    ):
        super(DenseMotionNetworkReg, self).__init__()
        in_c = 3 + label_nc + 2 * 3 if ldmkimg else 3 + label_nc
        self.hourglass = Hourglass(
            block_expansion=block_expansion,
            in_features=in_c,
            max_features=max_features,
            num_blocks=num_blocks,
            Lwarp=Lwarp,
            AdaINc=AdaINc,
            dec_lease=dec_lease,
        )

        self.occlusion = occlusion
        if dec_lease > 0:
            norm_layer = functools.partial(LayerNorm2d, affine=True)
            self.reger = nn.Sequential(
                norm_layer(self.hourglass.out_filters),
                nn.LeakyReLU(0.1),
                nn.Conv2d(
                    self.hourglass.out_filters, 2, kernel_size=7, stride=1, padding=3
                ),
            )
            if occlusion:
                self.occlusion_net = nn.Sequential(
                    norm_layer(self.hourglass.out_filters),
                    nn.LeakyReLU(0.1),
                    nn.Conv2d(
                        self.hourglass.out_filters,
                        1,
                        kernel_size=7,
                        stride=1,
                        padding=3,
                    ),
                )
        else:
            self.reger = nn.Conv2d(
                self.hourglass.out_filters, 2, kernel_size=(7, 7), padding=(3, 3)
            )

    def forward(self, source_image, drv_deca):
        prediction = self.hourglass(source_image, drv_exp=drv_deca)

        out_dict = {}
        flow = self.reger(prediction)
        bs, _, h, w = flow.shape
        flow_norm = 2 * torch.cat(
            [flow[:, :1, ...] / (w - 1), flow[:, 1:, ...] / (h - 1)], 1
        )
        out_dict["flow"] = flow_norm
        grid = make_coordinate_grid((h, w), type=torch.FloatTensor).to(flow_norm.device)
        deformation = grid + flow_norm.permute(0, 2, 3, 1)
        out_dict["deformation"] = deformation

        if self.occlusion:
            occlusion_map = torch.sigmoid(self.occlusion_net(prediction))
            _, _, h_old, w_old = occlusion_map.shape
            _, _, h, w = source_image.shape
            if h_old != h or w_old != w:
                occlusion_map = torch.nn.functional.interpolate(
                    occlusion_map, size=(h, w), mode="bilinear", align_corners=False
                )
            out_dict["occlusion_map"] = occlusion_map
        return out_dict