inpainting.py 5.04 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
2
#
Vijay Korthikanti's avatar
Vijay Korthikanti committed
3
4
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
xingjinliang's avatar
xingjinliang committed
5

6
7
8
9
10
import math
import apex
import einops
import torch
import torch.nn.functional as F
xingjinliang's avatar
xingjinliang committed
11
12
13
14
15
16
from megatron.training import get_args, print_rank_0
from megatron.legacy.model.utils import get_linear_layer
from megatron.legacy.model.vision.vit_backbone import VitBackbone
from megatron.legacy.model.module import MegatronModule
from megatron.legacy.model.vision.mit_backbone import mit_b3
from megatron.legacy.model.vision.utils import resize
17
18
19
20


class VitInpaintingModel(MegatronModule):

liangjing's avatar
v1  
liangjing committed
21
    def __init__(self, config, pre_process=True, post_process=True):
22
23
24
        super(VitInpaintingModel, self).__init__()
        args = get_args()

xingjinliang's avatar
xingjinliang committed
25
        self.config = config
26
27
        self.pre_process = pre_process
        self.post_process = post_process
liangjing's avatar
v1  
liangjing committed
28
        self.hidden_size = config.hidden_size
29
        self.backbone = VitBackbone(
liangjing's avatar
v1  
liangjing committed
30
            config=config,
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
            pre_process=self.pre_process,
            post_process=self.post_process,
            class_token=False,
        )
        self.patch_dim = args.patch_dim
        self.img_h = args.img_h
        self.img_w = args.img_w
        self.seq_length = args.seq_length
        # full mask

        if self.post_process:
            self.linear_decoder = get_linear_layer(
                self.hidden_size,
                self.backbone.flatten_dim,
                torch.nn.init.zeros_
            )

    def set_input_tensor(self, input_tensor):
        self.backbone.set_input_tensor(input_tensor)

    def forward(self, input):

        hidden_states = self.backbone(input)

        if not self.post_process:
            return hidden_states
        decoded_output = self.linear_decoder(hidden_states)
        output = einops.rearrange(
                decoded_output,
                "b (h w) (p1 p2 c) -> b c (h p1) (w p2)",
                p1=self.patch_dim,
                p2=self.patch_dim,
                h=self.img_h//self.patch_dim,
                w=self.img_w//self.patch_dim,
            )

        return output


class MLP(torch.nn.Module):
    """
    Linear Embedding
    """
    def __init__(self, input_dim=2048, embed_dim=768):
        super().__init__()
        self.proj = torch.nn.Linear(input_dim, embed_dim)

    def forward(self, x):
        x = x.flatten(2).transpose(1, 2)
        x = self.proj(x)
        return x


class MitInpaintingModel(MegatronModule):
    """Mix vision Transformer Model."""

    def __init__(self, pre_process=True, post_process=True):
        super(MitInpaintingModel, self).__init__()
        self.pre_process = pre_process
        self.post_process = post_process

        args = get_args()
        self.patch_dim = args.patch_dim
        self.img_h = args.img_h
        self.img_w = args.img_w
        self.flatten_dim = self.patch_dim * self.patch_dim * 3
        self.backbone = mit_b3()

        self.in_channels = [64, 128, 320, 512]
        self.embedding_dim = 768

        c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels

        self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=self.embedding_dim)
        self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=self.embedding_dim)
        self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=self.embedding_dim)
        self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=self.embedding_dim)

        self.conv_fuse = torch.nn.Conv2d(self.embedding_dim*4, self.embedding_dim, 1, 1, bias=False)
        self.norm = apex.parallel.SyncBatchNorm(self.embedding_dim)
        self.dropout = torch.nn.Dropout2d(0.1)
xingjinliang's avatar
xingjinliang committed
112

113
        self.linear_pred = torch.nn.Conv2d(self.embedding_dim, self.flatten_dim, kernel_size=1)
xingjinliang's avatar
xingjinliang committed
114

115
    def set_input_tensor(self, input_tensor):
xingjinliang's avatar
xingjinliang committed
116
        """See megatron.legacy.model.transformer.set_input_tensor()"""
117
118
119
120
121
122
123
124
        pass

    def forward(self, input):
        c1, c2, c3, c4 = self.backbone(input)

        n, _, h, w = c4.shape
        _c4 = self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3])
        _c4 = resize(_c4, size=c1.size()[2:], mode='bilinear', align_corners=False)
xingjinliang's avatar
xingjinliang committed
125

126
127
128
129
130
131
132
133
134
135
        _c3 = self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3])
        _c3 = resize(_c3, size=c1.size()[2:], mode='bilinear', align_corners=False)

        _c2 = self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3])
        _c2 = resize(_c2, size=c1.size()[2:], mode='bilinear', align_corners=False)

        _c1 = self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3])

        _c = torch.cat([_c4, _c3, _c2, _c1], dim=1)
        _c = self.conv_fuse(_c)
xingjinliang's avatar
xingjinliang committed
136

137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
        x = self.norm(_c)
        x = F.relu(x, inplace=True)
        x = self.dropout(x)

        x = self.linear_pred(x)

        output = einops.rearrange(
            x,
            "b (c p1 p2) h w -> b c (h p1) (w p2)",
            p1=self.patch_dim,
            p2=self.patch_dim,
            h=self.img_h//self.patch_dim,
            w=self.img_w//self.patch_dim,
        )

        return output