inference.py 3.67 KB
Newer Older
mashun1's avatar
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import argparse
import os

import cv2
import numpy as np
import torch
import yaml
from moviepy.editor import ImageSequenceClip
from torch.utils.data import DataLoader
from tqdm import tqdm

import utils
from dataset import PersonalDataset
from modules.discriminator import MultiScaleDiscriminator
from modules.generator import Generator
from modules.model import GeneratorFullModel


def build_model(args, conf):
    utils.set_random_seed(conf['general']['random_seed'])
    G = Generator(conf['model'].get('arch', None), **conf['model']['generator'], **conf['model']['common'])
    utils.load_ckpt(args['ckpt'], {'generator': G}, device=args['device'], strict=True)
    G.eval()
    
    D = MultiScaleDiscriminator(**conf['model']['discriminator'], **conf['model']['common'])

    G_full = GeneratorFullModel(None, G.cuda(), D.cuda(), conf['train'], conf['model'].get('arch', None), conf=conf)
    return G_full


def build_data_loader(conf, name='personal'):
    dataset = PersonalDataset(conf, name, is_train=False)
    sampler = torch.utils.data.SequentialSampler(dataset)
    test_dataloader = DataLoader(dataset, batch_size=1, sampler=sampler, num_workers=2, drop_last=False, pin_memory=True)
    return test_dataloader


def model_forward(G_full, data):
    for key, value in data.items():
        if isinstance(value, list):
            if isinstance(value[0], (str, list)):
                continue
            data[key] = [v.cuda() for v in value]
        elif isinstance(value, str):
            continue
        else:
            data[key] = value.cuda()
    
    generated = G_full(data, stage="Full", inference=True)
    return generated


def save_images(args, generated, data):
    for j in range(len(generated['prediction'])):
        final = np.transpose(generated['prediction'][j].data.cpu().numpy(), [1, 2, 0])
        final = np.clip(final * 255, 0, 255).astype(np.uint8)[:, :, ::-1]
        cv2.imwrite(os.path.join(args["save_dir"], "output_256", str(data['driving_name'][j])), final)


def save_to_video(args, gt_path):
    img_list = []
    for file in tqdm(gt_path):
        img_name = file.split("/")[-1]
        final_img_path = os.path.join(args["save_dir"], 'output_256', img_name)
        final_img = cv2.resize(cv2.cvtColor(cv2.imread(final_img_path), cv2.COLOR_BGR2RGB), (256, 256))
        img_list.append(final_img)
    
    imgseqclip = ImageSequenceClip(img_list, 23.98)
    imgseqclip.write_videofile(os.path.join(args["save_dir"], "out_256.mp4"), logger=None)


def evaluation(args, conf):
    os.makedirs(os.path.join(args["save_dir"], "output_256"), exist_ok=True)
    
    G_full = build_model(args, conf)
    name = conf["dataset"]["train_data"][0]
    test_dataloader = build_data_loader(conf, name)

    print("Evaluation using {} images.".format(len(test_dataloader.dataset)))
    
    for data in tqdm(test_dataloader):
        with torch.inference_mode():
            generated = model_forward(G_full, data)

        save_images(args, generated, data)

    save_to_video(args, test_dataloader.dataset.data['imgs'])


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Image Animation for Immersive Meeting Evaluation Scripts',
            formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--save_dir', type=str, default='../../', help='image save dir')
    parser.add_argument('--ckpt', type=str, help='load checkpoint path')
    parser.add_argument("--config", type=str, default="config/test.yaml", help="path to config")

    args = vars(parser.parse_args())
    args['device'] = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    with open(args['config']) as f:
        conf = yaml.safe_load(f)

    evaluation(args, conf)