main.py 3.58 KB
Newer Older
zachteed's avatar
zachteed committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
from lietorch import SO3, SE3, LieGroupParameter

import argparse
import numpy as np
import time
import torch.optim as optim
import torch.nn.functional as F


def draw(verticies):
    """ draw pose graph """
    import open3d as o3d

    n = len(verticies)
    points = np.array([x[1][:3] for x in verticies])
    lines = np.stack([np.arange(0,n-1), np.arange(1,n)], 1)

    line_set = o3d.geometry.LineSet(
        points=o3d.utility.Vector3dVector(points),
        lines=o3d.utility.Vector2iVector(lines),
    )
    o3d.visualization.draw_geometries([line_set])

def info2mat(info):
    mat = np.zeros((6,6))
    ix = 0
    for i in range(mat.shape[0]):
        mat[i,i:] = info[ix:ix+(6-i)]
        mat[i:,i] = info[ix:ix+(6-i)]
        ix += (6-i)

    return mat

def read_g2o(fn):
    verticies, edges = [], []
    with open(fn) as f:
        for line in f:
            line = line.split()
            if line[0] == 'VERTEX_SE3:QUAT':
                v = int(line[1])
                pose = np.array(line[2:], dtype=np.float32)
                verticies.append([v, pose])

            elif line[0] == 'EDGE_SE3:QUAT':
                u = int(line[1])
                v = int(line[2])
                pose = np.array(line[3:10], dtype=np.float32)
                info = np.array(line[10:], dtype=np.float32)

                info = info2mat(info)
                edges.append([u, v, pose, info, line])

    return verticies, edges

def write_g2o(pose_graph, fn):
    import csv
    verticies, edges = pose_graph
    with open(fn, 'w') as f:
        writer = csv.writer(f, delimiter=' ')
        for (v, pose) in verticies:
            row = ['VERTEX_SE3:QUAT', v] + pose.tolist()
            writer.writerow(row)
        for edge in edges:
            writer.writerow(edge[-1])

def reshaping_fn(dE, b=1.5):
    """ Reshaping function from "Intrinsic consensus on SO(3), Tron et al."""
zachteed's avatar
pgo fix  
zachteed committed
69
    ang = dE.log().norm(dim=-1)
zachteed's avatar
zachteed committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    err = 1/b - (1/b + ang) * torch.exp(-b*ang)
    return err.sum()

def gradient_initializer(pose_graph, n_steps=500, lr_init=0.2):
    """ Riemannian Gradient Descent """

    verticies, edges = pose_graph

    # edge indicies (ii, jj)
    ii = np.array([x[0] for x in edges])
    jj = np.array([x[1] for x in edges])
    ii = torch.from_numpy(ii).cuda()
    jj = torch.from_numpy(jj).cuda()

    Eij = np.stack([x[2][3:] for x in edges])
    Eij = SO3(torch.from_numpy(Eij).float().cuda())

    R = np.stack([x[1][3:] for x in verticies])
    R = SO3(torch.from_numpy(R).float().cuda())
    R = LieGroupParameter(R)

    # use gradient descent with momentum
    optimizer = optim.SGD([R], lr=lr_init, momentum=0.5)

    start = time.time()
    for i in range(n_steps):
        optimizer.zero_grad()

        for param_group in optimizer.param_groups:
            param_group['lr'] = lr_init * .995**i

        # rotation error
        dE = (R[ii].inv() * R[jj]) * Eij.inv()
        loss = reshaping_fn(dE)

        loss.backward()
        optimizer.step()

        if i%25 == 0:
            print(i, lr_init * .995**i, loss.item())

    # convert rotations to pose3
    quats = R.group.data.detach().cpu().numpy()

    for i in range(len(verticies)):
        verticies[i][1][3:] = quats[i]

    return verticies, edges


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--problem', help="input pose graph optimization file (.g2o format)")
    args = parser.parse_args()

    output_path = args.problem.replace('.g2o', '_rotavg.g2o')
    input_pose_graph = read_g2o(args.problem)

    rot_pose_graph = gradient_initializer(input_pose_graph)
    write_g2o(rot_pose_graph, output_path)