ba.py 2.3 KB
Newer Older
zachteed's avatar
zachteed committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import lietorch
import torch
import torch.nn.functional as F

from .chol import block_solve
import geom.projective_ops as pops

# utility functions for scattering ops
def safe_scatter_add_mat(H, data, ii, jj, B, M, D):
    v = (ii >= 0) & (jj >= 0)
    H.scatter_add_(1, (ii[v]*M + jj[v]).view(1,-1,1,1).repeat(B,1,D,D), data[:,v])

def safe_scatter_add_vec(b, data, ii, B, M, D):
    v = ii >= 0
    b.scatter_add_(1, ii[v].view(1,-1,1).repeat(B,1,D), data[:,v])

def MoBA(target, weight, poses, disps, intrinsics, ii, jj, fixedp=1, lm=0.0001, ep=0.1):
    """ MoBA: Motion Only Bundle Adjustment """

    B, M = poses.shape[:2]
    D = poses.manifold_dim
    N = ii.shape[0]

    ### 1: commpute jacobians and residuals ###
    coords, valid, (Ji, Jj) = pops.projective_transform(
        poses, disps, intrinsics, ii, jj, jacobian=True)

    r = (target - coords).view(B, N, -1, 1)
    w = (valid * weight).view(B, N, -1, 1)

    ### 2: construct linear system ###
    Ji = Ji.view(B, N, -1, D)
    Jj = Jj.view(B, N, -1, D)
    wJiT = (.001 * w * Ji).transpose(2,3)
    wJjT = (.001 * w * Jj).transpose(2,3)

    Hii = torch.matmul(wJiT, Ji)
    Hij = torch.matmul(wJiT, Jj)
    Hji = torch.matmul(wJjT, Ji)
    Hjj = torch.matmul(wJjT, Jj)

    vi = torch.matmul(wJiT, r).squeeze(-1)
    vj = torch.matmul(wJjT, r).squeeze(-1)

    # only optimize keyframe poses
    M = M - fixedp
    ii = ii - fixedp
    jj = jj - fixedp

    H = torch.zeros(B, M*M, D, D, device=target.device)
    safe_scatter_add_mat(H, Hii, ii, ii, B, M, D)
    safe_scatter_add_mat(H, Hij, ii, jj, B, M, D)
    safe_scatter_add_mat(H, Hji, jj, ii, B, M, D)
    safe_scatter_add_mat(H, Hjj, jj, jj, B, M, D)
    H = H.reshape(B, M, M, D, D)

    v = torch.zeros(B, M, D, device=target.device)
    safe_scatter_add_vec(v, vi, ii, B, M, D)
    safe_scatter_add_vec(v, vj, jj, B, M, D)

    ### 3: solve the system + apply retraction ###
    dx = block_solve(H, v, ep=ep, lm=lm)
    
    poses1, poses2 = poses[:,:fixedp], poses[:,fixedp:]
    poses2 = poses2.retr(dx)
    
    poses = lietorch.cat([poses1, poses2], dim=1)
    return poses

def SLessBA(target, weight, poses, disps, intrinsics, ii, jj, fixedp=1):
    """ Structureless Bundle Adjustment """
    pass


def BA(target, weight, poses, disps, intrinsics, ii, jj, fixedp=1):
    """ Full Bundle Adjustment """
    pass