main.py 3.58 KB
Newer Older
zachteed's avatar
zachteed committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch
from lietorch import SO3, SE3, LieGroupParameter

import argparse
import numpy as np
import time
import torch.optim as optim
import torch.nn.functional as F


def draw(verticies):
    """ draw pose graph """
    import open3d as o3d

    n = len(verticies)
    points = np.array([x[1][:3] for x in verticies])
    lines = np.stack([np.arange(0,n-1), np.arange(1,n)], 1)

    line_set = o3d.geometry.LineSet(
        points=o3d.utility.Vector3dVector(points),
        lines=o3d.utility.Vector2iVector(lines),
    )
    o3d.visualization.draw_geometries([line_set])

def info2mat(info):
    mat = np.zeros((6,6))
    ix = 0
    for i in range(mat.shape[0]):
        mat[i,i:] = info[ix:ix+(6-i)]
        mat[i:,i] = info[ix:ix+(6-i)]
        ix += (6-i)

    return mat

def read_g2o(fn):
    verticies, edges = [], []
    with open(fn) as f:
        for line in f:
            line = line.split()
            if line[0] == 'VERTEX_SE3:QUAT':
                v = int(line[1])
                pose = np.array(line[2:], dtype=np.float32)
                verticies.append([v, pose])

            elif line[0] == 'EDGE_SE3:QUAT':
                u = int(line[1])
                v = int(line[2])
                pose = np.array(line[3:10], dtype=np.float32)
                info = np.array(line[10:], dtype=np.float32)

                info = info2mat(info)
                edges.append([u, v, pose, info, line])

    return verticies, edges

def write_g2o(pose_graph, fn):
    import csv
    verticies, edges = pose_graph
    with open(fn, 'w') as f:
        writer = csv.writer(f, delimiter=' ')
        for (v, pose) in verticies:
            row = ['VERTEX_SE3:QUAT', v] + pose.tolist()
            writer.writerow(row)
        for edge in edges:
            writer.writerow(edge[-1])

def reshaping_fn(dE, b=1.5):
    """ Reshaping function from "Intrinsic consensus on SO(3), Tron et al."""
    ang = dE.log.norm(dim=-1)
    err = 1/b - (1/b + ang) * torch.exp(-b*ang)
    return err.sum()

def gradient_initializer(pose_graph, n_steps=500, lr_init=0.2):
    """ Riemannian Gradient Descent """

    verticies, edges = pose_graph

    # edge indicies (ii, jj)
    ii = np.array([x[0] for x in edges])
    jj = np.array([x[1] for x in edges])
    ii = torch.from_numpy(ii).cuda()
    jj = torch.from_numpy(jj).cuda()

    Eij = np.stack([x[2][3:] for x in edges])
    Eij = SO3(torch.from_numpy(Eij).float().cuda())

    R = np.stack([x[1][3:] for x in verticies])
    R = SO3(torch.from_numpy(R).float().cuda())
    R = LieGroupParameter(R)

    # use gradient descent with momentum
    optimizer = optim.SGD([R], lr=lr_init, momentum=0.5)

    start = time.time()
    for i in range(n_steps):
        optimizer.zero_grad()

        for param_group in optimizer.param_groups:
            param_group['lr'] = lr_init * .995**i

        # rotation error
        dE = (R[ii].inv() * R[jj]) * Eij.inv()
        loss = reshaping_fn(dE)

        loss.backward()
        optimizer.step()

        if i%25 == 0:
            print(i, lr_init * .995**i, loss.item())

    # convert rotations to pose3
    quats = R.group.data.detach().cpu().numpy()

    for i in range(len(verticies)):
        verticies[i][1][3:] = quats[i]

    return verticies, edges


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--problem', help="input pose graph optimization file (.g2o format)")
    args = parser.parse_args()

    output_path = args.problem.replace('.g2o', '_rotavg.g2o')
    input_pose_graph = read_g2o(args.problem)

    rot_pose_graph = gradient_initializer(input_pose_graph)
    write_g2o(rot_pose_graph, output_path)