exhaustive_raft.py 3.65 KB
Newer Older
bailuo's avatar
bailuo committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
This script computes all pairwise RAFT optical flow fields
for each pair, we use previous flow as initialization to compute the current flow
"""

import sys

sys.path.append('core')

import argparse
import os
import glob
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm

from raft import RAFT
from utils.utils import InputPadder
import warnings

warnings.filterwarnings("ignore")

DEVICE = 'cuda'


def load_image(imfile):
    img = np.array(Image.open(imfile)).astype(np.uint8)
    img = torch.from_numpy(img).permute(2, 0, 1).float()
    return img[None].to(DEVICE)


def run_exhaustive_flow(args):
    model = torch.nn.DataParallel(RAFT(args))
    model.load_state_dict(torch.load(args.model))

    model = model.module
    model.to(DEVICE)
    model.eval()

    data_dir = args.data_dir
    print('computing all pairwise optical flows for {}...'.format(data_dir))

    flow_out_dir = os.path.join(data_dir, 'raft_exhaustive')
    os.makedirs(flow_out_dir, exist_ok=True)

    img_files = sorted(glob.glob(os.path.join(data_dir, 'color', '*')))
    num_imgs = len(img_files)
    pbar = tqdm(total=num_imgs * (num_imgs - 1))
    with torch.no_grad():
        for i in range(num_imgs - 1):
            flow_low_prev = None
            for j in range(i + 1, num_imgs):
                imfile1 = img_files[i]
                imfile2 = img_files[j]
                image1 = load_image(imfile1)
                image2 = load_image(imfile2)

                padder = InputPadder(image1.shape)
                image1, image2 = padder.pad(image1, image2)

                flow_low, flow_up = model(image1, image2, iters=20, test_mode=True, flow_init=flow_low_prev)
                flow_up = padder.unpad(flow_up)

                flow_up_np = flow_up.squeeze().permute(1, 2, 0).cpu().numpy()
                save_file = os.path.join(flow_out_dir,
                                         '{}_{}.npy'.format(os.path.basename(imfile1), os.path.basename(imfile2)))
                np.save(save_file, flow_up_np)
                flow_low_prev = flow_low
                pbar.update(1)

        for i in range(num_imgs - 1, 0, -1):
            flow_low_prev = None
            for j in range(i - 1, -1, -1):
                imfile1 = img_files[i]
                imfile2 = img_files[j]
                image1 = load_image(imfile1)
                image2 = load_image(imfile2)

                padder = InputPadder(image1.shape)
                image1, image2 = padder.pad(image1, image2)

                flow_low, flow_up = model(image1, image2, iters=20, test_mode=True, flow_init=flow_low_prev)
                flow_up = padder.unpad(flow_up)

                flow_up_np = flow_up.squeeze().permute(1, 2, 0).cpu().numpy()
                save_file = os.path.join(flow_out_dir,
                                         '{}_{}.npy'.format(os.path.basename(imfile1), os.path.basename(imfile2)))
                np.save(save_file, flow_up_np)
                flow_low_prev = flow_low
                pbar.update(1)
        pbar.close()
        print('computing all pairwise optical flows for {} is done \n'.format(data_dir))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', default='models/raft-things.pth', help="restore checkpoint")
    parser.add_argument('--small', action='store_true', help='use small model')
    parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
    parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
    parser.add_argument('--data_dir', type=str, default='', help='dataset dir')
    args = parser.parse_args()

    run_exhaustive_flow(args)