inference_dfdnet.py 8.61 KB
Newer Older
mashun1's avatar
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import argparse
import glob
import numpy as np
import os
import torch
import torchvision.transforms as transforms
from skimage import io

from basicsr.archs.dfdnet_arch import DFDNet
from basicsr.utils import imwrite, tensor2img

try:
    from facexlib.utils.face_restoration_helper import FaceRestoreHelper
except ImportError:
    print('Please install facexlib: pip install facexlib')

# TODO: need to modify, as we have updated the FaceRestorationHelper


def get_part_location(landmarks):
    """Get part locations from landmarks."""
    map_left_eye = list(np.hstack((range(17, 22), range(36, 42))))
    map_right_eye = list(np.hstack((range(22, 27), range(42, 48))))
    map_nose = list(range(29, 36))
    map_mouth = list(range(48, 68))

    # left eye
    mean_left_eye = np.mean(landmarks[map_left_eye], 0)  # (x, y)
    half_len_left_eye = np.max(
        (np.max(np.max(landmarks[map_left_eye], 0) - np.min(landmarks[map_left_eye], 0)) / 2, 16))  # A number
    loc_left_eye = np.hstack((mean_left_eye - half_len_left_eye + 1, mean_left_eye + half_len_left_eye)).astype(int)
    loc_left_eye = torch.from_numpy(loc_left_eye).unsqueeze(0)
    # (1, 4), the four numbers forms two  coordinates in the diagonal

    # right eye
    mean_right_eye = np.mean(landmarks[map_right_eye], 0)
    half_len_right_eye = np.max(
        (np.max(np.max(landmarks[map_right_eye], 0) - np.min(landmarks[map_right_eye], 0)) / 2, 16))
    loc_right_eye = np.hstack(
        (mean_right_eye - half_len_right_eye + 1, mean_right_eye + half_len_right_eye)).astype(int)
    loc_right_eye = torch.from_numpy(loc_right_eye).unsqueeze(0)
    # nose
    mean_nose = np.mean(landmarks[map_nose], 0)
    half_len_nose = np.max(
        (np.max(np.max(landmarks[map_nose], 0) - np.min(landmarks[map_nose], 0)) / 2, 16))  # noqa: E126
    loc_nose = np.hstack((mean_nose - half_len_nose + 1, mean_nose + half_len_nose)).astype(int)
    loc_nose = torch.from_numpy(loc_nose).unsqueeze(0)
    # mouth
    mean_mouth = np.mean(landmarks[map_mouth], 0)
    half_len_mouth = np.max(
        (np.max(np.max(landmarks[map_mouth], 0) - np.min(landmarks[map_mouth], 0)) / 2, 16))  # noqa: E126
    loc_mouth = np.hstack((mean_mouth - half_len_mouth + 1, mean_mouth + half_len_mouth)).astype(int)
    loc_mouth = torch.from_numpy(loc_mouth).unsqueeze(0)

    return loc_left_eye, loc_right_eye, loc_nose, loc_mouth


if __name__ == '__main__':
    """We try to align to the official codes. But there are still slight
    differences: 1) we use dlib for 68 landmark detection; 2) the used image
    package are different (especially for reading and writing.)
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    parser = argparse.ArgumentParser()

    parser.add_argument('--upscale_factor', type=int, default=2)
    parser.add_argument(
        '--model_path',
        type=str,
        default=  # noqa: E251
        'experiments/pretrained_models/DFDNet/DFDNet_official-d1fa5650.pth')
    parser.add_argument(
        '--dict_path',
        type=str,
        default=  # noqa: E251
        'experiments/pretrained_models/DFDNet/DFDNet_dict_512-f79685f0.pth')
    parser.add_argument('--test_path', type=str, default='datasets/TestWhole')
    parser.add_argument('--upsample_num_times', type=int, default=1)
    parser.add_argument('--save_inverse_affine', action='store_true')
    parser.add_argument('--only_keep_largest', action='store_true')
    # The official codes use skimage.io to read the cropped images from disk
    # instead of directly using the intermediate results in the memory (as we
    # do). Such a different operation brings slight differences due to
    # skimage.io. For aligning with the official results, we could set the
    # official_adaption to True.
    parser.add_argument('--official_adaption', type=bool, default=True)

    # The following are the paths for dlib models
    parser.add_argument(
        '--detection_path',
        type=str,
        default=  # noqa: E251
        'experiments/pretrained_models/dlib/mmod_human_face_detector-4cb19393.dat'  # noqa: E501
    )
    parser.add_argument(
        '--landmark5_path',
        type=str,
        default=  # noqa: E251
        'experiments/pretrained_models/dlib/shape_predictor_5_face_landmarks-c4b1e980.dat'  # noqa: E501
    )
    parser.add_argument(
        '--landmark68_path',
        type=str,
        default=  # noqa: E251
        'experiments/pretrained_models/dlib/shape_predictor_68_face_landmarks-fbdc2cb8.dat'  # noqa: E501
    )

    args = parser.parse_args()
    if args.test_path.endswith('/'):  # solve when path ends with /
        args.test_path = args.test_path[:-1]
    result_root = f'results/DFDNet/{os.path.basename(args.test_path)}'

    # set up the DFDNet
    net = DFDNet(64, dict_path=args.dict_path).to(device)
    checkpoint = torch.load(args.model_path, map_location=lambda storage, loc: storage)
    net.load_state_dict(checkpoint['params'])
    net.eval()

    save_crop_root = os.path.join(result_root, 'cropped_faces')
    save_inverse_affine_root = os.path.join(result_root, 'inverse_affine')
    os.makedirs(save_inverse_affine_root, exist_ok=True)
    save_restore_root = os.path.join(result_root, 'restored_faces')
    save_final_root = os.path.join(result_root, 'final_results')

    face_helper = FaceRestoreHelper(args.upscale_factor, face_size=512)

    # scan all the jpg and png images
    for img_path in sorted(glob.glob(os.path.join(args.test_path, '*.[jp][pn]g'))):
        img_name = os.path.basename(img_path)
        print(f'Processing {img_name} image ...')
        save_crop_path = os.path.join(save_crop_root, img_name)
        if args.save_inverse_affine:
            save_inverse_affine_path = os.path.join(save_inverse_affine_root, img_name)
        else:
            save_inverse_affine_path = None

        face_helper.init_dlib(args.detection_path, args.landmark5_path, args.landmark68_path)
        # detect faces
        num_det_faces = face_helper.detect_faces(
            img_path, upsample_num_times=args.upsample_num_times, only_keep_largest=args.only_keep_largest)
        # get 5 face landmarks for each face
        num_landmarks = face_helper.get_face_landmarks_5()
        print(f'\tDetect {num_det_faces} faces, {num_landmarks} landmarks.')
        # warp and crop each face
        face_helper.warp_crop_faces(save_crop_path, save_inverse_affine_path)

        if args.official_adaption:
            path, ext = os.path.splitext(save_crop_path)
            paths = sorted(glob.glob(f'{path}_[0-9]*.png'))
            cropped_faces = [io.imread(path) for path in paths]
        else:
            cropped_faces = face_helper.cropped_faces

        # get 68 landmarks for each cropped face
        num_landmarks = face_helper.get_face_landmarks_68()
        print(f'\tDetect {num_landmarks} faces for 68 landmarks.')

        face_helper.free_dlib_gpu_memory()

        print('\tFace restoration ...')
        # face restoration for each cropped face
        assert len(cropped_faces) == len(face_helper.all_landmarks_68)
        for idx, (cropped_face, landmarks) in enumerate(zip(cropped_faces, face_helper.all_landmarks_68)):
            if landmarks is None:
                print(f'Landmarks is None, skip cropped faces with idx {idx}.')
                # just copy the cropped faces to the restored faces
                restored_face = cropped_face
            else:
                # prepare data
                part_locations = get_part_location(landmarks)
                cropped_face = transforms.ToTensor()(cropped_face)
                cropped_face = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(cropped_face)
                cropped_face = cropped_face.unsqueeze(0).to(device)

                try:
                    with torch.no_grad():
                        output = net(cropped_face, part_locations)
                        restored_face = tensor2img(output, min_max=(-1, 1))
                    del output
                    torch.cuda.empty_cache()
                except Exception as e:
                    print(f'DFDNet inference fail: {e}')
                    restored_face = tensor2img(cropped_face, min_max=(-1, 1))

            path = os.path.splitext(os.path.join(save_restore_root, img_name))[0]
            save_path = f'{path}_{idx:02d}.png'
            imwrite(restored_face, save_path)
            face_helper.add_restored_face(restored_face)

        print('\tGenerate the final result ...')
        # paste each restored face to the input image
        face_helper.paste_faces_to_input_image(os.path.join(save_final_root, img_name))

        # clean all the intermediate results to process the next image
        face_helper.clean_all()

    print(f'\nAll results are saved in {result_root}')