demo.py 2.5 KB
Newer Older
zachteed's avatar
zachteed committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import sys
sys.path.append('../core')

import argparse
import torch
import cv2
import numpy as np

from viz import sim3_visualization
from lietorch import SO3, SE3, Sim3
from networks.sim3_net import Sim3Net

def normalize_images(images):
    images = images[:, :, [2,1,0]]
    mean = torch.as_tensor([0.485, 0.456, 0.406], device=images.device)
    std = torch.as_tensor([0.229, 0.224, 0.225], device=images.device)
    return (images/255.0).sub_(mean[:, None, None]).div_(std[:, None, None])

def load_example(i=0):
    """ get demo example """
    DEPTH_SCALE = 5.0
    if i==0:
        image1 = cv2.imread('assets/image1.png')
        image2 = cv2.imread('assets/image2.png')
        depth1 = np.load('assets/depth1.npy') / DEPTH_SCALE
        depth2 = np.load('assets/depth2.npy') / DEPTH_SCALE
    
    elif i==1:
        image1 = cv2.imread('assets/image3.png')
        image2 = cv2.imread('assets/image4.png')
        depth1 = np.load('assets/depth3.npy') / DEPTH_SCALE
        depth2 = np.load('assets/depth4.npy') / DEPTH_SCALE

    images = np.stack([image1, image2], 0)
    images = torch.from_numpy(images).permute(0,3,1,2)

    depths = np.stack([depth1, depth2], 0)
    depths = torch.from_numpy(depths).float()

    intrinsics = np.array([320.0, 320.0, 320.0, 240.0])
    intrinsics = np.tile(intrinsics[None], (2,1))
    intrinsics = torch.from_numpy(intrinsics).float()

    return images[None].cuda(), depths[None].cuda(), intrinsics[None].cuda()


@torch.no_grad()
def demo(model, index=0):

    images, depths, intrinsics = load_example(index)

    # initial transformation estimate
    if args.transformation == 'SE3':
        Gs = SE3.Identity(1, 2, device='cuda')

    elif args.transformation == 'Sim3':
        Gs = Sim3.Identity(1, 2, device='cuda')
        depths[:,0] *= 2**(2*torch.rand(1) - 1.0).cuda()

    images1 = normalize_images(images)
    ests, _ = model(Gs, images1, depths, intrinsics, num_steps=12)

    # only care about last transformation
    Gs = ests[-1] 
    T = Gs[:,0] * Gs[:,1].inv()
    
    T = T[0].matrix().double().cpu().numpy()
    sim3_visualization(T, images, depths, intrinsics)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--transformation', default='SE3', help='checkpoint to restore')
    parser.add_argument('--ckpt', help='checkpoint to restore')
    args = parser.parse_args()

    model = Sim3Net(args)
    model.load_state_dict(torch.load(args.ckpt))

    model.cuda()
    model.eval()

    # run two demos
    demo(model, 0)
    demo(model, 1)